### Description

Run clang-format in CI. Formatted all c/c++, objective-c/c++ files.

Excluded

```
    'onnxruntime/core/mlas/**',
    'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**',
```

because they contain assembly or is data heavy


### Motivation and Context

Coding style consistency
This commit is contained in:
Justin Chu 2023-04-18 09:26:58 -07:00 коммит произвёл GitHub
Родитель 2700d01642
Коммит cf19c3697d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1023 изменённых файлов: 11239 добавлений и 11055 удалений

Просмотреть файл

@ -19,4 +19,3 @@ DerivePointerAlignment: false
# NamespaceIndentation: All
...

Просмотреть файл

@ -14,7 +14,7 @@
# To lint local changes:
#
# ```bash
# lintrunner -m main
# lintrunner
# ```
#
# To lint all files:
@ -33,6 +33,8 @@
# To update an existing linting rule or create a new one, modify this file or create a
# new adapter following examples in https://github.com/justinchuby/lintrunner-adapters.
merge_base_with = 'origin/main'
[[linter]]
code = 'RUFF'
include_patterns = [
@ -168,3 +170,44 @@ command = [
'@{{PATHSFILE}}'
]
is_formatter = true
[[linter]]
code = 'CLANGFORMAT'
include_patterns = [
'**/*.h',
'**/*.cc',
'**/*.hpp',
'**/*.cpp',
'**/*.m',
'**/*.mm',
]
exclude_patterns = [
'java/**', # FIXME: Enable clang-format for java
'js/**',
'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**', # Contains data chunks
'onnxruntime/core/flatbuffers/schema/ort.fbs.h', # Generated code
'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
'onnxruntime/core/mlas/**', # Contains assembly code
'winml/**', # FIXME: Enable clang-format for winml
]
command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'clangformat_linter',
'--binary=clang-format',
'--fallback',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python',
'-m',
'lintrunner_adapters',
'run',
'pip_init',
'--dry-run={{DRYRUN}}',
'clang-format==16.0.1',
]
is_formatter = true

Просмотреть файл

@ -6,6 +6,6 @@
// Extending the std namespace is undefined behavior
// NOLINTNEXTLINE
namespace std {
inline char *getenv(const char*) { return nullptr; }
}
inline char* getenv(const char*) { return nullptr; }
} // namespace std
#endif

Просмотреть файл

@ -1,11 +1,11 @@
#include <stdio.h>
#include "winrt/microsoft.ai.machinelearning.h"
#include "winrt/windows.storage.h"
#include "winrt/windows.foundation.h"
#include "winrt/windows.foundation.collections.h"
#include "winrt/Windows.Graphics.h"
#include "winrt/Windows.Graphics.Imaging.h"
#include "winrt/Windows.Graphics.h"
#include "winrt/Windows.Media.h"
#include "winrt/microsoft.ai.machinelearning.h"
#include "winrt/windows.foundation.collections.h"
#include "winrt/windows.foundation.h"
#include "winrt/windows.storage.h"
#include <stdio.h>
#include <windows.h>
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
@ -15,42 +15,44 @@ using namespace winrt::Windows::Storage;
using namespace winrt::Windows::Media;
using namespace winrt::Windows::Graphics::Imaging;
std::wstring GetModulePath() {
std::wstring val;
wchar_t modulePath[MAX_PATH] = {0};
GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath));
wchar_t drive[_MAX_DRIVE];
wchar_t dir[_MAX_DIR];
wchar_t filename[_MAX_FNAME];
wchar_t ext[_MAX_EXT];
_wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
std::wstring GetModulePath()
{
std::wstring val;
wchar_t modulePath[MAX_PATH] = {0};
GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath));
wchar_t drive[_MAX_DRIVE];
wchar_t dir[_MAX_DIR];
wchar_t filename[_MAX_FNAME];
wchar_t ext[_MAX_EXT];
_wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
val = drive;
val += dir;
val = drive;
val += dir;
return val;
return val;
}
int main() {
printf("Load squeezenet.onnx.\n");
auto model = LearningModel::LoadFromFilePath(L"squeezenet.onnx");
printf("Load kitten_224.png as StorageFile.\n");
auto name = GetModulePath() + L"kitten_224.png";
auto image = StorageFile::GetFileFromPathAsync(name).get();
printf("Load StorageFile into Stream.\n");
auto stream = image.OpenAsync(FileAccessMode::Read).get();
printf("Create SoftwareBitmap from decoded Stream.\n");
auto softwareBitmap = BitmapDecoder::CreateAsync(stream).get().GetSoftwareBitmapAsync().get();
printf("Create VideoFrame.\n");
auto frame = VideoFrame::CreateWithSoftwareBitmap(softwareBitmap);
printf("Create LearningModelSession.\n");
auto session = LearningModelSession(model);
printf("Create LearningModelBinding.\n");
auto binding = LearningModelBinding(session);
printf("Bind data_0.\n");
binding.Bind(L"data_0", frame);
printf("Evaluate.\n");
auto results = session.Evaluate(binding, L"");
printf("Success!\n");
return 0;
int main()
{
printf("Load squeezenet.onnx.\n");
auto model = LearningModel::LoadFromFilePath(L"squeezenet.onnx");
printf("Load kitten_224.png as StorageFile.\n");
auto name = GetModulePath() + L"kitten_224.png";
auto image = StorageFile::GetFileFromPathAsync(name).get();
printf("Load StorageFile into Stream.\n");
auto stream = image.OpenAsync(FileAccessMode::Read).get();
printf("Create SoftwareBitmap from decoded Stream.\n");
auto softwareBitmap = BitmapDecoder::CreateAsync(stream).get().GetSoftwareBitmapAsync().get();
printf("Create VideoFrame.\n");
auto frame = VideoFrame::CreateWithSoftwareBitmap(softwareBitmap);
printf("Create LearningModelSession.\n");
auto session = LearningModelSession(model);
printf("Create LearningModelBinding.\n");
auto binding = LearningModelBinding(session);
printf("Bind data_0.\n");
binding.Bind(L"data_0", frame);
printf("Evaluate.\n");
auto results = session.Evaluate(binding, L"");
printf("Success!\n");
return 0;
}

Просмотреть файл

@ -47,10 +47,10 @@ struct CodeLocation {
out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function;
return out.str();
}
//utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8.
// utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8.
const std::string file_and_path;
const int line_num;
//utf-8
// utf-8
const std::string function;
const std::vector<std::string> stacktrace;
};

Просмотреть файл

@ -37,7 +37,6 @@
#include "core/common/make_string.h"
#include "core/common/status.h"
namespace onnxruntime {
using TimePoint = std::chrono::high_resolution_clock::time_point;

Просмотреть файл

@ -158,9 +158,9 @@ class NodeHashSet : public std::unordered_set<T,
template <typename Key, typename Value, typename Allocator>
class NodeHashMap : public std::unordered_map<Key, Value,
std::hash<Key>,
std::equal_to<Key>,
Allocator> {
std::hash<Key>,
std::equal_to<Key>,
Allocator> {
using Base = std::unordered_map<Key, Value,
std::hash<Key>,
std::equal_to<Key>,

Просмотреть файл

@ -255,15 +255,17 @@
#else
// Disabled in Release builds.
#define VLOGS(logger, level) \
if constexpr (true) {} else LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level)
if constexpr (true) { \
} else \
LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level)
#define VLOGS_USER(logger, level) \
if constexpr (true) {} else LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level)
if constexpr (true) { \
} else \
LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level)
#define VLOGF(logger, level, format_str, ...)
#define VLOGF_USER(logger, level, format_str, ...)
#endif
// Default logger variants
#define VLOGS_DEFAULT(level) \
VLOGS(::onnxruntime::logging::LoggingManager::DefaultLogger(), level)

Просмотреть файл

