Add Tracelogging for profiling (#1639)

Enabled only if onnxruntime_ENABLE_INSTRUMENT is ON
This commit is contained in:
Changming Sun 2019-11-11 21:34:10 -08:00 коммит произвёл GitHub
Родитель 0c6e9f94d0
Коммит fc6773a65b
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
18 изменённых файлов: 1228 добавлений и 72 удалений

Просмотреть файл

@ -83,6 +83,7 @@ option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in l
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump node input shapes and output data to standard output when executing the model." OFF)
option(onnxruntime_USE_DML "Build with DirectML support" OFF)
option(onnxruntime_USE_ACL "Build with ACL support" OFF)
option(onnxruntime_ENABLE_INSTRUMENT "Enable Instrument with Event Tracing for Windows (ETW)" OFF)
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE)
#nsync tests failed on Mac Build
@ -91,6 +92,15 @@ set(ONNX_ML 1)
if(NOT onnxruntime_ENABLE_PYTHON)
set(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS OFF)
endif()
if(NOT WIN32)
#TODO: On Linux we may try https://github.com/microsoft/TraceLogging
if(onnxruntime_ENABLE_INSTRUMENT)
message(WARNING "Instrument is only supported on Windows now")
set(onnxruntime_ENABLE_INSTRUMENT OFF)
endif()
endif()
if(onnxruntime_USE_OPENMP)
find_package(OpenMP)
if (OPENMP_FOUND)

Просмотреть файл

@ -10,7 +10,9 @@ file(GLOB_RECURSE onnxruntime_framework_srcs CONFIGURE_DEPENDS
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_framework_srcs})
add_library(onnxruntime_framework ${onnxruntime_framework_srcs})
if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_framework PRIVATE ONNXRUNTIME_ENABLE_INSTRUMENT)
endif()
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
onnxruntime_add_include_to_target(onnxruntime_framework onnxruntime_common onnx onnx_proto protobuf::libprotobuf)
set_target_properties(onnxruntime_framework PROPERTIES FOLDER "ONNXRuntime")

Просмотреть файл

@ -12,6 +12,9 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_session_srcs})
add_library(onnxruntime_session ${onnxruntime_session_srcs})
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/session DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core)
onnxruntime_add_include_to_target(onnxruntime_session onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf)
if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_session PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
endif()
target_include_directories(onnxruntime_session PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS})
add_dependencies(onnxruntime_session ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")

Просмотреть файл

@ -776,6 +776,17 @@ if (onnxruntime_BUILD_SERVER)
endif()
#some ETW tools
if(WIN32 AND onnxruntime_ENABLE_INSTRUMENT)
add_executable(generate_perf_report_from_etl ${ONNXRUNTIME_ROOT}/tool/etw/main.cc ${ONNXRUNTIME_ROOT}/tool/etw/eparser.h ${ONNXRUNTIME_ROOT}/tool/etw/eparser.cc ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.h ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.cc)
target_compile_definitions(generate_perf_report_from_etl PRIVATE "_CONSOLE" "_UNICODE" "UNICODE")
target_link_libraries(generate_perf_report_from_etl PRIVATE tdh Advapi32)
add_executable(compare_two_sessions ${ONNXRUNTIME_ROOT}/tool/etw/compare_two_sessions.cc ${ONNXRUNTIME_ROOT}/tool/etw/eparser.h ${ONNXRUNTIME_ROOT}/tool/etw/eparser.cc ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.h ${ONNXRUNTIME_ROOT}/tool/etw/TraceSession.cc)
target_compile_definitions(compare_two_sessions PRIVATE "_CONSOLE" "_UNICODE" "UNICODE")
target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32)
endif()
add_executable(onnxruntime_mlas_test ${TEST_SRC_DIR}/mlas/unittest.cpp)
target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT})
set(onnxruntime_mlas_test_libs onnxruntime_mlas onnxruntime_common)

Просмотреть файл

@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <windows.h>
#include <TraceLoggingProvider.h>
TRACELOGGING_DECLARE_PROVIDER(telemetry_provider_handle);

Просмотреть файл

