* Refactoring of the crypto code into a proper library with implementations hidden behind abstract interfaces.

Co-authored-by: Amaury Chamayou <amaury@xargs.fr>
This commit is contained in:
Christoph M. Wintersteiger 2021-02-25 17:34:07 +00:00 коммит произвёл GitHub
Родитель 05e79a5960
Коммит 57c2d2eb78
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
83 изменённых файлов: 4056 добавлений и 3229 удалений

Просмотреть файл

@ -1 +1 @@
Always choose the lesser of two weevils.
A daily, just for good measure.

Просмотреть файл

@ -121,7 +121,7 @@ if("virtual" IN_LIST COMPILE_TARGETS)
_LIBCPP_HAS_THREAD_API_PTHREAD
)
target_compile_options(ccf.virtual PUBLIC -stdlib=libc++)
target_compile_options(ccf.virtual PUBLIC ${COMPILE_LIBCXX})
add_warning_checks(ccf.virtual)
target_include_directories(
@ -147,7 +147,6 @@ if("virtual" IN_LIST COMPILE_TARGETS)
set_property(TARGET ccf.virtual PROPERTY POSITION_INDEPENDENT_CODE ON)
use_client_mbedtls(ccf.virtual)
add_san(ccf.virtual)
add_lvi_mitigations(ccf.virtual)
@ -252,7 +251,6 @@ if(BUILD_TESTS)
${CMAKE_CURRENT_SOURCE_DIR}/src/kv/test/kv_snapshot.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/kv/test/kv_dynamic_tables.cpp
)
use_client_mbedtls(kv_test)
target_link_libraries(
kv_test PRIVATE ${CMAKE_THREAD_LIBS_INIT} http_parser.host
)
@ -278,33 +276,31 @@ if(BUILD_TESTS)
raft_test ${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/main.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/view_history.cpp
)
target_link_libraries(raft_test PRIVATE ${CRYPTO_LIBRARY})
target_link_libraries(raft_test PRIVATE ccfcrypto.host)
add_unit_test(
raft_enclave_test
${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/enclave.cpp
)
target_include_directories(raft_enclave_test PRIVATE ${CCFCRYPTO_INC})
target_link_libraries(raft_enclave_test PRIVATE ${CRYPTO_LIBRARY})
target_link_libraries(raft_enclave_test PRIVATE ccfcrypto.host)
add_unit_test(
crypto_test ${CMAKE_CURRENT_SOURCE_DIR}/src/crypto/test/crypto.cpp
)
target_include_directories(crypto_test PRIVATE ${CCFCRYPTO_INC})
target_link_libraries(crypto_test PRIVATE ${CRYPTO_LIBRARY})
target_link_libraries(crypto_test PRIVATE ccfcrypto.host)
add_unit_test(
history_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/history.cpp
)
target_link_libraries(
history_test PRIVATE ${CRYPTO_LIBRARY} http_parser.host
)
target_link_libraries(history_test PRIVATE ccfcrypto.host http_parser.host)
add_unit_test(
progress_tracker_test
${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/progress_tracker.cpp
)
target_link_libraries(progress_tracker_test PRIVATE ${CRYPTO_LIBRARY})
target_link_libraries(progress_tracker_test PRIVATE ccfcrypto.host)
add_unit_test(
secret_sharing_test
@ -314,10 +310,8 @@ if(BUILD_TESTS)
add_unit_test(
encryptor_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/encryptor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/crypto/symmetric_key.cpp
)
use_client_mbedtls(encryptor_test)
target_link_libraries(encryptor_test PRIVATE)
target_link_libraries(encryptor_test PRIVATE ccfcrypto.host)
add_unit_test(
historical_queries_test
@ -353,13 +347,11 @@ if(BUILD_TESTS)
key_exchange_test
${CMAKE_CURRENT_SOURCE_DIR}/src/tls/test/key_exchange.cpp
)
use_client_mbedtls(key_exchange_test)
target_link_libraries(key_exchange_test PRIVATE)
add_unit_test(
channels_test ${CMAKE_CURRENT_SOURCE_DIR}/src/node/test/channels.cpp
)
use_client_mbedtls(channels_test)
target_link_libraries(channels_test PRIVATE)
add_unit_test(
@ -430,10 +422,9 @@ if(BUILD_TESTS)
# Merkle Tree memory test
add_executable(merkle_mem src/node/test/merkle_mem.cpp)
target_link_libraries(
merkle_mem PRIVATE ccfcrypto.host ${CMAKE_THREAD_LIBS_INIT}
$<BUILD_INTERFACE:merklecpp> crypto
merkle_mem PRIVATE ${CMAKE_THREAD_LIBS_INIT} $<BUILD_INTERFACE:merklecpp>
crypto
)
use_client_mbedtls(merkle_mem)
# merklecpp tests
set(MERKLECPP_TEST_PREFIX "merklecpp-")
@ -444,7 +435,6 @@ if(BUILD_TESTS)
raft_driver ${CMAKE_CURRENT_SOURCE_DIR}/src/consensus/aft/test/driver.cpp
)
target_link_libraries(raft_driver PRIVATE ccfcrypto.host)
use_client_mbedtls(raft_driver)
target_include_directories(raft_driver PRIVATE src/aft)
add_test(
@ -476,28 +466,13 @@ if(BUILD_TESTS)
SRCS src/crypto/test/bench.cpp
LINK_LIBS
)
add_picobench(merkle_bench SRCS src/node/test/merkle_bench.cpp)
add_picobench(history_bench SRCS src/node/test/history_bench.cpp)
add_picobench(
merkle_bench
SRCS src/node/test/merkle_bench.cpp
LINK_LIBS ccfcrypto.host crypto
)
add_picobench(
history_bench
SRCS src/node/test/history_bench.cpp
LINK_LIBS ccfcrypto.host crypto
)
add_picobench(
kv_bench
SRCS src/kv/test/kv_bench.cpp src/crypto/symmetric_key.cpp
src/enclave/thread_local.cpp
LINK_LIBS ccfcrypto.host
kv_bench SRCS src/kv/test/kv_bench.cpp src/enclave/thread_local.cpp
)
add_picobench(hash_bench SRCS src/ds/test/hash_bench.cpp)
add_picobench(
digest_bench
SRCS src/crypto/test/digest_bench.cpp
LINK_LIBS ccfcrypto.host
)
add_picobench(digest_bench SRCS src/crypto/test/digest_bench.cpp)
# Storing signed governance operations
add_e2e_test(

Просмотреть файл

@ -30,7 +30,7 @@ if("virtual" IN_LIST COMPILE_TARGETS)
add_library(aft.virtual STATIC ${AFT_SRC})
add_san(aft.virtual)
target_compile_options(aft.virtual PRIVATE -stdlib=libc++)
target_compile_options(aft.virtual PRIVATE ${COMPILE_LIBCXX})
target_compile_definitions(
aft.virtual PUBLIC INSIDE_ENCLAVE VIRTUAL_ENCLAVE
_LIBCPP_HAS_THREAD_API_PTHREAD

Просмотреть файл

@ -231,7 +231,6 @@ function(add_ccf_app name)
set_property(TARGET ${virt_name} PROPERTY POSITION_INDEPENDENT_CODE ON)
use_client_mbedtls(${virt_name})
add_san(${virt_name})
add_lvi_mitigations(${virt_name})

Просмотреть файл

@ -182,23 +182,23 @@ set(HTTP_PARSER_SOURCES
find_library(CRYPTO_LIBRARY crypto)
list(APPEND COMPILE_LIBCXX -stdlib=libc++)
list(APPEND LINK_LIBCXX -lc++ -lc++abi -lc++fs -stdlib=libc++)
include(${CCF_DIR}/cmake/crypto.cmake)
include(${CCF_DIR}/cmake/quickjs.cmake)
include(${CCF_DIR}/cmake/sss.cmake)
list(APPEND LINK_LIBCXX -lc++ -lc++abi -lc++fs -stdlib=libc++)
# Unit test wrapper
function(add_unit_test name)
add_executable(${name} ${CCF_DIR}/src/enclave/thread_local.cpp ${ARGN})
target_compile_options(${name} PRIVATE -stdlib=libc++)
target_compile_options(${name} PRIVATE ${COMPILE_LIBCXX})
target_include_directories(${name} PRIVATE src ${CCFCRYPTO_INC})
enable_coverage(${name})
target_link_libraries(
${name} PRIVATE ${LINK_LIBCXX} ccfcrypto.host openenclave::oehost
$<BUILD_INTERFACE:merklecpp> crypto
$<BUILD_INTERFACE:merklecpp>
)
use_client_mbedtls(${name})
add_san(${name})
add_test(NAME ${name} COMMAND ${CCF_DIR}/tests/unit_test_wrapper.sh ${name})
@ -212,11 +212,10 @@ endfunction()
# Test binary wrapper
function(add_test_bin name)
add_executable(${name} ${CCF_DIR}/src/enclave/thread_local.cpp ${ARGN})
target_compile_options(${name} PRIVATE -stdlib=libc++)
target_compile_options(${name} PRIVATE ${COMPILE_LIBCXX})
target_include_directories(${name} PRIVATE src ${CCFCRYPTO_INC})
enable_coverage(${name})
target_link_libraries(${name} PRIVATE ${LINK_LIBCXX} ccfcrypto.host)
use_client_mbedtls(${name})
add_san(${name})
endfunction()
@ -227,8 +226,7 @@ if("sgx" IN_LIST COMPILE_TARGETS)
)
add_warning_checks(cchost)
use_client_mbedtls(cchost)
target_compile_options(cchost PRIVATE -stdlib=libc++)
target_compile_options(cchost PRIVATE ${COMPILE_LIBCXX})
target_include_directories(cchost PRIVATE ${CCF_GENERATED_DIR})
add_san(cchost)
add_lvi_mitigations(cchost)
@ -267,9 +265,8 @@ if("virtual" IN_LIST COMPILE_TARGETS)
# Virtual Host Executable
add_executable(cchost.virtual ${SNMALLOC_CPP} ${CCF_DIR}/src/host/main.cpp)
use_client_mbedtls(cchost.virtual)
target_compile_definitions(cchost.virtual PRIVATE -DVIRTUAL_ENCLAVE)
target_compile_options(cchost.virtual PRIVATE -stdlib=libc++)
target_compile_options(cchost.virtual PRIVATE ${COMPILE_LIBCXX})
target_include_directories(
cchost.virtual PRIVATE ${OE_INCLUDEDIR} ${CCF_GENERATED_DIR}
)
@ -294,10 +291,9 @@ endif()
add_executable(
scenario_perf_client ${CCF_DIR}/src/perf_client/scenario_perf_client.cpp
)
use_client_mbedtls(scenario_perf_client)
target_link_libraries(
scenario_perf_client PRIVATE ${CMAKE_THREAD_LIBS_INIT} http_parser.host
ccfcrypto.host
ccfcrypto.host c++fs
)
install(TARGETS scenario_perf_client DESTINATION bin)
@ -376,13 +372,13 @@ function(add_client_exe name)
add_executable(${name} ${PARSED_ARGS_SRCS})
target_link_libraries(${name} PRIVATE ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(
${name} PRIVATE ${CMAKE_THREAD_LIBS_INIT} ccfcrypto.host
)
target_include_directories(
${name} PRIVATE ${CCF_DIR}/src/perf_client ${PARSED_ARGS_INCLUDE_DIRS}
)
use_client_mbedtls(${name})
endfunction()
# Helper for building end-to-end function tests using the python infrastructure
@ -596,7 +592,7 @@ function(add_picobench name)
target_link_libraries(
${name} PRIVATE ${CMAKE_THREAD_LIBS_INIT} ${PARSED_ARGS_LINK_LIBS}
$<BUILD_INTERFACE:merklecpp> crypto
$<BUILD_INTERFACE:merklecpp> ccfcrypto.host
)
# -Wall -Werror catches a number of warnings in picobench
@ -609,7 +605,5 @@ function(add_picobench name)
"$<TARGET_FILE:${name}> --samples=1000 --out-fmt=csv --output=${name}.csv && cat ${name}.csv"
)
use_client_mbedtls(${name})
set_property(TEST ${name} PROPERTY LABELS benchmark)
endfunction()

Просмотреть файл

@ -1,8 +1,27 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the Apache 2.0 License.
set(CCFCRYPTO_SRC ${CCF_DIR}/src/crypto/hash.cpp
${CCF_DIR}/src/crypto/symmetric_key.cpp
set(CCFCRYPTO_SRC
${CCF_DIR}/src/crypto/entropy.cpp
${CCF_DIR}/src/crypto/hash.cpp
${CCF_DIR}/src/crypto/symmetric_key.cpp
${CCF_DIR}/src/crypto/key_pair.cpp
${CCF_DIR}/src/crypto/rsa_key_pair.cpp
${CCF_DIR}/src/crypto/verifier.cpp
${CCF_DIR}/src/crypto/mbedtls/symmetric_key.cpp
${CCF_DIR}/src/crypto/openssl/symmetric_key.cpp
${CCF_DIR}/src/crypto/mbedtls/public_key.cpp
${CCF_DIR}/src/crypto/openssl/public_key.cpp
${CCF_DIR}/src/crypto/mbedtls/key_pair.cpp
${CCF_DIR}/src/crypto/openssl/key_pair.cpp
${CCF_DIR}/src/crypto/mbedtls/hash.cpp
${CCF_DIR}/src/crypto/openssl/hash.cpp
${CCF_DIR}/src/crypto/mbedtls/rsa_public_key.cpp
${CCF_DIR}/src/crypto/openssl/rsa_public_key.cpp
${CCF_DIR}/src/crypto/mbedtls/rsa_key_pair.cpp
${CCF_DIR}/src/crypto/openssl/rsa_key_pair.cpp
${CCF_DIR}/src/crypto/mbedtls/verifier.cpp
${CCF_DIR}/src/crypto/openssl/verifier.cpp
)
if("sgx" IN_LIST COMPILE_TARGETS)
@ -28,8 +47,9 @@ endif()
add_library(ccfcrypto.host STATIC ${CCFCRYPTO_SRC})
add_san(ccfcrypto.host)
target_compile_options(ccfcrypto.host PRIVATE -stdlib=libc++)
target_link_libraries(ccfcrypto.host PRIVATE crypto)
target_compile_options(ccfcrypto.host PUBLIC ${COMPILE_LIBCXX})
target_link_options(ccfcrypto.host PUBLIC ${LINK_LIBCXX})
target_link_libraries(ccfcrypto.host PUBLIC crypto)
use_client_mbedtls(ccfcrypto.host)
set_property(TARGET ccfcrypto.host PROPERTY POSITION_INDEPENDENT_CODE ON)

Просмотреть файл

@ -5,16 +5,18 @@
add_picobench(
small_bank_serdes_bench
SRCS ${CMAKE_CURRENT_LIST_DIR}/tests/small_bank_serdes_bench.cpp
src/crypto/symmetric_key.cpp src/enclave/thread_local.cpp
src/enclave/thread_local.cpp
INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}
LINK_LIBS ccfcrypto.host crypto
LINK_LIBS
)
add_client_exe(
small_bank_client
SRCS ${CMAKE_CURRENT_LIST_DIR}/clients/small_bank_client.cpp
)
target_link_libraries(small_bank_client PRIVATE http_parser.host ccfcrypto.host)
target_link_libraries(
small_bank_client PRIVATE http_parser.host ccfcrypto.host c++fs
)
# SmallBank application
add_ccf_app(smallbank SRCS ${CMAKE_CURRENT_LIST_DIR}/app/smallbank.cpp)

Просмотреть файл

@ -53,39 +53,39 @@
},
{
"account": 1,
"balance": 0
"balance": -67
},
{
"account": 2,
"balance": 28
"balance": -24
},
{
"account": 3,
"balance": 890665
"balance": -47
},
{
"account": 4,
"balance": 19
"balance": -48
},
{
"account": 5,
"balance": 322
"balance": 0
},
{
"account": 6,
"balance": 40
"balance": 113
},
{
"account": 7,
"balance": 43
"balance": 881672
},
{
"account": 8,
"balance": 19
"balance": 0
},
{
"account": 9,
"balance": 63
"balance": 0
}
]
}

Просмотреть файл

@ -49,7 +49,7 @@
"final": [
{
"account": 0,
"balance": 120018
"balance": 36
},
{
"account": 1,
@ -57,35 +57,35 @@
},
{
"account": 2,
"balance": 187
"balance": -50
},
{
"account": 3,
"balance": 42
"balance": 53
},
{
"account": 4,
"balance": -35
"balance": 49
},
{
"account": 5,
"balance": 36
},
{
"account": 6,
"balance": 1189
},
{
"account": 7,
"balance": 6
},
{
"account": 8,
"balance": 0
},
{
"account": 6,
"balance": 0
},
{
"account": 7,
"balance": 107846
},
{
"account": 8,
"balance": 75
},
{
"account": 9,
"balance": 39
"balance": 11
}
]
}

Просмотреть файл

@ -49,11 +49,11 @@
"final": [
{
"account": 0,
"balance": 52
"balance": 0
},
{
"account": 1,
"balance": 32
"balance": 31146
},
{
"account": 2,
@ -61,27 +61,27 @@
},
{
"account": 3,
"balance": 41951
},
{
"account": 4,
"balance": 0
},
{
"account": 4,
"balance": -33
},
{
"account": 5,
"balance": 5
"balance": 626
},
{
"account": 6,
"balance": 94
"balance": -138
},
{
"account": 7,
"balance": 737
"balance": 284
},
{
"account": 8,
"balance": 493
"balance": 18
},
{
"account": 9,

Просмотреть файл

@ -49,43 +49,43 @@
"final": [
{
"account": 0,
"balance": 15
"balance": 48
},
{
"account": 1,
"balance": 79
"balance": 36
},
{
"account": 2,
"balance": 2181296
"balance": -27
},
{
"account": 3,
"balance": 37
"balance": 0
},
{
"account": 4,
"balance": 30
"balance": 2182875
},
{
"account": 5,
"balance": 0
"balance": 147
},
{
"account": 6,
"balance": -40
"balance": 0
},
{
"account": 7,
"balance": 505
"balance": 1
},
{
"account": 8,
"balance": -34
"balance": 0
},
{
"account": 9,
"balance": 0
"balance": -27
}
]
}

Просмотреть файл

@ -49,43 +49,43 @@
"final": [
{
"account": 0,
"balance": 49
},
{
"account": 1,
"balance": 43
},
{
"account": 2,
"balance": 1764
},
{
"account": 3,
"balance": 43
},
{
"account": 4,
"balance": 244
},
{
"account": 5,
"balance": -38
},
{
"account": 6,
"balance": -42
},
{
"account": 7,
"balance": 248231
},
{
"account": 8,
"balance": 0
},
{
"account": 1,
"balance": 0
},
{
"account": 2,
"balance": 13
},
{
"account": 3,
"balance": 0
},
{
"account": 4,
"balance": -85
},
{
"account": 5,
"balance": 0
},
{
"account": 6,
"balance": 241685
},
{
"account": 7,
"balance": 185
},
{
"account": 8,
"balance": -25
},
{
"account": 9,
"balance": -8
"balance": 128
}
]
}

Просмотреть файл

@ -21,6 +21,8 @@
#include <string>
#include <vector>
using namespace crypto;
class TlsClient
{
protected:

Просмотреть файл

@ -23,13 +23,6 @@ namespace crypto
static constexpr CurveID service_identity_curve_choice = CurveID::SECP384R1;
// SNIPPET_END: supported_curves
// Helper to access elliptic curve id from context
inline mbedtls_ecp_group_id get_mbedtls_ec_from_context(
const mbedtls_pk_context& ctx)
{
return mbedtls_pk_ec(ctx)->grp.id;
}
// Get message digest algorithm to use for given elliptic curve
inline MDType get_md_for_ec(CurveID ec)
{
@ -45,28 +38,4 @@ namespace crypto
}
}
}
inline mbedtls_md_type_t get_mbedtls_md_for_ec(
mbedtls_ecp_group_id ec, bool allow_none = false)
{
switch (ec)
{
case MBEDTLS_ECP_DP_SECP384R1:
return MBEDTLS_MD_SHA384;
case MBEDTLS_ECP_DP_SECP256R1:
return MBEDTLS_MD_SHA256;
default:
{
if (allow_none)
{
return MBEDTLS_MD_NONE;
}
else
{
const auto error = fmt::format("Unhandled ecp group id: {}", ec);
throw std::logic_error(error);
}
}
}
}
}

21
src/crypto/entropy.cpp Normal file
Просмотреть файл

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "entropy.h"
#include "mbedtls/entropy.h"
namespace crypto
{
EntropyPtr create_entropy()
{
if (use_drng)
{
if (!intel_drng_ptr)
intel_drng_ptr = std::make_shared<IntelDRNG>();
return intel_drng_ptr;
}
return std::make_shared<MbedtlsEntropy>();
}
}

Просмотреть файл

@ -2,43 +2,276 @@
// Licensed under the Apache 2.0 License.
#pragma once
#include "intel_drng.h"
#include "mbedtls_wrappers.h"
#include <functional>
#include <memory>
#include <cassert>
#include <utility>
#include <vector>
// Adapted from:
// https://software.intel.com/en-us/articles/intel-digital-random-number-generator-drng-software-implementation-guide
#define DRNG_NO_SUPPORT 0x0
#define DRNG_HAS_RDRAND 0x1
#define DRNG_HAS_RDSEED 0x2
// `It is recommended that applications attempt 10 retries in a tight loop in
// the unlikely event that the RDRAND instruction does not return a random
// number. This number is based on a binomial probability argument: given
// the design margins of the DRNG, the odds of ten failures in a row are
// astronomically small and would in fact be an indication of a larger CPU
// issue.`
#define RDRAND_RETRIES 10
namespace crypto
{
static bool use_drng = IntelDRNG::is_drng_supported();
using EntropyPtr = std::shared_ptr<Entropy>;
static EntropyPtr intel_drng_ptr;
EntropyPtr create_entropy();
using rng_func_t = int (*)(void* ctx, unsigned char* output, size_t len);
class MbedtlsEntropy : public Entropy
class Entropy
{
public:
virtual void* get_data() = 0;
virtual rng_func_t get_rng() = 0;
virtual std::vector<uint8_t> random(size_t len) = 0;
virtual void random(unsigned char* data, size_t len) = 0;
virtual uint64_t random64() = 0;
virtual ~Entropy() {}
};
class IntelDRNG : public Entropy
{
private:
mbedtls::Entropy entropy = mbedtls::make_unique<mbedtls::Entropy>();
mbedtls::CtrDrbg drbg = mbedtls::make_unique<mbedtls::CtrDrbg>();
typedef struct cpuid_struct
{
unsigned int eax;
unsigned int ebx;
unsigned int ecx;
unsigned int edx;
} cpuid_t;
static bool gen(uint64_t& v);
static void cpuid(cpuid_t* info, unsigned int leaf, unsigned int subleaf)
{
asm volatile(
"cpuid"
: "=a"(info->eax), "=b"(info->ebx), "=c"(info->ecx), "=d"(info->edx)
: "a"(leaf), "c"(subleaf));
}
static int _is_intel_cpu()
{
static int intel_cpu = -1;
cpuid_t info;
if (intel_cpu == -1)
{
cpuid(&info, 0, 0);
if (
memcmp((char*)&info.ebx, "Genu", 4) ||
memcmp((char*)&info.edx, "ineI", 4) ||
memcmp((char*)&info.ecx, "ntel", 4))
intel_cpu = 0;
else
intel_cpu = 1;
}
return intel_cpu;
}
static int get_drng_support()
{
static int drng_features = -1;
/* So we don't call cpuid multiple times for the same information */
if (drng_features == -1)
{
drng_features = DRNG_NO_SUPPORT;
if (_is_intel_cpu())
{
cpuid_t info;
cpuid(&info, 1, 0);
if ((info.ecx & 0x40000000) == 0x40000000)
drng_features |= DRNG_HAS_RDRAND;
cpuid(&info, 7, 0);
if ((info.ebx & 0x40000) == 0x40000)
drng_features |= DRNG_HAS_RDSEED;
}
}
return drng_features;
}
static int rdrand16_step(uint16_t* rand)
{
unsigned char ok;
asm volatile("rdrand %0; setc %1" : "=r"(*rand), "=qm"(ok));
return (int)ok;
}
static int rdrand32_step(uint32_t* rand)
{
unsigned char ok;
asm volatile("rdrand %0; setc %1" : "=r"(*rand), "=qm"(ok));
return (int)ok;
}
static int rdrand64_step(uint64_t* rand)
{
unsigned char ok;
asm volatile("rdrand %0; setc %1" : "=r"(*rand), "=qm"(ok));
return (int)ok;
}
static int rdrand16_retry(unsigned int retries, uint16_t* rand)
{
unsigned int count = 0;
while (count <= retries)
{
if (rdrand16_step(rand))
return 1;
++count;
}
return 0;
}
static int rdrand32_retry(unsigned int retries, uint32_t* rand)
{
unsigned int count = 0;
while (count <= retries)
{
if (rdrand32_step(rand))
return 1;
++count;
}
return 0;
}
static int rdrand64_retry(unsigned int retries, uint64_t* rand)
{
unsigned int count = 0;
while (count <= retries)
{
if (rdrand64_step(rand))
return 1;
++count;
}
return 0;
}
static unsigned int rdrand_get_bytes(unsigned int n, unsigned char* dest)
{
unsigned char *headstart, *tailstart = nullptr;
uint64_t* blockstart;
unsigned int count, ltail, lhead, lblock;
uint64_t i, temprand;
/* Get the address of the first 64-bit aligned block in the
* destination buffer. */
headstart = dest;
if (((uint64_t)headstart % (uint64_t)8) == 0)
{
blockstart = (uint64_t*)headstart;
lblock = n;
lhead = 0;
}
else
{
blockstart =
(uint64_t*)(((uint64_t)headstart & ~(uint64_t)7) + (uint64_t)8);
lhead = (unsigned int)((uint64_t)blockstart - (uint64_t)headstart);
lblock =
((n - lhead) & ~(unsigned int)7); // cwinter: this bit is/as buggy in
// the Intel examples.
}
/* Compute the number of 64-bit blocks and the remaining number
* of bytes (the tail) */
ltail = n - lblock - lhead;
count = lblock / 8; /* The number 64-bit rands needed */
assert(lhead < 8);
assert(lblock <= n);
assert(ltail < 8);
if (ltail)
tailstart = (unsigned char*)((uint64_t)blockstart + (uint64_t)lblock);
/* Populate the starting, mis-aligned section (the head) */
if (lhead)
{
if (!rdrand64_retry(RDRAND_RETRIES, &temprand))
return 0;
memcpy(headstart, &temprand, lhead);
}
/* Populate the central, aligned block */
for (i = 0; i < count; ++i, ++blockstart)
{
if (!rdrand64_retry(RDRAND_RETRIES, blockstart))
return i * 8 + lhead;
}
/* Populate the tail */
if (ltail)
{
if (!rdrand64_retry(RDRAND_RETRIES, &temprand))
return count * 8 + lhead;
memcpy(tailstart, &temprand, ltail);
}
return n;
}
// The following three functions should be used to generate
// randomness that will be used as seed for another RNG
static int rdseed16_step(uint16_t* seed)
{
unsigned char ok;
asm volatile("rdseed %0; setc %1" : "=r"(*seed), "=qm"(ok));
return (int)ok;
}
static int rdseed32_step(uint32_t* seed)
{
unsigned char ok;
asm volatile("rdseed %0; setc %1" : "=r"(*seed), "=qm"(ok));
return (int)ok;
}
static int rdseed64_step(uint64_t* seed)
{
unsigned char ok;
asm volatile("rdseed %0; setc %1" : "=r"(*seed), "=qm"(ok));
return (int)ok;
}
public:
MbedtlsEntropy()
IntelDRNG()
{
mbedtls_ctr_drbg_seed(
drbg.get(), mbedtls_entropy_func, entropy.get(), nullptr, 0);
if (!is_drng_supported())
throw std::logic_error("No support for RDRAND / RDSEED on this CPU.");
}
std::vector<uint8_t> random(size_t len) override
{
std::vector<uint8_t> data(len);
std::vector<uint8_t> buf(len);
if (mbedtls_ctr_drbg_random(drbg.get(), data.data(), data.size()) != 0)
if (rdrand_get_bytes(buf.size(), buf.data()) < buf.size())
throw std::logic_error("Couldn't create random data");
return data;
return buf;
}
uint64_t random64() override
@ -46,9 +279,7 @@ namespace crypto
uint64_t rnd;
uint64_t len = sizeof(uint64_t);
if (
mbedtls_ctr_drbg_random(
drbg.get(), reinterpret_cast<unsigned char*>(&rnd), len) != 0)
if (rdrand_get_bytes(len, reinterpret_cast<unsigned char*>(&rnd)) < len)
{
throw std::logic_error("Couldn't create random data");
}
@ -58,13 +289,15 @@ namespace crypto
void random(unsigned char* data, size_t len) override
{
if (mbedtls_ctr_drbg_random(drbg.get(), data, len) != 0)
if (rdrand_get_bytes(len, data) < len)
throw std::logic_error("Couldn't create random data");
}
static int rng(void* ctx, unsigned char* output, size_t len)
static int rng(void*, unsigned char* output, size_t len)
{
return mbedtls_ctr_drbg_random(ctx, output, len);
if (rdrand_get_bytes(len, output) < len)
throw std::logic_error("Couldn't create random data");
return 0;
}
rng_func_t get_rng() override
@ -74,20 +307,20 @@ namespace crypto
void* get_data() override
{
return drbg.get();
return this;
}
static bool is_drng_supported()
{
return (get_drng_support() & (DRNG_HAS_RDRAND | DRNG_HAS_RDSEED)) ==
(DRNG_HAS_RDRAND | DRNG_HAS_RDSEED);
}
};
inline EntropyPtr create_entropy()
{
if (use_drng)
{
if (!intel_drng_ptr)
intel_drng_ptr = std::make_shared<IntelDRNG>();
return intel_drng_ptr;
}
return std::make_shared<MbedtlsEntropy>();
}
static bool use_drng = IntelDRNG::is_drng_supported();
using EntropyPtr = std::shared_ptr<Entropy>;
static EntropyPtr intel_drng_ptr;
EntropyPtr create_entropy();
EntropyPtr create_entropy();
}

