[Fuzzer] Add fuzzer support for linux (#21996)
### Description Added some change in fuzzer project code to support linux also. How to test on linux: 1. Make sure you have installed clang/llvm. 2. run below command to build asan instrumented project: ``` CFLAGS="-g -fsanitize=address -shared-libasan -fprofile-instr-generate -fcoverage-mapping" CXXFLAGS="-g -shared-libasan -fsanitize=address -fprofile-instr-generate -fcoverage-mapping" CC=clang CXX=clang++ ./build.sh --update --build --config Debug --compile_no_warning_as_error --build_shared_lib --skip_submodule_sync --skip_tests --use_full_protobuf --parallel --fuzz_testing --build_dir build/ ``` 3. run fuzzer for some time, it will generate *.profraw file: ``` LLVM_PROFILE_FILE="%p.profraw" ./build/Debug/onnxruntime_security_fuzz /t /v onnxruntime/test/testdata/bart_tiny.onnx 1 m ``` 4. Get the cov by running below cmd: ``` llvm-profdata merge -sparse *.profraw -o default.profdata llvm-cov report ./build/Debug/onnxruntime_security_fuzz -instr-profile=default.profdata ``` <img width="1566" alt="Screenshot 2024-09-05 at 4 25 08 PM" src="https://github.com/user-attachments/assets/2aa0bb83-6634-4d33-b026-3535e97df431"> ### Motivation and Context 1. Currently fuzzer only supports windows and MSVC, we can't generate the code coverage using MSVC. With clang/llvm we can try and use clang instrumentation and llvm tools like llvm-cov. 2. In future we can add coverage guided fuzzer (libfuzzer) in same project. (Working on it)
This commit is contained in:
Родитель
f4d62eeb2e
Коммит
2dae8aaced
|
@ -413,8 +413,8 @@ endif(MSVC_Z7_OVERRIDE)
|
|||
set_msvc_c_cpp_compiler_warning_level(3)
|
||||
|
||||
# Fuzz test has only been tested with BUILD_SHARED_LIB option,
|
||||
# using the MSVC compiler and on windows OS.
|
||||
if (MSVC AND WIN32 AND onnxruntime_FUZZ_TEST AND onnxruntime_BUILD_SHARED_LIB AND onnxruntime_USE_FULL_PROTOBUF)
|
||||
# using the MSVC compiler and on windows OS and clang/gcc compiler on Linux.
|
||||
if (onnxruntime_FUZZ_TEST AND onnxruntime_BUILD_SHARED_LIB AND onnxruntime_USE_FULL_PROTOBUF)
|
||||
# Fuzz test library dependency, protobuf-mutator,
|
||||
# needs the onnx message to be compiled using "non-lite protobuf version"
|
||||
set(onnxruntime_FUZZ_ENABLED ON)
|
||||
|
|
|
@ -5,15 +5,15 @@
|
|||
# the fuzzing project
|
||||
if (onnxruntime_FUZZ_ENABLED)
|
||||
message(STATUS "Building dependency protobuf-mutator and libfuzzer")
|
||||
|
||||
|
||||
# set the options used to control the protobuf-mutator build
|
||||
set(PROTOBUF_LIBRARIES ${PROTOBUF_LIB})
|
||||
set(LIB_PROTO_MUTATOR_TESTING OFF)
|
||||
|
||||
|
||||
# include the protobuf-mutator CMakeLists.txt rather than the projects CMakeLists.txt to avoid target clashes
|
||||
# with google test
|
||||
add_subdirectory("external/libprotobuf-mutator/src")
|
||||
|
||||
|
||||
# add the appropriate include directory and compilation flags
|
||||
# needed by the protobuf-mutator target and the libfuzzer
|
||||
set(PROTOBUF_MUT_INCLUDE_DIRS "external/libprotobuf-mutator")
|
||||
|
@ -21,45 +21,65 @@ if (onnxruntime_FUZZ_ENABLED)
|
|||
onnxruntime_add_include_to_target(protobuf-mutator-libfuzzer ${PROTOBUF_LIB})
|
||||
target_include_directories(protobuf-mutator PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS})
|
||||
target_include_directories(protobuf-mutator-libfuzzer PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS})
|
||||
target_compile_options(protobuf-mutator PRIVATE "/wd4244" "/wd4245" "/wd4267" "/wd4100" "/wd4456")
|
||||
target_compile_options(protobuf-mutator-libfuzzer PRIVATE "/wd4146" "/wd4267")
|
||||
|
||||
# add Fuzzing Engine Build Configuration
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
|
||||
# MSVC-specific compiler options
|
||||
target_compile_options(protobuf-mutator PRIVATE "/wd4244" "/wd4245" "/wd4267" "/wd4100" "/wd4456")
|
||||
target_compile_options(protobuf-mutator-libfuzzer PRIVATE "/wd4146" "/wd4267")
|
||||
else()
|
||||
# Linux-specific compiler options
|
||||
target_compile_options(protobuf-mutator PRIVATE
|
||||
-Wno-shorten-64-to-32
|
||||
-Wno-conversion
|
||||
-Wno-sign-compare
|
||||
-Wno-unused-parameter
|
||||
-Wno-shadow
|
||||
-Wno-unused
|
||||
-fexceptions
|
||||
)
|
||||
target_compile_options(protobuf-mutator-libfuzzer PRIVATE
|
||||
-Wno-shorten-64-to-32
|
||||
-Wno-conversion
|
||||
-Wno-unused
|
||||
-fexceptions
|
||||
)
|
||||
endif()
|
||||
|
||||
# add Fuzzing Engine Build Configuration
|
||||
message(STATUS "Building Fuzzing engine")
|
||||
|
||||
|
||||
# set Fuzz root directory
|
||||
set(SEC_FUZZ_ROOT ${TEST_SRC_DIR}/fuzzing)
|
||||
|
||||
# Security fuzzing engine src file reference
|
||||
set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp"
|
||||
"${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp"
|
||||
"${SEC_FUZZ_ROOT}/src/testlog.cpp"
|
||||
|
||||
# Security fuzzing engine src file reference
|
||||
set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp"
|
||||
"${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp"
|
||||
"${SEC_FUZZ_ROOT}/src/testlog.cpp"
|
||||
"${SEC_FUZZ_ROOT}/src/test.cpp")
|
||||
|
||||
|
||||
# compile the executables
|
||||
onnxruntime_add_executable(onnxruntime_security_fuzz ${SEC_FUZ_SRC})
|
||||
|
||||
|
||||
# compile with c++17
|
||||
target_compile_features(onnxruntime_security_fuzz PUBLIC cxx_std_17)
|
||||
|
||||
|
||||
# Security fuzzing engine header file reference
|
||||
onnxruntime_add_include_to_target(onnxruntime_security_fuzz onnx onnxruntime)
|
||||
|
||||
|
||||
# Assign all include to one variable
|
||||
set(SEC_FUZ_INC "${SEC_FUZZ_ROOT}/include")
|
||||
set(INCLUDE_FILES ${SEC_FUZ_INC} "$<TARGET_PROPERTY:protobuf-mutator,INCLUDE_DIRECTORIES>")
|
||||
|
||||
|
||||
# add all these include directory to the Fuzzing engine
|
||||
target_include_directories(onnxruntime_security_fuzz PRIVATE ${INCLUDE_FILES})
|
||||
|
||||
|
||||
# add link libraries the project
|
||||
target_link_libraries(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB})
|
||||
|
||||
|
||||
# add the dependencies
|
||||
add_dependencies(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB})
|
||||
|
||||
|
||||
# copy the dlls to the execution directory
|
||||
add_custom_command(TARGET onnxruntime_security_fuzz POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> $<TARGET_FILE_DIR:onnxruntime_security_fuzz>
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:${PROTOBUF_LIB}> $<TARGET_FILE_DIR:onnxruntime_security_fuzz>)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include <random>
|
||||
#include <map>
|
||||
#include <chrono>
|
||||
#include <stdexcept>
|
||||
|
||||
// Default parameter will produce a shape with alpha = 0.5
|
||||
// and beta = 0.5.
|
||||
|
@ -81,7 +82,7 @@ class BetaDistribution {
|
|||
for (int i = 0; i < sample_size; i++) {
|
||||
calc_type sample = convert_to_fixed_range(gen);
|
||||
calc_type highest_probability_temp = highest_probability;
|
||||
highest_probability = std::max({highest_probability_temp, distribution(sample)});
|
||||
highest_probability = std::max(highest_probability_temp, distribution(sample));
|
||||
|
||||
// A new sample number with a higher probability has been found
|
||||
//
|
||||
|
@ -133,7 +134,7 @@ class BetaDistribution {
|
|||
|
||||
return static_cast<double>(result);
|
||||
} else {
|
||||
throw std::exception("Non special gamma values not yet Implemeted");
|
||||
throw std::runtime_error("Non special gamma values not yet Implemeted");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -154,7 +155,7 @@ class BetaDistribution {
|
|||
calc_type term2{pow((randVar - min()) / range, m_alpha - 1)};
|
||||
calc_type term3{pow((max() - randVar) / range, m_beta - 1)};
|
||||
|
||||
return {term * term1 * term2 * term3};
|
||||
return (term * term1 * term2 * term3);
|
||||
}
|
||||
|
||||
// Used to convert the number that generator produces
|
||||
|
@ -170,9 +171,9 @@ class BetaDistribution {
|
|||
|
||||
// Convert the number to the range [beginRange, endRange]
|
||||
//
|
||||
calc_type range{std::numeric_limits<generator::result_type>::max() - std::numeric_limits<generator::result_type>::lowest()};
|
||||
calc_type range{std::numeric_limits<typename generator::result_type>::max() - std::numeric_limits<typename generator::result_type>::lowest()};
|
||||
|
||||
calc_type delta{x - std::numeric_limits<generator::result_type>::lowest()};
|
||||
calc_type delta{x - std::numeric_limits<typename generator::result_type>::lowest()};
|
||||
calc_type ratio{delta / range};
|
||||
|
||||
calc_type new_range{static_cast<calc_type>(max()) - static_cast<calc_type>(min())};
|
||||
|
@ -183,7 +184,7 @@ class BetaDistribution {
|
|||
}
|
||||
|
||||
calc_type res{(ratio * new_range) + min()};
|
||||
return {res};
|
||||
return (res);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -235,4 +236,4 @@ GenerateRandomData(ONNX_ELEMENT_VALUE_TYPE initialValue, size_t numElementsToGen
|
|||
|
||||
return randomDataBucket;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -51,6 +51,13 @@ class OnnxPrediction {
|
|||
//
|
||||
OnnxPrediction(const std::vector<char>& model_data, Ort::Env& env);
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
// Helper function to convert std::wstring to std::string
|
||||
std::string wstring_to_string(const std::wstring& wstr) {
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
return converter.to_bytes(wstr);
|
||||
}
|
||||
#endif
|
||||
// Data to run prediction on
|
||||
//
|
||||
template <typename T>
|
||||
|
@ -80,7 +87,7 @@ class OnnxPrediction {
|
|||
input_value = Ort::Value::CreateTensor(alloc.GetInfo(),
|
||||
input_data[curr_input_index].get(), data_size_in_bytes, shapeInfo.data(), shapeInfo.size(), elem_type);
|
||||
} else {
|
||||
throw std::exception("only floats are implemented");
|
||||
throw std::runtime_error("only floats are implemented");
|
||||
}
|
||||
|
||||
// Insert data into the next input type
|
||||
|
|
|
@ -9,9 +9,16 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
#include <chrono>
|
||||
#include <unordered_map>
|
||||
#include <cwchar>
|
||||
#include <locale>
|
||||
#include <codecvt>
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
#include <windows.h> // For MultiByteToWideChar and CP_UTF8
|
||||
#endif
|
||||
|
||||
// Forward declarartions
|
||||
// Forward declarations
|
||||
//
|
||||
class OnnxPrediction;
|
||||
std::wostream& operator<<(std::wostream& out, OnnxPrediction& pred);
|
||||
|
@ -25,49 +32,37 @@ using LogEndln = void*;
|
|||
//
|
||||
class TestLog {
|
||||
public:
|
||||
// Flush out all output streams
|
||||
// used by logger
|
||||
//
|
||||
// Flush out all output streams used by logger
|
||||
void flush();
|
||||
|
||||
// Print out output of a prediction.
|
||||
//
|
||||
TestLog& operator<<(OnnxPrediction& pred);
|
||||
|
||||
// Generic log output that appends timing
|
||||
// information.
|
||||
//
|
||||
// Generic log output that appends timing information.
|
||||
template <typename T>
|
||||
TestLog& operator<<(const T& info);
|
||||
|
||||
// Disable logging
|
||||
//
|
||||
inline void disable();
|
||||
|
||||
// Enable logging
|
||||
//
|
||||
inline void enable();
|
||||
|
||||
// Ends the current line so that the
|
||||
// next line can start with time information.
|
||||
//
|
||||
// Ends the current line so that the next line
|
||||
// can start with time information.
|
||||
void operator<<(LogEndln info);
|
||||
|
||||
// Minimize log
|
||||
//
|
||||
inline void minLog();
|
||||
|
||||
// Maintain ring buffer
|
||||
// Note:
|
||||
// This is only used for minimun logging
|
||||
// if normal logging is being used this map
|
||||
// must be constrained.
|
||||
//
|
||||
// Note: This is only used for minimum logging;
|
||||
// if normal logging is being used, this map must
|
||||
// be constrained.
|
||||
void insert(std::wstring data);
|
||||
|
||||
// Singleton constructor only one object exists
|
||||
// Hence this resource is not thread-safe
|
||||
//
|
||||
// Singleton constructor - only one object exists
|
||||
// Note: this resource is not thread-safe
|
||||
TestLog();
|
||||
|
||||
private:
|
||||
|
@ -79,16 +74,32 @@ class TestLog {
|
|||
std::map<size_t, std::pair<size_t, std::wstring>> ring_buffer;
|
||||
static constexpr int logFileLineWidth{128};
|
||||
static constexpr int logFileLen{1000};
|
||||
std::wstring string_to_wstring(const std::string& str) {
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
// Windows implementation
|
||||
int size_needed = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
|
||||
std::wstring wstrTo(size_needed, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstrTo[0], size_needed);
|
||||
return wstrTo;
|
||||
#else
|
||||
// Linux implementation
|
||||
std::mbstate_t state = std::mbstate_t();
|
||||
const char* c_str = str.c_str();
|
||||
size_t len = std::mbsrtowcs(nullptr, &c_str, 0, &state);
|
||||
std::wstring wstr(len, L'\0');
|
||||
std::mbsrtowcs(&wstr[0], &c_str, len, &state);
|
||||
return wstr;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Reference to initialized logger
|
||||
// Not this resource is not thread safe and only one
|
||||
// exists for the entire process.
|
||||
// Note: this resource is not thread-safe and only
|
||||
// one exists for the entire process.
|
||||
//
|
||||
extern TestLog testLog;
|
||||
|
||||
// Object used to mark end of line format
|
||||
// for testLog
|
||||
// Object used to mark end of line format for testLog
|
||||
//
|
||||
static constexpr LogEndln endl = nullptr;
|
||||
|
||||
|
@ -100,13 +111,11 @@ std::wstring towstr(const char* pStr);
|
|||
// Inline Functions
|
||||
|
||||
// Minimize log
|
||||
//
|
||||
inline void Logger::TestLog::minLog() {
|
||||
min_log = true;
|
||||
}
|
||||
|
||||
// Enable logging
|
||||
//
|
||||
inline void Logger::TestLog::enable() {
|
||||
logging_on = true;
|
||||
}
|
||||
|
@ -115,49 +124,60 @@ inline void Logger::TestLog::disable() {
|
|||
logging_on = false;
|
||||
}
|
||||
|
||||
// Template functions
|
||||
namespace Logger {
|
||||
|
||||
// Generic log output that appends timing
|
||||
// information.
|
||||
//
|
||||
template <typename T>
|
||||
Logger::TestLog& Logger::TestLog::operator<<(const T& info) {
|
||||
TestLog& TestLog::operator<<(const T& info) {
|
||||
if (!logging_on) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Get the current time
|
||||
std::chrono::system_clock::time_point today{std::chrono::system_clock::now()};
|
||||
time_t tt{std::chrono::system_clock::to_time_t(today)};
|
||||
constexpr int length_time_str = 28;
|
||||
char buf[length_time_str];
|
||||
std::time_t tt{std::chrono::system_clock::to_time_t(today)};
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
std::tm tm;
|
||||
localtime_s(&tm, &tt); // Thread-safe on Windows
|
||||
#else
|
||||
std::tm tm = *std::localtime(&tt); // Thread-safe on Linux
|
||||
#endif
|
||||
|
||||
if (0 == ctime_s(buf, sizeof(buf), &tt)) {
|
||||
wchar_t wbuf[length_time_str];
|
||||
char const* ptr = buf;
|
||||
std::mbstate_t ps;
|
||||
size_t retVal;
|
||||
mbsrtowcs_s(&retVal, wbuf, length_time_str, &ptr, length_time_str, &ps);
|
||||
std::wstring_view temp(wbuf, retVal - 2);
|
||||
std::wstringstream stream;
|
||||
if (print_time_info) {
|
||||
stream << L"[" << temp << L"]" << L"\t";
|
||||
}
|
||||
// Buffer for formatted time
|
||||
char buf[100];
|
||||
std::strftime(buf, sizeof(buf), "%c", &tm);
|
||||
|
||||
if constexpr (std::is_same<T, std::string>()) {
|
||||
stream << towstr(info.data());
|
||||
} else {
|
||||
stream << info;
|
||||
}
|
||||
// Convert multi-byte to wide character string
|
||||
std::wstring wstr = string_to_wstring(buf);
|
||||
|
||||
if (min_log) {
|
||||
insert(stream.str());
|
||||
} else {
|
||||
std::wcout << stream.str();
|
||||
}
|
||||
print_time_info = false;
|
||||
// Use std::wstring_view to avoid copying the wide string
|
||||
std::wstring_view temp(wstr);
|
||||
|
||||
// Create a wstringstream to format the output
|
||||
std::wstringstream stream;
|
||||
if (print_time_info) {
|
||||
stream << L"[" << temp << L"]" << L"\t";
|
||||
}
|
||||
|
||||
// Append info to the stream
|
||||
if constexpr (std::is_same<T, std::string>()) {
|
||||
// Convert std::string to std::wstring and append
|
||||
std::wstring winfo = string_to_wstring(info);
|
||||
stream << winfo;
|
||||
} else {
|
||||
stream << info;
|
||||
}
|
||||
|
||||
// Output the formatted log
|
||||
if (min_log) {
|
||||
insert(stream.str());
|
||||
} else {
|
||||
std::wcout << stream.str();
|
||||
}
|
||||
|
||||
print_time_info = false;
|
||||
return *this;
|
||||
}
|
||||
} // namespace Logger
|
||||
|
||||
namespace Logger {
|
||||
template <typename CharT>
|
||||
|
@ -168,11 +188,10 @@ class cache_streambuf : public std::basic_streambuf<CharT> {
|
|||
using int_type = typename Base::int_type;
|
||||
|
||||
// Get the total number of unique errors found
|
||||
//
|
||||
inline size_t get_unique_errors();
|
||||
|
||||
protected:
|
||||
virtual int_type overflow(int_type ch = Traits::eof());
|
||||
virtual int_type overflow(int_type ch = Base::traits_type::eof());
|
||||
|
||||
private:
|
||||
std::basic_stringstream<char_type> buffer;
|
||||
|
@ -189,10 +208,8 @@ inline size_t cache_streambuf<CharT>::get_unique_errors() {
|
|||
|
||||
template <typename CharT>
|
||||
auto cache_streambuf<CharT>::overflow(int_type ch) -> int_type {
|
||||
// if not end of file
|
||||
//
|
||||
if (!Base::traits_type::eq_int_type(ch,
|
||||
Base::traits_type::eof())) {
|
||||
// If not end of file
|
||||
if (!Base::traits_type::eq_int_type(ch, Base::traits_type::eof())) {
|
||||
if (ch > 255) {
|
||||
if constexpr (std::is_same_v<char_type, char>) {
|
||||
std::cout << "Yikes";
|
||||
|
@ -217,4 +234,4 @@ auto cache_streambuf<CharT>::overflow(int_type ch) -> int_type {
|
|||
}
|
||||
} // namespace Logger
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
// Uses the onnxruntime to load the model
|
||||
// into a session.
|
||||
//
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
OnnxPrediction::OnnxPrediction(std::wstring& onnx_model_file, Ort::Env& env)
|
||||
: raw_model{nullptr},
|
||||
ptr_session{std::make_unique<Ort::Session>(env, onnx_model_file.c_str(), empty_session_option)},
|
||||
|
@ -16,6 +17,16 @@ OnnxPrediction::OnnxPrediction(std::wstring& onnx_model_file, Ort::Env& env)
|
|||
output_names(ptr_session->GetOutputCount()) {
|
||||
init();
|
||||
}
|
||||
#else
|
||||
OnnxPrediction::OnnxPrediction(std::wstring& onnx_model_file, Ort::Env& env)
|
||||
: raw_model{nullptr},
|
||||
// Convert std::wstring to std::string
|
||||
ptr_session{std::make_unique<Ort::Session>(env, wstring_to_string(onnx_model_file).c_str(), empty_session_option)},
|
||||
input_names(ptr_session->GetInputCount()),
|
||||
output_names(ptr_session->GetOutputCount()) {
|
||||
init();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Uses the onnx to seri
|
||||
//
|
||||
|
@ -60,7 +71,7 @@ std::wostream& operator<<(std::wostream& out, OnnxPrediction& pred) {
|
|||
auto pretty_print = [&out](auto ptr, Ort::Value& val) {
|
||||
out << L"[";
|
||||
std::wstring msg = L"";
|
||||
for (int i = 0; i < val.GetTensorTypeAndShapeInfo().GetElementCount(); i++) {
|
||||
for (int i = 0; i < static_cast<int>(val.GetTensorTypeAndShapeInfo().GetElementCount()); i++) {
|
||||
out << msg << ptr[i];
|
||||
msg = L", ";
|
||||
}
|
||||
|
@ -84,28 +95,15 @@ void GenerateDataForInputTypeTensor(OnnxPrediction& predict,
|
|||
(void)input_name;
|
||||
(void)input_index;
|
||||
|
||||
auto pretty_print = [&input_name](auto raw_data) {
|
||||
Logger::testLog << input_name << L" = ";
|
||||
Logger::testLog << L"[";
|
||||
std::wstring msg = L"";
|
||||
for (int i = 0; i < raw_data.size(); i++) {
|
||||
Logger::testLog << msg << raw_data[i];
|
||||
msg = L", ";
|
||||
}
|
||||
Logger::testLog << L"]\n";
|
||||
};
|
||||
|
||||
if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
|
||||
auto raw_data = GenerateRandomData(0.0f, elem_count, seed);
|
||||
// pretty_print(raw_data);
|
||||
predict << std::move(raw_data);
|
||||
} else if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
|
||||
int32_t initial = 0;
|
||||
auto raw_data = GenerateRandomData(initial, elem_count, seed);
|
||||
// pretty_print(raw_data);
|
||||
predict << std::move(raw_data);
|
||||
} else {
|
||||
throw std::exception("only floats are implemented");
|
||||
throw std::runtime_error("only floats are implemented");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,7 +141,7 @@ void OnnxPrediction::PrintOutputValues() {
|
|||
void OnnxPrediction::init() {
|
||||
// Initialize model input names
|
||||
//
|
||||
for (int i = 0; i < ptr_session->GetInputCount(); i++) {
|
||||
for (int i = 0; i < static_cast<int>(ptr_session->GetOutputCount()); i++) {
|
||||
// TODO Use push_back on input_names instead of assignment
|
||||
input_names_ptrs.push_back(ptr_session->GetInputNameAllocated(i, alloc));
|
||||
input_names[i] = input_names_ptrs.back().get();
|
||||
|
@ -152,7 +150,7 @@ void OnnxPrediction::init() {
|
|||
|
||||
// Initialize model output names
|
||||
//
|
||||
for (int i = 0; i < ptr_session->GetOutputCount(); i++) {
|
||||
for (int i = 0; i < static_cast<int>(ptr_session->GetOutputCount()); i++) {
|
||||
// TODO Use push_back on output_names instead of assignment
|
||||
output_names_ptrs.push_back(ptr_session->GetOutputNameAllocated(i, alloc));
|
||||
output_names[i] = output_names_ptrs.back().get();
|
||||
|
@ -169,8 +167,7 @@ Ort::AllocatorWithDefaultOptions& OnnxPrediction::GetAllocator() {
|
|||
void OnnxPrediction::SetupInput(
|
||||
InputGeneratorFunctionType GenerateData,
|
||||
size_t seed) {
|
||||
Logger::testLog << L"input data:\n";
|
||||
for (int i = 0; i < ptr_session->GetInputCount(); i++) {
|
||||
for (int i = 0; i < static_cast<int>(ptr_session->GetOutputCount()); i++) {
|
||||
auto inputType = ptr_session->GetInputTypeInfo(i);
|
||||
|
||||
if (inputType.GetONNXType() == ONNX_TYPE_TENSOR) {
|
||||
|
@ -189,4 +186,4 @@ void OnnxPrediction::SetupInput(
|
|||
}
|
||||
}
|
||||
Logger::testLog << Logger::endl;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
#include "testlog.h"
|
||||
#include "OnnxPrediction.h"
|
||||
#include "onnxruntime_session_options_config_keys.h"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
using user_options = struct
|
||||
|
@ -16,6 +15,13 @@ using user_options = struct
|
|||
bool is_ort;
|
||||
};
|
||||
|
||||
#if !defined(_WIN32) || !defined(_WIN64)
|
||||
std::string wstring_to_string(const std::wstring& wstr) {
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
return converter.to_bytes(wstr);
|
||||
}
|
||||
#endif
|
||||
|
||||
void predict(onnx::ModelProto& model_proto, unsigned int seed, Ort::Env& env) {
|
||||
// Create object for prediction
|
||||
//
|
||||
|
@ -61,10 +67,9 @@ void mutateModelTest(onnx::ModelProto& model_proto,
|
|||
|
||||
// Mutate model
|
||||
//
|
||||
Logger::testLog << L"Model Successfully Initialized" << Logger::endl;
|
||||
Logger::testLog << "Model Successfully Initialized" << Logger::endl;
|
||||
mutator.Seed(seed);
|
||||
mutator.Mutate(&model_proto, model_proto.ByteSizeLong());
|
||||
|
||||
if (opt.write_model) {
|
||||
// Create file to store model
|
||||
//
|
||||
|
@ -73,8 +78,11 @@ void mutateModelTest(onnx::ModelProto& model_proto,
|
|||
auto mutateModelFileName = mutateModelName.str();
|
||||
|
||||
// Log the model to a file
|
||||
//
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
std::ofstream outStream(mutateModelFileName);
|
||||
#else
|
||||
std::ofstream outStream(wstring_to_string(mutateModelFileName));
|
||||
#endif
|
||||
model_proto.SerializeToOstream(&outStream);
|
||||
Logger::testLog << "Mutated Model Written to file: " << mutateModelFileName << Logger::endl;
|
||||
|
||||
|
@ -153,7 +161,7 @@ int processCommandLine(int argc, char* argv[], runtimeOpt& opt) {
|
|||
std::stringstream parser{argv[3]};
|
||||
parser >> opt.seed;
|
||||
if (parser.bad()) {
|
||||
throw std::exception("Could not parse seed from command line");
|
||||
throw std::runtime_error("Could not parse seed from command line");
|
||||
}
|
||||
|
||||
std::wcout << L"seed: " << opt.seed << L"\n";
|
||||
|
@ -169,7 +177,7 @@ int processCommandLine(int argc, char* argv[], runtimeOpt& opt) {
|
|||
parser >> desired_scale;
|
||||
|
||||
if (parser.bad()) {
|
||||
throw std::exception("Could not parse the time scale from the command line");
|
||||
throw std::runtime_error("Could not parse the time scale from the command line");
|
||||
}
|
||||
|
||||
opt.scale = static_cast<timeScale>(std::tolower(desired_scale));
|
||||
|
@ -179,13 +187,13 @@ int processCommandLine(int argc, char* argv[], runtimeOpt& opt) {
|
|||
case timeScale::Sec:
|
||||
break;
|
||||
default:
|
||||
throw std::exception("Could not parse the time scale from the command line");
|
||||
throw std::runtime_error("Could not parse the time scale from the command line");
|
||||
}
|
||||
|
||||
parser << argv[index--];
|
||||
parser >> opt.test_time_out;
|
||||
if (parser.bad()) {
|
||||
throw std::exception("Could not parse the time value from the command line");
|
||||
throw std::runtime_error("Could not parse the time value from the command line");
|
||||
}
|
||||
|
||||
Logger::testLog << L"Running Test for: " << opt.test_time_out << desired_scale << Logger::endl;
|
||||
|
@ -193,7 +201,7 @@ int processCommandLine(int argc, char* argv[], runtimeOpt& opt) {
|
|||
Logger::testLog << L"Model file: " << opt.model_file_name << Logger::endl;
|
||||
std::filesystem::path model_file_namePath{opt.model_file_name};
|
||||
if (!std::filesystem::exists(model_file_namePath)) {
|
||||
throw std::exception("Cannot find model file");
|
||||
throw std::runtime_error("Cannot find model file");
|
||||
}
|
||||
|
||||
// process options
|
||||
|
@ -251,9 +259,9 @@ struct RunStats {
|
|||
static void fuzz_handle_exception(struct RunStats& run_stats) {
|
||||
try {
|
||||
throw;
|
||||
} catch (const Ort::Exception& ortException) {
|
||||
} catch (const Ort::Exception& ortexception) {
|
||||
run_stats.num_ort_exception++;
|
||||
Logger::testLog << L"onnx runtime exception: " << ortException.what() << Logger::endl;
|
||||
Logger::testLog << L"onnx runtime exception: " << ortexception.what() << Logger::endl;
|
||||
Logger::testLog << "Failed Test iteration: " << run_stats.iteration++ << Logger::endl;
|
||||
} catch (const std::exception& e) {
|
||||
run_stats.num_std_exception++;
|
||||
|
@ -299,8 +307,11 @@ int main(int argc, char* argv[]) {
|
|||
std::wstring model_file{model_file_name};
|
||||
|
||||
// Create a stream to hold the model
|
||||
//
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
std::ifstream modelStream{model_file, std::ios::in | std::ios::binary};
|
||||
#else
|
||||
std::ifstream modelStream(wstring_to_string(model_file), std::ios::in | std::ios::binary);
|
||||
#endif
|
||||
if (opt.user_opt.is_ort == false) {
|
||||
// Create an onnx protobuf object
|
||||
//
|
||||
|
@ -345,7 +356,7 @@ int main(int argc, char* argv[]) {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
throw std::exception("Unable to initialize the Onnx model in memory");
|
||||
throw std::runtime_error("Unable to initialize the Onnx model in memory");
|
||||
}
|
||||
} else {
|
||||
std::wstring ort_model_file = model_file;
|
||||
|
@ -353,16 +364,25 @@ int main(int argc, char* argv[]) {
|
|||
ort_model_file = model_file + L".ort";
|
||||
Ort::SessionOptions so;
|
||||
so.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
|
||||
so.SetOptimizedModelFilePath(ort_model_file.c_str());
|
||||
so.AddConfigEntry(kOrtSessionOptionsConfigSaveModelFormat, "ORT");
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
so.SetOptimizedModelFilePath(ort_model_file.c_str());
|
||||
Ort::Session session(env, model_file.c_str(), so);
|
||||
#else
|
||||
so.SetOptimizedModelFilePath(wstring_to_string(ort_model_file).c_str());
|
||||
Ort::Session session(env, wstring_to_string(model_file).c_str(), so);
|
||||
#endif
|
||||
} else if (model_file.substr(model_file.find_last_of(L".") + 1) != L"ort") {
|
||||
Logger::testLog << L"Input file name extension is not 'onnx' or 'ort' " << Logger::endl;
|
||||
return 1;
|
||||
}
|
||||
size_t num_bytes = std::filesystem::file_size(ort_model_file);
|
||||
std::vector<char> model_data(num_bytes);
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
std::ifstream ortModelStream(ort_model_file, std::ifstream::in | std::ifstream::binary);
|
||||
#else
|
||||
std::ifstream ortModelStream(wstring_to_string(ort_model_file), std::ifstream::in | std::ifstream::binary);
|
||||
#endif
|
||||
ortModelStream.read(model_data.data(), num_bytes);
|
||||
ortModelStream.close();
|
||||
// Currently mutations are generated by using XOR of a byte with the preceding byte at a time.
|
||||
|
@ -389,7 +409,8 @@ int main(int argc, char* argv[]) {
|
|||
if (user_opt.stress) {
|
||||
Logger::testLog.enable();
|
||||
}
|
||||
size_t toal_num_exception = run_stats.num_unknown_exception + run_stats.num_std_exception + run_stats.num_ort_exception;
|
||||
size_t toal_num_exception =
|
||||
run_stats.num_unknown_exception + run_stats.num_std_exception + run_stats.num_ort_exception;
|
||||
Logger::testLog << L"Total number of exceptions: " << toal_num_exception << Logger::endl;
|
||||
Logger::testLog << L"Number of Unknown exceptions: " << run_stats.num_unknown_exception << Logger::endl;
|
||||
Logger::testLog << L"Number of ort exceptions: " << run_stats.num_ort_exception << Logger::endl;
|
||||
|
|
|
@ -112,7 +112,7 @@ TestLog::TestLog()
|
|||
if (!std::filesystem::exists(mutateModelDir)) {
|
||||
std::filesystem::create_directory(mutateModelDir);
|
||||
}
|
||||
logFile.open(L"out/log", std::ios::ate);
|
||||
logFile.open("out/log", std::ios::ate);
|
||||
} else {
|
||||
throw std::runtime_error("TestLog has already been initialized. Call GetTestLog() to use it");
|
||||
}
|
||||
|
@ -121,18 +121,33 @@ TestLog::TestLog()
|
|||
// Helper function to convert string to wstring
|
||||
//
|
||||
std::wstring towstr(const char* pStr) {
|
||||
std::mbstate_t ps;
|
||||
size_t retVal;
|
||||
std::mbstate_t ps = std::mbstate_t();
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
size_t length_str = strnlen(pStr, onnxruntime::kMaxStrLen);
|
||||
size_t retVal;
|
||||
// On Windows, use mbsrtowcs_s which is safe and provides size checking
|
||||
mbsrtowcs_s(&retVal, nullptr, 0, &pStr, length_str, &ps);
|
||||
retVal += 1;
|
||||
auto ptr = std::make_unique<wchar_t[]>(retVal);
|
||||
if (ptr == nullptr) {
|
||||
std::stringstream str;
|
||||
str << "Failed to allocate memory: " << __func__ << __LINE__ << "\n";
|
||||
throw std::exception{str.str().data()};
|
||||
str << "Failed to allocate memory: " << __func__ << " " << __LINE__ << "\n";
|
||||
throw std::runtime_error{str.str().data()};
|
||||
}
|
||||
mbsrtowcs_s(&retVal, ptr.get(), retVal, &pStr, length_str, &ps);
|
||||
#else
|
||||
// On Linux, use mbsrtowcs which is part of the standard library
|
||||
size_t retVal = std::mbsrtowcs(nullptr, &pStr, 0, &ps);
|
||||
retVal += 1;
|
||||
auto ptr = std::make_unique<wchar_t[]>(retVal);
|
||||
if (ptr == nullptr) {
|
||||
std::stringstream str;
|
||||
str << "Failed to allocate memory: " << __func__ << " " << __LINE__ << "\n";
|
||||
throw std::runtime_error{str.str().data()};
|
||||
}
|
||||
std::mbsrtowcs(ptr.get(), &pStr, retVal, &ps);
|
||||
#endif
|
||||
|
||||
return std::wstring{ptr.get()};
|
||||
}
|
||||
} // namespace Logger
|
||||
|
|
Загрузка…
Ссылка в новой задаче