@ -25,6 +25,22 @@
using namespace Concurrency;
#endif
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
#include <Windows.h>
#include "core/platform/tracing.h"
namespace {
LARGE_INTEGER OrtGetPerformanceFrequency() {
LARGE_INTEGER v;
// On systems that run Windows XP or later, the QueryPerformanceFrequency function will always succeed
// and will thus never return zero.
(void)QueryPerformanceFrequency(&v);
return v;
}
LARGE_INTEGER perf_freq = OrtGetPerformanceFrequency();
} // namespace
#endif
namespace onnxruntime {
static Status ReleaseNodeMLValues(ExecutionFrame& frame,
@ -87,7 +103,10 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
if (p_op_kernel == nullptr)
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Got nullptr from GetKernel for node: ",
node.Name());
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
LARGE_INTEGER kernel_start;
QueryPerformanceCounter(&kernel_start);
#endif
// construct OpKernelContext
// TODO: log kernel inputs?
OpKernelContextInternal op_kernel_context(session_state, frame, *p_op_kernel, logger, terminate_flag_);
@ -128,7 +147,6 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
}
}
}
#if defined DEBUG_NODE_INPUTS_OUTPUTS
utils::DumpNodeInputs(op_kernel_context, p_op_kernel->Node());
#endif
@ -202,7 +220,19 @@ Status SequentialExecutor::Execute(const SessionState& session_state, const std:
}
}
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
LARGE_INTEGER kernel_stop;
QueryPerformanceCounter(&kernel_stop);
LARGE_INTEGER elapsed;
elapsed.QuadPart = kernel_stop.QuadPart - kernel_start.QuadPart;
elapsed.QuadPart *= 1000000;
elapsed.QuadPart /= perf_freq.QuadPart;
// Log an event
TraceLoggingWrite(telemetry_provider_handle, // handle to my provider
"OpEnd", // Event Name that should uniquely identify your event.
TraceLoggingValue(p_op_kernel->KernelDef().OpName().c_str(), "op_name"),
TraceLoggingValue(elapsed.QuadPart, "time"));
#endif
if (is_profiler_enabled) {
session_state.Profiler().EndTimeAndRecordEvent(profiling::NODE_EVENT,
p_op_kernel->Node().Name() + "_fence_after",

Просмотреть файл

@ -271,7 +271,10 @@ void SessionState::AddSubgraphSessionState(onnxruntime::NodeIndex index, const s
ORT_ENFORCE(existing_entries.find(attribute_name) == existing_entries.cend(), "Entry exists in node ", index,
" for attribute ", attribute_name);
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
session_state->parent_ = this;
GenerateGraphId();
#endif
subgraph_session_states_[index].insert(std::make_pair(attribute_name, std::move(session_state)));
}

Просмотреть файл

@ -268,6 +268,19 @@ class SessionState {
std::unique_ptr<NodeIndexInfo> node_index_info_;
std::multimap<int, std::unique_ptr<FeedsFetchesManager>> cached_feeds_fetches_managers_;
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
SessionState* parent_ = nullptr;
//Assign each graph in each session an unique id.
int graph_id_ = 0;
int next_graph_id_ = 1;
void GenerateGraphId() {
SessionState* p = this;
while (p->parent_ != nullptr) p = p->parent_;
graph_id_ = p->next_graph_id_ ++;
}
#endif
};
} // namespace onnxruntime

Просмотреть файл

@ -19,6 +19,10 @@
#include "core/platform/env.h"
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
#include "core/platform/tracing.h"
#endif
namespace onnxruntime {
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;

Просмотреть файл

@ -73,7 +73,7 @@ inline const wchar_t* GetDateFormatString<wchar_t>() {
return L"%Y-%m-%d_%H-%M-%S";
}
#endif
//TODO: use LoggingManager::GetTimestamp and date::operator<<
// TODO: use LoggingManager::GetTimestamp and date::operator<<
// (see ostream_sink.cc for an example)
// to simplify this and match the log file timestamp format.
template <typename T>
@ -115,7 +115,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
insert_cast_transformer_("CastFloat16Transformer") {
ORT_ENFORCE(Environment::IsInitialized(),
"Environment must be initialized before creating an InferenceSession.");
InitLogger(logging_manager);
session_state_.SetDataTransferMgr(&data_transfer_mgr_);
@ -144,6 +143,9 @@ InferenceSession::~InferenceSession() {
LOGS(*session_logger_, ERROR) << "Unknown error during EndProfiling()";
}
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
#endif
}
common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr<IExecutionProvider> p_exec_provider) {
@ -176,8 +178,8 @@ common::Status InferenceSession::RegisterExecutionProvider(std::unique_ptr<IExec
return Status::OK();
}
common::Status InferenceSession::RegisterGraphTransformer(std::unique_ptr<onnxruntime::GraphTransformer> p_graph_transformer,
TransformerLevel level) {
common::Status InferenceSession::RegisterGraphTransformer(
std::unique_ptr<onnxruntime::GraphTransformer> p_graph_transformer, TransformerLevel level) {
if (p_graph_transformer == nullptr) {
return Status(common::ONNXRUNTIME, common::FAIL, "Received nullptr for graph transformer");
}
@ -185,8 +187,7 @@ common::Status InferenceSession::RegisterGraphTransformer(std::unique_ptr<onnxru
}
common::Status InferenceSession::AddCustomTransformerList(const std::vector<std::string>& transformers_to_enable) {
std::copy(transformers_to_enable.begin(), transformers_to_enable.end(),
std::back_inserter(transformers_to_enable_));
std::copy(transformers_to_enable.begin(), transformers_to_enable.end(), std::back_inserter(transformers_to_enable_));
return Status::OK();
}
@ -213,7 +214,8 @@ common::Status InferenceSession::RegisterCustomRegistry(std::shared_ptr<CustomRe
return Status::OK();
}
common::Status InferenceSession::Load(std::function<common::Status(std::shared_ptr<Model>&)> loader, const std::string& event_name) {
common::Status InferenceSession::Load(std::function<common::Status(std::shared_ptr<Model>&)> loader,
const std::string& event_name) {
Status status = Status::OK();
TimePoint tp;
if (session_profiler_.IsEnabled()) {
@ -223,8 +225,7 @@ common::Status InferenceSession::Load(std::function<common::Status(std::shared_p
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
if (is_model_loaded_) { // already loaded
LOGS(*session_logger_, ERROR) << "This session already contains a loaded model.";
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED,
"This session already contains a loaded model.");
return common::Status(common::ONNXRUNTIME, common::MODEL_LOADED, "This session already contains a loaded model.");
}
std::shared_ptr<onnxruntime::Model> p_tmp_model;
@ -281,14 +282,10 @@ common::Status InferenceSession::Load(const std::basic_string<T>& model_uri) {
return Status::OK();
}
common::Status InferenceSession::Load(const std::string& model_uri) {
return Load<char>(model_uri);
}
common::Status InferenceSession::Load(const std::string& model_uri) { return Load<char>(model_uri); }
#ifdef _WIN32
common::Status InferenceSession::Load(const std::wstring& model_uri) {
return Load<PATH_CHAR_TYPE>(model_uri);
}
common::Status InferenceSession::Load(const std::wstring& model_uri) { return Load<PATH_CHAR_TYPE>(model_uri); }
#endif
common::Status InferenceSession::Load(const ModelProto& model_proto) {
@ -522,7 +519,6 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
const auto implicit_inputs = node.ImplicitInputDefs();
ORT_RETURN_IF_ERROR_SESSIONID_(initializer.CreatePlan(&node, &implicit_inputs,
session_options_.execution_mode));
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
// &*subgraph_info.session_state);
@ -554,12 +550,14 @@ common::Status InferenceSession::Initialize() {
LOGS(*session_logger_, ERROR) << "Model was not loaded";
return common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded.");
}
if (is_inited_) { // already initialized
LOGS(*session_logger_, INFO) << "Session has already been initialized.";
return common::Status::OK();
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
TraceLoggingWriteStart(session_activity, "OrtInferenceSessionActivity");
session_activity_started_ = true;
#endif
// Register default CPUExecutionProvider if user didn't provide it through the Register() calls
if (!execution_providers_.Get(onnxruntime::kCpuExecutionProvider)) {
LOGS(*session_logger_, INFO) << "Adding default CPU execution provider.";
@ -578,7 +576,8 @@ common::Status InferenceSession::Initialize() {
}
// add predefined transformers
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level, transformers_to_enable_);
AddPredefinedTransformers(graph_transformation_mgr_, session_options_.graph_optimization_level,
transformers_to_enable_);
onnxruntime::Graph& graph = model_->MainGraph();
@ -624,7 +623,6 @@ common::Status InferenceSession::Initialize() {
// handle any subgraphs
ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, session_state_));
is_inited_ = true;
LOGS(*session_logger_, INFO) << "Session successfully initialized.";
} catch (const NotImplementedException& ex) {
status = ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Exception during initialization: ", ex.what());
@ -643,9 +641,7 @@ common::Status InferenceSession::Initialize() {
return status;
}
int InferenceSession::GetCurrentNumRuns() const {
return current_num_runs_.load();
}
int InferenceSession::GetCurrentNumRuns() const { return current_num_runs_.load(); }
const std::vector<std::string>& InferenceSession::GetRegisteredProviderTypes() const {
return execution_providers_.GetIds();
@ -662,8 +658,7 @@ common::Status InferenceSession::CheckShapes(const std::string& input_name,
auto expected_shape_sz = expected_shape.NumDimensions();
if (input_shape_sz != expected_shape_sz) {
std::ostringstream ostr;
ostr << "Invalid rank for input: " << input_name
<< " Got: " << input_shape_sz << " Expected: " << expected_shape_sz
ostr << "Invalid rank for input: " << input_name << " Got: " << input_shape_sz << " Expected: " << expected_shape_sz
<< " Please fix either the inputs or the model.";
return Status(ONNXRUNTIME, INVALID_ARGUMENT, ostr.str());
}
@ -705,10 +700,8 @@ static common::Status CheckTypes(MLDataType actual, MLDataType expected) {
common::Status InferenceSession::ValidateInputs(const std::vector<std::string>& feed_names,
const std::vector<OrtValue>& feeds) const {
if (feed_names.size() != feeds.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size mismatch: feed_names has ",
feed_names.size(), "elements, but feeds has ",
feeds.size(), " elements.");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Size mismatch: feed_names has ", feed_names.size(),
"elements, but feeds has ", feeds.size(), " elements.");
}
for (size_t i = 0; i < feeds.size(); ++i) {
@ -716,8 +709,7 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
auto iter = input_def_map_.find(feed_name);
if (input_def_map_.end() == iter) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid Feed Input Name:", feed_name);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid Feed Input Name:", feed_name);
}
auto expected_type = iter->second.ml_data_type;
@ -725,8 +717,8 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
if (input_ml_value.IsTensor()) {
// check for type
if (!expected_type->IsTensorType()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ",
feed_name, " is not expected to be of type tensor.");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input with name: ", feed_name,
" is not expected to be of type tensor.");
}
auto expected_element_type = expected_type->AsTensorType()->GetElementType();
@ -751,17 +743,14 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>& output_names,
const std::vector<OrtValue>* p_fetches) const {
if (p_fetches == nullptr) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Output vector pointer is NULL");
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Output vector pointer is NULL");
}
if (output_names.empty()) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"At least one output should be requested.");
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "At least one output should be requested.");
}
if (!p_fetches->empty() &&
(output_names.size() != p_fetches->size())) {
if (!p_fetches->empty() && (output_names.size() != p_fetches->size())) {
std::ostringstream ostr;
ostr << "Output vector incorrectly sized: output_names.size(): " << output_names.size()
<< "p_fetches->size(): " << p_fetches->size();
@ -770,8 +759,7 @@ common::Status InferenceSession::ValidateOutputs(const std::vector<std::string>&
for (const auto& name : output_names) {
if (model_output_names_.find(name) == model_output_names_.end()) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Invalid Output Name:" + name);
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid Output Name:" + name);
}
}
@ -787,6 +775,12 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
if (session_profiler_.IsEnabled()) {
tp = session_profiler_.StartTime();
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
TraceLoggingActivity<telemetry_provider_handle> ortrun_activity;
ortrun_activity.SetRelatedActivity(session_activity);
TraceLoggingWriteStart(ortrun_activity, "OrtRun");
#endif
Status retval = Status::OK();
try {
@ -858,7 +852,9 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
if (session_profiler_.IsEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp);
}
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
TraceLoggingWriteStop(ortrun_activity, "OrtRun");
#endif
return retval;
}
@ -889,8 +885,7 @@ std::pair<common::Status, const ModelMetadata*> InferenceSession::GetModelMetada
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
if (!is_model_loaded_) {
LOGS(*session_logger_, ERROR) << "Model was not loaded";
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."),
nullptr);
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
}
}
@ -902,8 +897,7 @@ std::pair<common::Status, const InputDefList*> InferenceSession::GetModelInputs(
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
if (!is_model_loaded_) {
LOGS(*session_logger_, ERROR) << "Model was not loaded";
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."),
nullptr);
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
}
}
@ -916,8 +910,7 @@ std::pair<common::Status, const InputDefList*> InferenceSession::GetOverridableI
std::lock_guard<onnxruntime::OrtMutex> l(session_mutex_);
if (!is_model_loaded_) {
LOGS(*session_logger_, ERROR) << "Model was not loaded";
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."),
nullptr);
return std::make_pair(common::Status(common::ONNXRUNTIME, common::FAIL, "Model was not loaded."), nullptr);
}
}
@ -971,14 +964,10 @@ void InferenceSession::StartProfiling(const std::basic_string<T>& file_prefix) {
session_profiler_.StartProfiling(ss.str());
}
void InferenceSession::StartProfiling(const std::string& file_prefix) {
StartProfiling<char>(file_prefix);
}
void InferenceSession::StartProfiling(const std::string& file_prefix) { StartProfiling<char>(file_prefix); }
#ifdef _WIN32
void InferenceSession::StartProfiling(const std::wstring& file_prefix) {
StartProfiling<PATH_CHAR_TYPE>(file_prefix);
}
void InferenceSession::StartProfiling(const std::wstring& file_prefix) { StartProfiling<PATH_CHAR_TYPE>(file_prefix); }
#endif
void InferenceSession::StartProfiling(const logging::Logger* logger_ptr) {
@ -1026,11 +1015,11 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod
for (auto elem : inputs) {
auto elem_type = utils::GetMLDataType(*elem);
auto elem_shape_proto = elem->Shape();
input_def_map_.insert({elem->Name(), InputDefMetaData(elem,
elem_type,
elem_shape_proto
? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto)
: TensorShape())});
input_def_map_.insert(
{elem->Name(),
InputDefMetaData(
elem, elem_type,
elem_shape_proto ? utils::GetTensorShapeFromTensorShapeProto(*elem_shape_proto) : TensorShape())});
}
};
@ -1086,10 +1075,7 @@ const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& ru
severity = static_cast<logging::Severity>(run_options.run_log_severity_level);
}
new_run_logger = logging_manager_->CreateLogger(run_log_id,
severity,
false,
run_options.run_log_verbosity_level);
new_run_logger = logging_manager_->CreateLogger(run_log_id, severity, false, run_options.run_log_verbosity_level);
run_logger = new_run_logger.get();
VLOGS(*run_logger, 1) << "Created logger for run with id of " << run_log_id;
@ -1116,9 +1102,7 @@ void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) {
severity = static_cast<logging::Severity>(session_options_.session_log_severity_level);
}
owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid,
severity,
false,
owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false,
session_options_.session_log_verbosity_level);
session_logger_ = owned_session_logger_.get();
} else {

Просмотреть файл

@ -24,6 +24,10 @@
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
#include "core/language_interop_ops/language_interop_ops.h"
#endif
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
#include "core/platform/tracing.h"
#include <TraceLoggingActivity.h>
#endif
namespace onnxruntime { // forward declarations
class GraphTransformer;
@ -434,7 +438,6 @@ class InferenceSession {
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
InterOpDomains interop_domains_;
#endif
// used to support platform telemetry
static std::atomic<uint32_t> global_session_id_; // a monotonically increasing session id
uint32_t session_id_; // the current session's id
@ -442,5 +445,10 @@ class InferenceSession {
long long total_run_duration_since_last_; // the total duration (us) of Run() calls since the last report
TimePoint time_sent_last_; // the TimePoint of the last report
const long long kDurationBetweenSending = 1000* 1000 * 60 * 10; // duration in (us). send a report every 10 mins
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
bool session_activity_started_ = false;
TraceLoggingActivity<telemetry_provider_handle> session_activity;
#endif
};
} // namespace onnxruntime

Просмотреть файл

@ -0,0 +1,306 @@
//*********************************************************
//
// Copyright (c) Microsoft. All rights reserved.
// This code is licensed under the MIT License (MIT).
// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
//
//*********************************************************
#include "TraceSession.h"
namespace {
VOID WINAPI EventRecordCallback(EVENT_RECORD* pEventRecord)
{
auto session = (TraceSession*) pEventRecord->UserContext;
auto const& hdr = pEventRecord->EventHeader;
if (session->startTime_ == 0) {
session->startTime_ = hdr.TimeStamp.QuadPart;
}
auto iter = session->eventHandler_.find(hdr.ProviderId);
if (iter != session->eventHandler_.end()) {
auto const& h = iter->second;
(*h.fn_)(pEventRecord, h.ctxt_);
}
}
ULONG WINAPI BufferCallback(EVENT_TRACE_LOGFILE* pLogFile)
{
auto session = (TraceSession*) pLogFile->Context;
auto shouldStopFn = session->shouldStopProcessingEventsFn_;
if (shouldStopFn && (*shouldStopFn)()) {
return FALSE; // break out of ProcessTrace()
}
return TRUE; // continue processing events
}
bool OpenLogger(
TraceSession* session,
TCHAR const* name,
bool realtime)
{
// Open trace
EVENT_TRACE_LOGFILE loggerInfo = {};
/* Filled out below based on realtime:
loggerInfo.LogFileName = nullptr;
loggerInfo.LoggerName = nullptr;
*/
loggerInfo.ProcessTraceMode = PROCESS_TRACE_MODE_EVENT_RECORD | PROCESS_TRACE_MODE_RAW_TIMESTAMP;
loggerInfo.BufferCallback = BufferCallback;
loggerInfo.EventRecordCallback = EventRecordCallback;
loggerInfo.Context = session;
/* Output members (passed also to BufferCallback()):
loggerInfo.CurrentTime
loggerInfo.BuffersRead
loggerInfo.CurrentEvent
loggerInfo.LogfileHeader
loggerInfo.BufferSize
loggerInfo.Filled
loggerInfo.IsKernelTrace
*/
/* Not used:
loggerInfo.EventsLost
*/
if (realtime) {
loggerInfo.LoggerName = const_cast<decltype(loggerInfo.LoggerName)>(name);
loggerInfo.ProcessTraceMode |= PROCESS_TRACE_MODE_REAL_TIME;
} else {
loggerInfo.LogFileName = const_cast<decltype(loggerInfo.LoggerName)>(name);
}
session->traceHandle_ = OpenTrace(&loggerInfo);
if (session->traceHandle_ == INVALID_PROCESSTRACE_HANDLE) {
fprintf(stderr, "error: failed to open trace");
auto lastError = GetLastError();
switch (lastError) {
case ERROR_INVALID_PARAMETER: fprintf(stderr, " (Logfile is NULL)"); break;
case ERROR_BAD_PATHNAME: fprintf(stderr, " (invalid LoggerName)"); break;
case ERROR_ACCESS_DENIED: fprintf(stderr, " (access denied)"); break;
default: fprintf(stderr, " (error=%u)", lastError); break;
}
fprintf(stderr, ".\n");
return false;
}
// Copy desired state from loggerInfo
session->frequency_ = loggerInfo.LogfileHeader.PerfFreq.QuadPart;
return true;
}
}
size_t TraceSession::GUIDHash::operator()(GUID const& g) const
{
static_assert((sizeof(g) % sizeof(size_t)) == 0, "sizeof(GUID) must be multiple of sizeof(size_t)");
auto p = (size_t const*) &g;
auto h = (size_t) 0;
for (size_t i = 0; i < sizeof(g) / sizeof(size_t); ++i) {
h ^= p[i];
}
return h;
}
bool TraceSession::GUIDEqual::operator()(GUID const& lhs, GUID const& rhs) const
{
return IsEqualGUID(lhs, rhs) != FALSE;
}
bool TraceSession::AddProvider(GUID providerId, UCHAR level,
ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword)
{
auto p = eventProvider_.emplace(std::make_pair(providerId, Provider()));
if (!p.second) {
return false;
}
auto h = &p.first->second;
h->matchAny_ = matchAnyKeyword;
h->matchAll_ = matchAllKeyword;
h->level_ = level;
return true;
}
bool TraceSession::AddHandler(GUID providerId, EventHandlerFn handlerFn, void* handlerContext)
{
auto p = eventHandler_.emplace(std::make_pair(providerId, Handler()));
if (!p.second) {
return false;
}
auto h = &p.first->second;
h->fn_ = handlerFn;
h->ctxt_ = handlerContext;
return true;
}
bool TraceSession::AddProviderAndHandler(GUID providerId, UCHAR level,
ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword,
EventHandlerFn handlerFn, void* handlerContext)
{
if (!AddProvider(providerId, level, matchAnyKeyword, matchAllKeyword))
return false;
if (!AddHandler(providerId, handlerFn, handlerContext)) {
RemoveProvider(providerId);
return false;
}
return true;
}
bool TraceSession::RemoveProvider(GUID providerId)
{
if (sessionHandle_ != 0) {
auto status = EnableTraceEx2(sessionHandle_, &providerId, EVENT_CONTROL_CODE_DISABLE_PROVIDER, 0, 0, 0, 0, nullptr);
(void) status;
}
return eventProvider_.erase(providerId) != 0;
}
bool TraceSession::RemoveHandler(GUID providerId)
{
return eventHandler_.erase(providerId) != 0;
}
bool TraceSession::RemoveProviderAndHandler(GUID providerId)
{
return RemoveProvider(providerId) || RemoveHandler(providerId);
}
bool TraceSession::InitializeEtlFile(TCHAR const* inputEtlPath, ShouldStopProcessingEventsFn shouldStopFn)
{
// Open the trace
if (!OpenLogger(this, inputEtlPath, false)) {
Finalize();
return false;
}
// Initialize state
shouldStopProcessingEventsFn_ = shouldStopFn;
eventsLostCount_ = 0;
buffersLostCount_ = 0;
return true;
}
bool TraceSession::InitializeRealtime(TCHAR const* traceSessionName, ShouldStopProcessingEventsFn shouldStopFn)
{
// Set up and start a real-time collection session
memset(&properties_, 0, sizeof(properties_));
properties_.Wnode.BufferSize = (ULONG) offsetof(TraceSession, sessionHandle_);
//properties_.Wnode.Guid // ETW will create Guid
properties_.Wnode.ClientContext = 1; // Clock resolution to use when logging the timestamp for each event
// 1 == query performance counter
properties_.Wnode.Flags = 0;
//properties_.BufferSize = 0;
properties_.MinimumBuffers = 200;
//properties_.MaximumBuffers = 0;
//properties_.MaximumFileSize = 0;
properties_.LogFileMode = EVENT_TRACE_REAL_TIME_MODE;
//properties_.FlushTimer = 0;
//properties_.EnableFlags = 0;
properties_.LogFileNameOffset = 0;
properties_.LoggerNameOffset = offsetof(TraceSession, loggerName_);
auto status = StartTrace(&sessionHandle_, traceSessionName, &properties_);
if (status == ERROR_ALREADY_EXISTS) {
#ifdef _DEBUG
fprintf(stderr, "warning: trying to start trace session that already exists.\n");
#endif
status = ControlTrace((TRACEHANDLE) 0, traceSessionName, &properties_, EVENT_TRACE_CONTROL_STOP);
if (status == ERROR_SUCCESS) {
status = StartTrace(&sessionHandle_, traceSessionName, &properties_);
}
}
if (status != ERROR_SUCCESS) {
fprintf(stderr, "error: failed to start trace session (error=%lu).\n", status);
return false;
}
// Enable desired providers
for (auto const& p : eventProvider_) {
auto pGuid = &p.first;
auto const& h = p.second;
status = EnableTraceEx2(sessionHandle_, pGuid, EVENT_CONTROL_CODE_ENABLE_PROVIDER, h.level_, h.matchAny_, h.matchAll_, 0, nullptr);
if (status != ERROR_SUCCESS) {
fprintf(stderr, "error: failed to enable provider {%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x}.\n",
pGuid->Data1, pGuid->Data2, pGuid->Data3, pGuid->Data4[0], pGuid->Data4[1], pGuid->Data4[2],
pGuid->Data4[3], pGuid->Data4[4], pGuid->Data4[5], pGuid->Data4[6], pGuid->Data4[7]);
Finalize();
return false;
}
}
// Open the trace
if (!OpenLogger(this, traceSessionName, true)) {
Finalize();
return false;
}
// Initialize state
shouldStopProcessingEventsFn_ = shouldStopFn;
eventsLostCount_ = 0;
buffersLostCount_ = 0;
return true;
}
void TraceSession::Finalize()
{
ULONG status = ERROR_SUCCESS;
if (traceHandle_ != INVALID_PROCESSTRACE_HANDLE) {
status = CloseTrace(traceHandle_);
traceHandle_ = INVALID_PROCESSTRACE_HANDLE;
}
if (sessionHandle_ != 0) {
status = ControlTraceW(sessionHandle_, nullptr, &properties_, EVENT_TRACE_CONTROL_STOP);
while (!eventProvider_.empty()) {
RemoveProvider(eventProvider_.begin()->first);
}
while (!eventHandler_.empty()) {
RemoveHandler(eventHandler_.begin()->first);
}
sessionHandle_ = 0;
}
}
bool TraceSession::CheckLostReports(uint32_t* eventsLost, uint32_t* buffersLost)
{
if (sessionHandle_ == 0) {
*eventsLost = 0;
*buffersLost = 0;
return false;
}
auto status = ControlTraceW(sessionHandle_, nullptr, &properties_, EVENT_TRACE_CONTROL_QUERY);
if (status == ERROR_MORE_DATA) { // The buffer &properties_ is too small to hold all the information
*eventsLost = 0; // for the session. If you don't need the session's property information
*buffersLost = 0; // you can ignore this error.
return false;
}
if (status != ERROR_SUCCESS) {
fprintf(stderr, "error: failed to query trace status (%lu).\n", status);
*eventsLost = 0;
*buffersLost = 0;
return false;
}
*eventsLost = properties_.EventsLost - eventsLostCount_;
*buffersLost = properties_.RealTimeBuffersLost - buffersLostCount_;
eventsLostCount_ = properties_.EventsLost;
buffersLostCount_ = properties_.RealTimeBuffersLost;
return *eventsLost + *buffersLost > 0;
}

Просмотреть файл

@ -0,0 +1,99 @@
//*********************************************************
//
// Copyright (c) Microsoft. All rights reserved.
// This code is licensed under the MIT License (MIT).
// THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
// ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
// IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
// PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
//
//*********************************************************
#pragma once
#include <windows.h>
#include <tchar.h>
#include <evntcons.h> // must be after windows.h
#include <stdint.h>
#include <unordered_map>
typedef void (*EventHandlerFn)(EVENT_RECORD* pEventRecord, void* pContext);
typedef bool (*ShouldStopProcessingEventsFn)();
struct TraceSession {
// BEGIN trace property block, must be beginning of TraceSession
EVENT_TRACE_PROPERTIES properties_;
wchar_t loggerName_[MAX_PATH];
// END Trace property block
TRACEHANDLE sessionHandle_; // Must be first member after trace property block
TRACEHANDLE traceHandle_;
ShouldStopProcessingEventsFn shouldStopProcessingEventsFn_;
uint64_t startTime_;
uint64_t frequency_;
uint32_t eventsLostCount_;
uint32_t buffersLostCount_;
// Structure to hold the mapping from provider ID to event handler function
struct GUIDHash { size_t operator()(GUID const& g) const; };
struct GUIDEqual { bool operator()(GUID const& lhs, GUID const& rhs) const; };
struct Provider {
ULONGLONG matchAny_;
ULONGLONG matchAll_;
UCHAR level_;
};
struct Handler {
EventHandlerFn fn_;
void* ctxt_;
};
std::unordered_map<GUID, Provider, GUIDHash, GUIDEqual> eventProvider_;
std::unordered_map<GUID, Handler, GUIDHash, GUIDEqual> eventHandler_;
TraceSession()
: sessionHandle_(0)
, traceHandle_(INVALID_PROCESSTRACE_HANDLE)
, startTime_(0)
, frequency_(0)
, shouldStopProcessingEventsFn_(nullptr)
{
}
// Usage:
//
// 1) use TraceSession::AddProvider() to add the IDs for all the providers
// you want to trace. Use TraceSession::AddHandler() to add the handler
// functions for the providers/events you want to trace.
//
// 2) call TraceSession::InitializeRealtime() or
// TraceSession::InitializeEtlFile(), to start tracing events from
// real-time collection or from a previously-captured .etl file. At this
// point, events start to be traced.
//
// 3) call ::ProcessTrace() to start collecting the events; provider
// handler functions will be called as those provider events are collected.
// ProcessTrace() will exit when shouldStopProcessingEventsFn_ returns
// true, or when the .etl file is fully consumed.
//
// 4) Finalize() to clean up.
// AddProvider/Handler() returns false if the providerId already has a handler.
// RemoveProvider/Handler() returns false if the providerId don't have a handler.
bool AddProvider(GUID providerId, UCHAR level, ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword);
bool AddHandler(GUID handlerId, EventHandlerFn handlerFn, void* handlerContext);
bool AddProviderAndHandler(GUID providerId, UCHAR level, ULONGLONG matchAnyKeyword, ULONGLONG matchAllKeyword,
EventHandlerFn handlerFn, void* handlerContext);
bool RemoveProvider(GUID providerId);
bool RemoveHandler(GUID handlerId);
bool RemoveProviderAndHandler(GUID providerId);
// InitializeRealtime() and InitializeEtlFile() return false if the session
// could not be created.
bool InitializeEtlFile(TCHAR const* etlPath, ShouldStopProcessingEventsFn shouldStopProcessingEventsFn);
bool InitializeRealtime(TCHAR const* traceSessionName, ShouldStopProcessingEventsFn shouldStopProcessingEventsFn);
void Finalize();
// Call CheckLostReports() at any time the session is initialized to query
// how many events and buffers have been lost while tracing.
bool CheckLostReports(uint32_t* eventsLost, uint32_t* buffersLost);
};

Просмотреть файл

@ -0,0 +1,145 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <iostream>
#include "eparser.h"
#include "TraceSession.h"
#ifdef _WIN32
#include <tchar.h>
#else
#define TCHAR char
#define _tmain main
#endif
#ifdef _WIN32
#include "getopt.h"
#else
#include <getopt.h>
#include <thread>
#endif
static const GUID OrtProviderGuid = {0x54d81939, 0x62a0, 0x4dc0, {0xbf, 0x32, 0x3, 0x5e, 0xbd, 0xc7, 0xbc, 0xe9}};
int fetch_data(TCHAR* filename, ProfilingInfo& context) {
TraceSession session;
session.AddHandler(OrtProviderGuid, OrtEventHandler, &context);
session.InitializeEtlFile(filename, nullptr);
ULONG status = ProcessTrace(&session.traceHandle_, 1, 0, 0);
if (status != ERROR_SUCCESS && status != ERROR_CANCELLED) {
std::cout << "OpenTrace failed with " << status << std::endl;
session.Finalize();
return -1;
}
session.Finalize();
return 0;
}
template <typename T>
std::pair<double, double> CalcMeanAndStdSquare(const T* input, size_t input_len) {
T sum = 0;
T sum_square = 0;
const size_t N = input_len;
for (size_t i = 0; i != N; ++i) {
T t = input[i];
sum += t;
sum_square += t * t;
}
double mean = ((double)sum) / N;
double std = (sum_square - N * mean * mean) / (N - 1);
return std::make_pair(mean, std);
}
// see: "Statistical Distributions", 4th Edition, by Catherine Forbes, Merran Evans, Nicholas Hastings and Brian
// Peacock. Chapter 42: "Students t Distribution". I only implemented when v is even.
double TDistributionCDF(int v, double x) {
assert(v >= 2 && (v & 1) == 0);
double t = x / (2 * std::sqrt(v + x * x));
double sum1 = 0;
double b_j = 1;
for (int j = 0; j <= (v - 2) / 2; ++j) {
sum1 += b_j / std::pow(1 + x * x / v, j);
b_j *= static_cast<double>(2 * j + 1) / (2 * j + 2);
}
return 0.5 + t * sum1;
}
struct TTestResult {
double mean1, mean2;
double std1, std2;
double tvalue;
};
template <typename T>
TTestResult CalcTValue(const T* input1, size_t input1_len, const T* input2, size_t input2_len) {
TTestResult result;
auto p1 = CalcMeanAndStdSquare(input1, input1_len);
result.mean1 = p1.first;
result.std1 = std::sqrt(p1.second);
auto p2 = CalcMeanAndStdSquare(input2, input2_len);
result.mean2 = p2.first;
result.std2 = std::sqrt(p2.second);
auto diff_mean = p1.first - p2.first;
size_t n1 = input1_len;
size_t n2 = input2_len;
auto sdiff = ((n1 - 1) * p1.second + (n2 - 1) * p2.second) / (n1 + n2 - 2);
sdiff *= ((double)1) / n1 + ((double)1) / n2;
result.tvalue = diff_mean / std::sqrt(sdiff);
return result;
}
int real_main(int argc, TCHAR* argv[]) {
if (argc < 3) {
printf("error\n");
return -1;
}
ProfilingInfo context1;
int ret = fetch_data(argv[1], context1);
if (ret != 0) return ret;
ProfilingInfo context2;
ret = fetch_data(argv[2], context2);
if (ret != 0) return ret;
size_t n1 = context1.time_per_run.size();
size_t n2 = context2.time_per_run.size();
if (n1 <= 10 || n2 <= 10) {
printf("samples are too few, please try to gather more\n");
return -1;
}
// ignore the first run
--n1;
--n2;
if (((n1 + n2) & 1) != 0) {
if (n1 > n2)
n1--;
else
n2--;
}
TTestResult tresult = CalcTValue(context1.time_per_run.data() + 1, n1, context2.time_per_run.data() + 1, n2);
size_t freedom = n1 + n2 - 2;
double p = TDistributionCDF(static_cast<int>(freedom), std::abs(tresult.tvalue));
std::cout << "Mean1: " << tresult.mean1 << " std1: " << tresult.std1 << "\n"
<< "Mean2: " << tresult.mean2 << " std2: " << tresult.std2 << "\n"
<< "H0: Mean1 = Mean2\n"
<< "H1: Mean1 != Mean2\n"
<< "Test statistic: T = " << tresult.tvalue << "\n"
<< "Degrees of Freedom: v = " << freedom << "\n"
<< "Significance level:" << (1 - p) * 2 << ". The lower the more likely to reject H0\n";
if (p > 0.99995) {
std::cout << "The two population means are different at the 0.0001 significance level." << std::endl;
return -1;
} else {
std::cout << "They don't have significant statistical difference." << std::endl;
return 0;
}
}
int _tmain(int argc, TCHAR* argv[]) {
int retval = -1;
try {
retval = real_main(argc, argv);
} catch (std::exception& ex) {
fprintf(stderr, "%s\n", ex.what());
retval = -1;
}
return retval;
}

Просмотреть файл

@ -0,0 +1,355 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "eparser.h"
// Get the metadata for the event.
// Get the length of the property data. For MOF-based events, the size is inferred from the data type
// of the property. For manifest-based events, the property can specify the size of the property value
// using the length attribute. The length attribue can specify the size directly or specify the name
// of another property in the event data that contains the size. If the property does not include the
// length attribute, the size is inferred from the data type. The length will be zero for variable
// length, null-terminated strings and structures.
DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT PropertyLength);
// Get the size of the array. For MOF-based events, the size is specified in the declaration or using
// the MAX qualifier. For manifest-based events, the property can specify the size of the array
// using the count attribute. The count attribue can specify the size directly or specify the name
// of another property in the event data that contains the size.
DWORD GetArraySize(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT ArraySize);
// Both MOF-based events and manifest-based events can specify name/value maps. The
// map values can be integer values or bit values. If the property specifies a value
// map, get the map.
DWORD GetMapInfo(PEVENT_RECORD pEvent, LPWSTR pMapName, DWORD DecodingSource, PEVENT_MAP_INFO& pMapInfo);
// Print the property.
template <typename T>
PBYTE PrintProperties(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, DWORD PointerSize, USHORT i, PBYTE pUserData,
PBYTE pEndOfUserData, const T& t) {
TDHSTATUS status = ERROR_SUCCESS;
USHORT PropertyLength = 0;
DWORD FormattedDataSize = 0;
USHORT UserDataConsumed = 0;
LPWSTR pFormattedData = NULL;
DWORD LastMember = 0; // Last member of a structure
USHORT ArraySize = 0;
PEVENT_MAP_INFO pMapInfo = NULL;
// Get the length of the property.
status = GetPropertyLength(pEvent, pInfo, i, &PropertyLength);
if (ERROR_SUCCESS != status) {
wprintf(L"GetPropertyLength failed.\n");
pUserData = NULL;
goto cleanup;
}
// Get the size of the array if the property is an array.
status = GetArraySize(pEvent, pInfo, i, &ArraySize);
for (USHORT k = 0; k < ArraySize; k++) {
// If the property is a structure, print the members of the structure.
if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyStruct) == PropertyStruct) {
LastMember = pInfo->EventPropertyInfoArray[i].structType.StructStartIndex +
pInfo->EventPropertyInfoArray[i].structType.NumOfStructMembers;
for (USHORT j = pInfo->EventPropertyInfoArray[i].structType.StructStartIndex; j < LastMember; j++) {
pUserData = PrintProperties(pEvent, pInfo, PointerSize, j, pUserData, pEndOfUserData, t);
if (NULL == pUserData) {
wprintf(L"Printing the members of the structure failed.\n");
pUserData = NULL;
goto cleanup;
}
}
} else {
// Get the name/value mapping if the property specifies a value map.
status =
GetMapInfo(pEvent, (PWCHAR)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[i].nonStructType.MapNameOffset),
pInfo->DecodingSource, pMapInfo);
if (ERROR_SUCCESS != status) {
wprintf(L"GetMapInfo failed\n");
pUserData = NULL;
goto cleanup;
}
// Get the size of the buffer required for the formatted data.
status = TdhFormatProperty(pInfo, pMapInfo, PointerSize, pInfo->EventPropertyInfoArray[i].nonStructType.InType,
pInfo->EventPropertyInfoArray[i].nonStructType.OutType, PropertyLength,
(USHORT)(pEndOfUserData - pUserData), pUserData, &FormattedDataSize, pFormattedData,
&UserDataConsumed);
if (ERROR_INSUFFICIENT_BUFFER == status) {
if (pFormattedData) {
free(pFormattedData);
pFormattedData = NULL;
}
pFormattedData = (LPWSTR)malloc(FormattedDataSize);
if (pFormattedData == NULL) {
wprintf(L"Failed to allocate memory for formatted data (size=%lu).\n", FormattedDataSize);
status = ERROR_OUTOFMEMORY;
pUserData = NULL;
goto cleanup;
}
// Retrieve the formatted data.
status = TdhFormatProperty(pInfo, pMapInfo, PointerSize, pInfo->EventPropertyInfoArray[i].nonStructType.InType,
pInfo->EventPropertyInfoArray[i].nonStructType.OutType, PropertyLength,
(USHORT)(pEndOfUserData - pUserData), pUserData, &FormattedDataSize, pFormattedData,
&UserDataConsumed);
}
if (ERROR_SUCCESS == status) {
t((PWCHAR)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[i].NameOffset), pFormattedData);
pUserData += UserDataConsumed;
} else {
wprintf(L"TdhFormatProperty failed with %lu.\n", status);
pUserData = NULL;
goto cleanup;
}
}
}
cleanup:
if (pFormattedData) {
free(pFormattedData);
pFormattedData = NULL;
}
if (pMapInfo) {
free(pMapInfo);
pMapInfo = NULL;
}
return pUserData;
}
DWORD GetPropertyLength(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT PropertyLength) {
DWORD status = ERROR_SUCCESS;
PROPERTY_DATA_DESCRIPTOR DataDescriptor;
DWORD PropertySize = 0;
// If the property is a binary blob and is defined in a manifest, the property can
// specify the blob's size or it can point to another property that defines the
// blob's size. The PropertyParamLength flag tells you where the blob's size is defined.
if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyParamLength) == PropertyParamLength) {
DWORD Length = 0; // Expects the length to be defined by a UINT16 or UINT32
DWORD j = pInfo->EventPropertyInfoArray[i].lengthPropertyIndex;
ZeroMemory(&DataDescriptor, sizeof(PROPERTY_DATA_DESCRIPTOR));
DataDescriptor.PropertyName = (ULONGLONG)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[j].NameOffset);
DataDescriptor.ArrayIndex = ULONG_MAX;
status = TdhGetPropertySize(pEvent, 0, NULL, 1, &DataDescriptor, &PropertySize);
status = TdhGetProperty(pEvent, 0, NULL, 1, &DataDescriptor, PropertySize, (PBYTE)&Length);
*PropertyLength = (USHORT)Length;
} else {
if (pInfo->EventPropertyInfoArray[i].length > 0) {
*PropertyLength = pInfo->EventPropertyInfoArray[i].length;
} else {
// If the property is a binary blob and is defined in a MOF class, the extension
// qualifier is used to determine the size of the blob. However, if the extension
// is IPAddrV6, you must set the PropertyLength variable yourself because the
// EVENT_PROPERTY_INFO.length field will be zero.
if (TDH_INTYPE_BINARY == pInfo->EventPropertyInfoArray[i].nonStructType.InType &&
TDH_OUTTYPE_IPV6 == pInfo->EventPropertyInfoArray[i].nonStructType.OutType) {
*PropertyLength = (USHORT)sizeof(IN6_ADDR);
} else if (TDH_INTYPE_UNICODESTRING == pInfo->EventPropertyInfoArray[i].nonStructType.InType ||
TDH_INTYPE_ANSISTRING == pInfo->EventPropertyInfoArray[i].nonStructType.InType ||
(pInfo->EventPropertyInfoArray[i].Flags & PropertyStruct) == PropertyStruct) {
*PropertyLength = pInfo->EventPropertyInfoArray[i].length;
} else {
wprintf(L"Unexpected length of 0 for intype %d and outtype %d\n",
pInfo->EventPropertyInfoArray[i].nonStructType.InType,
pInfo->EventPropertyInfoArray[i].nonStructType.OutType);
status = ERROR_EVT_INVALID_EVENT_DATA;
goto cleanup;
}
}
}
cleanup:
return status;
}
DWORD GetArraySize(PEVENT_RECORD pEvent, PTRACE_EVENT_INFO pInfo, USHORT i, PUSHORT ArraySize) {
DWORD status = ERROR_SUCCESS;
PROPERTY_DATA_DESCRIPTOR DataDescriptor;
DWORD PropertySize = 0;
if ((pInfo->EventPropertyInfoArray[i].Flags & PropertyParamCount) == PropertyParamCount) {
DWORD Count = 0; // Expects the count to be defined by a UINT16 or UINT32
DWORD j = pInfo->EventPropertyInfoArray[i].countPropertyIndex;
ZeroMemory(&DataDescriptor, sizeof(PROPERTY_DATA_DESCRIPTOR));
DataDescriptor.PropertyName = (ULONGLONG)((PBYTE)(pInfo) + pInfo->EventPropertyInfoArray[j].NameOffset);
DataDescriptor.ArrayIndex = ULONG_MAX;
status = TdhGetPropertySize(pEvent, 0, NULL, 1, &DataDescriptor, &PropertySize);
status = TdhGetProperty(pEvent, 0, NULL, 1, &DataDescriptor, PropertySize, (PBYTE)&Count);
*ArraySize = (USHORT)Count;
} else {
*ArraySize = pInfo->EventPropertyInfoArray[i].count;
}
return status;
}
DWORD GetMapInfo(PEVENT_RECORD pEvent, LPWSTR pMapName, DWORD DecodingSource, PEVENT_MAP_INFO& pMapInfo) {
DWORD status = ERROR_SUCCESS;
DWORD MapSize = 0;
// Retrieve the required buffer size for the map info.
status = TdhGetEventMapInformation(pEvent, pMapName, pMapInfo, &MapSize);
if (ERROR_INSUFFICIENT_BUFFER == status) {
pMapInfo = (PEVENT_MAP_INFO)malloc(MapSize);
if (pMapInfo == NULL) {
wprintf(L"Failed to allocate memory for map info (size=%lu).\n", MapSize);
status = ERROR_OUTOFMEMORY;
goto cleanup;
}
// Retrieve the map info.
status = TdhGetEventMapInformation(pEvent, pMapName, pMapInfo, &MapSize);
}
if (ERROR_SUCCESS == status) {
if (DecodingSourceXMLFile == DecodingSource) {
abort();
}
} else {
if (ERROR_NOT_FOUND == status) {
status = ERROR_SUCCESS; // This case is okay.
} else {
wprintf(L"TdhGetEventMapInformation failed with 0x%x.\n", status);
}
}
cleanup:
return status;
}
LoggingEventRecord LoggingEventRecord::CreateLoggingEventRecord(EVENT_RECORD* pEvent, DWORD& status) {
LoggingEventRecord ret;
ret.event_record_ = pEvent;
status = ERROR_SUCCESS;
DWORD BufferSize = 0;
// Retrieve the required buffer size for the event metadata.
status = TdhGetEventInformation(pEvent, 0, NULL, nullptr, &BufferSize);
if (ERROR_INSUFFICIENT_BUFFER != status) return ret;
ret.buffer_.resize(BufferSize);
// Retrieve the event metadata.
status = TdhGetEventInformation(pEvent, 0, NULL, ret.GetEventInfo(), &BufferSize);
return ret;
}
void OrtEventHandler(EVENT_RECORD* pEvent, void* pContext) {
ProfilingInfo& info = *(ProfilingInfo*)pContext;
DWORD status = ERROR_SUCCESS;
LoggingEventRecord record = LoggingEventRecord::CreateLoggingEventRecord(pEvent, status);
if (ERROR_SUCCESS != status) {
if (status == ERROR_NOT_FOUND) return;
wprintf(L"GetEventInformation failed with %lu\n", status);
abort();
}
DWORD PointerSize = 0;
if (EVENT_HEADER_FLAG_32_BIT_HEADER == (pEvent->EventHeader.Flags & EVENT_HEADER_FLAG_32_BIT_HEADER)) {
PointerSize = 4;
} else {
PointerSize = 8;
}
PTRACE_EVENT_INFO pInfo = record.GetEventInfo();
const wchar_t* name = record.GetTaskName();
if (wcscmp(name, L"OpEnd") == 0) {
if (!info.session_started || info.session_ended) return;
PBYTE pUserData = (PBYTE)pEvent->UserData;
PBYTE pEndOfUserData = (PBYTE)pEvent->UserData + pEvent->UserDataLength;
// Print the event data for all the top-level properties. Metadata for all the
// top-level properties come before structure member properties in the
// property information array.
std::wstring opname;
long time_spent_in_this_op = 0;
for (USHORT i = 0; i < pInfo->TopLevelPropertyCount; i++) {
pUserData = PrintProperties(pEvent, pInfo, PointerSize, i, pUserData, pEndOfUserData,
[&opname, &time_spent_in_this_op](const wchar_t* key, wchar_t* value) {
if (wcscmp(key, L"op_name") == 0) {
opname = value;
} else if (wcscmp(key, L"time") == 0) {
time_spent_in_this_op = wcstol(value, nullptr, 10);
} else {
wprintf(key);
abort();
}
});
if (NULL == pUserData) {
wprintf(L"Printing top level properties failed.\n");
abort();
}
}
auto iter = info.op_stat.find(opname);
if (iter == info.op_stat.end()) {
OpStat s;
s.name = opname;
s.count = 1;
s.total_time = time_spent_in_this_op;
info.op_stat[opname] = s;
} else {
OpStat& s = iter->second;
++s.count;
s.total_time += time_spent_in_this_op;
}
} else if (wcscmp(name, L"OrtRun") == 0) {
if (!info.session_started || info.session_ended) return;
if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_START) {
info.op_start_time = pEvent->EventHeader.TimeStamp;
++info.ortrun_count;
} else if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_END) {
if (pEvent->EventHeader.TimeStamp.QuadPart < info.op_start_time.QuadPart) {
throw std::runtime_error("time error");
}
info.time_per_run.push_back(pEvent->EventHeader.TimeStamp.QuadPart - info.op_start_time.QuadPart);
++info.ortrun_end_count;
} else {
abort();
}
}
else if (wcscmp(name, L"OrtInferenceSessionActivity") == 0) {
if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_START) {
info.session_started = true;
} else if (pInfo->EventDescriptor.Opcode == EVENT_TRACE_TYPE_END) {
info.session_ended = true;
} else {
abort();
}
printf("OrtInferenceSessionActivity\n");
} else if (wcscmp(name, L"NodeNameMapping") == 0) {
// ignore
} else {
wprintf(L"unknown event:%s\n", name);
abort();
}
}

