diff --git a/src/Microsoft.Azure.Batch.SoftwareEntitlement.Client.Native/SoftwareEntitlementClient.cpp b/src/Microsoft.Azure.Batch.SoftwareEntitlement.Client.Native/SoftwareEntitlementClient.cpp index 1b112c1..36a02b8 100644 --- a/src/Microsoft.Azure.Batch.SoftwareEntitlement.Client.Native/SoftwareEntitlementClient.cpp +++ b/src/Microsoft.Azure.Batch.SoftwareEntitlement.Client.Native/SoftwareEntitlementClient.cpp @@ -1,9 +1,17 @@ +#if _MSC_VER +// The use of std::getenv generates a warning on Windows. +// Since the aim is to have portable code, suppress the warnings. +#define _CRT_SECURE_NO_WARNINGS +#endif #include "SoftwareEntitlementClient.h" #include #include #include +#include +#include #include #include +#include #include #include #include @@ -34,7 +42,8 @@ #endif #ifdef _WIN32 -#include "Wincrypt.h" +#include +#include #endif @@ -282,6 +291,28 @@ class Curl char _errbuf[CURL_ERROR_SIZE]; std::string _response; +public: + class CurlException : public Exception + { + CURLcode _code; + + public: + CurlException(CURLcode code, const std::string& message) + : Exception(message) + , _code(code) + {} + + CurlException(CURLcode code, const char *message) + : Exception(message) + , _code(code) + {} + + CURLcode GetCode() const + { + return _code; + } + }; + private: void ThrowIfCurlError(CURLcode res) { @@ -292,7 +323,7 @@ private: std::ostringstream what; what << "libcurl_error " << res << ": " << _errbuf; - throw Exception(what.str()); + throw CurlException(res, what.str()); } static size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* context) @@ -454,6 +485,19 @@ public: #ifdef _WIN32 ThrowIfCurlError(curl_easy_setopt(_curl.get(), CURLOPT_SSL_CTX_FUNCTION, OpenSSLContextCallback)); #endif // _WIN32 + + // Allow overriding the default connection timeout of 300 seconds. + const char* env = std::getenv("AZ_BATCH_SES_CURLOPT_CONNECTTIMEOUT"); + long timeout = 0; + if (env != nullptr) + { + timeout = std::strtol(env, nullptr, 10); + } + if (timeout <= 0 || timeout == LONG_MAX) + { + timeout = 300L; + } + ThrowIfCurlError(curl_easy_setopt(_curl.get(), CURLOPT_CONNECTTIMEOUT, timeout)); } void Post( @@ -571,8 +615,99 @@ public: throw Exception(GetErrorMessage(code)); } + + static std::unique_ptr GetEntitlement( + const std::string& url, + const std::string& entitlement_token, + const std::string& requested_entitlement) + { + Curl curl; + curl.Post(url + "softwareEntitlements?api-version=2017-05-01.5.0", entitlement_token, requested_entitlement); + + curl.VerifyIntermediateCertificate(url); + + return curl.GetEntitlement(); + } }; +#ifdef _WIN32 +struct WinHttpDeleter +{ + void operator()(HINTERNET h) + { + WinHttpCloseHandle(h); + } +}; + +typedef std::unique_ptr WinHttpHandle; + +// +// OpenSSL does not hook into the Windows Automatic Root Certificates Update process. +// This results in certificate validation failures, so we perform a dummy connection +// using WinHTTP which will setup the root cert store correctly. +// +void EnsureRootCertsArePopulated(const std::string& url) +{ + std::string prefix("https://"); + size_t start = url.rfind(prefix, 0); + if (start == std::string::npos) + { + throw Exception("Malformed URL: " + url); + } + + start += prefix.length(); + size_t end = url.find("/", start); + std::string temp; + if (end == std::string::npos) + { + temp = url.substr(start); + } + else + { + temp = url.substr(start, end - start); + } + + std::wstring hostname(temp.cbegin(), temp.cend()); + + WinHttpHandle hSession( + WinHttpOpen( + L"Azure Batch Software Entitlement Service client", + WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY, + WINHTTP_NO_PROXY_NAME, + WINHTTP_NO_PROXY_BYPASS, + 0) + ); + ThrowIfWin32Error(hSession == nullptr); + + WinHttpHandle hConn( + WinHttpConnect(hSession.get(), hostname.c_str(), INTERNET_DEFAULT_HTTPS_PORT, 0) + ); + ThrowIfWin32Error(hConn == nullptr); + + WinHttpHandle hRequest( + WinHttpOpenRequest( + hConn.get(), + nullptr, + nullptr, + nullptr, + WINHTTP_NO_REFERER, + WINHTTP_DEFAULT_ACCEPT_TYPES, + WINHTTP_FLAG_SECURE) + ); + ThrowIfWin32Error(hRequest == nullptr); + + ThrowIfWin32Error(TRUE != WinHttpSendRequest( + hRequest.get(), + WINHTTP_NO_ADDITIONAL_HEADERS, + 0, + WINHTTP_NO_REQUEST_DATA, + 0, + 0, + 0) + ); +} +#endif + } // anonymous namespace @@ -628,7 +763,8 @@ void Cleanup() std::unique_ptr GetEntitlement( std::string url, const std::string& entitlement_token, - const std::string& requested_entitlement) + const std::string& requested_entitlement, + unsigned int retries) { std::lock_guard lock(s_lock); @@ -657,12 +793,32 @@ std::unique_ptr GetEntitlement( url += '/'; } - Curl curl; - curl.Post(url + "softwareEntitlements?api-version=2017-05-01.5.0", entitlement_token, requested_entitlement); + for (unsigned int retry = 1; retry <= retries; ++retry) + { + try + { + return Curl::GetEntitlement(url, entitlement_token, requested_entitlement); + } + catch (const Curl::CurlException& e) + { + if (e.GetCode() == CURLE_OPERATION_TIMEDOUT) + { + std::this_thread::sleep_for(std::chrono::seconds(retry)); + } +#ifdef _WIN32 + else if (e.GetCode() == CURLE_SSL_CACERT) + { + EnsureRootCertsArePopulated(url); + } +#endif + else + { + throw e; + } + } + } - curl.VerifyIntermediateCertificate(url); - - return curl.GetEntitlement(); + return Curl::GetEntitlement(url, entitlement_token, requested_entitlement); } diff --git a/src/Microsoft.Azure.Batch.SoftwareEntitlement.Client.Native/SoftwareEntitlementClient.h b/src/Microsoft.Azure.Batch.SoftwareEntitlement.Client.Native/SoftwareEntitlementClient.h index 539641b..e491823 100644 --- a/src/Microsoft.Azure.Batch.SoftwareEntitlement.Client.Native/SoftwareEntitlementClient.h +++ b/src/Microsoft.Azure.Batch.SoftwareEntitlement.Client.Native/SoftwareEntitlementClient.h @@ -54,7 +54,8 @@ public: std::unique_ptr GetEntitlement( std::string url, const std::string& entitlement_token, - const std::string& requested_entitlement + const std::string& requested_entitlement, + unsigned int retries = 5 ); diff --git a/src/sesclient.native/sesclient.native.vcxproj b/src/sesclient.native/sesclient.native.vcxproj index 412478e..a868acd 100644 --- a/src/sesclient.native/sesclient.native.vcxproj +++ b/src/sesclient.native/sesclient.native.vcxproj @@ -102,7 +102,7 @@ Console - crypt32.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) + winhttp.lib;crypt32.lib;%(AdditionalDependencies) @@ -116,7 +116,7 @@ Console - crypt32.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies) + winhttp.lib;crypt32.lib;%(AdditionalDependencies) @@ -134,7 +134,7 @@ Console true true - crypt32.lib;%(AdditionalDependencies) + winhttp.lib;crypt32.lib;%(AdditionalDependencies) @@ -152,7 +152,7 @@ Console true true - crypt32.lib;%(AdditionalDependencies) + winhttp.lib;crypt32.lib;%(AdditionalDependencies)