Просмотреть файл

@ -2,104 +2,13 @@
// Licensed under the Apache 2.0 License.
#include "hash.h"
#include "mbedtls_wrappers.h"
#include <mbedtls/sha256.h>
#include <openssl/sha.h>
#include <stdexcept>
using namespace std;
#include "mbedtls/hash.h"
#include "openssl/hash.h"
namespace crypto
{
void Sha256Hash::mbedtls_sha256(const CBuffer& data, uint8_t* h)
void default_sha256(const CBuffer& data, uint8_t* h)
{
mbedtls_sha256_context ctx;
mbedtls_sha256_init(&ctx);
mbedtls_sha256_starts_ret(&ctx, 0);
mbedtls_sha256_update_ret(&ctx, data.p, data.rawSize());
mbedtls_sha256_finish_ret(&ctx, h);
mbedtls_sha256_free(&ctx);
}
void Sha256Hash::openssl_sha256(const CBuffer& data, uint8_t* h)
{
SHA256_CTX ctx;
SHA256_Init(&ctx);
SHA256_Update(&ctx, data.p, data.rawSize());
SHA256_Final(h, &ctx);
}
ISha256MbedTLS::ISha256MbedTLS()
{
ctx = new mbedtls_sha256_context();
mbedtls_sha256_starts_ret((mbedtls_sha256_context*)ctx, 0);
}
ISha256MbedTLS::~ISha256MbedTLS()
{
delete (mbedtls_sha256_context*)ctx;
}
Sha256Hash ISha256MbedTLS::finalise()
{
if (!ctx)
{
throw std::logic_error("Attempting to use hash after it was finalised");
}
Sha256Hash r;
mbedtls_sha256_finish_ret((mbedtls_sha256_context*)ctx, r.h.data());
mbedtls_sha256_free((mbedtls_sha256_context*)ctx);
delete (mbedtls_sha256_context*)ctx;
ctx = nullptr;
return r;
}
void ISha256MbedTLS::update_hash(CBuffer data)
{
if (!ctx)
{
throw std::logic_error("Attempting to use hash after it was finalised");
}
mbedtls_sha256_update_ret(
(mbedtls_sha256_context*)ctx, data.p, data.rawSize());
}
ISha256OpenSSL::ISha256OpenSSL()
{
ctx = new SHA256_CTX;
SHA256_Init((SHA256_CTX*)ctx);
}
ISha256OpenSSL::~ISha256OpenSSL()
{
delete (SHA256_CTX*)ctx;
}
void ISha256OpenSSL::update_hash(CBuffer data)
{
if (!ctx)
{
throw std::logic_error("Attempting to use hash after it was finalised");
}
SHA256_Update((SHA256_CTX*)ctx, data.p, data.rawSize());
}
Sha256Hash ISha256OpenSSL::finalise()
{
if (!ctx)
{
throw std::logic_error("Attempting to use hash after it was finalised");
}
Sha256Hash r;
SHA256_Final(r.h.data(), (SHA256_CTX*)ctx);
delete (SHA256_CTX*)ctx;
ctx = nullptr;
return r;
return openssl_sha256(data, h);
}
}

Просмотреть файл

@ -1,216 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "ds/buffer.h"
#include "ds/json.h"
#include <mbedtls/md.h>
#include <mbedtls/pk.h>
#include <openssl/evp.h>
#include <openssl/sha.h>
#include "mbedtls/hash.h"
#include "openssl/hash.h"
#define FMT_HEADER_ONLY
#include <fmt/format.h>
#include <msgpack/msgpack.hpp>
#include <ostream>
namespace crypto
{
enum class MDType
{
NONE = 0,
SHA1,
SHA256,
SHA384,
SHA512
};
using HashBytes = std::vector<uint8_t>;
class HashProviderBase
{
public:
virtual HashBytes Hash(const uint8_t*, size_t, MDType) const = 0;
};
class MBedHashProvider : public HashProviderBase
{
public:
static inline mbedtls_md_type_t get_md_type(MDType type)
{
switch (type)
{
case MDType::NONE:
return MBEDTLS_MD_NONE;
case MDType::SHA1:
return MBEDTLS_MD_SHA1;
case MDType::SHA256:
return MBEDTLS_MD_SHA256;
case MDType::SHA384:
return MBEDTLS_MD_SHA384;
case MDType::SHA512:
return MBEDTLS_MD_SHA512;
default:
throw std::runtime_error("Unsupported hash algorithm");
}
return MBEDTLS_MD_NONE;
}
virtual HashBytes Hash(const uint8_t* data, size_t size, MDType type) const
{
HashBytes r;
const auto mbedtls_md_type = get_md_type(type);
const auto info = mbedtls_md_info_from_type(mbedtls_md_type);
const auto hash_size = mbedtls_md_get_size(info);
r.resize(hash_size);
if (mbedtls_md(info, data, size, r.data()) != 0)
r.clear();
return r;
}
};
class OpenSSLHashProvider : public HashProviderBase
{
public:
static inline const EVP_MD* get_md_type(MDType type)
{
switch (type)
{
case MDType::NONE:
return nullptr;
case MDType::SHA1:
return EVP_sha1();
case MDType::SHA256:
return EVP_sha256();
case MDType::SHA384:
return EVP_sha384();
case MDType::SHA512:
return EVP_sha512();
default:
throw std::runtime_error("Unsupported hash algorithm");
}
return nullptr;
}
virtual HashBytes Hash(const uint8_t* data, size_t size, MDType type) const
{
auto o_md_type = get_md_type(type);
HashBytes r(EVP_MD_size(o_md_type));
unsigned int len = 0;
if (EVP_Digest(data, size, r.data(), &len, o_md_type, NULL) != 1)
throw std::runtime_error("OpenSSL hash update error");
return r;
}
};
typedef MBedHashProvider HashProvider;
class Sha256Hash
{
public:
static constexpr size_t SIZE = 256 / 8;
Sha256Hash() : h{0} {}
Sha256Hash(const CBuffer& data) : h{0}
{
::SHA256(data.p, data.rawSize(), h.data());
}
std::array<uint8_t, SIZE> h;
static void mbedtls_sha256(const CBuffer& data, uint8_t* h);
static void openssl_sha256(const CBuffer& data, uint8_t* h);
friend std::ostream& operator<<(
std::ostream& os, const crypto::Sha256Hash& h)
{
for (unsigned i = 0; i < crypto::Sha256Hash::SIZE; i++)
{
os << std::hex << static_cast<int>(h.h[i]);
}
return os;
}
std::string hex_str() const
{
return fmt::format("{:02x}", fmt::join(h, ""));
};
MSGPACK_DEFINE(h);
};
DECLARE_JSON_TYPE(Sha256Hash);
DECLARE_JSON_REQUIRED_FIELDS(Sha256Hash, h);
inline bool operator==(const Sha256Hash& lhs, const Sha256Hash& rhs)
{
for (unsigned i = 0; i < crypto::Sha256Hash::SIZE; i++)
{
if (lhs.h[i] != rhs.h[i])
{
return false;
}
}
return true;
}
inline bool operator!=(const Sha256Hash& lhs, const Sha256Hash& rhs)
{
return !(lhs == rhs);
}
// Incremental Hash Objects
class ISha256HashBase
{
public:
ISha256HashBase() {}
virtual ~ISha256HashBase() {}
virtual void update_hash(CBuffer data) = 0;
virtual Sha256Hash finalise() = 0;
template <typename T>
void update(const T& t)
{
update_hash({reinterpret_cast<const uint8_t*>(&t), sizeof(T)});
}
template <>
void update<std::vector<uint8_t>>(const std::vector<uint8_t>& d)
{
update_hash({d.data(), d.size()});
}
};
class ISha256MbedTLS : public ISha256HashBase
{
public:
ISha256MbedTLS();
~ISha256MbedTLS();
virtual void update_hash(CBuffer data);
virtual Sha256Hash finalise();
protected:
void* ctx;
};
class ISha256OpenSSL : public ISha256HashBase
{
public:
ISha256OpenSSL();
~ISha256OpenSSL();
virtual void update_hash(CBuffer data);
virtual Sha256Hash finalise();
protected:
void* ctx;
};
typedef ISha256OpenSSL ISha256Hash;
}

107
src/crypto/hash_base.h Normal file
Просмотреть файл

@ -0,0 +1,107 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "ds/buffer.h"
#include "ds/json.h"
#include <cstdint>
#include <iostream>
#include <msgpack/msgpack.hpp>
#include <vector>
namespace crypto
{
extern void default_sha256(const CBuffer& data, uint8_t* h);
enum class MDType
{
NONE = 0,
SHA1,
SHA256,
SHA384,
SHA512
};
using HashBytes = std::vector<uint8_t>;
class HashProviderBase
{
public:
virtual HashBytes Hash(const uint8_t*, size_t, MDType) const = 0;
};
class Sha256Hash
{
public:
static constexpr size_t SIZE = 256 / 8;
Sha256Hash() : h{0} {}
Sha256Hash(const CBuffer& data) : h{0}
{
default_sha256(data, h.data());
}
std::array<uint8_t, SIZE> h;
friend std::ostream& operator<<(
std::ostream& os, const crypto::Sha256Hash& h)
{
for (unsigned i = 0; i < crypto::Sha256Hash::SIZE; i++)
{
os << std::hex << static_cast<int>(h.h[i]);
}
return os;
}
std::string hex_str() const
{
return fmt::format("{:02x}", fmt::join(h, ""));
};
MSGPACK_DEFINE(h);
};
DECLARE_JSON_TYPE(Sha256Hash);
DECLARE_JSON_REQUIRED_FIELDS(Sha256Hash, h);
inline bool operator==(const Sha256Hash& lhs, const Sha256Hash& rhs)
{
for (unsigned i = 0; i < crypto::Sha256Hash::SIZE; i++)
{
if (lhs.h[i] != rhs.h[i])
{
return false;
}
}
return true;
}
inline bool operator!=(const Sha256Hash& lhs, const Sha256Hash& rhs)
{
return !(lhs == rhs);
}
// Incremental Hash Objects
class ISha256HashBase
{
public:
ISha256HashBase() {}
virtual ~ISha256HashBase() {}
virtual void update_hash(CBuffer data) = 0;
virtual Sha256Hash finalise() = 0;
template <typename T>
void update(const T& t)
{
update_hash({reinterpret_cast<const uint8_t*>(&t), sizeof(T)});
}
template <>
void update<std::vector<uint8_t>>(const std::vector<uint8_t>& d)
{
update_hash({d.data(), d.size()});
}
};
}

Просмотреть файл

@ -1,319 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include <cassert>
#include <utility>
#include <vector>
// Adapted from:
// https://software.intel.com/en-us/articles/intel-digital-random-number-generator-drng-software-implementation-guide
#define DRNG_NO_SUPPORT 0x0
#define DRNG_HAS_RDRAND 0x1
#define DRNG_HAS_RDSEED 0x2
// `It is recommended that applications attempt 10 retries in a tight loop in
// the unlikely event that the RDRAND instruction does not return a random
// number. This number is based on a binomial probability argument: given
// the design margins of the DRNG, the odds of ten failures in a row are
// astronomically small and would in fact be an indication of a larger CPU
// issue.`
#define RDRAND_RETRIES 10
namespace crypto
{
using rng_func_t = int (*)(void* ctx, unsigned char* output, size_t len);
class Entropy
{
public:
virtual void* get_data() = 0;
virtual rng_func_t get_rng() = 0;
virtual std::vector<uint8_t> random(size_t len) = 0;
virtual void random(unsigned char* data, size_t len) = 0;
virtual uint64_t random64() = 0;
virtual ~Entropy() {}
};
class IntelDRNG : public Entropy
{
private:
typedef struct cpuid_struct
{
unsigned int eax;
unsigned int ebx;
unsigned int ecx;
unsigned int edx;
} cpuid_t;
static void cpuid(cpuid_t* info, unsigned int leaf, unsigned int subleaf)
{
asm volatile(
"cpuid"
: "=a"(info->eax), "=b"(info->ebx), "=c"(info->ecx), "=d"(info->edx)
: "a"(leaf), "c"(subleaf));
}
static int _is_intel_cpu()
{
static int intel_cpu = -1;
cpuid_t info;
if (intel_cpu == -1)
{
cpuid(&info, 0, 0);
if (
memcmp((char*)&info.ebx, "Genu", 4) ||
memcmp((char*)&info.edx, "ineI", 4) ||
memcmp((char*)&info.ecx, "ntel", 4))
intel_cpu = 0;
else
intel_cpu = 1;
}
return intel_cpu;
}
static int get_drng_support()
{
static int drng_features = -1;
/* So we don't call cpuid multiple times for the same information */
if (drng_features == -1)
{
drng_features = DRNG_NO_SUPPORT;
if (_is_intel_cpu())
{
cpuid_t info;
cpuid(&info, 1, 0);
if ((info.ecx & 0x40000000) == 0x40000000)
drng_features |= DRNG_HAS_RDRAND;
cpuid(&info, 7, 0);
if ((info.ebx & 0x40000) == 0x40000)
drng_features |= DRNG_HAS_RDSEED;
}
}
return drng_features;
}
static int rdrand16_step(uint16_t* rand)
{
unsigned char ok;
asm volatile("rdrand %0; setc %1" : "=r"(*rand), "=qm"(ok));
return (int)ok;
}
static int rdrand32_step(uint32_t* rand)
{
unsigned char ok;
asm volatile("rdrand %0; setc %1" : "=r"(*rand), "=qm"(ok));
return (int)ok;
}
static int rdrand64_step(uint64_t* rand)
{
unsigned char ok;
asm volatile("rdrand %0; setc %1" : "=r"(*rand), "=qm"(ok));
return (int)ok;
}
static int rdrand16_retry(unsigned int retries, uint16_t* rand)
{
unsigned int count = 0;
while (count <= retries)
{
if (rdrand16_step(rand))
return 1;
++count;
}
return 0;
}
static int rdrand32_retry(unsigned int retries, uint32_t* rand)
{
unsigned int count = 0;
while (count <= retries)
{
if (rdrand32_step(rand))
return 1;
++count;
}
return 0;
}
static int rdrand64_retry(unsigned int retries, uint64_t* rand)
{
unsigned int count = 0;
while (count <= retries)
{
if (rdrand64_step(rand))
return 1;
++count;
}
return 0;
}
static unsigned int rdrand_get_bytes(unsigned int n, unsigned char* dest)
{
unsigned char *headstart, *tailstart = nullptr;
uint64_t* blockstart;
unsigned int count, ltail, lhead, lblock;
uint64_t i, temprand;
/* Get the address of the first 64-bit aligned block in the
* destination buffer. */
headstart = dest;
if (((uint64_t)headstart % (uint64_t)8) == 0)
{
blockstart = (uint64_t*)headstart;
lblock = n;
lhead = 0;
}
else
{
blockstart =
(uint64_t*)(((uint64_t)headstart & ~(uint64_t)7) + (uint64_t)8);
lhead = (unsigned int)((uint64_t)blockstart - (uint64_t)headstart);
lblock =
((n - lhead) & ~(unsigned int)7); // cwinter: this bit is/as buggy in
// the Intel examples.
}
/* Compute the number of 64-bit blocks and the remaining number
* of bytes (the tail) */
ltail = n - lblock - lhead;
count = lblock / 8; /* The number 64-bit rands needed */
assert(lhead < 8);
assert(lblock <= n);
assert(ltail < 8);
if (ltail)
tailstart = (unsigned char*)((uint64_t)blockstart + (uint64_t)lblock);
/* Populate the starting, mis-aligned section (the head) */
if (lhead)
{
if (!rdrand64_retry(RDRAND_RETRIES, &temprand))
return 0;
memcpy(headstart, &temprand, lhead);
}
/* Populate the central, aligned block */
for (i = 0; i < count; ++i, ++blockstart)
{
if (!rdrand64_retry(RDRAND_RETRIES, blockstart))
return i * 8 + lhead;
}
/* Populate the tail */
if (ltail)
{
if (!rdrand64_retry(RDRAND_RETRIES, &temprand))
return count * 8 + lhead;
memcpy(tailstart, &temprand, ltail);
}
return n;
}
// The following three functions should be used to generate
// randomness that will be used as seed for another RNG
static int rdseed16_step(uint16_t* seed)
{
unsigned char ok;
asm volatile("rdseed %0; setc %1" : "=r"(*seed), "=qm"(ok));
return (int)ok;
}
static int rdseed32_step(uint32_t* seed)
{
unsigned char ok;
asm volatile("rdseed %0; setc %1" : "=r"(*seed), "=qm"(ok));
return (int)ok;
}
static int rdseed64_step(uint64_t* seed)
{
unsigned char ok;
asm volatile("rdseed %0; setc %1" : "=r"(*seed), "=qm"(ok));
return (int)ok;
}
public:
IntelDRNG()
{
if (!is_drng_supported())
throw std::logic_error("No support for RDRAND / RDSEED on this CPU.");
}
std::vector<uint8_t> random(size_t len) override
{
std::vector<uint8_t> buf(len);
if (rdrand_get_bytes(buf.size(), buf.data()) < buf.size())
throw std::logic_error("Couldn't create random data");
return buf;
}
uint64_t random64() override
{
uint64_t rnd;
uint64_t len = sizeof(uint64_t);
if (rdrand_get_bytes(len, reinterpret_cast<unsigned char*>(&rnd)) < len)
{
throw std::logic_error("Couldn't create random data");
}
return rnd;
}
void random(unsigned char* data, size_t len) override
{
if (rdrand_get_bytes(len, data) < len)
throw std::logic_error("Couldn't create random data");
}
static int rng(void*, unsigned char* output, size_t len)
{
if (rdrand_get_bytes(len, output) < len)
throw std::logic_error("Couldn't create random data");
return 0;
}
rng_func_t get_rng() override
{
return &rng;
}
void* get_data() override
{
return this;
}
static bool is_drng_supported()
{
return (get_drng_support() & (DRNG_HAS_RDRAND | DRNG_HAS_RDSEED)) ==
(DRNG_HAS_RDRAND | DRNG_HAS_RDSEED);
}
};
}

46
src/crypto/key_pair.cpp Normal file
Просмотреть файл

@ -0,0 +1,46 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "key_pair.h"
#include "mbedtls/key_pair.h"
#include "mbedtls/public_key.h"
#include "openssl/key_pair.h"
#include "openssl/public_key.h"
#include <cstring>
#include <iomanip>
#include <limits>
#include <memory>
#include <string>
namespace crypto
{
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
using PublicKeyImpl = PublicKey_mbedTLS;
using KeyPairImpl = KeyPair_mbedTLS;
#else
using PublicKeyImpl = PublicKey_OpenSSL;
using KeyPairImpl = KeyPair_OpenSSL;
#endif
PublicKeyPtr make_public_key(const Pem& pem)
{
return std::make_shared<PublicKeyImpl>(pem);
}
PublicKeyPtr make_public_key(const std::vector<uint8_t>& der)
{
return std::make_shared<PublicKeyImpl>(der);
}
KeyPairPtr make_key_pair(CurveID curve_id)
{
return std::make_shared<KeyPairImpl>(curve_id);
}
KeyPairPtr make_key_pair(const Pem& pem, CBuffer pw)
{
return std::make_shared<KeyPairImpl>(pem, pw);
}
}

Просмотреть файл

@ -3,68 +3,96 @@
#pragma once
#include "curve.h"
#include "key_pair_base.h"
#include "key_pair_mbedtls.h"
#include "key_pair_openssl.h"
#include "hash.h"
#include "pem.h"
#include "public_key.h"
#include "san.h"
#include <cstring>
#include <iomanip>
#include <limits>
#include <memory>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
namespace crypto
{
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
using PublicKey = PublicKey_mbedTLS;
#else
using PublicKey = PublicKey_OpenSSL;
#endif
using PublicKeyPtr = std::shared_ptr<PublicKeyBase>;
class KeyPair
{
public:
virtual ~KeyPair() = default;
virtual Pem private_key_pem() const = 0;
virtual Pem public_key_pem() const = 0;
virtual bool verify(
const std::vector<uint8_t>& contents,
const std::vector<uint8_t>& signature) = 0;
virtual std::vector<uint8_t> sign_hash(
const uint8_t* hash, size_t hash_size) const = 0;
virtual int sign_hash(
const uint8_t* hash,
size_t hash_size,
size_t* sig_size,
uint8_t* sig) const = 0;
virtual std::vector<uint8_t> sign(CBuffer d, MDType md_type = {}) const = 0;
virtual Pem create_csr(const std::string& name) const = 0;
virtual Pem sign_csr(
const Pem& issuer_cert,
const Pem& signing_request,
const std::vector<SubjectAltName> subject_alt_names,
bool ca = false) const = 0;
Pem self_sign(
const std::string& name,
const std::optional<SubjectAltName> subject_alt_name = std::nullopt,
bool ca = true) const
{
std::vector<SubjectAltName> sans;
if (subject_alt_name.has_value())
sans.push_back(subject_alt_name.value());
auto csr = create_csr(name);
return sign_csr(Pem(0), csr, sans, ca);
}
Pem self_sign(
const std::string& name,
const std::vector<SubjectAltName> subject_alt_names,
bool ca = true) const
{
auto csr = create_csr(name);
return sign_csr(Pem(0), csr, subject_alt_names, ca);
}
};
using PublicKeyPtr = std::shared_ptr<PublicKey>;
using KeyPairPtr = std::shared_ptr<KeyPair>;
/**
* Construct PublicKey from a raw public key in PEM format
*
* @param public_pem Sequence of bytes containing the key in PEM format
*/
inline PublicKeyPtr make_public_key(const Pem& public_pem)
{
return std::make_shared<PublicKey>(public_pem);
}
PublicKeyPtr make_public_key(const Pem& pem);
/**
* Construct PublicKey from a raw public key in DER format
*
* @param public_der Sequence of bytes containing the key in DER format
*/
inline PublicKeyPtr make_public_key(const std::vector<uint8_t> public_der)
{
return std::make_shared<PublicKey>(public_der);
}
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
using KeyPair = KeyPair_mbedTLS;
#else
using KeyPair = KeyPair_OpenSSL;
#endif
using KeyPairPtr = std::shared_ptr<KeyPairBase>;
PublicKeyPtr make_public_key(const std::vector<uint8_t> der);
/**
* Create a new public / private ECDSA key pair on specified curve and
* implementation
*/
inline KeyPairPtr make_key_pair(
CurveID curve_id = service_identity_curve_choice)
{
return std::make_shared<KeyPair>(curve_id);
}
KeyPairPtr make_key_pair(CurveID curve_id = service_identity_curve_choice);
/**
* Create a public / private ECDSA key pair from existing private key data
*/
inline KeyPairPtr make_key_pair(const Pem& pkey, CBuffer pw = nullb)
{
return std::make_shared<KeyPair>(pkey, pw);
}
KeyPairPtr make_key_pair(const Pem& pkey, CBuffer pw = nullb);
}

Просмотреть файл

@ -1,719 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "ds/net.h"
#include "entropy.h"
#include "key_pair_base.h"
#include "mbedtls_wrappers.h"
#include "san.h"
#define FMT_HEADER_ONLY
#include <fmt/format.h>
#include <mbedtls/asn1write.h>
#include <mbedtls/bignum.h>
#include <mbedtls/error.h>
#include <mbedtls/oid.h>
#include <mbedtls/pem.h>
#include <mbedtls/pk.h>
#include <mbedtls/x509.h>
#include <mbedtls/x509_crt.h>
namespace crypto
{
class PublicKey_mbedTLS : public PublicKeyBase
{
protected:
mbedtls::PKContext ctx = mbedtls::make_unique<mbedtls::PKContext>();
PublicKey_mbedTLS() {}
inline mbedtls_md_type_t get_md_type(MDType mdt) const
{
switch (mdt)
{
case MDType::NONE:
return MBEDTLS_MD_NONE;
case MDType::SHA1:
return MBEDTLS_MD_SHA1;
case MDType::SHA256:
return MBEDTLS_MD_SHA256;
case MDType::SHA384:
return MBEDTLS_MD_SHA384;
case MDType::SHA512:
return MBEDTLS_MD_SHA512;
default:
return MBEDTLS_MD_NONE;
}
return MBEDTLS_MD_NONE;
}
public:
PublicKey_mbedTLS(PublicKey_mbedTLS&& pk) = default;
/**
* Construct from PEM
*/
PublicKey_mbedTLS(const Pem& pem)
{
int rc = mbedtls_pk_parse_public_key(ctx.get(), pem.data(), pem.size());
if (rc != 0)
{
throw std::logic_error(fmt::format(
"Could not parse public key PEM: {}\n\n(Key: {})",
error_string(rc),
pem.str()));
}
}
/**
* Construct from DER
*/
PublicKey_mbedTLS(const std::vector<uint8_t>& der)
{
int rc = mbedtls_pk_parse_public_key(ctx.get(), der.data(), der.size());
if (rc != 0)
{
throw std::logic_error(
fmt::format("Could not parse public key DER: {}", error_string(rc)));
}
}
virtual CurveID get_curve_id() const override
{
return get_curve_id(ctx.get());
}
static CurveID get_curve_id(const mbedtls_pk_context* pk_ctx)
{
if (mbedtls_pk_can_do(pk_ctx, MBEDTLS_PK_ECKEY))
{
auto grp_id = mbedtls_pk_ec(*pk_ctx)->grp.id;
switch (grp_id)
{
case MBEDTLS_ECP_DP_SECP384R1:
return CurveID::SECP384R1;
case MBEDTLS_ECP_DP_SECP256R1:
return CurveID::SECP256R1;
default:
throw std::logic_error(
fmt::format("unsupported mbedTLS group ID {}", grp_id));
}
}
return CurveID::NONE;
}
/**
* Construct from a pre-initialised pk context
*/
PublicKey_mbedTLS(mbedtls::PKContext&& c) : ctx(std::move(c)) {}
virtual ~PublicKey_mbedTLS() = default;
using PublicKeyBase::verify;
using PublicKeyBase::verify_hash;
virtual bool verify(
const uint8_t* contents,
size_t contents_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type,
HashBytes& bytes) override
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
MBedHashProvider hp;
bytes = hp.Hash(contents, contents_size, md_type);
return verify_hash(bytes.data(), bytes.size(), sig, sig_size, md_type);
}
virtual bool verify_hash(
const uint8_t* hash,
size_t hash_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type) override
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
const auto mmdt = get_md_type(md_type);
int rc =
mbedtls_pk_verify(ctx.get(), mmdt, hash, hash_size, sig, sig_size);
if (rc)
LOG_DEBUG_FMT("Failed to verify signature: {}", error_string(rc));
return rc == 0;
}
/**
* Get the public key in PEM format
*/
virtual Pem public_key_pem() const override
{
uint8_t data[max_pem_key_size];
int rc = mbedtls_pk_write_pubkey_pem(ctx.get(), data, max_pem_key_size);
if (rc != 0)
{
throw std::logic_error(
"mbedtls_pk_write_pubkey_pem: " + error_string(rc));
}
const size_t len = strlen((char const*)data);
return Pem(data, len);
}
mbedtls_pk_context* get_raw_context() const
{
return ctx.get();
}
static std::string error_string(int err)
{
constexpr size_t len = 256;
char buf[len];
mbedtls_strerror(err, buf, len);
if (strlen(buf) == 0)
{
return std::to_string(err);
}
return std::string(buf);
}
};
class KeyPair_mbedTLS : public PublicKey_mbedTLS, public KeyPairBase
{
public:
inline mbedtls_ecp_group_id get_mbedtls_group_id(CurveID gid)
{
switch (gid)
{
case CurveID::NONE:
return MBEDTLS_ECP_DP_NONE;
case CurveID::SECP384R1:
return MBEDTLS_ECP_DP_SECP384R1;
case CurveID::SECP256R1:
return MBEDTLS_ECP_DP_SECP256R1;
default:
throw std::logic_error(fmt::format("unsupported CurveID {}", gid));
}
return MBEDTLS_ECP_DP_NONE;
}
/**
* Create a new public / private ECDSA key pair
*/
KeyPair_mbedTLS(CurveID cid) : PublicKey_mbedTLS()
{
mbedtls_ecp_group_id ec = get_mbedtls_group_id(cid);
EntropyPtr entropy = create_entropy();
int rc = mbedtls_pk_setup(
ctx.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY));
if (rc != 0)
{
throw std::logic_error(
"Could not set up ECDSA context: " + error_string(rc));
}
rc = mbedtls_ecp_gen_key(
ec, mbedtls_pk_ec(*ctx), entropy->get_rng(), entropy->get_data());
if (rc != 0)
{
throw std::logic_error(
"Could not generate ECDSA keypair: " + error_string(rc));
}
const auto actual_ec = get_mbedtls_ec_from_context(*ctx);
if (actual_ec != ec)
{
throw std::logic_error(
"Created key and received unexpected type: " +
std::to_string(actual_ec) + " != " + error_string(ec));
}
}
KeyPair_mbedTLS(const Pem& pem, CBuffer pw = nullb)
{
// keylen is +1 to include terminating null byte
int rc =
mbedtls_pk_parse_key(ctx.get(), pem.data(), pem.size(), pw.p, pw.n);
if (rc != 0)
{
throw std::logic_error(
"Could not parse private key: " + error_string(rc));
}
}
/**
* Initialise from existing pre-parsed key
*/
KeyPair_mbedTLS(mbedtls::PKContext&& k) : PublicKey_mbedTLS(std::move(k)) {}
KeyPair_mbedTLS(const KeyPair_mbedTLS&) = delete;
using PublicKey_mbedTLS::verify;
virtual bool verify(
const std::vector<uint8_t>& contents,
const std::vector<uint8_t>& signature) override
{
return PublicKey_mbedTLS::verify(contents, signature);
}
/**
* Get the private key in PEM format
*/
virtual Pem private_key_pem() const override
{
uint8_t data[max_pem_key_size];
int rc = mbedtls_pk_write_key_pem(ctx.get(), data, max_pem_key_size);
if (rc != 0)
{
throw std::logic_error("mbedtls_pk_write_key_pem: " + error_string(rc));
}
const size_t len = strlen((char const*)data);
return Pem(data, len);
}
/**
* Get the public key in PEM format
*/
virtual Pem public_key_pem() const override
{
return PublicKey_mbedTLS::public_key_pem();
}
/**
* Create signature over hash of data from private key.
*
* @param d data
*
* @return Signature as a vector
*/
virtual std::vector<uint8_t> sign(
CBuffer d, MDType md_type = {}) const override
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
MBedHashProvider hp;
HashBytes hash = hp.Hash(d.p, d.rawSize(), md_type);
return sign_hash(hash.data(), hash.size());
}
/**
* Write signature over hash of data, and the size of that signature to
* specified locations.
*
* Important: sig must point somewhere that's at least
* MBEDTLS_E{C,D}DSA_MAX_LEN.
*
* @param d data
* @param sig_size location to which the signature size will be written.
* Initial value should be max size of sig
* @param sig location to which the signature will be written
*
* @return 0 if successful, error code of mbedtls_pk_sign otherwise,
* or 0xf if the signature_size exceeds that of a uint8_t.
*/
int sign(
CBuffer d, size_t* sig_size, uint8_t* sig, MDType md_type = {}) const
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
MBedHashProvider hp;
HashBytes hash = hp.Hash(d.p, d.rawSize(), md_type);
return sign_hash(hash.data(), hash.size(), sig_size, sig);
}
/**
* Create signature over hashed data.
*
* @param hash First byte in hash sequence
* @param hash_size Number of bytes in hash sequence
*
* @return Signature as a vector
*/
std::vector<uint8_t> sign_hash(
const uint8_t* hash, size_t hash_size) const override
{
uint8_t sig[MBEDTLS_ECDSA_MAX_LEN];
size_t written = sizeof(sig);
if (sign_hash(hash, hash_size, &written, sig) != 0)
{
return {};
}
return {sig, sig + written};
}
virtual int sign_hash(
const uint8_t* hash,
size_t hash_size,
size_t* sig_size,
uint8_t* sig) const override
{
EntropyPtr entropy = create_entropy();
// mbedTLS wants an MD algorithm; even when it doesn't use it, it stil
// checks that the hash size matches.
const auto mmdt = get_md_type(get_md_for_ec(get_curve_id()));
int r = mbedtls_pk_sign(
ctx.get(),
mmdt,
hash,
hash_size,
sig,
sig_size,
entropy->get_rng(),
entropy->get_data());
return r;
}
/**
* Create a certificate signing request for this key pair. If we were
* loaded from a private key, there will be no public key available for
* this call.
*/
virtual Pem create_csr(const std::string& name) const override
{
auto csr = mbedtls::make_unique<mbedtls::X509WriteCsr>();
mbedtls_x509write_csr_set_md_alg(csr.get(), MBEDTLS_MD_SHA512);
if (mbedtls_x509write_csr_set_subject_name(csr.get(), name.c_str()) != 0)
return {};
mbedtls_x509write_csr_set_key(csr.get(), ctx.get());
uint8_t buf[4096];
memset(buf, 0, sizeof(buf));
EntropyPtr entropy = create_entropy();
if (
mbedtls_x509write_csr_pem(
csr.get(),
buf,
sizeof(buf),
entropy->get_rng(),
entropy->get_data()) != 0)
return {};
auto len = strlen((char*)buf);
return Pem(buf, len);
}
void MCHK(int rc) const
{
if (rc != 0)
{
throw std::logic_error(
fmt::format("mbedTLS error: {}", error_string(rc)));
}
}
// Unfortunately, mbedtls does not provide a convenient API to write x509v3
// extensions for all supported Subject Alternative Name (SAN). Until they
// do, we have to write raw ASN1 ourselves.
// rfc5280 does not specify a maximum length for SAN,
// but rfc1035 specified that 255 bytes is enough for a DNS name
static constexpr auto max_san_length = 256;
static constexpr auto max_san_entries = 8;
// As per https://tools.ietf.org/html/rfc5280#section-4.2.1.6
enum san_type
{
other_name = 0,
rfc822_name = 1,
dns_name = 2,
x400_address = 3,
directory_name = 4,
edi_party_name = 5,
uniform_resource_identifier = 6,
ip_address = 7,
registeredID = 8
};
inline int x509write_crt_set_subject_alt_name(
mbedtls_x509write_cert* ctx,
const char* name,
san_type san = san_type::dns_name)
{
uint8_t san_buf[max_san_length];
int ret = 0;
size_t len = 0;
// mbedtls asn1 write API writes backward in san_buf
uint8_t* pc = san_buf + max_san_length;
auto name_len = strlen(name);
if (name_len > max_san_length)
{
throw std::logic_error(fmt::format(
"Subject Alternative Name {} is too long ({}>{})",
name,
name_len,
max_san_length));
}
switch (san)
{
case san_type::dns_name:
{
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_raw_buffer(
&pc, san_buf, (const unsigned char*)name, name_len));
MBEDTLS_ASN1_CHK_ADD(
len, mbedtls_asn1_write_len(&pc, san_buf, name_len));
break;
}
// mbedtls (2.16.2) only supports parsing of subject alternative name
// that is DNS= (so no IPAddress=). When connecting to a node that has
// IPAddress set, mbedtls_ssl_set_hostname() should not be called.
// However, it should work fine with a majority of other clients (e.g.
// curl).
case san_type::ip_address:
{
auto addr = ds::ip_to_binary(name);
if (!addr.has_value())
{
throw std ::logic_error(fmt::format(
"Subject Alternative Name {} is not a valid IPv4 or "
"IPv6 address",
name));
}
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_raw_buffer(
&pc, san_buf, (const unsigned char*)&addr->buf, addr->size));
MBEDTLS_ASN1_CHK_ADD(
len, mbedtls_asn1_write_len(&pc, san_buf, addr->size));
break;
}
default:
{
throw std::logic_error(
fmt::format("Subject Alternative Name {} is not supported", san));
}
}
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_tag(
&pc, san_buf, MBEDTLS_ASN1_CONTEXT_SPECIFIC | san));
MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_len(&pc, san_buf, len));
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_tag(
&pc, san_buf, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE));
return mbedtls_x509write_crt_set_extension(
ctx,
MBEDTLS_OID_SUBJECT_ALT_NAME,
MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME),
0, // Mark SAN as non-critical
san_buf + max_san_length - len,
len);
}
inline int x509write_crt_set_subject_alt_names(
mbedtls_x509write_cert* ctx,
const std::vector<SubjectAltName>& sans) const
{
if (sans.size() == 0)
return 0;
if (sans.size() > max_san_entries)
{
throw std::logic_error(fmt::format(
"Cannot set more than {} subject alternative names",
max_san_entries));
}
// The factor of two is an extremely conservative provision for ASN.1
// metadata
size_t buf_len = sans.size() * max_san_length * 2;
std::vector<uint8_t> buf(buf_len);
uint8_t* san_buf = buf.data();
int ret = 0;
size_t len = 0;
// mbedtls asn1 write API writes backward in san_buf
uint8_t* pc = san_buf + buf_len;
for (auto& san : sans)
{
if (san.san.size() > max_san_length)
{
throw std::logic_error(fmt::format(
"Subject Alternative Name {} is too long ({}>{})",
san.san,
san.san.size(),
max_san_length));
}
if (san.is_ip)
{
// mbedtls (2.16.2) only supports parsing of subject alternative name
// that is DNS= (so no IPAddress=). When connecting to a node that has
// IPAddress set, mbedtls_ssl_set_hostname() should not be called.
// However, it should work fine with a majority of other clients (e.g.
// curl).
auto addr = ds::ip_to_binary(san.san.c_str());
if (!addr.has_value())
{
throw std ::logic_error(fmt::format(
"Subject Alternative Name {} is not a valid IPv4 or "
"IPv6 address",
san.san));
}
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_raw_buffer(
&pc, san_buf, (const unsigned char*)&addr->buf, addr->size));
MBEDTLS_ASN1_CHK_ADD(
len, mbedtls_asn1_write_len(&pc, san_buf, addr->size));
}
else
{
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_raw_buffer(
&pc,
san_buf,
(const unsigned char*)san.san.data(),
san.san.size()));
MBEDTLS_ASN1_CHK_ADD(
len, mbedtls_asn1_write_len(&pc, san_buf, san.san.size()));
}
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_tag(
&pc,
san_buf,
MBEDTLS_ASN1_CONTEXT_SPECIFIC |
(san.is_ip ? san_type::ip_address : san_type::dns_name)));
}
MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_len(&pc, san_buf, len));
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_tag(
&pc, san_buf, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE));
return mbedtls_x509write_crt_set_extension(
ctx,
MBEDTLS_OID_SUBJECT_ALT_NAME,
MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME),
0, // Mark SAN as non-critical
san_buf + buf_len - len,
len);
}
virtual Pem sign_csr(
const Pem& issuer_cert,
const Pem& signing_request,
const std::vector<SubjectAltName> subject_alt_names,
bool ca = false) const override
{
auto entropy = create_entropy();
auto csr = mbedtls::make_unique<mbedtls::X509Csr>();
auto serial = mbedtls::make_unique<mbedtls::MPI>();
auto crt = mbedtls::make_unique<mbedtls::X509WriteCrt>();
auto icrt = mbedtls::make_unique<mbedtls::X509Crt>();
MCHK(mbedtls_x509_csr_parse(
csr.get(), signing_request.data(), signing_request.size()));
char subject[512];
mbedtls_x509_dn_gets(subject, sizeof(subject), &csr->subject);
mbedtls_x509write_crt_set_md_alg(
crt.get(), get_mbedtls_md_for_ec(get_mbedtls_ec_from_context(*ctx)));
mbedtls_x509write_crt_set_subject_key(crt.get(), &csr->pk);
if (!issuer_cert.empty())
{
MCHK(mbedtls_x509_crt_parse(
icrt.get(), issuer_cert.data(), issuer_cert.size()));
mbedtls_x509write_crt_set_issuer_key(crt.get(), ctx.get());
char issuer_name[512];
mbedtls_x509_dn_gets(issuer_name, sizeof(issuer_name), &icrt->subject);
MCHK(mbedtls_x509write_crt_set_issuer_name(crt.get(), issuer_name));
}
else
{
mbedtls_x509write_crt_set_issuer_key(crt.get(), ctx.get());
MCHK(mbedtls_x509write_crt_set_issuer_name(crt.get(), subject));
}
MCHK(mbedtls_mpi_fill_random(
serial.get(), 16, entropy->get_rng(), entropy->get_data()));
MCHK(mbedtls_x509write_crt_set_subject_name(crt.get(), subject));
MCHK(mbedtls_x509write_crt_set_serial(crt.get(), serial.get()));
// Note: 825-day validity range
// https://support.apple.com/en-us/HT210176
MCHK(mbedtls_x509write_crt_set_validity(
crt.get(), "20191101000000", "20211231235959"));
MCHK(
mbedtls_x509write_crt_set_basic_constraints(crt.get(), ca ? 1 : 0, 0));
MCHK(mbedtls_x509write_crt_set_subject_key_identifier(crt.get()));
MCHK(mbedtls_x509write_crt_set_authority_key_identifier(crt.get()));
// Because mbedtls does not support parsing x509v3 extensions from a
// CSR (https://github.com/ARMmbed/mbedtls/issues/2912), the CA sets the
// SAN directly instead of reading it from the CSR
try
{
MCHK(x509write_crt_set_subject_alt_names(crt.get(), subject_alt_names));
}
catch (const std::logic_error& err)
{
LOG_FAIL_FMT("Error writing SAN: {}", err.what());
return {};
}
uint8_t buf[4096];
memset(buf, 0, sizeof(buf));
MCHK(mbedtls_x509write_crt_pem(
crt.get(), buf, sizeof(buf), entropy->get_rng(), entropy->get_data()));
auto len = strlen((char*)buf);
return Pem(buf, len);
}
};
}

Просмотреть файл

@ -1,634 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "key_pair_base.h"
#include <openssl/ec.h>
#include <openssl/engine.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/x509v3.h>
namespace crypto
{
namespace
{
inline void OPENSSL_CHECK1(int rc)
{
unsigned long ec = ERR_get_error();
if (rc != 1 && ec != 0)
{
throw std::runtime_error(
fmt::format("OpenSSL error: {}", ERR_error_string(ec, NULL)));
}
}
inline void OPENSSL_CHECKNULL(void* ptr)
{
if (ptr == NULL)
{
throw std::runtime_error("OpenSSL error: missing object");
}
}
class Unique_BIO
{
std::unique_ptr<BIO, void (*)(BIO*)> p;
public:
Unique_BIO() : p(BIO_new(BIO_s_mem()), [](auto x) { BIO_free(x); })
{
if (!p)
throw std::runtime_error("out of memory");
}
Unique_BIO(const void* buf, int len) :
p(BIO_new_mem_buf(buf, len), [](auto x) { BIO_free(x); })
{
if (!p)
throw std::runtime_error("out of memory");
}
operator BIO*()
{
return p.get();
}
};
class Unique_EVP_PKEY_CTX
{
std::unique_ptr<EVP_PKEY_CTX, void (*)(EVP_PKEY_CTX*)> p;
public:
Unique_EVP_PKEY_CTX(EVP_PKEY* key) :
p(EVP_PKEY_CTX_new(key, NULL), EVP_PKEY_CTX_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
Unique_EVP_PKEY_CTX() :
p(EVP_PKEY_CTX_new_id(EVP_PKEY_EC, NULL), EVP_PKEY_CTX_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
operator EVP_PKEY_CTX*()
{
return p.get();
}
};
class Unique_X509_REQ
{
std::unique_ptr<X509_REQ, void (*)(X509_REQ*)> p;
public:
Unique_X509_REQ() : p(X509_REQ_new(), X509_REQ_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
Unique_X509_REQ(BIO* mem) :
p(PEM_read_bio_X509_REQ(mem, NULL, NULL, NULL), X509_REQ_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
operator X509_REQ*()
{
return p.get();
}
};
class Unique_X509
{
std::unique_ptr<X509, void (*)(X509*)> p;
public:
Unique_X509() : p(X509_new(), X509_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
Unique_X509(BIO* mem) :
p(PEM_read_bio_X509(mem, NULL, NULL, NULL), X509_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
operator X509*()
{
return p.get();
}
};
}
class PublicKey_OpenSSL : public PublicKeyBase
{
protected:
EVP_PKEY* key = nullptr;
PublicKey_OpenSSL() {}
inline const EVP_MD* get_md_type(MDType mdt) const
{
switch (mdt)
{
case MDType::NONE:
return nullptr;
case MDType::SHA1:
return EVP_sha1();
case MDType::SHA256:
return EVP_sha256();
case MDType::SHA384:
return EVP_sha384();
case MDType::SHA512:
return EVP_sha512();
default:
return nullptr;
}
return nullptr;
}
public:
PublicKey_OpenSSL(PublicKey_OpenSSL&& key) = default;
/**
* Construct from PEM
*/
PublicKey_OpenSSL(const Pem& pem)
{
Unique_BIO mem(pem.data(), -1);
key = PEM_read_bio_PUBKEY(mem, NULL, NULL, NULL);
if (!key)
throw std::runtime_error("could not parse PEM");
}
/**
* Construct from DER
*/
PublicKey_OpenSSL(const std::vector<uint8_t>& der)
{
const unsigned char* pp = der.data();
key = d2i_PublicKey(EVP_PKEY_EC, &key, &pp, der.size());
if (!key)
{
throw new std::runtime_error("Could not read DER");
}
}
/**
* Construct from a pre-initialised pk context
*/
PublicKey_OpenSSL(EVP_PKEY* key) : key(key) {}
virtual ~PublicKey_OpenSSL()
{
if (key)
EVP_PKEY_free(key);
}
virtual CurveID get_curve_id() const override
{
int nid =
EC_GROUP_get_curve_name(EC_KEY_get0_group(EVP_PKEY_get0_EC_KEY(key)));
switch (nid)
{
case NID_secp384r1:
return CurveID::SECP384R1;
case NID_X9_62_prime256v1:
return CurveID::SECP256R1;
default:
throw std::runtime_error(
fmt::format("Unknown OpenSSL curve {}", nid));
}
return CurveID::NONE;
}
using PublicKeyBase::verify;
using PublicKeyBase::verify_hash;
virtual bool verify(
const uint8_t* contents,
size_t contents_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type,
HashBytes& bytes) override
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
OpenSSLHashProvider hp;
bytes = hp.Hash(contents, contents_size, md_type);
return verify_hash(bytes.data(), bytes.size(), sig, sig_size, md_type);
}
virtual bool verify_hash(
const uint8_t* hash,
size_t hash_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type) override
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
Unique_EVP_PKEY_CTX pctx(key);
OPENSSL_CHECK1(EVP_PKEY_verify_init(pctx));
if (md_type != MDType::NONE)
{
OPENSSL_CHECK1(
EVP_PKEY_CTX_set_signature_md(pctx, get_md_type(md_type)));
}
int rc = EVP_PKEY_verify(pctx, sig, sig_size, hash, hash_size);
bool ok = rc == 1;
if (!ok)
{
int ec = ERR_get_error();
LOG_DEBUG_FMT(
"OpenSSL signature verification failure: {}",
ERR_error_string(ec, NULL));
}
return ok;
}
/**
* Get the public key in PEM format
*/
virtual Pem public_key_pem() const override
{
Unique_BIO buf;
OPENSSL_CHECK1(PEM_write_bio_PUBKEY(buf, key));
BUF_MEM* bptr;
BIO_get_mem_ptr(buf, &bptr);
return Pem((uint8_t*)bptr->data, bptr->length);
}
static std::string error_string(unsigned long ec)
{
return ERR_error_string(ec, NULL);
}
};
class KeyPair_OpenSSL : public PublicKey_OpenSSL, public KeyPairBase
{
protected:
inline int get_openssl_group_id(CurveID gid)
{
switch (gid)
{
case CurveID::NONE:
return NID_undef;
case CurveID::SECP384R1:
return NID_secp384r1;
case CurveID::SECP256R1:
return NID_X9_62_prime256v1;
default:
throw std::logic_error(
fmt::format("unsupported OpenSSL CurveID {}", gid));
}
return MBEDTLS_ECP_DP_NONE;
}
static std::vector<std::pair<std::string, std::string>> parse_name(
const std::string& name)
{
std::vector<std::pair<std::string, std::string>> r;
char* name_cpy = strdup(name.c_str());
char* p = std::strtok(name_cpy, ",");
while (p)
{
char* eq = strchr(p, '=');
*eq = '\0';
r.push_back(std::make_pair(p, eq + 1));
p = std::strtok(NULL, ",");
}
free(name_cpy);
return r;
}
public:
/**
* Generate a fresh key
*/
KeyPair_OpenSSL(CurveID curve_id)
{
int curve_nid = get_openssl_group_id(curve_id);
key = EVP_PKEY_new();
Unique_EVP_PKEY_CTX pkctx;
if (
EVP_PKEY_paramgen_init(pkctx) < 0 ||
EVP_PKEY_CTX_set_ec_paramgen_curve_nid(pkctx, curve_nid) < 0 ||
EVP_PKEY_CTX_set_ec_param_enc(pkctx, OPENSSL_EC_NAMED_CURVE) < 0)
throw std::runtime_error("could not initialize PK context");
if (EVP_PKEY_keygen_init(pkctx) < 0 || EVP_PKEY_keygen(pkctx, &key) < 0)
throw std::runtime_error("could not generate new EC key");
}
KeyPair_OpenSSL(const KeyPair_OpenSSL&) = delete;
KeyPair_OpenSSL(const Pem& pem, CBuffer pw = nullb)
{
Unique_BIO mem(pem.data(), -1);
key = PEM_read_bio_PrivateKey(mem, NULL, NULL, (void*)pw.p);
if (!key)
throw std::runtime_error("could not parse PEM");
}
virtual ~KeyPair_OpenSSL() = default;
using PublicKey_OpenSSL::verify;
virtual bool verify(
const std::vector<uint8_t>& contents,
const std::vector<uint8_t>& signature) override
{
return PublicKey_OpenSSL::verify(contents, signature);
}
/**
* Get the private key in PEM format
*/
virtual Pem private_key_pem() const override
{
Unique_BIO buf;
OPENSSL_CHECK1(
PEM_write_bio_PrivateKey(buf, key, NULL, NULL, 0, NULL, NULL));
BUF_MEM* bptr;
BIO_get_mem_ptr(buf, &bptr);
return Pem((uint8_t*)bptr->data, bptr->length);
}
/**
* Get the public key in PEM format
*/
virtual Pem public_key_pem() const override
{
return PublicKey_OpenSSL::public_key_pem();
}
/**
* Create signature over hash of data from private key.
*
* @param d data
*
* @return Signature as a vector
*/
virtual std::vector<uint8_t> sign(
CBuffer d, MDType md_type = {}) const override
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
OpenSSLHashProvider hp;
HashBytes hash = hp.Hash(d.p, d.rawSize(), md_type);
return sign_hash(hash.data(), hash.size());
}
/**
* Write signature over hash of data, and the size of that signature to
* specified locations.
*
* @param d data
* @param sig_size location to which the signature size will be written.
* Initial value should be max size of sig
* @param sig location to which the signature will be written
*
* @return 0 if successful, otherwise OpenSSL error code
*/
int sign(
CBuffer d, size_t* sig_size, uint8_t* sig, MDType md_type = {}) const
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
OpenSSLHashProvider hp;
HashBytes hash = hp.Hash(d.p, d.rawSize(), md_type);
return sign_hash(hash.data(), hash.size(), sig_size, sig);
}
/**
* Create signature over hashed data.
*
* @param hash First byte in hash sequence
* @param hash_size Number of bytes in hash sequence
*
* @return Signature as a vector
*/
std::vector<uint8_t> sign_hash(
const uint8_t* hash, size_t hash_size) const override
{
std::vector<uint8_t> sig(EVP_PKEY_size(key));
size_t written = sig.size();
if (sign_hash(hash, hash_size, &written, sig.data()) != 0)
{
return {};
}
sig.resize(written);
return sig;
}
virtual int sign_hash(
const uint8_t* hash,
size_t hash_size,
size_t* sig_size,
uint8_t* sig) const override
{
Unique_EVP_PKEY_CTX pctx(key);
OPENSSL_CHECK1(EVP_PKEY_sign_init(pctx));
OPENSSL_CHECK1(EVP_PKEY_sign(pctx, sig, sig_size, hash, hash_size));
return 0;
}
/**
* Create a certificate signing request for this key pair. If we were
* loaded from a private key, there will be no public key available for
* this call.
*/
virtual Pem create_csr(const std::string& name) const override
{
Unique_X509_REQ req;
OPENSSL_CHECK1(X509_REQ_set_pubkey(req, key));
X509_NAME* subj_name = NULL;
OPENSSL_CHECKNULL(subj_name = X509_NAME_new());
for (auto kv : parse_name(name))
{
OPENSSL_CHECK1(X509_NAME_add_entry_by_txt(
subj_name,
kv.first.c_str(),
MBSTRING_ASC,
(const unsigned char*)kv.second.c_str(),
-1,
-1,
0));
}
OPENSSL_CHECK1(X509_REQ_set_subject_name(req, subj_name));
X509_NAME_free(subj_name);
if (key)
OPENSSL_CHECK1(X509_REQ_sign(req, key, EVP_sha512()));
Unique_BIO mem;
OPENSSL_CHECK1(PEM_write_bio_X509_REQ(mem, req));
BUF_MEM* bptr;
BIO_get_mem_ptr(mem, &bptr);
Pem result((uint8_t*)bptr->data, bptr->length);
return result;
}
virtual Pem sign_csr(
const Pem& issuer_cert,
const Pem& signing_request,
const std::vector<SubjectAltName> subject_alt_names,
bool ca = false) const override
{
X509* icrt = NULL;
Unique_BIO mem(signing_request.data(), -1);
Unique_X509_REQ csr(mem);
Unique_X509 crt;
OPENSSL_CHECK1(X509_set_version(crt, 2));
// Add serial number
unsigned char rndbytes[16];
OPENSSL_CHECK1(RAND_bytes(rndbytes, sizeof(rndbytes)));
BIGNUM* bn = NULL;
OPENSSL_CHECKNULL(bn = BN_new());
BN_bin2bn(rndbytes, sizeof(rndbytes), bn);
ASN1_INTEGER* serial = ASN1_INTEGER_new();
BN_to_ASN1_INTEGER(bn, serial);
OPENSSL_CHECK1(X509_set_serialNumber(crt, serial));
ASN1_INTEGER_free(serial);
BN_free(bn);
// Add issuer name
if (!issuer_cert.empty())
{
Unique_BIO imem(issuer_cert.data(), -1);
OPENSSL_CHECKNULL(icrt = PEM_read_bio_X509(imem, NULL, NULL, NULL));
OPENSSL_CHECK1(X509_set_issuer_name(crt, X509_get_subject_name(icrt)));
}
else
{
OPENSSL_CHECK1(
X509_set_issuer_name(crt, X509_REQ_get_subject_name(csr)));
}
// Note: 825-day validity range
// https://support.apple.com/en-us/HT210176
ASN1_TIME *before = NULL, *after = NULL;
OPENSSL_CHECKNULL(before = ASN1_TIME_new());
OPENSSL_CHECKNULL(after = ASN1_TIME_new());
OPENSSL_CHECK1(ASN1_TIME_set_string(before, "20191101000000Z"));
OPENSSL_CHECK1(ASN1_TIME_set_string(after, "20211231235959Z"));
X509_set1_notBefore(crt, before);
X509_set1_notAfter(crt, after);
ASN1_TIME_free(before);
ASN1_TIME_free(after);
X509_set_subject_name(crt, X509_REQ_get_subject_name(csr));
EVP_PKEY* req_pubkey = X509_REQ_get_pubkey(csr);
X509_set_pubkey(crt, req_pubkey);
EVP_PKEY_free(req_pubkey);
// Extensions
X509V3_CTX v3ctx;
X509V3_set_ctx_nodb(&v3ctx);
X509V3_set_ctx(&v3ctx, icrt ? icrt : crt, NULL, csr, NULL, 0);
// Add basic constraints
X509_EXTENSION* ext = NULL;
OPENSSL_CHECKNULL(
ext = X509V3_EXT_conf_nid(
NULL, &v3ctx, NID_basic_constraints, ca ? "CA:TRUE" : "CA:FALSE"));
OPENSSL_CHECK1(X509_add_ext(crt, ext, -1));
X509_EXTENSION_free(ext);
// Add subject key identifier
OPENSSL_CHECKNULL(
ext = X509V3_EXT_conf_nid(
NULL, &v3ctx, NID_subject_key_identifier, "hash"));
OPENSSL_CHECK1(X509_add_ext(crt, ext, -1));
X509_EXTENSION_free(ext);
// Add authority key identifier
OPENSSL_CHECKNULL(
ext = X509V3_EXT_conf_nid(
NULL, &v3ctx, NID_authority_key_identifier, "keyid:always"));
OPENSSL_CHECK1(X509_add_ext(crt, ext, -1));
X509_EXTENSION_free(ext);
// Subject alternative names (Necessary? Shouldn't they be in the CSR?)
if (!subject_alt_names.empty())
{
std::string all_alt_names;
bool first = true;
for (auto san : subject_alt_names)
{
if (first)
{
first = !first;
}
else
{
all_alt_names += ", ";
}
if (san.is_ip)
all_alt_names += "IP:";
else
all_alt_names += "DNS:";
all_alt_names += san.san;
}
OPENSSL_CHECKNULL(
ext = X509V3_EXT_conf_nid(
NULL, &v3ctx, NID_subject_alt_name, all_alt_names.c_str()));
OPENSSL_CHECK1(X509_add_ext(crt, ext, -1));
X509_EXTENSION_free(ext);
}
// Sign
auto md = get_md_type(get_md_for_ec(get_curve_id()));
int size = X509_sign(crt, key, md);
if (size <= 0)
throw std::runtime_error("could not sign CRT");
Unique_BIO omem;
OPENSSL_CHECK1(PEM_write_bio_X509(omem, crt));
// Export
BUF_MEM* bptr;
BIO_get_mem_ptr(omem, &bptr);
Pem result((uint8_t*)bptr->data, bptr->length);
if (icrt)
X509_free(icrt);
return result;
}
};
}

Просмотреть файл

@ -0,0 +1,45 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/curve.h"
#include "crypto/hash.h"
#include <mbedtls/ecp.h>
#include <stdexcept>
#include <string>
namespace crypto
{
// Helper to access elliptic curve id from context
inline mbedtls_ecp_group_id get_mbedtls_ec_from_context(
const mbedtls_pk_context& ctx)
{
return mbedtls_pk_ec(ctx)->grp.id;
}
inline mbedtls_md_type_t get_mbedtls_md_for_ec(
mbedtls_ecp_group_id ec, bool allow_none = false)
{
switch (ec)
{
case MBEDTLS_ECP_DP_SECP384R1:
return MBEDTLS_MD_SHA384;
case MBEDTLS_ECP_DP_SECP256R1:
return MBEDTLS_MD_SHA256;
default:
{
if (allow_none)
{
return MBEDTLS_MD_NONE;
}
else
{
const auto error = fmt::format("Unhandled ecp group id: {}", ec);
throw std::logic_error(error);
}
}
}
}
}

Просмотреть файл

@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/entropy.h"
#include "mbedtls_wrappers.h"
#include <functional>
#include <memory>
#include <vector>
namespace crypto
{
class MbedtlsEntropy : public Entropy
{
private:
mbedtls::Entropy entropy = mbedtls::make_unique<mbedtls::Entropy>();
mbedtls::CtrDrbg drbg = mbedtls::make_unique<mbedtls::CtrDrbg>();
static bool gen(uint64_t& v);
public:
MbedtlsEntropy()
{
mbedtls_ctr_drbg_seed(
drbg.get(), mbedtls_entropy_func, entropy.get(), nullptr, 0);
}
std::vector<uint8_t> random(size_t len) override
{
std::vector<uint8_t> data(len);
if (mbedtls_ctr_drbg_random(drbg.get(), data.data(), data.size()) != 0)
throw std::logic_error("Couldn't create random data");
return data;
}
uint64_t random64() override
{
uint64_t rnd;
uint64_t len = sizeof(uint64_t);
if (
mbedtls_ctr_drbg_random(
drbg.get(), reinterpret_cast<unsigned char*>(&rnd), len) != 0)
{
throw std::logic_error("Couldn't create random data");
}
return rnd;
}
void random(unsigned char* data, size_t len) override
{
if (mbedtls_ctr_drbg_random(drbg.get(), data, len) != 0)
throw std::logic_error("Couldn't create random data");
}
static int rng(void* ctx, unsigned char* output, size_t len)
{
return mbedtls_ctr_drbg_random(ctx, output, len);
}
rng_func_t get_rng() override
{
return &rng;
}
void* get_data() override
{
return drbg.get();
}
};
}

Просмотреть файл

@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "hash.h"
#include "ds/buffer.h"
#include "mbedtls_wrappers.h"
#include <mbedtls/sha256.h>
#include <stdexcept>
namespace crypto
{
using namespace mbedtls;
void mbedtls_sha256(const CBuffer& data, uint8_t* h)
{
mbedtls_sha256_context ctx;
mbedtls_sha256_init(&ctx);
mbedtls_sha256_starts_ret(&ctx, 0);
mbedtls_sha256_update_ret(&ctx, data.p, data.rawSize());
mbedtls_sha256_finish_ret(&ctx, h);
mbedtls_sha256_free(&ctx);
}
ISha256MbedTLS::ISha256MbedTLS()
{
ctx = new mbedtls_sha256_context();
mbedtls_sha256_starts_ret((mbedtls_sha256_context*)ctx, 0);
}
ISha256MbedTLS::~ISha256MbedTLS()
{
delete (mbedtls_sha256_context*)ctx;
}
Sha256Hash ISha256MbedTLS::finalise()
{
if (!ctx)
{
throw std::logic_error("Attempting to use hash after it was finalised");
}
Sha256Hash r;
mbedtls_sha256_finish_ret((mbedtls_sha256_context*)ctx, r.h.data());
mbedtls_sha256_free((mbedtls_sha256_context*)ctx);
delete (mbedtls_sha256_context*)ctx;
ctx = nullptr;
return r;
}
void ISha256MbedTLS::update_hash(CBuffer data)
{
if (!ctx)
{
throw std::logic_error("Attempting to use hash after it was finalised");
}
mbedtls_sha256_update_ret(
(mbedtls_sha256_context*)ctx, data.p, data.rawSize());
}
}

72
src/crypto/mbedtls/hash.h Normal file
Просмотреть файл

@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/hash_base.h"
#include "ds/buffer.h"
#include <mbedtls/md.h>
#include <mbedtls/pk.h>
#define FMT_HEADER_ONLY
#include <fmt/format.h>
#include <msgpack/msgpack.hpp>
namespace crypto
{
namespace mbedtls
{
inline mbedtls_md_type_t get_md_type(MDType type)
{
switch (type)
{
case MDType::NONE:
return MBEDTLS_MD_NONE;
case MDType::SHA1:
return MBEDTLS_MD_SHA1;
case MDType::SHA256:
return MBEDTLS_MD_SHA256;
case MDType::SHA384:
return MBEDTLS_MD_SHA384;
case MDType::SHA512:
return MBEDTLS_MD_SHA512;
default:
throw std::runtime_error("Unsupported hash algorithm");
}
return MBEDTLS_MD_NONE;
}
}
class MBedHashProvider : public HashProviderBase
{
public:
virtual HashBytes Hash(const uint8_t* data, size_t size, MDType type) const
{
HashBytes r;
const auto mbedtls_md_type = mbedtls::get_md_type(type);
const auto info = mbedtls_md_info_from_type(mbedtls_md_type);
const auto hash_size = mbedtls_md_get_size(info);
r.resize(hash_size);
if (mbedtls_md(info, data, size, r.data()) != 0)
r.clear();
return r;
}
};
class ISha256MbedTLS : public ISha256HashBase
{
public:
ISha256MbedTLS();
~ISha256MbedTLS();
virtual void update_hash(CBuffer data);
virtual Sha256Hash finalise();
protected:
void* ctx;
};
void mbedtls_sha256(const CBuffer& data, uint8_t* h);
}

Просмотреть файл

@ -0,0 +1,486 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "key_pair.h"
#include "curve.h"
#include "ds/net.h"
#include "entropy.h"
#define FMT_HEADER_ONLY
#include <fmt/format.h>
#include <iomanip>
#include <limits>
#include <mbedtls/asn1write.h>
#include <mbedtls/bignum.h>
#include <mbedtls/error.h>
#include <mbedtls/oid.h>
#include <mbedtls/pem.h>
#include <mbedtls/pk.h>
#include <mbedtls/x509.h>
#include <mbedtls/x509_crt.h>
#include <memory>
#include <string>
namespace crypto
{
using namespace mbedtls;
static constexpr size_t max_pem_key_size = 2048;
static mbedtls_ecp_group_id get_mbedtls_group_id(CurveID gid)
{
switch (gid)
{
case CurveID::NONE:
return MBEDTLS_ECP_DP_NONE;
case CurveID::SECP384R1:
return MBEDTLS_ECP_DP_SECP384R1;
case CurveID::SECP256R1:
return MBEDTLS_ECP_DP_SECP256R1;
default:
throw std::logic_error(fmt::format("unsupported CurveID {}", gid));
}
return MBEDTLS_ECP_DP_NONE;
}
KeyPair_mbedTLS::KeyPair_mbedTLS(CurveID cid) : PublicKey_mbedTLS()
{
mbedtls_ecp_group_id ec = get_mbedtls_group_id(cid);
EntropyPtr entropy = create_entropy();
int rc =
mbedtls_pk_setup(ctx.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY));
if (rc != 0)
{
throw std::logic_error(
"Could not set up ECDSA context: " + error_string(rc));
}
rc = mbedtls_ecp_gen_key(
ec, mbedtls_pk_ec(*ctx), entropy->get_rng(), entropy->get_data());
if (rc != 0)
{
throw std::logic_error(
"Could not generate ECDSA keypair: " + error_string(rc));
}
const auto actual_ec = get_mbedtls_ec_from_context(*ctx);
if (actual_ec != ec)
{
throw std::logic_error(
"Created key and received unexpected type: " +
std::to_string(actual_ec) + " != " + error_string(ec));
}
}
KeyPair_mbedTLS::KeyPair_mbedTLS(const Pem& pem, CBuffer pw)
{
// keylen is +1 to include terminating null byte
int rc =
mbedtls_pk_parse_key(ctx.get(), pem.data(), pem.size(), pw.p, pw.n);
if (rc != 0)
{
throw std::logic_error(
"Could not parse private key: " + error_string(rc));
}
}
KeyPair_mbedTLS::KeyPair_mbedTLS(mbedtls::PKContext&& k) :
PublicKey_mbedTLS(std::move(k))
{}
Pem KeyPair_mbedTLS::private_key_pem() const
{
uint8_t data[max_pem_key_size];
int rc = mbedtls_pk_write_key_pem(ctx.get(), data, max_pem_key_size);
if (rc != 0)
{
throw std::logic_error("mbedtls_pk_write_key_pem: " + error_string(rc));
}
const size_t len = strlen((char const*)data);
return Pem(data, len);
}
Pem KeyPair_mbedTLS::public_key_pem() const
{
return PublicKey_mbedTLS::public_key_pem();
}
bool KeyPair_mbedTLS::verify(
const std::vector<uint8_t>& contents, const std::vector<uint8_t>& signature)
{
return PublicKey_mbedTLS::verify(contents, signature);
}
std::vector<uint8_t> KeyPair_mbedTLS::sign(CBuffer d, MDType md_type) const
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
MBedHashProvider hp;
HashBytes hash = hp.Hash(d.p, d.rawSize(), md_type);
return sign_hash(hash.data(), hash.size());
}
int KeyPair_mbedTLS::sign(
CBuffer d, size_t* sig_size, uint8_t* sig, MDType md_type) const
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
MBedHashProvider hp;
HashBytes hash = hp.Hash(d.p, d.rawSize(), md_type);
return sign_hash(hash.data(), hash.size(), sig_size, sig);
}
std::vector<uint8_t> KeyPair_mbedTLS::sign_hash(
const uint8_t* hash, size_t hash_size) const
{
std::vector<uint8_t> sig(MBEDTLS_ECDSA_MAX_LEN);
size_t written = sizeof(sig);
if (sign_hash(hash, hash_size, &written, sig.data()) != 0)
{
return {};
}
sig.resize(written);
return sig;
}
int KeyPair_mbedTLS::sign_hash(
const uint8_t* hash, size_t hash_size, size_t* sig_size, uint8_t* sig) const
{
EntropyPtr entropy = create_entropy();
const auto mmdt = get_md_type(get_md_for_ec(get_curve_id()));
int r = mbedtls_pk_sign(
ctx.get(),
mmdt,
hash,
hash_size,
sig,
sig_size,
entropy->get_rng(),
entropy->get_data());
return r;
}
Pem KeyPair_mbedTLS::create_csr(const std::string& name) const
{
auto csr = mbedtls::make_unique<mbedtls::X509WriteCsr>();
mbedtls_x509write_csr_set_md_alg(csr.get(), MBEDTLS_MD_SHA512);
if (mbedtls_x509write_csr_set_subject_name(csr.get(), name.c_str()) != 0)
return {};
mbedtls_x509write_csr_set_key(csr.get(), ctx.get());
uint8_t buf[4096];
memset(buf, 0, sizeof(buf));
EntropyPtr entropy = create_entropy();
if (
mbedtls_x509write_csr_pem(
csr.get(), buf, sizeof(buf), entropy->get_rng(), entropy->get_data()) !=
0)
return {};
auto len = strlen((char*)buf);
return Pem(buf, len);
}
static void MCHK(int rc)
{
if (rc != 0)
{
throw std::logic_error(
fmt::format("mbedTLS error: {}", error_string(rc)));
}
}
// Unfortunately, mbedtls does not provide a convenient API to write x509v3
// extensions for all supported Subject Alternative Name (SAN). Until they
// do, we have to write raw ASN1 ourselves.
// rfc5280 does not specify a maximum length for SAN,
// but rfc1035 specified that 255 bytes is enough for a DNS name
static constexpr auto max_san_length = 256;
static constexpr auto max_san_entries = 8;
// As per https://tools.ietf.org/html/rfc5280#section-4.2.1.6
enum san_type
{
other_name = 0,
rfc822_name = 1,
dns_name = 2,
x400_address = 3,
directory_name = 4,
edi_party_name = 5,
uniform_resource_identifier = 6,
ip_address = 7,
registeredID = 8
};
static inline int x509write_crt_set_subject_alt_name(
mbedtls_x509write_cert* ctx, const char* name, san_type san)
{
uint8_t san_buf[max_san_length];
int ret = 0;
size_t len = 0;
// mbedtls asn1 write API writes backward in san_buf
uint8_t* pc = san_buf + max_san_length;
auto name_len = strlen(name);
if (name_len > max_san_length)
{
throw std::logic_error(fmt::format(
"Subject Alternative Name {} is too long ({}>{})",
name,
name_len,
max_san_length));
}
switch (san)
{
case san_type::dns_name:
{
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_raw_buffer(
&pc, san_buf, (const unsigned char*)name, name_len));
MBEDTLS_ASN1_CHK_ADD(
len, mbedtls_asn1_write_len(&pc, san_buf, name_len));
break;
}
// mbedtls (2.16.2) only supports parsing of subject alternative name
// that is DNS= (so no IPAddress=). When connecting to a node that has
// IPAddress set, mbedtls_ssl_set_hostname() should not be called.
// However, it should work fine with a majority of other clients (e.g.
// curl).
case san_type::ip_address:
{
auto addr = ds::ip_to_binary(name);
if (!addr.has_value())
{
throw std ::logic_error(fmt::format(
"Subject Alternative Name {} is not a valid IPv4 or "
"IPv6 address",
name));
}
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_raw_buffer(
&pc, san_buf, (const unsigned char*)&addr->buf, addr->size));
MBEDTLS_ASN1_CHK_ADD(
len, mbedtls_asn1_write_len(&pc, san_buf, addr->size));
break;
}
default:
{
throw std::logic_error(
fmt::format("Subject Alternative Name {} is not supported", san));
}
}
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_tag(
&pc, san_buf, MBEDTLS_ASN1_CONTEXT_SPECIFIC | san));
MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_len(&pc, san_buf, len));
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_tag(
&pc, san_buf, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE));
return mbedtls_x509write_crt_set_extension(
ctx,
MBEDTLS_OID_SUBJECT_ALT_NAME,
MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME),
0, // Mark SAN as non-critical
san_buf + max_san_length - len,
len);
}
static inline int x509write_crt_set_subject_alt_names(
mbedtls_x509write_cert* ctx, const std::vector<SubjectAltName>& sans)
{
if (sans.size() == 0)
return 0;
if (sans.size() > max_san_entries)
{
throw std::logic_error(fmt::format(
"Cannot set more than {} subject alternative names", max_san_entries));
}
// The factor of two is an extremely conservative provision for ASN.1
// metadata
size_t buf_len = sans.size() * max_san_length * 2;
std::vector<uint8_t> buf(buf_len);
uint8_t* san_buf = buf.data();
int ret = 0;
size_t len = 0;
// mbedtls asn1 write API writes backward in san_buf
uint8_t* pc = san_buf + buf_len;
for (auto& san : sans)
{
if (san.san.size() > max_san_length)
{
throw std::logic_error(fmt::format(
"Subject Alternative Name {} is too long ({}>{})",
san.san,
san.san.size(),
max_san_length));
}
if (san.is_ip)
{
// mbedtls (2.16.2) only supports parsing of subject alternative name
// that is DNS= (so no IPAddress=). When connecting to a node that has
// IPAddress set, mbedtls_ssl_set_hostname() should not be called.
// However, it should work fine with a majority of other clients (e.g.
// curl).
auto addr = ds::ip_to_binary(san.san.c_str());
if (!addr.has_value())
{
throw std ::logic_error(fmt::format(
"Subject Alternative Name {} is not a valid IPv4 or "
"IPv6 address",
san.san));
}
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_raw_buffer(
&pc, san_buf, (const unsigned char*)&addr->buf, addr->size));
MBEDTLS_ASN1_CHK_ADD(
len, mbedtls_asn1_write_len(&pc, san_buf, addr->size));
}
else
{
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_raw_buffer(
&pc,
san_buf,
(const unsigned char*)san.san.data(),
san.san.size()));
MBEDTLS_ASN1_CHK_ADD(
len, mbedtls_asn1_write_len(&pc, san_buf, san.san.size()));
}
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_tag(
&pc,
san_buf,
MBEDTLS_ASN1_CONTEXT_SPECIFIC |
(san.is_ip ? san_type::ip_address : san_type::dns_name)));
}
MBEDTLS_ASN1_CHK_ADD(len, mbedtls_asn1_write_len(&pc, san_buf, len));
MBEDTLS_ASN1_CHK_ADD(
len,
mbedtls_asn1_write_tag(
&pc, san_buf, MBEDTLS_ASN1_CONSTRUCTED | MBEDTLS_ASN1_SEQUENCE));
return mbedtls_x509write_crt_set_extension(
ctx,
MBEDTLS_OID_SUBJECT_ALT_NAME,
MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME),
0, // Mark SAN as non-critical
san_buf + buf_len - len,
len);
}
Pem KeyPair_mbedTLS::sign_csr(
const Pem& issuer_cert,
const Pem& signing_request,
const std::vector<SubjectAltName> subject_alt_names,
bool ca) const
{
auto entropy = create_entropy();
auto csr = mbedtls::make_unique<mbedtls::X509Csr>();
auto serial = mbedtls::make_unique<mbedtls::MPI>();
auto crt = mbedtls::make_unique<mbedtls::X509WriteCrt>();
auto icrt = mbedtls::make_unique<mbedtls::X509Crt>();
MCHK(mbedtls_x509_csr_parse(
csr.get(), signing_request.data(), signing_request.size()));
char subject[512];
mbedtls_x509_dn_gets(subject, sizeof(subject), &csr->subject);
mbedtls_x509write_crt_set_md_alg(
crt.get(), get_mbedtls_md_for_ec(get_mbedtls_ec_from_context(*ctx)));
mbedtls_x509write_crt_set_subject_key(crt.get(), &csr->pk);
if (!issuer_cert.empty())
{
MCHK(mbedtls_x509_crt_parse(
icrt.get(), issuer_cert.data(), issuer_cert.size()));
mbedtls_x509write_crt_set_issuer_key(crt.get(), ctx.get());
char issuer_name[512];
mbedtls_x509_dn_gets(issuer_name, sizeof(issuer_name), &icrt->subject);
MCHK(mbedtls_x509write_crt_set_issuer_name(crt.get(), issuer_name));
}
else
{
mbedtls_x509write_crt_set_issuer_key(crt.get(), ctx.get());
MCHK(mbedtls_x509write_crt_set_issuer_name(crt.get(), subject));
}
MCHK(mbedtls_mpi_fill_random(
serial.get(), 16, entropy->get_rng(), entropy->get_data()));
MCHK(mbedtls_x509write_crt_set_subject_name(crt.get(), subject));
MCHK(mbedtls_x509write_crt_set_serial(crt.get(), serial.get()));
// Note: 825-day validity range
// https://support.apple.com/en-us/HT210176
MCHK(mbedtls_x509write_crt_set_validity(
crt.get(), "20191101000000", "20211231235959"));
MCHK(mbedtls_x509write_crt_set_basic_constraints(crt.get(), ca ? 1 : 0, 0));
MCHK(mbedtls_x509write_crt_set_subject_key_identifier(crt.get()));
MCHK(mbedtls_x509write_crt_set_authority_key_identifier(crt.get()));
// Because mbedtls does not support parsing x509v3 extensions from a
// CSR (https://github.com/ARMmbed/mbedtls/issues/2912), the CA sets the
// SAN directly instead of reading it from the CSR
try
{
MCHK(x509write_crt_set_subject_alt_names(crt.get(), subject_alt_names));
}
catch (const std::logic_error& err)
{
LOG_FAIL_FMT("Error writing SAN: {}", err.what());
return {};
}
uint8_t buf[4096];
memset(buf, 0, sizeof(buf));
MCHK(mbedtls_x509write_crt_pem(
crt.get(), buf, sizeof(buf), entropy->get_rng(), entropy->get_data()));
auto len = strlen((char*)buf);
return Pem(buf, len);
}
}

Просмотреть файл

@ -0,0 +1,54 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "../key_pair.h"
#include "../san.h"
#include "mbedtls_wrappers.h"
#include "public_key.h"
namespace crypto
{
class KeyPair_mbedTLS : public PublicKey_mbedTLS, public KeyPair
{
public:
KeyPair_mbedTLS(CurveID cid);
KeyPair_mbedTLS(const Pem& pem, CBuffer pw = nullb);
KeyPair_mbedTLS(mbedtls::PKContext&& k);
KeyPair_mbedTLS(const KeyPair_mbedTLS&) = delete;
virtual ~KeyPair_mbedTLS() = default;
virtual Pem private_key_pem() const override;
virtual Pem public_key_pem() const override;
using PublicKey_mbedTLS::verify;
virtual bool verify(
const std::vector<uint8_t>& contents,
const std::vector<uint8_t>& signature) override;
virtual std::vector<uint8_t> sign(
CBuffer d, MDType md_type = {}) const override;
int sign(
CBuffer d, size_t* sig_size, uint8_t* sig, MDType md_type = {}) const;
std::vector<uint8_t> sign_hash(
const uint8_t* hash, size_t hash_size) const override;
virtual int sign_hash(
const uint8_t* hash,
size_t hash_size,
size_t* sig_size,
uint8_t* sig) const override;
virtual Pem create_csr(const std::string& name) const override;
virtual Pem sign_csr(
const Pem& issuer_cert,
const Pem& signing_request,
const std::vector<SubjectAltName> subject_alt_names,
bool ca = false) const override;
};
}

Просмотреть файл

@ -0,0 +1,109 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/entropy.h>
#include <mbedtls/error.h>
#include <mbedtls/gcm.h>
#include <mbedtls/net_sockets.h>
#include <mbedtls/sha256.h>
#include <mbedtls/ssl.h>
#include <mbedtls/x509.h>
#include <mbedtls/x509_crt.h>
#include <mbedtls/x509_csr.h>
#include <memory>
#include <string>
namespace crypto
{
namespace mbedtls
{
template <typename T>
T make_unique();
#define DEFINE_MBEDTLS_WRAPPER( \
NEW_TYPE, MBED_TYPE, MBED_INIT_FN, MBED_FREE_FN) \
struct NEW_TYPE##Deleter \
{ \
void operator()(MBED_TYPE* ptr) \
{ \
MBED_FREE_FN(ptr); \
delete ptr; \
} \
}; \
using NEW_TYPE = std::unique_ptr<MBED_TYPE, NEW_TYPE##Deleter>; \
template <> \
inline NEW_TYPE make_unique<NEW_TYPE>() \
{ \
auto p = new MBED_TYPE; \
MBED_INIT_FN(p); \
return NEW_TYPE(p); \
}
DEFINE_MBEDTLS_WRAPPER(
CtrDrbg,
mbedtls_ctr_drbg_context,
mbedtls_ctr_drbg_init,
mbedtls_ctr_drbg_free);
DEFINE_MBEDTLS_WRAPPER(
ECDHContext, mbedtls_ecdh_context, mbedtls_ecdh_init, mbedtls_ecdh_free);
DEFINE_MBEDTLS_WRAPPER(
Entropy,
mbedtls_entropy_context,
mbedtls_entropy_init,
mbedtls_entropy_free);
DEFINE_MBEDTLS_WRAPPER(
GcmContext, mbedtls_gcm_context, mbedtls_gcm_init, mbedtls_gcm_free);
DEFINE_MBEDTLS_WRAPPER(
MPI, mbedtls_mpi, mbedtls_mpi_init, mbedtls_mpi_free);
DEFINE_MBEDTLS_WRAPPER(
NetContext, mbedtls_net_context, mbedtls_net_init, mbedtls_net_free);
DEFINE_MBEDTLS_WRAPPER(
PKContext, mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free);
DEFINE_MBEDTLS_WRAPPER(
SSLConfig,
mbedtls_ssl_config,
mbedtls_ssl_config_init,
mbedtls_ssl_config_free);
DEFINE_MBEDTLS_WRAPPER(
SSLContext, mbedtls_ssl_context, mbedtls_ssl_init, mbedtls_ssl_free);
DEFINE_MBEDTLS_WRAPPER(
X509Crl, mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free);
DEFINE_MBEDTLS_WRAPPER(
X509Crt, mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free);
DEFINE_MBEDTLS_WRAPPER(
X509Csr, mbedtls_x509_csr, mbedtls_x509_csr_init, mbedtls_x509_csr_free);
DEFINE_MBEDTLS_WRAPPER(
X509WriteCrt,
mbedtls_x509write_cert,
mbedtls_x509write_crt_init,
mbedtls_x509write_crt_free);
DEFINE_MBEDTLS_WRAPPER(
X509WriteCsr,
mbedtls_x509write_csr,
mbedtls_x509write_csr_init,
mbedtls_x509write_csr_free);
DEFINE_MBEDTLS_WRAPPER(
SHA256Ctx,
mbedtls_sha256_context,
mbedtls_sha256_init,
mbedtls_sha256_free);
#undef DEFINE_MBEDTLS_WRAPPER
inline std::string error_string(int err)
{
constexpr size_t len = 256;
char buf[len];
mbedtls_strerror(err, buf, len);
if (strlen(buf) == 0)
{
return std::to_string(err);
}
return std::string(buf);
}
}
}

Просмотреть файл

@ -0,0 +1,144 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "curve.h"
#include "ds/net.h"
#include "entropy.h"
#include "key_pair.h"
#define FMT_HEADER_ONLY
#include <fmt/format.h>
#include <iomanip>
#include <limits>
#include <mbedtls/asn1write.h>
#include <mbedtls/bignum.h>
#include <mbedtls/error.h>
#include <mbedtls/oid.h>
#include <mbedtls/pem.h>
#include <mbedtls/pk.h>
#include <mbedtls/x509.h>
#include <mbedtls/x509_crt.h>
#include <memory>
#include <string>
namespace crypto
{
using namespace mbedtls;
static constexpr size_t max_pem_key_size = 2048;
PublicKey_mbedTLS::PublicKey_mbedTLS() {}
PublicKey_mbedTLS::PublicKey_mbedTLS(const Pem& pem)
{
int rc = mbedtls_pk_parse_public_key(ctx.get(), pem.data(), pem.size());
if (rc != 0)
{
throw std::logic_error(fmt::format(
"Could not parse public key PEM: {}\n\n(Key: {})",
error_string(rc),
pem.str()));
}
}
PublicKey_mbedTLS::PublicKey_mbedTLS(const std::vector<uint8_t>& der)
{
int rc = mbedtls_pk_parse_public_key(ctx.get(), der.data(), der.size());
if (rc != 0)
{
throw std::logic_error(
fmt::format("Could not parse public key DER: {}", error_string(rc)));
}
}
static CurveID get_curve_id(const mbedtls_pk_context* pk_ctx)
{
if (mbedtls_pk_can_do(pk_ctx, MBEDTLS_PK_ECKEY))
{
auto grp_id = mbedtls_pk_ec(*pk_ctx)->grp.id;
switch (grp_id)
{
case MBEDTLS_ECP_DP_SECP384R1:
return CurveID::SECP384R1;
case MBEDTLS_ECP_DP_SECP256R1:
return CurveID::SECP256R1;
default:
throw std::logic_error(
fmt::format("unsupported mbedTLS group ID {}", grp_id));
}
}
return CurveID::NONE;
}
CurveID PublicKey_mbedTLS::get_curve_id() const
{
return crypto::get_curve_id(ctx.get());
}
PublicKey_mbedTLS::PublicKey_mbedTLS(mbedtls::PKContext&& c) :
ctx(std::move(c))
{}
bool PublicKey_mbedTLS::verify(
const uint8_t* contents,
size_t contents_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type,
HashBytes& bytes)
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
MBedHashProvider hp;
bytes = hp.Hash(contents, contents_size, md_type);
return verify_hash(bytes.data(), bytes.size(), sig, sig_size, md_type);
}
bool PublicKey_mbedTLS::verify_hash(
const uint8_t* hash,
size_t hash_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type)
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
const auto mmdt = get_md_type(md_type);
int rc = mbedtls_pk_verify(ctx.get(), mmdt, hash, hash_size, sig, sig_size);
if (rc)
LOG_DEBUG_FMT("Failed to verify signature: {}", error_string(rc));
return rc == 0;
}
Pem PublicKey_mbedTLS::public_key_pem() const
{
uint8_t data[max_pem_key_size];
int rc = mbedtls_pk_write_pubkey_pem(ctx.get(), data, max_pem_key_size);
if (rc != 0)
{
throw std::logic_error(
"mbedtls_pk_write_pubkey_pem: " + error_string(rc));
}
const size_t len = strlen((char const*)data);
return Pem(data, len);
}
mbedtls_pk_context* PublicKey_mbedTLS::get_raw_context() const
{
return ctx.get();
}
}

Просмотреть файл

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "../public_key.h"
#include "../san.h"
#include "mbedtls_wrappers.h"
namespace crypto
{
class PublicKey_mbedTLS : public PublicKey
{
protected:
mbedtls::PKContext ctx = mbedtls::make_unique<mbedtls::PKContext>();
PublicKey_mbedTLS();
CurveID get_curve_id() const;
public:
PublicKey_mbedTLS(PublicKey_mbedTLS&& pk) = default;
PublicKey_mbedTLS(mbedtls::PKContext&& c);
PublicKey_mbedTLS(const Pem& pem);
PublicKey_mbedTLS(const std::vector<uint8_t>& der);
virtual ~PublicKey_mbedTLS() = default;
using PublicKey::verify;
using PublicKey::verify_hash;
virtual bool verify(
const uint8_t* contents,
size_t contents_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type,
HashBytes& bytes) override;
virtual bool verify_hash(
const uint8_t* hash,
size_t hash_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type) override;
virtual Pem public_key_pem() const override;
mbedtls_pk_context* get_raw_context() const;
};
}

Просмотреть файл

@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "rsa_key_pair.h"
#include "entropy.h"
#include "mbedtls_wrappers.h"
namespace crypto
{
using namespace mbedtls;
RSAKeyPair_mbedTLS::RSAKeyPair_mbedTLS(
size_t public_key_size, size_t public_exponent)
{
EntropyPtr entropy = create_entropy();
int rc =
mbedtls_pk_setup(ctx.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_RSA));
if (rc != 0)
{
throw std::logic_error(
"Could not set up RSA context: " + error_string(rc));
}
rc = mbedtls_rsa_gen_key(
mbedtls_pk_rsa(*ctx.get()),
entropy->get_rng(),
entropy->get_data(),
public_key_size,
public_exponent);
if (rc != 0)
{
throw std::logic_error(
"Could not generate RSA keypair: " + error_string(rc));
}
}
RSAKeyPair_mbedTLS::RSAKeyPair_mbedTLS(mbedtls::PKContext&& k) :
RSAPublicKey_mbedTLS(std::move(k))
{}
RSAKeyPair_mbedTLS::RSAKeyPair_mbedTLS(const Pem& pem, CBuffer pw) :
RSAPublicKey_mbedTLS()
{
// keylen is +1 to include terminating null byte
int rc =
mbedtls_pk_parse_key(ctx.get(), pem.data(), pem.size(), pw.p, pw.n);
if (rc != 0)
{
throw std::logic_error(
"Could not parse private key: " + error_string(rc));
}
}
std::vector<uint8_t> RSAKeyPair_mbedTLS::unwrap(
const std::vector<uint8_t>& input, std::optional<std::string> label)
{
mbedtls_rsa_context* rsa_ctx = mbedtls_pk_rsa(*ctx.get());
mbedtls_rsa_set_padding(rsa_ctx, rsa_padding_mode, rsa_padding_digest_id);
std::vector<uint8_t> output_buf(rsa_ctx->len);
auto entropy = create_entropy();
const unsigned char* label_ = NULL;
size_t label_size = 0;
if (label.has_value())
{
label_ = reinterpret_cast<const unsigned char*>(label->c_str());
label_size = label->size();
}
size_t olen;
auto rc = mbedtls_rsa_rsaes_oaep_decrypt(
rsa_ctx,
entropy->get_rng(),
entropy->get_data(),
MBEDTLS_RSA_PRIVATE,
label_,
label_size,
&olen,
input.data(),
output_buf.data(),
output_buf.size());
if (rc != 0)
{
throw std::logic_error(
fmt::format("Error during RSA OEAP unwrap: {}", error_string(rc)));
}
output_buf.resize(olen);
return output_buf;
}
Pem RSAKeyPair_mbedTLS::public_key_pem() const
{
return PublicKey_mbedTLS::public_key_pem();
}
}

Просмотреть файл

@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/rsa_key_pair.h"
#include "mbedtls_wrappers.h"
#include "rsa_public_key.h"
#include <optional>
#include <vector>
namespace crypto
{
class RSAKeyPair_mbedTLS : public RSAPublicKey_mbedTLS, public RSAKeyPair
{
public:
RSAKeyPair_mbedTLS(
size_t public_key_size = default_public_key_size,
size_t public_exponent = default_public_exponent);
RSAKeyPair_mbedTLS(mbedtls::PKContext&& k);
RSAKeyPair_mbedTLS(const RSAKeyPair&) = delete;
RSAKeyPair_mbedTLS(const Pem& pem, CBuffer pw = nullb);
virtual ~RSAKeyPair_mbedTLS() = default;
virtual std::vector<uint8_t> unwrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt);
virtual Pem public_key_pem() const;
};
}

Просмотреть файл

@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "rsa_public_key.h"
#include "entropy.h"
#include "mbedtls_wrappers.h"
namespace crypto
{
using namespace mbedtls;
RSAPublicKey_mbedTLS::RSAPublicKey_mbedTLS(mbedtls::PKContext&& c) :
PublicKey_mbedTLS(std::move(c))
{}
RSAPublicKey_mbedTLS::RSAPublicKey_mbedTLS(const Pem& pem) :
PublicKey_mbedTLS(pem)
{
if (!mbedtls_pk_can_do(ctx.get(), MBEDTLS_PK_RSA))
{
throw std::logic_error("invalid RSA key");
}
}
RSAPublicKey_mbedTLS::RSAPublicKey_mbedTLS(const std::vector<uint8_t>& der) :
PublicKey_mbedTLS(der)
{
if (!mbedtls_pk_can_do(ctx.get(), MBEDTLS_PK_RSA))
{
throw std::logic_error("invalid RSA key");
}
}
std::vector<uint8_t> RSAPublicKey_mbedTLS::wrap(
const uint8_t* input,
size_t input_size,
const uint8_t* label,
size_t label_size)
{
mbedtls_rsa_context* rsa_ctx = mbedtls_pk_rsa(*ctx.get());
mbedtls_rsa_set_padding(rsa_ctx, rsa_padding_mode, rsa_padding_digest_id);
std::vector<uint8_t> output_buf(rsa_ctx->len);
auto entropy = create_entropy();
// Note that the maximum input size to wrap is k - 2*hLen - 2
// where hLen is the hash size (32 bytes = SHA256) and
// k the wrapping key modulus size (e.g. 256 bytes = 2048 bits).
// In this example, it would be 190 bytes (1520 bits) max.
// This is enough for wrapping AES keys for example.
auto rc = mbedtls_rsa_rsaes_oaep_encrypt(
rsa_ctx,
entropy->get_rng(),
entropy->get_data(),
MBEDTLS_RSA_PUBLIC,
label,
label_size,
input_size,
input,
output_buf.data());
if (rc != 0)
{
throw std::logic_error(
fmt::format("Error during RSA OEAP wrap: {}", error_string(rc)));
}
return output_buf;
}
std::vector<uint8_t> RSAPublicKey_mbedTLS::wrap(
const std::vector<uint8_t>& input, std::optional<std::string> label)
{
const unsigned char* label_ = NULL;
size_t label_size = 0;
if (label.has_value())
{
label_ = reinterpret_cast<const unsigned char*>(label->c_str());
label_size = label->size();
}
return wrap(input.data(), input.size(), label_, label_size);
}
Pem RSAPublicKey_mbedTLS::public_key_pem() const
{
return PublicKey_mbedTLS::public_key_pem();
}
}

Просмотреть файл

@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/rsa_public_key.h"
#include "mbedtls_wrappers.h"
#include "public_key.h"
#include <optional>
#include <string>
#include <vector>
namespace crypto
{
class RSAPublicKey_mbedTLS : public PublicKey_mbedTLS, public RSAPublicKey
{
public:
// Compatible with Azure HSM encryption schemes (see
// https://docs.microsoft.com/en-gb/azure/key-vault/keys/about-keys#wrapkeyunwrapkey-encryptdecrypt)
static constexpr auto rsa_padding_mode = MBEDTLS_RSA_PKCS_V21;
static constexpr auto rsa_padding_digest_id = MBEDTLS_MD_SHA256;
RSAPublicKey_mbedTLS() = default;
virtual ~RSAPublicKey_mbedTLS() = default;
RSAPublicKey_mbedTLS(crypto::mbedtls::PKContext&& c);
RSAPublicKey_mbedTLS(const Pem& pem);
RSAPublicKey_mbedTLS(const std::vector<uint8_t>& der);
virtual std::vector<uint8_t> wrap(
const uint8_t* input,
size_t input_size,
const uint8_t* label = nullptr,
size_t label_size = 0);
virtual std::vector<uint8_t> wrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt);
virtual Pem public_key_pem() const;
};
}

Просмотреть файл

@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "symmetric_key.h"
#include <mbedtls/aes.h>
#include <mbedtls/error.h>
namespace crypto
{
using namespace mbedtls;
KeyAesGcm_mbedTLS::KeyAesGcm_mbedTLS(CBuffer rawKey)
{
for (uint32_t i = 0; i < ctxs.size(); ++i)
{
ctxs[i] = mbedtls::make_unique<mbedtls::GcmContext>();
size_t n_bits;
const auto n = static_cast<unsigned int>(rawKey.rawSize() * 8);
if (n >= 256)
{
n_bits = 256;
}
else if (n >= 192)
{
n_bits = 192;
}
else if (n >= 128)
{
n_bits = 128;
}
else
{
throw std::logic_error(
fmt::format("Need at least {} bits, only have {}", 128, n));
}
int rc = mbedtls_gcm_setkey(
ctxs[i].get(), MBEDTLS_CIPHER_ID_AES, rawKey.p, n_bits);
if (rc != 0)
{
throw std::logic_error(error_string(rc));
}
}
}
void KeyAesGcm_mbedTLS::encrypt(
CBuffer iv,
CBuffer plain,
CBuffer aad,
uint8_t* cipher,
uint8_t tag[GCM_SIZE_TAG]) const
{
auto ctx = ctxs[threading::get_current_thread_id()].get();
int rc = mbedtls_gcm_crypt_and_tag(
ctx,
MBEDTLS_GCM_ENCRYPT,
plain.n,
iv.p,
iv.n,
aad.p,
aad.n,
plain.p,
cipher,
GCM_SIZE_TAG,
tag);
if (rc != 0)
{
throw std::logic_error(error_string(rc));
}
}
bool KeyAesGcm_mbedTLS::decrypt(
CBuffer iv,
const uint8_t tag[GCM_SIZE_TAG],
CBuffer cipher,
CBuffer aad,
uint8_t* plain) const
{
auto ctx = ctxs[threading::get_current_thread_id()].get();
return !mbedtls_gcm_auth_decrypt(
ctx,
cipher.n,
iv.p,
iv.n,
aad.p,
aad.n,
tag,
GCM_SIZE_TAG,
cipher.p,
plain);
}
}

Просмотреть файл

@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/symmetric_key.h"
#include "mbedtls_wrappers.h"
namespace crypto
{
class KeyAesGcm_mbedTLS : public KeyAesGcm
{
private:
mutable std::
array<mbedtls::GcmContext, threading::ThreadMessaging::max_num_threads>
ctxs;
public:
KeyAesGcm_mbedTLS(CBuffer rawKey);
KeyAesGcm_mbedTLS(const KeyAesGcm_mbedTLS& that) = delete;
KeyAesGcm_mbedTLS(KeyAesGcm_mbedTLS&& that);
virtual ~KeyAesGcm_mbedTLS() = default;
virtual void encrypt(
CBuffer iv,
CBuffer plain,
CBuffer aad,
uint8_t* cipher,
uint8_t tag[GCM_SIZE_TAG]) const override;
virtual bool decrypt(
CBuffer iv,
const uint8_t tag[GCM_SIZE_TAG],
CBuffer cipher,
CBuffer aad,
uint8_t* plain) const override;
};
}

Просмотреть файл

@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "verifier.h"
#include "public_key.h"
#include "rsa_key_pair.h"
#include <mbedtls/pem.h>
namespace crypto
{
using namespace mbedtls;
static constexpr size_t max_pem_cert_size = 4096;
// As these are not exposed by mbedTLS, define them here to allow simple
// conversion from DER to PEM format
static constexpr auto PEM_CERTIFICATE_HEADER =
"-----BEGIN CERTIFICATE-----\n";
static constexpr auto PEM_CERTIFICATE_FOOTER = "-----END CERTIFICATE-----\n";
MDType Verifier_mbedTLS::get_md_type(mbedtls_md_type_t mdt) const
{
switch (mdt)
{
case MBEDTLS_MD_NONE:
return MDType::NONE;
case MBEDTLS_MD_SHA1:
return MDType::SHA1;
case MBEDTLS_MD_SHA256:
return MDType::SHA256;
case MBEDTLS_MD_SHA384:
return MDType::SHA384;
case MBEDTLS_MD_SHA512:
return MDType::SHA512;
default:
return MDType::NONE;
}
return MDType::NONE;
}
Verifier_mbedTLS::Verifier_mbedTLS(const std::vector<uint8_t>& c) : Verifier()
{
cert = mbedtls::make_unique<mbedtls::X509Crt>();
int rc = mbedtls_x509_crt_parse(cert.get(), c.data(), c.size());
if (rc)
{
throw std::invalid_argument(
fmt::format("Failed to parse certificate: {}", error_string(rc)));
}
md_type = get_md_type(cert->sig_md);
// public_key expects to have unique ownership of the context and so does
// `cert`, so we duplicate the key context here.
unsigned char buf[2048];
rc = mbedtls_pk_write_pubkey_pem(&cert->pk, buf, sizeof(buf));
if (rc != 0)
{
throw std::runtime_error(
fmt::format("PEM export failed: {}", error_string(rc)));
}
Pem pem(buf, sizeof(buf));
if (mbedtls_pk_can_do(&cert->pk, MBEDTLS_PK_ECKEY))
{
public_key = std::make_unique<PublicKey_mbedTLS>(pem);
}
else if (mbedtls_pk_can_do(&cert->pk, MBEDTLS_PK_RSA))
{
public_key = std::make_unique<RSAPublicKey_mbedTLS>(pem);
}
else
{
throw std::logic_error("unsupported public key type");
}
}
std::vector<uint8_t> Verifier_mbedTLS::cert_der()
{
return {cert->raw.p, cert->raw.p + cert->raw.len};
}
Pem Verifier_mbedTLS::cert_pem()
{
unsigned char buf[max_pem_cert_size];
size_t len;
auto rc = mbedtls_pem_write_buffer(
PEM_CERTIFICATE_HEADER,
PEM_CERTIFICATE_FOOTER,
cert->raw.p,
cert->raw.len,
buf,
max_pem_cert_size,
&len);
if (rc != 0)
{
throw std::logic_error(
"mbedtls_pem_write_buffer failed: " + error_string(rc));
}
return Pem(buf, len);
}
}

Просмотреть файл

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/verifier.h"
#include "mbedtls_wrappers.h"
namespace crypto
{
class Verifier_mbedTLS : public Verifier
{
protected:
mutable mbedtls::X509Crt cert;
MDType get_md_type(mbedtls_md_type_t mdt) const;
public:
Verifier_mbedTLS(const std::vector<uint8_t>& c);
Verifier_mbedTLS(const Verifier_mbedTLS&) = delete;
virtual ~Verifier_mbedTLS() = default;
virtual std::vector<uint8_t> cert_der() override;
virtual Pem cert_pem() override;
};
}

Просмотреть файл

@ -1,89 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include <mbedtls/ctr_drbg.h>
#include <mbedtls/entropy.h>
#include <mbedtls/gcm.h>
#include <mbedtls/net_sockets.h>
#include <mbedtls/sha256.h>
#include <mbedtls/ssl.h>
#include <mbedtls/x509.h>
#include <mbedtls/x509_crt.h>
#include <mbedtls/x509_csr.h>
#include <memory>
namespace mbedtls
{
template <typename T>
T make_unique();
#define DEFINE_MBEDTLS_WRAPPER( \
NEW_TYPE, MBED_TYPE, MBED_INIT_FN, MBED_FREE_FN) \
struct NEW_TYPE##Deleter \
{ \
void operator()(MBED_TYPE* ptr) \
{ \
MBED_FREE_FN(ptr); \
delete ptr; \
} \
}; \
using NEW_TYPE = std::unique_ptr<MBED_TYPE, NEW_TYPE##Deleter>; \
template <> \
inline NEW_TYPE make_unique<NEW_TYPE>() \
{ \
auto p = new MBED_TYPE; \
MBED_INIT_FN(p); \
return NEW_TYPE(p); \
}
DEFINE_MBEDTLS_WRAPPER(
CtrDrbg,
mbedtls_ctr_drbg_context,
mbedtls_ctr_drbg_init,
mbedtls_ctr_drbg_free);
DEFINE_MBEDTLS_WRAPPER(
ECDHContext, mbedtls_ecdh_context, mbedtls_ecdh_init, mbedtls_ecdh_free);
DEFINE_MBEDTLS_WRAPPER(
Entropy,
mbedtls_entropy_context,
mbedtls_entropy_init,
mbedtls_entropy_free);
DEFINE_MBEDTLS_WRAPPER(
GcmContext, mbedtls_gcm_context, mbedtls_gcm_init, mbedtls_gcm_free);
DEFINE_MBEDTLS_WRAPPER(MPI, mbedtls_mpi, mbedtls_mpi_init, mbedtls_mpi_free);
DEFINE_MBEDTLS_WRAPPER(
NetContext, mbedtls_net_context, mbedtls_net_init, mbedtls_net_free);
DEFINE_MBEDTLS_WRAPPER(
PKContext, mbedtls_pk_context, mbedtls_pk_init, mbedtls_pk_free);
DEFINE_MBEDTLS_WRAPPER(
SSLConfig,
mbedtls_ssl_config,
mbedtls_ssl_config_init,
mbedtls_ssl_config_free);
DEFINE_MBEDTLS_WRAPPER(
SSLContext, mbedtls_ssl_context, mbedtls_ssl_init, mbedtls_ssl_free);
DEFINE_MBEDTLS_WRAPPER(
X509Crl, mbedtls_x509_crl, mbedtls_x509_crl_init, mbedtls_x509_crl_free);
DEFINE_MBEDTLS_WRAPPER(
X509Crt, mbedtls_x509_crt, mbedtls_x509_crt_init, mbedtls_x509_crt_free);
DEFINE_MBEDTLS_WRAPPER(
X509Csr, mbedtls_x509_csr, mbedtls_x509_csr_init, mbedtls_x509_csr_free);
DEFINE_MBEDTLS_WRAPPER(
X509WriteCrt,
mbedtls_x509write_cert,
mbedtls_x509write_crt_init,
mbedtls_x509write_crt_free);
DEFINE_MBEDTLS_WRAPPER(
X509WriteCsr,
mbedtls_x509write_csr,
mbedtls_x509write_csr_init,
mbedtls_x509write_csr_free);
DEFINE_MBEDTLS_WRAPPER(
SHA256Ctx,
mbedtls_sha256_context,
mbedtls_sha256_init,
mbedtls_sha256_free);
#undef DEFINE_MBEDTLS_WRAPPER
}

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "hash.h"
#include <openssl/sha.h>
#include <stdexcept>
namespace crypto
{
using namespace OpenSSL;
void openssl_sha256(const CBuffer& data, uint8_t* h)
{
SHA256_CTX ctx;
SHA256_Init(&ctx);
SHA256_Update(&ctx, data.p, data.rawSize());
SHA256_Final(h, &ctx);
}
ISha256OpenSSL::ISha256OpenSSL()
{
ctx = new SHA256_CTX;
SHA256_Init((SHA256_CTX*)ctx);
}
ISha256OpenSSL::~ISha256OpenSSL()
{
delete (SHA256_CTX*)ctx;
}
void ISha256OpenSSL::update_hash(CBuffer data)
{
if (!ctx)
{
throw std::logic_error("Attempting to use hash after it was finalised");
}
SHA256_Update((SHA256_CTX*)ctx, data.p, data.rawSize());
}
Sha256Hash ISha256OpenSSL::finalise()
{
if (!ctx)
{
throw std::logic_error("Attempting to use hash after it was finalised");
}
Sha256Hash r;
SHA256_Final(r.h.data(), (SHA256_CTX*)ctx);
delete (SHA256_CTX*)ctx;
ctx = nullptr;
return r;
}
}

68
src/crypto/openssl/hash.h Normal file
Просмотреть файл

@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/hash_base.h"
#include <openssl/evp.h>
#include <openssl/sha.h>
#define FMT_HEADER_ONLY
#include <fmt/format.h>
#include <msgpack/msgpack.hpp>
namespace crypto
{
namespace OpenSSL
{
inline const EVP_MD* get_md_type(MDType type)
{
switch (type)
{
case MDType::NONE:
return nullptr;
case MDType::SHA1:
return EVP_sha1();
case MDType::SHA256:
return EVP_sha256();
case MDType::SHA384:
return EVP_sha384();
case MDType::SHA512:
return EVP_sha512();
default:
throw std::runtime_error("Unsupported hash algorithm");
}
return nullptr;
}
}
class OpenSSLHashProvider : public HashProviderBase
{
public:
virtual HashBytes Hash(const uint8_t* data, size_t size, MDType type) const
{
auto o_md_type = OpenSSL::get_md_type(type);
HashBytes r(EVP_MD_size(o_md_type));
unsigned int len = 0;
if (EVP_Digest(data, size, r.data(), &len, o_md_type, NULL) != 1)
throw std::runtime_error("OpenSSL hash update error");
return r;
}
};
class ISha256OpenSSL : public ISha256HashBase
{
public:
ISha256OpenSSL();
~ISha256OpenSSL();
virtual void update_hash(CBuffer data);
virtual Sha256Hash finalise();
protected:
void* ctx;
};
void openssl_sha256(const CBuffer& data, uint8_t* h);
}

Просмотреть файл

