[Fuzzer] Add fuzzer support for linux (#21996)

### Description
Added some change in fuzzer project code to support linux also.

How to test on linux:
1. Make sure you have installed clang/llvm.
2. run below command to build asan instrumented project:
```
CFLAGS="-g -fsanitize=address -shared-libasan -fprofile-instr-generate -fcoverage-mapping" CXXFLAGS="-g -shared-libasan -fsanitize=address -fprofile-instr-generate -fcoverage-mapping" CC=clang CXX=clang++ ./build.sh --update --build --config Debug --compile_no_warning_as_error --build_shared_lib --skip_submodule_sync --skip_tests --use_full_protobuf  --parallel --fuzz_testing --build_dir build/
```

3. run fuzzer for some time, it will generate *.profraw file:
```
LLVM_PROFILE_FILE="%p.profraw" ./build/Debug/onnxruntime_security_fuzz /t /v onnxruntime/test/testdata/bart_tiny.onnx 1 m
```
4. Get the cov by running below cmd:
```
llvm-profdata merge -sparse *.profraw -o default.profdata
llvm-cov report ./build/Debug/onnxruntime_security_fuzz  -instr-profile=default.profdata
```

<img width="1566" alt="Screenshot 2024-09-05 at 4 25 08 PM"
src="https://github.com/user-attachments/assets/2aa0bb83-6634-4d33-b026-3535e97df431">



### Motivation and Context
1. Currently fuzzer only supports windows and MSVC, we can't generate
the code coverage using MSVC. With clang/llvm we can try and use clang
instrumentation and llvm tools like llvm-cov.
2. In future we can add coverage guided fuzzer (libfuzzer) in same
project. (Working on it)
This commit is contained in:
0xdr3dd 2024-09-06 00:22:15 +05:30 коммит произвёл GitHub
Родитель f4d62eeb2e
Коммит 2dae8aaced
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
8 изменённых файлов: 215 добавлений и 137 удалений

Просмотреть файл

@ -413,8 +413,8 @@ endif(MSVC_Z7_OVERRIDE)
set_msvc_c_cpp_compiler_warning_level(3)
# Fuzz test has only been tested with BUILD_SHARED_LIB option,
# using the MSVC compiler and on windows OS.
if (MSVC AND WIN32 AND onnxruntime_FUZZ_TEST AND onnxruntime_BUILD_SHARED_LIB AND onnxruntime_USE_FULL_PROTOBUF)
# using the MSVC compiler and on windows OS and clang/gcc compiler on Linux.
if (onnxruntime_FUZZ_TEST AND onnxruntime_BUILD_SHARED_LIB AND onnxruntime_USE_FULL_PROTOBUF)
# Fuzz test library dependency, protobuf-mutator,
# needs the onnx message to be compiled using "non-lite protobuf version"
set(onnxruntime_FUZZ_ENABLED ON)

Просмотреть файл

@ -5,15 +5,15 @@
# the fuzzing project
if (onnxruntime_FUZZ_ENABLED)
message(STATUS "Building dependency protobuf-mutator and libfuzzer")
# set the options used to control the protobuf-mutator build
set(PROTOBUF_LIBRARIES ${PROTOBUF_LIB})
set(LIB_PROTO_MUTATOR_TESTING OFF)
# include the protobuf-mutator CMakeLists.txt rather than the projects CMakeLists.txt to avoid target clashes
# with google test
add_subdirectory("external/libprotobuf-mutator/src")
# add the appropriate include directory and compilation flags
# needed by the protobuf-mutator target and the libfuzzer
set(PROTOBUF_MUT_INCLUDE_DIRS "external/libprotobuf-mutator")
@ -21,45 +21,65 @@ if (onnxruntime_FUZZ_ENABLED)
onnxruntime_add_include_to_target(protobuf-mutator-libfuzzer ${PROTOBUF_LIB})
target_include_directories(protobuf-mutator PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS})
target_include_directories(protobuf-mutator-libfuzzer PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS})
target_compile_options(protobuf-mutator PRIVATE "/wd4244" "/wd4245" "/wd4267" "/wd4100" "/wd4456")
target_compile_options(protobuf-mutator-libfuzzer PRIVATE "/wd4146" "/wd4267")
# add Fuzzing Engine Build Configuration
if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
# MSVC-specific compiler options
target_compile_options(protobuf-mutator PRIVATE "/wd4244" "/wd4245" "/wd4267" "/wd4100" "/wd4456")
target_compile_options(protobuf-mutator-libfuzzer PRIVATE "/wd4146" "/wd4267")
else()
# Linux-specific compiler options
target_compile_options(protobuf-mutator PRIVATE
-Wno-shorten-64-to-32
-Wno-conversion
-Wno-sign-compare
-Wno-unused-parameter
-Wno-shadow
-Wno-unused
-fexceptions
)
target_compile_options(protobuf-mutator-libfuzzer PRIVATE
-Wno-shorten-64-to-32
-Wno-conversion
-Wno-unused
-fexceptions
)
endif()
# add Fuzzing Engine Build Configuration
message(STATUS "Building Fuzzing engine")
# set Fuzz root directory
set(SEC_FUZZ_ROOT ${TEST_SRC_DIR}/fuzzing)
# Security fuzzing engine src file reference
set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp"
"${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp"
"${SEC_FUZZ_ROOT}/src/testlog.cpp"
# Security fuzzing engine src file reference
set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp"
"${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp"
"${SEC_FUZZ_ROOT}/src/testlog.cpp"
"${SEC_FUZZ_ROOT}/src/test.cpp")
# compile the executables
onnxruntime_add_executable(onnxruntime_security_fuzz ${SEC_FUZ_SRC})
# compile with c++17
target_compile_features(onnxruntime_security_fuzz PUBLIC cxx_std_17)
# Security fuzzing engine header file reference
onnxruntime_add_include_to_target(onnxruntime_security_fuzz onnx onnxruntime)
# Assign all include to one variable
set(SEC_FUZ_INC "${SEC_FUZZ_ROOT}/include")
set(INCLUDE_FILES ${SEC_FUZ_INC} "$<TARGET_PROPERTY:protobuf-mutator,INCLUDE_DIRECTORIES>")
# add all these include directory to the Fuzzing engine
target_include_directories(onnxruntime_security_fuzz PRIVATE ${INCLUDE_FILES})
# add link libraries the project
target_link_libraries(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB})
# add the dependencies
add_dependencies(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB})
# copy the dlls to the execution directory
add_custom_command(TARGET onnxruntime_security_fuzz POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> $<TARGET_FILE_DIR:onnxruntime_security_fuzz>
COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:${PROTOBUF_LIB}> $<TARGET_FILE_DIR:onnxruntime_security_fuzz>)
endif()
endif()