Просмотреть файл

@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <windows.h>
#include <stdio.h>
#include <wbemidl.h>
#include <wmistr.h>
#include <evntrace.h>
#include <assert.h>
#include <tdh.h>
#include <stdexcept>
#include <vector>
#include <sstream>
#include <in6addr.h>
#include <unordered_map>
#include <algorithm>
#include <numeric>
#include <iostream>
#include <iomanip>
void OrtEventHandler(EVENT_RECORD* pEventRecord, void* pContext);
class LoggingEventRecord {
private:
std::vector<char> buffer_;
EVENT_RECORD* event_record_;
public:
const TRACE_EVENT_INFO* GetEventInfo() const { return (const TRACE_EVENT_INFO*)buffer_.data(); }
TRACE_EVENT_INFO* GetEventInfo() { return (TRACE_EVENT_INFO*)buffer_.data(); }
const wchar_t* GetTaskName() const {
const TRACE_EVENT_INFO* p = GetEventInfo();
return (const wchar_t*)(buffer_.data() + p->TaskNameOffset);
}
static LoggingEventRecord CreateLoggingEventRecord(EVENT_RECORD* pEvent, DWORD& status);
};
struct OpStat {
std::wstring name;
size_t count = 0;
uint64_t total_time = 0;
};
struct ProfilingInfo {
int ortrun_count = 0;
int ortrun_end_count = 0;
int session_count = 0;
bool session_started = false;
bool session_ended = false;
LARGE_INTEGER op_start_time;
std::unordered_map<std::wstring, OpStat> op_stat;
std::vector<ULONG64> time_per_run;
};

