SocketServerImpl.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. #ifndef SOCKETSERVERIMPL_H
  2. #define SOCKETSERVERIMPL_H
  3. #pragma once
  4. #pragma warning(push)
  5. #pragma warning(disable:4995)
  6. #include <vector>
  7. #include <list>
  8. #pragma warning(pop)
  9. #include "CritSection.h"
  10. #include "ThreadPool.hpp"
  11. #include "SocketHandle.h"
  12. typedef std::list<SOCKET> SocketList;
  13. /**
  14. * ISocketServerHandler
  15. * Event handler that SocketServerImpl<T> must implement
  16. * This class is not required, you can do the same thing as long your class exposes these functions.
  17. * (These functions are not pure to save you some typing)
  18. */
  19. class ISocketServerHandler
  20. {
  21. public:
  22. virtual void OnThreadBegin(CSocketHandle* ) {}
  23. virtual void OnThreadExit(CSocketHandle* ) {}
  24. virtual void OnThreadLoopEnter(CSocketHandle* ) {}
  25. virtual void OnThreadLoopLeave(CSocketHandle* ) {}
  26. virtual void OnAddConnection(CSocketHandle* , SOCKET ) {}
  27. virtual void OnRemoveConnection(CSocketHandle* , SOCKET ) {}
  28. //virtual void OnDataReceived(CSocketHandle* , const BYTE* , DWORD , const SockAddrIn& ) {}
  29. virtual void OnDataReceived(LPWSAOVERLAPPED , const BYTE* , DWORD , const SockAddrIn& ) {}
  30. virtual void OnConnectionFailure(CSocketHandle*, SOCKET) {}
  31. virtual void OnConnectionDropped(CSocketHandle* ) {}
  32. virtual void OnConnectionError(CSocketHandle* , DWORD ) {}
  33. };
  34. /**
  35. * SocketServerImpl<T, tBufferSize>
  36. * Because <typename T> may refer to any class of your choosing,
  37. * Server Communication wrapper
  38. */
  39. template <typename T, size_t tBufferSize = 2048>
  40. class SocketServerImpl
  41. {
  42. typedef SocketServerImpl<T, tBufferSize> thisClass;
  43. public:
  44. SocketServerImpl()
  45. : _pInterface(0)
  46. , _thread(0)
  47. {
  48. }
  49. void SetInterface(T* pInterface)
  50. {
  51. ::InterlockedExchangePointer(reinterpret_cast<void**>(&_pInterface), pInterface);
  52. }
  53. operator CSocketHandle*() throw()
  54. {
  55. return( &_socket );
  56. }
  57. CSocketHandle* operator->() throw()
  58. {
  59. return( &_socket );
  60. }
  61. bool IsOpen() const
  62. {
  63. return _socket.IsOpen();
  64. }
  65. bool CreateSocket(LPCTSTR pszHost, LPCTSTR pszServiceName, int nFamily, int nType, UINT uOptions = 0)
  66. {
  67. return _socket.CreateSocket(pszHost, pszServiceName, nFamily, nType, uOptions);
  68. }
  69. void Close()
  70. {
  71. _socket.Close();
  72. }
  73. DWORD Read(LPBYTE lpBuffer, DWORD dwSize, LPSOCKADDR lpAddrIn = NULL, DWORD dwTimeout = INFINITE)
  74. {
  75. return _socket.Read(lpBuffer, dwSize, lpAddrIn, dwTimeout);
  76. }
  77. DWORD Write(const LPBYTE lpBuffer, DWORD dwCount, const LPSOCKADDR lpAddrIn = NULL, DWORD dwTimeout = INFINITE)
  78. {
  79. return _socket.Write(lpBuffer, dwCount, lpAddrIn, dwTimeout);
  80. }
  81. const SocketList& GetSocketList() const
  82. {
  83. // direct access! - use Lock/Unlock to protect
  84. return _sockets;
  85. }
  86. bool Lock()
  87. {
  88. return _critSection.Lock();
  89. }
  90. bool Unlock()
  91. {
  92. return _critSection.Unlock();
  93. }
  94. void ResetConnectionList()
  95. {
  96. AutoThreadSection aSection(&_critSection);
  97. _sockets.clear();
  98. }
  99. size_t GetConnectionCount() const
  100. {
  101. AutoThreadSection aSection(&_critSection);
  102. return _sockets.size();
  103. }
  104. void AddConnection(SOCKET sock)
  105. {
  106. AutoThreadSection aSection(&_critSection);
  107. _sockets.push_back( sock );
  108. }
  109. void RemoveConnection(SOCKET sock)
  110. {
  111. AutoThreadSection aSection(&_critSection);
  112. _sockets.remove( sock );
  113. }
  114. bool CloseConnection(SOCKET sock)
  115. {
  116. return CSocketHandle::ShutdownConnection( sock );
  117. }
  118. void CloseAllConnections();
  119. bool StartServer(LPCTSTR pszHost, LPCTSTR pszServiceName, int nFamily, int nType, UINT uOptions = 0);
  120. void Terminate(DWORD dwTimeout = INFINITE);
  121. static bool IsConnectionDropped(DWORD dwError);
  122. protected:
  123. void Run();
  124. void OnConnection(ULONG_PTR s);
  125. static DWORD WINAPI SocketServerProc(thisClass* _this);
  126. T* _pInterface;
  127. HANDLE _thread;
  128. ThreadSection _critSection;
  129. CSocketHandle _socket;
  130. SocketList _sockets;
  131. };
  132. template <typename T, size_t tBufferSize>
  133. void SocketServerImpl<T, tBufferSize>::CloseAllConnections()
  134. {
  135. AutoThreadSection aSection(&_critSection);
  136. if ( !_sockets.empty() )
  137. {
  138. // NOTE(elaurentin): this function closes all connections but handles are kept inside of list
  139. // (socket handles are removed by the pooling thread)
  140. SocketList::iterator iter;
  141. for(iter = _sockets.begin(); iter != _sockets.end(); ++iter)
  142. {
  143. CloseConnection( (*iter) );
  144. }
  145. }
  146. }
  147. template <typename T, size_t tBufferSize>
  148. bool SocketServerImpl<T, tBufferSize>::StartServer(LPCTSTR pszHost, LPCTSTR pszServiceName, int nFamily, int nType, UINT uOptions)
  149. {
  150. // must be closed first...
  151. if ( IsOpen() ) return false;
  152. bool result = false;
  153. result = _socket.CreateSocket(pszHost, pszServiceName, nFamily, nType, uOptions);
  154. if ( result )
  155. {
  156. _thread = AtlCreateThread(SocketServerProc, this);
  157. if ( _thread == NULL )
  158. {
  159. DWORD dwError = GetLastError();
  160. _socket.Close();
  161. SetLastError(dwError);
  162. result = false;
  163. }
  164. }
  165. return result;
  166. }
  167. template <typename T, size_t tBufferSize>
  168. void SocketServerImpl<T, tBufferSize>::OnConnection(ULONG_PTR s)
  169. {
  170. SockAddrIn addrIn;
  171. std::vector<unsigned char> data;
  172. data.resize( tBufferSize );
  173. DWORD dwRead;
  174. DWORD dwError;
  175. SOCKET sock = static_cast<SOCKET>(static_cast<ULONG>(s));
  176. CSocketHandle sockHandle;
  177. sockHandle.Attach(sock);
  178. sockHandle.GetPeerName( addrIn );
  179. int type = sockHandle.GetSocketType();
  180. // Notification: OnThreadLoopEnter
  181. if ( _pInterface != NULL ) {
  182. _pInterface->OnThreadLoopEnter(*this);
  183. }
  184. if (type == SOCK_STREAM) {
  185. AddConnection( sock );
  186. // Notification: OnAddConnection
  187. if ( _pInterface != NULL ) {
  188. _pInterface->OnAddConnection(*this, sock);
  189. }
  190. }
  191. if (type == SOCK_STREAM) {
  192. _socket.GetPeerName( addrIn );
  193. }
  194. // Connection loop
  195. while ( sockHandle.IsOpen() )
  196. {
  197. if (type == SOCK_STREAM)
  198. {
  199. dwRead = sockHandle.Read(&data[0], tBufferSize, NULL, INFINITE);
  200. }
  201. else
  202. {
  203. dwRead = sockHandle.Read(&data[0], tBufferSize, addrIn, INFINITE);
  204. }
  205. if ( ( dwRead != -1L ) && (dwRead > 0))
  206. {
  207. // Notification: OnDataReceived
  208. if ( _pInterface != NULL ) {
  209. _pInterface->OnDataReceived(*this, &data[0], dwRead, addrIn);
  210. }
  211. }
  212. else if (type == SOCK_STREAM && dwRead == 0L )
  213. {
  214. // connection broken
  215. if ( _pInterface != NULL ) {
  216. _pInterface->OnConnectionDropped(*this);
  217. }
  218. break;
  219. }
  220. else if ( dwRead == -1L)
  221. {
  222. dwError = GetLastError();
  223. if ( _pInterface != NULL )
  224. {
  225. if (IsConnectionDropped( dwError) ) {
  226. // Notification: OnConnectionDropped
  227. if (type == SOCK_STREAM || (dwError == WSAENOTSOCK || dwError == WSAENETDOWN))
  228. {
  229. _pInterface->OnConnectionDropped(*this);
  230. break;
  231. }
  232. }
  233. // Notification: OnConnectionError
  234. _pInterface->OnConnectionError(*this, dwError);
  235. }
  236. else {
  237. break;
  238. }
  239. }
  240. }
  241. // remove this connection from our list
  242. if (type == SOCK_STREAM) {
  243. RemoveConnection( sock );
  244. // Notification: OnRemoveConnection
  245. if ( _pInterface != NULL ) {
  246. _pInterface->OnRemoveConnection(*this, sock);
  247. }
  248. }
  249. // Detach or Close this socket (TCP-mode only)
  250. if (type != SOCK_STREAM ) {
  251. sockHandle.Detach();
  252. }
  253. data.clear();
  254. // Notification: OnThreadLoopLeave
  255. if ( _pInterface != NULL ) {
  256. _pInterface->OnThreadLoopLeave(*this);
  257. }
  258. }
  259. template <typename T, size_t tBufferSize>
  260. void SocketServerImpl<T, tBufferSize>::Run()
  261. {
  262. _ASSERTE( _pInterface != NULL && "Need an interface to pass events");
  263. SOCKET sock = _socket.GetSocket();
  264. int type = _socket.GetSocketType();
  265. // Notification: OnThreadBegin
  266. if ( _pInterface != NULL ) {
  267. _pInterface->OnThreadBegin(*this);
  268. }
  269. if (type == SOCK_STREAM)
  270. {
  271. // In TCP mode, we need one thread per connection
  272. while( _socket.IsOpen() )
  273. {
  274. SOCKET newSocket = CSocketHandle::WaitForConnection(sock);
  275. if (!_socket.IsOpen())
  276. break;
  277. // run a new client thread for each connection
  278. // report failure if not a valid socket or threadpool failed
  279. if ((newSocket == INVALID_SOCKET) ||
  280. !ThreadPool::QueueWorkItem(&SocketServerImpl<T, tBufferSize>::OnConnection,
  281. this,
  282. static_cast<ULONG_PTR>(newSocket))
  283. )
  284. {
  285. // Notification: OnConnectionFailure
  286. if ( _pInterface != NULL ) {
  287. _pInterface->OnConnectionFailure(*this, newSocket);
  288. }
  289. }
  290. }
  291. // close all connections
  292. CloseAllConnections();
  293. }
  294. else
  295. {
  296. // UDP - only one instance
  297. OnConnection( sock );
  298. }
  299. // Notification: OnThreadExit
  300. if ( _pInterface != NULL ) {
  301. _pInterface->OnThreadExit(*this);
  302. }
  303. }
  304. template <typename T, size_t tBufferSize>
  305. void SocketServerImpl<T, tBufferSize>::Terminate(DWORD dwTimeout /*= INFINITE*/)
  306. {
  307. _socket.Close();
  308. if ( _thread != NULL )
  309. {
  310. if ( WaitForSingleObject(_thread, dwTimeout) == WAIT_TIMEOUT ) {
  311. TerminateThread(_thread, 1);
  312. }
  313. CloseHandle(_thread);
  314. _thread = NULL;
  315. }
  316. }
  317. template <typename T, size_t tBufferSize>
  318. DWORD WINAPI SocketServerImpl<T, tBufferSize>::SocketServerProc(thisClass* _this)
  319. {
  320. if ( _this != NULL )
  321. {
  322. _this->Run();
  323. }
  324. return 0;
  325. }
  326. template <typename T, size_t tBufferSize>
  327. bool SocketServerImpl<T, tBufferSize>::IsConnectionDropped(DWORD dwError)
  328. {
  329. // see: winerror.h for definition
  330. switch( dwError )
  331. {
  332. case WSAENOTSOCK:
  333. case WSAENETDOWN:
  334. case WSAENETUNREACH:
  335. case WSAENETRESET:
  336. case WSAECONNABORTED:
  337. case WSAECONNRESET:
  338. case WSAESHUTDOWN:
  339. case WSAEHOSTDOWN:
  340. case WSAEHOSTUNREACH:
  341. return true;
  342. default:
  343. break;
  344. }
  345. return false;
  346. }
  347. #endif //SOCKETSERVERIMPL_H