diff --git a/src/winhttppal.cpp b/src/winhttppal.cpp index 77648e0..947bd29 100644 --- a/src/winhttppal.cpp +++ b/src/winhttppal.cpp @@ -38,6 +38,7 @@ #include #include #include +#include #ifdef UNICODE #include @@ -57,17 +58,12 @@ #include "winhttppal.h" #endif +#include "winhttppal_imp.h" + #ifndef WINHTTP_CURL_MAX_WRITE_SIZE #define WINHTTP_CURL_MAX_WRITE_SIZE CURL_MAX_WRITE_SIZE #endif -#include -extern "C" -{ -#include -#include -} - class WinHttpSessionImp; class WinHttpRequestImp; @@ -210,43 +206,14 @@ static void TRACE_INTERNAL(const char *fmt, ...) return (retval); \ } -typedef void (*CompletionCb)(std::shared_ptr &, DWORD status); - #define MUTEX_TYPE std::mutex #define MUTEX_SETUP(x) #define MUTEX_CLEANUP(x) #define MUTEX_LOCK(x) x.lock() #define MUTEX_UNLOCK(x) x.unlock() -#ifdef WIN32 -#define THREAD_ID GetCurrentThreadId() -#define THREAD_HANDLE HANDLE -#define THREADPARAM LPVOID -#define CREATETHREAD(func, param, id) \ - CreateThread( \ - NULL, \ - 0, \ - (LPTHREAD_START_ROUTINE)func, \ - param, \ - 0, \ - id) -#define THREADJOIN(h) WaitForSingleObject(h, INFINITE); -#define THREADRETURN DWORD -#else -#define THREADPARAM void* -#define THREADRETURN void* -#define THREAD_ID pthread_self() -#define THREAD_HANDLE pthread_t -#define THREADJOIN(x) pthread_join(x, NULL) - -typedef void* (*LPTHREAD_START_ROUTINE)(void *); -static inline THREAD_HANDLE CREATETHREAD(LPTHREAD_START_ROUTINE func, LPVOID param, pthread_t *id) -{ - pthread_t inc_x_thread; - pthread_create(&inc_x_thread, NULL, func, param); - return inc_x_thread; -} -#endif +class WinHttpBase; +class WinHttpConnectImp; #if OPENSSL_VERSION_NUMBER < 0x10100000L /* This array will store all of the mutexes available to OpenSSL. */ @@ -299,75 +266,6 @@ static int thread_cleanup() return 1; } -class UserCallbackContext -{ - std::shared_ptr m_request; - DWORD m_dwInternetStatus; - DWORD m_dwStatusInformationLength; - DWORD m_dwNotificationFlags; - WINHTTP_STATUS_CALLBACK m_cb; - LPVOID m_userdata; - LPVOID m_StatusInformationVal = NULL; - BYTE* m_StatusInformation = NULL; - bool m_allocate = FALSE; - BOOL m_AsyncResultValid = false; - CompletionCb m_requestCompletionCb; - - BOOL SetAsyncResult(LPVOID statusInformation, DWORD statusInformationCopySize, bool allocate) { - - if (allocate) - { - m_StatusInformation = new BYTE[statusInformationCopySize]; - if (m_StatusInformation) - { - memcpy(m_StatusInformation, statusInformation, statusInformationCopySize); - } - else - return FALSE; - } - else - { - m_StatusInformationVal = statusInformation; - } - m_AsyncResultValid = true; - m_allocate = allocate; - - return TRUE; - } - -public: - UserCallbackContext(std::shared_ptr &request, - DWORD dwInternetStatus, - DWORD dwStatusInformationLength, - DWORD dwNotificationFlags, - WINHTTP_STATUS_CALLBACK cb, - LPVOID userdata, - LPVOID statusInformation, - DWORD statusInformationCopySize, - bool allocate, - CompletionCb completion); - ~UserCallbackContext(); - std::shared_ptr &GetRequestRef() { return m_request; } - WinHttpRequestImp *GetRequest() { return m_request.get(); } - DWORD GetInternetStatus() const { return m_dwInternetStatus; } - DWORD GetStatusInformationLength() const { return m_dwStatusInformationLength; } - DWORD GetNotificationFlags() const { return m_dwNotificationFlags; } - WINHTTP_STATUS_CALLBACK &GetCb() { return m_cb; } - CompletionCb &GetRequestCompletionCb() { return m_requestCompletionCb; } - LPVOID GetUserdata() { return m_userdata; } - LPVOID GetStatusInformation() { - if (m_allocate) - return m_StatusInformation; - else - return m_StatusInformationVal; - } - BOOL GetStatusInformationValid() const { return m_AsyncResultValid; } - -private: - UserCallbackContext(const UserCallbackContext&); - UserCallbackContext& operator=(const UserCallbackContext&); -}; - static void ConvertCstrAssign(const TCHAR *lpstr, size_t cLen, std::string &target) { #ifdef UNICODE @@ -419,38 +317,93 @@ static BOOL SizeCheck(LPVOID lpBuffer, LPDWORD lpdwBufferLength, DWORD Required) return TRUE; } -class WinHttpBase +static DWORD ConvertSecurityProtocol(DWORD offered) { -public: - virtual ~WinHttpBase() {} -}; + DWORD min = 0; + DWORD max = 0; -class WinHttpConnectImp :public WinHttpBase -{ - WinHttpSessionImp *m_Handle = NULL; - void *m_UserBuffer = NULL; - -public: - void SetHandle(WinHttpSessionImp *session) { m_Handle = session; } - WinHttpSessionImp *GetHandle() { return m_Handle; } - - BOOL SetUserData(void **data) - { - if (!data) - return FALSE; - - m_UserBuffer = *data; - return TRUE; + if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_SSL2) { + min = CURL_SSLVERSION_SSLv2; } - void *GetUserData() { return m_UserBuffer; } -}; + else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_SSL3) { + min = CURL_SSLVERSION_SSLv3; + } + else + min = CURL_SSLVERSION_DEFAULT; -struct BufferRequest + if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1) { + min = CURL_SSLVERSION_TLSv1_0; + } + else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_1) { + min = CURL_SSLVERSION_TLSv1_1; + } + else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2) { + min = CURL_SSLVERSION_TLSv1_2; + } + +#if (LIBCURL_VERSION_MAJOR>=7) && (LIBCURL_VERSION_MINOR >= 61) + if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2) { + max = CURL_SSLVERSION_MAX_TLSv1_2; + } + else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_1) { + max = CURL_SSLVERSION_MAX_TLSv1_1; + } + else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1) { + max = CURL_SSLVERSION_MAX_TLSv1_0; + } + else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_SSL3) { + max = CURL_SSLVERSION_SSLv3; + } + else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_SSL2) { + max = CURL_SSLVERSION_SSLv2; + } + else + max = CURL_SSLVERSION_MAX_DEFAULT; +#endif + + return min | max; +} + +template +static BOOL CallMemberFunction(WinHttpBase *base, std::function fn, LPVOID lpBuffer) { - LPVOID m_Buffer = NULL; - size_t m_Length = 0; - size_t m_Used = 0; -}; + T *obj; + + if (!lpBuffer) + return FALSE; + + if ((obj = dynamic_cast(base))) + { + if (fn(obj, static_cast(lpBuffer))) + return TRUE; + } + return FALSE; +} + +WinHttpSessionImp *GetImp(WinHttpBase *base) +{ + WinHttpConnectImp *connect; + WinHttpSessionImp *session; + WinHttpRequestImp *request; + + if ((connect = dynamic_cast(base))) + { + session = connect->GetHandle(); + } + else if ((request = dynamic_cast(base))) + { + connect = request->GetSession(); + if (!connect) + return NULL; + session = connect->GetHandle(); + } + else + { + session = dynamic_cast(base); + } + + return session; +} static void QueueBufferRequest(std::vector &store, LPVOID buffer, size_t length) { @@ -475,800 +428,339 @@ static BufferRequest &PeekBufferRequest(std::vector &store) return store.front(); } -class WinHttpRequestImp :public WinHttpBase, public std::enable_shared_from_this +EnvInit::EnvInit() { - CURL *m_curl = NULL; - std::vector m_ResponseString; - std::string m_HeaderString; - std::string m_Header; - std::string m_FullPath; - std::string m_OptionalData; - size_t m_TotalSize = 0; - size_t m_TotalReceiveSize = 0; - std::vector m_ReadData; + if (const char* env_p = std::getenv("WINHTTP_PAL_DEBUG")) + winhttp_tracing = std::stoi(std::string(env_p)); - std::mutex m_ReadDataEventMtx; - DWORD m_ReadDataEventCounter = 0; - THREAD_HANDLE m_UploadCallbackThread; - bool m_UploadThreadExitStatus = false; + if (const char* env_p = std::getenv("WINHTTP_PAL_DEBUG_VERBOSE")) + winhttp_tracing_verbose = std::stoi(std::string(env_p)); +} - DWORD m_SecureProtocol = 0; - DWORD m_MaxConnections = 0; +static EnvInit envinit; - std::string m_Type; - LPVOID m_UserBuffer = NULL; - bool m_HeaderReceiveComplete = false; - - struct curl_slist *m_HeaderList = NULL; - WinHttpConnectImp *m_Session = NULL; - std::mutex m_HeaderStringMutex; - std::mutex m_BodyStringMutex; - - std::mutex m_QueryDataEventMtx; - bool m_QueryDataEventState = false; - std::atomic m_QueryDataPending; - - std::mutex m_ReceiveCompletionEventMtx; - - std::mutex m_ReceiveResponseMutex; - std::atomic m_ReceiveResponsePending; - std::atomic m_ReceiveResponseEventCounter; - int m_ReceiveResponseSendCounter = 0; - - std::atomic m_RedirectPending; - - bool m_closing = false; - bool m_closed = false; - bool m_Completion = false; - bool m_Async = false; - CURLcode m_CompletionCode = CURLE_OK; - long m_VerifyPeer = 1; - long m_VerifyHost = 2; - bool m_Uploading = false; - - WINHTTP_STATUS_CALLBACK m_InternetCallback = NULL; - DWORD m_NotificationFlags = 0; - std::vector m_OutstandingWrites; - std::vector m_OutstandingReads; - bool m_Secure = false; - -public: - bool &GetSecure() { return m_Secure; } - std::vector &GetOutstandingWrites() { return m_OutstandingWrites; } - std::vector &GetOutstandingReads() { return m_OutstandingReads; } - - long &VerifyPeer() { return m_VerifyPeer; } - long &VerifyHost() { return m_VerifyHost; } - - static size_t WriteHeaderFunction(void *ptr, size_t size, size_t nmemb, void* rqst); - static size_t WriteBodyFunction(void *ptr, size_t size, size_t nmemb, void* rqst); - void ConsumeIncoming(std::shared_ptr &srequest, void* &ptr, size_t &available, size_t &read); - void SetCallback(WINHTTP_STATUS_CALLBACK lpfnInternetCallback, DWORD dwNotificationFlags) { - m_InternetCallback = lpfnInternetCallback; - m_NotificationFlags = dwNotificationFlags; - } - WINHTTP_STATUS_CALLBACK GetCallback(DWORD *dwNotificationFlags) - { - if (dwNotificationFlags) - *dwNotificationFlags = m_NotificationFlags; - return m_InternetCallback; - } - - void SetAsync() { m_Async = TRUE; } - BOOL GetAsync() { return m_Async; } - - WinHttpRequestImp(); - - bool &GetQueryDataEventState() { return m_QueryDataEventState; } - std::mutex &GetQueryDataEventMtx() { return m_QueryDataEventMtx; } - - bool HandleQueryDataNotifications(std::shared_ptr &, size_t available); - void WaitAsyncQueryDataCompletion(std::shared_ptr &); - - void HandleReceiveNotifications(std::shared_ptr &srequest); - void WaitAsyncReceiveCompletion(std::shared_ptr &srequest); - - CURLcode &GetCompletionCode() { return m_CompletionCode; } - bool &GetCompletionStatus() { return m_Completion; } - bool &GetClosing() { return m_closing; } - bool &GetClosed() { return m_closed; } - void CleanUp(); - ~WinHttpRequestImp(); - - bool &GetUploadThreadExitStatus() { return m_UploadThreadExitStatus; } - std::atomic &GetQueryDataPending() { return m_QueryDataPending; } - - THREAD_HANDLE &GetUploadThread() { - return m_UploadCallbackThread; - } - - static THREADRETURN UploadThreadFunction(THREADPARAM lpThreadParameter) - { - WinHttpRequestImp *request = static_cast(lpThreadParameter); - if (!request) - return NULL; - - CURL *curl = request->GetCurl(); - CURLcode res; - - TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)request); - /* Perform the request, res will get the return code */ - res = curl_easy_perform(curl); - /* Check for errors */ - CURL_BAILOUT_ONERROR(res, request, NULL); - - TRACE("%-35s:%-8d:%-16p res:%d\n", __func__, __LINE__, (void*)request, res); - request->GetUploadThreadExitStatus() = true; - - #if OPENSSL_VERSION_NUMBER >= 0x10000000L && OPENSSL_VERSION_NUMBER < 0x10100000L - ERR_remove_thread_state(NULL); - #elif OPENSSL_VERSION_NUMBER < 0x10000000L - ERR_remove_state(0); - #endif - return 0; - } - - // used to wake up CURL ReadCallback triggered on a upload request - std::mutex &GetReadDataEventMtx() { return m_ReadDataEventMtx; } - DWORD &GetReadDataEventCounter() { return m_ReadDataEventCounter; } - - std::mutex &GetReceiveCompletionEventMtx() { return m_ReceiveCompletionEventMtx; } - - // counters used to see which one of these events are observed: ResponseCallbackEventCounter - // counters used to see which one of these events are broadcasted: ResponseCallbackSendCounter - // WINHTTP_CALLBACK_STATUS_RECEIVING_RESPONSE - // WINHTTP_CALLBACK_STATUS_RESPONSE_RECEIVED - // WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE - // - // These counters need to be 3 and identical under normal circumstances - std::atomic &ResponseCallbackEventCounter() { return m_ReceiveResponseEventCounter; } - int &ResponseCallbackSendCounter() { return m_ReceiveResponseSendCounter; } - std::atomic &GetReceiveResponsePending() { return m_ReceiveResponsePending; } - std::mutex &GetReceiveResponseMutex() { return m_ReceiveResponseMutex; } - - std::atomic &GetRedirectPending() { return m_RedirectPending; } - - std::mutex &GetHeaderStringMutex() { return m_HeaderStringMutex; } - std::mutex &GetBodyStringMutex() { return m_HeaderStringMutex; } - - BOOL AsyncQueue(std::shared_ptr &, - DWORD dwInternetStatus, size_t statusInformationLength, - LPVOID statusInformation, DWORD statusInformationCopySize, bool allocate); - - void SetHeaderReceiveComplete() { m_HeaderReceiveComplete = TRUE; } - - BOOL SetSecureProtocol(DWORD *data) - { - if (!data) - return FALSE; - - m_SecureProtocol = *data; - return TRUE; - } - DWORD GetSecureProtocol() { return m_SecureProtocol; } - - BOOL SetMaxConnections(DWORD *data) - { - if (!data) - return FALSE; - - m_MaxConnections = *data; - return TRUE; - } - DWORD GetMaxConnections() { return m_MaxConnections; } - - BOOL SetUserData(void **data) - { - if (!data) - return FALSE; - - m_UserBuffer = *data; - return TRUE; - } - LPVOID GetUserData() { return m_UserBuffer; } - - void SetFullPath(std::string &server, std::string &relative) - { - m_FullPath.append(server); - m_FullPath.append(relative); - } - std::string &GetFullPath() { return m_FullPath; } - std::string &GetType() { return m_Type; } - bool &Uploading() { return m_Uploading; } - - void SetSession(WinHttpConnectImp *connect) { m_Session = connect; } - WinHttpConnectImp *GetSession() { return m_Session; } - - struct curl_slist *GetHeaderList() { return m_HeaderList; } - - void SetHeaderList(struct curl_slist *list) { m_HeaderList = list; } - std::string &GetOutgoingHeaderList() { return m_Header; } - void AddHeader(std::string &headers) { - std::stringstream check1(headers); - std::string str; - - while(getline(check1, str, '\n')) - { - str.erase(std::remove(str.begin(), str.end(), '\r'), str.end()); - SetHeaderList(curl_slist_append(GetHeaderList(), str.c_str())); - } - } - - std::vector &GetResponseString() { return m_ResponseString; } - std::string &GetHeaderString() { return m_HeaderString; } - CURL *GetCurl() { return m_curl; } - - BOOL SetProxy(std::vector &proxies) - { - std::vector::iterator it; - - for (it = proxies.begin(); it != proxies.end(); it++) - { - std::string &urlstr = (*it); - CURLcode res; - - res = curl_easy_setopt(GetCurl(), CURLOPT_PROXY, urlstr.c_str()); - CURL_BAILOUT_ONERROR(res, this, FALSE); - } - - return TRUE; - } - - BOOL SetServer(std::string &ServerName, int nServerPort) - { - CURLcode res; - - res = curl_easy_setopt(GetCurl(), CURLOPT_URL, ServerName.c_str()); - CURL_BAILOUT_ONERROR(res, this, FALSE); - - res = curl_easy_setopt(GetCurl(), CURLOPT_PORT, nServerPort); - CURL_BAILOUT_ONERROR(res, this, FALSE); - - return TRUE; - } - - size_t &GetTotalLength() { return m_TotalSize; } - size_t &GetReadLength() { return m_TotalReceiveSize; } - - std::string &GetOptionalData() { return m_OptionalData; } - BOOL SetOptionalData(void *lpOptional, size_t dwOptionalLength) - { - if (!lpOptional || !dwOptionalLength) - return FALSE; - - m_OptionalData.assign(&(static_cast(lpOptional))[0], dwOptionalLength); - return TRUE; - } - - std::vector &GetReadData() { return m_ReadData; } - - void AppendReadData(const void *data, size_t len) - { - std::lock_guard lck(GetReadDataEventMtx()); - m_ReadData.insert(m_ReadData.end(), static_cast(data), static_cast(data) + len); - } - - static int SocketCallback(CURL *handle, curl_infotype type, - char *data, size_t size, - void *userp); - - static size_t ReadCallback(void *ptr, size_t size, size_t nmemb, void *userp); -}; - -typedef std::vector UserCallbackQueue; - -class UserCallbackContainer +THREADRETURN UserCallbackContainer::UserCallbackThreadFunction(LPVOID lpThreadParameter) { - UserCallbackQueue m_Queue; + UserCallbackContainer *cbContainer = static_cast(lpThreadParameter); - // used to protect GetCallbackMap() - std::mutex m_MapMutex; - - THREAD_HANDLE m_hThread; - - // used by UserCallbackThreadFunction as a wake-up signal - std::mutex m_hEventMtx; - std::condition_variable m_hEvent; - - // cleared by UserCallbackThreadFunction, set by multiple event providers - std::atomic m_EventCounter; - - bool m_closing = false; - UserCallbackQueue &GetCallbackQueue() { return m_Queue; } - - -public: - - bool GetClosing() const { return m_closing; } - - UserCallbackContext* GetNext() + while (true) { - UserCallbackContext *ctx = NULL; - // show content: - if (!GetCallbackQueue().empty()) { - ctx = GetCallbackQueue().front(); - GetCallbackQueue().erase(GetCallbackQueue().begin()); + std::unique_lock GetEventHndlMtx(cbContainer->m_hEventMtx); + while (!cbContainer->m_EventCounter) + cbContainer->m_hEvent.wait(GetEventHndlMtx); + GetEventHndlMtx.unlock(); + + cbContainer->m_EventCounter--; } - return ctx; - } - - static THREADRETURN UserCallbackThreadFunction(LPVOID lpThreadParameter); - - BOOL Queue(UserCallbackContext *ctx) - { - if (!ctx) - return FALSE; - + if (cbContainer->GetClosing()) { - std::lock_guard lck(m_MapMutex); - - TRACE_VERBOSE("%-35s:%-8d:%-16p ctx = %p cb = %p userdata = %p dwInternetStatus = %p\n", - __func__, __LINE__, (void*)ctx->GetRequest(), reinterpret_cast(ctx), reinterpret_cast(ctx->GetCb()), - ctx->GetUserdata(), ctx->GetStatusInformation()); - GetCallbackQueue().push_back(ctx); + TRACE("%s", "exiting\n"); + break; } - { - std::lock_guard lck(m_hEventMtx); - m_EventCounter++; - m_hEvent.notify_all(); - } - return TRUE; - } - void DrainQueue() - { while (true) { - m_MapMutex.lock(); - UserCallbackContext *ctx = GetNext(); - m_MapMutex.unlock(); - + cbContainer->m_MapMutex.lock(); + UserCallbackContext *ctx = NULL; + ctx = cbContainer->GetNext(); + cbContainer->m_MapMutex.unlock(); if (!ctx) break; + void *statusInformation = NULL; + + if (ctx->GetStatusInformationValid()) + statusInformation = ctx->GetStatusInformation(); + + if (!ctx->GetRequestRef()->GetClosed() && ctx->GetCb()) { + TRACE_VERBOSE("%-35s:%-8d:%-16p ctx = %p cb = %p ctx->GetUserdata() = %p dwInternetStatus:0x%lx statusInformation=%p refcount:%lu\n", + __func__, __LINE__, (void*)ctx->GetRequest(), reinterpret_cast(ctx), reinterpret_cast(ctx->GetCb()), + ctx->GetUserdata(), ctx->GetInternetStatus(), statusInformation, ctx->GetRequestRef().use_count()); + + ctx->GetCb()(ctx->GetRequest(), + (DWORD_PTR)(ctx->GetUserdata()), + ctx->GetInternetStatus(), + statusInformation, + ctx->GetStatusInformationLength()); + + TRACE_VERBOSE("%-35s:%-8d:%-16p ctx = %p cb = %p ctx->GetUserdata() = %p dwInternetStatus:0x%lx statusInformation=%p refcount:%lu\n", + __func__, __LINE__, (void*)ctx->GetRequest(), reinterpret_cast(ctx), reinterpret_cast(ctx->GetCb()), + ctx->GetUserdata(), ctx->GetInternetStatus(), statusInformation, ctx->GetRequestRef().use_count()); + } + ctx->GetRequestCompletionCb()(ctx->GetRequestRef(), ctx->GetInternetStatus()); delete ctx; } } - UserCallbackContainer(): m_EventCounter(0) - { - DWORD thread_id; - m_hThread = CREATETHREAD( - UserCallbackThreadFunction, // thread function name - this, // argument to thread function - &thread_id - ); - } - ~UserCallbackContainer() - { - m_closing = true; - { - std::lock_guard lck(m_hEventMtx); - m_EventCounter++; - m_hEvent.notify_all(); - } - THREADJOIN(m_hThread); - DrainQueue(); - } - static UserCallbackContainer &GetInstance() - { - static UserCallbackContainer the_instance; - return the_instance; - } -private: - UserCallbackContainer(const UserCallbackContainer&); - UserCallbackContainer& operator=(const UserCallbackContainer&); -}; +#if OPENSSL_VERSION_NUMBER >= 0x10000000L && OPENSSL_VERSION_NUMBER < 0x10100000L + ERR_remove_thread_state(NULL); +#elif OPENSSL_VERSION_NUMBER < 0x10000000L + ERR_remove_state(0); +#endif -class EnvInit + return 0; +} + +UserCallbackContext* UserCallbackContainer::GetNext() { -public: - EnvInit() + UserCallbackContext *ctx = NULL; + // show content: + if (!GetCallbackQueue().empty()) { - if (const char* env_p = std::getenv("WINHTTP_PAL_DEBUG")) - winhttp_tracing = std::stoi(std::string(env_p)); - - if (const char* env_p = std::getenv("WINHTTP_PAL_DEBUG_VERBOSE")) - winhttp_tracing_verbose = std::stoi(std::string(env_p)); + ctx = GetCallbackQueue().front(); + GetCallbackQueue().erase(GetCallbackQueue().begin()); } -}; -static EnvInit envinit; + return ctx; +} -class ComContainer +BOOL UserCallbackContainer::Queue(UserCallbackContext *ctx) { - THREAD_HANDLE m_hAsyncThread; - std::mutex m_hAsyncEventMtx; + if (!ctx) + return FALSE; - // used to wake up the Async Thread - // Set by external components, cleared by the Async thread - std::atomic m_hAsyncEventCounter; - std::condition_variable m_hAsyncEvent; - std::mutex m_MultiMutex; - CURLM *m_curlm = NULL; - std::vector< CURL *> m_ActiveCurl; - std::vector> m_ActiveRequests; - BOOL m_closing = FALSE; - - // used to protect CURLM data structures - BOOL GetThreadClosing() const { return m_closing; } - -public: - CURL *AllocCURL() { - CURL *ptr; + std::lock_guard lck(m_MapMutex); - m_MultiMutex.lock(); - ptr = curl_easy_init(); - m_MultiMutex.unlock(); - - return ptr; + TRACE_VERBOSE("%-35s:%-8d:%-16p ctx = %p cb = %p userdata = %p dwInternetStatus = %p\n", + __func__, __LINE__, (void*)ctx->GetRequest(), reinterpret_cast(ctx), reinterpret_cast(ctx->GetCb()), + ctx->GetUserdata(), ctx->GetStatusInformation()); + GetCallbackQueue().push_back(ctx); } - - void FreeCURL(CURL *ptr) { - m_MultiMutex.lock(); - curl_easy_cleanup(ptr); - m_MultiMutex.unlock(); + std::lock_guard lck(m_hEventMtx); + m_EventCounter++; + m_hEvent.notify_all(); } + return TRUE; +} - static ComContainer &GetInstance() +void UserCallbackContainer::DrainQueue() +{ + while (true) { - static ComContainer the_instance; - return the_instance; + m_MapMutex.lock(); + UserCallbackContext *ctx = GetNext(); + m_MapMutex.unlock(); + + if (!ctx) + break; + + delete ctx; } +} - void ResumeTransfer(CURL *handle, int bitmask) - { - int still_running; +CURL *ComContainer::AllocCURL() +{ + CURL *ptr; - m_MultiMutex.lock(); - curl_easy_pause(handle, bitmask); - curl_multi_perform(m_curlm, &still_running); - m_MultiMutex.unlock(); - } + m_MultiMutex.lock(); + ptr = curl_easy_init(); + m_MultiMutex.unlock(); - BOOL AddHandle(std::shared_ptr &srequest, CURL *handle) - { - CURLMcode mres = CURLM_OK; + return ptr; +} - m_MultiMutex.lock(); - if (std::find(m_ActiveCurl.begin(), m_ActiveCurl.end(), handle) != m_ActiveCurl.end()) { - mres = curl_multi_remove_handle(m_curlm, handle); - m_ActiveCurl.erase(std::remove(m_ActiveCurl.begin(), m_ActiveCurl.end(), handle), m_ActiveCurl.end()); - } - if (std::find(m_ActiveRequests.begin(), m_ActiveRequests.end(), srequest) != m_ActiveRequests.end()) { - m_ActiveRequests.erase(std::remove(m_ActiveRequests.begin(), m_ActiveRequests.end(), srequest), m_ActiveRequests.end()); - } - m_MultiMutex.unlock(); - if (mres != CURLM_OK) - { - TRACE("curl_multi_remove_handle() failed: %s\n", curl_multi_strerror(mres)); - return FALSE; - } +void ComContainer::FreeCURL(CURL *ptr) +{ + m_MultiMutex.lock(); + curl_easy_cleanup(ptr); + m_MultiMutex.unlock(); +} - m_MultiMutex.lock(); - mres = curl_multi_add_handle(m_curlm, handle); - m_ActiveCurl.push_back(handle); - m_ActiveRequests.push_back(srequest); - m_MultiMutex.unlock(); - if (mres != CURLM_OK) - { - TRACE("curl_multi_add_handle() failed: %s\n", curl_multi_strerror(mres)); - return FALSE; - } +ComContainer &ComContainer::GetInstance() +{ + static ComContainer the_instance; + return the_instance; +} - return TRUE; - } - BOOL RemoveHandle(std::shared_ptr &srequest, CURL *handle, bool clearPrivate) - { - CURLMcode mres; +void ComContainer::ResumeTransfer(CURL *handle, int bitmask) +{ + int still_running; - m_MultiMutex.lock(); + m_MultiMutex.lock(); + curl_easy_pause(handle, bitmask); + curl_multi_perform(m_curlm, &still_running); + m_MultiMutex.unlock(); +} + +BOOL ComContainer::AddHandle(std::shared_ptr &srequest, CURL *handle) +{ + CURLMcode mres = CURLM_OK; + + m_MultiMutex.lock(); + if (std::find(m_ActiveCurl.begin(), m_ActiveCurl.end(), handle) != m_ActiveCurl.end()) { mres = curl_multi_remove_handle(m_curlm, handle); - - if (clearPrivate) - curl_easy_getinfo(handle, CURLINFO_PRIVATE, NULL); - m_ActiveCurl.erase(std::remove(m_ActiveCurl.begin(), m_ActiveCurl.end(), handle), m_ActiveCurl.end()); + } + if (std::find(m_ActiveRequests.begin(), m_ActiveRequests.end(), srequest) != m_ActiveRequests.end()) { m_ActiveRequests.erase(std::remove(m_ActiveRequests.begin(), m_ActiveRequests.end(), srequest), m_ActiveRequests.end()); - m_MultiMutex.unlock(); - if (mres != CURLM_OK) - { - TRACE("curl_multi_add_handle() failed: %s\n", curl_multi_strerror(mres)); - return FALSE; - } - - return TRUE; } - - long GetTimeout() + m_MultiMutex.unlock(); + if (mres != CURLM_OK) { - long curl_timeo; - - m_MultiMutex.lock(); - curl_multi_timeout(m_curlm, &curl_timeo); - m_MultiMutex.unlock(); - - return curl_timeo; + TRACE("curl_multi_remove_handle() failed: %s\n", curl_multi_strerror(mres)); + return FALSE; } - void KickStart() + m_MultiMutex.lock(); + mres = curl_multi_add_handle(m_curlm, handle); + m_ActiveCurl.push_back(handle); + m_ActiveRequests.push_back(srequest); + m_MultiMutex.unlock(); + if (mres != CURLM_OK) + { + TRACE("curl_multi_add_handle() failed: %s\n", curl_multi_strerror(mres)); + return FALSE; + } + + return TRUE; +} + +BOOL ComContainer::RemoveHandle(std::shared_ptr &srequest, CURL *handle, bool clearPrivate) +{ + CURLMcode mres; + + m_MultiMutex.lock(); + mres = curl_multi_remove_handle(m_curlm, handle); + + if (clearPrivate) + curl_easy_getinfo(handle, CURLINFO_PRIVATE, NULL); + + m_ActiveCurl.erase(std::remove(m_ActiveCurl.begin(), m_ActiveCurl.end(), handle), m_ActiveCurl.end()); + m_ActiveRequests.erase(std::remove(m_ActiveRequests.begin(), m_ActiveRequests.end(), srequest), m_ActiveRequests.end()); + m_MultiMutex.unlock(); + if (mres != CURLM_OK) + { + TRACE("curl_multi_add_handle() failed: %s\n", curl_multi_strerror(mres)); + return FALSE; + } + + return TRUE; +} + +long ComContainer::GetTimeout() +{ + long curl_timeo; + + m_MultiMutex.lock(); + curl_multi_timeout(m_curlm, &curl_timeo); + m_MultiMutex.unlock(); + + return curl_timeo; +} + +void ComContainer::KickStart() +{ + std::lock_guard lck(m_hAsyncEventMtx); + m_hAsyncEventCounter++; + m_hAsyncEvent.notify_all(); +} + +ComContainer::ComContainer(): m_hAsyncEventCounter(0) +{ + curl_global_init(CURL_GLOBAL_ALL); + thread_setup(); + + m_curlm = curl_multi_init(); + + DWORD thread_id; + m_hAsyncThread = CREATETHREAD( + AsyncThreadFunction, // thread function name + this, &thread_id); // argument to thread function +} + +ComContainer::~ComContainer() +{ + TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)this); + m_closing = true; { std::lock_guard lck(m_hAsyncEventMtx); m_hAsyncEventCounter++; m_hAsyncEvent.notify_all(); } + THREADJOIN(m_hAsyncThread); + TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)this); - ComContainer(): m_hAsyncEventCounter(0) - { - curl_global_init(CURL_GLOBAL_ALL); - thread_setup(); - - m_curlm = curl_multi_init(); - - DWORD thread_id; - m_hAsyncThread = CREATETHREAD( - AsyncThreadFunction, // thread function name - this, &thread_id); // argument to thread function - } - - ~ComContainer() - { - TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)this); - m_closing = true; - { - std::lock_guard lck(m_hAsyncEventMtx); - m_hAsyncEventCounter++; - m_hAsyncEvent.notify_all(); - } - THREADJOIN(m_hAsyncThread); - TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)this); - - curl_multi_cleanup(m_curlm); - curl_global_cleanup(); - thread_cleanup(); - } - - - static THREADRETURN AsyncThreadFunction(THREADPARAM lpThreadParameter); - - int QueryData(int *still_running) - { - if (!still_running) - return 0; - - int rc = 0; - struct timeval timeout; - - fd_set fdread; - fd_set fdwrite; - fd_set fdexcep; - int maxfd = -1; - - long curl_timeo = -1; - - FD_ZERO(&fdread); - FD_ZERO(&fdwrite); - FD_ZERO(&fdexcep); - - /* set a suitable timeout to play around with */ - timeout.tv_sec = 1; - timeout.tv_usec = 0; - - m_MultiMutex.lock(); - curl_multi_timeout(m_curlm, &curl_timeo); - m_MultiMutex.unlock(); - if (curl_timeo < 0) - /* no set timeout, use a default */ - curl_timeo = 10000; - - if (curl_timeo > 0) { - timeout.tv_sec = curl_timeo / 1000; - if (timeout.tv_sec > 1) - timeout.tv_sec = 1; - else - timeout.tv_usec = (curl_timeo % 1000) * 1000; - } - - /* get file descriptors from the transfers */ - m_MultiMutex.lock(); - rc = curl_multi_fdset(m_curlm, &fdread, &fdwrite, &fdexcep, &maxfd); - m_MultiMutex.unlock(); - - if ((maxfd == -1) || (rc != CURLM_OK)) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } - else - rc = select(maxfd + 1, &fdread, &fdwrite, &fdexcep, &timeout); - - switch (rc) { - case -1: - /* select error */ - *still_running = 0; - TRACE("%s\n", "select() returns error, this is badness\n"); - break; - case 0: - default: - /* timeout or readable/writable sockets */ - m_MultiMutex.lock(); - curl_multi_perform(m_curlm, still_running); - m_MultiMutex.unlock(); - break; - } - - return rc; - } -private: - ComContainer(const ComContainer&); - ComContainer& operator=(const ComContainer&); -}; + curl_multi_cleanup(m_curlm); + curl_global_cleanup(); + thread_cleanup(); +} template -class WinHttpHandleContainer +void WinHttpHandleContainer::UnRegister(T *val) { - std::mutex m_ActiveRequestMtx; - std::vector> m_ActiveRequests; - -public: - static WinHttpHandleContainer &Instance() - { - static WinHttpHandleContainer the_instance; - return the_instance; - } - typedef std::shared_ptr t_shared; - void UnRegister(T *val) + typename std::vector::iterator findit; + bool found = false; + std::lock_guard lck(m_ActiveRequestMtx); + + TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)val); + + for (auto it = m_ActiveRequests.begin(); it != m_ActiveRequests.end(); ++it) { - typename std::vector::iterator findit; - bool found = false; - std::lock_guard lck(m_ActiveRequestMtx); - - TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)val); - - for (auto it = m_ActiveRequests.begin(); it != m_ActiveRequests.end(); ++it) + auto v = (*it); + if (v.get() == val) { - auto v = (*it); - if (v.get() == val) - { - findit = it; - found = true; - break; - } + findit = it; + found = true; + break; } - if (found) - m_ActiveRequests.erase(findit); } + if (found) + m_ActiveRequests.erase(findit); +} - bool IsRegistered(T *val) - { - bool found = false; - std::lock_guard lck(m_ActiveRequestMtx); - - for (auto it = m_ActiveRequests.begin(); it != m_ActiveRequests.end(); ++it) - { - auto v = (*it); - if (v.get() == val) - { - found = true; - break; - } - } - TRACE("%-35s:%-8d:%-16p found:%d\n", __func__, __LINE__, (void*)val, found); - - return found; - } - - void Register(std::shared_ptr rqst) - { - TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)(rqst.get())); - std::lock_guard lck(m_ActiveRequestMtx); - m_ActiveRequests.push_back(rqst); - } -}; - -class WinHttpSessionImp :public WinHttpBase +template +bool WinHttpHandleContainer::IsRegistered(T *val) { - std::string m_ServerName; - std::string m_Referrer; - std::vector m_Proxies; - std::string m_Proxy; - WINHTTP_STATUS_CALLBACK m_InternetCallback = NULL; - DWORD m_NotificationFlags = 0; + bool found = false; + std::lock_guard lck(m_ActiveRequestMtx); - int m_ServerPort = 0; - long m_Timeout = 15000; - BOOL m_Async = false; - - bool m_closing = false; - DWORD m_MaxConnections = 0; - DWORD m_SecureProtocol = 0; - void *m_UserBuffer = NULL; - -public: - - BOOL SetUserData(void **data) + for (auto it = m_ActiveRequests.begin(); it != m_ActiveRequests.end(); ++it) { - if (!data) - return FALSE; - - m_UserBuffer = *data; - return TRUE; + auto v = (*it); + if (v.get() == val) + { + found = true; + break; + } } - void *GetUserData() { return m_UserBuffer; } + TRACE("%-35s:%-8d:%-16p found:%d\n", __func__, __LINE__, (void*)val, found); - BOOL SetSecureProtocol(DWORD *data) - { - if (!data) - return FALSE; + return found; +} - m_SecureProtocol = *data; - return TRUE; - } - DWORD GetSecureProtocol() const { return m_SecureProtocol; } +template +void WinHttpHandleContainer::Register(std::shared_ptr rqst) +{ + TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)(rqst.get())); + std::lock_guard lck(m_ActiveRequestMtx); + m_ActiveRequests.push_back(rqst); +} - BOOL SetMaxConnections(DWORD *data) - { - if (!data) - return FALSE; +long WinHttpSessionImp::GetTimeout() const { + if (m_Timeout) + return m_Timeout; - m_MaxConnections = *data; - return TRUE; - } - DWORD GetMaxConnections() const { return m_MaxConnections; } + long curl_timeo; - void SetAsync() { m_Async = TRUE; } - BOOL GetAsync() const { return m_Async; } + curl_timeo = ComContainer::GetInstance().GetTimeout(); - BOOL GetThreadClosing() const { return m_closing; } + if (curl_timeo < 0) + curl_timeo = 10000; - std::vector &GetProxies() { return m_Proxies; } - void SetProxies(std::vector &proxies) { m_Proxies = proxies; } + return curl_timeo; +} - std::string &GetProxy() { return m_Proxy; } - - int GetServerPort() const { return m_ServerPort; } - void SetServerPort(int port) { m_ServerPort = port; } - - long GetTimeout() const { - if (m_Timeout) - return m_Timeout; - - long curl_timeo; - - curl_timeo = ComContainer::GetInstance().GetTimeout(); - - if (curl_timeo < 0) - curl_timeo = 10000; - - return curl_timeo; - } - void SetTimeout(long timeout) { m_Timeout = timeout; } - - std::string &GetServerName() { return m_ServerName; } - - std::string &GetReferrer() { return m_Referrer; } - - void SetCallback(WINHTTP_STATUS_CALLBACK lpfnInternetCallback, DWORD dwNotificationFlags) { - m_InternetCallback = lpfnInternetCallback; - m_NotificationFlags = dwNotificationFlags; - } - WINHTTP_STATUS_CALLBACK GetCallback(DWORD *dwNotificationFlags) - { - if (dwNotificationFlags) - *dwNotificationFlags = m_NotificationFlags; - return m_InternetCallback; - } - - ~WinHttpSessionImp() - { - TRACE("%-35s:%-8d:%-16p sesion\n", __func__, __LINE__, (void*)this); - SetCallback(NULL, 0); - TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)this); - } -}; +WinHttpSessionImp::~WinHttpSessionImp() +{ + TRACE("%-35s:%-8d:%-16p sesion\n", __func__, __LINE__, (void*)this); + SetCallback(NULL, 0); + TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)this); +} THREADRETURN ComContainer::AsyncThreadFunction(LPVOID lpThreadParameter) { @@ -1348,6 +840,7 @@ THREADRETURN ComContainer::AsyncThreadFunction(LPVOID lpThreadParameter) { TRACE("%-35s:%-8d:%-16p consumed length:%lu\n", __func__, __LINE__, (void*)srequest.get(), totalread); } + request->FlushIncoming(srequest); } else if (m->data.result == CURLE_OPERATION_TIMEDOUT) { @@ -1406,116 +899,126 @@ THREADRETURN ComContainer::AsyncThreadFunction(LPVOID lpThreadParameter) return 0; } -THREADRETURN UserCallbackContainer::UserCallbackThreadFunction(LPVOID lpThreadParameter) +int ComContainer::QueryData(int *still_running) { - UserCallbackContainer *cbContainer = static_cast(lpThreadParameter); + if (!still_running) + return 0; - while (true) - { - { - std::unique_lock GetEventHndlMtx(cbContainer->m_hEventMtx); - while (!cbContainer->m_EventCounter) - cbContainer->m_hEvent.wait(GetEventHndlMtx); - GetEventHndlMtx.unlock(); + int rc = 0; + struct timeval timeout; - cbContainer->m_EventCounter--; - } + fd_set fdread; + fd_set fdwrite; + fd_set fdexcep; + int maxfd = -1; - if (cbContainer->GetClosing()) - { - TRACE("%s", "exiting\n"); - break; - } + long curl_timeo = -1; - while (true) - { - cbContainer->m_MapMutex.lock(); - UserCallbackContext *ctx = NULL; - ctx = cbContainer->GetNext(); - cbContainer->m_MapMutex.unlock(); - if (!ctx) - break; + FD_ZERO(&fdread); + FD_ZERO(&fdwrite); + FD_ZERO(&fdexcep); - void *statusInformation = NULL; + /* set a suitable timeout to play around with */ + timeout.tv_sec = 1; + timeout.tv_usec = 0; - if (ctx->GetStatusInformationValid()) - statusInformation = ctx->GetStatusInformation(); + m_MultiMutex.lock(); + curl_multi_timeout(m_curlm, &curl_timeo); + m_MultiMutex.unlock(); + if (curl_timeo < 0) + /* no set timeout, use a default */ + curl_timeo = 10000; - if (!ctx->GetRequestRef()->GetClosed() && ctx->GetCb()) { - TRACE_VERBOSE("%-35s:%-8d:%-16p ctx = %p cb = %p ctx->GetUserdata() = %p dwInternetStatus:0x%lx statusInformation=%p refcount:%lu\n", - __func__, __LINE__, (void*)ctx->GetRequest(), reinterpret_cast(ctx), reinterpret_cast(ctx->GetCb()), - ctx->GetUserdata(), ctx->GetInternetStatus(), statusInformation, ctx->GetRequestRef().use_count()); - ctx->GetCb()(ctx->GetRequest(), - (DWORD_PTR)(ctx->GetUserdata()), - ctx->GetInternetStatus(), - statusInformation, - ctx->GetStatusInformationLength()); - - TRACE_VERBOSE("%-35s:%-8d:%-16p ctx = %p cb = %p ctx->GetUserdata() = %p dwInternetStatus:0x%lx statusInformation=%p refcount:%lu\n", - __func__, __LINE__, (void*)ctx->GetRequest(), reinterpret_cast(ctx), reinterpret_cast(ctx->GetCb()), - ctx->GetUserdata(), ctx->GetInternetStatus(), statusInformation, ctx->GetRequestRef().use_count()); - } - ctx->GetRequestCompletionCb()(ctx->GetRequestRef(), ctx->GetInternetStatus()); - delete ctx; - } + if (curl_timeo > 0) { + timeout.tv_sec = curl_timeo / 1000; + if (timeout.tv_sec > 1) + timeout.tv_sec = 1; + else + timeout.tv_usec = (curl_timeo % 1000) * 1000; } + /* get file descriptors from the transfers */ + m_MultiMutex.lock(); + rc = curl_multi_fdset(m_curlm, &fdread, &fdwrite, &fdexcep, &maxfd); + m_MultiMutex.unlock(); + + if ((maxfd == -1) || (rc != CURLM_OK)) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + else + rc = select(maxfd + 1, &fdread, &fdwrite, &fdexcep, &timeout); + + switch (rc) { + case -1: + /* select error */ + *still_running = 0; + TRACE("%s\n", "select() returns error, this is badness\n"); + break; + case 0: + default: + /* timeout or readable/writable sockets */ + m_MultiMutex.lock(); + curl_multi_perform(m_curlm, still_running); + m_MultiMutex.unlock(); + break; + } + + return rc; +} + +THREADRETURN WinHttpRequestImp::UploadThreadFunction(THREADPARAM lpThreadParameter) +{ + WinHttpRequestImp *request = static_cast(lpThreadParameter); + if (!request) + return NULL; + + CURL *curl = request->GetCurl(); + CURLcode res; + + TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)request); + /* Perform the request, res will get the return code */ + res = curl_easy_perform(curl); + /* Check for errors */ + CURL_BAILOUT_ONERROR(res, request, NULL); + + TRACE("%-35s:%-8d:%-16p res:%d\n", __func__, __LINE__, (void*)request, res); + request->GetUploadThreadExitStatus() = true; + #if OPENSSL_VERSION_NUMBER >= 0x10000000L && OPENSSL_VERSION_NUMBER < 0x10100000L ERR_remove_thread_state(NULL); #elif OPENSSL_VERSION_NUMBER < 0x10000000L ERR_remove_state(0); #endif - return 0; } -UserCallbackContext::UserCallbackContext(std::shared_ptr &request, - DWORD dwInternetStatus, - DWORD dwStatusInformationLength, - DWORD dwNotificationFlags, - WINHTTP_STATUS_CALLBACK cb, - LPVOID userdata, - LPVOID statusInformation, - DWORD statusInformationCopySize, bool allocate, - CompletionCb completion) - : m_request(request), m_dwInternetStatus(dwInternetStatus), - m_dwStatusInformationLength(dwStatusInformationLength), - m_dwNotificationFlags(dwNotificationFlags), m_cb(cb), m_userdata(userdata), - m_requestCompletionCb(completion) +BOOL WinHttpRequestImp::SetProxy(std::vector &proxies) { - if (statusInformation) - SetAsyncResult(statusInformation, statusInformationCopySize, allocate); + std::vector::iterator it; + + for (it = proxies.begin(); it != proxies.end(); it++) + { + std::string &urlstr = (*it); + CURLcode res; + + res = curl_easy_setopt(GetCurl(), CURLOPT_PROXY, urlstr.c_str()); + CURL_BAILOUT_ONERROR(res, this, FALSE); + } + + return TRUE; } -WinHttpSessionImp *GetImp(WinHttpBase *base) +BOOL WinHttpRequestImp::SetServer(std::string &ServerName, int nServerPort) { - WinHttpConnectImp *connect; - WinHttpSessionImp *session; - WinHttpRequestImp *request; + CURLcode res; - if ((connect = dynamic_cast(base))) - { - session = connect->GetHandle(); - } - else if ((request = dynamic_cast(base))) - { - connect = request->GetSession(); - if (!connect) - return NULL; - session = connect->GetHandle(); - } - else - { - session = dynamic_cast(base); - } + res = curl_easy_setopt(GetCurl(), CURLOPT_URL, ServerName.c_str()); + CURL_BAILOUT_ONERROR(res, this, FALSE); - return session; -} + res = curl_easy_setopt(GetCurl(), CURLOPT_PORT, nServerPort); + CURL_BAILOUT_ONERROR(res, this, FALSE); - -UserCallbackContext::~UserCallbackContext() -{ - delete [] m_StatusInformation; + return TRUE; } void WinHttpRequestImp::HandleReceiveNotifications(std::shared_ptr &srequest) @@ -1586,7 +1089,6 @@ size_t WinHttpRequestImp::WriteHeaderFunction(void *ptr, size_t size, size_t nme TRACE("%-35s:%-8d:%-16p %zu\n", __func__, __LINE__, (void*)request, size * nmemb); { std::lock_guard lck(request->GetHeaderStringMutex()); - request->GetHeaderString().append(static_cast(ptr), size * nmemb); if (request->GetHeaderString().find("\r\n\r\n") != std::string::npos) @@ -1624,38 +1126,6 @@ size_t WinHttpRequestImp::WriteHeaderFunction(void *ptr, size_t size, size_t nme return size * nmemb; } -void WinHttpRequestImp::ConsumeIncoming(std::shared_ptr &srequest, void* &ptr, size_t &available, size_t &read) -{ - while (available) - { - BufferRequest buf = GetBufferRequest(srequest->GetOutstandingReads()); - if (!buf.m_Length) - break; - - size_t len = MIN(buf.m_Length, available); - - if (len) - { - TRACE("%-35s:%-8d:%-16p reading length:%lu written:%lu %p %ld\n", - __func__, __LINE__, (void*)srequest.get(), len, read, buf.m_Buffer, buf.m_Length); - memcpy(static_cast(buf.m_Buffer), ptr, len); - } - - ptr = static_cast(ptr) + len; - available -= len; - - if (len) - { - LPVOID StatusInformation = buf.m_Buffer; - - TRACE("%-35s:%-8d:%-16p WINHTTP_CALLBACK_STATUS_READ_COMPLETE lpBuffer: %p length:%lu\n", __func__, __LINE__, (void*)srequest.get(), buf.m_Buffer, len); - srequest->AsyncQueue(srequest, WINHTTP_CALLBACK_STATUS_READ_COMPLETE, len, StatusInformation, sizeof(StatusInformation), false); - } - - read += len; - } -} - size_t WinHttpRequestImp::WriteBodyFunction(void *ptr, size_t size, size_t nmemb, void* rqst) { size_t read = 0; WinHttpRequestImp *request = static_cast(rqst); @@ -1707,6 +1177,53 @@ size_t WinHttpRequestImp::WriteBodyFunction(void *ptr, size_t size, size_t nmemb return read; } +void WinHttpRequestImp::ConsumeIncoming(std::shared_ptr &srequest, void* &ptr, size_t &available, size_t &read) +{ + while (available) + { + BufferRequest buf = GetBufferRequest(srequest->GetOutstandingReads()); + if (!buf.m_Length) + break; + + size_t len = MIN(buf.m_Length, available); + + if (len) + { + TRACE("%-35s:%-8d:%-16p reading length:%lu written:%lu %p %ld\n", + __func__, __LINE__, (void*)srequest.get(), len, read, buf.m_Buffer, buf.m_Length); + memcpy(static_cast(buf.m_Buffer), ptr, len); + } + + ptr = static_cast(ptr) + len; + available -= len; + + if (len) + { + LPVOID StatusInformation = buf.m_Buffer; + + TRACE("%-35s:%-8d:%-16p WINHTTP_CALLBACK_STATUS_READ_COMPLETE lpBuffer: %p length:%lu\n", __func__, __LINE__, (void*)srequest.get(), buf.m_Buffer, len); + srequest->AsyncQueue(srequest, WINHTTP_CALLBACK_STATUS_READ_COMPLETE, len, StatusInformation, sizeof(StatusInformation), false); + } + + read += len; + } +} + +void WinHttpRequestImp::FlushIncoming(std::shared_ptr &srequest) +{ + while (1) + { + BufferRequest buf = GetBufferRequest(srequest->GetOutstandingReads()); + if (!buf.m_Length) + break; + + LPVOID StatusInformation = buf.m_Buffer; + + TRACE("%-35s:%-8d:%-16p WINHTTP_CALLBACK_STATUS_READ_COMPLETE lpBuffer: %p length:%d\n", __func__, __LINE__, (void*)srequest.get(), buf.m_Buffer, 0); + srequest->AsyncQueue(srequest, WINHTTP_CALLBACK_STATUS_READ_COMPLETE, 0, StatusInformation, sizeof(StatusInformation), false); + } +} + void WinHttpRequestImp::CleanUp() { m_CompletionCode = CURLE_OK; @@ -1844,7 +1361,6 @@ size_t WinHttpRequestImp::ReadCallback(void *ptr, size_t size, size_t nmemb, voi { std::lock_guard lck(request->GetReadDataEventMtx()); if (!request->GetOutstandingWrites().empty()) - { BufferRequest &buf = PeekBufferRequest(request->GetOutstandingWrites()); len = MIN(buf.m_Length - buf.m_Used, size * nmemb); @@ -1935,6 +1451,62 @@ int WinHttpRequestImp::SocketCallback(CURL *handle, curl_infotype type, return 0; } +bool WinHttpRequestImp::HandleQueryDataNotifications(std::shared_ptr &srequest, size_t available) +{ + bool expected = true; + bool result = std::atomic_compare_exchange_strong(&GetQueryDataPending(), &expected, false); + if (result) + { + if (!available) + { + size_t length; + + GetBodyStringMutex().lock(); + length = GetResponseString().size(); + available = length; + GetBodyStringMutex().unlock(); + } + + DWORD lpvStatusInformation = static_cast(available); + TRACE("%-35s:%-8d:%-16p WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE:%lu\n", __func__, __LINE__, (void*)this, lpvStatusInformation); + AsyncQueue(srequest, WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE, sizeof(lpvStatusInformation), + (LPVOID)&lpvStatusInformation, sizeof(lpvStatusInformation), true); + } + return result; +} + +void WinHttpRequestImp::WaitAsyncQueryDataCompletion(std::shared_ptr &srequest) +{ + bool completed; + GetQueryDataPending() = true; + TRACE("%-35s:%-8d:%-16p GetQueryDataPending() = %d\n", __func__, __LINE__, (void*)this, (int)GetQueryDataPending()); + { + std::lock_guard lck(GetQueryDataEventMtx()); + completed = GetQueryDataEventState(); + } + + if (completed) { + TRACE("%-35s:%-8d:%-16p transfer already finished\n", __func__, __LINE__, (void*)this); + HandleQueryDataNotifications(srequest, 0); + } +} + +void WinHttpRequestImp::WaitAsyncReceiveCompletion(std::shared_ptr &srequest) +{ + bool expected = false; + std::atomic_compare_exchange_strong(&GetReceiveResponsePending(), &expected, true); + + TRACE("%-35s:%-8d:%-16p GetReceiveResponsePending() = %d\n", __func__, __LINE__, (void*)this, (int) GetReceiveResponsePending()); + { + std::lock_guard lck(GetReceiveCompletionEventMtx()); + if (ResponseCallbackEventCounter()) + { + TRACE("%-35s:%-8d:%-16p HandleReceiveNotifications \n", __func__, __LINE__, (void*)this); + HandleReceiveNotifications(srequest); + } + } +} + WINHTTPAPI HINTERNET WINAPI WinHttpOpen ( LPCTSTR pszAgentW, @@ -1959,7 +1531,12 @@ WINHTTPAPI HINTERNET WINAPI WinHttpOpen std::vector proxies = Split(session->GetProxy(), ';'); if (proxies.empty()) - proxies.push_back(std::string(pszProxyW)); + { + std::string sproxy; + + ConvertCstrAssign(pszProxyW, WCTLEN(pszProxyW), sproxy); + proxies.push_back(sproxy); + } session->SetProxies(proxies); } @@ -2217,7 +1794,6 @@ WINHTTPAPI HINTERNET WINAPI WinHttpOpenRequest( ver = CURL_HTTP_VERSION_2_0; else return NULL; - res = curl_easy_setopt(request->GetCurl(), CURLOPT_HTTP_VERSION, ver); CURL_BAILOUT_ONERROR(res, request, NULL); } @@ -2538,6 +2114,7 @@ BOOLAPI WinHttpSendRequest TRACE("%-35s:%-8d:%-16p \n", __func__, __LINE__, (void*)request); return FALSE; } + request->CleanUp(); request->AsyncQueue(srequest, WINHTTP_CALLBACK_STATUS_SENDING_REQUEST, 0, NULL, 0, false); @@ -2572,22 +2149,6 @@ BOOLAPI WinHttpSendRequest return TRUE; } -void WinHttpRequestImp::WaitAsyncReceiveCompletion(std::shared_ptr &srequest) -{ - bool expected = false; - std::atomic_compare_exchange_strong(&GetReceiveResponsePending(), &expected, true); - - TRACE("%-35s:%-8d:%-16p GetReceiveResponsePending() = %d\n", __func__, __LINE__, (void*)this, (int) GetReceiveResponsePending()); - { - std::lock_guard lck(GetReceiveCompletionEventMtx()); - if (ResponseCallbackEventCounter()) - { - TRACE("%-35s:%-8d:%-16p HandleReceiveNotifications \n", __func__, __LINE__, (void*)this); - HandleReceiveNotifications(srequest); - } - } -} - WINHTTPAPI BOOL WINAPI @@ -2654,46 +2215,6 @@ WinHttpReceiveResponse return TRUE; } -bool WinHttpRequestImp::HandleQueryDataNotifications(std::shared_ptr &srequest, size_t available) -{ - bool expected = true; - bool result = std::atomic_compare_exchange_strong(&GetQueryDataPending(), &expected, false); - if (result) - { - if (!available) - { - size_t length; - - GetBodyStringMutex().lock(); - length = GetResponseString().size(); - available = length; - GetBodyStringMutex().unlock(); - } - - DWORD lpvStatusInformation = static_cast(available); - TRACE("%-35s:%-8d:%-16p WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE:%lu\n", __func__, __LINE__, (void*)this, lpvStatusInformation); - AsyncQueue(srequest, WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE, sizeof(lpvStatusInformation), - (LPVOID)&lpvStatusInformation, sizeof(lpvStatusInformation), true); - } - return result; -} - -void WinHttpRequestImp::WaitAsyncQueryDataCompletion(std::shared_ptr &srequest) -{ - bool completed; - GetQueryDataPending() = true; - TRACE("%-35s:%-8d:%-16p GetQueryDataPending() = %d\n", __func__, __LINE__, (void*)this, (int)GetQueryDataPending()); - { - std::lock_guard lck(GetQueryDataEventMtx()); - completed = GetQueryDataEventState(); - } - - if (completed) { - TRACE("%-35s:%-8d:%-16p transfer already finished\n", __func__, __LINE__, (void*)this); - HandleQueryDataNotifications(srequest, 0); - } -} - BOOLAPI WinHttpQueryDataAvailable ( @@ -3062,7 +2583,6 @@ BOOLAPI WinHttpQueryHeaders( return TRUE; } - if (dwInfoLevel == WINHTTP_QUERY_RAW_HEADERS) { std::string header; @@ -3125,71 +2645,6 @@ BOOLAPI WinHttpQueryHeaders( return FALSE; } -static DWORD ConvertSecurityProtocol(DWORD offered) -{ - DWORD min = 0; - DWORD max = 0; - - if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_SSL2) { - min = CURL_SSLVERSION_SSLv2; - } - else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_SSL3) { - min = CURL_SSLVERSION_SSLv3; - } - else - min = CURL_SSLVERSION_DEFAULT; - - if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1) { - min = CURL_SSLVERSION_TLSv1_0; - - } - else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_1) { - min = CURL_SSLVERSION_TLSv1_1; - } - else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2) { - min = CURL_SSLVERSION_TLSv1_2; - } - -#if (LIBCURL_VERSION_MAJOR>=7) && (LIBCURL_VERSION_MINOR >= 61) - if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2) { - max = CURL_SSLVERSION_MAX_TLSv1_2; - } - else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_1) { - max = CURL_SSLVERSION_MAX_TLSv1_1; - } - else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_TLS1) { - max = CURL_SSLVERSION_MAX_TLSv1_0; - } - else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_SSL3) { - max = CURL_SSLVERSION_SSLv3; - } - else if (offered & WINHTTP_FLAG_SECURE_PROTOCOL_SSL2) { - max = CURL_SSLVERSION_SSLv2; - } - else - max = CURL_SSLVERSION_MAX_DEFAULT; -#endif - - return min | max; -} - - -template -static BOOL CallMemberFunction(WinHttpBase *base, std::function fn, LPVOID lpBuffer) -{ - T *obj; - - if (!lpBuffer) - return FALSE; - - if ((obj = dynamic_cast(base))) - { - if (fn(obj, static_cast(lpBuffer))) - return TRUE; - } - return FALSE; -} - BOOLAPI WinHttpSetOption( HINTERNET hInternet, DWORD dwOption, @@ -3291,17 +2746,30 @@ WinHttpSetStatusCallback DWORD_PTR dwReserved ) { - WinHttpSessionImp *session = static_cast(hInternet); - WINHTTP_STATUS_CALLBACK oldcb; - DWORD olddwNotificationFlags; - - TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)session); - - if (hInternet == NULL) + if (hInternet == NULL) return WINHTTP_INVALID_STATUS_CALLBACK; - oldcb = session->GetCallback(&olddwNotificationFlags); - session->SetCallback(lpfnInternetCallback, dwNotificationFlags); + WinHttpBase *base = static_cast(hInternet); + WinHttpSessionImp *session; + WinHttpRequestImp *request; + + WINHTTP_STATUS_CALLBACK oldcb; + DWORD olddwNotificationFlags; + TRACE("%-35s:%-8d:%-16p\n", __func__, __LINE__, (void*)hInternet); + + if ((session = dynamic_cast(base))) + { + oldcb = session->GetCallback(&olddwNotificationFlags); + session->SetCallback(lpfnInternetCallback, dwNotificationFlags); + } + else if ((request = dynamic_cast(base))) + { + oldcb = request->GetCallback(&olddwNotificationFlags); + request->SetCallback(lpfnInternetCallback, dwNotificationFlags); + } + else + return FALSE; + return oldcb; } @@ -3688,4 +3156,3 @@ BOOLAPI WinHttpWriteData return TRUE; } - diff --git a/src/winhttppal_imp.h b/src/winhttppal_imp.h new file mode 100644 index 0000000..5c075ea --- /dev/null +++ b/src/winhttppal_imp.h @@ -0,0 +1,613 @@ +/*** + * Copyright (C) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE.txt file in the project root for full license information. + * + * =+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+=+ + * + * HTTP Library: Client-side APIs. + * + * This file contains WinHttp client implementation using CURL. + * + * For the latest on this and related APIs, please see: https://github.com/microsoft/WinHttpPAL + * + * =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- + ****/ + +extern "C" +{ +#include +#include +} + +#ifdef WIN32 +#define THREAD_ID GetCurrentThreadId() +#define THREADPARAM LPVOID +#define CREATETHREAD(func, param, id) \ + CreateThread( \ + NULL, \ + 0, \ + (LPTHREAD_START_ROUTINE)func, \ + param, \ + 0, \ + id) + +#define THREADJOIN(h) WaitForSingleObject(h, INFINITE); +#define THREADRETURN DWORD +#define THREAD_HANDLE HANDLE +#else +#define THREADPARAM void* +#define THREADRETURN void* +#define THREAD_ID pthread_self() +#define THREADJOIN(x) pthread_join(x, NULL) +#define THREAD_HANDLE pthread_t + +typedef void* (*LPTHREAD_START_ROUTINE)(void *); +static inline THREAD_HANDLE CREATETHREAD(LPTHREAD_START_ROUTINE func, LPVOID param, pthread_t *id) +{ + pthread_t inc_x_thread; + pthread_create(&inc_x_thread, NULL, func, param); + return inc_x_thread; +} +#endif + +class WinHttpSessionImp; + +class WinHttpBase +{ +public: + virtual ~WinHttpBase() {} +}; + +class WinHttpSessionImp :public WinHttpBase +{ + std::string m_ServerName; + std::string m_Referrer; + std::vector m_Proxies; + std::string m_Proxy; + WINHTTP_STATUS_CALLBACK m_InternetCallback = NULL; + DWORD m_NotificationFlags = 0; + + int m_ServerPort = 0; + long m_Timeout = 15000; + BOOL m_Async = false; + + bool m_closing = false; + DWORD m_MaxConnections = 0; + DWORD m_SecureProtocol = 0; + void *m_UserBuffer = NULL; + +public: + + BOOL SetUserData(void **data) + { + if (!data) + return FALSE; + + m_UserBuffer = *data; + return TRUE; + } + void *GetUserData() { return m_UserBuffer; } + + BOOL SetSecureProtocol(DWORD *data) + { + if (!data) + return FALSE; + + m_SecureProtocol = *data; + return TRUE; + } + DWORD GetSecureProtocol() const { return m_SecureProtocol; } + + BOOL SetMaxConnections(DWORD *data) + { + if (!data) + return FALSE; + + m_MaxConnections = *data; + return TRUE; + } + DWORD GetMaxConnections() const { return m_MaxConnections; } + + void SetAsync() { m_Async = TRUE; } + BOOL GetAsync() const { return m_Async; } + + BOOL GetThreadClosing() const { return m_closing; } + + std::vector &GetProxies() { return m_Proxies; } + void SetProxies(std::vector &proxies) { m_Proxies = proxies; } + + std::string &GetProxy() { return m_Proxy; } + + int GetServerPort() const { return m_ServerPort; } + void SetServerPort(int port) { m_ServerPort = port; } + + long GetTimeout() const; + void SetTimeout(long timeout) { m_Timeout = timeout; } + + std::string &GetServerName() { return m_ServerName; } + + std::string &GetReferrer() { return m_Referrer; } + + void SetCallback(WINHTTP_STATUS_CALLBACK lpfnInternetCallback, DWORD dwNotificationFlags) { + m_InternetCallback = lpfnInternetCallback; + m_NotificationFlags = dwNotificationFlags; + } + WINHTTP_STATUS_CALLBACK GetCallback(DWORD *dwNotificationFlags) + { + if (dwNotificationFlags) + *dwNotificationFlags = m_NotificationFlags; + return m_InternetCallback; + } + + ~WinHttpSessionImp(); +}; + +class WinHttpConnectImp :public WinHttpBase +{ + WinHttpSessionImp *m_Handle = NULL; + void *m_UserBuffer = NULL; + +public: + void SetHandle(WinHttpSessionImp *session) { m_Handle = session; } + WinHttpSessionImp *GetHandle() { return m_Handle; } + + BOOL SetUserData(void **data) + { + if (!data) + return FALSE; + + m_UserBuffer = *data; + return TRUE; + } + void *GetUserData() { return m_UserBuffer; } +}; + +struct BufferRequest +{ + LPVOID m_Buffer = NULL; + size_t m_Length = 0; + size_t m_Used = 0; +}; + +class WinHttpRequestImp :public WinHttpBase, public std::enable_shared_from_this +{ + CURL *m_curl = NULL; + std::vector m_ResponseString; + std::string m_HeaderString; + std::string m_Header; + std::string m_FullPath; + std::string m_OptionalData; + size_t m_TotalSize = 0; + size_t m_TotalReceiveSize = 0; + std::vector m_ReadData; + + std::mutex m_ReadDataEventMtx; + DWORD m_ReadDataEventCounter = 0; + + THREAD_HANDLE m_UploadCallbackThread; + bool m_UploadThreadExitStatus = false; + + DWORD m_SecureProtocol = 0; + DWORD m_MaxConnections = 0; + + std::string m_Type; + LPVOID m_UserBuffer = NULL; + bool m_HeaderReceiveComplete = false; + + struct curl_slist *m_HeaderList = NULL; + WinHttpConnectImp *m_Session = NULL; + std::mutex m_HeaderStringMutex; + std::mutex m_BodyStringMutex; + + std::mutex m_QueryDataEventMtx; + bool m_QueryDataEventState = false; + std::atomic m_QueryDataPending; + + std::mutex m_ReceiveCompletionEventMtx; + + std::mutex m_ReceiveResponseMutex; + std::atomic m_ReceiveResponsePending; + std::atomic m_ReceiveResponseEventCounter; + int m_ReceiveResponseSendCounter = 0; + + std::atomic m_RedirectPending; + + bool m_closing = false; + bool m_closed = false; + bool m_Completion = false; + bool m_Async = false; + CURLcode m_CompletionCode = CURLE_OK; + long m_VerifyPeer = 1; + long m_VerifyHost = 2; + bool m_Uploading = false; + + WINHTTP_STATUS_CALLBACK m_InternetCallback = NULL; + DWORD m_NotificationFlags = 0; + std::vector m_OutstandingWrites; + std::vector m_OutstandingReads; + bool m_Secure = false; + +public: + bool &GetSecure() { return m_Secure; } + std::vector &GetOutstandingWrites() { return m_OutstandingWrites; } + std::vector &GetOutstandingReads() { return m_OutstandingReads; } + + long &VerifyPeer() { return m_VerifyPeer; } + long &VerifyHost() { return m_VerifyHost; } + + static size_t WriteHeaderFunction(void *ptr, size_t size, size_t nmemb, void* rqst); + static size_t WriteBodyFunction(void *ptr, size_t size, size_t nmemb, void* rqst); + void ConsumeIncoming(std::shared_ptr &srequest, void* &ptr, size_t &available, size_t &read); + void FlushIncoming(std::shared_ptr &srequest); + void SetCallback(WINHTTP_STATUS_CALLBACK lpfnInternetCallback, DWORD dwNotificationFlags) { + m_InternetCallback = lpfnInternetCallback; + m_NotificationFlags = dwNotificationFlags; + } + WINHTTP_STATUS_CALLBACK GetCallback(DWORD *dwNotificationFlags) + { + if (dwNotificationFlags) + *dwNotificationFlags = m_NotificationFlags; + return m_InternetCallback; + } + + void SetAsync() { m_Async = TRUE; } + BOOL GetAsync() { return m_Async; } + + WinHttpRequestImp(); + + bool &GetQueryDataEventState() { return m_QueryDataEventState; } + std::mutex &GetQueryDataEventMtx() { return m_QueryDataEventMtx; } + + bool HandleQueryDataNotifications(std::shared_ptr &, size_t available); + void WaitAsyncQueryDataCompletion(std::shared_ptr &); + + void HandleReceiveNotifications(std::shared_ptr &srequest); + void WaitAsyncReceiveCompletion(std::shared_ptr &srequest); + + CURLcode &GetCompletionCode() { return m_CompletionCode; } + bool &GetCompletionStatus() { return m_Completion; } + bool &GetClosing() { return m_closing; } + bool &GetClosed() { return m_closed; } + void CleanUp(); + ~WinHttpRequestImp(); + + bool &GetUploadThreadExitStatus() { return m_UploadThreadExitStatus; } + std::atomic &GetQueryDataPending() { return m_QueryDataPending; } + + THREAD_HANDLE &GetUploadThread() { + return m_UploadCallbackThread; + } + + static THREADRETURN UploadThreadFunction(THREADPARAM lpThreadParameter); + + // used to wake up CURL ReadCallback triggered on a upload request + std::mutex &GetReadDataEventMtx() { return m_ReadDataEventMtx; } + DWORD &GetReadDataEventCounter() { return m_ReadDataEventCounter; } + + std::mutex &GetReceiveCompletionEventMtx() { return m_ReceiveCompletionEventMtx; } + + // counters used to see which one of these events are observed: ResponseCallbackEventCounter + // counters used to see which one of these events are broadcasted: ResponseCallbackSendCounter + // WINHTTP_CALLBACK_STATUS_RECEIVING_RESPONSE + // WINHTTP_CALLBACK_STATUS_RESPONSE_RECEIVED + // WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE + // + // These counters need to be 3 and identical under normal circumstances + std::atomic &ResponseCallbackEventCounter() { return m_ReceiveResponseEventCounter; } + int &ResponseCallbackSendCounter() { return m_ReceiveResponseSendCounter; } + std::atomic &GetReceiveResponsePending() { return m_ReceiveResponsePending; } + std::mutex &GetReceiveResponseMutex() { return m_ReceiveResponseMutex; } + + std::atomic &GetRedirectPending() { return m_RedirectPending; } + + std::mutex &GetHeaderStringMutex() { return m_HeaderStringMutex; } + std::mutex &GetBodyStringMutex() { return m_HeaderStringMutex; } + + BOOL AsyncQueue(std::shared_ptr &, + DWORD dwInternetStatus, size_t statusInformationLength, + LPVOID statusInformation, DWORD statusInformationCopySize, bool allocate); + + void SetHeaderReceiveComplete() { m_HeaderReceiveComplete = TRUE; } + + BOOL SetSecureProtocol(DWORD *data) + { + if (!data) + return FALSE; + + m_SecureProtocol = *data; + return TRUE; + } + DWORD GetSecureProtocol() { return m_SecureProtocol; } + + BOOL SetMaxConnections(DWORD *data) + { + if (!data) + return FALSE; + + m_MaxConnections = *data; + return TRUE; + } + DWORD GetMaxConnections() { return m_MaxConnections; } + + BOOL SetUserData(void **data) + { + if (!data) + return FALSE; + + m_UserBuffer = *data; + return TRUE; + } + LPVOID GetUserData() { return m_UserBuffer; } + + void SetFullPath(std::string &server, std::string &relative) + { + m_FullPath.append(server); + m_FullPath.append(relative); + } + std::string &GetFullPath() { return m_FullPath; } + std::string &GetType() { return m_Type; } + bool &Uploading() { return m_Uploading; } + + void SetSession(WinHttpConnectImp *connect) { m_Session = connect; } + WinHttpConnectImp *GetSession() { return m_Session; } + + struct curl_slist *GetHeaderList() { return m_HeaderList; } + + void SetHeaderList(struct curl_slist *list) { m_HeaderList = list; } + std::string &GetOutgoingHeaderList() { return m_Header; } + + void AddHeader(std::string &headers) { + std::stringstream check1(headers); + std::string str; + + while(getline(check1, str, '\n')) + { + str.erase(std::remove(str.begin(), str.end(), '\r'), str.end()); + SetHeaderList(curl_slist_append(GetHeaderList(), str.c_str())); + } + } + + std::vector &GetResponseString() { return m_ResponseString; } + std::string &GetHeaderString() { return m_HeaderString; } + CURL *GetCurl() { return m_curl; } + + BOOL SetProxy(std::vector &proxies); + BOOL SetServer(std::string &ServerName, int nServerPort); + + size_t &GetTotalLength() { return m_TotalSize; } + size_t &GetReadLength() { return m_TotalReceiveSize; } + + std::string &GetOptionalData() { return m_OptionalData; } + BOOL SetOptionalData(void *lpOptional, size_t dwOptionalLength) + { + if (!lpOptional || !dwOptionalLength) + return FALSE; + + m_OptionalData.assign(&(static_cast(lpOptional))[0], dwOptionalLength); + return TRUE; + } + + std::vector &GetReadData() { return m_ReadData; } + + void AppendReadData(const void *data, size_t len) + { + std::lock_guard lck(GetReadDataEventMtx()); + m_ReadData.insert(m_ReadData.end(), static_cast(data), static_cast(data) + len); + } + + static int SocketCallback(CURL *handle, curl_infotype type, + char *data, size_t size, + void *userp); + + static size_t ReadCallback(void *ptr, size_t size, size_t nmemb, void *userp); +}; + +class UserCallbackContext +{ + typedef void (*CompletionCb)(std::shared_ptr &, DWORD status); + + std::shared_ptr m_request; + DWORD m_dwInternetStatus; + DWORD m_dwStatusInformationLength; + DWORD m_dwNotificationFlags; + WINHTTP_STATUS_CALLBACK m_cb; + LPVOID m_userdata; + LPVOID m_StatusInformationVal = NULL; + BYTE* m_StatusInformation = NULL; + bool m_allocate = FALSE; + BOOL m_AsyncResultValid = false; + CompletionCb m_requestCompletionCb; + + BOOL SetAsyncResult(LPVOID statusInformation, DWORD statusInformationCopySize, bool allocate) { + if (allocate) + { + m_StatusInformation = new BYTE[statusInformationCopySize]; + if (m_StatusInformation) + { + memcpy(m_StatusInformation, statusInformation, statusInformationCopySize); + } + else + return FALSE; + } + else + { + m_StatusInformationVal = statusInformation; + } + m_AsyncResultValid = true; + m_allocate = allocate; + + return TRUE; + } + +public: + UserCallbackContext(std::shared_ptr &request, + DWORD dwInternetStatus, + DWORD dwStatusInformationLength, + DWORD dwNotificationFlags, + WINHTTP_STATUS_CALLBACK cb, + LPVOID userdata, + LPVOID statusInformation, + DWORD statusInformationCopySize, bool allocate, + CompletionCb completion) + : m_request(request), m_dwInternetStatus(dwInternetStatus), + m_dwStatusInformationLength(dwStatusInformationLength), + m_dwNotificationFlags(dwNotificationFlags), m_cb(cb), m_userdata(userdata), + m_requestCompletionCb(completion) + { + if (statusInformation) + SetAsyncResult(statusInformation, statusInformationCopySize, allocate); + } + + ~UserCallbackContext() + { + delete [] m_StatusInformation; + } + + std::shared_ptr &GetRequestRef() { return m_request; } + WinHttpRequestImp *GetRequest() { return m_request.get(); } + DWORD GetInternetStatus() const { return m_dwInternetStatus; } + DWORD GetStatusInformationLength() const { return m_dwStatusInformationLength; } + DWORD GetNotificationFlags() const { return m_dwNotificationFlags; } + WINHTTP_STATUS_CALLBACK &GetCb() { return m_cb; } + CompletionCb &GetRequestCompletionCb() { return m_requestCompletionCb; } + LPVOID GetUserdata() { return m_userdata; } + LPVOID GetStatusInformation() { + if (m_allocate) + return m_StatusInformation; + else + return m_StatusInformationVal; + } + BOOL GetStatusInformationValid() const { return m_AsyncResultValid; } + +private: + UserCallbackContext(const UserCallbackContext&); + UserCallbackContext& operator=(const UserCallbackContext&); +}; + +template +class WinHttpHandleContainer +{ + std::mutex m_ActiveRequestMtx; + std::vector> m_ActiveRequests; + +public: + static WinHttpHandleContainer &Instance() + { + static WinHttpHandleContainer the_instance; + return the_instance; + } + + void UnRegister(T *val); + bool IsRegistered(T *val); + void Register(std::shared_ptr rqst); +}; + + +class UserCallbackContainer +{ + typedef std::vector UserCallbackQueue; + + UserCallbackQueue m_Queue; + + // used to protect GetCallbackMap() + std::mutex m_MapMutex; + + THREAD_HANDLE m_hThread; + + // used by UserCallbackThreadFunction as a wake-up signal + std::mutex m_hEventMtx; + std::condition_variable m_hEvent; + + // cleared by UserCallbackThreadFunction, set by multiple event providers + std::atomic m_EventCounter; + + bool m_closing = false; + UserCallbackQueue &GetCallbackQueue() { return m_Queue; } + + +public: + + bool GetClosing() const { return m_closing; } + + UserCallbackContext* GetNext(); + + static THREADRETURN UserCallbackThreadFunction(LPVOID lpThreadParameter); + + BOOL Queue(UserCallbackContext *ctx); + void DrainQueue(); + + UserCallbackContainer(): m_EventCounter(0) + { + DWORD thread_id; + m_hThread = CREATETHREAD( + UserCallbackThreadFunction, // thread function name + this, // argument to thread function + &thread_id + ); + } + + ~UserCallbackContainer() + { + m_closing = true; + { + std::lock_guard lck(m_hEventMtx); + m_EventCounter++; + m_hEvent.notify_all(); + } + THREADJOIN(m_hThread); + DrainQueue(); + } + + static UserCallbackContainer &GetInstance() + { + static UserCallbackContainer the_instance; + return the_instance; + } +private: + UserCallbackContainer(const UserCallbackContainer&); + UserCallbackContainer& operator=(const UserCallbackContainer&); +}; + +class EnvInit +{ +public: + EnvInit(); +}; + +class ComContainer +{ + THREAD_HANDLE m_hAsyncThread; + std::mutex m_hAsyncEventMtx; + + // used to wake up the Async Thread + // Set by external components, cleared by the Async thread + std::atomic m_hAsyncEventCounter; + std::condition_variable m_hAsyncEvent; + std::mutex m_MultiMutex; + CURLM *m_curlm = NULL; + std::vector< CURL *> m_ActiveCurl; + std::vector> m_ActiveRequests; + BOOL m_closing = FALSE; + + // used to protect CURLM data structures + BOOL GetThreadClosing() const { return m_closing; } + +public: + CURL *AllocCURL(); + void FreeCURL(CURL *ptr); + static ComContainer &GetInstance(); + void ResumeTransfer(CURL *handle, int bitmask); + BOOL AddHandle(std::shared_ptr &srequest, CURL *handle); + BOOL RemoveHandle(std::shared_ptr &srequest, CURL *handle, bool clearPrivate); + long GetTimeout(); + void KickStart(); + ComContainer(); + ~ComContainer(); + + static THREADRETURN AsyncThreadFunction(THREADPARAM lpThreadParameter); + + int QueryData(int *still_running); +private: + ComContainer(const ComContainer&); + ComContainer& operator=(const ComContainer&); +};