Просмотреть файл

@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "eparser.h"
#include "TraceSession.h"
// Turns the DEFINE_GUID for EventTraceGuid into a const.
#define INITGUID
static const GUID OrtProviderGuid = {0x3a26b1ff, 0x7484, 0x7484, {0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d}};
int real_main(int argc, TCHAR* argv[]) {
ProfilingInfo context;
TraceSession session;
session.AddHandler(OrtProviderGuid, OrtEventHandler, &context);
session.InitializeEtlFile(argv[1], nullptr);
ULONG status = ProcessTrace(&session.traceHandle_, 1, 0, 0);
if (status != ERROR_SUCCESS && status != ERROR_CANCELLED) {
std::cout << "OpenTrace failed with " << status << std::endl;
session.Finalize();
return -1;
}
session.Finalize();
assert(context.ortrun_count == context.ortrun_end_count);
std::vector<OpStat*> stat_array(context.op_stat.size());
size_t i = 0;
for (auto& p : context.op_stat) {
stat_array[i++] = &p.second;
}
std::sort(stat_array.begin(), stat_array.end(),
[](const OpStat* left, const OpStat* right) { return left->total_time > right->total_time; });
size_t iterations = context.time_per_run.size();
ULONG64 total_time = std::accumulate(context.time_per_run.begin() + 1, context.time_per_run.end(), (ULONG64)0);
// in microseconds
ULONG64 avg_time = total_time / (context.time_per_run.size() - 1) / 10;
double sum = 0;
for (OpStat* p : stat_array) {
if (p->name == L"Scan") {
continue;
}
uint64_t avg_time_per_op = p->total_time / iterations;
if (avg_time_per_op >= 0) {
double t = avg_time_per_op * 100.0 / avg_time;
std::wcout << p->name << L" " << p->total_time / p->count << L" " << std::fixed << std::setprecision(1) << t
<< L"%\n";
}
sum += p->total_time / (double)iterations;
}
std::wcout << L"total " << std::fixed << std::setprecision(1) << (sum * 100.0) / avg_time << L"%\n";
return 0;
}
int _tmain(int argc, TCHAR* argv[]) {
int retval = -1;
try {
retval = real_main(argc, argv);
} catch (std::exception& ex) {
fprintf(stderr, "%s\n", ex.what());
retval = -1;
}
return retval;
}