Просмотреть файл

@ -6,6 +6,7 @@
#include <random>
#include <map>
#include <chrono>
#include <stdexcept>
// Default parameter will produce a shape with alpha = 0.5
// and beta = 0.5.
@ -81,7 +82,7 @@ class BetaDistribution {
for (int i = 0; i < sample_size; i++) {
calc_type sample = convert_to_fixed_range(gen);
calc_type highest_probability_temp = highest_probability;
highest_probability = std::max({highest_probability_temp, distribution(sample)});
highest_probability = std::max(highest_probability_temp, distribution(sample));
// A new sample number with a higher probability has been found
//
@ -133,7 +134,7 @@ class BetaDistribution {
return static_cast<double>(result);
} else {
throw std::exception("Non special gamma values not yet Implemeted");
throw std::runtime_error("Non special gamma values not yet Implemeted");
}
}
@ -154,7 +155,7 @@ class BetaDistribution {
calc_type term2{pow((randVar - min()) / range, m_alpha - 1)};
calc_type term3{pow((max() - randVar) / range, m_beta - 1)};
return {term * term1 * term2 * term3};
return (term * term1 * term2 * term3);
}
// Used to convert the number that generator produces
@ -170,9 +171,9 @@ class BetaDistribution {
// Convert the number to the range [beginRange, endRange]
//
calc_type range{std::numeric_limits<generator::result_type>::max() - std::numeric_limits<generator::result_type>::lowest()};
calc_type range{std::numeric_limits<typename generator::result_type>::max() - std::numeric_limits<typename generator::result_type>::lowest()};
calc_type delta{x - std::numeric_limits<generator::result_type>::lowest()};
calc_type delta{x - std::numeric_limits<typename generator::result_type>::lowest()};
calc_type ratio{delta / range};
calc_type new_range{static_cast<calc_type>(max()) - static_cast<calc_type>(min())};
@ -183,7 +184,7 @@ class BetaDistribution {
}
calc_type res{(ratio * new_range) + min()};
return {res};
return (res);
}
private:
@ -235,4 +236,4 @@ GenerateRandomData(ONNX_ELEMENT_VALUE_TYPE initialValue, size_t numElementsToGen
return randomDataBucket;
}
#endif
#endif

Просмотреть файл

@ -51,6 +51,13 @@ class OnnxPrediction {
//
OnnxPrediction(const std::vector<char>& model_data, Ort::Env& env);
#if !defined(_WIN32) && !defined(_WIN64)
// Helper function to convert std::wstring to std::string
std::string wstring_to_string(const std::wstring& wstr) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.to_bytes(wstr);
}
#endif
// Data to run prediction on
//
template <typename T>
@ -80,7 +87,7 @@ class OnnxPrediction {
input_value = Ort::Value::CreateTensor(alloc.GetInfo(),
input_data[curr_input_index].get(), data_size_in_bytes, shapeInfo.data(), shapeInfo.size(), elem_type);
} else {
throw std::exception("only floats are implemented");
throw std::runtime_error("only floats are implemented");
}
// Insert data into the next input type

Просмотреть файл

@ -9,9 +9,16 @@
#include <vector>
#include <map>
#include <fstream>
#include <chrono>
#include <unordered_map>
#include <cwchar>
#include <locale>
#include <codecvt>
#if defined(_WIN32) || defined(_WIN64)
#include <windows.h> // For MultiByteToWideChar and CP_UTF8
#endif
// Forward declarartions
// Forward declarations
//
class OnnxPrediction;
std::wostream& operator<<(std::wostream& out, OnnxPrediction& pred);
@ -25,49 +32,37 @@ using LogEndln = void*;
//
class TestLog {
public:
// Flush out all output streams
// used by logger
//
// Flush out all output streams used by logger
void flush();
// Print out output of a prediction.
//
TestLog& operator<<(OnnxPrediction& pred);
// Generic log output that appends timing
// information.
//
// Generic log output that appends timing information.
template <typename T>
TestLog& operator<<(const T& info);
// Disable logging
//
inline void disable();
// Enable logging
//
inline void enable();
// Ends the current line so that the
// next line can start with time information.
//
// Ends the current line so that the next line
// can start with time information.
void operator<<(LogEndln info);
// Minimize log
//
inline void minLog();
// Maintain ring buffer
// Note:
// This is only used for minimun logging
// if normal logging is being used this map
// must be constrained.
//
// Note: This is only used for minimum logging;
// if normal logging is being used, this map must
// be constrained.
void insert(std::wstring data);
// Singleton constructor only one object exists
// Hence this resource is not thread-safe
//
// Singleton constructor - only one object exists
// Note: this resource is not thread-safe
TestLog();
private:
@ -79,16 +74,32 @@ class TestLog {
std::map<size_t, std::pair<size_t, std::wstring>> ring_buffer;
static constexpr int logFileLineWidth{128};
static constexpr int logFileLen{1000};
std::wstring string_to_wstring(const std::string& str) {
#if defined(_WIN32) || defined(_WIN64)
// Windows implementation
int size_needed = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
std::wstring wstrTo(size_needed, 0);
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstrTo[0], size_needed);
return wstrTo;
#else
// Linux implementation
std::mbstate_t state = std::mbstate_t();
const char* c_str = str.c_str();
size_t len = std::mbsrtowcs(nullptr, &c_str, 0, &state);
std::wstring wstr(len, L'\0');
std::mbsrtowcs(&wstr[0], &c_str, len, &state);
return wstr;
#endif
}
};
// Reference to initialized logger
// Not this resource is not thread safe and only one
// exists for the entire process.
// Note: this resource is not thread-safe and only
// one exists for the entire process.
//
extern TestLog testLog;
// Object used to mark end of line format
// for testLog
// Object used to mark end of line format for testLog
//
static constexpr LogEndln endl = nullptr;
@ -100,13 +111,11 @@ std::wstring towstr(const char* pStr);
// Inline Functions
// Minimize log
//
inline void Logger::TestLog::minLog() {
min_log = true;
}
// Enable logging
//
inline void Logger::TestLog::enable() {
logging_on = true;
}
@ -115,49 +124,60 @@ inline void Logger::TestLog::disable() {
logging_on = false;
}
// Template functions
namespace Logger {
// Generic log output that appends timing
// information.
//
template <typename T>
Logger::TestLog& Logger::TestLog::operator<<(const T& info) {
TestLog& TestLog::operator<<(const T& info) {
if (!logging_on) {
return *this;
}
// Get the current time
std::chrono::system_clock::time_point today{std::chrono::system_clock::now()};
time_t tt{std::chrono::system_clock::to_time_t(today)};
constexpr int length_time_str = 28;
char buf[length_time_str];
std::time_t tt{std::chrono::system_clock::to_time_t(today)};
#if defined(_WIN32) || defined(_WIN64)
std::tm tm;
localtime_s(&tm, &tt); // Thread-safe on Windows
#else
std::tm tm = *std::localtime(&tt); // Thread-safe on Linux
#endif
if (0 == ctime_s(buf, sizeof(buf), &tt)) {
wchar_t wbuf[length_time_str];
char const* ptr = buf;
std::mbstate_t ps;
size_t retVal;
mbsrtowcs_s(&retVal, wbuf, length_time_str, &ptr, length_time_str, &ps);
std::wstring_view temp(wbuf, retVal - 2);
std::wstringstream stream;
if (print_time_info) {
stream << L"[" << temp << L"]" << L"\t";
}
// Buffer for formatted time
char buf[100];
std::strftime(buf, sizeof(buf), "%c", &tm);
if constexpr (std::is_same<T, std::string>()) {
stream << towstr(info.data());
} else {
stream << info;
}
// Convert multi-byte to wide character string
std::wstring wstr = string_to_wstring(buf);
if (min_log) {
insert(stream.str());
} else {
std::wcout << stream.str();
}
print_time_info = false;
// Use std::wstring_view to avoid copying the wide string
std::wstring_view temp(wstr);
// Create a wstringstream to format the output
std::wstringstream stream;
if (print_time_info) {
stream << L"[" << temp << L"]" << L"\t";
}
// Append info to the stream
if constexpr (std::is_same<T, std::string>()) {
// Convert std::string to std::wstring and append
std::wstring winfo = string_to_wstring(info);
stream << winfo;
} else {
stream << info;
}
// Output the formatted log
if (min_log) {
insert(stream.str());
} else {
std::wcout << stream.str();
}
print_time_info = false;
return *this;
}
} // namespace Logger
namespace Logger {
template <typename CharT>
@ -168,11 +188,10 @@ class cache_streambuf : public std::basic_streambuf<CharT> {
using int_type = typename Base::int_type;
// Get the total number of unique errors found
//
inline size_t get_unique_errors();
protected:
virtual int_type overflow(int_type ch = Traits::eof());
virtual int_type overflow(int_type ch = Base::traits_type::eof());
private:
std::basic_stringstream<char_type> buffer;
@ -189,10 +208,8 @@ inline size_t cache_streambuf<CharT>::get_unique_errors() {
template <typename CharT>
auto cache_streambuf<CharT>::overflow(int_type ch) -> int_type {
// if not end of file
//
if (!Base::traits_type::eq_int_type(ch,
Base::traits_type::eof())) {
// If not end of file
if (!Base::traits_type::eq_int_type(ch, Base::traits_type::eof())) {
if (ch > 255) {
if constexpr (std::is_same_v<char_type, char>) {
std::cout << "Yikes";
@ -217,4 +234,4 @@ auto cache_streambuf<CharT>::overflow(int_type ch) -> int_type {
}
} // namespace Logger
#endif
#endif

Просмотреть файл

@ -9,6 +9,7 @@
// Uses the onnxruntime to load the model
// into a session.
//
#if defined(_WIN32) || defined(_WIN64)
OnnxPrediction::OnnxPrediction(std::wstring& onnx_model_file, Ort::Env& env)
: raw_model{nullptr},
ptr_session{std::make_unique<Ort::Session>(env, onnx_model_file.c_str(), empty_session_option)},
@ -16,6 +17,16 @@ OnnxPrediction::OnnxPrediction(std::wstring& onnx_model_file, Ort::Env& env)
output_names(ptr_session->GetOutputCount()) {
init();
}
#else
OnnxPrediction::OnnxPrediction(std::wstring& onnx_model_file, Ort::Env& env)
: raw_model{nullptr},
// Convert std::wstring to std::string
ptr_session{std::make_unique<Ort::Session>(env, wstring_to_string(onnx_model_file).c_str(), empty_session_option)},
input_names(ptr_session->GetInputCount()),
output_names(ptr_session->GetOutputCount()) {
init();
}
#endif
// Uses the onnx to seri
//
@ -60,7 +71,7 @@ std::wostream& operator<<(std::wostream& out, OnnxPrediction& pred) {
auto pretty_print = [&out](auto ptr, Ort::Value& val) {
out << L"[";
std::wstring msg = L"";
for (int i = 0; i < val.GetTensorTypeAndShapeInfo().GetElementCount(); i++) {
for (int i = 0; i < static_cast<int>(val.GetTensorTypeAndShapeInfo().GetElementCount()); i++) {
out << msg << ptr[i];
msg = L", ";
}
@ -84,28 +95,15 @@ void GenerateDataForInputTypeTensor(OnnxPrediction& predict,
(void)input_name;
(void)input_index;
auto pretty_print = [&input_name](auto raw_data) {
Logger::testLog << input_name << L" = ";
Logger::testLog << L"[";
std::wstring msg = L"";
for (int i = 0; i < raw_data.size(); i++) {
Logger::testLog << msg << raw_data[i];
msg = L", ";
}
Logger::testLog << L"]\n";
};
if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
auto raw_data = GenerateRandomData(0.0f, elem_count, seed);
// pretty_print(raw_data);
predict << std::move(raw_data);
} else if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
int32_t initial = 0;
auto raw_data = GenerateRandomData(initial, elem_count, seed);
// pretty_print(raw_data);
predict << std::move(raw_data);
} else {
throw std::exception("only floats are implemented");
throw std::runtime_error("only floats are implemented");
}
}
@ -143,7 +141,7 @@ void OnnxPrediction::PrintOutputValues() {
void OnnxPrediction::init() {
// Initialize model input names
//
for (int i = 0; i < ptr_session->GetInputCount(); i++) {
for (int i = 0; i < static_cast<int>(ptr_session->GetOutputCount()); i++) {
// TODO Use push_back on input_names instead of assignment
input_names_ptrs.push_back(ptr_session->GetInputNameAllocated(i, alloc));
input_names[i] = input_names_ptrs.back().get();
@ -152,7 +150,7 @@ void OnnxPrediction::init() {
// Initialize model output names
//
for (int i = 0; i < ptr_session->GetOutputCount(); i++) {
for (int i = 0; i < static_cast<int>(ptr_session->GetOutputCount()); i++) {
// TODO Use push_back on output_names instead of assignment
output_names_ptrs.push_back(ptr_session->GetOutputNameAllocated(i, alloc));
output_names[i] = output_names_ptrs.back().get();
@ -169,8 +167,7 @@ Ort::AllocatorWithDefaultOptions& OnnxPrediction::GetAllocator() {
void OnnxPrediction::SetupInput(
InputGeneratorFunctionType GenerateData,
size_t seed) {
Logger::testLog << L"input data:\n";
for (int i = 0; i < ptr_session->GetInputCount(); i++) {
for (int i = 0; i < static_cast<int>(ptr_session->GetOutputCount()); i++) {
auto inputType = ptr_session->GetInputTypeInfo(i);
if (inputType.GetONNXType() == ONNX_TYPE_TENSOR) {
@ -189,4 +186,4 @@ void OnnxPrediction::SetupInput(
}
}
Logger::testLog << Logger::endl;
}
}

Просмотреть файл

@ -5,7 +5,6 @@
#include "testlog.h"
#include "OnnxPrediction.h"
#include "onnxruntime_session_options_config_keys.h"
#include <type_traits>
using user_options = struct
@ -16,6 +15,13 @@ using user_options = struct
bool is_ort;
};
#if !defined(_WIN32) || !defined(_WIN64)
std::string wstring_to_string(const std::wstring& wstr) {
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
return converter.to_bytes(wstr);
}
#endif
void predict(onnx::ModelProto& model_proto, unsigned int seed, Ort::Env& env) {
// Create object for prediction
//
@ -61,10 +67,9 @@ void mutateModelTest(onnx::ModelProto& model_proto,
// Mutate model
//
Logger::testLog << L"Model Successfully Initialized" << Logger::endl;
Logger::testLog << "Model Successfully Initialized" << Logger::endl;
mutator.Seed(seed);
mutator.Mutate(&model_proto, model_proto.ByteSizeLong());
if (opt.write_model) {
// Create file to store model
//
@ -73,8 +78,11 @@ void mutateModelTest(onnx::ModelProto& model_proto,
auto mutateModelFileName = mutateModelName.str();
// Log the model to a file
//
#if defined(_WIN32) || defined(_WIN64)
std::ofstream outStream(mutateModelFileName);
#else
std::ofstream outStream(wstring_to_string(mutateModelFileName));
#endif
model_proto.SerializeToOstream(&outStream);
Logger::testLog << "Mutated Model Written to file: " << mutateModelFileName << Logger::endl;
@ -153,7 +161,7 @@ int processCommandLine(int argc, char* argv[], runtimeOpt& opt) {
std::stringstream parser{argv[3]};
parser >> opt.seed;
if (parser.bad()) {
throw std::exception("Could not parse seed from command line");
throw std::runtime_error("Could not parse seed from command line");
}
std::wcout << L"seed: " << opt.seed << L"\n";
@ -169,7 +177,7 @@ int processCommandLine(int argc, char* argv[], runtimeOpt& opt) {
parser >> desired_scale;
if (parser.bad()) {
throw std::exception("Could not parse the time scale from the command line");
throw std::runtime_error("Could not parse the time scale from the command line");
}
opt.scale = static_cast<timeScale>(std::tolower(desired_scale));
@ -179,13 +187,13 @@ int processCommandLine(int argc, char* argv[], runtimeOpt& opt) {
case timeScale::Sec:
break;
default:
throw std::exception("Could not parse the time scale from the command line");
throw std::runtime_error("Could not parse the time scale from the command line");
}
parser << argv[index--];
parser >> opt.test_time_out;
if (parser.bad()) {
throw std::exception("Could not parse the time value from the command line");
throw std::runtime_error("Could not parse the time value from the command line");
}
Logger::testLog << L"Running Test for: " << opt.test_time_out << desired_scale << Logger::endl;
@ -193,7 +201,7 @@ int processCommandLine(int argc, char* argv[], runtimeOpt& opt) {
Logger::testLog << L"Model file: " << opt.model_file_name << Logger::endl;
std::filesystem::path model_file_namePath{opt.model_file_name};
if (!std::filesystem::exists(model_file_namePath)) {
throw std::exception("Cannot find model file");
throw std::runtime_error("Cannot find model file");
}
// process options
@ -251,9 +259,9 @@ struct RunStats {
static void fuzz_handle_exception(struct RunStats& run_stats) {
try {
throw;
} catch (const Ort::Exception& ortException) {
} catch (const Ort::Exception& ortexception) {
run_stats.num_ort_exception++;
Logger::testLog << L"onnx runtime exception: " << ortException.what() << Logger::endl;
Logger::testLog << L"onnx runtime exception: " << ortexception.what() << Logger::endl;
Logger::testLog << "Failed Test iteration: " << run_stats.iteration++ << Logger::endl;
} catch (const std::exception& e) {
run_stats.num_std_exception++;
@ -299,8 +307,11 @@ int main(int argc, char* argv[]) {
std::wstring model_file{model_file_name};
// Create a stream to hold the model
//
#if defined(_WIN32) || defined(_WIN64)
std::ifstream modelStream{model_file, std::ios::in | std::ios::binary};
#else
std::ifstream modelStream(wstring_to_string(model_file), std::ios::in | std::ios::binary);
#endif
if (opt.user_opt.is_ort == false) {
// Create an onnx protobuf object
//
@ -345,7 +356,7 @@ int main(int argc, char* argv[]) {
}
}
} else {
throw std::exception("Unable to initialize the Onnx model in memory");
throw std::runtime_error("Unable to initialize the Onnx model in memory");
}
} else {
std::wstring ort_model_file = model_file;
@ -353,16 +364,25 @@ int main(int argc, char* argv[]) {
ort_model_file = model_file + L".ort";
Ort::SessionOptions so;
so.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
so.SetOptimizedModelFilePath(ort_model_file.c_str());
so.AddConfigEntry(kOrtSessionOptionsConfigSaveModelFormat, "ORT");
#if defined(_WIN32) || defined(_WIN64)
so.SetOptimizedModelFilePath(ort_model_file.c_str());
Ort::Session session(env, model_file.c_str(), so);
#else
so.SetOptimizedModelFilePath(wstring_to_string(ort_model_file).c_str());
Ort::Session session(env, wstring_to_string(model_file).c_str(), so);
#endif
} else if (model_file.substr(model_file.find_last_of(L".") + 1) != L"ort") {
Logger::testLog << L"Input file name extension is not 'onnx' or 'ort' " << Logger::endl;
return 1;
}
size_t num_bytes = std::filesystem::file_size(ort_model_file);
std::vector<char> model_data(num_bytes);
#if defined(_WIN32) || defined(_WIN64)
std::ifstream ortModelStream(ort_model_file, std::ifstream::in | std::ifstream::binary);
#else
std::ifstream ortModelStream(wstring_to_string(ort_model_file), std::ifstream::in | std::ifstream::binary);
#endif
ortModelStream.read(model_data.data(), num_bytes);
ortModelStream.close();
// Currently mutations are generated by using XOR of a byte with the preceding byte at a time.
@ -389,7 +409,8 @@ int main(int argc, char* argv[]) {
if (user_opt.stress) {
Logger::testLog.enable();
}
size_t toal_num_exception = run_stats.num_unknown_exception + run_stats.num_std_exception + run_stats.num_ort_exception;
size_t toal_num_exception =
run_stats.num_unknown_exception + run_stats.num_std_exception + run_stats.num_ort_exception;
Logger::testLog << L"Total number of exceptions: " << toal_num_exception << Logger::endl;
Logger::testLog << L"Number of Unknown exceptions: " << run_stats.num_unknown_exception << Logger::endl;
Logger::testLog << L"Number of ort exceptions: " << run_stats.num_ort_exception << Logger::endl;

Просмотреть файл

@ -112,7 +112,7 @@ TestLog::TestLog()
if (!std::filesystem::exists(mutateModelDir)) {
std::filesystem::create_directory(mutateModelDir);
}
logFile.open(L"out/log", std::ios::ate);
logFile.open("out/log", std::ios::ate);
} else {
throw std::runtime_error("TestLog has already been initialized. Call GetTestLog() to use it");
}
@ -121,18 +121,33 @@ TestLog::TestLog()
// Helper function to convert string to wstring
//
std::wstring towstr(const char* pStr) {
std::mbstate_t ps;
size_t retVal;
std::mbstate_t ps = std::mbstate_t();
#if defined(_WIN32) || defined(_WIN64)
size_t length_str = strnlen(pStr, onnxruntime::kMaxStrLen);
size_t retVal;
// On Windows, use mbsrtowcs_s which is safe and provides size checking
mbsrtowcs_s(&retVal, nullptr, 0, &pStr, length_str, &ps);
retVal += 1;
auto ptr = std::make_unique<wchar_t[]>(retVal);
if (ptr == nullptr) {
std::stringstream str;
str << "Failed to allocate memory: " << __func__ << __LINE__ << "\n";
throw std::exception{str.str().data()};
str << "Failed to allocate memory: " << __func__ << " " << __LINE__ << "\n";
throw std::runtime_error{str.str().data()};
}
mbsrtowcs_s(&retVal, ptr.get(), retVal, &pStr, length_str, &ps);
#else
// On Linux, use mbsrtowcs which is part of the standard library
size_t retVal = std::mbsrtowcs(nullptr, &pStr, 0, &ps);
retVal += 1;
auto ptr = std::make_unique<wchar_t[]>(retVal);
if (ptr == nullptr) {
std::stringstream str;
str << "Failed to allocate memory: " << __func__ << " " << __LINE__ << "\n";
throw std::runtime_error{str.str().data()};
}
std::mbsrtowcs(ptr.get(), &pStr, retVal, &ps);
#endif
return std::wstring{ptr.get()};
}
} // namespace Logger