@ -75,7 +75,7 @@ struct EventRecord {
using Events = std::vector<EventRecord>;
//Execution Provider Profiler
// Execution Provider Profiler
class EpProfiler {
public:
virtual ~EpProfiler() = default;

Просмотреть файл

@ -121,10 +121,10 @@ class [[nodiscard]] Status {
Status(StatusCategory category, int code);
GSL_SUPPRESS(r.11)
GSL_SUPPRESS(r .11)
Status(const Status& other)
: state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {}
GSL_SUPPRESS(r.11)
GSL_SUPPRESS(r .11)
Status& operator=(const Status& other) {
if (state_ != other.state_) {
if (other.state_ == nullptr) {

Просмотреть файл

@ -22,12 +22,11 @@ namespace onnxruntime {
class ORTInvoker {
public:
ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
const logging::Logger& logger,
const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) :
execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) {
const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) : execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) {
if (!execution_provider_) {
ORT_THROW("Execution provider is nullptr");
ORT_THROW("Execution provider is nullptr");
}
}
@ -36,7 +35,7 @@ class ORTInvoker {
}
common::Status Invoke(const std::string& op_name,
//optional inputs / outputs?
// optional inputs / outputs?
const std::vector<OrtValue>& inputs,
std::vector<OrtValue>& outputs,
const NodeAttributes* attributes,

Просмотреть файл

@ -386,8 +386,8 @@ void AssignOpaqueDomainName(const char* domain, const char* name,
} // namespace data_types_internal
//The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
//However, we do not allocate this type on heap.
// The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
// However, we do not allocate this type on heap.
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 26436)
@ -614,7 +614,7 @@ class OptionalType :
#if !defined(DISABLE_OPTIONAL_TYPE)
OptionalType()
#else
OptionalType() : DisabledTypeBase { DataTypeImpl::GeneralType::kOptional, 0 }
OptionalType() : DisabledTypeBase{DataTypeImpl::GeneralType::kOptional, 0}
#endif
{
using namespace data_types_internal;

Просмотреть файл

@ -29,17 +29,17 @@
namespace onnxruntime {
namespace utils {
// The following primitives are strongly recommended for switching on tensor input datatypes for
// kernel implementations.
//
// 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros
// DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function<T>().
// 2) if you have a few types, use Tensor.IsDataType<T>()/IsDataTypeString() or use utils::IsPrimitiveDataType<T>()
// if you have a standalone MLDatatType with a sequence of if/else statements.
// 3) For something in between, we suggest to use CallDispatcher pattern.
//
// Invoking DataTypeImpl::GetType<T>() for switching on input types is discouraged and should be avoided.
// Every primitive type carries with it an integer constant that can be used for quick switching on types.
// The following primitives are strongly recommended for switching on tensor input datatypes for
// kernel implementations.
//
// 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros
// DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function<T>().
// 2) if you have a few types, use Tensor.IsDataType<T>()/IsDataTypeString() or use utils::IsPrimitiveDataType<T>()
// if you have a standalone MLDatatType with a sequence of if/else statements.
// 3) For something in between, we suggest to use CallDispatcher pattern.
//
// Invoking DataTypeImpl::GetType<T>() for switching on input types is discouraged and should be avoided.
// Every primitive type carries with it an integer constant that can be used for quick switching on types.
#define DispatchOnTensorType(tensor_type, function, ...) \
switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
@ -498,11 +498,10 @@ class ContainerChecker {
ORT_ENFORCE(++index < c.size(), "Sequence is missing type entry for its element");
constexpr int32_t prim_type = ToTensorProtoElementType<T>();
// Check if this is a primitive type and it matches
if constexpr(prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
if constexpr (prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
c[index].IsPrimType(prim_type);
}
else {
} else {
// T is not primitive, check next entry for non-primitive proto
return IsContainerOfType<T>::check(c, index);
}
@ -528,11 +527,11 @@ class ContainerChecker {
}
ORT_ENFORCE(++index < c.size(), "Map is missing type entry for its value");
constexpr int32_t val_type = ToTensorProtoElementType<V>();
if constexpr(val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
if constexpr (val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
c[index].IsPrimType(val_type);
}
else return IsContainerOfType<V>::check(c, index);
} else
return IsContainerOfType<V>::check(c, index);
}
};

Просмотреть файл

@ -69,10 +69,9 @@ struct BFloat16 {
val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}
#else
if constexpr(endian::native == endian::little) {
if constexpr (endian::native == endian::little) {
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
}
else {
} else {
std::memcpy(&val, &v, sizeof(uint16_t));
}
#endif
@ -93,11 +92,10 @@ struct BFloat16 {
float result;
char* const first = reinterpret_cast<char*>(&result);
char* const second = first + sizeof(uint16_t);
if constexpr(endian::native == endian::little) {
if constexpr (endian::native == endian::little) {
std::memset(first, 0, sizeof(uint16_t));
std::memcpy(second, &val, sizeof(uint16_t));
}
else {
} else {
std::memcpy(first, &val, sizeof(uint16_t));
std::memset(second, 0, sizeof(uint16_t));
}
@ -117,7 +115,6 @@ inline ORT_HOST_DEVICE bool operator==(const BFloat16& left, const BFloat16& rig
inline ORT_HOST_DEVICE bool operator!=(const BFloat16& left, const BFloat16& right) { return left.val != right.val; }
inline ORT_HOST_DEVICE bool operator<(const BFloat16& left, const BFloat16& right) { return left.val < right.val; }
// User defined suffixes to make it easier to declare
// initializers with MLFloat16 and BFloat16 from unsigned short
// E.g 10_f16 or 10_b16

Просмотреть файл

@ -10,7 +10,7 @@ using DestroyFunc = void (*)(void*, void*);
using AllocatorHandle = void*;
typedef struct {
//right now we only include allocation for host memory
// right now we only include allocation for host memory
AllocateFunc allocate_func;
DestroyFunc release_func;
AllocatorHandle allocator_handle;

Просмотреть файл

@ -78,10 +78,10 @@ inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { ret
std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info);
namespace std {
template<>
template <>
struct hash<OrtMemoryInfo> {
size_t operator()(const OrtMemoryInfo& i) const {
return i.Hash();
}
};
}
} // namespace std

Просмотреть файл

@ -18,8 +18,8 @@ class DataTransferManager;
/**
* @brief This is a Sparse Format enumeration
*
*
*
*
*/
enum class SparseFormat : uint32_t {
kUndefined = 0x0U, // For completeness
@ -31,7 +31,7 @@ enum class SparseFormat : uint32_t {
std::ostream& operator<<(std::ostream&, SparseFormat);
/**
* @brief This class implements SparseTensor.
* @brief This class implements SparseTensor.
* This class holds sparse non-zero data (values) and sparse format
* specific indices. There are two main uses for the class (similar to that of Tensor)
* - one is to re-present model sparse inputs. Such inputs typically reside
@ -43,7 +43,7 @@ std::ostream& operator<<(std::ostream&, SparseFormat);
* be used to supply pointers to format specific indices. These buffers are used as is
* and will not be modified or deallocated by the instance. However, the lifespan of the buffers
* must eclipse the lifespan of the SparseTensor instance.
*
*
* - Represent sparse data that is a result of format conversion or a computation result. Use second constructor
* to supply a desired allocator. Use Make*() format specific interfaces to supply values and format
* specific indices. The specified data will be copied into an internally allocated buffer.
@ -446,7 +446,6 @@ class SparseTensor final {
const TensorShape& values_shape, const void* values_data,
const TensorShape& indices_shape, const int32_t* indices_data);
/// <summary>
/// The method allocates a single contiguous buffer and creates instances of std::strings in it, with
/// copies of the supplied zero-terminated strings followed by COO indices.

Просмотреть файл

@ -167,5 +167,4 @@ class IStreamCommandHandleRegistry {
virtual void RegisterCreateStreamFn(OrtDevice::DeviceType device_type, CreateStreamFn f) = 0;
};
} // namespace onnxruntime

Просмотреть файл

@ -3,7 +3,7 @@
#pragma once
#include <stddef.h> // needed for size_t on some platforms
#include <stddef.h> // needed for size_t on some platforms
namespace onnxruntime {

Просмотреть файл

@ -13,8 +13,8 @@ class Node;
namespace onnxruntime {
/**
@class Function
/**
@class Function
Class representing a Function.
*/
class Function {

Просмотреть файл

@ -408,9 +408,9 @@ class Node {
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
/**
* Clears removable attributes. These are no longer needed after the initialization
* of the session. The function returns the number of removed attributes.
*/
* Clears removable attributes. These are no longer needed after the initialization
* of the session. The function returns the number of removed attributes.
*/
int PruneRemovableAttributes(gsl::span<const std::string> removable_attributes);
#if !defined(ORT_MINIMAL_BUILD)
@ -659,7 +659,7 @@ class Node {
std::vector<std::unique_ptr<Graph>> subgraphs_;
// Can be saved? The node cannot be saved anymore if removable attributes have been cleared.
bool can_be_saved_;
bool can_be_saved_;
};
/**

Просмотреть файл

@ -21,7 +21,7 @@ using OpName_Domain_Version_Schema_Map = std::unordered_map<
std::string,
std::unordered_map<std::string, std::map<ONNX_NAMESPACE::OperatorSetVersion, ONNX_NAMESPACE::OpSchema>>>;
/**
/**
@struct SchemaRegistryVersion
onnxruntime schema registry is a supplement to the built-in ONNX schema.
Every schema registry represent a collection of schema deltas from baseline_opset_version to opset_version
@ -60,7 +60,7 @@ class IOnnxRuntimeOpSchemaCollection : public ONNX_NAMESPACE::ISchemaRegistry {
int* earliest_opset_where_unchanged) const = 0;
};
/**
/**
@class OnnxRuntimeOpSchemaRegistry
OnnxRuntimeOpSchemaRegistry is used to provide supplement for built-in ONNX schemas.
@ -111,10 +111,10 @@ class OnnxRuntimeOpSchemaRegistry : public IOnnxRuntimeOpSchemaCollection {
};
/**
@class SchemaRegistryManager
@class SchemaRegistryManager
SchemaRegistryManager provides a view based on built-in ONNX schema and a list of
OnnxRuntimeOpSchemaRegistry as supplement.
SchemaRegistryManager provides a view based on built-in ONNX schema and a list of
OnnxRuntimeOpSchemaRegistry as supplement.
The user needs to make sure the customized schema registry is valid, otherwise the behavior is undefined.
@ -137,7 +137,7 @@ class SchemaRegistryManager : public onnxruntime::IOnnxRuntimeOpSchemaCollection
/** Gets the last released opset versions.
@param is_onnx_only If true, return ONNX schemas only. If false, return the schemas for all domains.
*/
DomainToVersionMap GetLastReleasedOpsetVersions(bool is_onnx_only) const ;
DomainToVersionMap GetLastReleasedOpsetVersions(bool is_onnx_only) const;
/**
Gets the OpSchema and its history.
Searches custom schema registries starting with the last one added. \

Просмотреть файл

@ -19,7 +19,7 @@ The interface for in-place transformation of a Graph.
class GraphTransformer {
public:
GraphTransformer(const std::string& name,
const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: name_(name), compatible_provider_types_(compatible_execution_providers) {
}

Просмотреть файл

@ -9,23 +9,23 @@
namespace onnxruntime {
/**
@class RewriteRule
@class RewriteRule
The base class for a rewrite rule. A rewrite rule represents a semantics-preserving transformation of a
computation graph. It can be used to represent, for example, the elimination of operators that serve as
no-ops (e.g., dropout during inference), as well as inlining of "function" definitions or the dual operation
of replacing a complex expression by an equivalent function-call). Unlike the more general GraphTransformer,
a rewrite rule is a more local transformation that is triggered on a particular node of the graph.
The base class for a rewrite rule. A rewrite rule represents a semantics-preserving transformation of a
computation graph. It can be used to represent, for example, the elimination of operators that serve as
no-ops (e.g., dropout during inference), as well as inlining of "function" definitions or the dual operation
of replacing a complex expression by an equivalent function-call). Unlike the more general GraphTransformer,
a rewrite rule is a more local transformation that is triggered on a particular node of the graph.
Each rule has a set of conditions and a body. The conditions have to be satisfied for the body of the rule
to be triggered. Therefore, when creating a new rewrite rule, two main functions have to be implemented:
- SatisfyCondition defines the condition checks. It is advisable to add the more selective checks first,
Each rule has a set of conditions and a body. The conditions have to be satisfied for the body of the rule
to be triggered. Therefore, when creating a new rewrite rule, two main functions have to be implemented:
- SatisfyCondition defines the condition checks. It is advisable to add the more selective checks first,
because those will lead to discarding fast rules that cannot be applied on a node.
- Apply is the actual body of the rule that will be executed if SatisfyCondition returns true for a particular
node. Note that additional, more complex checks can be included in the Apply if putting them in the
SatisfyCondition would lead to duplicate work (e.g., when we make a check on a Node attribute but we need
that attribute to execute the rule too).
In general, simple fast checks are a better fit for SatisfyCondition, whereas more complex ones can be added
In general, simple fast checks are a better fit for SatisfyCondition, whereas more complex ones can be added
in the Apply.
In order to avoid evaluating the SatisfyCondition for each rule and each node of the graph, each rewrite rule
@ -75,13 +75,13 @@ class RewriteRule {
const std::string name_;
/** Checks if the Node of the given Graph satisfies the conditions of this rule. The body of the rule will be
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
/** Checks if the Node of the given Graph satisfies the conditions of this rule. The body of the rule will be
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
of the nodes. */
virtual bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const = 0;
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
The return-value of node may be different from the input-value due to rewriting.
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const = 0;

Просмотреть файл

@ -13,12 +13,12 @@ namespace onnxruntime {
/**
@class RuleBasedGraphTransformer
Rule-based graph transformer that provides an API to register rewrite rules
Rule-based graph transformer that provides an API to register rewrite rules
and an API to apply all applicable rules to a Graph.
Represents an IGraphTransformer determined by a set of rewrite rules.
The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy.
Several rewriting-strategies are possible when traversing the graph and applying rewrite rules,
Several rewriting-strategies are possible when traversing the graph and applying rewrite rules,
each with different trade offs. At the moment, we define one that performs top-down traversal of nodes.
@TODO: Is a bottom-up traversal more efficient?
@ -36,7 +36,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
/** Registers a rewrite rule in this transformer. */
Status Register(std::unique_ptr<RewriteRule> rule);
/** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type
/** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type
by this rule-based transformer.
@returns a pointer to the vector containing all the registered rewrite rules. */
const InlinedVector<std::reference_wrapper<const RewriteRule>>* GetRewriteRulesForOpType(const std::string& op_type) const {

Просмотреть файл

@ -240,23 +240,23 @@ class ThreadPoolProfiler {
~ThreadPoolProfiler();
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler);
using Clock = std::chrono::high_resolution_clock;
void Start(); //called by executor to start profiling
std::string Stop(); //called by executor to stop profiling and return collected numbers
void LogStart(); //called in main thread to record the starting time point
void LogEnd(ThreadPoolEvent); //called in main thread to calculate and save the time elapsed from last start point
void Start(); // called by executor to start profiling
std::string Stop(); // called by executor to stop profiling and return collected numbers
void LogStart(); // called in main thread to record the starting time point
void LogEnd(ThreadPoolEvent); // called in main thread to calculate and save the time elapsed from last start point
void LogEndAndStart(ThreadPoolEvent);
void LogStartAndCoreAndBlock(std::ptrdiff_t block_size);
void LogCoreAndBlock(std::ptrdiff_t block_size); //called in main thread to log core and block size for task breakdown
void LogThreadId(int thread_idx); //called in child thread to log its id
void LogRun(int thread_idx); //called in child thread to log num of run
std::string DumpChildThreadStat(); //return all child statitics collected so far
void LogCoreAndBlock(std::ptrdiff_t block_size); // called in main thread to log core and block size for task breakdown
void LogThreadId(int thread_idx); // called in child thread to log its id
void LogRun(int thread_idx); // called in child thread to log num of run
std::string DumpChildThreadStat(); // return all child statitics collected so far
private:
static const char* GetEventName(ThreadPoolEvent);
struct MainThreadStat {
uint64_t events_[MAX_EVENT] = {};
int32_t core_ = -1;
std::vector<std::ptrdiff_t> blocks_; //block size determined by cost model
std::vector<std::ptrdiff_t> blocks_; // block size determined by cost model
std::vector<onnxruntime::TimePoint> points_;
void LogCore();
void LogBlockSize(std::ptrdiff_t block_size);
@ -266,7 +266,7 @@ class ThreadPoolProfiler {
std::string Reset();
};
bool enabled_ = false;
MainThreadStat& GetMainThreadStat(); //return thread local stat
MainThreadStat& GetMainThreadStat(); // return thread local stat
int num_threads_;
#ifdef _MSC_VER
#pragma warning(push)
@ -277,7 +277,7 @@ class ThreadPoolProfiler {
std::thread::id thread_id_;
uint64_t num_run_ = 0;
onnxruntime::TimePoint last_logged_point_ = Clock::now();
int32_t core_ = -1; //core that the child thread is running on
int32_t core_ = -1; // core that the child thread is running on
};
#ifdef _MSC_VER
#pragma warning(pop)
@ -770,7 +770,8 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
for (auto i = 0u; i < num_threads_; i++) {
worker_data_[i].thread.reset(env_.CreateThread(name, i, WorkerLoop, this, thread_options));
}
} ORT_CATCH(...) {
}
ORT_CATCH(...) {
ORT_HANDLE_EXCEPTION([&]() {
SignalAllAndWait();
throw;
@ -1336,7 +1337,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
#pragma warning(push)
// C4324: structure was padded due to alignment specifier
#pragma warning(disable : 4324)
#endif // _MSC_VER
#endif // _MSC_VER
struct ORT_ALIGN_TO_AVOID_FALSE_SHARING PerThread {
constexpr PerThread() : pool(nullptr) {
@ -1358,8 +1359,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
#ifdef _MSC_VER
#pragma warning(pop)
#endif // _MSC_VER
#endif // _MSC_VER
struct WorkerData {
constexpr WorkerData() : thread(), queue() {

Просмотреть файл

@ -127,7 +127,6 @@ struct TensorOpCost {
double compute_cycles;
};
namespace concurrency {
template <typename Environment>
@ -197,11 +196,11 @@ class ThreadPool {
// parallel loops.
class ParallelSection {
public:
explicit ParallelSection(ThreadPool *tp);
public:
explicit ParallelSection(ThreadPool* tp);
~ParallelSection();
private:
private:
friend class ThreadPool;
// Owning reference for the underlying ThreadPoolParallelSection
@ -210,7 +209,7 @@ class ThreadPool {
// ThreadPoolParallelSection does not need to be available at this
// point to avoid a dependence on the Eigen headers.
ThreadPoolParallelSection* ps_{nullptr};
ThreadPool *tp_;
ThreadPool* tp_;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ParallelSection);
};

Просмотреть файл

@ -25,53 +25,53 @@ Environment:
// in this file configures the provider as a normal (non-telemetry) provider.
#ifndef TraceLoggingOptionMicrosoftTelemetry
#define TraceLoggingOptionMicrosoftTelemetry() \
TraceLoggingOptionGroup(0000000000, 00000, 00000, 0000, 0000, 0000, 0000, 0000, 000, 0000, 0000)
// Empty definition for TraceLoggingOptionMicrosoftTelemetry
TraceLoggingOptionGroup(0000000000, 00000, 00000, 0000, 0000, 0000, 0000, 0000, 000, 0000, 0000)
// Empty definition for TraceLoggingOptionMicrosoftTelemetry
#endif
// Configuration macro for use in TRACELOGGING_DEFINE_PROVIDER. The definition
// in this file configures the provider as a normal (non-telemetry) provider.
#define TraceLoggingOptionWindowsCoreTelemetry() \
// Empty definition for TraceLoggingOptionWindowsCoreTelemetry
// Empty definition for TraceLoggingOptionWindowsCoreTelemetry
// Event privacy tags. Use the PDT macro values for the tag parameter, e.g.:
// TraceLoggingWrite(...,
// TelemetryPrivacyDataTag(PDT_BrowsingHistory | PDT_ProductAndServiceUsage),
// ...);
#define TelemetryPrivacyDataTag(tag) TraceLoggingUInt64((tag), "PartA_PrivTags")
#define PDT_BrowsingHistory 0x0000000000000002u
#define PDT_BrowsingHistory 0x0000000000000002u
#define PDT_DeviceConnectivityAndConfiguration 0x0000000000000800u
#define PDT_InkingTypingAndSpeechUtterance 0x0000000000020000u
#define PDT_ProductAndServicePerformance 0x0000000001000000u
#define PDT_ProductAndServiceUsage 0x0000000002000000u
#define PDT_SoftwareSetupAndInventory 0x0000000080000000u
#define PDT_InkingTypingAndSpeechUtterance 0x0000000000020000u
#define PDT_ProductAndServicePerformance 0x0000000001000000u
#define PDT_ProductAndServiceUsage 0x0000000002000000u
#define PDT_SoftwareSetupAndInventory 0x0000000080000000u
// Event categories specified via keywords, e.g.:
// TraceLoggingWrite(...,
// TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
// ...);
#define MICROSOFT_KEYWORD_CRITICAL_DATA 0x0000800000000000 // Bit 47
#define MICROSOFT_KEYWORD_MEASURES 0x0000400000000000 // Bit 46
#define MICROSOFT_KEYWORD_TELEMETRY 0x0000200000000000 // Bit 45
#define MICROSOFT_KEYWORD_RESERVED_44 0x0000100000000000 // Bit 44 (reserved for future assignment)
#define MICROSOFT_KEYWORD_CRITICAL_DATA 0x0000800000000000 // Bit 47
#define MICROSOFT_KEYWORD_MEASURES 0x0000400000000000 // Bit 46
#define MICROSOFT_KEYWORD_TELEMETRY 0x0000200000000000 // Bit 45
#define MICROSOFT_KEYWORD_RESERVED_44 0x0000100000000000 // Bit 44 (reserved for future assignment)
// Event categories specified via event tags, e.g.:
// TraceLoggingWrite(...,
// TraceLoggingEventTag(MICROSOFT_EVENTTAG_REALTIME_LATENCY),
// ...);
#define MICROSOFT_EVENTTAG_DROP_USER_IDS 0x00008000
#define MICROSOFT_EVENTTAG_AGGREGATE 0x00010000
#define MICROSOFT_EVENTTAG_DROP_PII_EXCEPT_IP 0x00020000
#define MICROSOFT_EVENTTAG_COSTDEFERRED_LATENCY 0x00040000
#define MICROSOFT_EVENTTAG_CORE_DATA 0x00080000
#define MICROSOFT_EVENTTAG_INJECT_XTOKEN 0x00100000
#define MICROSOFT_EVENTTAG_REALTIME_LATENCY 0x00200000
#define MICROSOFT_EVENTTAG_NORMAL_LATENCY 0x00400000
#define MICROSOFT_EVENTTAG_CRITICAL_PERSISTENCE 0x00800000
#define MICROSOFT_EVENTTAG_NORMAL_PERSISTENCE 0x01000000
#define MICROSOFT_EVENTTAG_DROP_PII 0x02000000
#define MICROSOFT_EVENTTAG_HASH_PII 0x04000000
#define MICROSOFT_EVENTTAG_MARK_PII 0x08000000
#define MICROSOFT_EVENTTAG_DROP_USER_IDS 0x00008000
#define MICROSOFT_EVENTTAG_AGGREGATE 0x00010000
#define MICROSOFT_EVENTTAG_DROP_PII_EXCEPT_IP 0x00020000
#define MICROSOFT_EVENTTAG_COSTDEFERRED_LATENCY 0x00040000
#define MICROSOFT_EVENTTAG_CORE_DATA 0x00080000
#define MICROSOFT_EVENTTAG_INJECT_XTOKEN 0x00100000
#define MICROSOFT_EVENTTAG_REALTIME_LATENCY 0x00200000
#define MICROSOFT_EVENTTAG_NORMAL_LATENCY 0x00400000
#define MICROSOFT_EVENTTAG_CRITICAL_PERSISTENCE 0x00800000
#define MICROSOFT_EVENTTAG_NORMAL_PERSISTENCE 0x01000000
#define MICROSOFT_EVENTTAG_DROP_PII 0x02000000
#define MICROSOFT_EVENTTAG_HASH_PII 0x04000000
#define MICROSOFT_EVENTTAG_MARK_PII 0x08000000
// Field categories specified via field tags, e.g.:
// TraceLoggingWrite(...,

Просмотреть файл

@ -4,7 +4,7 @@
#pragma once
#pragma warning(push)
#pragma warning(disable : 4201) // nonstandard extension used: nameless struct/union
#pragma warning(disable : 4201) // nonstandard extension used: nameless struct/union
#ifdef _GAMING_XBOX_SCARLETT
#include <d3d12_xs.h>
#elif defined(_GAMING_XBOX_XBOXONE)
@ -15,10 +15,10 @@
#pragma warning(pop)
#ifdef __cplusplus
#include <DirectML.h>
#include <DirectML.h>
#else
struct IDMLDevice;
typedef struct IDMLDevice IDMLDevice;
struct IDMLDevice;
typedef struct IDMLDevice IDMLDevice;
#endif
// Windows pollutes the macro space, causing a build break in constants.h.
@ -36,8 +36,8 @@ extern "C" {
* The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead.
*
* Creates a DirectML Execution Provider which executes on the hardware adapter with the given device_id, also known as
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
* primary display GPU installed on the system. A negative device_id is invalid.
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
@ -49,8 +49,8 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOpti
*
* Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12
* command queue. The DirectML device and D3D12 command queue must have the same parent ID3D12Device, or an error will
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
* objects.
* See also: DMLCreateDevice
* See also: ID3D12Device::CreateCommandQueue
@ -58,47 +58,46 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOpti
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
struct OrtDmlApi;
typedef struct OrtDmlApi OrtDmlApi;
struct OrtDmlApi {
/**
* Creates a DirectML Execution Provider which executes on the hardware adapter with the given device_id, also known as
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
* primary display GPU installed on the system. A negative device_id is invalid.
*/
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
/**
* Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12
* command queue. The DirectML device and D3D12 command queue must have the same parent ID3D12Device, or an error will
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
* objects.
* See also: DMLCreateDevice
* See also: ID3D12Device::CreateCommandQueue
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML1, _In_ OrtSessionOptions* options,
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
/**
* CreateGPUAllocationFromD3DResource
* This API creates a DML EP resource based on a user-specified D3D12 resource.
*/
* CreateGPUAllocationFromD3DResource
* This API creates a DML EP resource based on a user-specified D3D12 resource.
*/
ORT_API2_STATUS(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource);
/**
* FreeGPUAllocation
* This API frees the DML EP resource created by CreateGPUAllocationFromD3DResource.
*/
* FreeGPUAllocation
* This API frees the DML EP resource created by CreateGPUAllocationFromD3DResource.
*/
ORT_API2_STATUS(FreeGPUAllocation, _In_ void* dml_resource);
/**
* GetD3D12ResourceFromAllocation
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
*/
* GetD3D12ResourceFromAllocation
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
*/
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource);
};

Просмотреть файл

@ -63,7 +63,7 @@ struct Value : Ort::Value {
static Ort::Value CreateTensor(const std::vector<int64_t>& shape, ONNXTensorElementDataType type);
};
}
}
} // namespace Experimental
} // namespace Ort
#include "experimental_onnxruntime_cxx_inline.h"

Просмотреть файл

@ -104,5 +104,5 @@ inline Ort::Value Value::CreateTensor(const std::vector<int64_t>& shape, ONNXTen
return Ort::Value::CreateTensor(allocator, shape.data(), shape.size(), type);
}
}
}
} // namespace Experimental
} // namespace Ort

Просмотреть файл

@ -452,7 +452,6 @@ typedef struct OrtCUDAProviderOptions {
*/
int tunable_op_tuning_enable;
} OrtCUDAProviderOptions;
/** \brief ROCM Provider Options

Просмотреть файл

@ -356,10 +356,10 @@ using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
* constructors to construct an instance of a Status object from exceptions.
*/
struct Status : detail::Base<OrtStatus> {
explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message.
std::string GetErrorMessage() const;
OrtErrorCode GetErrorCode() const;
@ -473,7 +473,6 @@ struct RunOptions : detail::Base<OrtRunOptions> {
RunOptions& UnsetTerminate();
};
namespace detail {
// Utility function that returns a SessionOption config entry key for a specific custom operator.
// Ex: custom_op.[custom_op_name].[config]
@ -514,7 +513,7 @@ struct CustomOpConfigs {
* {"my_op.key", "value"}.
*
* \return An unordered map of flattened configurations.
*/
*/
const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
private:
@ -574,7 +573,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers

Просмотреть файл

@ -6,7 +6,7 @@
# Requires a .clang-format config file to be in the current directory or a parent directory from where the script is run.
# Expected usage is to run it from its current location so that source in 'core' and 'test' is updated.
gci -Recurse -Include *.h, *.cc | foreach {
Write-Host "Updating " $_.FullName
clang-format -i $_
}
gci -Recurse -Include *.h, *.cc | foreach {
Write-Host "Updating " $_.FullName
clang-format -i $_
}

Просмотреть файл

@ -14,14 +14,14 @@ class IAttentionMechanism {
virtual ~IAttentionMechanism() = default;
virtual void PrepareMemory(
const gsl::span<const T>& memory,
const gsl::span<const int>& memory_sequence_lengths) = 0;
const gsl::span<const T>& memory,
const gsl::span<const int>& memory_sequence_lengths) = 0;
virtual void Compute(
const gsl::span<const T>& query,
const gsl::span<const T>& prev_alignment,
const gsl::span<T>& output,
const gsl::span<T>& alignment) const = 0;
const gsl::span<const T>& query,
const gsl::span<const T>& prev_alignment,
const gsl::span<T>& output,
const gsl::span<T>& alignment) const = 0;
virtual const gsl::span<const T> Values() const = 0;
@ -32,5 +32,5 @@ class IAttentionMechanism {
virtual bool NeedPrevAlignment() const = 0;
};
}
}
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -8,7 +8,7 @@
#include <memory>
using onnxruntime::rnn::detail::Allocate;
//TODO: fix the warnings
// TODO: fix the warnings
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(disable : 26451)
#endif
@ -55,9 +55,9 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
}
if (has_attn_layer_) {
//concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) =
// p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_
// The first part is calulated above. Here just add the later.
// concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) =
// p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_
// The first part is calulated above. Here just add the later.
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
attn_context_.data(), attn_context_depth_,
@ -76,7 +76,7 @@ void AttentionWrapper<T>::SetWeights(const gsl::span<const T>& wrapper_weights)
has_attn_layer_ = !wrapper_weights.empty();
if (has_attn_layer_) {
//cell weight size and attn weight size in the attn layer
// cell weight size and attn weight size in the attn layer
size_t cws = inner_cell_hidden_size_ * attn_layer_depth_;
size_t aws = attn_context_depth_ * attn_layer_depth_;
attn_layer_cell_weights_ = wrapper_weights.subspan(0, cws);

Просмотреть файл

@ -108,17 +108,17 @@ static void SoftmaxInplace(const gsl::span<T>& alignments) {
}
/**
* Args:
* queries: Tensor, shape `[batch_size_, query_depth_]` to compare to keys.
* keys_: Processed memory, shape `[batch_size_, max_memory_step_, attn_depth_]`.
*/
* Args:
* queries: Tensor, shape `[batch_size_, query_depth_]` to compare to keys.
* keys_: Processed memory, shape `[batch_size_, max_memory_step_, attn_depth_]`.
*/
template <typename T>
void BahdanauAttention<T>::Compute(
const gsl::span<const T>& queries,
const gsl::span<const T>&, // Not used by bahdanau attention
const gsl::span<T>& output,
const gsl::span<T>& aligns) const {
//process query in dense query layer without bias
// process query in dense query layer without bias
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
batch_size_, attn_depth_, query_depth_, T{1.0},
queries.data(), query_depth_,

Просмотреть файл

@ -11,7 +11,7 @@
#include "core/common/narrow.h"
#include "core/platform/threadpool.h"
#include "core/framework/allocator.h"
//TODO: fix the warnings
// TODO: fix the warnings
#if defined(_MSC_VER) && !defined(__clang__)
// Chance of arithmetic overflow could be reduced
#pragma warning(disable : 26451)
@ -26,7 +26,7 @@ extern template class BahdanauAttention<float>;
/* AttnLSTM operator */
ONNX_OPERATOR_KERNEL_EX(
AttnLSTM, //name
AttnLSTM, // name
kMSDomain,
1,
kCpuExecutionProvider,

Просмотреть файл

@ -14,7 +14,7 @@
#ifdef _MSC_VER
#pragma warning(pop)
#endif
//TODO: fix the warnings
// TODO: fix the warnings
#if defined(_MSC_VER) && !defined(__clang__)
// Chance of arithmetic overflow could be reduced
#pragma warning(disable : 26451)
@ -280,7 +280,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
// after the first step this will switch to the output from the previous step
span_T_const_iter previous_state = batched_hidden_state_one_step.begin();
//run through steps sequentially
// run through steps sequentially
for (int step = 0; step < max_sequence_length; step++) {
const std::string seqno_str = " [seqno=" + std::to_string(step) + "]";
@ -335,7 +335,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
}
if (output_sequence) {
//set to 0 if step >= sequence_length
// set to 0 if step >= sequence_length
for (int lrow = 0; lrow < batch_size_; lrow++) {
if (step >= min_sequence_length && step >= sequence_lengths[lrow]) {
auto dst = outputs.data() + step * output_step_length + lrow * hidden_size_;
@ -409,7 +409,7 @@ void UniDirectionalAttnLstm<T>::GateComputations(span_T_iter& out, span_T_iter&
const float* pBi = use_bias_ ? SafeRawConstPointer<T>(bias_WRi_, 0, hidden_size_) : nullptr;
clip_with_bias_ptr_(clip_, pBi, pi, hidden_size_); // post: pi has input to f() to calculate i
activation_f_.func(pi, hidden_size_, activation_f_.alpha, activation_f_.beta);
//DumpMatrix("i" + row_str, pi, 1, hidden_size_);
// DumpMatrix("i" + row_str, pi, 1, hidden_size_);
// Forget Gate
if (input_forget_) {

Просмотреть файл

@ -279,9 +279,8 @@ Status AttentionBase::CheckMask(const Tensor* mask_index,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size or 3 * batch_size + 2");
}
mask_type = (mask_dims[0] == batch_size ?
AttentionMaskType::MASK_1D_KEY_SEQ_LEN :
mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START : AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
mask_type = (mask_dims[0] == batch_size ? AttentionMaskType::MASK_1D_KEY_SEQ_LEN : mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START
: AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
} else if (mask_dims.size() == 2) {
if (mask_dims[0] == batch_size && mask_dims[1] == total_sequence_length) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;

Просмотреть файл

@ -42,16 +42,16 @@ enum AttentionKernelType {
struct AttentionParameters {
int batch_size;
int sequence_length;
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int original_past_sequence_length; // original sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int kv_sequence_length; // input sequence length of K or V
int past_sequence_length; // sequence length in past state of K or V
int original_past_sequence_length; // original sequence length in past state of K or V
int total_sequence_length; // total sequence length of K or V
int max_sequence_length; // max sequence length from 4D mask
int input_hidden_size; // first dimension of weights for input projection
int hidden_size; // hidden size of Q or K
int head_size; // hidden size per head of Q or K
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
bool is_unidirectional;
bool past_present_share_buffer;

Просмотреть файл

@ -16,21 +16,21 @@ namespace contrib {
class AttentionCPUBase : public AttentionBase {
protected:
AttentionCPUBase(const OpKernelInfo& info, bool require_same_hidden_size)
: AttentionBase(info, require_same_hidden_size) {}
: AttentionBase(info, require_same_hidden_size) {}
template <typename T>
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxNxSxH
const T* V, // V value with size BxNxSxH_v
const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
const Tensor* past, // past state
Tensor* output, // output tensor
int batch_size, // batch size (B)
int sequence_length, // sequence length (S)
int qk_head_size, // head size of Q or K (H)
int v_head_size, // head size of V (H_v)
int v_hidden_size, // hidden size of V (D_v)
const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
const T* K, // K data with shape BxNxSxH
const T* V, // V value with size BxNxSxH_v
const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
const Tensor* past, // past state
Tensor* output, // output tensor
int batch_size, // batch size (B)
int sequence_length, // sequence length (S)
int qk_head_size, // head size of Q or K (H)
int v_head_size, // head size of V (H_v)
int v_hidden_size, // hidden size of V (D_v)
const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT
OpKernelContext* context) const {
const int kv_sequence_length = sequence_length;
@ -206,7 +206,7 @@ class AttentionCPUBase : public AttentionBase {
const T* past, // past state
T* present, // present state
ThreadPool* tp) const {
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const ptrdiff_t past_chunk_length = SafeInt<ptrdiff_t>(past_sequence_length) * v_head_size; // P x H_v
const ptrdiff_t input_chunk_length = SafeInt<ptrdiff_t>(kv_sequence_length) * v_head_size; // L x H_v
const ptrdiff_t present_chunk_length = past_chunk_length + input_chunk_length; // T x H_v

Просмотреть файл

@ -60,7 +60,7 @@ Status BiasGelu<T, use_approximation>::Compute(OpKernelContext* context) const {
p_output[i] = value * (static_cast<T>(C) * value * value + static_cast<T>(B));
}
MlasComputeTanh(p_output, p_output,narrow<size_t>(count));
MlasComputeTanh(p_output, p_output, narrow<size_t>(count));
for (int64_t i = 0; i < count; i++) {
p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
@ -106,7 +106,7 @@ void BiasGelu<T, use_approximation>::AddBiasGelu(
temp[i] = value * 0.5f;
}
MlasComputeTanh(output, output,narrow<size_t>(count));
MlasComputeTanh(output, output, narrow<size_t>(count));
for (int64_t i = 0; i < count; i++) {
output[i] = temp[i] * (output[i] + 1.0f);
@ -118,7 +118,7 @@ void BiasGelu<T, use_approximation>::AddBiasGelu(
temp[i] = value * 0.5f;
}
MlasComputeErf(output, output,narrow<size_t>(count));
MlasComputeErf(output, output, narrow<size_t>(count));
for (int64_t i = 0; i < count; i++) {
output[i] = temp[i] * (output[i] + 1.0f);

Просмотреть файл

@ -7,7 +7,6 @@
#include "core/common/safeint.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
namespace contrib {
class BifurcationDetector : public OpKernel {

Просмотреть файл

@ -93,11 +93,7 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
failed.store(true, std::memory_order_release);
return;
}
int position_col_index = (position_ids_data == nullptr) ?
index % sequence_length :
(broadcast_position_ids ?
position_ids_data[index % sequence_length] :
position_ids_data[index]);
int position_col_index = (position_ids_data == nullptr) ? index % sequence_length : (broadcast_position_ids ? position_ids_data[index % sequence_length] : position_ids_data[index]);
if (position_col_index >= position_embedding_length) {
failed.store(true, std::memory_order_release);
return;

Просмотреть файл

@ -28,12 +28,12 @@ Status CheckInputs(const OpKernelContext* context, bool quantizedVersion) {
if (nullptr != position_ids) {
if (input_ids->Shape()[1] != position_ids->Shape()[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input_ids and position_ids shall have same sequence_length");
"input_ids and position_ids shall have same sequence_length");
}
if (position_ids->Shape()[0] != input_ids->Shape()[0] &&
position_ids->Shape()[0] != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids's first dimension shall be 1 or batch_size");
"position_ids's first dimension shall be 1 or batch_size");
}
}
}

Просмотреть файл

@ -163,7 +163,7 @@ Status CheckInputs(const T* query,
qkv_format = Q_KV_BSNH_BSN2H;
kv_sequence_length = static_cast<int>(key_dims[1]);
} else { // key_dims.size() == 4 (cross-attention with past_key)
} else { // key_dims.size() == 4 (cross-attention with past_key)
if (static_cast<int>(key_dims[1]) != num_heads || static_cast<int>(key_dims[3]) != head_size) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
@ -242,7 +242,7 @@ Status CheckInputs(const T* query,
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
}
v_hidden_size = static_cast<int>(value_dims[2]);
} else { // value_dims.size() == 4
} else { // value_dims.size() == 4
if (static_cast<int64_t>(kv_sequence_length) != value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall have the same dim 2 (kv_sequence_length)");

Просмотреть файл

@ -3,7 +3,6 @@
#include "ngram_repeat_block.h"
namespace onnxruntime {
namespace contrib {

Просмотреть файл

@ -9,5 +9,5 @@
namespace onnxruntime {
namespace contrib {
Status RegisterCpuContribKernels(KernelRegistry& kernel_registry);
} // namespace contrib
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -130,4 +130,4 @@ class Crop final : public CropBase, public OpKernel {
};
} // namespace contrib
} //namespace onnxruntime
} // namespace onnxruntime

Просмотреть файл

@ -21,14 +21,13 @@ namespace contrib {
*
* TODO!! implemente thread partition similar with
* fp16 conv operator
*/
*/
class NhwcPoolFp16 : public OpKernel {
public:
explicit NhwcPoolFp16(const OpKernelInfo& info)
: OpKernel(info),
pool_attrs_(info, info.GetKernelDef().OpName(), info.node().SinceVersion()),
is_max_pool_(info.GetKernelDef().OpName() == "MaxPool")
{}
pool_attrs_(info, info.GetKernelDef().OpName(), info.node().SinceVersion()),
is_max_pool_(info.GetKernelDef().OpName() == "MaxPool") {}
Status Compute(OpKernelContext* context) const override;
@ -182,7 +181,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
NhwcPoolFp16);
#endif
} // namespace contrib

Просмотреть файл

@ -6,16 +6,16 @@
namespace onnxruntime {
namespace contrib {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GridSample, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GridSample, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
GridSample<T>);
REGISTER_KERNEL_TYPED(float)

Просмотреть файл

@ -10,7 +10,7 @@
#include "core/util/math_cpuonly.h"
namespace onnxruntime {
namespace contrib{
namespace contrib {
template <typename T>
class ImageScaler final : public OpKernel {
@ -54,5 +54,5 @@ class ImageScaler final : public OpKernel {
float scale_;
std::vector<float> bias_;
};
}
} //namespace onnxruntime
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -62,9 +62,9 @@ inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrix
}
}
template<> inline
void SparseDenseMatMulImpl<float>(const ComputeCtx& ctx, const ConstSparseMatrixMap<float>& map_A,
const ConstEigenMatrixMapRowMajor<float>& map_B, EigenMatrixMapRowMajor<float>& output_map) {
template <>
inline void SparseDenseMatMulImpl<float>(const ComputeCtx& ctx, const ConstSparseMatrixMap<float>& map_A,
const ConstEigenMatrixMapRowMajor<float>& map_B, EigenMatrixMapRowMajor<float>& output_map) {
if (ctx.trans_A && ctx.trans_B) {
output_map = map_A.transpose() * ctx.alpha * map_B.transpose();
} else if (ctx.trans_A && !ctx.trans_B) {
@ -97,15 +97,15 @@ struct SparseToDenseCsr {
}
};
#endif //!defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
template<typename T> inline
T Mul(T a_value, float, T b_value) {
template <typename T>
inline T Mul(T a_value, float, T b_value) {
return a_value * b_value;
}
template <> inline
constexpr float Mul<float>(float a_value, float alpha, float b_value) {
template <>
inline constexpr float Mul<float>(float a_value, float alpha, float b_value) {
return a_value * alpha * b_value;
}
@ -203,7 +203,7 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const {
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "WASM and 32-bit builds support only COO format");
}
#endif //!defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
return Status::OK();
}
@ -211,4 +211,4 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const {
} // namespace contrib
} // namespace onnxruntime
#endif //!defined(DISABLE_SPARSE_TENSORS)
#endif //! defined(DISABLE_SPARSE_TENSORS)

Просмотреть файл

@ -5,14 +5,14 @@
namespace onnxruntime {
namespace contrib {
// Register MVN operator for backward compatibility.
// Register MVN operator for backward compatibility.
// The experimental MVN op was removed. The history has to be kept locally as below.
// As of (9/26/2018) MVN is a production function in ONNX.
// As of (9/26/2018) MVN is a production function in ONNX.
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
MeanVarianceNormalization,
1,
8,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
MeanVarianceNormalization_0<float>);
}
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -2,9 +2,9 @@
// MurmurHash3 was written by Austin Appleby, and is placed in the public
// domain. The author hereby disclaims copyright to this source code.
//scikit-learn is a Python module for machine learning built on top of SciPy and
//distributed under the 3-Clause BSD license. See https://github.com/scikit-learn/scikit-learn.
//This material is licensed under the BSD License (see https://github.com/scikit-learn/scikit-learn/blob/master/COPYING);
// scikit-learn is a Python module for machine learning built on top of SciPy and
// distributed under the 3-Clause BSD license. See https://github.com/scikit-learn/scikit-learn.
// This material is licensed under the BSD License (see https://github.com/scikit-learn/scikit-learn/blob/master/COPYING);
/* Modifications Copyright (c) Microsoft. */
#include "contrib_ops/cpu/murmur_hash3.h"
@ -200,7 +200,7 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const {
}
} else {
auto input = reinterpret_cast<const unsigned char*>(keys->DataRaw());
//input_element_bytes is 4, 8,.. less than 4 bytes is not allowed
// input_element_bytes is 4, 8,.. less than 4 bytes is not allowed
int input_num_bytes = static_cast<int>(input_element_bytes);
ORT_ENFORCE(input_num_bytes % 4 == 0);
const auto input_end = input + input_count * input_num_bytes;

Просмотреть файл

@ -18,10 +18,10 @@ class MurmurHash3 final : public OpKernel {
Status Compute(OpKernelContext* context) const override;
private:
private:
void MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out) const;
private :
private:
uint32_t seed_;
bool is_positive_{true};
};

Просмотреть файл

@ -323,12 +323,12 @@ Status NchwcUpsample::Compute(OpKernelContext* context) const {
const auto interpolation_w = ComputeInterpolation(input_w, output_w, scales_[3]);
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
const ptrdiff_t total_work =((SafeInt<ptrdiff_t>(batch_count) * nchwc_channels) / nchwc_block_size) * output_h;
const ptrdiff_t total_work = ((SafeInt<ptrdiff_t>(batch_count) * nchwc_channels) / nchwc_block_size) * output_h;
// Partition the work with the goal of generating the following number of
// elements, so that operations involving a smaller number of columns will
// process more rows per worker.
constexpr ptrdiff_t worker_goal = 16 * 1024;
ptrdiff_t work_per_worker = std::max<ptrdiff_t>(worker_goal / (SafeInt<ptrdiff_t>(output_w) * nchwc_block_size), 1);
ptrdiff_t work_per_worker = std::max<ptrdiff_t>(worker_goal / (SafeInt<ptrdiff_t>(output_w) * nchwc_block_size), 1);
ptrdiff_t worker_count = std::max<ptrdiff_t>(total_work / work_per_worker, 1);
auto upsample_worker = [&](ptrdiff_t batch) {

Просмотреть файл

@ -132,7 +132,7 @@ Status DynamicQuantizeLSTM::UseSharedPrePackedBuffers(std::vector<BufferUniquePt
#define WeightCheck(weight_shape, weight_name) \
if ((weight_shape.NumDimensions() != 1 && weight_shape.NumDimensions() != 2) || \
(weight_shape.NumDimensions() == 2 && weight_shape[1] != static_cast<int64_t>(hidden_size_) * 4) || \
(weight_shape.NumDimensions() == 2 && weight_shape[1] != static_cast<int64_t>(hidden_size_) * 4) || \
weight_shape[0] != num_directions_) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, \
"Input ", #weight_name, " must have shape {", num_directions_, "} for per-tensor/layer quantization or shape {", \

Просмотреть файл

@ -51,4 +51,3 @@ class QLinearSigmoid final : public QLinearLookupBase<T> {
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -13,7 +13,6 @@
namespace onnxruntime {
namespace contrib {
QLinearConcat::QLinearConcat(const OpKernelInfo& info) : OpKernel(info), ConcatBase(info) {
size_t input_def_count = info.node().InputDefs().size();
ORT_ENFORCE(input_def_count >= 5 && (input_def_count - 2) % 3 == 0,

Просмотреть файл

@ -19,7 +19,7 @@ class QLinearConcat final : public OpKernel, public ConcatBase {
private:
std::vector<std::vector<uint8_t>> fixed_lookup_tables_;
std::vector<int> fixed_table_attrs_; // is_static or not, is_copy or not
std::vector<int> fixed_table_attrs_; // is_static or not, is_copy or not
};
} // namespace contrib

Просмотреть файл

@ -21,7 +21,7 @@ class QLinearGlobalAveragePool final : public OpKernel {
bool channels_last_;
};
template<typename T8Bits>
template <typename T8Bits>
Status ComputeQLinearGlobalAvgPool(
const T8Bits* x,
float x_scale,

Просмотреть файл

@ -56,10 +56,10 @@ void QlinearBuildLookupTable(uint8_t* table,
const float X_scale = *(tensor_x_scale->Data<float>());
const T X_zero_point =
(tensor_x_zero_point == nullptr) ? static_cast<T>(0) : *(tensor_x_zero_point->Data<T>());
(tensor_x_zero_point == nullptr) ? static_cast<T>(0) : *(tensor_x_zero_point->Data<T>());
const float Y_scale = *(tensor_y_scale->Data<float>());
const T Y_zero_point =
(tensor_y_zero_point == nullptr) ? static_cast<T>(0) : *(tensor_y_zero_point->Data<T>());
(tensor_y_zero_point == nullptr) ? static_cast<T>(0) : *(tensor_y_zero_point->Data<T>());
float dequantized_input[256];
float dequantized_output[256];

Просмотреть файл

@ -37,6 +37,5 @@ void QlinearBuildLookupTable(uint8_t* table,
template <typename TOutput>
void QLinearLookupTableTransform(const uint8_t* x, const TOutput* table, TOutput* y, size_t n);
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -67,7 +67,7 @@ struct QLinearPool1DTask final {
}
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
for (std::ptrdiff_t c = begin; c < end; ++c) {
for (std::ptrdiff_t c = begin; c < end; ++c) {
operator()(c);
}
}
@ -149,7 +149,7 @@ struct QLinearPoolNhwc1DTask final {
int64_t element_count = (pool_attrs_.count_include_pad) ? kernel_shape[0] : hend - hstart;
for (int64_t c = 0; c < channels; ++c) {
PoolType::Finalize(element_count,Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
PoolType::Finalize(element_count, Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
y_d[phc + c] = quantize_value(Yh[onnxruntime::narrow<size_t>(c)], y_scale, y_zero_point);
}
}
@ -181,7 +181,7 @@ struct QLinearPool2DTask final {
}
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
for (std::ptrdiff_t c = begin; c < end; ++c) {
for (std::ptrdiff_t c = begin; c < end; ++c) {
operator()(c);
}
}
@ -287,7 +287,7 @@ struct QLinearPoolNhwc2DTask final {
int64_t input_index = channels * (h * width + wstart);
for (int64_t w = wstart; w < wend; ++w) {
for (int64_t c = 0; c < channels; c++) {
PoolType::Process(x_d[input_index + c],Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
PoolType::Process(x_d[input_index + c], Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
}
input_index += channels;
}
@ -295,7 +295,7 @@ struct QLinearPoolNhwc2DTask final {
int64_t elements_count = (pool_attrs_.count_include_pad) ? kernel_size : (hend - hstart) * (wend - wstart);
for (int64_t c = 0; c < channels; c++) {
PoolType::Finalize(elements_count,Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
PoolType::Finalize(elements_count, Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
auto y_value = quantize_value(Yh[onnxruntime::narrow<size_t>(c)], y_scale, y_zero_point);
y_d[pool_index + c] = y_value;
}
@ -337,7 +337,7 @@ struct QLinearPool3DTask final {
}
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
for (std::ptrdiff_t c = begin; c < end; ++c) {
for (std::ptrdiff_t c = begin; c < end; ++c) {
operator()(c);
}
}
@ -462,7 +462,7 @@ struct QLinearPoolNhwc3DTask final {
int64_t input_index = channels * (input_index_h + w * depth + dstart);
for (int64_t d = dstart; d < dend; ++d) {
for (int64_t c = 0; c < channels; c++) {
PoolType::Process(x_d[input_index + c],Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
PoolType::Process(x_d[input_index + c], Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
}
input_index += channels;
}
@ -471,7 +471,7 @@ struct QLinearPoolNhwc3DTask final {
int64_t elements_count = (pool_attrs_.count_include_pad) ? kernel_size : (hend - hstart) * (wend - wstart) * (dend - dstart);
for (int64_t c = 0; c < channels; c++) {
PoolType::Finalize(elements_count,Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
PoolType::Finalize(elements_count, Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
auto y_value = quantize_value(Yh[onnxruntime::narrow<size_t>(c)], y_scale, y_zero_point);
y_d[pool_index + c] = y_value;
}

Просмотреть файл

@ -25,7 +25,6 @@ static inline bool has_same_zero_point(bool is_signed, const Tensor* tensor_x_ze
const uint8_t X_zero_point = (tensor_x_zero_point == nullptr) ? static_cast<uint8_t>(0) : *(tensor_x_zero_point->Data<uint8_t>());
const uint8_t Y_zero_point = (tensor_y_zero_point == nullptr) ? static_cast<uint8_t>(0) : *(tensor_y_zero_point->Data<uint8_t>());
return X_zero_point == Y_zero_point;
}
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -95,7 +95,7 @@ ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() {
if (condition == target) {
// Transform the output to the correct value from LookupTable
std::transform(value.cbegin(), value.cend(), output.begin(),
[condition, target, &look_up_table,is_copy](const T& value_element) {
[condition, target, &look_up_table, is_copy](const T& value_element) {
return is_copy ? value_element : look_up_table[value_element];
});
} else {
@ -112,7 +112,7 @@ ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() {
auto output = per_iter_bh.OutputSpan<T>();
// Transform the output to the correct value from LookupTable
std::transform(condition.begin(), condition.end(), output.begin(),
[target, &value,&look_up_table,is_copy](bool condition_element) {
[target, &value, &look_up_table, is_copy](bool condition_element) {
return condition_element == target ? is_copy ? value : look_up_table[value] : T{};
});
},
@ -126,7 +126,7 @@ ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() {
auto output = per_iter_bh.OutputSpan<T>();
// Transform the output to the correct value from LookupTable
std::transform(condition.begin(), condition.end(), value.cbegin(), output.begin(),
[target,&look_up_table,is_copy](bool condition_element, const T& value_element) {
[target, &look_up_table, is_copy](bool condition_element, const T& value_element) {
return condition_element == target ? is_copy ? value_element : look_up_table[value_element] : T{};
});
}};
@ -310,16 +310,16 @@ QLinearWhere::QLinearWhere(const OpKernelInfo& info) : OpKernel(info) {
}
Status QLinearWhere::Compute(OpKernelContext* ctx) const {
// const auto* tensor_condition = ctx->Input<Tensor>(0);
// const auto* tensor_x_input = ctx->Input<Tensor>(1);
// const auto* tensor_condition = ctx->Input<Tensor>(0);
// const auto* tensor_x_input = ctx->Input<Tensor>(1);
const auto* tensor_x_scale = ctx->Input<Tensor>(2);
const auto* tensor_x_zero_point = ctx->Input<Tensor>(3);
// const auto* tensor_y_input = ctx->Input<Tensor>(4);
// const auto* tensor_y_input = ctx->Input<Tensor>(4);
const auto* tensor_y_scale = ctx->Input<Tensor>(5);
const auto* tensor_y_zero_point = ctx->Input<Tensor>(6);
const auto* tensor_z_scale = ctx->Input<Tensor>(7);
const auto* tensor_z_zero_point = ctx->Input<Tensor>(8);
// auto* tensor_output = ctx->Output(0, tensor_condition->Shape());
// auto* tensor_output = ctx->Output(0, tensor_condition->Shape());
ORT_ENFORCE(tensor_x_scale->IsDataType<float>(), "Input scale is not float for quantized input x @ 2");
ORT_ENFORCE(tensor_y_scale->IsDataType<float>(), "Input scale is not float for quantized input y @ 5");
ORT_ENFORCE(tensor_z_scale->IsDataType<float>(), "Input scale is not float for quantized output z @ 7");
@ -377,7 +377,7 @@ Status QLinearWhere::Compute(OpKernelContext* ctx) const {
if (!is_y_copy) {
std::copy(y_lookup_table.begin(), y_lookup_table.end(), y_user_data.begin() + 2);
}
//Allocator, Allocation, and SelectBroadcastFuncs are the same implementation from where_op.cc
// Allocator, Allocation, and SelectBroadcastFuncs are the same implementation from where_op.cc
const auto typed_tensor_allocation = [](const TensorAllocator& allocator,
const TensorShape& shape) {
return allocator.Allocate<uint8_t>(shape);

Просмотреть файл

@ -32,7 +32,7 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
size_t N = SafeInt<size_t>(helper.N());
size_t K = SafeInt<size_t>(helper.K());
//validate scales and zero points
// validate scales and zero points
const auto* a_zp = context->Input<Tensor>(IN_A_ZERO_POINT);
const auto* b_zp = context->Input<Tensor>(IN_B_ZERO_POINT);
const auto* y_zp = context->Input<Tensor>(IN_Y_ZERO_POINT);

Просмотреть файл

@ -91,7 +91,6 @@ Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
}
}
int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);
const T* input_data = input->Data<T>();

Просмотреть файл

@ -149,14 +149,13 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
t5_decoder_subgraph_->head_size,
t5_decoder_subgraph_->num_layers);
}
}
else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) {
} else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) {
if (attribute_name == "encoder") {
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
t5_encoder_subgraph_ = std::make_unique<WhisperEncoderSubgraph>(node,
attribute_name,
subgraph_session_state.GetGraphViewer());
attribute_name,
subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
@ -260,23 +259,23 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
// Subgraph has constraint that the output is either float or float16
if (!t5_decoder_subgraph_->IsOutputFloat16()) {
BeamSearchT5<float> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
cuda_device_prop_,
cuda_device_arch_};
BeamSearchT5<float> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
cuda_device_prop_,
cuda_device_arch_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);

Просмотреть файл

@ -31,10 +31,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
ORT_ENFORCE(context != nullptr);
const Tensor* input_ids = context->Input<Tensor>(0);
const auto& dims = input_ids->Shape().GetDims();
if (this->model_type == IGenerationParameters::kModelTypeWhisper){
if (this->model_type == IGenerationParameters::kModelTypeWhisper) {
ORT_ENFORCE(dims.size() == 3, "input_features shall have 3 dimensions. Got ", dims.size());
}
else {
} else {
ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
}
batch_size = static_cast<int>(dims[0]);
@ -69,8 +68,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
if (length_penalty_tensor) {
if (length_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) {
length_penalty = static_cast<float>(*length_penalty_tensor->Data<float>());
}
else {
} else {
length_penalty = static_cast<MLFloat16>(*length_penalty_tensor->Data<MLFloat16>());
}
} else {
@ -81,8 +79,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
if (repetition_penalty_tensor) {
if (repetition_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) {
repetition_penalty = static_cast<float>(*repetition_penalty_tensor->Data<float>());
}
else {
} else {
repetition_penalty = static_cast<MLFloat16>(*repetition_penalty_tensor->Data<MLFloat16>());
}
} else {

Просмотреть файл

@ -124,17 +124,16 @@ class GenerateBase {
const Tensor* attention_mask,
const Tensor* presence_mask) const {
const auto& dims = input_ids->Shape().GetDims();
if (parameters->model_type == IGenerationParameters::kModelTypeWhisper){
if (dims.size() != 3){
if (parameters->model_type == IGenerationParameters::kModelTypeWhisper) {
if (dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'input_features' is expected to have 3 dimensions, got ", dims.size());
"Input 'input_features' is expected to have 3 dimensions, got ", dims.size());
}
}
else if (dims.size() != 2) {
} else if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'input_ids' is expected to have 2 dimensions, got ", dims.size());
}
}
if (vocab_mask != nullptr) { // vocab_mask is optional
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
@ -184,10 +183,9 @@ class GenerateBase {
if (parameters->model_type == IGenerationParameters::kModelTypeWhisper) {
if (dims_attn.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have 3 dimensions, got ", dims_attn.size());
"Input 'attention_mask' is expected to have 3 dimensions, got ", dims_attn.size());
}
}
else if (dims_attn.size() != 2) {
} else if (dims_attn.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size());
}

Просмотреть файл

@ -329,7 +329,6 @@ Status CreateWhisperEncoderInputs(
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids);
// ---------------------------------------------------------------
// Utility Functions
// ---------------------------------------------------------------

Просмотреть файл

@ -147,7 +147,7 @@ class LogitsProcessorList : public ILogitsProcessorList {
void Process(const ISequences* sequences, gsl::span<float>& next_token_scores, int step);
private:
template<typename GenerationParametersT>
template <typename GenerationParametersT>
void LogitsProcessorInitImpl(const GenerationParametersT& parameters) {
processor_list_.clear();
@ -159,8 +159,7 @@ class LogitsProcessorList : public ILogitsProcessorList {
if (parameters.no_repeat_ngram_size > 0) {
no_repeat_ngram_processor_ = std::make_unique<
NoRepeatNGramLogitsProcessor<float>
>(parameters.no_repeat_ngram_size);
NoRepeatNGramLogitsProcessor<float>>(parameters.no_repeat_ngram_size);
processor_list_.push_back(no_repeat_ngram_processor_.get());
}
@ -171,9 +170,8 @@ class LogitsProcessorList : public ILogitsProcessorList {
if (!parameters.prefix_vocab_mask.empty()) {
prefix_vocab_mask_processor_ = std::make_unique<
PrefixVocabMaskLogitsProcessor<float>
>(parameters.prefix_vocab_mask,
parameters.batch_size);
PrefixVocabMaskLogitsProcessor<float>>(parameters.prefix_vocab_mask,
parameters.batch_size);
processor_list_.push_back(prefix_vocab_mask_processor_.get());
}
@ -190,9 +188,8 @@ class LogitsProcessorList : public ILogitsProcessorList {
if (!parameters.presence_mask.empty()) {
presence_penalty_processor_ = std::make_unique<
PresencePenaltyLogitsProcessor<float>
>(parameters.presence_mask,
parameters.presence_penalty);
PresencePenaltyLogitsProcessor<float>>(parameters.presence_mask,
parameters.presence_penalty);
processor_list_.push_back(presence_penalty_processor_.get());
}

Просмотреть файл

@ -154,12 +154,12 @@ Status Sample(AllocatorPtr& allocator,
*sampled_idx));
// TODO: update presense_mask()
#ifdef DEBUG_GENERATION
dumper->Print("sampled_idx", *sampled_idx);
dumper->Print("sampled_idx", *sampled_idx);
#endif
return Status::OK();
}
} // namespace SamplingCudaHelper
} // namespace contrib
} // namespace onnxruntime
} // namespace SamplingCpuHelper
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -40,7 +40,7 @@ namespace transformers {
*/
Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
const std::vector<const NodeArg*>& subgraph_outputs) {
ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs);
ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs);
@ -75,7 +75,7 @@ Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgr
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type && subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float16_type,
"encoder subgraph input 0 (encoder_input_features) shall have float32 or float16 type");
"encoder subgraph input 0 (encoder_input_features) shall have float32 or float16 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"encoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,

Просмотреть файл

@ -38,9 +38,9 @@ Status Unique<float>::Compute(OpKernelContext* ctx) const {
int64_t* output_idx_data = output_idx->MutableData<int64_t>();
struct ElementData {
int64_t input_pos_; // original index
int64_t input_pos_; // original index
int64_t output_pos_;
int64_t count_; // number of times encountered
int64_t count_; // number of times encountered
};
// XXX: Refactoring for less memory allocations. unordered_map

Просмотреть файл

@ -34,7 +34,7 @@ void WordConvEmbedding::CharEmbeddingLookup(
}
}
//input : [sequence_length, word_length, char_embedding_size]
// input : [sequence_length, word_length, char_embedding_size]
void WordConvEmbedding::ComputeConvMaxPoolWithActivation(
AllocatorPtr allocator,
const float* input,
@ -190,11 +190,11 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
// SafeInt<size_t>(seq_len) * word_len * char_embedding_size
size_t chars_embeddings_size = SafeInt<size_t>(seq_len) * word_len * char_embedding_size;
auto chars_embeddings_ptr = IAllocator::MakeUniquePtr<float>(alloc, chars_embeddings_size);
auto words_length_ptr = IAllocator::MakeUniquePtr<int>(alloc, onnxruntime::narrow<size_t>(seq_len) );
auto words_length_ptr = IAllocator::MakeUniquePtr<int>(alloc, onnxruntime::narrow<size_t>(seq_len));
std::memset(chars_embeddings_ptr.get(), 0, chars_embeddings_size * sizeof(float));
std::memset(words_length_ptr.get(), 0, SafeInt<size_t>(seq_len) * sizeof(int));
CalculateLengthOfEachWordInSequence(seq_ptr, words_length_ptr.get(), onnxruntime::narrow<size_t>(seq_len) , onnxruntime::narrow<size_t>(word_len));
CalculateLengthOfEachWordInSequence(seq_ptr, words_length_ptr.get(), onnxruntime::narrow<size_t>(seq_len), onnxruntime::narrow<size_t>(word_len));
CharEmbeddingLookup(seq_ptr,
w_char_embedding.Data<float>(),
@ -224,7 +224,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
/* Range operator */
ONNX_OPERATOR_KERNEL_EX(
WordConvEmbedding, //name
WordConvEmbedding, // name
kMSDomain,
1,
kCpuExecutionProvider,

Просмотреть файл

@ -22,19 +22,19 @@ namespace cuda {
.MayInplace(0, 0), \
x<T>);
#define UNARY_ACTIVATION_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToCudaType<T>::MappedType>( \
Stream(context), \
#define UNARY_ACTIVATION_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
UnaryElementwisePreparation p; \
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x<typename ToCudaType<T>::MappedType>( \
Stream(context), \
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(p.input_tensor->Data<T>()), \
reinterpret_cast<typename ToCudaType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
&func_ctx, p.output_tensor->Shape().Size()); \
\
return Status::OK(); \
}
#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \
@ -56,6 +56,6 @@ REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, double)
} //namespace cuda
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -62,10 +62,10 @@ __device__ inline void Softmax(const int all_sequence_length,
if (i >= valid_start) {
const int index = offset + i;
float input_at_idx = no_rpb
? float(input[index])
: float(input[index] + (broadcast_rel_pos_bias
? rel_pos_bias[index % size_per_batch]
: rel_pos_bias[index]));
? float(input[index])
: float(input[index] + (broadcast_rel_pos_bias
? rel_pos_bias[index % size_per_batch]
: rel_pos_bias[index]));
if (thread_data_max < input_at_idx) {
thread_data_max = input_at_idx;
}
@ -148,10 +148,10 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length,
const bool no_rpb = (rel_pos_bias == nullptr);
const int size_per_batch = gridDim.x * all_sequence_length;
float input_data = no_rpb
? float(input[index])
: float(input[index] + (broadcast_rel_pos_bias
? rel_pos_bias[index % size_per_batch]
: rel_pos_bias[index]));
? float(input[index])
: float(input[index] + (broadcast_rel_pos_bias
? rel_pos_bias[index % size_per_batch]
: rel_pos_bias[index]));
float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F);
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
@ -229,12 +229,12 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length,
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
float input_data = is_valid
? (rel_pos_bias
? float(input[index] + (broadcast_rel_pos_bias
? rel_pos_bias[index % size_per_batch]
: rel_pos_bias[index]))
: float(input[index]))
: float(-CUDART_INF_F);
? (rel_pos_bias
? float(input[index] + (broadcast_rel_pos_bias
? rel_pos_bias[index % size_per_batch]
: rel_pos_bias[index]))
: float(input[index]))
: float(-CUDART_INF_F);
cached_data[seq_idx] = input_data;
thread_data_max = max(thread_data_max, input_data);
}
@ -300,8 +300,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
if (rel_pos_bias == nullptr) {
thread_data = float(input[index]) * rsqrt_head_size;
} else {
T rel_pos_bias_value = broadcast_rel_pos_bias ?
rel_pos_bias[index % size_per_batch] : rel_pos_bias[index];
T rel_pos_bias_value = broadcast_rel_pos_bias ? rel_pos_bias[index % size_per_batch] : rel_pos_bias[index];
thread_data = float(input[index] + rel_pos_bias_value) * rsqrt_head_size;
}
@ -433,8 +432,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
}
if (rel_pos_bias != nullptr) {
float rel_pos_bias_value = broadcast_rel_pos_bias ?
float(rel_pos_bias[index % size_per_batch]) : float(rel_pos_bias[index]);
float rel_pos_bias_value = broadcast_rel_pos_bias ? float(rel_pos_bias[index % size_per_batch]) : float(rel_pos_bias[index]);
thread_data += rel_pos_bias_value;
}
}
@ -590,10 +588,10 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length,
const bool no_rpb = (rel_pos_bias == nullptr);
const int size_per_batch = gridDim.x * sequence_length;
float input_data = no_rpb
? float(input[index])
: float(input[index] + (broadcast_rel_pos_bias
? rel_pos_bias[index % size_per_batch]
: rel_pos_bias[index]));
? float(input[index])
: float(input[index] + (broadcast_rel_pos_bias
? rel_pos_bias[index % size_per_batch]
: rel_pos_bias[index]));
float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F);
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
@ -761,17 +759,17 @@ Status ComputeSoftmaxWithCumSeqLength(
template <typename T>
Status ComputeSoftmaxWithMask1D(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int num_heads,
const int* mask_index,
const int* mask_start,
const T* rel_pos_bias,
const bool broadcast_rel_pos_bias,
const T* input,
T* output,
const bool is_unidirectional) {
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int num_heads,
const int* mask_index,
const int* mask_start,
const T* rel_pos_bias,
const bool broadcast_rel_pos_bias,
const T* input,
T* output,
const bool is_unidirectional) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {

Просмотреть файл

@ -30,12 +30,12 @@ struct MemoryEfficientAttentionParams {
int32_t* seqstart_k_ptr;
int32_t* seqlen_k_ptr;
const void* query; // [B, S, N, H]
const void* key; // [B, L, N, H], where L is kv_sequence_length
const void* value; // [B, L, N, H_v]
const void* attn_bias; // [N, S, S*] or null
void* output; // [B, S, N, H_v]
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
const void* query; // [B, S, N, H]
const void* key; // [B, L, N, H], where L is kv_sequence_length
const void* value; // [B, L, N, H_v]
const void* attn_bias; // [N, S, S*] or null
void* output; // [B, S, N, H_v]
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
cudaStream_t stream;
static bool need_workspace(size_t v_head_size, bool is_float) {

Просмотреть файл

@ -65,24 +65,24 @@ Status EmbedLayerNorm<T>::ComputeInternal(OpKernelContext* context) const {
return LaunchEmbedLayerNormKernel(
Stream(context),
output->MutableData<T>(),
mask_index->MutableData<int32_t>(),
input_ids->Data<int32_t>(),
nullptr == segment_ids ? nullptr : segment_ids->Data<int32_t>(),
nullptr == mask ? nullptr : mask->Data<int32_t>(),
gamma->Data<T>(),
beta->Data<T>(),
word_embedding->Data<T>(),
position_embedding->Data<T>(),
nullptr == segment_embedding ? nullptr : segment_embedding->Data<T>(),
epsilon_,
static_cast<int>(hidden_size),
batch_size,
sequence_length,
element_size,
embedding_sum == nullptr ? nullptr : embedding_sum->MutableData<T>(),
position_ids == nullptr ? nullptr : position_ids->Data<int32_t>(),
broadcast_position_ids);
output->MutableData<T>(),
mask_index->MutableData<int32_t>(),
input_ids->Data<int32_t>(),
nullptr == segment_ids ? nullptr : segment_ids->Data<int32_t>(),
nullptr == mask ? nullptr : mask->Data<int32_t>(),
gamma->Data<T>(),
beta->Data<T>(),
word_embedding->Data<T>(),
position_embedding->Data<T>(),
nullptr == segment_embedding ? nullptr : segment_embedding->Data<T>(),
epsilon_,
static_cast<int>(hidden_size),
batch_size,
sequence_length,
element_size,
embedding_sum == nullptr ? nullptr : embedding_sum->MutableData<T>(),
position_ids == nullptr ? nullptr : position_ids->Data<int32_t>(),
broadcast_position_ids);
}
} // namespace cuda

Просмотреть файл

@ -8,24 +8,24 @@ namespace contrib {
namespace cuda {
Status LaunchEmbedLayerNormKernel(cudaStream_t stream,
void* output, // output tensor
void* mask_index, // output mask index
const int* input_ids, // input word IDs
const int* segment_ids, // input segment IDs
const int* input_mask, // input mask
const void* gamma, // weight for layer normalization
const void* beta, // bias for layer normalization
const void* word_embedding, // weights for word embeddings
const void* position_embedding, // weights for position embeddings
const void* segment_embedding, // weights for segment (like sentence) embeddings
float epsilon, // epsilon for layer normalization
const int hidden_size, // hidden size (that is head_size * num_heads)
int batch_size, // batch size
int sequence_length, // sequence length
const size_t element_size, // size of output element: 2 for half, 4 for float.
void* embedding_sum, // Optional output of sum of embeddings
const int* position_ids, // Optional input of position ids
const bool broadcast_position_ids); // Whether to broadcast position ids
void* output, // output tensor
void* mask_index, // output mask index
const int* input_ids, // input word IDs
const int* segment_ids, // input segment IDs
const int* input_mask, // input mask
const void* gamma, // weight for layer normalization
const void* beta, // bias for layer normalization
const void* word_embedding, // weights for word embeddings
const void* position_embedding, // weights for position embeddings
const void* segment_embedding, // weights for segment (like sentence) embeddings
float epsilon, // epsilon for layer normalization
const int hidden_size, // hidden size (that is head_size * num_heads)
int batch_size, // batch size
int sequence_length, // sequence length
const size_t element_size, // size of output element: 2 for half, 4 for float.
void* embedding_sum, // Optional output of sum of embeddings
const int* position_ids, // Optional input of position ids
const bool broadcast_position_ids); // Whether to broadcast position ids
} // namespace cuda
} // namespace contrib

Просмотреть файл

@ -10,7 +10,7 @@ namespace cuda {
template <typename T>
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
const T* input, const T* bias, T* output, bool use_half2);
const T* input, const T* bias, T* output, bool use_half2);
} // namespace cuda
} // namespace contrib

Просмотреть файл

@ -88,12 +88,12 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
relative_position_bias,
past_key,
past_value,
nullptr, // past_seq_len
nullptr, // past_seq_len
&parameters,
num_heads_,
mask_filter_value_,
scale_,
false, // past_present_share_buffer
false, // past_present_share_buffer
device_prop.maxThreadsPerBlock));
int sequence_length = parameters.sequence_length;

Просмотреть файл

@ -219,7 +219,7 @@ MHARunner* PackedAttention<T>::TryGettingFusedRunner(const PackedAttentionParame
!parameters.has_relative_position_bias &&
parameters.hidden_size == parameters.v_hidden_size;
if(!use_fused_runner) {
if (!use_fused_runner) {
return fused_runner;
}
@ -232,7 +232,7 @@ MHARunner* PackedAttention<T>::TryGettingFusedRunner(const PackedAttentionParame
enable_trt_flash_attention_,
false);
if(!is_fMHA_supported) {
if (!is_fMHA_supported) {
return fused_runner;
}

Просмотреть файл

@ -12,14 +12,13 @@ using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
RelativePositionBias , \
RelativePositionBias, \
kMSDomain, \
1, \
T, \
@ -63,8 +62,8 @@ Status RelPosAttnBias<T>::ComputeInternal(OpKernelContext* context) const {
const int64_t num_buckets = bias_table_dims[0];
const int64_t num_heads = bias_table_dims[1];
const int64_t query_len = *query_length->Data<int64_t>();
const int64_t key_len = *key_length->Data<int64_t>();
const int64_t query_len = *query_length->Data<int64_t>();
const int64_t key_len = *key_length->Data<int64_t>();
if (query_len != key_len) {
ORT_THROW("Relatvie position bias currently only support query length equal to key length in Self Attention.");
@ -145,7 +144,7 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
typedef typename ToCudaType<T>::MappedType CudaT;
const auto BNS = batch_size * num_heads_ * seq_len;
const size_t elements_in_query = (size_t)BNS * (size_t)head_size;
const size_t elements_after_gemm = (size_t)BNS *(size_t)D;
const size_t elements_after_gemm = (size_t)BNS * (size_t)D;
bool reuse_output = (seq_len >= D);
size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm));
auto workspace = GetScratchBuffer<void>(workspace_size, context->GetComputeStream());

Просмотреть файл

@ -33,7 +33,6 @@ class GatedRelativePositionBias final : public CudaKernel {
int num_heads_;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

Просмотреть файл

@ -19,8 +19,7 @@ Status LaunchRelPosAttnBiasKernel(
const int num_bucket,
const int max_distance,
const bool is_bidirectional,
const int max_threads_per_block
);
const int max_threads_per_block);
template <typename T>
Status LaunchGatedRelativePositionBiasKernel(

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -136,7 +136,7 @@ AllToAll::AllToAll(const OpKernelInfo& info) : NcclKernel(info) {
Status AllToAll::ComputeInternal(OpKernelContext* context) const {
const ncclComm_t comm = nccl_->Comm();
auto input_tensor = context->Input<Tensor>(0);
const char* input_data = static_cast<const char *>(input_tensor->DataRaw());
const char* input_data = static_cast<const char*>(input_tensor->DataRaw());
const auto in_shape = input_tensor->Shape();
const int64_t input_count = in_shape.Size();
auto out_shape = in_shape;
@ -144,7 +144,7 @@ Status AllToAll::ComputeInternal(OpKernelContext* context) const {
const int64_t rank_stride = input_count / group_size_;
const ncclDataType_t dtype = GetNcclDataType(input_tensor->DataType());
char* output_data = static_cast<char *>(context->Output(0, out_shape)->MutableDataRaw());
char* output_data = static_cast<char*>(context->Output(0, out_shape)->MutableDataRaw());
#ifdef ORT_USE_NCCL
NCCL_RETURN_IF_ERROR(ncclGroupStart());

Просмотреть файл

@ -69,7 +69,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
key,
value,
nullptr, //bias
nullptr, // bias
mask_index,
relative_position_bias,
past_key,
@ -108,9 +108,9 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
Tensor* output = context->Output(0, output_shape);
std::vector<int64_t> present_dims{
parameters.batch_size, parameters.num_heads,
past_present_share_buffer_ ? parameters.max_sequence_length : parameters.total_sequence_length,
parameters.head_size};
parameters.batch_size, parameters.num_heads,
past_present_share_buffer_ ? parameters.max_sequence_length : parameters.total_sequence_length,
parameters.head_size};
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(kPresentOutputIndex, present_shape);
Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape);

Просмотреть файл

@ -23,7 +23,7 @@ static constexpr int kPresentOutputIndex = 1;
#define REGISTER_KERNEL_TYPED(T1, T2) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
DecoderMaskedSelfAttention, \
DecoderMaskedSelfAttention, \
kMSDomain, \
1, \
T1, \

Просмотреть файл

@ -34,11 +34,10 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
void* out = nullptr;
const int32_t* cache_indir = nullptr;
const int32_t* mask = nullptr; // [B, total_sequence_length]
const int32_t* mask = nullptr; // [B, total_sequence_length]
};
template<
template <
// The type of the inputs. Supported types: float and half.
typename T,
// The hidden dimension per head.
@ -51,11 +50,9 @@ template<
int THREADS_PER_BLOCK>
__global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
template<typename T, int head_size>
template <typename T, int head_size>
void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
} // namespace cuda
} // namespace contrib

Просмотреть файл

@ -71,10 +71,10 @@ Status BiasAdd<T>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<T>::MappedType CudaT;
const int32_t grid_size = static_cast<int32_t>(input_dims[0] * input_dims[1]);
LaunchBiasAddKernel<CudaT>(Stream(context), grid_size, static_cast<int32_t>(input_dims[2]),
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(bias->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
reinterpret_cast<CudaT*>(output->MutableData<T>()));
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(bias->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
reinterpret_cast<CudaT*>(output->MutableData<T>()));
CUDA_RETURN_IF_ERROR(cudaPeekAtLastError());
return Status::OK();

Просмотреть файл

@ -19,7 +19,7 @@ class GridSample final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
private:
int64_t mode_i_; // 0: bilinear (default), 1: nearest 2: bicubic
int64_t mode_i_; // 0: bilinear (default), 1: nearest 2: bicubic
int64_t padding_mode_i_; // 0:'zeros', 1: 'border', 2:'reflection'
int64_t align_corners_;
};

Просмотреть файл

@ -19,24 +19,24 @@ namespace cuda {
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
x<T>);
#define CONTRIB_BINARY_ELEMENTWISE_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
BinaryElementwisePreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare(context, &prepare)); \
Impl_##x<typename ToCudaType<T>::MappedType>( \
Stream(context), \
prepare.output_rank_or_simple_broadcast, \
&prepare.lhs_padded_strides, \
#define CONTRIB_BINARY_ELEMENTWISE_COMPUTE(x, T) \
template <> \
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
BinaryElementwisePreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare(context, &prepare)); \
Impl_##x<typename ToCudaType<T>::MappedType>( \
Stream(context), \
prepare.output_rank_or_simple_broadcast, \
&prepare.lhs_padded_strides, \
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()), \
&prepare.rhs_padded_strides, \
&prepare.rhs_padded_strides, \
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(prepare.rhs_tensor->Data<T>()), \
&prepare.fdm_output_strides, \
prepare.fdm_H, \
prepare.fdm_C, \
&prepare.fdm_output_strides, \
prepare.fdm_H, \
prepare.fdm_C, \
reinterpret_cast<typename ToCudaType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()), \
prepare.output_tensor->Shape().Size()); \
return Status::OK(); \
prepare.output_tensor->Shape().Size()); \
return Status::OK(); \
}
#define CONTRIB_BINARY_OP_TYPED(name, ver, T) \

Просмотреть файл

@ -20,7 +20,7 @@ namespace cuda {
#define CONTRIB_BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template <typename T> \
void Impl_##name( \
cudaStream_t stream, \
cudaStream_t stream, \
int32_t output_rank_or_simple_broadcast, \
const TArray<int64_t>* lhs_padded_strides, \
const T* lhs_data, \

Просмотреть файл

@ -12,7 +12,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
//key
// key
struct FFTState {
int64_t signal_ndim;
int64_t signal_dims[5];
@ -22,7 +22,7 @@ struct FFTState {
cudaDataType exec_type;
};
//value
// value
struct CufftPlanInfo {
cufftHandle plan;
size_t ws_size_t;

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше