Add Tracelogging for profiling (#1639)
Enabled only if onnxruntime_ENABLE_INSTRUMENT is ON
This commit is contained in:
Родитель
0c6e9f94d0
Коммит
fc6773a65b
|
@ -83,6 +83,7 @@ option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in l
|
|||
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump node input shapes and output data to standard output when executing the model." OFF)
|
||||
option(onnxruntime_USE_DML "Build with DirectML support" OFF)
|
||||
option(onnxruntime_USE_ACL "Build with ACL support" OFF)
|
||||
option(onnxruntime_ENABLE_INSTRUMENT "Enable Instrument with Event Tracing for Windows (ETW)" OFF)
|
||||
|
||||
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE)
|
||||
#nsync tests failed on Mac Build
|
||||
|
@ -91,6 +92,15 @@ set(ONNX_ML 1)
|
|||
if(NOT onnxruntime_ENABLE_PYTHON)
|
||||
set(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS OFF)
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
#TODO: On Linux we may try https://github.com/microsoft/TraceLogging
|
||||
if(onnxruntime_ENABLE_INSTRUMENT)
|
||||
message(WARNING "Instrument is only supported on Windows now")
|
||||
set(onnxruntime_ENABLE_INSTRUMENT OFF)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(onnxruntime_USE_OPENMP)
|
||||
find_package(OpenMP)
|
||||
if (OPENMP_FOUND)
|
||||
|
|
|
@ -10,7 +10,9 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS
|
|||
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_framework_srcs})
|
||||
|
||||
add_library(onnxruntime_framework ${onnxruntime_framework_srcs})
|
||||
|
||||
if(onnxruntime_ENABLE_INSTRUMENT)
|
||||
target_compile_definitions(onnxruntime_framework PRIVATE ONNXRUNTIME_ENABLE_INSTRUMENT)
|
||||
endif()
|
||||
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||
onnxruntime_add_include_to_target(onnxruntime_framework onnxruntime_common onnx onnx_proto protobuf::libprotobuf)
|
||||
set_target_properties(onnxruntime_framework PROPERTIES FOLDER "ONNXRuntime")
|
||||
|
|
|
@ -12,6 +12,9 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_session_srcs})
|
|||
add_library(onnxruntime_session ${onnxruntime_session_srcs})
|
||||
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/session DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core)
|
||||
onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
|
||||
if(onnxruntime_ENABLE_INSTRUMENT)
|
||||
target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
|
||||
endif()
|
||||
target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS})
|
||||
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
|
||||
|
|
|
@ -776,6 +776,17 @@ if (onnxruntime_BUILD_SERVER)
|
|||
|
||||
endif()
|
||||
|
||||
#some ETW tools
|
||||
if(WIN32 AND onnxruntime_ENABLE_INSTRUMENT)
|
||||
add_executable(generate_perf_report_from_etl ${ONNXRUNTIME_ROOT}/tool/etw/main.cc ${ONNXRUNTIME_ROOT}/tool/etw/eparser.h ${ONNXRUNTIME_ROOT}/tool/etw/eparser.cc ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.h ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.cc)
|
||||
target_compile_definitions(generate_perf_report_from_etl PRIVATE "_CONSOLE" "_UNICODE" "UNICODE")
|
||||
target_link_libraries(generate_perf_report_from_etl PRIVATE tdh Advapi32)
|
||||
|
||||
add_executable(compare_two_sessions ${ONNXRUNTIME_ROOT}/tool/etw/compare_two_sessions.cc ${ONNXRUNTIME_ROOT}/tool/etw/eparser.h ${ONNXRUNTIME_ROOT}/tool/etw/eparser.cc ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.h ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.cc)
|
||||
target_compile_definitions(compare_two_sessions PRIVATE "_CONSOLE" "_UNICODE" "UNICODE")
|
||||
target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32)
|
||||
endif()
|
||||
|
||||
add_executable(onnxruntime_mlas_test ${TEST_SRC_DIR}/mlas/unittest.cpp)
|
||||
target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT})
|
||||
set(onnxruntime_mlas_test_libs onnxruntime_mlas onnxruntime_common)
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <windows.h>
|
||||
#include <TraceLoggingProvider.h>
|
||||
|
||||
TRACELOGGING_DECLARE_PROVIDER(telemetry_provider_handle);
|
|
@ -25,6 +25,22 @@
|
|||
using namespace Concurrency;
|
||||
#endif
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
#include <Windows.h>
|
||||
#include "core/platform/tracing.h"
|
||||
namespace {
|
||||
LARGE_INTEGER OrtGetPerformanceFrequency() {
|
||||
LARGE_INTEGER v;
|
||||
// On systems that run Windows XP or later, the QueryPerformanceFrequency function will always succeed
|
||||
// and will thus never return zero.
|
||||
(void)QueryPerformanceFrequency(&v);
|
||||
return v;
|
||||
}
|
||||
|
||||
LARGE_INTEGER perf_freq = OrtGetPerformanceFrequency();
|
||||
} // namespace
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
static Status ReleaseNodeMLValues(ExecutionFrame& frame,
|
||||
|
@ -87,7 +103,10 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
if (p_op_kernel == nullptr)
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ",
|
||||
node.Name());
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
LARGE_INTEGER kernel_start;
|
||||
QueryPerformanceCounter(&kernel_start);
|
||||
#endif
|
||||
// construct OpKernelContext
|
||||
// TODO: log kernel inputs?
|
||||
OpKernelContextInternal op_kernel_context(session_state, frame, *p_op_kernel, logger, terminate_flag_);
|
||||
|
@ -128,7 +147,6 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined DEBUG_NODE_INPUTS_OUTPUTS
|
||||
utils::DumpNodeInputs(op_kernel_context, p_op_kernel->Node());
|
||||
#endif
|
||||
|
@ -202,7 +220,19 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
LARGE_INTEGER kernel_stop;
|
||||
QueryPerformanceCounter(&kernel_stop);
|
||||
LARGE_INTEGER elapsed;
|
||||
elapsed.QuadPart = kernel_stop.QuadPart - kernel_start.QuadPart;
|
||||
elapsed.QuadPart *= 1000000;
|
||||
elapsed.QuadPart /= perf_freq.QuadPart;
|
||||
// Log an event
|
||||
TraceLoggingWrite(telemetry_provider_handle, // handle to my provider
|
||||
"OpEnd", // Event Name that should uniquely identify your event.
|
||||
TraceLoggingValue(p_op_kernel->KernelDef().OpName().c_str(), "op_name"),
|
||||
TraceLoggingValue(elapsed.QuadPart, "time"));
|
||||
#endif
|
||||
if (is_profiler_enabled) {
|
||||
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
|
||||
p_op_kernel->Node().Name() + "_fence_after",
|
||||
|
|
|
@ -271,7 +271,10 @@ void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const s
|
|||
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index,
|
||||
" for attribute ", attribute_name);
|
||||
}
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
session_state->parent_ = this;
|
||||
GenerateGraphId();
|
||||
#endif
|
||||
subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state)));
|
||||
}
|
||||
|
||||
|
|
|
@ -268,6 +268,19 @@ class SessionState {
|
|||
|
||||
std::unique_ptr<NodeIndexInfo> node_index_info_;
|
||||
std::multimap<int, std::unique_ptr<FeedsFetchesManager>> cached_feeds_fetches_managers_;
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
SessionState* parent_ = nullptr;
|
||||
//Assign each graph in each session an unique id.
|
||||
int graph_id_ = 0;
|
||||
int next_graph_id_ = 1;
|
||||
|
||||
void GenerateGraphId() {
|
||||
SessionState* p = this;
|
||||
while (p->parent_ != nullptr) p = p->parent_;
|
||||
graph_id_ = p->next_graph_id_ ++;
|
||||
}
|
||||
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -19,6 +19,10 @@
|
|||
|
||||
#include "core/platform/env.h"
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
#include "core/platform/tracing.h"
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
|
|
@ -73,7 +73,7 @@ inline const wchar_t* GetDateFormatString<wchar_t>() {
|
|||
return L"%Y-%m-%d_%H-%M-%S";
|
||||
}
|
||||
#endif
|
||||
//TODO: use LoggingManager::GetTimestamp and date::operator<<
|
||||
// TODO: use LoggingManager::GetTimestamp and date::operator<<
|
||||
// (see ostream_sink.cc for an example)
|
||||
// to simplify this and match the log file timestamp format.
|
||||
template <typename T>
|
||||
|
@ -115,7 +115,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
|
|||
insert_cast_transformer_("CastFloat16Transformer") {
|
||||
ORT_ENFORCE(Environment::IsInitialized(),
|
||||
"Environment must be initialized before creating an InferenceSession.");
|
||||
|
||||
InitLogger(logging_manager);
|
||||
|
||||
session_state_.SetDataTransferMgr(&data_transfer_mgr_);
|
||||
|
@ -144,6 +143,9 @@ InferenceSession::~InferenceSession() {
|
|||
LOGS(*session_logger_, ERROR) << "Unknown error during EndProfiling()";
|
||||
}
|
||||
}
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
|
||||
#endif
|
||||
}
|
||||
|
||||
common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr<IExecutionProvider> p_exec_provider) {
|
||||
|
@ -176,8 +178,8 @@ common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr<IExec
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::RegisterGraphTransformer(std::unique_ptr<onnxruntime::GraphTransformer> p_graph_transformer,
|
||||
TransformerLevel level) {
|
||||
common::Status InferenceSession::RegisterGraphTransformer(
|
||||
std::unique_ptr<onnxruntime::GraphTransformer> p_graph_transformer, TransformerLevel level) {
|
||||
if (p_graph_transformer == nullptr) {
|
||||
return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for graph transformer");
|
||||
}
|
||||
|
@ -185,8 +187,7 @@ common::Status InferenceSession::RegisterGraphTransformer(std::unique_ptr<onnxru
|
|||
}
|
||||
|
||||
common::Status InferenceSession::AddCustomTransformerList(const std::vector<std::string>& transformers_to_enable) {
|
||||
std::copy(transformers_to_enable.begin(), transformers_to_enable.end(),
|
||||
std::back_inserter(transformers_to_enable_));
|
||||
std::copy(transformers_to_enable.begin(), transformers_to_enable.end(), std::back_inserter(transformers_to_enable_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -213,7 +214,8 @@ common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr<CustomRe
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::Load(std::function<common::Status(std::shared_ptr<Model>&)> loader, const std::string& event_name) {
|
||||
common::Status InferenceSession::Load(std::function<common::Status(std::shared_ptr<Model>&)> loader,
|
||||
const std::string& event_name) {
|
||||
Status status = Status::OK();
|
||||
TimePoint tp;
|
||||
if (session_profiler_.IsEnabled()) {
|
||||
|
@ -223,8 +225,7 @@ common::Status InferenceSession::Load(std::function<common::Status(std::shared_p
|
|||
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
||||
if (is_model_loaded_) { // already loaded
|
||||
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
|
||||
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED,
|
||||
"This session already contains a loaded model.");
|
||||
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
|
||||
}
|
||||
|
||||
std::shared_ptr<onnxruntime::Model> p_tmp_model;
|
||||
|
@ -281,14 +282,10 @@ common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
common::Status InferenceSession::Load(const std::string& model_uri) {
|
||||
return Load<char>(model_uri);
|
||||
}
|
||||
common::Status InferenceSession::Load(const std::string& model_uri) { return Load<char>(model_uri); }
|
||||
|
||||
#ifdef _WIN32
|
||||
common::Status InferenceSession::Load(const std::wstring& model_uri) {
|
||||
return Load<PATH_CHAR_TYPE>(model_uri);
|
||||
}
|
||||
common::Status InferenceSession::Load(const std::wstring& model_uri) { return Load<PATH_CHAR_TYPE>(model_uri); }
|
||||
#endif
|
||||
|
||||
common::Status InferenceSession::Load(const ModelProto& model_proto) {
|
||||
|
@ -522,7 +519,6 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
|
|||
const auto implicit_inputs = node.ImplicitInputDefs();
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(initializer.CreatePlan(&node, &implicit_inputs,
|
||||
session_options_.execution_mode));
|
||||
|
||||
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
|
||||
// &*subgraph_info.session_state);
|
||||
|
||||
|
@ -554,12 +550,14 @@ common::Status InferenceSession::Initialize() {
|
|||
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
||||
return common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded.");
|
||||
}
|
||||
|
||||
if (is_inited_) { // already initialized
|
||||
LOGS(*session_logger_, INFO) << "Session has already been initialized.";
|
||||
return common::Status::OK();
|
||||
}
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
TraceLoggingWriteStart(session_activity, "OrtInferenceSessionActivity");
|
||||
session_activity_started_ = true;
|
||||
#endif
|
||||
// Register default CPUExecutionProvider if user didn't provide it through the Register() calls
|
||||
if (!execution_providers_.Get(onnxruntime::kCpuExecutionProvider)) {
|
||||
LOGS(*session_logger_, INFO) << "Adding default CPU execution provider.";
|
||||
|
@ -578,7 +576,8 @@ common::Status InferenceSession::Initialize() {
|
|||
}
|
||||
|
||||
// add predefined transformers
|
||||
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, transformers_to_enable_);
|
||||
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
|
||||
transformers_to_enable_);
|
||||
|
||||
onnxruntime::Graph& graph = model_->MainGraph();
|
||||
|
||||
|
@ -624,7 +623,6 @@ common::Status InferenceSession::Initialize() {
|
|||
// handle any subgraphs
|
||||
ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, session_state_));
|
||||
is_inited_ = true;
|
||||
|
||||
LOGS(*session_logger_, INFO) << "Session successfully initialized.";
|
||||
} catch (const NotImplementedException& ex) {
|
||||
status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what());
|
||||
|
@ -643,9 +641,7 @@ common::Status InferenceSession::Initialize() {
|
|||
return status;
|
||||
}
|
||||
|
||||
int InferenceSession::GetCurrentNumRuns() const {
|
||||
return current_num_runs_.load();
|
||||
}
|
||||
int InferenceSession::GetCurrentNumRuns() const { return current_num_runs_.load(); }
|
||||
|
||||
const std::vector<std::string>& InferenceSession::GetRegisteredProviderTypes() const {
|
||||
return execution_providers_.GetIds();
|
||||
|
@ -662,8 +658,7 @@ common::Status InferenceSession::CheckShapes(const std::string& input_name,
|
|||
auto expected_shape_sz = expected_shape.NumDimensions();
|
||||
if (input_shape_sz != expected_shape_sz) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Invalid rank for input: " << input_name
|
||||
<< " Got: " << input_shape_sz << " Expected: " << expected_shape_sz
|
||||
ostr << "Invalid rank for input: " << input_name << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz
|
||||
<< " Please fix either the inputs or the model.";
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str());
|
||||
}
|
||||
|
@ -705,10 +700,8 @@ static common::Status CheckTypes(MLDataType actual, MLDataType expected) {
|
|||
common::Status InferenceSession::ValidateInputs(const std::vector<std::string>& feed_names,
|
||||
const std::vector<OrtValue>& feeds) const {
|
||||
if (feed_names.size() != feeds.size()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Size mismatch: feed_names has ",
|
||||
feed_names.size(), "elements, but feeds has ",
|
||||
feeds.size(), " elements.");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(),
|
||||
"elements, but feeds has ", feeds.size(), " elements.");
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < feeds.size(); ++i) {
|
||||
|
@ -716,8 +709,7 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
|
|||
|
||||
auto iter = input_def_map_.find(feed_name);
|
||||
if (input_def_map_.end() == iter) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Invalid Feed Input Name:", feed_name);
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Feed Input Name:", feed_name);
|
||||
}
|
||||
|
||||
auto expected_type = iter->second.ml_data_type;
|
||||
|
@ -725,8 +717,8 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
|
|||
if (input_ml_value.IsTensor()) {
|
||||
// check for type
|
||||
if (!expected_type->IsTensorType()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ",
|
||||
feed_name, " is not expected to be of type tensor.");
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name,
|
||||
" is not expected to be of type tensor.");
|
||||
}
|
||||
|
||||
auto expected_element_type = expected_type->AsTensorType()->GetElementType();
|
||||
|
@ -751,17 +743,14 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
|
|||
common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>& output_names,
|
||||
const std::vector<OrtValue>* p_fetches) const {
|
||||
if (p_fetches == nullptr) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Output vector pointer is NULL");
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL");
|
||||
}
|
||||
|
||||
if (output_names.empty()) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"At least one output should be requested.");
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "At least one output should be requested.");
|
||||
}
|
||||
|
||||
if (!p_fetches->empty() &&
|
||||
(output_names.size() != p_fetches->size())) {
|
||||
if (!p_fetches->empty() && (output_names.size() != p_fetches->size())) {
|
||||
std::ostringstream ostr;
|
||||
ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size()
|
||||
<< "p_fetches->size(): " << p_fetches->size();
|
||||
|
@ -770,8 +759,7 @@ common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>&
|
|||
|
||||
for (const auto& name : output_names) {
|
||||
if (model_output_names_.find(name) == model_output_names_.end()) {
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
|
||||
"Invalid Output Name:" + name);
|
||||
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid Output Name:" + name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -787,6 +775,12 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
|
|||
if (session_profiler_.IsEnabled()) {
|
||||
tp = session_profiler_.StartTime();
|
||||
}
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
TraceLoggingActivity<telemetry_provider_handle> ortrun_activity;
|
||||
ortrun_activity.SetRelatedActivity(session_activity);
|
||||
TraceLoggingWriteStart(ortrun_activity, "OrtRun");
|
||||
#endif
|
||||
Status retval = Status::OK();
|
||||
|
||||
try {
|
||||
|
@ -858,7 +852,9 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
|
|||
if (session_profiler_.IsEnabled()) {
|
||||
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp);
|
||||
}
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
TraceLoggingWriteStop(ortrun_activity, "OrtRun");
|
||||
#endif
|
||||
return retval;
|
||||
}
|
||||
|
||||
|
@ -889,8 +885,7 @@ std::pair<common::Status, const ModelMetadata*> InferenceSession::GetModelMetada
|
|||
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
||||
if (!is_model_loaded_) {
|
||||
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
||||
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."),
|
||||
nullptr);
|
||||
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -902,8 +897,7 @@ std::pair<common::Status, const InputDefList*> InferenceSession::GetModelInputs(
|
|||
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
||||
if (!is_model_loaded_) {
|
||||
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
||||
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."),
|
||||
nullptr);
|
||||
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -916,8 +910,7 @@ std::pair<common::Status, const InputDefList*> InferenceSession::GetOverridableI
|
|||
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
|
||||
if (!is_model_loaded_) {
|
||||
LOGS(*session_logger_, ERROR) << "Model was not loaded";
|
||||
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."),
|
||||
nullptr);
|
||||
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -971,14 +964,10 @@ void InferenceSession::StartProfiling(const std::basic_string<T>& file_prefix) {
|
|||
session_profiler_.StartProfiling(ss.str());
|
||||
}
|
||||
|
||||
void InferenceSession::StartProfiling(const std::string& file_prefix) {
|
||||
StartProfiling<char>(file_prefix);
|
||||
}
|
||||
void InferenceSession::StartProfiling(const std::string& file_prefix) { StartProfiling<char>(file_prefix); }
|
||||
|
||||
#ifdef _WIN32
|
||||
void InferenceSession::StartProfiling(const std::wstring& file_prefix) {
|
||||
StartProfiling<PATH_CHAR_TYPE>(file_prefix);
|
||||
}
|
||||
void InferenceSession::StartProfiling(const std::wstring& file_prefix) { StartProfiling<PATH_CHAR_TYPE>(file_prefix); }
|
||||
#endif
|
||||
|
||||
void InferenceSession::StartProfiling(const logging::Logger* logger_ptr) {
|
||||
|
@ -1026,11 +1015,11 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod
|
|||
for (auto elem : inputs) {
|
||||
auto elem_type = utils::GetMLDataType(*elem);
|
||||
auto elem_shape_proto = elem->Shape();
|
||||
input_def_map_.insert({elem->Name(), InputDefMetaData(elem,
|
||||
elem_type,
|
||||
elem_shape_proto
|
||||
? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto)
|
||||
: TensorShape())});
|
||||
input_def_map_.insert(
|
||||
{elem->Name(),
|
||||
InputDefMetaData(
|
||||
elem, elem_type,
|
||||
elem_shape_proto ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) : TensorShape())});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1086,10 +1075,7 @@ const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& ru
|
|||
severity = static_cast<logging::Severity>(run_options.run_log_severity_level);
|
||||
}
|
||||
|
||||
new_run_logger = logging_manager_->CreateLogger(run_log_id,
|
||||
severity,
|
||||
false,
|
||||
run_options.run_log_verbosity_level);
|
||||
new_run_logger = logging_manager_->CreateLogger(run_log_id, severity, false, run_options.run_log_verbosity_level);
|
||||
|
||||
run_logger = new_run_logger.get();
|
||||
VLOGS(*run_logger, 1) << "Created logger for run with id of " << run_log_id;
|
||||
|
@ -1116,9 +1102,7 @@ void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) {
|
|||
severity = static_cast<logging::Severity>(session_options_.session_log_severity_level);
|
||||
}
|
||||
|
||||
owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid,
|
||||
severity,
|
||||
false,
|
||||
owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false,
|
||||
session_options_.session_log_verbosity_level);
|
||||
session_logger_ = owned_session_logger_.get();
|
||||
} else {
|
||||
|
|
|
@ -24,6 +24,10 @@
|
|||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
#include "core/language_interop_ops/language_interop_ops.h"
|
||||
#endif
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
#include "core/platform/tracing.h"
|
||||
#include <TraceLoggingActivity.h>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime { // forward declarations
|
||||
class GraphTransformer;
|
||||
|
@ -434,7 +438,6 @@ class InferenceSession {
|
|||
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
|
||||
InterOpDomains interop_domains_;
|
||||
#endif
|
||||
|
||||
// used to support platform telemetry
|
||||
static std::atomic<uint32_t> global_session_id_; // a monotonically increasing session id
|
||||
uint32_t session_id_; // the current session's id
|
||||
|
@ -442,5 +445,10 @@ class InferenceSession {
|
|||
long long total_run_duration_since_last_; // the total duration (us) of Run() calls since the last report
|
||||
TimePoint time_sent_last_; // the TimePoint of the last report
|
||||
const long long kDurationBetweenSending = 1000* 1000 * 60 * 10; // duration in (us). send a report every 10 mins
|
||||
|
||||
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
|
||||
bool session_activity_started_ = false;
|
||||
TraceLoggingActivity<telemetry_provider_handle> session_activity;
|
||||
#endif
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -0,0 +1,306 @@
|
|||
//*********************************************************
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// This code is licensed under the MIT License (MIT).
|
||||
// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
|
||||
// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
|
||||
// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
|
||||
// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
|
||||
//
|
||||
//*********************************************************
|
||||
|
||||
#include "TraceSession.h"
|
||||
|
||||
namespace {
|
||||
|
||||
VOID WINAPI EventRecordCallback(EVENT_RECORD* pEventRecord)
|
||||
{
|
||||
auto session = (TraceSession*) pEventRecord->UserContext;
|
||||
auto const& hdr = pEventRecord->EventHeader;
|
||||
|
||||
if (session->startTime_ == 0) {
|
||||
session->startTime_ = hdr.TimeStamp.QuadPart;
|
||||
}
|
||||
|
||||
auto iter = session->eventHandler_.find(hdr.ProviderId);
|
||||
if (iter != session->eventHandler_.end()) {
|
||||
auto const& h = iter->second;
|
||||
(*h.fn_)(pEventRecord, h.ctxt_);
|
||||
}
|
||||
}
|
||||
|
||||
ULONG WINAPI BufferCallback(EVENT_TRACE_LOGFILE* pLogFile)
|
||||
{
|
||||
auto session = (TraceSession*) pLogFile->Context;
|
||||
auto shouldStopFn = session->shouldStopProcessingEventsFn_;
|
||||
if (shouldStopFn && (*shouldStopFn)()) {
|
||||
return FALSE; // break out of ProcessTrace()
|
||||
}
|
||||
|
||||
return TRUE; // continue processing events
|
||||
}
|
||||
|
||||
bool OpenLogger(
|
||||
TraceSession* session,
|
||||
TCHAR const* name,
|
||||
bool realtime)
|
||||
{
|
||||
// Open trace
|
||||
EVENT_TRACE_LOGFILE loggerInfo = {};
|
||||
/* Filled out below based on realtime:
|
||||
loggerInfo.LogFileName = nullptr;
|
||||
loggerInfo.LoggerName = nullptr;
|
||||
*/
|
||||
loggerInfo.ProcessTraceMode = PROCESS_TRACE_MODE_EVENT_RECORD | PROCESS_TRACE_MODE_RAW_TIMESTAMP;
|
||||
loggerInfo.BufferCallback = BufferCallback;
|
||||
loggerInfo.EventRecordCallback = EventRecordCallback;
|
||||
loggerInfo.Context = session;
|
||||
/* Output members (passed also to BufferCallback()):
|
||||
loggerInfo.CurrentTime
|
||||
loggerInfo.BuffersRead
|
||||
loggerInfo.CurrentEvent
|
||||
loggerInfo.LogfileHeader
|
||||
loggerInfo.BufferSize
|
||||
loggerInfo.Filled
|
||||
loggerInfo.IsKernelTrace
|
||||
*/
|
||||
/* Not used:
|
||||
loggerInfo.EventsLost
|
||||
*/
|
||||
|
||||
if (realtime) {
|
||||
loggerInfo.LoggerName = const_cast<decltype(loggerInfo.LoggerName)>(name);
|
||||
loggerInfo.ProcessTraceMode |= PROCESS_TRACE_MODE_REAL_TIME;
|
||||
} else {
|
||||
loggerInfo.LogFileName = const_cast<decltype(loggerInfo.LoggerName)>(name);
|
||||
}
|
||||
|
||||
session->traceHandle_ = OpenTrace(&loggerInfo);
|
||||
if (session->traceHandle_ == INVALID_PROCESSTRACE_HANDLE) {
|
||||
fprintf(stderr, "error: failed to open trace");
|
||||
auto lastError = GetLastError();
|
||||
switch (lastError) {
|
||||
case ERROR_INVALID_PARAMETER: fprintf(stderr, " (Logfile is NULL)"); break;
|
||||
case ERROR_BAD_PATHNAME: fprintf(stderr, " (invalid LoggerName)"); break;
|
||||
case ERROR_ACCESS_DENIED: fprintf(stderr, " (access denied)"); break;
|
||||
default: fprintf(stderr, " (error=%u)", lastError); break;
|
||||
}
|
||||
fprintf(stderr, ".\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Copy desired state from loggerInfo
|
||||
session->frequency_ = loggerInfo.LogfileHeader.PerfFreq.QuadPart;
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
size_t TraceSession::GUIDHash::operator()(GUID const& g) const
|
||||
{
|
||||
static_assert((sizeof(g) % sizeof(size_t)) == 0, "sizeof(GUID) must be multiple of sizeof(size_t)");
|
||||
auto p = (size_t const*) &g;
|
||||
auto h = (size_t) 0;
|
||||
for (size_t i = 0; i < sizeof(g) / sizeof(size_t); ++i) {
|
||||
h ^= p[i];
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
bool TraceSession::GUIDEqual::operator()(GUID const& lhs, GUID const& rhs) const
|
||||
{
|
||||
return IsEqualGUID(lhs, rhs) != FALSE;
|
||||
}
|
||||
|
||||
bool TraceSession::AddProvider(GUID providerId, UCHAR level,
|
||||
ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword)
|
||||
{
|
||||
auto p = eventProvider_.emplace(std::make_pair(providerId, Provider()));
|
||||
if (!p.second) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto h = &p.first->second;
|
||||
h->matchAny_ = matchAnyKeyword;
|
||||
h->matchAll_ = matchAllKeyword;
|
||||
h->level_ = level;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TraceSession::AddHandler(GUID providerId, EventHandlerFn handlerFn, void* handlerContext)
|
||||
{
|
||||
auto p = eventHandler_.emplace(std::make_pair(providerId, Handler()));
|
||||
if (!p.second) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto h = &p.first->second;
|
||||
h->fn_ = handlerFn;
|
||||
h->ctxt_ = handlerContext;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TraceSession::AddProviderAndHandler(GUID providerId, UCHAR level,
|
||||
ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword,
|
||||
EventHandlerFn handlerFn, void* handlerContext)
|
||||
{
|
||||
if (!AddProvider(providerId, level, matchAnyKeyword, matchAllKeyword))
|
||||
return false;
|
||||
if (!AddHandler(providerId, handlerFn, handlerContext)) {
|
||||
RemoveProvider(providerId);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TraceSession::RemoveProvider(GUID providerId)
|
||||
{
|
||||
if (sessionHandle_ != 0) {
|
||||
auto status = EnableTraceEx2(sessionHandle_, &providerId, EVENT_CONTROL_CODE_DISABLE_PROVIDER, 0, 0, 0, 0, nullptr);
|
||||
(void) status;
|
||||
}
|
||||
|
||||
return eventProvider_.erase(providerId) != 0;
|
||||
}
|
||||
|
||||
bool TraceSession::RemoveHandler(GUID providerId)
|
||||
{
|
||||
return eventHandler_.erase(providerId) != 0;
|
||||
}
|
||||
|
||||
bool TraceSession::RemoveProviderAndHandler(GUID providerId)
|
||||
{
|
||||
return RemoveProvider(providerId) || RemoveHandler(providerId);
|
||||
}
|
||||
|
||||
bool TraceSession::InitializeEtlFile(TCHAR const* inputEtlPath, ShouldStopProcessingEventsFn shouldStopFn)
|
||||
{
|
||||
// Open the trace
|
||||
if (!OpenLogger(this, inputEtlPath, false)) {
|
||||
Finalize();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Initialize state
|
||||
shouldStopProcessingEventsFn_ = shouldStopFn;
|
||||
eventsLostCount_ = 0;
|
||||
buffersLostCount_ = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool TraceSession::InitializeRealtime(TCHAR const* traceSessionName, ShouldStopProcessingEventsFn shouldStopFn)
|
||||
{
|
||||
// Set up and start a real-time collection session
|
||||
memset(&properties_, 0, sizeof(properties_));
|
||||
|
||||
properties_.Wnode.BufferSize = (ULONG) offsetof(TraceSession, sessionHandle_);
|
||||
//properties_.Wnode.Guid // ETW will create Guid
|
||||
properties_.Wnode.ClientContext = 1; // Clock resolution to use when logging the timestamp for each event
|
||||
// 1 == query performance counter
|
||||
properties_.Wnode.Flags = 0;
|
||||
//properties_.BufferSize = 0;
|
||||
properties_.MinimumBuffers = 200;
|
||||
//properties_.MaximumBuffers = 0;
|
||||
//properties_.MaximumFileSize = 0;
|
||||
properties_.LogFileMode = EVENT_TRACE_REAL_TIME_MODE;
|
||||
//properties_.FlushTimer = 0;
|
||||
//properties_.EnableFlags = 0;
|
||||
properties_.LogFileNameOffset = 0;
|
||||
properties_.LoggerNameOffset = offsetof(TraceSession, loggerName_);
|
||||
|
||||
auto status = StartTrace(&sessionHandle_, traceSessionName, &properties_);
|
||||
if (status == ERROR_ALREADY_EXISTS) {
|
||||
#ifdef _DEBUG
|
||||
fprintf(stderr, "warning: trying to start trace session that already exists.\n");
|
||||
#endif
|
||||
status = ControlTrace((TRACEHANDLE) 0, traceSessionName, &properties_, EVENT_TRACE_CONTROL_STOP);
|
||||
if (status == ERROR_SUCCESS) {
|
||||
status = StartTrace(&sessionHandle_, traceSessionName, &properties_);
|
||||
}
|
||||
}
|
||||
if (status != ERROR_SUCCESS) {
|
||||
fprintf(stderr, "error: failed to start trace session (error=%lu).\n", status);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Enable desired providers
|
||||
for (auto const& p : eventProvider_) {
|
||||
auto pGuid = &p.first;
|
||||
auto const& h = p.second;
|
||||
|
||||
status = EnableTraceEx2(sessionHandle_, pGuid, EVENT_CONTROL_CODE_ENABLE_PROVIDER, h.level_, h.matchAny_, h.matchAll_, 0, nullptr);
|
||||
if (status != ERROR_SUCCESS) {
|
||||
fprintf(stderr, "error: failed to enable provider {%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x}.\n",
|
||||
pGuid->Data1, pGuid->Data2, pGuid->Data3, pGuid->Data4[0], pGuid->Data4[1], pGuid->Data4[2],
|
||||
pGuid->Data4[3], pGuid->Data4[4], pGuid->Data4[5], pGuid->Data4[6], pGuid->Data4[7]);
|
||||
Finalize();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Open the trace
|
||||
if (!OpenLogger(this, traceSessionName, true)) {
|
||||
Finalize();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Initialize state
|
||||
shouldStopProcessingEventsFn_ = shouldStopFn;
|
||||
eventsLostCount_ = 0;
|
||||
buffersLostCount_ = 0;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void TraceSession::Finalize()
|
||||
{
|
||||
ULONG status = ERROR_SUCCESS;
|
||||
|
||||
if (traceHandle_ != INVALID_PROCESSTRACE_HANDLE) {
|
||||
status = CloseTrace(traceHandle_);
|
||||
traceHandle_ = INVALID_PROCESSTRACE_HANDLE;
|
||||
}
|
||||
|
||||
if (sessionHandle_ != 0) {
|
||||
status = ControlTraceW(sessionHandle_, nullptr, &properties_, EVENT_TRACE_CONTROL_STOP);
|
||||
|
||||
while (!eventProvider_.empty()) {
|
||||
RemoveProvider(eventProvider_.begin()->first);
|
||||
}
|
||||
while (!eventHandler_.empty()) {
|
||||
RemoveHandler(eventHandler_.begin()->first);
|
||||
}
|
||||
|
||||
sessionHandle_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool TraceSession::CheckLostReports(uint32_t* eventsLost, uint32_t* buffersLost)
|
||||
{
|
||||
if (sessionHandle_ == 0) {
|
||||
*eventsLost = 0;
|
||||
*buffersLost = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto status = ControlTraceW(sessionHandle_, nullptr, &properties_, EVENT_TRACE_CONTROL_QUERY);
|
||||
if (status == ERROR_MORE_DATA) { // The buffer &properties_ is too small to hold all the information
|
||||
*eventsLost = 0; // for the session. If you don't need the session's property information
|
||||
*buffersLost = 0; // you can ignore this error.
|
||||
return false;
|
||||
}
|
||||
|
||||
if (status != ERROR_SUCCESS) {
|
||||
fprintf(stderr, "error: failed to query trace status (%lu).\n", status);
|
||||
*eventsLost = 0;
|
||||
*buffersLost = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
*eventsLost = properties_.EventsLost - eventsLostCount_;
|
||||
*buffersLost = properties_.RealTimeBuffersLost - buffersLostCount_;
|
||||
eventsLostCount_ = properties_.EventsLost;
|
||||
buffersLostCount_ = properties_.RealTimeBuffersLost;
|
||||
return *eventsLost + *buffersLost > 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
//*********************************************************
|
||||
//
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
// This code is licensed under the MIT License (MIT).
|
||||
// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
|
||||
// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
|
||||
// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
|
||||
// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
|
||||
//
|
||||
//*********************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <windows.h>
|
||||
#include <tchar.h>
|
||||
#include <evntcons.h> // must be after windows.h
|
||||
#include <stdint.h>
|
||||
#include <unordered_map>
|
||||
|
||||
typedef void (*EventHandlerFn)(EVENT_RECORD* pEventRecord, void* pContext);
|
||||
typedef bool (*ShouldStopProcessingEventsFn)();
|
||||
|
||||
struct TraceSession {
|
||||
// BEGIN trace property block, must be beginning of TraceSession
|
||||
EVENT_TRACE_PROPERTIES properties_;
|
||||
wchar_t loggerName_[MAX_PATH];
|
||||
// END Trace property block
|
||||
|
||||
TRACEHANDLE sessionHandle_; // Must be first member after trace property block
|
||||
TRACEHANDLE traceHandle_;
|
||||
ShouldStopProcessingEventsFn shouldStopProcessingEventsFn_;
|
||||
uint64_t startTime_;
|
||||
uint64_t frequency_;
|
||||
uint32_t eventsLostCount_;
|
||||
uint32_t buffersLostCount_;
|
||||
|
||||
// Structure to hold the mapping from provider ID to event handler function
|
||||
struct GUIDHash { size_t operator()(GUID const& g) const; };
|
||||
struct GUIDEqual { bool operator()(GUID const& lhs, GUID const& rhs) const; };
|
||||
struct Provider {
|
||||
ULONGLONG matchAny_;
|
||||
ULONGLONG matchAll_;
|
||||
UCHAR level_;
|
||||
};
|
||||
struct Handler {
|
||||
EventHandlerFn fn_;
|
||||
void* ctxt_;
|
||||
};
|
||||
std::unordered_map<GUID, Provider, GUIDHash, GUIDEqual> eventProvider_;
|
||||
std::unordered_map<GUID, Handler, GUIDHash, GUIDEqual> eventHandler_;
|
||||
|
||||
TraceSession()
|
||||
: sessionHandle_(0)
|
||||
, traceHandle_(INVALID_PROCESSTRACE_HANDLE)
|
||||
, startTime_(0)
|
||||
, frequency_(0)
|
||||
, shouldStopProcessingEventsFn_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
// Usage:
|
||||
//
|
||||
// 1) use TraceSession::AddProvider() to add the IDs for all the providers
|
||||
// you want to trace. Use TraceSession::AddHandler() to add the handler
|
||||
// functions for the providers/events you want to trace.
|
||||
//
|
||||
// 2) call TraceSession::InitializeRealtime() or
|
||||
// TraceSession::InitializeEtlFile(), to start tracing events from
|
||||
// real-time collection or from a previously-captured .etl file. At this
|
||||
// point, events start to be traced.
|
||||
//
|
||||
// 3) call ::ProcessTrace() to start collecting the events; provider
|
||||
// handler functions will be called as those provider events are collected.
|
||||
// ProcessTrace() will exit when shouldStopProcessingEventsFn_ returns
|
||||
// true, or when the .etl file is fully consumed.
|
||||
//
|
||||
// 4) Finalize() to clean up.
|
||||
|
||||
// AddProvider/Handler() returns false if the providerId already has a handler.
|
||||
// RemoveProvider/Handler() returns false if the providerId don't have a handler.
|
||||
bool AddProvider(GUID providerId, UCHAR level, ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword);
|
||||
bool AddHandler(GUID handlerId, EventHandlerFn handlerFn, void* handlerContext);
|
||||
bool AddProviderAndHandler(GUID providerId, UCHAR level, ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword,
|
||||
EventHandlerFn handlerFn, void* handlerContext);
|
||||
bool RemoveProvider(GUID providerId);
|
||||
bool RemoveHandler(GUID handlerId);
|
||||
bool RemoveProviderAndHandler(GUID providerId);
|
||||
|
||||
// InitializeRealtime() and InitializeEtlFile() return false if the session
|
||||
// could not be created.
|
||||
bool InitializeEtlFile(TCHAR const* etlPath, ShouldStopProcessingEventsFn shouldStopProcessingEventsFn);
|
||||
bool InitializeRealtime(TCHAR const* traceSessionName, ShouldStopProcessingEventsFn shouldStopProcessingEventsFn);
|
||||
void Finalize();
|
||||
|
||||
// Call CheckLostReports() at any time the session is initialized to query
|
||||
// how many events and buffers have been lost while tracing.
|
||||
bool CheckLostReports(uint32_t* eventsLost, uint32_t* buffersLost);
|
||||
};
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <iostream>
|
||||
#include "eparser.h"
|
||||
#include "TraceSession.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <tchar.h>
|
||||
#else
|
||||
#define TCHAR char
|
||||
#define _tmain main
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
#include "getopt.h"
|
||||
#else
|
||||
#include <getopt.h>
|
||||
#include <thread>
|
||||
#endif
|
||||
|
||||
static const GUID OrtProviderGuid = {0x54d81939, 0x62a0, 0x4dc0, {0xbf, 0x32, 0x3, 0x5e, 0xbd, 0xc7, 0xbc, 0xe9}};
|
||||
|
||||
int fetch_data(TCHAR* filename, ProfilingInfo& context) {
|
||||
TraceSession session;
|
||||
session.AddHandler(OrtProviderGuid, OrtEventHandler, &context);
|
||||
session.InitializeEtlFile(filename, nullptr);
|
||||
ULONG status = ProcessTrace(&session.traceHandle_, 1, 0, 0);
|
||||
if (status != ERROR_SUCCESS && status != ERROR_CANCELLED) {
|
||||
std::cout << "OpenTrace failed with " << status << std::endl;
|
||||
session.Finalize();
|
||||
return -1;
|
||||
}
|
||||
session.Finalize();
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::pair<double, double> CalcMeanAndStdSquare(const T* input, size_t input_len) {
|
||||
T sum = 0;
|
||||
T sum_square = 0;
|
||||
const size_t N = input_len;
|
||||
for (size_t i = 0; i != N; ++i) {
|
||||
T t = input[i];
|
||||
sum += t;
|
||||
sum_square += t * t;
|
||||
}
|
||||
double mean = ((double)sum) / N;
|
||||
double std = (sum_square - N * mean * mean) / (N - 1);
|
||||
return std::make_pair(mean, std);
|
||||
}
|
||||
// see: "Statistical Distributions", 4th Edition, by Catherine Forbes, Merran Evans, Nicholas Hastings and Brian
|
||||
// Peacock. Chapter 42: "Student’s t Distribution". I only implemented when v is even.
|
||||
double TDistributionCDF(int v, double x) {
|
||||
assert(v >= 2 && (v & 1) == 0);
|
||||
double t = x / (2 * std::sqrt(v + x * x));
|
||||
double sum1 = 0;
|
||||
double b_j = 1;
|
||||
for (int j = 0; j <= (v - 2) / 2; ++j) {
|
||||
sum1 += b_j / std::pow(1 + x * x / v, j);
|
||||
b_j *= static_cast<double>(2 * j + 1) / (2 * j + 2);
|
||||
}
|
||||
return 0.5 + t * sum1;
|
||||
}
|
||||
|
||||
struct TTestResult {
|
||||
double mean1, mean2;
|
||||
double std1, std2;
|
||||
double tvalue;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
TTestResult CalcTValue(const T* input1, size_t input1_len, const T* input2, size_t input2_len) {
|
||||
TTestResult result;
|
||||
auto p1 = CalcMeanAndStdSquare(input1, input1_len);
|
||||
result.mean1 = p1.first;
|
||||
result.std1 = std::sqrt(p1.second);
|
||||
auto p2 = CalcMeanAndStdSquare(input2, input2_len);
|
||||
result.mean2 = p2.first;
|
||||
result.std2 = std::sqrt(p2.second);
|
||||
auto diff_mean = p1.first - p2.first;
|
||||
size_t n1 = input1_len;
|
||||
size_t n2 = input2_len;
|
||||
auto sdiff = ((n1 - 1) * p1.second + (n2 - 1) * p2.second) / (n1 + n2 - 2);
|
||||
sdiff *= ((double)1) / n1 + ((double)1) / n2;
|
||||
result.tvalue = diff_mean / std::sqrt(sdiff);
|
||||
return result;
|
||||
}
|
||||
|
||||
int real_main(int argc, TCHAR* argv[]) {
|
||||
if (argc < 3) {
|
||||
printf("error\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
ProfilingInfo context1;
|
||||
int ret = fetch_data(argv[1], context1);
|
||||
if (ret != 0) return ret;
|
||||
ProfilingInfo context2;
|
||||
ret = fetch_data(argv[2], context2);
|
||||
if (ret != 0) return ret;
|
||||
size_t n1 = context1.time_per_run.size();
|
||||
size_t n2 = context2.time_per_run.size();
|
||||
if (n1 <= 10 || n2 <= 10) {
|
||||
printf("samples are too few, please try to gather more\n");
|
||||
return -1;
|
||||
}
|
||||
// ignore the first run
|
||||
--n1;
|
||||
--n2;
|
||||
if (((n1 + n2) & 1) != 0) {
|
||||
if (n1 > n2)
|
||||
n1--;
|
||||
else
|
||||
n2--;
|
||||
}
|
||||
TTestResult tresult = CalcTValue(context1.time_per_run.data() + 1, n1, context2.time_per_run.data() + 1, n2);
|
||||
size_t freedom = n1 + n2 - 2;
|
||||
double p = TDistributionCDF(static_cast<int>(freedom), std::abs(tresult.tvalue));
|
||||
std::cout << "Mean1: " << tresult.mean1 << " std1: " << tresult.std1 << "\n"
|
||||
<< "Mean2: " << tresult.mean2 << " std2: " << tresult.std2 << "\n"
|
||||
<< "H0: Mean1 = Mean2\n"
|
||||
<< "H1: Mean1 != Mean2\n"
|
||||
<< "Test statistic: T = " << tresult.tvalue << "\n"
|
||||
<< "Degrees of Freedom: v = " << freedom << "\n"
|
||||
<< "Significance level:" << (1 - p) * 2 << ". The lower the more likely to reject H0\n";
|
||||
if (p > 0.99995) {
|
||||
std::cout << "The two population means are different at the 0.0001 significance level." << std::endl;
|
||||
return -1;
|
||||
} else {
|
||||
std::cout << "They don't have significant statistical difference." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int _tmain(int argc, TCHAR* argv[]) {
|
||||
int retval = -1;
|
||||
try {
|
||||
retval = real_main(argc, argv);
|
||||
} catch (std::exception& ex) {
|
||||
fprintf(stderr, "%s\n", ex.what());
|
||||
retval = -1;
|
||||
}
|
||||
return retval;
|
||||
}
|
|
@ -0,0 +1,355 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "eparser.h"
|
||||
|
||||
// Get the metadata for the event.
|
||||
|
||||
// Get the length of the property data. For MOF-based events, the size is inferred from the data type
|
||||
// of the property. For manifest-based events, the property can specify the size of the property value
|
||||
// using the length attribute. The length attribue can specify the size directly or specify the name
|
||||
// of another property in the event data that contains the size. If the property does not include the
|
||||
// length attribute, the size is inferred from the data type. The length will be zero for variable
|
||||
// length, null-terminated strings and structures.
|
||||
|
||||
DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT PropertyLength);
|
||||
|
||||
// Get the size of the array. For MOF-based events, the size is specified in the declaration or using
|
||||
// the MAX qualifier. For manifest-based events, the property can specify the size of the array
|
||||
// using the count attribute. The count attribue can specify the size directly or specify the name
|
||||
// of another property in the event data that contains the size.
|
||||
|
||||
DWORD GetArraySize(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT ArraySize);
|
||||
|
||||
// Both MOF-based events and manifest-based events can specify name/value maps. The
|
||||
// map values can be integer values or bit values. If the property specifies a value
|
||||
// map, get the map.
|
||||
|
||||
DWORD GetMapInfo(PEVENT_RECORD pEvent, LPWSTR pMapName, DWORD DecodingSource, PEVENT_MAP_INFO& pMapInfo);
|
||||
|
||||
// Print the property.
|
||||
template <typename T>
|
||||
PBYTE PrintProperties(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, DWORD PointerSize, USHORT i, PBYTE pUserData,
|
||||
PBYTE pEndOfUserData, const T& t) {
|
||||
TDHSTATUS status = ERROR_SUCCESS;
|
||||
USHORT PropertyLength = 0;
|
||||
DWORD FormattedDataSize = 0;
|
||||
USHORT UserDataConsumed = 0;
|
||||
LPWSTR pFormattedData = NULL;
|
||||
DWORD LastMember = 0; // Last member of a structure
|
||||
USHORT ArraySize = 0;
|
||||
PEVENT_MAP_INFO pMapInfo = NULL;
|
||||
|
||||
// Get the length of the property.
|
||||
|
||||
status = GetPropertyLength(pEvent, pInfo, i, &PropertyLength);
|
||||
if (ERROR_SUCCESS != status) {
|
||||
wprintf(L"GetPropertyLength failed.\n");
|
||||
pUserData = NULL;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// Get the size of the array if the property is an array.
|
||||
|
||||
status = GetArraySize(pEvent, pInfo, i, &ArraySize);
|
||||
|
||||
for (USHORT k = 0; k < ArraySize; k++) {
|
||||
// If the property is a structure, print the members of the structure.
|
||||
|
||||
if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyStruct) == PropertyStruct) {
|
||||
LastMember = pInfo->EventPropertyInfoArray[i].structType.StructStartIndex +
|
||||
pInfo->EventPropertyInfoArray[i].structType.NumOfStructMembers;
|
||||
|
||||
for (USHORT j = pInfo->EventPropertyInfoArray[i].structType.StructStartIndex; j < LastMember; j++) {
|
||||
pUserData = PrintProperties(pEvent, pInfo, PointerSize, j, pUserData, pEndOfUserData, t);
|
||||
if (NULL == pUserData) {
|
||||
wprintf(L"Printing the members of the structure failed.\n");
|
||||
pUserData = NULL;
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Get the name/value mapping if the property specifies a value map.
|
||||
|
||||
status =
|
||||
GetMapInfo(pEvent, (PWCHAR)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[i].nonStructType.MapNameOffset),
|
||||
pInfo->DecodingSource, pMapInfo);
|
||||
|
||||
if (ERROR_SUCCESS != status) {
|
||||
wprintf(L"GetMapInfo failed\n");
|
||||
pUserData = NULL;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// Get the size of the buffer required for the formatted data.
|
||||
|
||||
status = TdhFormatProperty(pInfo, pMapInfo, PointerSize, pInfo->EventPropertyInfoArray[i].nonStructType.InType,
|
||||
pInfo->EventPropertyInfoArray[i].nonStructType.OutType, PropertyLength,
|
||||
(USHORT)(pEndOfUserData - pUserData), pUserData, &FormattedDataSize, pFormattedData,
|
||||
&UserDataConsumed);
|
||||
|
||||
if (ERROR_INSUFFICIENT_BUFFER == status) {
|
||||
if (pFormattedData) {
|
||||
free(pFormattedData);
|
||||
pFormattedData = NULL;
|
||||
}
|
||||
|
||||
pFormattedData = (LPWSTR)malloc(FormattedDataSize);
|
||||
if (pFormattedData == NULL) {
|
||||
wprintf(L"Failed to allocate memory for formatted data (size=%lu).\n", FormattedDataSize);
|
||||
status = ERROR_OUTOFMEMORY;
|
||||
pUserData = NULL;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// Retrieve the formatted data.
|
||||
|
||||
status = TdhFormatProperty(pInfo, pMapInfo, PointerSize, pInfo->EventPropertyInfoArray[i].nonStructType.InType,
|
||||
pInfo->EventPropertyInfoArray[i].nonStructType.OutType, PropertyLength,
|
||||
(USHORT)(pEndOfUserData - pUserData), pUserData, &FormattedDataSize, pFormattedData,
|
||||
&UserDataConsumed);
|
||||
}
|
||||
|
||||
if (ERROR_SUCCESS == status) {
|
||||
t((PWCHAR)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[i].NameOffset), pFormattedData);
|
||||
pUserData += UserDataConsumed;
|
||||
} else {
|
||||
wprintf(L"TdhFormatProperty failed with %lu.\n", status);
|
||||
pUserData = NULL;
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
|
||||
if (pFormattedData) {
|
||||
free(pFormattedData);
|
||||
pFormattedData = NULL;
|
||||
}
|
||||
|
||||
if (pMapInfo) {
|
||||
free(pMapInfo);
|
||||
pMapInfo = NULL;
|
||||
}
|
||||
|
||||
return pUserData;
|
||||
}
|
||||
|
||||
DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT PropertyLength) {
|
||||
DWORD status = ERROR_SUCCESS;
|
||||
PROPERTY_DATA_DESCRIPTOR DataDescriptor;
|
||||
DWORD PropertySize = 0;
|
||||
|
||||
// If the property is a binary blob and is defined in a manifest, the property can
|
||||
// specify the blob's size or it can point to another property that defines the
|
||||
// blob's size. The PropertyParamLength flag tells you where the blob's size is defined.
|
||||
|
||||
if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyParamLength) == PropertyParamLength) {
|
||||
DWORD Length = 0; // Expects the length to be defined by a UINT16 or UINT32
|
||||
DWORD j = pInfo->EventPropertyInfoArray[i].lengthPropertyIndex;
|
||||
ZeroMemory(&DataDescriptor, sizeof(PROPERTY_DATA_DESCRIPTOR));
|
||||
DataDescriptor.PropertyName = (ULONGLONG)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[j].NameOffset);
|
||||
DataDescriptor.ArrayIndex = ULONG_MAX;
|
||||
status = TdhGetPropertySize(pEvent, 0, NULL, 1, &DataDescriptor, &PropertySize);
|
||||
status = TdhGetProperty(pEvent, 0, NULL, 1, &DataDescriptor, PropertySize, (PBYTE)&Length);
|
||||
*PropertyLength = (USHORT)Length;
|
||||
} else {
|
||||
if (pInfo->EventPropertyInfoArray[i].length > 0) {
|
||||
*PropertyLength = pInfo->EventPropertyInfoArray[i].length;
|
||||
} else {
|
||||
// If the property is a binary blob and is defined in a MOF class, the extension
|
||||
// qualifier is used to determine the size of the blob. However, if the extension
|
||||
// is IPAddrV6, you must set the PropertyLength variable yourself because the
|
||||
// EVENT_PROPERTY_INFO.length field will be zero.
|
||||
|
||||
if (TDH_INTYPE_BINARY == pInfo->EventPropertyInfoArray[i].nonStructType.InType &&
|
||||
TDH_OUTTYPE_IPV6 == pInfo->EventPropertyInfoArray[i].nonStructType.OutType) {
|
||||
*PropertyLength = (USHORT)sizeof(IN6_ADDR);
|
||||
} else if (TDH_INTYPE_UNICODESTRING == pInfo->EventPropertyInfoArray[i].nonStructType.InType ||
|
||||
TDH_INTYPE_ANSISTRING == pInfo->EventPropertyInfoArray[i].nonStructType.InType ||
|
||||
(pInfo->EventPropertyInfoArray[i].Flags & PropertyStruct) == PropertyStruct) {
|
||||
*PropertyLength = pInfo->EventPropertyInfoArray[i].length;
|
||||
} else {
|
||||
wprintf(L"Unexpected length of 0 for intype %d and outtype %d\n",
|
||||
pInfo->EventPropertyInfoArray[i].nonStructType.InType,
|
||||
pInfo->EventPropertyInfoArray[i].nonStructType.OutType);
|
||||
|
||||
status = ERROR_EVT_INVALID_EVENT_DATA;
|
||||
goto cleanup;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
DWORD GetArraySize(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT ArraySize) {
|
||||
DWORD status = ERROR_SUCCESS;
|
||||
PROPERTY_DATA_DESCRIPTOR DataDescriptor;
|
||||
DWORD PropertySize = 0;
|
||||
|
||||
if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyParamCount) == PropertyParamCount) {
|
||||
DWORD Count = 0; // Expects the count to be defined by a UINT16 or UINT32
|
||||
DWORD j = pInfo->EventPropertyInfoArray[i].countPropertyIndex;
|
||||
ZeroMemory(&DataDescriptor, sizeof(PROPERTY_DATA_DESCRIPTOR));
|
||||
DataDescriptor.PropertyName = (ULONGLONG)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[j].NameOffset);
|
||||
DataDescriptor.ArrayIndex = ULONG_MAX;
|
||||
status = TdhGetPropertySize(pEvent, 0, NULL, 1, &DataDescriptor, &PropertySize);
|
||||
status = TdhGetProperty(pEvent, 0, NULL, 1, &DataDescriptor, PropertySize, (PBYTE)&Count);
|
||||
*ArraySize = (USHORT)Count;
|
||||
} else {
|
||||
*ArraySize = pInfo->EventPropertyInfoArray[i].count;
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
DWORD GetMapInfo(PEVENT_RECORD pEvent, LPWSTR pMapName, DWORD DecodingSource, PEVENT_MAP_INFO& pMapInfo) {
|
||||
DWORD status = ERROR_SUCCESS;
|
||||
DWORD MapSize = 0;
|
||||
|
||||
// Retrieve the required buffer size for the map info.
|
||||
|
||||
status = TdhGetEventMapInformation(pEvent, pMapName, pMapInfo, &MapSize);
|
||||
|
||||
if (ERROR_INSUFFICIENT_BUFFER == status) {
|
||||
pMapInfo = (PEVENT_MAP_INFO)malloc(MapSize);
|
||||
if (pMapInfo == NULL) {
|
||||
wprintf(L"Failed to allocate memory for map info (size=%lu).\n", MapSize);
|
||||
status = ERROR_OUTOFMEMORY;
|
||||
goto cleanup;
|
||||
}
|
||||
|
||||
// Retrieve the map info.
|
||||
|
||||
status = TdhGetEventMapInformation(pEvent, pMapName, pMapInfo, &MapSize);
|
||||
}
|
||||
|
||||
if (ERROR_SUCCESS == status) {
|
||||
if (DecodingSourceXMLFile == DecodingSource) {
|
||||
abort();
|
||||
}
|
||||
} else {
|
||||
if (ERROR_NOT_FOUND == status) {
|
||||
status = ERROR_SUCCESS; // This case is okay.
|
||||
} else {
|
||||
wprintf(L"TdhGetEventMapInformation failed with 0x%x.\n", status);
|
||||
}
|
||||
}
|
||||
|
||||
cleanup:
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
LoggingEventRecord LoggingEventRecord::CreateLoggingEventRecord(EVENT_RECORD* pEvent, DWORD& status) {
|
||||
LoggingEventRecord ret;
|
||||
ret.event_record_ = pEvent;
|
||||
status = ERROR_SUCCESS;
|
||||
DWORD BufferSize = 0;
|
||||
|
||||
// Retrieve the required buffer size for the event metadata.
|
||||
|
||||
status = TdhGetEventInformation(pEvent, 0, NULL, nullptr, &BufferSize);
|
||||
|
||||
if (ERROR_INSUFFICIENT_BUFFER != status) return ret;
|
||||
ret.buffer_.resize(BufferSize);
|
||||
// Retrieve the event metadata.
|
||||
status = TdhGetEventInformation(pEvent, 0, NULL, ret.GetEventInfo(), &BufferSize);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void OrtEventHandler(EVENT_RECORD* pEvent, void* pContext) {
|
||||
ProfilingInfo& info = *(ProfilingInfo*)pContext;
|
||||
DWORD status = ERROR_SUCCESS;
|
||||
|
||||
LoggingEventRecord record = LoggingEventRecord::CreateLoggingEventRecord(pEvent, status);
|
||||
if (ERROR_SUCCESS != status) {
|
||||
if (status == ERROR_NOT_FOUND) return;
|
||||
wprintf(L"GetEventInformation failed with %lu\n", status);
|
||||
abort();
|
||||
}
|
||||
DWORD PointerSize = 0;
|
||||
if (EVENT_HEADER_FLAG_32_BIT_HEADER == (pEvent->EventHeader.Flags & EVENT_HEADER_FLAG_32_BIT_HEADER)) {
|
||||
PointerSize = 4;
|
||||
} else {
|
||||
PointerSize = 8;
|
||||
}
|
||||
|
||||
PTRACE_EVENT_INFO pInfo = record.GetEventInfo();
|
||||
const wchar_t* name = record.GetTaskName();
|
||||
if (wcscmp(name, L"OpEnd") == 0) {
|
||||
if (!info.session_started || info.session_ended) return;
|
||||
PBYTE pUserData = (PBYTE)pEvent->UserData;
|
||||
PBYTE pEndOfUserData = (PBYTE)pEvent->UserData + pEvent->UserDataLength;
|
||||
|
||||
// Print the event data for all the top-level properties. Metadata for all the
|
||||
// top-level properties come before structure member properties in the
|
||||
// property information array.
|
||||
std::wstring opname;
|
||||
long time_spent_in_this_op = 0;
|
||||
for (USHORT i = 0; i < pInfo->TopLevelPropertyCount; i++) {
|
||||
pUserData = PrintProperties(pEvent, pInfo, PointerSize, i, pUserData, pEndOfUserData,
|
||||
[&opname, &time_spent_in_this_op](const wchar_t* key, wchar_t* value) {
|
||||
if (wcscmp(key, L"op_name") == 0) {
|
||||
opname = value;
|
||||
} else if (wcscmp(key, L"time") == 0) {
|
||||
time_spent_in_this_op = wcstol(value, nullptr, 10);
|
||||
} else {
|
||||
wprintf(key);
|
||||
abort();
|
||||
}
|
||||
});
|
||||
if (NULL == pUserData) {
|
||||
wprintf(L"Printing top level properties failed.\n");
|
||||
abort();
|
||||
}
|
||||
}
|
||||
auto iter = info.op_stat.find(opname);
|
||||
if (iter == info.op_stat.end()) {
|
||||
OpStat s;
|
||||
s.name = opname;
|
||||
s.count = 1;
|
||||
s.total_time = time_spent_in_this_op;
|
||||
info.op_stat[opname] = s;
|
||||
} else {
|
||||
OpStat& s = iter->second;
|
||||
++s.count;
|
||||
s.total_time += time_spent_in_this_op;
|
||||
}
|
||||
} else if (wcscmp(name, L"OrtRun") == 0) {
|
||||
if (!info.session_started || info.session_ended) return;
|
||||
if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_START) {
|
||||
info.op_start_time = pEvent->EventHeader.TimeStamp;
|
||||
++info.ortrun_count;
|
||||
} else if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_END) {
|
||||
if (pEvent->EventHeader.TimeStamp.QuadPart < info.op_start_time.QuadPart) {
|
||||
throw std::runtime_error("time error");
|
||||
}
|
||||
info.time_per_run.push_back(pEvent->EventHeader.TimeStamp.QuadPart - info.op_start_time.QuadPart);
|
||||
++info.ortrun_end_count;
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
else if (wcscmp(name, L"OrtInferenceSessionActivity") == 0) {
|
||||
if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_START) {
|
||||
info.session_started = true;
|
||||
} else if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_END) {
|
||||
info.session_ended = true;
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
|
||||
printf("OrtInferenceSessionActivity\n");
|
||||
} else if (wcscmp(name, L"NodeNameMapping") == 0) {
|
||||
// ignore
|
||||
} else {
|
||||
wprintf(L"unknown event:%s\n", name);
|
||||
abort();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <windows.h>
|
||||
#include <stdio.h>
|
||||
#include <wbemidl.h>
|
||||
#include <wmistr.h>
|
||||
#include <evntrace.h>
|
||||
#include <assert.h>
|
||||
#include <tdh.h>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <in6addr.h>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
void OrtEventHandler(EVENT_RECORD* pEventRecord, void* pContext);
|
||||
|
||||
class LoggingEventRecord {
|
||||
private:
|
||||
std::vector<char> buffer_;
|
||||
EVENT_RECORD* event_record_;
|
||||
|
||||
public:
|
||||
const TRACE_EVENT_INFO* GetEventInfo() const { return (const TRACE_EVENT_INFO*)buffer_.data(); }
|
||||
TRACE_EVENT_INFO* GetEventInfo() { return (TRACE_EVENT_INFO*)buffer_.data(); }
|
||||
|
||||
const wchar_t* GetTaskName() const {
|
||||
const TRACE_EVENT_INFO* p = GetEventInfo();
|
||||
return (const wchar_t*)(buffer_.data() + p->TaskNameOffset);
|
||||
}
|
||||
|
||||
static LoggingEventRecord CreateLoggingEventRecord(EVENT_RECORD* pEvent, DWORD& status);
|
||||
};
|
||||
|
||||
struct OpStat {
|
||||
std::wstring name;
|
||||
size_t count = 0;
|
||||
uint64_t total_time = 0;
|
||||
};
|
||||
|
||||
struct ProfilingInfo {
|
||||
int ortrun_count = 0;
|
||||
int ortrun_end_count = 0;
|
||||
int session_count = 0;
|
||||
bool session_started = false;
|
||||
bool session_ended = false;
|
||||
LARGE_INTEGER op_start_time;
|
||||
|
||||
std::unordered_map<std::wstring, OpStat> op_stat;
|
||||
std::vector<ULONG64> time_per_run;
|
||||
};
|
||||
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "eparser.h"
|
||||
#include "TraceSession.h"
|
||||
|
||||
// Turns the DEFINE_GUID for EventTraceGuid into a const.
|
||||
#define INITGUID
|
||||
|
||||
static const GUID OrtProviderGuid = {0x3a26b1ff, 0x7484, 0x7484, {0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d}};
|
||||
|
||||
int real_main(int argc, TCHAR* argv[]) {
|
||||
ProfilingInfo context;
|
||||
TraceSession session;
|
||||
session.AddHandler(OrtProviderGuid, OrtEventHandler, &context);
|
||||
session.InitializeEtlFile(argv[1], nullptr);
|
||||
ULONG status = ProcessTrace(&session.traceHandle_, 1, 0, 0);
|
||||
if (status != ERROR_SUCCESS && status != ERROR_CANCELLED) {
|
||||
std::cout << "OpenTrace failed with " << status << std::endl;
|
||||
session.Finalize();
|
||||
return -1;
|
||||
}
|
||||
session.Finalize();
|
||||
|
||||
assert(context.ortrun_count == context.ortrun_end_count);
|
||||
std::vector<OpStat*> stat_array(context.op_stat.size());
|
||||
size_t i = 0;
|
||||
for (auto& p : context.op_stat) {
|
||||
stat_array[i++] = &p.second;
|
||||
}
|
||||
std::sort(stat_array.begin(), stat_array.end(),
|
||||
[](const OpStat* left, const OpStat* right) { return left->total_time > right->total_time; });
|
||||
size_t iterations = context.time_per_run.size();
|
||||
ULONG64 total_time = std::accumulate(context.time_per_run.begin() + 1, context.time_per_run.end(), (ULONG64)0);
|
||||
// in microseconds
|
||||
ULONG64 avg_time = total_time / (context.time_per_run.size() - 1) / 10;
|
||||
double sum = 0;
|
||||
for (OpStat* p : stat_array) {
|
||||
if (p->name == L"Scan") {
|
||||
continue;
|
||||
}
|
||||
uint64_t avg_time_per_op = p->total_time / iterations;
|
||||
if (avg_time_per_op >= 0) {
|
||||
double t = avg_time_per_op * 100.0 / avg_time;
|
||||
std::wcout << p->name << L" " << p->total_time / p->count << L" " << std::fixed << std::setprecision(1) << t
|
||||
<< L"%\n";
|
||||
}
|
||||
sum += p->total_time / (double)iterations;
|
||||
}
|
||||
std::wcout << L"total " << std::fixed << std::setprecision(1) << (sum * 100.0) / avg_time << L"%\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
int _tmain(int argc, TCHAR* argv[]) {
|
||||
int retval = -1;
|
||||
try {
|
||||
retval = real_main(argc, argv);
|
||||
} catch (std::exception& ex) {
|
||||
fprintf(stderr, "%s\n", ex.what());
|
||||
retval = -1;
|
||||
}
|
||||
return retval;
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- TODO:
|
||||
1. Find and replace "OrtTraceLoggingProvider" with your component name.
|
||||
2. See TODO below to update GUID for your event provider
|
||||
-->
|
||||
<WindowsPerformanceRecorder Version="1.0" Author="Microsoft Corporation"
|
||||
Copyright="Microsoft Corporation" Company="Microsoft Corporation">
|
||||
<Profiles>
|
||||
<EventCollector Id="EventCollector_OrtTraceLoggingProvider"
|
||||
Name="OrtTraceLoggingProviderCollector">
|
||||
<BufferSize Value="65536" />
|
||||
<Buffers Value="10" PercentageOfTotalMemory="true"/>
|
||||
</EventCollector>
|
||||
|
||||
<EventProvider Id="EventProvider_OrtTraceLoggingProvider"
|
||||
Name="3a26b1ff-7484-7484-7484-15261f42614d" />
|
||||
<Profile Id="OrtTraceLoggingProvider.Verbose.File"
|
||||
Name="OrtTraceLoggingProvider" Description="OrtTraceLoggingProvider"
|
||||
LoggingMode="File" DetailLevel="Verbose">
|
||||
<Collectors>
|
||||
<EventCollectorId Value="EventCollector_OrtTraceLoggingProvider">
|
||||
<EventProviders>
|
||||
<EventProviderId Value="EventProvider_OrtTraceLoggingProvider" />
|
||||
</EventProviders>
|
||||
</EventCollectorId>
|
||||
</Collectors>
|
||||
</Profile>
|
||||
|
||||
<Profile Id="OrtTraceLoggingProvider.Light.File"
|
||||
Name="OrtTraceLoggingProvider"
|
||||
Description="OrtTraceLoggingProvider"
|
||||
Base="OrtTraceLoggingProvider.Verbose.File"
|
||||
LoggingMode="File"
|
||||
DetailLevel="Light" />
|
||||
|
||||
<Profile Id="OrtTraceLoggingProvider.Verbose.Memory"
|
||||
Name="OrtTraceLoggingProvider"
|
||||
Description="OrtTraceLoggingProvider"
|
||||
Base="OrtTraceLoggingProvider.Verbose.File"
|
||||
LoggingMode="Memory"
|
||||
DetailLevel="Verbose" />
|
||||
|
||||
<Profile Id="OrtTraceLoggingProvider.Light.Memory"
|
||||
Name="OrtTraceLoggingProvider"
|
||||
Description="OrtTraceLoggingProvider"
|
||||
Base="OrtTraceLoggingProvider.Verbose.File"
|
||||
LoggingMode="Memory"
|
||||
DetailLevel="Light" />
|
||||
|
||||
</Profiles>
|
||||
</WindowsPerformanceRecorder>
|
Загрузка…
Ссылка в новой задаче