@ -0,0 +1,320 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "key_pair.h"
#include "crypto/curve.h"
#include "hash.h"
#include "openssl_wrappers.h"
#include <openssl/ec.h>
#include <openssl/engine.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/x509v3.h>
#include <stdexcept>
#include <string>
namespace crypto
{
using namespace OpenSSL;
static inline int get_openssl_group_id(CurveID gid)
{
switch (gid)
{
case CurveID::NONE:
return NID_undef;
case CurveID::SECP384R1:
return NID_secp384r1;
case CurveID::SECP256R1:
return NID_X9_62_prime256v1;
default:
throw std::logic_error(
fmt::format("unsupported OpenSSL CurveID {}", gid));
}
return MBEDTLS_ECP_DP_NONE;
}
static std::vector<std::pair<std::string, std::string>> parse_name(
const std::string& name)
{
std::vector<std::pair<std::string, std::string>> r;
char* name_cpy = strdup(name.c_str());
char* p = std::strtok(name_cpy, ",");
while (p)
{
char* eq = strchr(p, '=');
*eq = '\0';
r.push_back(std::make_pair(p, eq + 1));
p = std::strtok(NULL, ",");
}
free(name_cpy);
return r;
}
KeyPair_OpenSSL::KeyPair_OpenSSL(CurveID curve_id)
{
int curve_nid = get_openssl_group_id(curve_id);
key = EVP_PKEY_new();
Unique_EVP_PKEY_CTX pkctx;
if (
EVP_PKEY_paramgen_init(pkctx) < 0 ||
EVP_PKEY_CTX_set_ec_paramgen_curve_nid(pkctx, curve_nid) < 0 ||
EVP_PKEY_CTX_set_ec_param_enc(pkctx, OPENSSL_EC_NAMED_CURVE) < 0)
throw std::runtime_error("could not initialize PK context");
if (EVP_PKEY_keygen_init(pkctx) < 0 || EVP_PKEY_keygen(pkctx, &key) < 0)
throw std::runtime_error("could not generate new EC key");
}
KeyPair_OpenSSL::KeyPair_OpenSSL(const Pem& pem, CBuffer pw)
{
Unique_BIO mem(pem.data(), -1);
key = PEM_read_bio_PrivateKey(mem, NULL, NULL, (void*)pw.p);
if (!key)
throw std::runtime_error("could not parse PEM");
}
Pem KeyPair_OpenSSL::private_key_pem() const
{
Unique_BIO buf;
OpenSSL::CHECK1(
PEM_write_bio_PrivateKey(buf, key, NULL, NULL, 0, NULL, NULL));
BUF_MEM* bptr;
BIO_get_mem_ptr(buf, &bptr);
return Pem((uint8_t*)bptr->data, bptr->length);
}
Pem KeyPair_OpenSSL::public_key_pem() const
{
return PublicKey_OpenSSL::public_key_pem();
}
bool KeyPair_OpenSSL::verify(
const std::vector<uint8_t>& contents, const std::vector<uint8_t>& signature)
{
return PublicKey_OpenSSL::verify(contents, signature);
}
std::vector<uint8_t> KeyPair_OpenSSL::sign(CBuffer d, MDType md_type) const
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
OpenSSLHashProvider hp;
HashBytes hash = hp.Hash(d.p, d.rawSize(), md_type);
return sign_hash(hash.data(), hash.size());
}
int KeyPair_OpenSSL::sign(
CBuffer d, size_t* sig_size, uint8_t* sig, MDType md_type) const
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
OpenSSLHashProvider hp;
HashBytes hash = hp.Hash(d.p, d.rawSize(), md_type);
return sign_hash(hash.data(), hash.size(), sig_size, sig);
}
std::vector<uint8_t> KeyPair_OpenSSL::sign_hash(
const uint8_t* hash, size_t hash_size) const
{
std::vector<uint8_t> sig(EVP_PKEY_size(key));
size_t written = sig.size();
if (sign_hash(hash, hash_size, &written, sig.data()) != 0)
{
return {};
}
sig.resize(written);
return sig;
}
int KeyPair_OpenSSL::sign_hash(
const uint8_t* hash, size_t hash_size, size_t* sig_size, uint8_t* sig) const
{
Unique_EVP_PKEY_CTX pctx(key);
OpenSSL::CHECK1(EVP_PKEY_sign_init(pctx));
OpenSSL::CHECK1(EVP_PKEY_sign(pctx, sig, sig_size, hash, hash_size));
return 0;
}
Pem KeyPair_OpenSSL::create_csr(const std::string& name) const
{
Unique_X509_REQ req;
OpenSSL::CHECK1(X509_REQ_set_pubkey(req, key));
X509_NAME* subj_name = NULL;
OpenSSL::CHECKNULL(subj_name = X509_NAME_new());
for (auto kv : parse_name(name))
{
OpenSSL::CHECK1(X509_NAME_add_entry_by_txt(
subj_name,
kv.first.c_str(),
MBSTRING_ASC,
(const unsigned char*)kv.second.c_str(),
-1,
-1,
0));
}
OpenSSL::CHECK1(X509_REQ_set_subject_name(req, subj_name));
X509_NAME_free(subj_name);
if (key)
OpenSSL::CHECK1(X509_REQ_sign(req, key, EVP_sha512()));
Unique_BIO mem;
OpenSSL::CHECK1(PEM_write_bio_X509_REQ(mem, req));
BUF_MEM* bptr;
BIO_get_mem_ptr(mem, &bptr);
Pem result((uint8_t*)bptr->data, bptr->length);
return result;
}
Pem KeyPair_OpenSSL::sign_csr(
const Pem& issuer_cert,
const Pem& signing_request,
const std::vector<SubjectAltName> subject_alt_names,
bool ca) const
{
X509* icrt = NULL;
Unique_BIO mem(signing_request.data(), -1);
Unique_X509_REQ csr(mem);
Unique_X509 crt;
OpenSSL::CHECK1(X509_set_version(crt, 2));
// Add serial number
unsigned char rndbytes[16];
OpenSSL::CHECK1(RAND_bytes(rndbytes, sizeof(rndbytes)));
BIGNUM* bn = NULL;
OpenSSL::CHECKNULL(bn = BN_new());
BN_bin2bn(rndbytes, sizeof(rndbytes), bn);
ASN1_INTEGER* serial = ASN1_INTEGER_new();
BN_to_ASN1_INTEGER(bn, serial);
OpenSSL::CHECK1(X509_set_serialNumber(crt, serial));
ASN1_INTEGER_free(serial);
BN_free(bn);
// Add issuer name
if (!issuer_cert.empty())
{
Unique_BIO imem(issuer_cert.data(), -1);
OpenSSL::CHECKNULL(icrt = PEM_read_bio_X509(imem, NULL, NULL, NULL));
OpenSSL::CHECK1(X509_set_issuer_name(crt, X509_get_subject_name(icrt)));
}
else
{
OpenSSL::CHECK1(
X509_set_issuer_name(crt, X509_REQ_get_subject_name(csr)));
}
// Note: 825-day validity range
// https://support.apple.com/en-us/HT210176
ASN1_TIME *before = NULL, *after = NULL;
OpenSSL::CHECKNULL(before = ASN1_TIME_new());
OpenSSL::CHECKNULL(after = ASN1_TIME_new());
OpenSSL::CHECK1(ASN1_TIME_set_string(before, "20191101000000Z"));
OpenSSL::CHECK1(ASN1_TIME_set_string(after, "20211231235959Z"));
X509_set1_notBefore(crt, before);
X509_set1_notAfter(crt, after);
ASN1_TIME_free(before);
ASN1_TIME_free(after);
X509_set_subject_name(crt, X509_REQ_get_subject_name(csr));
EVP_PKEY* req_pubkey = X509_REQ_get_pubkey(csr);
X509_set_pubkey(crt, req_pubkey);
EVP_PKEY_free(req_pubkey);
// Extensions
X509V3_CTX v3ctx;
X509V3_set_ctx_nodb(&v3ctx);
X509V3_set_ctx(&v3ctx, icrt ? icrt : crt, NULL, csr, NULL, 0);
// Add basic constraints
X509_EXTENSION* ext = NULL;
OpenSSL::CHECKNULL(
ext = X509V3_EXT_conf_nid(
NULL, &v3ctx, NID_basic_constraints, ca ? "CA:TRUE" : "CA:FALSE"));
OpenSSL::CHECK1(X509_add_ext(crt, ext, -1));
X509_EXTENSION_free(ext);
// Add subject key identifier
OpenSSL::CHECKNULL(
ext =
X509V3_EXT_conf_nid(NULL, &v3ctx, NID_subject_key_identifier, "hash"));
OpenSSL::CHECK1(X509_add_ext(crt, ext, -1));
X509_EXTENSION_free(ext);
// Add authority key identifier
OpenSSL::CHECKNULL(
ext = X509V3_EXT_conf_nid(
NULL, &v3ctx, NID_authority_key_identifier, "keyid:always"));
OpenSSL::CHECK1(X509_add_ext(crt, ext, -1));
X509_EXTENSION_free(ext);
// Subject alternative names (Necessary? Shouldn't they be in the CSR?)
if (!subject_alt_names.empty())
{
std::string all_alt_names;
bool first = true;
for (auto san : subject_alt_names)
{
if (first)
{
first = !first;
}
else
{
all_alt_names += ", ";
}
if (san.is_ip)
all_alt_names += "IP:";
else
all_alt_names += "DNS:";
all_alt_names += san.san;
}
OpenSSL::CHECKNULL(
ext = X509V3_EXT_conf_nid(
NULL, &v3ctx, NID_subject_alt_name, all_alt_names.c_str()));
OpenSSL::CHECK1(X509_add_ext(crt, ext, -1));
X509_EXTENSION_free(ext);
}
// Sign
auto md = get_md_type(get_md_for_ec(get_curve_id()));
int size = X509_sign(crt, key, md);
if (size <= 0)
throw std::runtime_error("could not sign CRT");
Unique_BIO omem;
OpenSSL::CHECK1(PEM_write_bio_X509(omem, crt));
// Export
BUF_MEM* bptr;
BIO_get_mem_ptr(omem, &bptr);
Pem result((uint8_t*)bptr->data, bptr->length);
if (icrt)
X509_free(icrt);
return result;
}
}

Просмотреть файл

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "../key_pair.h"
#include "openssl_wrappers.h"
#include "public_key.h"
#include <stdexcept>
#include <string>
namespace crypto
{
class KeyPair_OpenSSL : public PublicKey_OpenSSL, public KeyPair
{
public:
KeyPair_OpenSSL(CurveID curve_id);
KeyPair_OpenSSL(const KeyPair_OpenSSL&) = delete;
KeyPair_OpenSSL(const Pem& pem, CBuffer pw = nullb);
virtual ~KeyPair_OpenSSL() = default;
virtual Pem private_key_pem() const override;
virtual Pem public_key_pem() const override;
using PublicKey_OpenSSL::verify;
virtual bool verify(
const std::vector<uint8_t>& contents,
const std::vector<uint8_t>& signature) override;
virtual std::vector<uint8_t> sign(
CBuffer d, MDType md_type = {}) const override;
int sign(
CBuffer d, size_t* sig_size, uint8_t* sig, MDType md_type = {}) const;
std::vector<uint8_t> sign_hash(
const uint8_t* hash, size_t hash_size) const override;
virtual int sign_hash(
const uint8_t* hash,
size_t hash_size,
size_t* sig_size,
uint8_t* sig) const override;
virtual Pem create_csr(const std::string& name) const override;
virtual Pem sign_csr(
const Pem& issuer_cert,
const Pem& signing_request,
const std::vector<SubjectAltName> subject_alt_names,
bool ca = false) const override;
};
}

Просмотреть файл

@ -0,0 +1,147 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#define FMT_HEADER_ONLY
#include <fmt/format.h>
#include <memory>
#include <openssl/ec.h>
#include <openssl/engine.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/x509v3.h>
namespace crypto
{
namespace OpenSSL
{
inline void CHECK1(int rc)
{
unsigned long ec = ERR_get_error();
if (rc != 1 && ec != 0)
{
throw std::runtime_error(
fmt::format("OpenSSL error: {}", ERR_error_string(ec, NULL)));
}
}
inline void CHECKNULL(void* ptr)
{
if (ptr == NULL)
{
throw std::runtime_error("OpenSSL error: missing object");
}
}
class Unique_BIO
{
std::unique_ptr<BIO, void (*)(BIO*)> p;
public:
Unique_BIO() : p(BIO_new(BIO_s_mem()), [](auto x) { BIO_free(x); })
{
if (!p)
throw std::runtime_error("out of memory");
}
Unique_BIO(const void* buf, int len) :
p(BIO_new_mem_buf(buf, len), [](auto x) { BIO_free(x); })
{
if (!p)
throw std::runtime_error("out of memory");
}
operator BIO*()
{
return p.get();
}
};
class Unique_EVP_PKEY_CTX
{
std::unique_ptr<EVP_PKEY_CTX, void (*)(EVP_PKEY_CTX*)> p;
public:
Unique_EVP_PKEY_CTX(EVP_PKEY* key) :
p(EVP_PKEY_CTX_new(key, NULL), EVP_PKEY_CTX_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
Unique_EVP_PKEY_CTX() :
p(EVP_PKEY_CTX_new_id(EVP_PKEY_EC, NULL), EVP_PKEY_CTX_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
operator EVP_PKEY_CTX*()
{
return p.get();
}
};
class Unique_X509_REQ
{
std::unique_ptr<X509_REQ, void (*)(X509_REQ*)> p;
public:
Unique_X509_REQ() : p(X509_REQ_new(), X509_REQ_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
Unique_X509_REQ(BIO* mem) :
p(PEM_read_bio_X509_REQ(mem, NULL, NULL, NULL), X509_REQ_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
operator X509_REQ*()
{
return p.get();
}
};
class Unique_X509
{
std::unique_ptr<X509, void (*)(X509*)> p;
public:
Unique_X509() : p(X509_new(), X509_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
Unique_X509(BIO* mem) :
p(PEM_read_bio_X509(mem, NULL, NULL, NULL), X509_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
operator X509*()
{
return p.get();
}
};
class Unique_EVP_CIPHER_CTX
{
std::unique_ptr<EVP_CIPHER_CTX, void (*)(EVP_CIPHER_CTX*)> p;
public:
Unique_EVP_CIPHER_CTX() : p(EVP_CIPHER_CTX_new(), EVP_CIPHER_CTX_free)
{
if (!p)
throw std::runtime_error("out of memory");
}
operator EVP_CIPHER_CTX*()
{
return p.get();
}
};
inline std::string error_string(int ec)
{
return ERR_error_string((unsigned long)ec, NULL);
}
}
}

Просмотреть файл

@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "public_key.h"
#include "openssl_wrappers.h"
#include <openssl/ec.h>
#include <openssl/engine.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/x509v3.h>
#include <stdexcept>
#include <string>
namespace crypto
{
using namespace OpenSSL;
PublicKey_OpenSSL::PublicKey_OpenSSL() {}
PublicKey_OpenSSL::PublicKey_OpenSSL(const Pem& pem)
{
Unique_BIO mem(pem.data(), -1);
key = PEM_read_bio_PUBKEY(mem, NULL, NULL, NULL);
if (!key)
throw std::runtime_error("could not parse PEM");
}
PublicKey_OpenSSL::PublicKey_OpenSSL(const std::vector<uint8_t>& der)
{
const unsigned char* pp = der.data();
key = d2i_PublicKey(EVP_PKEY_EC, &key, &pp, der.size());
if (!key)
{
throw new std::runtime_error("Could not read DER");
}
}
PublicKey_OpenSSL::PublicKey_OpenSSL(EVP_PKEY* key) : key(key) {}
PublicKey_OpenSSL::~PublicKey_OpenSSL()
{
if (key)
EVP_PKEY_free(key);
}
CurveID PublicKey_OpenSSL::get_curve_id() const
{
int nid =
EC_GROUP_get_curve_name(EC_KEY_get0_group(EVP_PKEY_get0_EC_KEY(key)));
switch (nid)
{
case NID_secp384r1:
return CurveID::SECP384R1;
case NID_X9_62_prime256v1:
return CurveID::SECP256R1;
default:
throw std::runtime_error(fmt::format("Unknown OpenSSL curve {}", nid));
}
return CurveID::NONE;
}
bool PublicKey_OpenSSL::verify(
const uint8_t* contents,
size_t contents_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type,
HashBytes& bytes)
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
OpenSSLHashProvider hp;
bytes = hp.Hash(contents, contents_size, md_type);
return verify_hash(bytes.data(), bytes.size(), sig, sig_size, md_type);
}
bool PublicKey_OpenSSL::verify_hash(
const uint8_t* hash,
size_t hash_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type)
{
if (md_type == MDType::NONE)
{
md_type = get_md_for_ec(get_curve_id());
}
Unique_EVP_PKEY_CTX pctx(key);
OpenSSL::CHECK1(EVP_PKEY_verify_init(pctx));
if (md_type != MDType::NONE)
{
OpenSSL::CHECK1(
EVP_PKEY_CTX_set_signature_md(pctx, get_md_type(md_type)));
}
int rc = EVP_PKEY_verify(pctx, sig, sig_size, hash, hash_size);
bool ok = rc == 1;
if (!ok)
{
int ec = ERR_get_error();
LOG_DEBUG_FMT(
"OpenSSL signature verification failure: {}",
ERR_error_string(ec, NULL));
}
return ok;
}
Pem PublicKey_OpenSSL::public_key_pem() const
{
Unique_BIO buf;
OpenSSL::CHECK1(PEM_write_bio_PUBKEY(buf, key));
BUF_MEM* bptr;
BIO_get_mem_ptr(buf, &bptr);
return Pem((uint8_t*)bptr->data, bptr->length);
}
}

Просмотреть файл

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "../public_key.h"
#include <openssl/err.h>
#include <openssl/evp.h>
#include <stdexcept>
#include <string>
namespace crypto
{
class PublicKey_OpenSSL : public PublicKey
{
protected:
EVP_PKEY* key = nullptr;
PublicKey_OpenSSL();
CurveID get_curve_id() const;
public:
PublicKey_OpenSSL(PublicKey_OpenSSL&& key) = default;
PublicKey_OpenSSL(EVP_PKEY* key);
PublicKey_OpenSSL(const Pem& pem);
PublicKey_OpenSSL(const std::vector<uint8_t>& der);
virtual ~PublicKey_OpenSSL();
using PublicKey::verify;
using PublicKey::verify_hash;
virtual bool verify(
const uint8_t* contents,
size_t contents_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type,
HashBytes& bytes) override;
virtual bool verify_hash(
const uint8_t* hash,
size_t hash_size,
const uint8_t* sig,
size_t sig_size,
MDType md_type) override;
virtual Pem public_key_pem() const override;
};
}

Просмотреть файл

@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "rsa_key_pair.h"
#include "openssl_wrappers.h"
namespace crypto
{
using namespace OpenSSL;
RSAKeyPair_OpenSSL::RSAKeyPair_OpenSSL(
size_t public_key_size, size_t public_exponent)
{
RSA* rsa = NULL;
BIGNUM* big_exp = NULL;
OpenSSL::CHECKNULL(big_exp = BN_new());
OpenSSL::CHECK1(BN_set_word(big_exp, public_exponent));
OpenSSL::CHECKNULL(rsa = RSA_new());
OpenSSL::CHECK1(RSA_generate_key_ex(rsa, public_key_size, big_exp, NULL));
OpenSSL::CHECKNULL(key = EVP_PKEY_new());
OpenSSL::CHECK1(EVP_PKEY_set1_RSA(key, rsa));
BN_free(big_exp);
RSA_free(rsa);
}
RSAKeyPair_OpenSSL::RSAKeyPair_OpenSSL(EVP_PKEY* k) :
RSAPublicKey_OpenSSL(std::move(k))
{}
RSAKeyPair_OpenSSL::RSAKeyPair_OpenSSL(const Pem& pem, CBuffer pw)
{
Unique_BIO mem(pem.data(), -1);
key = PEM_read_bio_PrivateKey(mem, NULL, NULL, (void*)pw.p);
if (!key)
{
throw std::runtime_error("could not parse PEM");
}
}
std::vector<uint8_t> RSAKeyPair_OpenSSL::unwrap(
const std::vector<uint8_t>& input, std::optional<std::string> label)
{
const unsigned char* label_ = NULL;
size_t label_size = 0;
if (label.has_value())
{
label_ = reinterpret_cast<const unsigned char*>(label->c_str());
label_size = label->size();
}
Unique_EVP_PKEY_CTX ctx(key);
OpenSSL::CHECK1(EVP_PKEY_decrypt_init(ctx));
EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING);
EVP_PKEY_CTX_set_rsa_oaep_md(ctx, EVP_sha256());
EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, EVP_sha256());
if (label_)
{
unsigned char* openssl_label = (unsigned char*)OPENSSL_malloc(label_size);
std::copy(label_, label_ + label_size, openssl_label);
EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, openssl_label, label_size);
}
else
{
EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, NULL, 0);
}
size_t olen;
OpenSSL::CHECK1(
EVP_PKEY_decrypt(ctx, NULL, &olen, input.data(), input.size()));
std::vector<uint8_t> output(olen);
OpenSSL::CHECK1(
EVP_PKEY_decrypt(ctx, output.data(), &olen, input.data(), input.size()));
output.resize(olen);
return output;
}
Pem RSAKeyPair_OpenSSL::public_key_pem() const
{
return PublicKey_OpenSSL::public_key_pem();
}
}

Просмотреть файл

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/rsa_key_pair.h"
#include "rsa_public_key.h"
#include <optional>
#include <vector>
namespace crypto
{
class RSAKeyPair_OpenSSL : public RSAPublicKey_OpenSSL, public RSAKeyPair
{
public:
RSAKeyPair_OpenSSL(
size_t public_key_size = default_public_key_size,
size_t public_exponent = default_public_exponent);
RSAKeyPair_OpenSSL(EVP_PKEY* k);
RSAKeyPair_OpenSSL(const RSAKeyPair&) = delete;
RSAKeyPair_OpenSSL(const Pem& pem, CBuffer pw = nullb);
virtual ~RSAKeyPair_OpenSSL() = default;
virtual std::vector<uint8_t> unwrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt);
virtual Pem public_key_pem() const;
};
}

Просмотреть файл

@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "openssl_wrappers.h"
#include "rsa_key_pair.h"
namespace crypto
{
using namespace OpenSSL;
RSAPublicKey_OpenSSL::RSAPublicKey_OpenSSL(EVP_PKEY* c) : PublicKey_OpenSSL(c)
{
if (!EVP_PKEY_get0_RSA(key))
{
throw std::logic_error("invalid RSA key");
}
}
RSAPublicKey_OpenSSL::RSAPublicKey_OpenSSL(const Pem& pem)
{
Unique_BIO mem(pem.data(), -1);
key = PEM_read_bio_PUBKEY(mem, NULL, NULL, NULL);
if (!key || !EVP_PKEY_get0_RSA(key))
{
throw std::logic_error("invalid RSA key");
}
}
RSAPublicKey_OpenSSL::RSAPublicKey_OpenSSL(const std::vector<uint8_t>& der)
{
const unsigned char* pp = der.data();
RSA* rsa = NULL;
if (
((rsa = d2i_RSA_PUBKEY(NULL, &pp, der.size())) ==
NULL) && // "SubjectPublicKeyInfo structure" format
((rsa = d2i_RSAPublicKey(NULL, &pp, der.size())) ==
NULL)) // PKCS#1 structure format
{
unsigned long ec = ERR_get_error();
const char* msg = ERR_error_string(ec, NULL);
throw new std::runtime_error(fmt::format("OpenSSL error: {}", msg));
}
key = EVP_PKEY_new();
OpenSSL::CHECK1(EVP_PKEY_set1_RSA(key, rsa));
RSA_free(rsa);
}
std::vector<uint8_t> RSAPublicKey_OpenSSL::wrap(
const uint8_t* input,
size_t input_size,
const uint8_t* label,
size_t label_size)
{
Unique_EVP_PKEY_CTX ctx(key);
OpenSSL::CHECK1(EVP_PKEY_encrypt_init(ctx));
EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING);
EVP_PKEY_CTX_set_rsa_oaep_md(ctx, EVP_sha256());
EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, EVP_sha256());
if (label)
{
unsigned char* openssl_label = (unsigned char*)OPENSSL_malloc(label_size);
std::copy(label, label + label_size, openssl_label);
EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, openssl_label, label_size);
}
else
{
EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, NULL, 0);
}
size_t olen;
OpenSSL::CHECK1(EVP_PKEY_encrypt(ctx, NULL, &olen, input, input_size));
std::vector<uint8_t> output(olen);
OpenSSL::CHECK1(
EVP_PKEY_encrypt(ctx, output.data(), &olen, input, input_size));
output.resize(olen);
return output;
}
std::vector<uint8_t> RSAPublicKey_OpenSSL::wrap(
const std::vector<uint8_t>& input, std::optional<std::string> label)
{
const unsigned char* label_ = NULL;
size_t label_size = 0;
if (label.has_value())
{
label_ = reinterpret_cast<const unsigned char*>(label->c_str());
label_size = label->size();
}
return wrap(input.data(), input.size(), label_, label_size);
}
Pem RSAPublicKey_OpenSSL::public_key_pem() const
{
return PublicKey_OpenSSL::public_key_pem();
}
}

Просмотреть файл

@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/rsa_public_key.h"
#include "key_pair.h"
#include <optional>
#include <string>
#include <vector>
namespace crypto
{
class RSAPublicKey_OpenSSL : public PublicKey_OpenSSL, public RSAPublicKey
{
public:
RSAPublicKey_OpenSSL() = default;
RSAPublicKey_OpenSSL(EVP_PKEY* c);
RSAPublicKey_OpenSSL(const Pem& pem);
RSAPublicKey_OpenSSL(const std::vector<uint8_t>& der);
virtual ~RSAPublicKey_OpenSSL() = default;
virtual std::vector<uint8_t> wrap(
const uint8_t* input,
size_t input_size,
const uint8_t* label = nullptr,
size_t label_size = 0);
virtual std::vector<uint8_t> wrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt);
virtual Pem public_key_pem() const;
};
}

Просмотреть файл

@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "symmetric_key.h"
#include "../mbedtls/symmetric_key.h"
#include "crypto/openssl/openssl_wrappers.h"
#include "crypto/symmetric_key.h"
#include "ds/logger.h"
#include "ds/thread_messaging.h"
#include <openssl/aes.h>
#include <openssl/evp.h>
namespace crypto
{
using namespace OpenSSL;
KeyAesGcm_OpenSSL::KeyAesGcm_OpenSSL(CBuffer rawKey) :
key(std::vector<uint8_t>(rawKey.p, rawKey.p + rawKey.n)),
evp_cipher(nullptr)
{
const auto n = static_cast<unsigned int>(rawKey.rawSize() * 8);
if (n >= 256)
{
evp_cipher = EVP_aes_256_gcm();
}
else if (n >= 192)
{
evp_cipher = EVP_aes_192_gcm();
}
else if (n >= 128)
{
evp_cipher = EVP_aes_128_gcm();
}
else
{
throw std::logic_error(
fmt::format("Need at least {} bits, only have {}", 128, n));
}
}
void KeyAesGcm_OpenSSL::encrypt(
CBuffer iv,
CBuffer plain,
CBuffer aad,
uint8_t* cipher,
uint8_t tag[GCM_SIZE_TAG]) const
{
std::vector<uint8_t> cb(plain.n + GCM_SIZE_TAG);
int len = 0;
Unique_EVP_CIPHER_CTX ctx;
CHECK1(EVP_EncryptInit_ex(ctx, evp_cipher, NULL, key.data(), iv.p));
CHECK1(EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv.n, NULL));
CHECK1(EVP_EncryptInit_ex(ctx, NULL, NULL, key.data(), iv.p));
if (aad.n > 0)
CHECK1(EVP_EncryptUpdate(ctx, NULL, &len, aad.p, aad.n));
CHECK1(EVP_EncryptUpdate(ctx, cb.data(), &len, plain.p, plain.n));
CHECK1(EVP_EncryptFinal_ex(ctx, cb.data() + len, &len));
CHECK1(
EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, GCM_SIZE_TAG, &tag[0]));
if (plain.n > 0)
memcpy(cipher, cb.data(), plain.n);
}
bool KeyAesGcm_OpenSSL::decrypt(
CBuffer iv,
const uint8_t tag[GCM_SIZE_TAG],
CBuffer cipher,
CBuffer aad,
uint8_t* plain) const
{
std::vector<uint8_t> pb(cipher.n + GCM_SIZE_TAG);
int len = 0;
Unique_EVP_CIPHER_CTX ctx;
CHECK1(EVP_DecryptInit_ex(ctx, evp_cipher, NULL, NULL, NULL));
CHECK1(EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv.n, NULL));
CHECK1(EVP_DecryptInit_ex(ctx, NULL, NULL, key.data(), iv.p));
if (aad.n > 0)
CHECK1(EVP_DecryptUpdate(ctx, NULL, &len, aad.p, aad.n));
CHECK1(EVP_DecryptUpdate(ctx, pb.data(), &len, cipher.p, cipher.n));
CHECK1(EVP_CIPHER_CTX_ctrl(
ctx, EVP_CTRL_GCM_SET_TAG, GCM_SIZE_TAG, (uint8_t*)tag));
int r = EVP_DecryptFinal_ex(ctx, pb.data() + len, &len) > 0;
if (r == 1 && cipher.n > 0)
memcpy(plain, pb.data(), cipher.n);
return r == 1;
}
}

Просмотреть файл

@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/symmetric_key.h"
#include "openssl_wrappers.h"
namespace crypto
{
class KeyAesGcm_OpenSSL : public KeyAesGcm
{
private:
const std::vector<uint8_t> key;
const EVP_CIPHER* evp_cipher;
public:
KeyAesGcm_OpenSSL(CBuffer rawKey);
KeyAesGcm_OpenSSL(const KeyAesGcm_OpenSSL& that) = delete;
KeyAesGcm_OpenSSL(KeyAesGcm_OpenSSL&& that);
virtual ~KeyAesGcm_OpenSSL() = default;
virtual void encrypt(
CBuffer iv,
CBuffer plain,
CBuffer aad,
uint8_t* cipher,
uint8_t tag[GCM_SIZE_TAG]) const override;
virtual bool decrypt(
CBuffer iv,
const uint8_t tag[GCM_SIZE_TAG],
CBuffer cipher,
CBuffer aad,
uint8_t* plain) const override;
};
}