51
ort.wprp Normal file
Просмотреть файл

@ -0,0 +1,51 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- TODO:
1. Find and replace "OrtTraceLoggingProvider" with your component name.
2. See TODO below to update GUID for your event provider
-->
<WindowsPerformanceRecorder Version="1.0" Author="Microsoft Corporation"
Copyright="Microsoft Corporation" Company="Microsoft Corporation">
<Profiles>
<EventCollector Id="EventCollector_OrtTraceLoggingProvider"
Name="OrtTraceLoggingProviderCollector">
<BufferSize Value="65536" />
<Buffers Value="10" PercentageOfTotalMemory="true"/>
</EventCollector>
<EventProvider Id="EventProvider_OrtTraceLoggingProvider"
Name="3a26b1ff-7484-7484-7484-15261f42614d" />
<Profile Id="OrtTraceLoggingProvider.Verbose.File"
Name="OrtTraceLoggingProvider" Description="OrtTraceLoggingProvider"
LoggingMode="File" DetailLevel="Verbose">
<Collectors>
<EventCollectorId Value="EventCollector_OrtTraceLoggingProvider">
<EventProviders>
<EventProviderId Value="EventProvider_OrtTraceLoggingProvider" />
</EventProviders>
</EventCollectorId>
</Collectors>
</Profile>
<Profile Id="OrtTraceLoggingProvider.Light.File"
Name="OrtTraceLoggingProvider"
Description="OrtTraceLoggingProvider"
Base="OrtTraceLoggingProvider.Verbose.File"
LoggingMode="File"
DetailLevel="Light" />
<Profile Id="OrtTraceLoggingProvider.Verbose.Memory"
Name="OrtTraceLoggingProvider"
Description="OrtTraceLoggingProvider"
Base="OrtTraceLoggingProvider.Verbose.File"
LoggingMode="Memory"
DetailLevel="Verbose" />
<Profile Id="OrtTraceLoggingProvider.Light.Memory"
Name="OrtTraceLoggingProvider"
Description="OrtTraceLoggingProvider"
Base="OrtTraceLoggingProvider.Verbose.File"
LoggingMode="Memory"
DetailLevel="Light" />
</Profiles>
</WindowsPerformanceRecorder>