Просмотреть файл

@ -0,0 +1,94 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "verifier.h"
#include "public_key.h"
#include "rsa_key_pair.h"
#include <openssl/evp.h>
#include <openssl/x509.h>
namespace crypto
{
using namespace OpenSSL;
MDType Verifier_OpenSSL::get_md_type(int mdt) const
{
switch (mdt)
{
case NID_undef:
return MDType::NONE;
case NID_sha1:
return MDType::SHA1;
case NID_sha256:
return MDType::SHA256;
case NID_sha384:
return MDType::SHA384;
case NID_sha512:
return MDType::SHA512;
default:
return MDType::NONE;
}
return MDType::NONE;
}
Verifier_OpenSSL::Verifier_OpenSSL(const std::vector<uint8_t>& c) : Verifier()
{
Unique_BIO certbio(c.data(), c.size());
if (!(cert = PEM_read_bio_X509(certbio, NULL, 0, NULL)))
{
BIO_reset(certbio);
if (!(cert = d2i_X509_bio(certbio, NULL)))
{
throw std::invalid_argument(fmt::format(
"OpenSSL error: {}", OpenSSL::error_string(ERR_get_error())));
}
}
int mdnid, pknid, secbits;
X509_get_signature_info(cert, &mdnid, &pknid, &secbits, 0);
md_type = get_md_type(mdnid);
EVP_PKEY* pk = X509_get_pubkey(cert);
if (EVP_PKEY_get0_EC_KEY(pk))
{
public_key = std::make_unique<PublicKey_OpenSSL>(pk);
}
else if (EVP_PKEY_get0_RSA(pk))
{
public_key = std::make_unique<RSAPublicKey_OpenSSL>(pk);
}
else
{
throw std::logic_error("unsupported public key type");
}
}
Verifier_OpenSSL::~Verifier_OpenSSL()
{
if (cert)
X509_free(cert);
}
std::vector<uint8_t> Verifier_OpenSSL::cert_der()
{
Unique_BIO mem;
CHECK1(i2d_X509_bio(mem, cert));
BUF_MEM* bptr;
BIO_get_mem_ptr(mem, &bptr);
return {(uint8_t*)bptr->data, (uint8_t*)bptr->data + bptr->length};
}
Pem Verifier_OpenSSL::cert_pem()
{
Unique_BIO mem;
CHECK1(PEM_write_bio_X509(mem, cert));
BUF_MEM* bptr;
BIO_get_mem_ptr(mem, &bptr);
return Pem((uint8_t*)bptr->data, bptr->length);
}
}

Просмотреть файл

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/verifier.h"
#include <openssl/x509.h>
namespace crypto
{
class Verifier_OpenSSL : public Verifier
{
protected:
mutable X509* cert;
MDType get_md_type(int mdt) const;
public:
Verifier_OpenSSL(const std::vector<uint8_t>& c);
Verifier_OpenSSL(Verifier_OpenSSL&& v) = default;
Verifier_OpenSSL(const Verifier_OpenSSL&) = delete;
virtual ~Verifier_OpenSSL();
virtual std::vector<uint8_t> cert_der() override;
virtual Pem cert_pem() override;
};
}

Просмотреть файл

@ -7,24 +7,29 @@
#include "pem.h"
#include "san.h"
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
namespace crypto
{
static constexpr size_t max_pem_key_size = 2048;
static inline void hexdump(
const char* name, const uint8_t* bytes, size_t size)
{
printf("%s: ", name);
for (size_t i = 0; i < size; i++)
printf("%02x", bytes[i]);
printf("\n");
}
class PublicKeyBase
class PublicKey
{
public:
virtual CurveID get_curve_id() const = 0;
/**
* Verify that a signature was produced on contents with the private key
* associated with the public key held by the object.
*
* @param contents address of contents
* @param contents_size size of contents
* @param sig address of signature
* @param sig_size size of signature
* @param md_type Digest algorithm to use
* @param bytes Buffer to write the hash to
*
* @return Whether the signature matches the contents and the key
*/
virtual bool verify(
const uint8_t* contents,
size_t contents_size,
@ -41,8 +46,8 @@ namespace crypto
* @param contents_size size of contents
* @param sig address of signature
* @param sig_size size of signature
* @param md_type Digest algorithm to use. Derived from the public key if
* MDType::None.
* @param md_type Digest algorithm to use (derived from the public key if
* md_type == MDType::None).
*
* @return Whether the signature matches the contents and the key
*/
@ -74,6 +79,15 @@ namespace crypto
contents.data(), contents.size(), signature.data(), signature.size());
}
/**
* Verify that a signature was produced on the hash of some contents with
* the private key associated with the public key held by the object.
*
* @param hash Hash of some content
* @param signature Signature as a sequence of bytes
*
* @return Whether the signature matches the hash and the key
*/
virtual bool verify_hash(
const std::vector<uint8_t>& hash,
const std::vector<uint8_t>& signature,
@ -83,6 +97,18 @@ namespace crypto
hash.data(), hash.size(), signature.data(), signature.size(), md_type);
}
/**
* Verify that a signature was produced on the hash of some contents with
* the private key associated with the public key held by the object.
*
* @param hash Hash of some content
* @param hash_size length of @p hash
* @param sig Signature as a sequence of bytes
* @param sig_size Length og @p sig
* @param md_type Digest algorithm
*
* @return Whether the signature matches the hash and the key
*/
virtual bool verify_hash(
const uint8_t* hash,
size_t hash_size,
@ -95,58 +121,4 @@ namespace crypto
*/
virtual Pem public_key_pem() const = 0;
};
class KeyPairBase
{
public:
virtual ~KeyPairBase() = default;
virtual Pem private_key_pem() const = 0;
virtual Pem public_key_pem() const = 0;
virtual bool verify(
const std::vector<uint8_t>& contents,
const std::vector<uint8_t>& signature) = 0;
virtual std::vector<uint8_t> sign_hash(
const uint8_t* hash, size_t hash_size) const = 0;
virtual int sign_hash(
const uint8_t* hash,
size_t hash_size,
size_t* sig_size,
uint8_t* sig) const = 0;
virtual std::vector<uint8_t> sign(CBuffer d, MDType md_type = {}) const = 0;
virtual Pem create_csr(const std::string& name) const = 0;
virtual Pem sign_csr(
const Pem& issuer_cert,
const Pem& signing_request,
const std::vector<SubjectAltName> subject_alt_names,
bool ca = false) const = 0;
Pem self_sign(
const std::string& name,
const std::optional<SubjectAltName> subject_alt_name = std::nullopt,
bool ca = true) const
{
std::vector<SubjectAltName> sans;
if (subject_alt_name.has_value())
sans.push_back(subject_alt_name.value());
auto csr = create_csr(name);
return sign_csr(Pem(0), csr, sans, ca);
}
Pem self_sign(
const std::string& name,
const std::vector<SubjectAltName> subject_alt_names,
bool ca = true) const
{
auto csr = create_csr(name);
return sign_csr(Pem(0), csr, subject_alt_names, ca);
}
};
}
}

Просмотреть файл

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "rsa_key_pair.h"
#include "mbedtls/rsa_key_pair.h"
#include "openssl/rsa_key_pair.h"
namespace crypto
{
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
using RSAPublicKeyImpl = RSAPublicKey_mbedTLS;
using RSAKeyPairImpl = RSAKeyPair_mbedTLS;
#else
using RSAPublicKeyImpl = RSAPublicKey_OpenSSL;
using RSAKeyPairImpl = RSAKeyPair_OpenSSL;
#endif
RSAPublicKeyPtr make_rsa_public_key(const Pem& public_pem)
{
return make_rsa_public_key(public_pem.data(), public_pem.size());
}
RSAPublicKeyPtr make_rsa_public_key(const std::vector<uint8_t>& der)
{
return std::make_shared<RSAPublicKeyImpl>(der);
}
static constexpr auto PEM_BEGIN = "-----BEGIN";
static constexpr auto PEM_BEGIN_LEN =
std::char_traits<char>::length(PEM_BEGIN);
RSAPublicKeyPtr make_rsa_public_key(const uint8_t* data, size_t size)
{
if (size < 10 || strncmp(PEM_BEGIN, (char*)data, PEM_BEGIN_LEN) != 0)
{
std::vector<uint8_t> der = {data, data + size};
return std::make_shared<RSAPublicKeyImpl>(der);
}
else
{
Pem pem(data, size);
return std::make_shared<RSAPublicKeyImpl>(pem);
}
}
/**
* Create a new public / private RSA key pair with specified size and exponent
*/
RSAKeyPairPtr make_rsa_key_pair(
size_t public_key_size, size_t public_exponent)
{
return std::make_shared<RSAKeyPairImpl>(public_key_size, public_exponent);
}
/**
* Create a public / private RSA key pair from existing private key data
*/
RSAKeyPairPtr make_rsa_key_pair(const Pem& pem, CBuffer pw)
{
return std::make_shared<RSAKeyPairImpl>(pem, pw);
}
}

Просмотреть файл

@ -3,296 +3,26 @@
#pragma once
#include "key_pair.h"
#include "pem.h"
#include "rsa_public_key.h"
#include <algorithm>
#include <openssl/bn.h>
#include <openssl/evp.h>
#include <openssl/pem.h>
#include <openssl/rsa.h>
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
namespace crypto
{
// Compatible with Azure HSM encryption schemes (see
// https://docs.microsoft.com/en-gb/azure/key-vault/keys/about-keys#wrapkeyunwrapkey-encryptdecrypt)
static constexpr auto rsa_padding_mode = MBEDTLS_RSA_PKCS_V21;
static constexpr auto rsa_padding_digest_id = MBEDTLS_MD_SHA256;
class RSAPublicKey_mbedTLS : public PublicKey_mbedTLS
{
public:
RSAPublicKey_mbedTLS() = default;
RSAPublicKey_mbedTLS(mbedtls::PKContext&& c) :
PublicKey_mbedTLS(std::move(c))
{}
/**
* Construct from PEM
*/
RSAPublicKey_mbedTLS(const Pem& pem) : PublicKey_mbedTLS(pem)
{
if (!mbedtls_pk_can_do(ctx.get(), MBEDTLS_PK_RSA))
{
throw std::logic_error("invalid RSA key");
}
}
/**
* Construct from DER
*/
RSAPublicKey_mbedTLS(const std::vector<uint8_t>& der) :
PublicKey_mbedTLS(der)
{
if (!mbedtls_pk_can_do(ctx.get(), MBEDTLS_PK_RSA))
{
throw std::logic_error("invalid RSA key");
}
}
/**
* Wrap data using RSA-OAEP-256
*
* @param input Pointer to raw data to wrap
* @param input_size Size of raw data
* @param label Optional string used as label during wrapping
* @param label Optional string used as label during wrapping
*
* @return Wrapped data
*/
std::vector<uint8_t> wrap(
const uint8_t* input,
size_t input_size,
const uint8_t* label = nullptr,
size_t label_size = 0)
{
mbedtls_rsa_context* rsa_ctx = mbedtls_pk_rsa(*ctx.get());
mbedtls_rsa_set_padding(rsa_ctx, rsa_padding_mode, rsa_padding_digest_id);
std::vector<uint8_t> output_buf(rsa_ctx->len);
auto entropy = create_entropy();
// Note that the maximum input size to wrap is k - 2*hLen - 2
// where hLen is the hash size (32 bytes = SHA256) and
// k the wrapping key modulus size (e.g. 256 bytes = 2048 bits).
// In this example, it would be 190 bytes (1520 bits) max.
// This is enough for wrapping AES keys for example.
auto rc = mbedtls_rsa_rsaes_oaep_encrypt(
rsa_ctx,
entropy->get_rng(),
entropy->get_data(),
MBEDTLS_RSA_PUBLIC,
label,
label_size,
input_size,
input,
output_buf.data());
if (rc != 0)
{
throw std::logic_error(
fmt::format("Error during RSA OEAP wrap: {}", error_string(rc)));
}
return output_buf;
}
/**
* Wrap data using RSA-OAEP-256
*
* @param input Raw data to wrap
* @param label Optional string used as label during wrapping
*
* @return Wrapped data
*/
std::vector<uint8_t> wrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt)
{
const unsigned char* label_ = NULL;
size_t label_size = 0;
if (label.has_value())
{
label_ = reinterpret_cast<const unsigned char*>(label->c_str());
label_size = label->size();
}
return wrap(input.data(), input.size(), label_, label_size);
}
};
class RSAPublicKey_OpenSSL : public PublicKey_OpenSSL
{
public:
RSAPublicKey_OpenSSL() = default;
RSAPublicKey_OpenSSL(EVP_PKEY* c) : PublicKey_OpenSSL(c)
{
if (!EVP_PKEY_get0_RSA(key))
{
throw std::logic_error("invalid RSA key");
}
}
/**
* Construct from PEM
*/
RSAPublicKey_OpenSSL(const Pem& pem)
{
Unique_BIO mem(pem.data(), -1);
key = PEM_read_bio_PUBKEY(mem, NULL, NULL, NULL);
if (!key || !EVP_PKEY_get0_RSA(key))
{
throw std::logic_error("invalid RSA key");
}
}
/**
* Construct from DER
*/
RSAPublicKey_OpenSSL(const std::vector<uint8_t>& der)
{
const unsigned char* pp = der.data();
RSA* rsa = NULL;
if (
((rsa = d2i_RSA_PUBKEY(NULL, &pp, der.size())) ==
NULL) && // "SubjectPublicKeyInfo structure" format
((rsa = d2i_RSAPublicKey(NULL, &pp, der.size())) ==
NULL)) // PKCS#1 structure format
{
unsigned long ec = ERR_get_error();
const char* msg = ERR_error_string(ec, NULL);
throw new std::runtime_error(fmt::format("OpenSSL error: {}", msg));
}
key = EVP_PKEY_new();
OPENSSL_CHECK1(EVP_PKEY_set1_RSA(key, rsa));
RSA_free(rsa);
}
/**
* Wrap data using RSA-OAEP-256
*
* @param input Pointer to raw data to wrap
* @param input_size Size of raw data
* @param label Optional string used as label during wrapping
* @param label Optional string used as label during wrapping
*
* @return Wrapped data
*/
std::vector<uint8_t> wrap(
const uint8_t* input,
size_t input_size,
const uint8_t* label = nullptr,
size_t label_size = 0)
{
Unique_EVP_PKEY_CTX ctx(key);
OPENSSL_CHECK1(EVP_PKEY_encrypt_init(ctx));
EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING);
EVP_PKEY_CTX_set_rsa_oaep_md(ctx, EVP_sha256());
EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, EVP_sha256());
if (label)
{
unsigned char* openssl_label =
(unsigned char*)OPENSSL_malloc(label_size);
std::copy(label, label + label_size, openssl_label);
EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, openssl_label, label_size);
}
else
{
EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, NULL, 0);
}
size_t olen;
OPENSSL_CHECK1(EVP_PKEY_encrypt(ctx, NULL, &olen, input, input_size));
std::vector<uint8_t> output(olen);
OPENSSL_CHECK1(
EVP_PKEY_encrypt(ctx, output.data(), &olen, input, input_size));
output.resize(olen);
return output;
}
/**
* Wrap data using RSA-OAEP-256
*
* @param input Raw data to wrap
* @param label Optional string used as label during wrapping
*
* @return Wrapped data
*/
std::vector<uint8_t> wrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt)
{
const unsigned char* label_ = NULL;
size_t label_size = 0;
if (label.has_value())
{
label_ = reinterpret_cast<const unsigned char*>(label->c_str());
label_size = label->size();
}
return wrap(input.data(), input.size(), label_, label_size);
}
};
class RSAKeyPair_mbedTLS : public RSAPublicKey_mbedTLS
class RSAKeyPair
{
public:
static constexpr size_t default_public_key_size = 2048;
static constexpr size_t default_public_exponent = 65537;
/**
* Create a new public / private RSA key pair
*/
RSAKeyPair_mbedTLS(
size_t public_key_size = default_public_key_size,
size_t public_exponent = default_public_exponent)
{
EntropyPtr entropy = create_entropy();
int rc =
mbedtls_pk_setup(ctx.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_RSA));
if (rc != 0)
{
throw std::logic_error(
"Could not set up RSA context: " + error_string(rc));
}
rc = mbedtls_rsa_gen_key(
mbedtls_pk_rsa(*ctx.get()),
entropy->get_rng(),
entropy->get_data(),
public_key_size,
public_exponent);
if (rc != 0)
{
throw std::logic_error(
"Could not generate RSA keypair: " + error_string(rc));
}
}
RSAKeyPair_mbedTLS(mbedtls::PKContext&& k) :
RSAPublicKey_mbedTLS(std::move(k))
{}
RSAKeyPair_mbedTLS(const RSAKeyPair_mbedTLS&) = delete;
RSAKeyPair_mbedTLS(const Pem& pem, CBuffer pw = nullb) :
RSAPublicKey_mbedTLS()
{
// keylen is +1 to include terminating null byte
int rc =
mbedtls_pk_parse_key(ctx.get(), pem.data(), pem.size(), pw.p, pw.n);
if (rc != 0)
{
throw std::logic_error(
"Could not parse private key: " + error_string(rc));
}
}
RSAKeyPair() = default;
RSAKeyPair(const RSAKeyPair&) = delete;
RSAKeyPair(const Pem& pem, CBuffer pw = nullb);
virtual ~RSAKeyPair() = default;
/**
* Unwrap data using RSA-OAEP-256
@ -302,181 +32,32 @@ namespace crypto
*
* @return Unwrapped data
*/
std::vector<uint8_t> unwrap(
virtual std::vector<uint8_t> unwrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt)
{
mbedtls_rsa_context* rsa_ctx = mbedtls_pk_rsa(*ctx.get());
mbedtls_rsa_set_padding(rsa_ctx, rsa_padding_mode, rsa_padding_digest_id);
std::vector<uint8_t> output_buf(rsa_ctx->len);
auto entropy = create_entropy();
const unsigned char* label_ = NULL;
size_t label_size = 0;
if (label.has_value())
{
label_ = reinterpret_cast<const unsigned char*>(label->c_str());
label_size = label->size();
}
size_t olen;
auto rc = mbedtls_rsa_rsaes_oaep_decrypt(
rsa_ctx,
entropy->get_rng(),
entropy->get_data(),
MBEDTLS_RSA_PRIVATE,
label_,
label_size,
&olen,
input.data(),
output_buf.data(),
output_buf.size());
if (rc != 0)
{
throw std::logic_error(
fmt::format("Error during RSA OEAP unwrap: {}", error_string(rc)));
}
output_buf.resize(olen);
return output_buf;
}
};
class RSAKeyPair_OpenSSL : public RSAPublicKey_OpenSSL
{
public:
static constexpr size_t default_public_key_size = 2048;
static constexpr size_t default_public_exponent = 65537;
std::optional<std::string> label = std::nullopt) = 0;
/**
* Create a new public / private RSA key pair
* Get the public key in PEM format
*/
RSAKeyPair_OpenSSL(
size_t public_key_size = default_public_key_size,
size_t public_exponent = default_public_exponent)
{
RSA* rsa = NULL;
BIGNUM* big_exp = NULL;
OPENSSL_CHECKNULL(big_exp = BN_new());
OPENSSL_CHECK1(BN_set_word(big_exp, public_exponent));
OPENSSL_CHECKNULL(rsa = RSA_new());
OPENSSL_CHECK1(RSA_generate_key_ex(rsa, public_key_size, big_exp, NULL));
OPENSSL_CHECKNULL(key = EVP_PKEY_new());
OPENSSL_CHECK1(EVP_PKEY_set1_RSA(key, rsa));
BN_free(big_exp);
RSA_free(rsa);
}
RSAKeyPair_OpenSSL(EVP_PKEY* k) : RSAPublicKey_OpenSSL(std::move(k)) {}
RSAKeyPair_OpenSSL(const RSAKeyPair_OpenSSL&) = delete;
RSAKeyPair_OpenSSL(const Pem& pem, CBuffer pw = nullb)
{
Unique_BIO mem(pem.data(), -1);
key = PEM_read_bio_PrivateKey(mem, NULL, NULL, (void*)pw.p);
if (!key)
{
throw std::runtime_error("could not parse PEM");
}
}
/**
* Unwrap data using RSA-OAEP-256
*
* @param input Raw data to unwrap
* @param label Optional string used as label during unwrapping
*
* @return Unwrapped data
*/
std::vector<uint8_t> unwrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt)
{
const unsigned char* label_ = NULL;
size_t label_size = 0;
if (label.has_value())
{
label_ = reinterpret_cast<const unsigned char*>(label->c_str());
label_size = label->size();
}
Unique_EVP_PKEY_CTX ctx(key);
OPENSSL_CHECK1(EVP_PKEY_decrypt_init(ctx));
EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING);
EVP_PKEY_CTX_set_rsa_oaep_md(ctx, EVP_sha256());
EVP_PKEY_CTX_set_rsa_mgf1_md(ctx, EVP_sha256());
if (label_)
{
unsigned char* openssl_label =
(unsigned char*)OPENSSL_malloc(label_size);
std::copy(label_, label_ + label_size, openssl_label);
EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, openssl_label, label_size);
}
else
{
EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, NULL, 0);
}
size_t olen;
OPENSSL_CHECK1(
EVP_PKEY_decrypt(ctx, NULL, &olen, input.data(), input.size()));
std::vector<uint8_t> output(olen);
OPENSSL_CHECK1(EVP_PKEY_decrypt(
ctx, output.data(), &olen, input.data(), input.size()));
output.resize(olen);
return output;
}
virtual Pem public_key_pem() const = 0;
};
using RSAPublicKey = RSAPublicKey_OpenSSL;
using RSAKeyPair = RSAKeyPair_OpenSSL;
using RSAKeyPairPtr = std::shared_ptr<RSAKeyPair>;
using RSAPublicKeyPtr = std::shared_ptr<RSAPublicKey>;
using RSAKeyPairPtr = std::shared_ptr<RSAKeyPair>;
RSAPublicKeyPtr make_rsa_public_key(const Pem& pem);
RSAPublicKeyPtr make_rsa_public_key(const std::vector<uint8_t>& der);
RSAPublicKeyPtr make_rsa_public_key(const uint8_t* data, size_t size);
/**
* Create a new public / private RSA key pair with specified size and exponent
*/
inline RSAKeyPairPtr make_rsa_key_pair(
RSAKeyPairPtr make_rsa_key_pair(
size_t public_key_size = RSAKeyPair::default_public_key_size,
size_t public_exponent = RSAKeyPair::default_public_exponent)
{
return std::make_shared<RSAKeyPair>(public_key_size, public_exponent);
}
size_t public_exponent = RSAKeyPair::default_public_exponent);
/**
* Create a public / private RSA key pair from existing private key data
*/
inline RSAKeyPairPtr make_rsa_key_pair(const Pem& pkey, CBuffer pw = nullb)
{
return std::make_shared<RSAKeyPair>(pkey, pw);
}
inline RSAPublicKeyPtr make_rsa_public_key(const std::vector<uint8_t>& der)
{
return std::make_shared<RSAPublicKey>(der);
}
inline RSAPublicKeyPtr make_rsa_public_key(const uint8_t* data, size_t size)
{
if (size < 10 || strncmp("-----BEGIN", (char*)data, 10) != 0)
{
std::vector<uint8_t> der = {data, data + size};
return std::make_shared<RSAPublicKey>(der);
}
else
{
Pem pem(data, size);
return std::make_shared<RSAPublicKey>(pem);
}
}
inline RSAPublicKeyPtr make_rsa_public_key(const Pem& public_pem)
{
return make_rsa_public_key(public_pem.data(), public_pem.size());
}
}
RSAKeyPairPtr make_rsa_key_pair(const Pem& pem, CBuffer pw = nullb);
}

Просмотреть файл

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#pragma once
#include "pem.h"
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
namespace crypto
{
class RSAPublicKey
{
public:
RSAPublicKey() = default;
virtual ~RSAPublicKey() = default;
/**
* Construct from PEM
*/
RSAPublicKey(const Pem& pem);
/**
* Construct from DER
*/
RSAPublicKey(const std::vector<uint8_t>& der);
/**
* Wrap data using RSA-OAEP-256
*
* @param input Pointer to raw data to wrap
* @param input_size Size of raw data
* @param label Optional string used as label during wrapping
* @param label Optional string used as label during wrapping
*
* @return Wrapped data
*/
virtual std::vector<uint8_t> wrap(
const uint8_t* input,
size_t input_size,
const uint8_t* label = nullptr,
size_t label_size = 0) = 0;
/**
* Wrap data using RSA-OAEP-256
*
* @param input Raw data to wrap
* @param label Optional string used as label during wrapping
*
* @return Wrapped data
*/
virtual std::vector<uint8_t> wrap(
const std::vector<uint8_t>& input,
std::optional<std::string> label = std::nullopt) = 0;
/**
* Get the public key in PEM format
*/
virtual Pem public_key_pem() const = 0;
};
}

Просмотреть файл

@ -1,102 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "symmetric_key.h"
#include "crypto/key_pair_mbedtls.h"
#include "ds/logger.h"
#include "ds/thread_messaging.h"
#include "crypto/mbedtls/symmetric_key.h"
#include <mbedtls/aes.h>
#include <mbedtls/error.h>
#include "crypto/openssl/symmetric_key.h"
namespace crypto
{
KeyAesGcm::KeyAesGcm(CBuffer rawKey)
using namespace mbedtls;
std::unique_ptr<KeyAesGcm> make_key_aes_gcm(CBuffer rawKey)
{
for (uint32_t i = 0; i < ctxs.size(); ++i)
{
ctxs[i] = mbedtls::make_unique<mbedtls::GcmContext>();
size_t n_bits;
const auto n = static_cast<unsigned int>(rawKey.rawSize() * 8);
if (n >= 256)
{
n_bits = 256;
}
else if (n >= 192)
{
n_bits = 192;
}
else if (n >= 128)
{
n_bits = 128;
}
else
{
throw std::logic_error(
fmt::format("Need at least {} bits, only have {}", 128, n));
}
int rc = mbedtls_gcm_setkey(
ctxs[i].get(), MBEDTLS_CIPHER_ID_AES, rawKey.p, n_bits);
if (rc != 0)
{
throw std::logic_error(PublicKey_mbedTLS::error_string(rc));
}
}
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
return std::make_unique<KeyAesGcm_mbedTLS>(rawKey);
#else
return std::make_unique<KeyAesGcm_OpenSSL>(rawKey);
#endif
}
KeyAesGcm::KeyAesGcm(KeyAesGcm&& that)
{
ctxs = std::move(that.ctxs);
}
void KeyAesGcm::encrypt(
CBuffer iv,
CBuffer plain,
CBuffer aad,
uint8_t* cipher,
uint8_t tag[GCM_SIZE_TAG]) const
{
auto ctx = ctxs[threading::get_current_thread_id()].get();
int rc = mbedtls_gcm_crypt_and_tag(
ctx,
MBEDTLS_GCM_ENCRYPT,
plain.n,
iv.p,
iv.n,
aad.p,
aad.n,
plain.p,
cipher,
GCM_SIZE_TAG,
tag);
if (rc != 0)
{
throw std::logic_error(PublicKey_mbedTLS::error_string(rc));
}
}
bool KeyAesGcm::decrypt(
CBuffer iv,
const uint8_t tag[GCM_SIZE_TAG],
CBuffer cipher,
CBuffer aad,
uint8_t* plain) const
{
auto ctx = ctxs[threading::get_current_thread_id()].get();
return !mbedtls_gcm_auth_decrypt(
ctx,
cipher.n,
iv.p,
iv.n,
aad.p,
aad.n,
tag,
GCM_SIZE_TAG,
cipher.p,
plain);
}
}
}

Просмотреть файл

@ -4,7 +4,6 @@
#include "ds/buffer.h"
#include "ds/serialized.h"
#include "ds/thread_messaging.h"
#include "mbedtls_wrappers.h"
namespace crypto
{
@ -135,28 +134,24 @@ namespace crypto
class KeyAesGcm
{
private:
mutable std::
array<mbedtls::GcmContext, threading::ThreadMessaging::max_num_threads>
ctxs;
public:
KeyAesGcm(CBuffer rawKey);
KeyAesGcm(const KeyAesGcm& that) = delete;
KeyAesGcm(KeyAesGcm&& that);
KeyAesGcm() = default;
virtual ~KeyAesGcm() = default;
void encrypt(
virtual void encrypt(
CBuffer iv,
CBuffer plain,
CBuffer aad,
uint8_t* cipher,
uint8_t tag[GCM_SIZE_TAG]) const;
uint8_t tag[GCM_SIZE_TAG]) const = 0;
bool decrypt(
virtual bool decrypt(
CBuffer iv,
const uint8_t tag[GCM_SIZE_TAG],
CBuffer cipher,
CBuffer aad,
uint8_t* plain) const;
uint8_t* plain) const = 0;
};
std::unique_ptr<KeyAesGcm> make_key_aes_gcm(CBuffer rawKey);
}

Просмотреть файл

@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "crypto/key_pair.h"
#include "crypto/mbedtls/key_pair.h"
#include "crypto/openssl/key_pair.h"
#define PICOBENCH_IMPLEMENT_WITH_MAIN
#include <picobench/picobench.hpp>

Просмотреть файл

@ -2,14 +2,23 @@
// Licensed under the Apache 2.0 License.
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
#include "crypto/key_pair.h"
#include "crypto/mbedtls/entropy.h"
#include "crypto/mbedtls/key_pair.h"
#include "crypto/mbedtls/rsa_key_pair.h"
#include "crypto/mbedtls/symmetric_key.h"
#include "crypto/mbedtls/verifier.h"
#include "crypto/openssl/key_pair.h"
#include "crypto/openssl/rsa_key_pair.h"
#include "crypto/openssl/symmetric_key.h"
#include "crypto/openssl/verifier.h"
#include "crypto/rsa_key_pair.h"
#include "crypto/symmetric_key.h"
#include "crypto/verifier.h"
#include "tls/base64.h"
#include <chrono>
#include <cstring>
#include <doctest/doctest.h>
#include <string>
using namespace std;
using namespace tls;
@ -375,15 +384,114 @@ static const vector<uint8_t>& getRawKey()
TEST_CASE("ExtendedIv0")
{
KeyAesGcm k(getRawKey());
auto k = crypto::make_key_aes_gcm(getRawKey());
// setup plain text
unsigned char rawP[100];
memset(rawP, 'x', sizeof(rawP));
Buffer p{rawP, sizeof(rawP)};
// test large IV
GcmHeader<1234> h;
k.encrypt(h.get_iv(), p, nullb, p.p, h.tag);
k->encrypt(h.get_iv(), p, nullb, p.p, h.tag);
KeyAesGcm k2(getRawKey());
REQUIRE(k2.decrypt(h.get_iv(), h.tag, p, nullb, p.p));
auto k2 = crypto::make_key_aes_gcm(getRawKey());
REQUIRE(k2->decrypt(h.get_iv(), h.tag, p, nullb, p.p));
}
TEST_CASE("AES mbedTLS vs OpenSSL")
{
auto key = getRawKey();
GcmHeader<1234> h;
{ // mbedTLS -> OpenSSL
auto mbed = std::make_unique<KeyAesGcm_mbedTLS>(key);
auto ossl = std::make_unique<KeyAesGcm_OpenSSL>(key);
std::vector<uint8_t> encrypted(contents_.size());
mbed->encrypt(h.get_iv(), contents_, nullb, encrypted.data(), h.tag);
std::vector<unsigned char> rawP(contents_.size(), 'x');
Buffer p{rawP.data(), rawP.size()};
std::vector<uint8_t> decrypted(contents_.size());
REQUIRE(ossl->decrypt(
h.get_iv(),
h.tag,
{encrypted.data(), encrypted.size()},
nullb,
decrypted.data()));
REQUIRE(decrypted.size() == contents_.size());
REQUIRE(memcmp(decrypted.data(), contents_.data(), sizeof(contents_)) == 0);
}
{ // OpenSSL -> mbedTLS
auto mbed = std::make_unique<KeyAesGcm_mbedTLS>(key);
auto ossl = std::make_unique<KeyAesGcm_OpenSSL>(key);
std::vector<uint8_t> encrypted(contents_.size());
ossl->encrypt(h.get_iv(), contents_, nullb, encrypted.data(), h.tag);
std::vector<unsigned char> rawP(contents_.size(), 'x');
Buffer p{rawP.data(), rawP.size()};
std::vector<uint8_t> decrypted(contents_.size());
CBuffer encbuf{encrypted.data(), encrypted.size()};
REQUIRE(mbed->decrypt(h.get_iv(), h.tag, encbuf, nullb, decrypted.data()));
REQUIRE(decrypted.size() == contents_.size());
REQUIRE(memcmp(decrypted.data(), contents_.data(), sizeof(contents_)) == 0);
}
}
TEST_CASE("AES mbedTLS vs OpenSSL + AAD")
{
auto key = getRawKey();
GcmHeader<1234> h;
std::vector<uint8_t> aad(123, 'y');
{
INFO("mbedTLS -> OpenSSL");
auto mbed = std::make_unique<KeyAesGcm_mbedTLS>(key);
auto ossl = std::make_unique<KeyAesGcm_OpenSSL>(key);
std::vector<uint8_t> encrypted(contents_.size());
mbed->encrypt(h.get_iv(), contents_, aad, encrypted.data(), h.tag);
std::vector<unsigned char> rawP(contents_.size(), 'x');
Buffer p{rawP.data(), rawP.size()};
std::vector<uint8_t> decrypted(contents_.size());
REQUIRE(ossl->decrypt(
h.get_iv(),
h.tag,
{encrypted.data(), encrypted.size()},
aad,
decrypted.data()));
REQUIRE(decrypted.size() == contents_.size());
REQUIRE(memcmp(decrypted.data(), contents_.data(), sizeof(contents_)) == 0);
}
{
INFO("OpenSSL -> mbedTLS");
auto mbed = std::make_unique<KeyAesGcm_mbedTLS>(key);
auto ossl = std::make_unique<KeyAesGcm_OpenSSL>(key);
std::vector<uint8_t> encrypted(contents_.size());
ossl->encrypt(h.get_iv(), contents_, aad, encrypted.data(), h.tag);
std::vector<unsigned char> rawP(contents_.size(), 'x');
Buffer p{rawP.data(), rawP.size()};
std::vector<uint8_t> decrypted(contents_.size());
CBuffer encbuf{encrypted.data(), encrypted.size()};
REQUIRE(mbed->decrypt(h.get_iv(), h.tag, encbuf, aad, decrypted.data()));
REQUIRE(decrypted.size() == contents_.size());
REQUIRE(memcmp(decrypted.data(), contents_.data(), sizeof(contents_)) == 0);
}
}

Просмотреть файл

@ -2,6 +2,8 @@
// Licensed under the Apache 2.0 License.
#include "crypto/hash.h"
#include "crypto/mbedtls/hash.h"
#include "crypto/openssl/hash.h"
#define PICOBENCH_IMPLEMENT_WITH_MAIN
#include <picobench/picobench.hpp>
@ -28,11 +30,11 @@ static void sha256_bench(picobench::state& s)
{
if constexpr (IMPL == HashImpl::mbedtls)
{
crypto::Sha256Hash::mbedtls_sha256(v, h.h.data());
crypto::mbedtls_sha256(v, h.h.data());
}
else if constexpr (IMPL == HashImpl::openssl)
{
crypto::Sha256Hash::openssl_sha256(v, h.h.data());
crypto::openssl_sha256(v, h.h.data());
}
}
s.stop_timer();

66
src/crypto/verifier.cpp Normal file
Просмотреть файл

@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "verifier.h"
#include "crypto/mbedtls/verifier.h"
#include "crypto/openssl/verifier.h"
namespace crypto
{
using VerifierPtr = std::shared_ptr<Verifier>;
using VerifierUniquePtr = std::unique_ptr<Verifier>;
/**
* Construct Verifier from a certificate in DER or PEM format
*
* @param cert Sequence of bytes containing the certificate
*/
VerifierUniquePtr make_unique_verifier(const std::vector<uint8_t>& cert)
{
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
return std::make_unique<Verifier_mbedTLS>(cert);
#else
return std::make_unique<Verifier_OpenSSL>(cert);
#endif
}
VerifierPtr make_verifier(const std::vector<uint8_t>& cert)
{
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
return std::make_shared<Verifier_mbedTLS>(cert);
#else
return std::make_shared<Verifier_OpenSSL>(cert);
#endif
}
VerifierUniquePtr make_unique_verifier(const Pem& pem)
{
return make_unique_verifier(pem.raw());
}
VerifierPtr make_verifier(const Pem& pem)
{
return make_verifier(pem.raw());
}
crypto::Pem cert_der_to_pem(const std::vector<uint8_t>& der)
{
return make_verifier(der)->cert_pem();
}
std::vector<uint8_t> cert_pem_to_der(const std::string& pem_string)
{
return make_verifier(Pem(pem_string).raw())->cert_der();
}
Pem public_key_pem_from_cert(const Pem& cert)
{
return make_unique_verifier(cert)->public_key_pem();
}
void check_is_cert(const CBuffer& der)
{
make_unique_verifier((std::vector<uint8_t>)der); // throws on error
}
}

Просмотреть файл

@ -2,35 +2,22 @@
// Licensed under the Apache 2.0 License.
#pragma once
#include "curve.h"
#include "hash.h"
#include "key_pair.h"
#include "pem.h"
#include "rsa_key_pair.h"
#include <mbedtls/pem.h>
#include <openssl/evp.h>
#include <openssl/x509.h>
#include "public_key.h"
namespace crypto
{
static constexpr size_t max_pem_cert_size = 4096;
// As these are not exposed by mbedTLS, define them here to allow simple
// conversion from DER to PEM format
static constexpr auto PEM_CERTIFICATE_HEADER =
"-----BEGIN CERTIFICATE-----\n";
static constexpr auto PEM_CERTIFICATE_FOOTER = "-----END CERTIFICATE-----\n";
class VerifierBase
class Verifier
{
protected:
std::shared_ptr<PublicKeyBase> public_key;
std::shared_ptr<PublicKey> public_key;
MDType md_type = MDType::NONE;
public:
VerifierBase() : public_key(nullptr) {}
virtual ~VerifierBase() {}
Verifier() : public_key(nullptr) {}
virtual ~Verifier() {}
virtual std::vector<uint8_t> cert_der() = 0;
virtual Pem cert_pem() = 0;
@ -124,270 +111,33 @@ namespace crypto
hash.data(), hash.size(), signature.data(), signature.size(), md_type);
}
virtual CurveID get_curve_id() const
{
return public_key->get_curve_id();
}
virtual Pem public_key_pem() const
{
return public_key->public_key_pem();
}
};
class Verifier_mbedTLS : public VerifierBase
{
protected:
mutable mbedtls::X509Crt cert;
inline MDType get_md_type(mbedtls_md_type_t mdt) const
{
switch (mdt)
{
case MBEDTLS_MD_NONE:
return MDType::NONE;
case MBEDTLS_MD_SHA1:
return MDType::SHA1;
case MBEDTLS_MD_SHA256:
return MDType::SHA256;
case MBEDTLS_MD_SHA384:
return MDType::SHA384;
case MBEDTLS_MD_SHA512:
return MDType::SHA512;
default:
return MDType::NONE;
}
return MDType::NONE;
}
public:
/**
* Construct from a certificate
*
* @param c Certificate in DER or PEM format
*/
Verifier_mbedTLS(const std::vector<uint8_t>& c) : VerifierBase()
{
cert = mbedtls::make_unique<mbedtls::X509Crt>();
int rc = mbedtls_x509_crt_parse(cert.get(), c.data(), c.size());
if (rc)
{
throw std::invalid_argument(fmt::format(
"Failed to parse certificate: {}",
PublicKey_mbedTLS::error_string(rc)));
}
md_type = get_md_type(cert->sig_md);
// public_key expects to have unique ownership of the context and so does
// `cert`, so we duplicate the key context here.
unsigned char buf[2048];
rc = mbedtls_pk_write_pubkey_pem(&cert->pk, buf, sizeof(buf));
if (rc != 0)
{
throw std::runtime_error(fmt::format(
"PEM export failed: {}", PublicKey_mbedTLS::error_string(rc)));
}
Pem pem(buf, sizeof(buf));
if (mbedtls_pk_can_do(&cert->pk, MBEDTLS_PK_ECKEY))
{
public_key = std::make_unique<PublicKey_mbedTLS>(pem);
}
else if (mbedtls_pk_can_do(&cert->pk, MBEDTLS_PK_RSA))
{
public_key = std::make_unique<RSAPublicKey_mbedTLS>(pem);
}
else
{
throw std::logic_error("unsupported public key type");
}
}
Verifier_mbedTLS(const Verifier_mbedTLS&) = delete;
virtual ~Verifier_mbedTLS() = default;
virtual std::vector<uint8_t> cert_der() override
{
return {cert->raw.p, cert->raw.p + cert->raw.len};
}
virtual Pem cert_pem() override
{
unsigned char buf[max_pem_cert_size];
size_t len;
auto rc = mbedtls_pem_write_buffer(
PEM_CERTIFICATE_HEADER,
PEM_CERTIFICATE_FOOTER,
cert->raw.p,
cert->raw.len,
buf,
max_pem_cert_size,
&len);
if (rc != 0)
{
throw std::logic_error(
"mbedtls_pem_write_buffer failed: " +
PublicKey_mbedTLS::error_string(rc));
}
return Pem(buf, len);
}
};
class Verifier_OpenSSL : public VerifierBase
{
protected:
mutable X509* cert;
MDType get_md_type(int mdt) const
{
switch (mdt)
{
case NID_undef:
return MDType::NONE;
case NID_sha1:
return MDType::SHA1;
case NID_sha256:
return MDType::SHA256;
case NID_sha384:
return MDType::SHA384;
case NID_sha512:
return MDType::SHA512;
default:
return MDType::NONE;
}
return MDType::NONE;
}
public:
/**
* Construct from a certificate
*
* @param c Certificate in DER or PEM format
*/
Verifier_OpenSSL(const std::vector<uint8_t>& c) : VerifierBase()
{
Unique_BIO certbio(c.data(), c.size());
if (!(cert = PEM_read_bio_X509(certbio, NULL, 0, NULL)))
{
BIO_reset(certbio);
if (!(cert = d2i_X509_bio(certbio, NULL)))
{
throw std::invalid_argument(fmt::format(
"OpenSSL error: {}", ERR_error_string(ERR_get_error(), NULL)));
}
}
int mdnid, pknid, secbits;
X509_get_signature_info(cert, &mdnid, &pknid, &secbits, 0);
md_type = get_md_type(mdnid);
EVP_PKEY* pk = X509_get_pubkey(cert);
if (EVP_PKEY_get0_EC_KEY(pk))
{
public_key = std::make_unique<PublicKey_OpenSSL>(pk);
}
else if (EVP_PKEY_get0_RSA(pk))
{
public_key = std::make_unique<RSAPublicKey_OpenSSL>(pk);
}
else
{
throw std::logic_error("unsupported public key type");
}
}
Verifier_OpenSSL(Verifier_OpenSSL&& v) = default;
Verifier_OpenSSL(const Verifier_OpenSSL&) = delete;
virtual ~Verifier_OpenSSL()
{
if (cert)
X509_free(cert);
}
virtual std::vector<uint8_t> cert_der() override
{
Unique_BIO mem;
OPENSSL_CHECK1(i2d_X509_bio(mem, cert));
BUF_MEM* bptr;
BIO_get_mem_ptr(mem, &bptr);
return {(uint8_t*)bptr->data, (uint8_t*)bptr->data + bptr->length};
}
virtual Pem cert_pem() override
{
Unique_BIO mem;
OPENSSL_CHECK1(PEM_write_bio_X509(mem, cert));
BUF_MEM* bptr;
BIO_get_mem_ptr(mem, &bptr);
return Pem((uint8_t*)bptr->data, bptr->length);
}
};
using VerifierPtr = std::shared_ptr<VerifierBase>;
using VerifierUniquePtr = std::unique_ptr<VerifierBase>;
using VerifierPtr = std::shared_ptr<Verifier>;
using VerifierUniquePtr = std::unique_ptr<Verifier>;
/**
* Construct Verifier from a certificate in DER or PEM format
*
* @param cert Sequence of bytes containing the certificate
*/
inline VerifierUniquePtr make_unique_verifier(
const std::vector<uint8_t>& cert)
{
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
return std::make_unique<Verifier_mbedTLS>(cert);
#else
return std::make_unique<Verifier_OpenSSL>(cert);
#endif
}
VerifierUniquePtr make_unique_verifier(const std::vector<uint8_t>& cert);
inline VerifierPtr make_verifier(const std::vector<uint8_t>& cert)
{
#ifdef CRYPTO_PROVIDER_IS_MBEDTLS
return std::make_shared<Verifier_mbedTLS>(cert);
#else
return std::make_shared<Verifier_OpenSSL>(cert);
#endif
}
VerifierPtr make_verifier(const std::vector<uint8_t>& cert);
inline VerifierPtr make_unique_verifier(const Pem& pem)
{
return make_unique_verifier(pem.raw());
}
VerifierUniquePtr make_unique_verifier(const Pem& pem);
inline VerifierPtr make_verifier(const Pem& pem)
{
return make_verifier(pem.raw());
}
VerifierPtr make_verifier(const Pem& pem);
inline crypto::Pem cert_der_to_pem(const std::vector<uint8_t>& der)
{
return make_verifier(der)->cert_pem();
}
crypto::Pem cert_der_to_pem(const std::vector<uint8_t>& der);
inline std::vector<uint8_t> cert_pem_to_der(const std::string& pem_string)
{
return make_verifier(Pem(pem_string).raw())->cert_der();
}
std::vector<uint8_t> cert_pem_to_der(const std::string& pem_string);
static inline Pem public_key_pem_from_cert(const Pem& cert)
{
return make_unique_verifier(cert)->public_key_pem();
}
Pem public_key_pem_from_cert(const Pem& cert);
inline void check_is_cert(const CBuffer& der)
{
make_unique_verifier((std::vector<uint8_t>)der); // throws on error
}
void check_is_cert(const CBuffer& der);
}

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "crypto/hash.h"
#include "crypto/key_pair.h"
#include "crypto/verifier.h"
#include "http_consts.h"
#include "http_parser.h"
#include "tls/base64.h"

Просмотреть файл

@ -330,7 +330,7 @@ namespace ccf
void establish(bool complete)
{
auto shared_secret = ctx.compute_shared_secret();
key = std::make_unique<crypto::KeyAesGcm>(shared_secret);
key = crypto::make_key_aes_gcm(shared_secret);
ctx.free_ctx();
status = ESTABLISHED;

Просмотреть файл

@ -119,7 +119,7 @@ namespace ccf
kv::Term term = 0;
public:
NullTxHistory(kv::Store& store_, NodeId id_, crypto::KeyPairBase&) :
NullTxHistory(kv::Store& store_, NodeId id_, crypto::KeyPair&) :
store(store_),
id(id_)
{}
@ -252,7 +252,7 @@ namespace ccf
kv::Store& store;
T& replicated_state_tree;
NodeId id;
crypto::KeyPairBase& kp;
crypto::KeyPair& kp;
public:
MerkleTreeHistoryPendingTx(
@ -261,7 +261,7 @@ namespace ccf
kv::Store& store_,
T& replicated_state_tree_,
NodeId id_,
crypto::KeyPairBase& kp_) :
crypto::KeyPair& kp_) :
txid(txid_),
commit_txid(commit_txid_),
store(store_),
@ -464,7 +464,7 @@ namespace ccf
NodeId id;
T replicated_state_tree;
crypto::KeyPairBase& kp;
crypto::KeyPair& kp;
std::map<RequestID, std::vector<uint8_t>> requests;
@ -479,7 +479,7 @@ namespace ccf
HashedTxHistory(
kv::Store& store_,
NodeId id_,
crypto::KeyPairBase& kp_,
crypto::KeyPair& kp_,
size_t sig_tx_interval_ = 0,
size_t sig_ms_interval_ = 0,
bool signature_timer = false) :

Просмотреть файл

@ -31,7 +31,7 @@ namespace ccf
// key.
LedgerSecret(const LedgerSecret& other) :
raw_key(other.raw_key),
key(std::make_shared<crypto::KeyAesGcm>(other.raw_key)),
key(crypto::make_key_aes_gcm(other.raw_key)),
previous_secret_stored_version(other.previous_secret_stored_version)
{}
@ -40,7 +40,7 @@ namespace ccf
std::optional<kv::Version> previous_secret_stored_version_ =
std::nullopt) :
raw_key(raw_key_),
key(std::make_shared<crypto::KeyAesGcm>(std::move(raw_key_))),
key(crypto::make_key_aes_gcm(std::move(raw_key_))),
previous_secret_stored_version(previous_secret_stored_version_)
{}
};

Просмотреть файл

@ -2,6 +2,7 @@
// Licensed under the Apache 2.0 License.
#pragma once
#include "crypto/entropy.h"
#include "ds/ccf_assert.h"
#include "ds/ccf_exception.h"
#include "kv/kv_types.h"

Просмотреть файл

@ -84,7 +84,7 @@ namespace ccf
{
public:
ProgressTrackerStoreAdapter(
kv::AbstractStore& store_, crypto::KeyPairBase& kp_) :
kv::AbstractStore& store_, crypto::KeyPair& kp_) :
store(store_),
kp(kp_),
nodes(ccf::Tables::NODES),
@ -278,7 +278,7 @@ namespace ccf
private:
kv::AbstractStore& store;
crypto::KeyPairBase& kp;
crypto::KeyPair& kp;
ccf::Nodes nodes;
ccf::BackupSignaturesMap backup_signatures;
aft::RevealedNoncesMap revealed_nonces;

Просмотреть файл

@ -93,7 +93,7 @@ class NullTxHistoryWithOverride : public ccf::NullTxHistory
public:
NullTxHistoryWithOverride(
kv::Store& store_, NodeId id_, crypto::KeyPairBase& kp_) :
kv::Store& store_, NodeId id_, crypto::KeyPair& kp_) :
ccf::NullTxHistory(store_, id_, kp_)
{}

Просмотреть файл

@ -22,7 +22,7 @@ namespace ccf
{
// Encrypt secrets with a shared secret derived from backup public
// key
crypto::KeyAesGcm backup_shared_secret(
auto backup_shared_secret = crypto::make_key_aes_gcm(
tls::KeyExchangeContext(encryption_key, backup_pubk)
.compute_shared_secret());
@ -30,7 +30,7 @@ namespace ccf
auto iv = crypto::create_entropy()->random(gcmcipher.hdr.get_iv().n);
std::copy(iv.begin(), iv.end(), gcmcipher.hdr.iv);
backup_shared_secret.encrypt(
backup_shared_secret->encrypt(
iv, plain, nullb, gcmcipher.cipher.data(), gcmcipher.hdr.tag);
return gcmcipher.serialise();
@ -113,11 +113,11 @@ namespace ccf
gcmcipher.deserialise(cipher);
std::vector<uint8_t> plain(gcmcipher.cipher.size());
crypto::KeyAesGcm primary_shared_key(
auto primary_shared_key = crypto::make_key_aes_gcm(
tls::KeyExchangeContext(encryption_key, primary_pubk)
.compute_shared_secret());
if (!primary_shared_key.decrypt(
if (!primary_shared_key->decrypt(
gcmcipher.hdr.get_iv(),
gcmcipher.hdr.tag,
gcmcipher.cipher,

Просмотреть файл

@ -53,7 +53,7 @@ namespace ccf
crypto::GcmCipher encrypted_ls(ledger_secret->raw_key.size());
crypto::KeyAesGcm(data).encrypt(
crypto::make_key_aes_gcm(data)->encrypt(
encrypted_ls.hdr.get_iv(), // iv is always 0 here as the share wrapping
// key is never re-used for encryption
ledger_secret->raw_key,
@ -73,7 +73,7 @@ namespace ccf
encrypted_ls.deserialise(wrapped_latest_ledger_secret);
std::vector<uint8_t> decrypted_ls(encrypted_ls.cipher.size());
if (!crypto::KeyAesGcm(data).decrypt(
if (!crypto::make_key_aes_gcm(data)->decrypt(
encrypted_ls.hdr.get_iv(),
encrypted_ls.hdr.tag,
encrypted_ls.cipher,

Просмотреть файл

@ -1,5 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the Apache 2.0 License.
#include "crypto/mbedtls/hash.h"
#include "kv/test/stub_consensus.h"
#include "node/history.h"
@ -58,7 +59,7 @@ static void hash_only(picobench::state& s)
(void)_;
auto data = txs[idx++];
crypto::Sha256Hash h;
crypto::Sha256Hash::mbedtls_sha256({data}, h.h.data());
crypto::mbedtls_sha256({data}, h.h.data());
do_not_optimize(h);
clobber_memory();
}

Просмотреть файл

@ -85,7 +85,7 @@ namespace timing
std::stringstream ss;
const auto now = Clock::now();
auto now_tt = Clock::to_time_t(now);
time_t now_tt = now.time_since_epoch().count();
tm now_tm;
::localtime_r(&now_tt, &now_tm);

Просмотреть файл

@ -2,7 +2,7 @@
// Licensed under the Apache 2.0 License.
#pragma once
#include "../crypto/mbedtls_wrappers.h"
#include "../crypto/mbedtls/mbedtls_wrappers.h"
#include "../crypto/pem.h"
#include "../ds/buffer.h"
@ -13,14 +13,14 @@ namespace tls
class CA
{
private:
mbedtls::X509Crt ca = nullptr;
mbedtls::X509Crl crl = nullptr;
crypto::mbedtls::X509Crt ca = nullptr;
crypto::mbedtls::X509Crl crl = nullptr;
public:
CA(CBuffer ca_ = nullb, CBuffer crl_ = nullb)
{
auto tmp_ca = mbedtls::make_unique<mbedtls::X509Crt>();
auto tmp_crl = mbedtls::make_unique<mbedtls::X509Crl>();
auto tmp_ca = crypto::mbedtls::make_unique<crypto::mbedtls::X509Crt>();
auto tmp_crl = crypto::mbedtls::make_unique<crypto::mbedtls::X509Crl>();
if (ca_.n > 0)
{

Просмотреть файл

@ -3,13 +3,15 @@
#pragma once
#include "ca.h"
#include "crypto/mbedtls_wrappers.h"
#include "crypto/mbedtls/mbedtls_wrappers.h"
#include "error_string.h"
#include <cstring>
#include <memory>
#include <optional>
using namespace crypto;
namespace tls
{
enum Auth

Просмотреть файл

@ -4,11 +4,13 @@
#include "cert.h"
#include "crypto/entropy.h"
#include "crypto/mbedtls_wrappers.h"
#include "crypto/mbedtls/mbedtls_wrappers.h"
#include "error_string.h"
#include <memory>
using namespace crypto;
namespace tls
{
class Context

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include "crypto/entropy.h"
#include "crypto/key_pair_mbedtls.h"
#include "crypto/mbedtls/key_pair.h"
#include "ds/logger.h"
#include "tls/error_string.h"
@ -16,7 +16,7 @@ namespace tls
class KeyExchangeContext
{
private:
mbedtls::ECDHContext ctx = nullptr;
crypto::mbedtls::ECDHContext ctx = nullptr;
std::vector<uint8_t> own_public;
crypto::EntropyPtr entropy;
@ -31,7 +31,8 @@ namespace tls
own_public(len_public),
entropy(crypto::create_entropy())
{
auto tmp_ctx = mbedtls::make_unique<mbedtls::ECDHContext>();
auto tmp_ctx =
crypto::mbedtls::make_unique<crypto::mbedtls::ECDHContext>();
size_t len;
int rc = mbedtls_ecp_group_load(&tmp_ctx->grp, domain_parameter);
@ -64,7 +65,8 @@ namespace tls
std::shared_ptr<crypto::PublicKey_mbedTLS> peer_pubk) :
entropy(crypto::create_entropy())
{
auto tmp_ctx = mbedtls::make_unique<mbedtls::ECDHContext>();
auto tmp_ctx =
crypto::mbedtls::make_unique<crypto::mbedtls::ECDHContext>();
int rc = mbedtls_ecdh_get_params(
tmp_ctx.get(),