Run clang-format in CI (#15524)
### Description Run clang-format in CI. Formatted all c/c++, objective-c/c++ files. Excluded ``` 'onnxruntime/core/mlas/**', 'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**', ``` because they contain assembly or is data heavy ### Motivation and Context Coding style consistency
This commit is contained in:
Родитель
2700d01642
Коммит
cf19c3697d
|
@ -19,4 +19,3 @@ DerivePointerAlignment: false
|
|||
# NamespaceIndentation: All
|
||||
|
||||
...
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# To lint local changes:
|
||||
#
|
||||
# ```bash
|
||||
# lintrunner -m main
|
||||
# lintrunner
|
||||
# ```
|
||||
#
|
||||
# To lint all files:
|
||||
|
@ -33,6 +33,8 @@
|
|||
# To update an existing linting rule or create a new one, modify this file or create a
|
||||
# new adapter following examples in https://github.com/justinchuby/lintrunner-adapters.
|
||||
|
||||
merge_base_with = 'origin/main'
|
||||
|
||||
[[linter]]
|
||||
code = 'RUFF'
|
||||
include_patterns = [
|
||||
|
@ -168,3 +170,44 @@ command = [
|
|||
'@{{PATHSFILE}}'
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
[[linter]]
|
||||
code = 'CLANGFORMAT'
|
||||
include_patterns = [
|
||||
'**/*.h',
|
||||
'**/*.cc',
|
||||
'**/*.hpp',
|
||||
'**/*.cpp',
|
||||
'**/*.m',
|
||||
'**/*.mm',
|
||||
]
|
||||
exclude_patterns = [
|
||||
'java/**', # FIXME: Enable clang-format for java
|
||||
'js/**',
|
||||
'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**', # Contains data chunks
|
||||
'onnxruntime/core/flatbuffers/schema/ort.fbs.h', # Generated code
|
||||
'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
|
||||
'onnxruntime/core/mlas/**', # Contains assembly code
|
||||
'winml/**', # FIXME: Enable clang-format for winml
|
||||
]
|
||||
command = [
|
||||
'python',
|
||||
'-m',
|
||||
'lintrunner_adapters',
|
||||
'run',
|
||||
'clangformat_linter',
|
||||
'--binary=clang-format',
|
||||
'--fallback',
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
init_command = [
|
||||
'python',
|
||||
'-m',
|
||||
'lintrunner_adapters',
|
||||
'run',
|
||||
'pip_init',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'clang-format==16.0.1',
|
||||
]
|
||||
is_formatter = true
|
||||
|
|
|
@ -6,6 +6,6 @@
|
|||
// Extending the std namespace is undefined behavior
|
||||
// NOLINTNEXTLINE
|
||||
namespace std {
|
||||
inline char *getenv(const char*) { return nullptr; }
|
||||
}
|
||||
inline char* getenv(const char*) { return nullptr; }
|
||||
} // namespace std
|
||||
#endif
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
#include <stdio.h>
|
||||
#include "winrt/microsoft.ai.machinelearning.h"
|
||||
#include "winrt/windows.storage.h"
|
||||
#include "winrt/windows.foundation.h"
|
||||
#include "winrt/windows.foundation.collections.h"
|
||||
#include "winrt/Windows.Graphics.h"
|
||||
#include "winrt/Windows.Graphics.Imaging.h"
|
||||
#include "winrt/Windows.Graphics.h"
|
||||
#include "winrt/Windows.Media.h"
|
||||
#include "winrt/microsoft.ai.machinelearning.h"
|
||||
#include "winrt/windows.foundation.collections.h"
|
||||
#include "winrt/windows.foundation.h"
|
||||
#include "winrt/windows.storage.h"
|
||||
#include <stdio.h>
|
||||
#include <windows.h>
|
||||
|
||||
EXTERN_C IMAGE_DOS_HEADER __ImageBase;
|
||||
|
@ -15,42 +15,44 @@ using namespace winrt::Windows::Storage;
|
|||
using namespace winrt::Windows::Media;
|
||||
using namespace winrt::Windows::Graphics::Imaging;
|
||||
|
||||
std::wstring GetModulePath() {
|
||||
std::wstring val;
|
||||
wchar_t modulePath[MAX_PATH] = {0};
|
||||
GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath));
|
||||
wchar_t drive[_MAX_DRIVE];
|
||||
wchar_t dir[_MAX_DIR];
|
||||
wchar_t filename[_MAX_FNAME];
|
||||
wchar_t ext[_MAX_EXT];
|
||||
_wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
|
||||
std::wstring GetModulePath()
|
||||
{
|
||||
std::wstring val;
|
||||
wchar_t modulePath[MAX_PATH] = {0};
|
||||
GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath));
|
||||
wchar_t drive[_MAX_DRIVE];
|
||||
wchar_t dir[_MAX_DIR];
|
||||
wchar_t filename[_MAX_FNAME];
|
||||
wchar_t ext[_MAX_EXT];
|
||||
_wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT);
|
||||
|
||||
val = drive;
|
||||
val += dir;
|
||||
val = drive;
|
||||
val += dir;
|
||||
|
||||
return val;
|
||||
return val;
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("Load squeezenet.onnx.\n");
|
||||
auto model = LearningModel::LoadFromFilePath(L"squeezenet.onnx");
|
||||
printf("Load kitten_224.png as StorageFile.\n");
|
||||
auto name = GetModulePath() + L"kitten_224.png";
|
||||
auto image = StorageFile::GetFileFromPathAsync(name).get();
|
||||
printf("Load StorageFile into Stream.\n");
|
||||
auto stream = image.OpenAsync(FileAccessMode::Read).get();
|
||||
printf("Create SoftwareBitmap from decoded Stream.\n");
|
||||
auto softwareBitmap = BitmapDecoder::CreateAsync(stream).get().GetSoftwareBitmapAsync().get();
|
||||
printf("Create VideoFrame.\n");
|
||||
auto frame = VideoFrame::CreateWithSoftwareBitmap(softwareBitmap);
|
||||
printf("Create LearningModelSession.\n");
|
||||
auto session = LearningModelSession(model);
|
||||
printf("Create LearningModelBinding.\n");
|
||||
auto binding = LearningModelBinding(session);
|
||||
printf("Bind data_0.\n");
|
||||
binding.Bind(L"data_0", frame);
|
||||
printf("Evaluate.\n");
|
||||
auto results = session.Evaluate(binding, L"");
|
||||
printf("Success!\n");
|
||||
return 0;
|
||||
int main()
|
||||
{
|
||||
printf("Load squeezenet.onnx.\n");
|
||||
auto model = LearningModel::LoadFromFilePath(L"squeezenet.onnx");
|
||||
printf("Load kitten_224.png as StorageFile.\n");
|
||||
auto name = GetModulePath() + L"kitten_224.png";
|
||||
auto image = StorageFile::GetFileFromPathAsync(name).get();
|
||||
printf("Load StorageFile into Stream.\n");
|
||||
auto stream = image.OpenAsync(FileAccessMode::Read).get();
|
||||
printf("Create SoftwareBitmap from decoded Stream.\n");
|
||||
auto softwareBitmap = BitmapDecoder::CreateAsync(stream).get().GetSoftwareBitmapAsync().get();
|
||||
printf("Create VideoFrame.\n");
|
||||
auto frame = VideoFrame::CreateWithSoftwareBitmap(softwareBitmap);
|
||||
printf("Create LearningModelSession.\n");
|
||||
auto session = LearningModelSession(model);
|
||||
printf("Create LearningModelBinding.\n");
|
||||
auto binding = LearningModelBinding(session);
|
||||
printf("Bind data_0.\n");
|
||||
binding.Bind(L"data_0", frame);
|
||||
printf("Evaluate.\n");
|
||||
auto results = session.Evaluate(binding, L"");
|
||||
printf("Success!\n");
|
||||
return 0;
|
||||
}
|
|
@ -47,10 +47,10 @@ struct CodeLocation {
|
|||
out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function;
|
||||
return out.str();
|
||||
}
|
||||
//utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8.
|
||||
// utf-8. Because on Windows we compile our code with "/utf-8". And we assume the other platforms only use utf-8.
|
||||
const std::string file_and_path;
|
||||
const int line_num;
|
||||
//utf-8
|
||||
// utf-8
|
||||
const std::string function;
|
||||
const std::vector<std::string> stacktrace;
|
||||
};
|
||||
|
|
|
@ -37,7 +37,6 @@
|
|||
#include "core/common/make_string.h"
|
||||
#include "core/common/status.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
using TimePoint = std::chrono::high_resolution_clock::time_point;
|
||||
|
|
|
@ -158,9 +158,9 @@ class NodeHashSet : public std::unordered_set<T,
|
|||
|
||||
template <typename Key, typename Value, typename Allocator>
|
||||
class NodeHashMap : public std::unordered_map<Key, Value,
|
||||
std::hash<Key>,
|
||||
std::equal_to<Key>,
|
||||
Allocator> {
|
||||
std::hash<Key>,
|
||||
std::equal_to<Key>,
|
||||
Allocator> {
|
||||
using Base = std::unordered_map<Key, Value,
|
||||
std::hash<Key>,
|
||||
std::equal_to<Key>,
|
||||
|
|
|
@ -255,15 +255,17 @@
|
|||
#else
|
||||
// Disabled in Release builds.
|
||||
#define VLOGS(logger, level) \
|
||||
if constexpr (true) {} else LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level)
|
||||
if constexpr (true) { \
|
||||
} else \
|
||||
LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level)
|
||||
#define VLOGS_USER(logger, level) \
|
||||
if constexpr (true) {} else LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level)
|
||||
if constexpr (true) { \
|
||||
} else \
|
||||
LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level)
|
||||
#define VLOGF(logger, level, format_str, ...)
|
||||
#define VLOGF_USER(logger, level, format_str, ...)
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
// Default logger variants
|
||||
#define VLOGS_DEFAULT(level) \
|
||||
VLOGS(::onnxruntime::logging::LoggingManager::DefaultLogger(), level)
|
||||
|
|
|
@ -75,7 +75,7 @@ struct EventRecord {
|
|||
|
||||
using Events = std::vector<EventRecord>;
|
||||
|
||||
//Execution Provider Profiler
|
||||
// Execution Provider Profiler
|
||||
class EpProfiler {
|
||||
public:
|
||||
virtual ~EpProfiler() = default;
|
||||
|
|
|
@ -121,10 +121,10 @@ class [[nodiscard]] Status {
|
|||
|
||||
Status(StatusCategory category, int code);
|
||||
|
||||
GSL_SUPPRESS(r.11)
|
||||
GSL_SUPPRESS(r .11)
|
||||
Status(const Status& other)
|
||||
: state_((other.state_ == nullptr) ? nullptr : new State(*other.state_)) {}
|
||||
GSL_SUPPRESS(r.11)
|
||||
GSL_SUPPRESS(r .11)
|
||||
Status& operator=(const Status& other) {
|
||||
if (state_ != other.state_) {
|
||||
if (other.state_ == nullptr) {
|
||||
|
|
|
@ -22,12 +22,11 @@ namespace onnxruntime {
|
|||
|
||||
class ORTInvoker {
|
||||
public:
|
||||
ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
|
||||
ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
|
||||
const logging::Logger& logger,
|
||||
const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) :
|
||||
execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) {
|
||||
const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) : execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) {
|
||||
if (!execution_provider_) {
|
||||
ORT_THROW("Execution provider is nullptr");
|
||||
ORT_THROW("Execution provider is nullptr");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,7 +35,7 @@ class ORTInvoker {
|
|||
}
|
||||
|
||||
common::Status Invoke(const std::string& op_name,
|
||||
//optional inputs / outputs?
|
||||
// optional inputs / outputs?
|
||||
const std::vector<OrtValue>& inputs,
|
||||
std::vector<OrtValue>& outputs,
|
||||
const NodeAttributes* attributes,
|
||||
|
|
|
@ -386,8 +386,8 @@ void AssignOpaqueDomainName(const char* domain, const char* name,
|
|||
|
||||
} // namespace data_types_internal
|
||||
|
||||
//The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
|
||||
//However, we do not allocate this type on heap.
|
||||
// The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor."
|
||||
// However, we do not allocate this type on heap.
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 26436)
|
||||
|
@ -614,7 +614,7 @@ class OptionalType :
|
|||
#if !defined(DISABLE_OPTIONAL_TYPE)
|
||||
OptionalType()
|
||||
#else
|
||||
OptionalType() : DisabledTypeBase { DataTypeImpl::GeneralType::kOptional, 0 }
|
||||
OptionalType() : DisabledTypeBase{DataTypeImpl::GeneralType::kOptional, 0}
|
||||
#endif
|
||||
{
|
||||
using namespace data_types_internal;
|
||||
|
|
|
@ -29,17 +29,17 @@
|
|||
namespace onnxruntime {
|
||||
namespace utils {
|
||||
|
||||
// The following primitives are strongly recommended for switching on tensor input datatypes for
|
||||
// kernel implementations.
|
||||
//
|
||||
// 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros
|
||||
// DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function<T>().
|
||||
// 2) if you have a few types, use Tensor.IsDataType<T>()/IsDataTypeString() or use utils::IsPrimitiveDataType<T>()
|
||||
// if you have a standalone MLDatatType with a sequence of if/else statements.
|
||||
// 3) For something in between, we suggest to use CallDispatcher pattern.
|
||||
//
|
||||
// Invoking DataTypeImpl::GetType<T>() for switching on input types is discouraged and should be avoided.
|
||||
// Every primitive type carries with it an integer constant that can be used for quick switching on types.
|
||||
// The following primitives are strongly recommended for switching on tensor input datatypes for
|
||||
// kernel implementations.
|
||||
//
|
||||
// 1) If you need to handle all of the primitive tensor contained datatypes, the best choice would be macros
|
||||
// DispatchOnTensorType or DispatchOnTensorTypeWithReturn. Use inline wrappers so your function can be invoked as function<T>().
|
||||
// 2) if you have a few types, use Tensor.IsDataType<T>()/IsDataTypeString() or use utils::IsPrimitiveDataType<T>()
|
||||
// if you have a standalone MLDatatType with a sequence of if/else statements.
|
||||
// 3) For something in between, we suggest to use CallDispatcher pattern.
|
||||
//
|
||||
// Invoking DataTypeImpl::GetType<T>() for switching on input types is discouraged and should be avoided.
|
||||
// Every primitive type carries with it an integer constant that can be used for quick switching on types.
|
||||
|
||||
#define DispatchOnTensorType(tensor_type, function, ...) \
|
||||
switch (tensor_type->AsPrimitiveDataType()->GetDataType()) { \
|
||||
|
@ -498,11 +498,10 @@ class ContainerChecker {
|
|||
ORT_ENFORCE(++index < c.size(), "Sequence is missing type entry for its element");
|
||||
constexpr int32_t prim_type = ToTensorProtoElementType<T>();
|
||||
// Check if this is a primitive type and it matches
|
||||
if constexpr(prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
|
||||
if constexpr (prim_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
|
||||
return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
|
||||
c[index].IsPrimType(prim_type);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
// T is not primitive, check next entry for non-primitive proto
|
||||
return IsContainerOfType<T>::check(c, index);
|
||||
}
|
||||
|
@ -528,11 +527,11 @@ class ContainerChecker {
|
|||
}
|
||||
ORT_ENFORCE(++index < c.size(), "Map is missing type entry for its value");
|
||||
constexpr int32_t val_type = ToTensorProtoElementType<V>();
|
||||
if constexpr(val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
|
||||
if constexpr (val_type != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
|
||||
return c[index].IsType(data_types_internal::ContainerType::kTensor) &&
|
||||
c[index].IsPrimType(val_type);
|
||||
}
|
||||
else return IsContainerOfType<V>::check(c, index);
|
||||
} else
|
||||
return IsContainerOfType<V>::check(c, index);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -69,10 +69,9 @@ struct BFloat16 {
|
|||
val = static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
||||
}
|
||||
#else
|
||||
if constexpr(endian::native == endian::little) {
|
||||
if constexpr (endian::native == endian::little) {
|
||||
std::memcpy(&val, reinterpret_cast<char*>(&v) + sizeof(uint16_t), sizeof(uint16_t));
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
std::memcpy(&val, &v, sizeof(uint16_t));
|
||||
}
|
||||
#endif
|
||||
|
@ -93,11 +92,10 @@ struct BFloat16 {
|
|||
float result;
|
||||
char* const first = reinterpret_cast<char*>(&result);
|
||||
char* const second = first + sizeof(uint16_t);
|
||||
if constexpr(endian::native == endian::little) {
|
||||
if constexpr (endian::native == endian::little) {
|
||||
std::memset(first, 0, sizeof(uint16_t));
|
||||
std::memcpy(second, &val, sizeof(uint16_t));
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
std::memcpy(first, &val, sizeof(uint16_t));
|
||||
std::memset(second, 0, sizeof(uint16_t));
|
||||
}
|
||||
|
@ -117,7 +115,6 @@ inline ORT_HOST_DEVICE bool operator==(const BFloat16& left, const BFloat16& rig
|
|||
inline ORT_HOST_DEVICE bool operator!=(const BFloat16& left, const BFloat16& right) { return left.val != right.val; }
|
||||
inline ORT_HOST_DEVICE bool operator<(const BFloat16& left, const BFloat16& right) { return left.val < right.val; }
|
||||
|
||||
|
||||
// User defined suffixes to make it easier to declare
|
||||
// initializers with MLFloat16 and BFloat16 from unsigned short
|
||||
// E.g 10_f16 or 10_b16
|
||||
|
|
|
@ -10,7 +10,7 @@ using DestroyFunc = void (*)(void*, void*);
|
|||
using AllocatorHandle = void*;
|
||||
|
||||
typedef struct {
|
||||
//right now we only include allocation for host memory
|
||||
// right now we only include allocation for host memory
|
||||
AllocateFunc allocate_func;
|
||||
DestroyFunc release_func;
|
||||
AllocatorHandle allocator_handle;
|
||||
|
|
|
@ -78,10 +78,10 @@ inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { ret
|
|||
std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info);
|
||||
|
||||
namespace std {
|
||||
template<>
|
||||
template <>
|
||||
struct hash<OrtMemoryInfo> {
|
||||
size_t operator()(const OrtMemoryInfo& i) const {
|
||||
return i.Hash();
|
||||
}
|
||||
};
|
||||
}
|
||||
} // namespace std
|
||||
|
|
|
@ -18,8 +18,8 @@ class DataTransferManager;
|
|||
|
||||
/**
|
||||
* @brief This is a Sparse Format enumeration
|
||||
*
|
||||
*
|
||||
*
|
||||
*
|
||||
*/
|
||||
enum class SparseFormat : uint32_t {
|
||||
kUndefined = 0x0U, // For completeness
|
||||
|
@ -31,7 +31,7 @@ enum class SparseFormat : uint32_t {
|
|||
std::ostream& operator<<(std::ostream&, SparseFormat);
|
||||
|
||||
/**
|
||||
* @brief This class implements SparseTensor.
|
||||
* @brief This class implements SparseTensor.
|
||||
* This class holds sparse non-zero data (values) and sparse format
|
||||
* specific indices. There are two main uses for the class (similar to that of Tensor)
|
||||
* - one is to re-present model sparse inputs. Such inputs typically reside
|
||||
|
@ -43,7 +43,7 @@ std::ostream& operator<<(std::ostream&, SparseFormat);
|
|||
* be used to supply pointers to format specific indices. These buffers are used as is
|
||||
* and will not be modified or deallocated by the instance. However, the lifespan of the buffers
|
||||
* must eclipse the lifespan of the SparseTensor instance.
|
||||
*
|
||||
*
|
||||
* - Represent sparse data that is a result of format conversion or a computation result. Use second constructor
|
||||
* to supply a desired allocator. Use Make*() format specific interfaces to supply values and format
|
||||
* specific indices. The specified data will be copied into an internally allocated buffer.
|
||||
|
@ -446,7 +446,6 @@ class SparseTensor final {
|
|||
const TensorShape& values_shape, const void* values_data,
|
||||
const TensorShape& indices_shape, const int32_t* indices_data);
|
||||
|
||||
|
||||
/// <summary>
|
||||
/// The method allocates a single contiguous buffer and creates instances of std::strings in it, with
|
||||
/// copies of the supplied zero-terminated strings followed by COO indices.
|
||||
|
|
|
@ -167,5 +167,4 @@ class IStreamCommandHandleRegistry {
|
|||
virtual void RegisterCreateStreamFn(OrtDevice::DeviceType device_type, CreateStreamFn f) = 0;
|
||||
};
|
||||
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <stddef.h> // needed for size_t on some platforms
|
||||
#include <stddef.h> // needed for size_t on some platforms
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ class Node;
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@class Function
|
||||
/**
|
||||
@class Function
|
||||
Class representing a Function.
|
||||
*/
|
||||
class Function {
|
||||
|
|
|
@ -408,9 +408,9 @@ class Node {
|
|||
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
|
||||
|
||||
/**
|
||||
* Clears removable attributes. These are no longer needed after the initialization
|
||||
* of the session. The function returns the number of removed attributes.
|
||||
*/
|
||||
* Clears removable attributes. These are no longer needed after the initialization
|
||||
* of the session. The function returns the number of removed attributes.
|
||||
*/
|
||||
int PruneRemovableAttributes(gsl::span<const std::string> removable_attributes);
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
@ -659,7 +659,7 @@ class Node {
|
|||
std::vector<std::unique_ptr<Graph>> subgraphs_;
|
||||
|
||||
// Can be saved? The node cannot be saved anymore if removable attributes have been cleared.
|
||||
bool can_be_saved_;
|
||||
bool can_be_saved_;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -21,7 +21,7 @@ using OpName_Domain_Version_Schema_Map = std::unordered_map<
|
|||
std::string,
|
||||
std::unordered_map<std::string, std::map<ONNX_NAMESPACE::OperatorSetVersion, ONNX_NAMESPACE::OpSchema>>>;
|
||||
|
||||
/**
|
||||
/**
|
||||
@struct SchemaRegistryVersion
|
||||
onnxruntime schema registry is a supplement to the built-in ONNX schema.
|
||||
Every schema registry represent a collection of schema deltas from baseline_opset_version to opset_version
|
||||
|
@ -60,7 +60,7 @@ class IOnnxRuntimeOpSchemaCollection : public ONNX_NAMESPACE::ISchemaRegistry {
|
|||
int* earliest_opset_where_unchanged) const = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
/**
|
||||
@class OnnxRuntimeOpSchemaRegistry
|
||||
|
||||
OnnxRuntimeOpSchemaRegistry is used to provide supplement for built-in ONNX schemas.
|
||||
|
@ -111,10 +111,10 @@ class OnnxRuntimeOpSchemaRegistry : public IOnnxRuntimeOpSchemaCollection {
|
|||
};
|
||||
|
||||
/**
|
||||
@class SchemaRegistryManager
|
||||
@class SchemaRegistryManager
|
||||
|
||||
SchemaRegistryManager provides a view based on built-in ONNX schema and a list of
|
||||
OnnxRuntimeOpSchemaRegistry as supplement.
|
||||
SchemaRegistryManager provides a view based on built-in ONNX schema and a list of
|
||||
OnnxRuntimeOpSchemaRegistry as supplement.
|
||||
|
||||
The user needs to make sure the customized schema registry is valid, otherwise the behavior is undefined.
|
||||
|
||||
|
@ -137,7 +137,7 @@ class SchemaRegistryManager : public onnxruntime::IOnnxRuntimeOpSchemaCollection
|
|||
/** Gets the last released opset versions.
|
||||
@param is_onnx_only If true, return ONNX schemas only. If false, return the schemas for all domains.
|
||||
*/
|
||||
DomainToVersionMap GetLastReleasedOpsetVersions(bool is_onnx_only) const ;
|
||||
DomainToVersionMap GetLastReleasedOpsetVersions(bool is_onnx_only) const;
|
||||
/**
|
||||
Gets the OpSchema and its history.
|
||||
Searches custom schema registries starting with the last one added. \
|
||||
|
|
|
@ -19,7 +19,7 @@ The interface for in-place transformation of a Graph.
|
|||
class GraphTransformer {
|
||||
public:
|
||||
GraphTransformer(const std::string& name,
|
||||
const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: name_(name), compatible_provider_types_(compatible_execution_providers) {
|
||||
}
|
||||
|
||||
|
|
|
@ -9,23 +9,23 @@
|
|||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@class RewriteRule
|
||||
@class RewriteRule
|
||||
|
||||
The base class for a rewrite rule. A rewrite rule represents a semantics-preserving transformation of a
|
||||
computation graph. It can be used to represent, for example, the elimination of operators that serve as
|
||||
no-ops (e.g., dropout during inference), as well as inlining of "function" definitions or the dual operation
|
||||
of replacing a complex expression by an equivalent function-call). Unlike the more general GraphTransformer,
|
||||
a rewrite rule is a more local transformation that is triggered on a particular node of the graph.
|
||||
The base class for a rewrite rule. A rewrite rule represents a semantics-preserving transformation of a
|
||||
computation graph. It can be used to represent, for example, the elimination of operators that serve as
|
||||
no-ops (e.g., dropout during inference), as well as inlining of "function" definitions or the dual operation
|
||||
of replacing a complex expression by an equivalent function-call). Unlike the more general GraphTransformer,
|
||||
a rewrite rule is a more local transformation that is triggered on a particular node of the graph.
|
||||
|
||||
Each rule has a set of conditions and a body. The conditions have to be satisfied for the body of the rule
|
||||
to be triggered. Therefore, when creating a new rewrite rule, two main functions have to be implemented:
|
||||
- SatisfyCondition defines the condition checks. It is advisable to add the more selective checks first,
|
||||
Each rule has a set of conditions and a body. The conditions have to be satisfied for the body of the rule
|
||||
to be triggered. Therefore, when creating a new rewrite rule, two main functions have to be implemented:
|
||||
- SatisfyCondition defines the condition checks. It is advisable to add the more selective checks first,
|
||||
because those will lead to discarding fast rules that cannot be applied on a node.
|
||||
- Apply is the actual body of the rule that will be executed if SatisfyCondition returns true for a particular
|
||||
node. Note that additional, more complex checks can be included in the Apply if putting them in the
|
||||
SatisfyCondition would lead to duplicate work (e.g., when we make a check on a Node attribute but we need
|
||||
that attribute to execute the rule too).
|
||||
In general, simple fast checks are a better fit for SatisfyCondition, whereas more complex ones can be added
|
||||
In general, simple fast checks are a better fit for SatisfyCondition, whereas more complex ones can be added
|
||||
in the Apply.
|
||||
|
||||
In order to avoid evaluating the SatisfyCondition for each rule and each node of the graph, each rewrite rule
|
||||
|
@ -75,13 +75,13 @@ class RewriteRule {
|
|||
|
||||
const std::string name_;
|
||||
|
||||
/** Checks if the Node of the given Graph satisfies the conditions of this rule. The body of the rule will be
|
||||
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
|
||||
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
|
||||
/** Checks if the Node of the given Graph satisfies the conditions of this rule. The body of the rule will be
|
||||
evaluated if this condition function returns true. This can include a more complex pattern matching (conditions
|
||||
on the ascending or descending nodes of the node for which this rule was triggered) or some other properties
|
||||
of the nodes. */
|
||||
virtual bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const = 0;
|
||||
|
||||
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
|
||||
/** This is the actual body of the rule that performs the graph transformation. The transformation happens in-place.
|
||||
The return-value of node may be different from the input-value due to rewriting.
|
||||
The value of "rule_effect" indicates whether and how the graph was modified by the rule. */
|
||||
virtual common::Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const = 0;
|
||||
|
|
|
@ -13,12 +13,12 @@ namespace onnxruntime {
|
|||
/**
|
||||
@class RuleBasedGraphTransformer
|
||||
|
||||
Rule-based graph transformer that provides an API to register rewrite rules
|
||||
Rule-based graph transformer that provides an API to register rewrite rules
|
||||
and an API to apply all applicable rules to a Graph.
|
||||
|
||||
Represents an IGraphTransformer determined by a set of rewrite rules.
|
||||
The transformer will apply all the rewrite rules iteratively as determined by the underlying rewriting strategy.
|
||||
Several rewriting-strategies are possible when traversing the graph and applying rewrite rules,
|
||||
Several rewriting-strategies are possible when traversing the graph and applying rewrite rules,
|
||||
each with different trade offs. At the moment, we define one that performs top-down traversal of nodes.
|
||||
|
||||
@TODO: Is a bottom-up traversal more efficient?
|
||||
|
@ -36,7 +36,7 @@ class RuleBasedGraphTransformer : public GraphTransformer {
|
|||
/** Registers a rewrite rule in this transformer. */
|
||||
Status Register(std::unique_ptr<RewriteRule> rule);
|
||||
|
||||
/** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type
|
||||
/** Gets the list of registered rewrite rules that will be triggered on nodes with the given op type
|
||||
by this rule-based transformer.
|
||||
@returns a pointer to the vector containing all the registered rewrite rules. */
|
||||
const InlinedVector<std::reference_wrapper<const RewriteRule>>* GetRewriteRulesForOpType(const std::string& op_type) const {
|
||||
|
|
|
@ -240,23 +240,23 @@ class ThreadPoolProfiler {
|
|||
~ThreadPoolProfiler();
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler);
|
||||
using Clock = std::chrono::high_resolution_clock;
|
||||
void Start(); //called by executor to start profiling
|
||||
std::string Stop(); //called by executor to stop profiling and return collected numbers
|
||||
void LogStart(); //called in main thread to record the starting time point
|
||||
void LogEnd(ThreadPoolEvent); //called in main thread to calculate and save the time elapsed from last start point
|
||||
void Start(); // called by executor to start profiling
|
||||
std::string Stop(); // called by executor to stop profiling and return collected numbers
|
||||
void LogStart(); // called in main thread to record the starting time point
|
||||
void LogEnd(ThreadPoolEvent); // called in main thread to calculate and save the time elapsed from last start point
|
||||
void LogEndAndStart(ThreadPoolEvent);
|
||||
void LogStartAndCoreAndBlock(std::ptrdiff_t block_size);
|
||||
void LogCoreAndBlock(std::ptrdiff_t block_size); //called in main thread to log core and block size for task breakdown
|
||||
void LogThreadId(int thread_idx); //called in child thread to log its id
|
||||
void LogRun(int thread_idx); //called in child thread to log num of run
|
||||
std::string DumpChildThreadStat(); //return all child statitics collected so far
|
||||
void LogCoreAndBlock(std::ptrdiff_t block_size); // called in main thread to log core and block size for task breakdown
|
||||
void LogThreadId(int thread_idx); // called in child thread to log its id
|
||||
void LogRun(int thread_idx); // called in child thread to log num of run
|
||||
std::string DumpChildThreadStat(); // return all child statitics collected so far
|
||||
|
||||
private:
|
||||
static const char* GetEventName(ThreadPoolEvent);
|
||||
struct MainThreadStat {
|
||||
uint64_t events_[MAX_EVENT] = {};
|
||||
int32_t core_ = -1;
|
||||
std::vector<std::ptrdiff_t> blocks_; //block size determined by cost model
|
||||
std::vector<std::ptrdiff_t> blocks_; // block size determined by cost model
|
||||
std::vector<onnxruntime::TimePoint> points_;
|
||||
void LogCore();
|
||||
void LogBlockSize(std::ptrdiff_t block_size);
|
||||
|
@ -266,7 +266,7 @@ class ThreadPoolProfiler {
|
|||
std::string Reset();
|
||||
};
|
||||
bool enabled_ = false;
|
||||
MainThreadStat& GetMainThreadStat(); //return thread local stat
|
||||
MainThreadStat& GetMainThreadStat(); // return thread local stat
|
||||
int num_threads_;
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(push)
|
||||
|
@ -277,7 +277,7 @@ class ThreadPoolProfiler {
|
|||
std::thread::id thread_id_;
|
||||
uint64_t num_run_ = 0;
|
||||
onnxruntime::TimePoint last_logged_point_ = Clock::now();
|
||||
int32_t core_ = -1; //core that the child thread is running on
|
||||
int32_t core_ = -1; // core that the child thread is running on
|
||||
};
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
|
@ -770,7 +770,8 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
|
|||
for (auto i = 0u; i < num_threads_; i++) {
|
||||
worker_data_[i].thread.reset(env_.CreateThread(name, i, WorkerLoop, this, thread_options));
|
||||
}
|
||||
} ORT_CATCH(...) {
|
||||
}
|
||||
ORT_CATCH(...) {
|
||||
ORT_HANDLE_EXCEPTION([&]() {
|
||||
SignalAllAndWait();
|
||||
throw;
|
||||
|
@ -1336,7 +1337,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
|
|||
#pragma warning(push)
|
||||
// C4324: structure was padded due to alignment specifier
|
||||
#pragma warning(disable : 4324)
|
||||
#endif // _MSC_VER
|
||||
#endif // _MSC_VER
|
||||
|
||||
struct ORT_ALIGN_TO_AVOID_FALSE_SHARING PerThread {
|
||||
constexpr PerThread() : pool(nullptr) {
|
||||
|
@ -1358,8 +1359,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
|
|||
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
#endif // _MSC_VER
|
||||
|
||||
#endif // _MSC_VER
|
||||
|
||||
struct WorkerData {
|
||||
constexpr WorkerData() : thread(), queue() {
|
||||
|
|
|
@ -127,7 +127,6 @@ struct TensorOpCost {
|
|||
double compute_cycles;
|
||||
};
|
||||
|
||||
|
||||
namespace concurrency {
|
||||
|
||||
template <typename Environment>
|
||||
|
@ -197,11 +196,11 @@ class ThreadPool {
|
|||
// parallel loops.
|
||||
|
||||
class ParallelSection {
|
||||
public:
|
||||
explicit ParallelSection(ThreadPool *tp);
|
||||
public:
|
||||
explicit ParallelSection(ThreadPool* tp);
|
||||
~ParallelSection();
|
||||
|
||||
private:
|
||||
private:
|
||||
friend class ThreadPool;
|
||||
|
||||
// Owning reference for the underlying ThreadPoolParallelSection
|
||||
|
@ -210,7 +209,7 @@ class ThreadPool {
|
|||
// ThreadPoolParallelSection does not need to be available at this
|
||||
// point to avoid a dependence on the Eigen headers.
|
||||
ThreadPoolParallelSection* ps_{nullptr};
|
||||
ThreadPool *tp_;
|
||||
ThreadPool* tp_;
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ParallelSection);
|
||||
};
|
||||
|
||||
|
|
|
@ -25,53 +25,53 @@ Environment:
|
|||
// in this file configures the provider as a normal (non-telemetry) provider.
|
||||
#ifndef TraceLoggingOptionMicrosoftTelemetry
|
||||
#define TraceLoggingOptionMicrosoftTelemetry() \
|
||||
TraceLoggingOptionGroup(0000000000, 00000, 00000, 0000, 0000, 0000, 0000, 0000, 000, 0000, 0000)
|
||||
// Empty definition for TraceLoggingOptionMicrosoftTelemetry
|
||||
TraceLoggingOptionGroup(0000000000, 00000, 00000, 0000, 0000, 0000, 0000, 0000, 000, 0000, 0000)
|
||||
// Empty definition for TraceLoggingOptionMicrosoftTelemetry
|
||||
#endif
|
||||
|
||||
// Configuration macro for use in TRACELOGGING_DEFINE_PROVIDER. The definition
|
||||
// in this file configures the provider as a normal (non-telemetry) provider.
|
||||
#define TraceLoggingOptionWindowsCoreTelemetry() \
|
||||
// Empty definition for TraceLoggingOptionWindowsCoreTelemetry
|
||||
// Empty definition for TraceLoggingOptionWindowsCoreTelemetry
|
||||
|
||||
// Event privacy tags. Use the PDT macro values for the tag parameter, e.g.:
|
||||
// TraceLoggingWrite(...,
|
||||
// TelemetryPrivacyDataTag(PDT_BrowsingHistory | PDT_ProductAndServiceUsage),
|
||||
// ...);
|
||||
#define TelemetryPrivacyDataTag(tag) TraceLoggingUInt64((tag), "PartA_PrivTags")
|
||||
#define PDT_BrowsingHistory 0x0000000000000002u
|
||||
#define PDT_BrowsingHistory 0x0000000000000002u
|
||||
#define PDT_DeviceConnectivityAndConfiguration 0x0000000000000800u
|
||||
#define PDT_InkingTypingAndSpeechUtterance 0x0000000000020000u
|
||||
#define PDT_ProductAndServicePerformance 0x0000000001000000u
|
||||
#define PDT_ProductAndServiceUsage 0x0000000002000000u
|
||||
#define PDT_SoftwareSetupAndInventory 0x0000000080000000u
|
||||
#define PDT_InkingTypingAndSpeechUtterance 0x0000000000020000u
|
||||
#define PDT_ProductAndServicePerformance 0x0000000001000000u
|
||||
#define PDT_ProductAndServiceUsage 0x0000000002000000u
|
||||
#define PDT_SoftwareSetupAndInventory 0x0000000080000000u
|
||||
|
||||
// Event categories specified via keywords, e.g.:
|
||||
// TraceLoggingWrite(...,
|
||||
// TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
|
||||
// ...);
|
||||
#define MICROSOFT_KEYWORD_CRITICAL_DATA 0x0000800000000000 // Bit 47
|
||||
#define MICROSOFT_KEYWORD_MEASURES 0x0000400000000000 // Bit 46
|
||||
#define MICROSOFT_KEYWORD_TELEMETRY 0x0000200000000000 // Bit 45
|
||||
#define MICROSOFT_KEYWORD_RESERVED_44 0x0000100000000000 // Bit 44 (reserved for future assignment)
|
||||
#define MICROSOFT_KEYWORD_CRITICAL_DATA 0x0000800000000000 // Bit 47
|
||||
#define MICROSOFT_KEYWORD_MEASURES 0x0000400000000000 // Bit 46
|
||||
#define MICROSOFT_KEYWORD_TELEMETRY 0x0000200000000000 // Bit 45
|
||||
#define MICROSOFT_KEYWORD_RESERVED_44 0x0000100000000000 // Bit 44 (reserved for future assignment)
|
||||
|
||||
// Event categories specified via event tags, e.g.:
|
||||
// TraceLoggingWrite(...,
|
||||
// TraceLoggingEventTag(MICROSOFT_EVENTTAG_REALTIME_LATENCY),
|
||||
// ...);
|
||||
#define MICROSOFT_EVENTTAG_DROP_USER_IDS 0x00008000
|
||||
#define MICROSOFT_EVENTTAG_AGGREGATE 0x00010000
|
||||
#define MICROSOFT_EVENTTAG_DROP_PII_EXCEPT_IP 0x00020000
|
||||
#define MICROSOFT_EVENTTAG_COSTDEFERRED_LATENCY 0x00040000
|
||||
#define MICROSOFT_EVENTTAG_CORE_DATA 0x00080000
|
||||
#define MICROSOFT_EVENTTAG_INJECT_XTOKEN 0x00100000
|
||||
#define MICROSOFT_EVENTTAG_REALTIME_LATENCY 0x00200000
|
||||
#define MICROSOFT_EVENTTAG_NORMAL_LATENCY 0x00400000
|
||||
#define MICROSOFT_EVENTTAG_CRITICAL_PERSISTENCE 0x00800000
|
||||
#define MICROSOFT_EVENTTAG_NORMAL_PERSISTENCE 0x01000000
|
||||
#define MICROSOFT_EVENTTAG_DROP_PII 0x02000000
|
||||
#define MICROSOFT_EVENTTAG_HASH_PII 0x04000000
|
||||
#define MICROSOFT_EVENTTAG_MARK_PII 0x08000000
|
||||
#define MICROSOFT_EVENTTAG_DROP_USER_IDS 0x00008000
|
||||
#define MICROSOFT_EVENTTAG_AGGREGATE 0x00010000
|
||||
#define MICROSOFT_EVENTTAG_DROP_PII_EXCEPT_IP 0x00020000
|
||||
#define MICROSOFT_EVENTTAG_COSTDEFERRED_LATENCY 0x00040000
|
||||
#define MICROSOFT_EVENTTAG_CORE_DATA 0x00080000
|
||||
#define MICROSOFT_EVENTTAG_INJECT_XTOKEN 0x00100000
|
||||
#define MICROSOFT_EVENTTAG_REALTIME_LATENCY 0x00200000
|
||||
#define MICROSOFT_EVENTTAG_NORMAL_LATENCY 0x00400000
|
||||
#define MICROSOFT_EVENTTAG_CRITICAL_PERSISTENCE 0x00800000
|
||||
#define MICROSOFT_EVENTTAG_NORMAL_PERSISTENCE 0x01000000
|
||||
#define MICROSOFT_EVENTTAG_DROP_PII 0x02000000
|
||||
#define MICROSOFT_EVENTTAG_HASH_PII 0x04000000
|
||||
#define MICROSOFT_EVENTTAG_MARK_PII 0x08000000
|
||||
|
||||
// Field categories specified via field tags, e.g.:
|
||||
// TraceLoggingWrite(...,
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4201) // nonstandard extension used: nameless struct/union
|
||||
#pragma warning(disable : 4201) // nonstandard extension used: nameless struct/union
|
||||
#ifdef _GAMING_XBOX_SCARLETT
|
||||
#include <d3d12_xs.h>
|
||||
#elif defined(_GAMING_XBOX_XBOXONE)
|
||||
|
@ -15,10 +15,10 @@
|
|||
#pragma warning(pop)
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include <DirectML.h>
|
||||
#include <DirectML.h>
|
||||
#else
|
||||
struct IDMLDevice;
|
||||
typedef struct IDMLDevice IDMLDevice;
|
||||
struct IDMLDevice;
|
||||
typedef struct IDMLDevice IDMLDevice;
|
||||
#endif
|
||||
|
||||
// Windows pollutes the macro space, causing a build break in constants.h.
|
||||
|
@ -36,8 +36,8 @@ extern "C" {
|
|||
* The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead.
|
||||
*
|
||||
* Creates a DirectML Execution Provider which executes on the hardware adapter with the given device_id, also known as
|
||||
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
|
||||
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
|
||||
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
|
||||
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
|
||||
* primary display GPU installed on the system. A negative device_id is invalid.
|
||||
*/
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
|
||||
|
@ -49,8 +49,8 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOpti
|
|||
*
|
||||
* Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12
|
||||
* command queue. The DirectML device and D3D12 command queue must have the same parent ID3D12Device, or an error will
|
||||
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
|
||||
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
|
||||
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
|
||||
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
|
||||
* objects.
|
||||
* See also: DMLCreateDevice
|
||||
* See also: ID3D12Device::CreateCommandQueue
|
||||
|
@ -58,47 +58,46 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOpti
|
|||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSessionOptions* options,
|
||||
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
|
||||
|
||||
|
||||
struct OrtDmlApi;
|
||||
typedef struct OrtDmlApi OrtDmlApi;
|
||||
|
||||
struct OrtDmlApi {
|
||||
/**
|
||||
* Creates a DirectML Execution Provider which executes on the hardware adapter with the given device_id, also known as
|
||||
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
|
||||
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
|
||||
* the adapter index. The device ID corresponds to the enumeration order of hardware adapters as given by
|
||||
* IDXGIFactory::EnumAdapters. A device_id of 0 always corresponds to the default adapter, which is typically the
|
||||
* primary display GPU installed on the system. A negative device_id is invalid.
|
||||
*/
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id);
|
||||
|
||||
/**
|
||||
* Creates a DirectML Execution Provider using the given DirectML device, and which executes work on the supplied D3D12
|
||||
* command queue. The DirectML device and D3D12 command queue must have the same parent ID3D12Device, or an error will
|
||||
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
|
||||
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
|
||||
* be returned. The D3D12 command queue must be of type DIRECT or COMPUTE (see D3D12_COMMAND_LIST_TYPE). If this
|
||||
* function succeeds, the inference session maintains a strong reference on both the dml_device and the command_queue
|
||||
* objects.
|
||||
* See also: DMLCreateDevice
|
||||
* See also: ID3D12Device::CreateCommandQueue
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML1, _In_ OrtSessionOptions* options,
|
||||
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
|
||||
_In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue);
|
||||
|
||||
/**
|
||||
* CreateGPUAllocationFromD3DResource
|
||||
* This API creates a DML EP resource based on a user-specified D3D12 resource.
|
||||
*/
|
||||
* CreateGPUAllocationFromD3DResource
|
||||
* This API creates a DML EP resource based on a user-specified D3D12 resource.
|
||||
*/
|
||||
ORT_API2_STATUS(CreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* d3d_resource, _Out_ void** dml_resource);
|
||||
|
||||
/**
|
||||
* FreeGPUAllocation
|
||||
* This API frees the DML EP resource created by CreateGPUAllocationFromD3DResource.
|
||||
*/
|
||||
* FreeGPUAllocation
|
||||
* This API frees the DML EP resource created by CreateGPUAllocationFromD3DResource.
|
||||
*/
|
||||
ORT_API2_STATUS(FreeGPUAllocation, _In_ void* dml_resource);
|
||||
|
||||
/**
|
||||
* GetD3D12ResourceFromAllocation
|
||||
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
|
||||
*/
|
||||
* GetD3D12ResourceFromAllocation
|
||||
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
|
||||
*/
|
||||
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource);
|
||||
};
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ struct Value : Ort::Value {
|
|||
static Ort::Value CreateTensor(const std::vector<int64_t>& shape, ONNXTensorElementDataType type);
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace Experimental
|
||||
} // namespace Ort
|
||||
|
||||
#include "experimental_onnxruntime_cxx_inline.h"
|
||||
|
|
|
@ -104,5 +104,5 @@ inline Ort::Value Value::CreateTensor(const std::vector<int64_t>& shape, ONNXTen
|
|||
return Ort::Value::CreateTensor(allocator, shape.data(), shape.size(), type);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace Experimental
|
||||
} // namespace Ort
|
||||
|
|
|
@ -452,7 +452,6 @@ typedef struct OrtCUDAProviderOptions {
|
|||
*/
|
||||
int tunable_op_tuning_enable;
|
||||
|
||||
|
||||
} OrtCUDAProviderOptions;
|
||||
|
||||
/** \brief ROCM Provider Options
|
||||
|
|
|
@ -356,10 +356,10 @@ using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
|
|||
* constructors to construct an instance of a Status object from exceptions.
|
||||
*/
|
||||
struct Status : detail::Base<OrtStatus> {
|
||||
explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
|
||||
explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
|
||||
explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
|
||||
explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
|
||||
explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
|
||||
explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
|
||||
explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
|
||||
explicit Status(const std::exception&) noexcept; ///< Creates status instance out of exception
|
||||
Status(const char* message, OrtErrorCode code) noexcept; ///< Creates status instance out of null-terminated string message.
|
||||
std::string GetErrorMessage() const;
|
||||
OrtErrorCode GetErrorCode() const;
|
||||
|
@ -473,7 +473,6 @@ struct RunOptions : detail::Base<OrtRunOptions> {
|
|||
RunOptions& UnsetTerminate();
|
||||
};
|
||||
|
||||
|
||||
namespace detail {
|
||||
// Utility function that returns a SessionOption config entry key for a specific custom operator.
|
||||
// Ex: custom_op.[custom_op_name].[config]
|
||||
|
@ -514,7 +513,7 @@ struct CustomOpConfigs {
|
|||
* {"my_op.key", "value"}.
|
||||
*
|
||||
* \return An unordered map of flattened configurations.
|
||||
*/
|
||||
*/
|
||||
const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
|
||||
|
||||
private:
|
||||
|
@ -574,7 +573,7 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
|
|||
|
||||
SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
|
||||
|
||||
SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
|
||||
SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
|
||||
|
||||
SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
|
||||
SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# Requires a .clang-format config file to be in the current directory or a parent directory from where the script is run.
|
||||
# Expected usage is to run it from its current location so that source in 'core' and 'test' is updated.
|
||||
|
||||
gci -Recurse -Include *.h, *.cc | foreach {
|
||||
Write-Host "Updating " $_.FullName
|
||||
clang-format -i $_
|
||||
}
|
||||
gci -Recurse -Include *.h, *.cc | foreach {
|
||||
Write-Host "Updating " $_.FullName
|
||||
clang-format -i $_
|
||||
}
|
||||
|
|
|
@ -14,14 +14,14 @@ class IAttentionMechanism {
|
|||
virtual ~IAttentionMechanism() = default;
|
||||
|
||||
virtual void PrepareMemory(
|
||||
const gsl::span<const T>& memory,
|
||||
const gsl::span<const int>& memory_sequence_lengths) = 0;
|
||||
const gsl::span<const T>& memory,
|
||||
const gsl::span<const int>& memory_sequence_lengths) = 0;
|
||||
|
||||
virtual void Compute(
|
||||
const gsl::span<const T>& query,
|
||||
const gsl::span<const T>& prev_alignment,
|
||||
const gsl::span<T>& output,
|
||||
const gsl::span<T>& alignment) const = 0;
|
||||
const gsl::span<const T>& query,
|
||||
const gsl::span<const T>& prev_alignment,
|
||||
const gsl::span<T>& output,
|
||||
const gsl::span<T>& alignment) const = 0;
|
||||
|
||||
virtual const gsl::span<const T> Values() const = 0;
|
||||
|
||||
|
@ -32,5 +32,5 @@ class IAttentionMechanism {
|
|||
virtual bool NeedPrevAlignment() const = 0;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#include <memory>
|
||||
|
||||
using onnxruntime::rnn::detail::Allocate;
|
||||
//TODO: fix the warnings
|
||||
// TODO: fix the warnings
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
#pragma warning(disable : 26451)
|
||||
#endif
|
||||
|
@ -55,9 +55,9 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
|
|||
}
|
||||
|
||||
if (has_attn_layer_) {
|
||||
//concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) =
|
||||
// p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_
|
||||
// The first part is calulated above. Here just add the later.
|
||||
// concat([p_cell_output, context]) * stack([attn_layer_cell_weights_, attn_layer_attn_weights_]) =
|
||||
// p_cell_output * attn_layer_cell_weights_ + context * attn_layer_attn_weights_
|
||||
// The first part is calulated above. Here just add the later.
|
||||
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
|
||||
attn_context_.data(), attn_context_depth_,
|
||||
|
@ -76,7 +76,7 @@ void AttentionWrapper<T>::SetWeights(const gsl::span<const T>& wrapper_weights)
|
|||
has_attn_layer_ = !wrapper_weights.empty();
|
||||
|
||||
if (has_attn_layer_) {
|
||||
//cell weight size and attn weight size in the attn layer
|
||||
// cell weight size and attn weight size in the attn layer
|
||||
size_t cws = inner_cell_hidden_size_ * attn_layer_depth_;
|
||||
size_t aws = attn_context_depth_ * attn_layer_depth_;
|
||||
attn_layer_cell_weights_ = wrapper_weights.subspan(0, cws);
|
||||
|
|
|
@ -108,17 +108,17 @@ static void SoftmaxInplace(const gsl::span<T>& alignments) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Args:
|
||||
* queries: Tensor, shape `[batch_size_, query_depth_]` to compare to keys.
|
||||
* keys_: Processed memory, shape `[batch_size_, max_memory_step_, attn_depth_]`.
|
||||
*/
|
||||
* Args:
|
||||
* queries: Tensor, shape `[batch_size_, query_depth_]` to compare to keys.
|
||||
* keys_: Processed memory, shape `[batch_size_, max_memory_step_, attn_depth_]`.
|
||||
*/
|
||||
template <typename T>
|
||||
void BahdanauAttention<T>::Compute(
|
||||
const gsl::span<const T>& queries,
|
||||
const gsl::span<const T>&, // Not used by bahdanau attention
|
||||
const gsl::span<T>& output,
|
||||
const gsl::span<T>& aligns) const {
|
||||
//process query in dense query layer without bias
|
||||
// process query in dense query layer without bias
|
||||
math::GemmEx<T>(CblasNoTrans, CblasNoTrans,
|
||||
batch_size_, attn_depth_, query_depth_, T{1.0},
|
||||
queries.data(), query_depth_,
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
#include "core/common/narrow.h"
|
||||
#include "core/platform/threadpool.h"
|
||||
#include "core/framework/allocator.h"
|
||||
//TODO: fix the warnings
|
||||
// TODO: fix the warnings
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
// Chance of arithmetic overflow could be reduced
|
||||
#pragma warning(disable : 26451)
|
||||
|
@ -26,7 +26,7 @@ extern template class BahdanauAttention<float>;
|
|||
|
||||
/* AttnLSTM operator */
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
AttnLSTM, //name
|
||||
AttnLSTM, // name
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
//TODO: fix the warnings
|
||||
// TODO: fix the warnings
|
||||
#if defined(_MSC_VER) && !defined(__clang__)
|
||||
// Chance of arithmetic overflow could be reduced
|
||||
#pragma warning(disable : 26451)
|
||||
|
@ -280,7 +280,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
// after the first step this will switch to the output from the previous step
|
||||
span_T_const_iter previous_state = batched_hidden_state_one_step.begin();
|
||||
|
||||
//run through steps sequentially
|
||||
// run through steps sequentially
|
||||
for (int step = 0; step < max_sequence_length; step++) {
|
||||
const std::string seqno_str = " [seqno=" + std::to_string(step) + "]";
|
||||
|
||||
|
@ -335,7 +335,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
|
|||
}
|
||||
|
||||
if (output_sequence) {
|
||||
//set to 0 if step >= sequence_length
|
||||
// set to 0 if step >= sequence_length
|
||||
for (int lrow = 0; lrow < batch_size_; lrow++) {
|
||||
if (step >= min_sequence_length && step >= sequence_lengths[lrow]) {
|
||||
auto dst = outputs.data() + step * output_step_length + lrow * hidden_size_;
|
||||
|
@ -409,7 +409,7 @@ void UniDirectionalAttnLstm<T>::GateComputations(span_T_iter& out, span_T_iter&
|
|||
const float* pBi = use_bias_ ? SafeRawConstPointer<T>(bias_WRi_, 0, hidden_size_) : nullptr;
|
||||
clip_with_bias_ptr_(clip_, pBi, pi, hidden_size_); // post: pi has input to f() to calculate i
|
||||
activation_f_.func(pi, hidden_size_, activation_f_.alpha, activation_f_.beta);
|
||||
//DumpMatrix("i" + row_str, pi, 1, hidden_size_);
|
||||
// DumpMatrix("i" + row_str, pi, 1, hidden_size_);
|
||||
|
||||
// Forget Gate
|
||||
if (input_forget_) {
|
||||
|
|
|
@ -279,9 +279,8 @@ Status AttentionBase::CheckMask(const Tensor* mask_index,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size or 3 * batch_size + 2");
|
||||
}
|
||||
mask_type = (mask_dims[0] == batch_size ?
|
||||
AttentionMaskType::MASK_1D_KEY_SEQ_LEN :
|
||||
mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START : AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
|
||||
mask_type = (mask_dims[0] == batch_size ? AttentionMaskType::MASK_1D_KEY_SEQ_LEN : mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START
|
||||
: AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
|
||||
} else if (mask_dims.size() == 2) {
|
||||
if (mask_dims[0] == batch_size && mask_dims[1] == total_sequence_length) {
|
||||
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
|
||||
|
|
|
@ -42,16 +42,16 @@ enum AttentionKernelType {
|
|||
struct AttentionParameters {
|
||||
int batch_size;
|
||||
int sequence_length;
|
||||
int kv_sequence_length; // input sequence length of K or V
|
||||
int past_sequence_length; // sequence length in past state of K or V
|
||||
int original_past_sequence_length; // original sequence length in past state of K or V
|
||||
int total_sequence_length; // total sequence length of K or V
|
||||
int max_sequence_length; // max sequence length from 4D mask
|
||||
int input_hidden_size; // first dimension of weights for input projection
|
||||
int hidden_size; // hidden size of Q or K
|
||||
int head_size; // hidden size per head of Q or K
|
||||
int v_hidden_size; // hidden size of V
|
||||
int v_head_size; // hidden size per head of V
|
||||
int kv_sequence_length; // input sequence length of K or V
|
||||
int past_sequence_length; // sequence length in past state of K or V
|
||||
int original_past_sequence_length; // original sequence length in past state of K or V
|
||||
int total_sequence_length; // total sequence length of K or V
|
||||
int max_sequence_length; // max sequence length from 4D mask
|
||||
int input_hidden_size; // first dimension of weights for input projection
|
||||
int hidden_size; // hidden size of Q or K
|
||||
int head_size; // hidden size per head of Q or K
|
||||
int v_hidden_size; // hidden size of V
|
||||
int v_head_size; // hidden size per head of V
|
||||
int num_heads;
|
||||
bool is_unidirectional;
|
||||
bool past_present_share_buffer;
|
||||
|
|
|
@ -16,21 +16,21 @@ namespace contrib {
|
|||
class AttentionCPUBase : public AttentionBase {
|
||||
protected:
|
||||
AttentionCPUBase(const OpKernelInfo& info, bool require_same_hidden_size)
|
||||
: AttentionBase(info, require_same_hidden_size) {}
|
||||
: AttentionBase(info, require_same_hidden_size) {}
|
||||
|
||||
template <typename T>
|
||||
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
|
||||
const T* K, // K data with shape BxNxSxH
|
||||
const T* V, // V value with size BxNxSxH_v
|
||||
const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
|
||||
const Tensor* past, // past state
|
||||
Tensor* output, // output tensor
|
||||
int batch_size, // batch size (B)
|
||||
int sequence_length, // sequence length (S)
|
||||
int qk_head_size, // head size of Q or K (H)
|
||||
int v_head_size, // head size of V (H_v)
|
||||
int v_hidden_size, // hidden size of V (D_v)
|
||||
const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT
|
||||
Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
|
||||
const T* K, // K data with shape BxNxSxH
|
||||
const T* V, // V value with size BxNxSxH_v
|
||||
const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
|
||||
const Tensor* past, // past state
|
||||
Tensor* output, // output tensor
|
||||
int batch_size, // batch size (B)
|
||||
int sequence_length, // sequence length (S)
|
||||
int qk_head_size, // head size of Q or K (H)
|
||||
int v_head_size, // head size of V (H_v)
|
||||
int v_hidden_size, // hidden size of V (D_v)
|
||||
const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT
|
||||
OpKernelContext* context) const {
|
||||
const int kv_sequence_length = sequence_length;
|
||||
|
||||
|
@ -206,7 +206,7 @@ class AttentionCPUBase : public AttentionBase {
|
|||
const T* past, // past state
|
||||
T* present, // present state
|
||||
ThreadPool* tp) const {
|
||||
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
|
||||
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
|
||||
const ptrdiff_t past_chunk_length = SafeInt<ptrdiff_t>(past_sequence_length) * v_head_size; // P x H_v
|
||||
const ptrdiff_t input_chunk_length = SafeInt<ptrdiff_t>(kv_sequence_length) * v_head_size; // L x H_v
|
||||
const ptrdiff_t present_chunk_length = past_chunk_length + input_chunk_length; // T x H_v
|
||||
|
|
|
@ -60,7 +60,7 @@ Status BiasGelu<T, use_approximation>::Compute(OpKernelContext* context) const {
|
|||
p_output[i] = value * (static_cast<T>(C) * value * value + static_cast<T>(B));
|
||||
}
|
||||
|
||||
MlasComputeTanh(p_output, p_output,narrow<size_t>(count));
|
||||
MlasComputeTanh(p_output, p_output, narrow<size_t>(count));
|
||||
|
||||
for (int64_t i = 0; i < count; i++) {
|
||||
p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f);
|
||||
|
@ -106,7 +106,7 @@ void BiasGelu<T, use_approximation>::AddBiasGelu(
|
|||
temp[i] = value * 0.5f;
|
||||
}
|
||||
|
||||
MlasComputeTanh(output, output,narrow<size_t>(count));
|
||||
MlasComputeTanh(output, output, narrow<size_t>(count));
|
||||
|
||||
for (int64_t i = 0; i < count; i++) {
|
||||
output[i] = temp[i] * (output[i] + 1.0f);
|
||||
|
@ -118,7 +118,7 @@ void BiasGelu<T, use_approximation>::AddBiasGelu(
|
|||
temp[i] = value * 0.5f;
|
||||
}
|
||||
|
||||
MlasComputeErf(output, output,narrow<size_t>(count));
|
||||
MlasComputeErf(output, output, narrow<size_t>(count));
|
||||
|
||||
for (int64_t i = 0; i < count; i++) {
|
||||
output[i] = temp[i] * (output[i] + 1.0f);
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
#include "core/common/safeint.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
class BifurcationDetector : public OpKernel {
|
||||
|
|
|
@ -93,11 +93,7 @@ Status EmbedLayerNorm<T>::Compute(OpKernelContext* context) const {
|
|||
failed.store(true, std::memory_order_release);
|
||||
return;
|
||||
}
|
||||
int position_col_index = (position_ids_data == nullptr) ?
|
||||
index % sequence_length :
|
||||
(broadcast_position_ids ?
|
||||
position_ids_data[index % sequence_length] :
|
||||
position_ids_data[index]);
|
||||
int position_col_index = (position_ids_data == nullptr) ? index % sequence_length : (broadcast_position_ids ? position_ids_data[index % sequence_length] : position_ids_data[index]);
|
||||
if (position_col_index >= position_embedding_length) {
|
||||
failed.store(true, std::memory_order_release);
|
||||
return;
|
||||
|
|
|
@ -28,12 +28,12 @@ Status CheckInputs(const OpKernelContext* context, bool quantizedVersion) {
|
|||
if (nullptr != position_ids) {
|
||||
if (input_ids->Shape()[1] != position_ids->Shape()[1]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input_ids and position_ids shall have same sequence_length");
|
||||
"input_ids and position_ids shall have same sequence_length");
|
||||
}
|
||||
if (position_ids->Shape()[0] != input_ids->Shape()[0] &&
|
||||
position_ids->Shape()[0] != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"position_ids's first dimension shall be 1 or batch_size");
|
||||
"position_ids's first dimension shall be 1 or batch_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -163,7 +163,7 @@ Status CheckInputs(const T* query,
|
|||
|
||||
qkv_format = Q_KV_BSNH_BSN2H;
|
||||
kv_sequence_length = static_cast<int>(key_dims[1]);
|
||||
} else { // key_dims.size() == 4 (cross-attention with past_key)
|
||||
} else { // key_dims.size() == 4 (cross-attention with past_key)
|
||||
if (static_cast<int>(key_dims[1]) != num_heads || static_cast<int>(key_dims[3]) != head_size) {
|
||||
return ORT_MAKE_STATUS(
|
||||
ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
@ -242,7 +242,7 @@ Status CheckInputs(const T* query,
|
|||
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
|
||||
}
|
||||
v_hidden_size = static_cast<int>(value_dims[2]);
|
||||
} else { // value_dims.size() == 4
|
||||
} else { // value_dims.size() == 4
|
||||
if (static_cast<int64_t>(kv_sequence_length) != value_dims[2]) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'past_key' and 'past_value' shall have the same dim 2 (kv_sequence_length)");
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
#include "ngram_repeat_block.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
|
|
|
@ -9,5 +9,5 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
Status RegisterCpuContribKernels(KernelRegistry& kernel_registry);
|
||||
} // namespace contrib
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -130,4 +130,4 @@ class Crop final : public CropBase, public OpKernel {
|
|||
};
|
||||
|
||||
} // namespace contrib
|
||||
} //namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -21,14 +21,13 @@ namespace contrib {
|
|||
*
|
||||
* TODO!! implemente thread partition similar with
|
||||
* fp16 conv operator
|
||||
*/
|
||||
*/
|
||||
class NhwcPoolFp16 : public OpKernel {
|
||||
public:
|
||||
explicit NhwcPoolFp16(const OpKernelInfo& info)
|
||||
: OpKernel(info),
|
||||
pool_attrs_(info, info.GetKernelDef().OpName(), info.node().SinceVersion()),
|
||||
is_max_pool_(info.GetKernelDef().OpName() == "MaxPool")
|
||||
{}
|
||||
pool_attrs_(info, info.GetKernelDef().OpName(), info.node().SinceVersion()),
|
||||
is_max_pool_(info.GetKernelDef().OpName() == "MaxPool") {}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
|
@ -182,7 +181,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
|
|||
.TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
|
||||
NhwcPoolFp16);
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
@ -6,16 +6,16 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
GridSample, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCpuExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
GridSample, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCpuExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
|
||||
GridSample<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib{
|
||||
namespace contrib {
|
||||
|
||||
template <typename T>
|
||||
class ImageScaler final : public OpKernel {
|
||||
|
@ -54,5 +54,5 @@ class ImageScaler final : public OpKernel {
|
|||
float scale_;
|
||||
std::vector<float> bias_;
|
||||
};
|
||||
}
|
||||
} //namespace onnxruntime
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -62,9 +62,9 @@ inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrix
|
|||
}
|
||||
}
|
||||
|
||||
template<> inline
|
||||
void SparseDenseMatMulImpl<float>(const ComputeCtx& ctx, const ConstSparseMatrixMap<float>& map_A,
|
||||
const ConstEigenMatrixMapRowMajor<float>& map_B, EigenMatrixMapRowMajor<float>& output_map) {
|
||||
template <>
|
||||
inline void SparseDenseMatMulImpl<float>(const ComputeCtx& ctx, const ConstSparseMatrixMap<float>& map_A,
|
||||
const ConstEigenMatrixMapRowMajor<float>& map_B, EigenMatrixMapRowMajor<float>& output_map) {
|
||||
if (ctx.trans_A && ctx.trans_B) {
|
||||
output_map = map_A.transpose() * ctx.alpha * map_B.transpose();
|
||||
} else if (ctx.trans_A && !ctx.trans_B) {
|
||||
|
@ -97,15 +97,15 @@ struct SparseToDenseCsr {
|
|||
}
|
||||
};
|
||||
|
||||
#endif //!defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
|
||||
#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
|
||||
|
||||
template<typename T> inline
|
||||
T Mul(T a_value, float, T b_value) {
|
||||
template <typename T>
|
||||
inline T Mul(T a_value, float, T b_value) {
|
||||
return a_value * b_value;
|
||||
}
|
||||
|
||||
template <> inline
|
||||
constexpr float Mul<float>(float a_value, float alpha, float b_value) {
|
||||
template <>
|
||||
inline constexpr float Mul<float>(float a_value, float alpha, float b_value) {
|
||||
return a_value * alpha * b_value;
|
||||
}
|
||||
|
||||
|
@ -203,7 +203,7 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const {
|
|||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "WASM and 32-bit builds support only COO format");
|
||||
}
|
||||
#endif //!defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
|
||||
#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -211,4 +211,4 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const {
|
|||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif //!defined(DISABLE_SPARSE_TENSORS)
|
||||
#endif //! defined(DISABLE_SPARSE_TENSORS)
|
|
@ -5,14 +5,14 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
// Register MVN operator for backward compatibility.
|
||||
// Register MVN operator for backward compatibility.
|
||||
// The experimental MVN op was removed. The history has to be kept locally as below.
|
||||
// As of (9/26/2018) MVN is a production function in ONNX.
|
||||
// As of (9/26/2018) MVN is a production function in ONNX.
|
||||
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
|
||||
MeanVarianceNormalization,
|
||||
1,
|
||||
8,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
MeanVarianceNormalization_0<float>);
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
// MurmurHash3 was written by Austin Appleby, and is placed in the public
|
||||
// domain. The author hereby disclaims copyright to this source code.
|
||||
|
||||
//scikit-learn is a Python module for machine learning built on top of SciPy and
|
||||
//distributed under the 3-Clause BSD license. See https://github.com/scikit-learn/scikit-learn.
|
||||
//This material is licensed under the BSD License (see https://github.com/scikit-learn/scikit-learn/blob/master/COPYING);
|
||||
// scikit-learn is a Python module for machine learning built on top of SciPy and
|
||||
// distributed under the 3-Clause BSD license. See https://github.com/scikit-learn/scikit-learn.
|
||||
// This material is licensed under the BSD License (see https://github.com/scikit-learn/scikit-learn/blob/master/COPYING);
|
||||
/* Modifications Copyright (c) Microsoft. */
|
||||
|
||||
#include "contrib_ops/cpu/murmur_hash3.h"
|
||||
|
@ -200,7 +200,7 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const {
|
|||
}
|
||||
} else {
|
||||
auto input = reinterpret_cast<const unsigned char*>(keys->DataRaw());
|
||||
//input_element_bytes is 4, 8,.. less than 4 bytes is not allowed
|
||||
// input_element_bytes is 4, 8,.. less than 4 bytes is not allowed
|
||||
int input_num_bytes = static_cast<int>(input_element_bytes);
|
||||
ORT_ENFORCE(input_num_bytes % 4 == 0);
|
||||
const auto input_end = input + input_count * input_num_bytes;
|
||||
|
|
|
@ -18,10 +18,10 @@ class MurmurHash3 final : public OpKernel {
|
|||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
private:
|
||||
void MurmurHash3_x86_32(const void* key, int len, uint32_t seed, void* out) const;
|
||||
|
||||
private :
|
||||
private:
|
||||
uint32_t seed_;
|
||||
bool is_positive_{true};
|
||||
};
|
||||
|
|
|
@ -323,12 +323,12 @@ Status NchwcUpsample::Compute(OpKernelContext* context) const {
|
|||
const auto interpolation_w = ComputeInterpolation(input_w, output_w, scales_[3]);
|
||||
|
||||
const int64_t nchwc_block_size = static_cast<int64_t>(MlasNchwcGetBlockSize());
|
||||
const ptrdiff_t total_work =((SafeInt<ptrdiff_t>(batch_count) * nchwc_channels) / nchwc_block_size) * output_h;
|
||||
const ptrdiff_t total_work = ((SafeInt<ptrdiff_t>(batch_count) * nchwc_channels) / nchwc_block_size) * output_h;
|
||||
// Partition the work with the goal of generating the following number of
|
||||
// elements, so that operations involving a smaller number of columns will
|
||||
// process more rows per worker.
|
||||
constexpr ptrdiff_t worker_goal = 16 * 1024;
|
||||
ptrdiff_t work_per_worker = std::max<ptrdiff_t>(worker_goal / (SafeInt<ptrdiff_t>(output_w) * nchwc_block_size), 1);
|
||||
ptrdiff_t work_per_worker = std::max<ptrdiff_t>(worker_goal / (SafeInt<ptrdiff_t>(output_w) * nchwc_block_size), 1);
|
||||
ptrdiff_t worker_count = std::max<ptrdiff_t>(total_work / work_per_worker, 1);
|
||||
|
||||
auto upsample_worker = [&](ptrdiff_t batch) {
|
||||
|
|
|
@ -132,7 +132,7 @@ Status DynamicQuantizeLSTM::UseSharedPrePackedBuffers(std::vector<BufferUniquePt
|
|||
|
||||
#define WeightCheck(weight_shape, weight_name) \
|
||||
if ((weight_shape.NumDimensions() != 1 && weight_shape.NumDimensions() != 2) || \
|
||||
(weight_shape.NumDimensions() == 2 && weight_shape[1] != static_cast<int64_t>(hidden_size_) * 4) || \
|
||||
(weight_shape.NumDimensions() == 2 && weight_shape[1] != static_cast<int64_t>(hidden_size_) * 4) || \
|
||||
weight_shape[0] != num_directions_) { \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, \
|
||||
"Input ", #weight_name, " must have shape {", num_directions_, "} for per-tensor/layer quantization or shape {", \
|
||||
|
|
|
@ -51,4 +51,3 @@ class QLinearSigmoid final : public QLinearLookupBase<T> {
|
|||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
|
||||
QLinearConcat::QLinearConcat(const OpKernelInfo& info) : OpKernel(info), ConcatBase(info) {
|
||||
size_t input_def_count = info.node().InputDefs().size();
|
||||
ORT_ENFORCE(input_def_count >= 5 && (input_def_count - 2) % 3 == 0,
|
||||
|
|
|
@ -19,7 +19,7 @@ class QLinearConcat final : public OpKernel, public ConcatBase {
|
|||
|
||||
private:
|
||||
std::vector<std::vector<uint8_t>> fixed_lookup_tables_;
|
||||
std::vector<int> fixed_table_attrs_; // is_static or not, is_copy or not
|
||||
std::vector<int> fixed_table_attrs_; // is_static or not, is_copy or not
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
@ -21,7 +21,7 @@ class QLinearGlobalAveragePool final : public OpKernel {
|
|||
bool channels_last_;
|
||||
};
|
||||
|
||||
template<typename T8Bits>
|
||||
template <typename T8Bits>
|
||||
Status ComputeQLinearGlobalAvgPool(
|
||||
const T8Bits* x,
|
||||
float x_scale,
|
||||
|
|
|
@ -56,10 +56,10 @@ void QlinearBuildLookupTable(uint8_t* table,
|
|||
|
||||
const float X_scale = *(tensor_x_scale->Data<float>());
|
||||
const T X_zero_point =
|
||||
(tensor_x_zero_point == nullptr) ? static_cast<T>(0) : *(tensor_x_zero_point->Data<T>());
|
||||
(tensor_x_zero_point == nullptr) ? static_cast<T>(0) : *(tensor_x_zero_point->Data<T>());
|
||||
const float Y_scale = *(tensor_y_scale->Data<float>());
|
||||
const T Y_zero_point =
|
||||
(tensor_y_zero_point == nullptr) ? static_cast<T>(0) : *(tensor_y_zero_point->Data<T>());
|
||||
(tensor_y_zero_point == nullptr) ? static_cast<T>(0) : *(tensor_y_zero_point->Data<T>());
|
||||
|
||||
float dequantized_input[256];
|
||||
float dequantized_output[256];
|
||||
|
|
|
@ -37,6 +37,5 @@ void QlinearBuildLookupTable(uint8_t* table,
|
|||
template <typename TOutput>
|
||||
void QLinearLookupTableTransform(const uint8_t* x, const TOutput* table, TOutput* y, size_t n);
|
||||
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -67,7 +67,7 @@ struct QLinearPool1DTask final {
|
|||
}
|
||||
|
||||
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
|
||||
for (std::ptrdiff_t c = begin; c < end; ++c) {
|
||||
for (std::ptrdiff_t c = begin; c < end; ++c) {
|
||||
operator()(c);
|
||||
}
|
||||
}
|
||||
|
@ -149,7 +149,7 @@ struct QLinearPoolNhwc1DTask final {
|
|||
|
||||
int64_t element_count = (pool_attrs_.count_include_pad) ? kernel_shape[0] : hend - hstart;
|
||||
for (int64_t c = 0; c < channels; ++c) {
|
||||
PoolType::Finalize(element_count,Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
PoolType::Finalize(element_count, Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
y_d[phc + c] = quantize_value(Yh[onnxruntime::narrow<size_t>(c)], y_scale, y_zero_point);
|
||||
}
|
||||
}
|
||||
|
@ -181,7 +181,7 @@ struct QLinearPool2DTask final {
|
|||
}
|
||||
|
||||
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
|
||||
for (std::ptrdiff_t c = begin; c < end; ++c) {
|
||||
for (std::ptrdiff_t c = begin; c < end; ++c) {
|
||||
operator()(c);
|
||||
}
|
||||
}
|
||||
|
@ -287,7 +287,7 @@ struct QLinearPoolNhwc2DTask final {
|
|||
int64_t input_index = channels * (h * width + wstart);
|
||||
for (int64_t w = wstart; w < wend; ++w) {
|
||||
for (int64_t c = 0; c < channels; c++) {
|
||||
PoolType::Process(x_d[input_index + c],Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
PoolType::Process(x_d[input_index + c], Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
}
|
||||
input_index += channels;
|
||||
}
|
||||
|
@ -295,7 +295,7 @@ struct QLinearPoolNhwc2DTask final {
|
|||
|
||||
int64_t elements_count = (pool_attrs_.count_include_pad) ? kernel_size : (hend - hstart) * (wend - wstart);
|
||||
for (int64_t c = 0; c < channels; c++) {
|
||||
PoolType::Finalize(elements_count,Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
PoolType::Finalize(elements_count, Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
auto y_value = quantize_value(Yh[onnxruntime::narrow<size_t>(c)], y_scale, y_zero_point);
|
||||
y_d[pool_index + c] = y_value;
|
||||
}
|
||||
|
@ -337,7 +337,7 @@ struct QLinearPool3DTask final {
|
|||
}
|
||||
|
||||
void operator()(std::ptrdiff_t begin, std::ptrdiff_t end) const {
|
||||
for (std::ptrdiff_t c = begin; c < end; ++c) {
|
||||
for (std::ptrdiff_t c = begin; c < end; ++c) {
|
||||
operator()(c);
|
||||
}
|
||||
}
|
||||
|
@ -462,7 +462,7 @@ struct QLinearPoolNhwc3DTask final {
|
|||
int64_t input_index = channels * (input_index_h + w * depth + dstart);
|
||||
for (int64_t d = dstart; d < dend; ++d) {
|
||||
for (int64_t c = 0; c < channels; c++) {
|
||||
PoolType::Process(x_d[input_index + c],Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
PoolType::Process(x_d[input_index + c], Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
}
|
||||
input_index += channels;
|
||||
}
|
||||
|
@ -471,7 +471,7 @@ struct QLinearPoolNhwc3DTask final {
|
|||
|
||||
int64_t elements_count = (pool_attrs_.count_include_pad) ? kernel_size : (hend - hstart) * (wend - wstart) * (dend - dstart);
|
||||
for (int64_t c = 0; c < channels; c++) {
|
||||
PoolType::Finalize(elements_count,Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
PoolType::Finalize(elements_count, Yh[onnxruntime::narrow<size_t>(c)], pool_context_);
|
||||
auto y_value = quantize_value(Yh[onnxruntime::narrow<size_t>(c)], y_scale, y_zero_point);
|
||||
y_d[pool_index + c] = y_value;
|
||||
}
|
||||
|
|
|
@ -25,7 +25,6 @@ static inline bool has_same_zero_point(bool is_signed, const Tensor* tensor_x_ze
|
|||
const uint8_t X_zero_point = (tensor_x_zero_point == nullptr) ? static_cast<uint8_t>(0) : *(tensor_x_zero_point->Data<uint8_t>());
|
||||
const uint8_t Y_zero_point = (tensor_y_zero_point == nullptr) ? static_cast<uint8_t>(0) : *(tensor_y_zero_point->Data<uint8_t>());
|
||||
return X_zero_point == Y_zero_point;
|
||||
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -95,7 +95,7 @@ ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() {
|
|||
if (condition == target) {
|
||||
// Transform the output to the correct value from LookupTable
|
||||
std::transform(value.cbegin(), value.cend(), output.begin(),
|
||||
[condition, target, &look_up_table,is_copy](const T& value_element) {
|
||||
[condition, target, &look_up_table, is_copy](const T& value_element) {
|
||||
return is_copy ? value_element : look_up_table[value_element];
|
||||
});
|
||||
} else {
|
||||
|
@ -112,7 +112,7 @@ ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() {
|
|||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
// Transform the output to the correct value from LookupTable
|
||||
std::transform(condition.begin(), condition.end(), output.begin(),
|
||||
[target, &value,&look_up_table,is_copy](bool condition_element) {
|
||||
[target, &value, &look_up_table, is_copy](bool condition_element) {
|
||||
return condition_element == target ? is_copy ? value : look_up_table[value] : T{};
|
||||
});
|
||||
},
|
||||
|
@ -126,7 +126,7 @@ ProcessBroadcastSpanFuncs CreateNonScalarBroadcastFuncs() {
|
|||
auto output = per_iter_bh.OutputSpan<T>();
|
||||
// Transform the output to the correct value from LookupTable
|
||||
std::transform(condition.begin(), condition.end(), value.cbegin(), output.begin(),
|
||||
[target,&look_up_table,is_copy](bool condition_element, const T& value_element) {
|
||||
[target, &look_up_table, is_copy](bool condition_element, const T& value_element) {
|
||||
return condition_element == target ? is_copy ? value_element : look_up_table[value_element] : T{};
|
||||
});
|
||||
}};
|
||||
|
@ -310,16 +310,16 @@ QLinearWhere::QLinearWhere(const OpKernelInfo& info) : OpKernel(info) {
|
|||
}
|
||||
|
||||
Status QLinearWhere::Compute(OpKernelContext* ctx) const {
|
||||
// const auto* tensor_condition = ctx->Input<Tensor>(0);
|
||||
// const auto* tensor_x_input = ctx->Input<Tensor>(1);
|
||||
// const auto* tensor_condition = ctx->Input<Tensor>(0);
|
||||
// const auto* tensor_x_input = ctx->Input<Tensor>(1);
|
||||
const auto* tensor_x_scale = ctx->Input<Tensor>(2);
|
||||
const auto* tensor_x_zero_point = ctx->Input<Tensor>(3);
|
||||
// const auto* tensor_y_input = ctx->Input<Tensor>(4);
|
||||
// const auto* tensor_y_input = ctx->Input<Tensor>(4);
|
||||
const auto* tensor_y_scale = ctx->Input<Tensor>(5);
|
||||
const auto* tensor_y_zero_point = ctx->Input<Tensor>(6);
|
||||
const auto* tensor_z_scale = ctx->Input<Tensor>(7);
|
||||
const auto* tensor_z_zero_point = ctx->Input<Tensor>(8);
|
||||
// auto* tensor_output = ctx->Output(0, tensor_condition->Shape());
|
||||
// auto* tensor_output = ctx->Output(0, tensor_condition->Shape());
|
||||
ORT_ENFORCE(tensor_x_scale->IsDataType<float>(), "Input scale is not float for quantized input x @ 2");
|
||||
ORT_ENFORCE(tensor_y_scale->IsDataType<float>(), "Input scale is not float for quantized input y @ 5");
|
||||
ORT_ENFORCE(tensor_z_scale->IsDataType<float>(), "Input scale is not float for quantized output z @ 7");
|
||||
|
@ -377,7 +377,7 @@ Status QLinearWhere::Compute(OpKernelContext* ctx) const {
|
|||
if (!is_y_copy) {
|
||||
std::copy(y_lookup_table.begin(), y_lookup_table.end(), y_user_data.begin() + 2);
|
||||
}
|
||||
//Allocator, Allocation, and SelectBroadcastFuncs are the same implementation from where_op.cc
|
||||
// Allocator, Allocation, and SelectBroadcastFuncs are the same implementation from where_op.cc
|
||||
const auto typed_tensor_allocation = [](const TensorAllocator& allocator,
|
||||
const TensorShape& shape) {
|
||||
return allocator.Allocate<uint8_t>(shape);
|
||||
|
|
|
@ -32,7 +32,7 @@ class QGemm : protected GemmBase, public MatMulIntegerBase {
|
|||
size_t N = SafeInt<size_t>(helper.N());
|
||||
size_t K = SafeInt<size_t>(helper.K());
|
||||
|
||||
//validate scales and zero points
|
||||
// validate scales and zero points
|
||||
const auto* a_zp = context->Input<Tensor>(IN_A_ZERO_POINT);
|
||||
const auto* b_zp = context->Input<Tensor>(IN_B_ZERO_POINT);
|
||||
const auto* y_zp = context->Input<Tensor>(IN_Y_ZERO_POINT);
|
||||
|
|
|
@ -91,7 +91,6 @@ Status SkipLayerNorm<T>::Compute(OpKernelContext* p_ctx) const {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);
|
||||
|
||||
const T* input_data = input->Data<T>();
|
||||
|
|
|
@ -149,14 +149,13 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
|
|||
t5_decoder_subgraph_->head_size,
|
||||
t5_decoder_subgraph_->num_layers);
|
||||
}
|
||||
}
|
||||
else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) {
|
||||
} else if (parameters_.model_type == IGenerationParameters::kModelTypeWhisper) {
|
||||
if (attribute_name == "encoder") {
|
||||
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
|
||||
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
t5_encoder_subgraph_ = std::make_unique<WhisperEncoderSubgraph>(node,
|
||||
attribute_name,
|
||||
subgraph_session_state.GetGraphViewer());
|
||||
attribute_name,
|
||||
subgraph_session_state.GetGraphViewer());
|
||||
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
|
||||
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
|
||||
|
||||
|
@ -260,23 +259,23 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
|
||||
// Subgraph has constraint that the output is either float or float16
|
||||
if (!t5_decoder_subgraph_->IsOutputFloat16()) {
|
||||
BeamSearchT5<float> impl{
|
||||
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
|
||||
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
|
||||
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
|
||||
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
|
||||
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
|
||||
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
|
||||
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
|
||||
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
|
||||
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
|
||||
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
|
||||
cuda_device_prop_,
|
||||
cuda_device_arch_};
|
||||
BeamSearchT5<float> impl{
|
||||
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
|
||||
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
|
||||
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
|
||||
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
|
||||
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
|
||||
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : GenerationCpuDeviceHelper::CreateEncoderInputs,
|
||||
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : GenerationCpuDeviceHelper::UpdateDecoderFeeds<float>,
|
||||
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
|
||||
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
|
||||
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
|
||||
cuda_device_prop_,
|
||||
cuda_device_arch_};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
|
|
|
@ -31,10 +31,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
|||
ORT_ENFORCE(context != nullptr);
|
||||
const Tensor* input_ids = context->Input<Tensor>(0);
|
||||
const auto& dims = input_ids->Shape().GetDims();
|
||||
if (this->model_type == IGenerationParameters::kModelTypeWhisper){
|
||||
if (this->model_type == IGenerationParameters::kModelTypeWhisper) {
|
||||
ORT_ENFORCE(dims.size() == 3, "input_features shall have 3 dimensions. Got ", dims.size());
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
|
||||
}
|
||||
batch_size = static_cast<int>(dims[0]);
|
||||
|
@ -69,8 +68,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
|||
if (length_penalty_tensor) {
|
||||
if (length_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) {
|
||||
length_penalty = static_cast<float>(*length_penalty_tensor->Data<float>());
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
length_penalty = static_cast<MLFloat16>(*length_penalty_tensor->Data<MLFloat16>());
|
||||
}
|
||||
} else {
|
||||
|
@ -81,8 +79,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
|||
if (repetition_penalty_tensor) {
|
||||
if (repetition_penalty_tensor->DataType() == DataTypeImpl::GetType<float>()) {
|
||||
repetition_penalty = static_cast<float>(*repetition_penalty_tensor->Data<float>());
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
repetition_penalty = static_cast<MLFloat16>(*repetition_penalty_tensor->Data<MLFloat16>());
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -124,17 +124,16 @@ class GenerateBase {
|
|||
const Tensor* attention_mask,
|
||||
const Tensor* presence_mask) const {
|
||||
const auto& dims = input_ids->Shape().GetDims();
|
||||
if (parameters->model_type == IGenerationParameters::kModelTypeWhisper){
|
||||
if (dims.size() != 3){
|
||||
if (parameters->model_type == IGenerationParameters::kModelTypeWhisper) {
|
||||
if (dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'input_features' is expected to have 3 dimensions, got ", dims.size());
|
||||
"Input 'input_features' is expected to have 3 dimensions, got ", dims.size());
|
||||
}
|
||||
|
||||
}
|
||||
else if (dims.size() != 2) {
|
||||
} else if (dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'input_ids' is expected to have 2 dimensions, got ", dims.size());
|
||||
}
|
||||
}
|
||||
|
||||
if (vocab_mask != nullptr) { // vocab_mask is optional
|
||||
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
|
||||
|
@ -184,10 +183,9 @@ class GenerateBase {
|
|||
if (parameters->model_type == IGenerationParameters::kModelTypeWhisper) {
|
||||
if (dims_attn.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'attention_mask' is expected to have 3 dimensions, got ", dims_attn.size());
|
||||
"Input 'attention_mask' is expected to have 3 dimensions, got ", dims_attn.size());
|
||||
}
|
||||
}
|
||||
else if (dims_attn.size() != 2) {
|
||||
} else if (dims_attn.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'attention_mask' is expected to have 2 dimensions, got ", dims_attn.size());
|
||||
}
|
||||
|
|
|
@ -329,7 +329,6 @@ Status CreateWhisperEncoderInputs(
|
|||
OrtValue& encoder_attention_mask,
|
||||
OrtValue& decoder_input_ids);
|
||||
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Utility Functions
|
||||
// ---------------------------------------------------------------
|
||||
|
|
|
@ -147,7 +147,7 @@ class LogitsProcessorList : public ILogitsProcessorList {
|
|||
void Process(const ISequences* sequences, gsl::span<float>& next_token_scores, int step);
|
||||
|
||||
private:
|
||||
template<typename GenerationParametersT>
|
||||
template <typename GenerationParametersT>
|
||||
void LogitsProcessorInitImpl(const GenerationParametersT& parameters) {
|
||||
processor_list_.clear();
|
||||
|
||||
|
@ -159,8 +159,7 @@ class LogitsProcessorList : public ILogitsProcessorList {
|
|||
|
||||
if (parameters.no_repeat_ngram_size > 0) {
|
||||
no_repeat_ngram_processor_ = std::make_unique<
|
||||
NoRepeatNGramLogitsProcessor<float>
|
||||
>(parameters.no_repeat_ngram_size);
|
||||
NoRepeatNGramLogitsProcessor<float>>(parameters.no_repeat_ngram_size);
|
||||
processor_list_.push_back(no_repeat_ngram_processor_.get());
|
||||
}
|
||||
|
||||
|
@ -171,9 +170,8 @@ class LogitsProcessorList : public ILogitsProcessorList {
|
|||
|
||||
if (!parameters.prefix_vocab_mask.empty()) {
|
||||
prefix_vocab_mask_processor_ = std::make_unique<
|
||||
PrefixVocabMaskLogitsProcessor<float>
|
||||
>(parameters.prefix_vocab_mask,
|
||||
parameters.batch_size);
|
||||
PrefixVocabMaskLogitsProcessor<float>>(parameters.prefix_vocab_mask,
|
||||
parameters.batch_size);
|
||||
processor_list_.push_back(prefix_vocab_mask_processor_.get());
|
||||
}
|
||||
|
||||
|
@ -190,9 +188,8 @@ class LogitsProcessorList : public ILogitsProcessorList {
|
|||
|
||||
if (!parameters.presence_mask.empty()) {
|
||||
presence_penalty_processor_ = std::make_unique<
|
||||
PresencePenaltyLogitsProcessor<float>
|
||||
>(parameters.presence_mask,
|
||||
parameters.presence_penalty);
|
||||
PresencePenaltyLogitsProcessor<float>>(parameters.presence_mask,
|
||||
parameters.presence_penalty);
|
||||
processor_list_.push_back(presence_penalty_processor_.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -154,12 +154,12 @@ Status Sample(AllocatorPtr& allocator,
|
|||
*sampled_idx));
|
||||
// TODO: update presense_mask()
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("sampled_idx", *sampled_idx);
|
||||
dumper->Print("sampled_idx", *sampled_idx);
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace SamplingCudaHelper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace SamplingCpuHelper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -40,7 +40,7 @@ namespace transformers {
|
|||
*/
|
||||
|
||||
Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) {
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) {
|
||||
ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs);
|
||||
|
||||
ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs);
|
||||
|
@ -75,7 +75,7 @@ Status WhisperEncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgr
|
|||
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float32_type && subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != float16_type,
|
||||
"encoder subgraph input 0 (encoder_input_features) shall have float32 or float16 type");
|
||||
"encoder subgraph input 0 (encoder_input_features) shall have float32 or float16 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"encoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
|
|
|
@ -38,9 +38,9 @@ Status Unique<float>::Compute(OpKernelContext* ctx) const {
|
|||
int64_t* output_idx_data = output_idx->MutableData<int64_t>();
|
||||
|
||||
struct ElementData {
|
||||
int64_t input_pos_; // original index
|
||||
int64_t input_pos_; // original index
|
||||
int64_t output_pos_;
|
||||
int64_t count_; // number of times encountered
|
||||
int64_t count_; // number of times encountered
|
||||
};
|
||||
|
||||
// XXX: Refactoring for less memory allocations. unordered_map
|
||||
|
|
|
@ -34,7 +34,7 @@ void WordConvEmbedding::CharEmbeddingLookup(
|
|||
}
|
||||
}
|
||||
|
||||
//input : [sequence_length, word_length, char_embedding_size]
|
||||
// input : [sequence_length, word_length, char_embedding_size]
|
||||
void WordConvEmbedding::ComputeConvMaxPoolWithActivation(
|
||||
AllocatorPtr allocator,
|
||||
const float* input,
|
||||
|
@ -190,11 +190,11 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
|
|||
// SafeInt<size_t>(seq_len) * word_len * char_embedding_size
|
||||
size_t chars_embeddings_size = SafeInt<size_t>(seq_len) * word_len * char_embedding_size;
|
||||
auto chars_embeddings_ptr = IAllocator::MakeUniquePtr<float>(alloc, chars_embeddings_size);
|
||||
auto words_length_ptr = IAllocator::MakeUniquePtr<int>(alloc, onnxruntime::narrow<size_t>(seq_len) );
|
||||
auto words_length_ptr = IAllocator::MakeUniquePtr<int>(alloc, onnxruntime::narrow<size_t>(seq_len));
|
||||
std::memset(chars_embeddings_ptr.get(), 0, chars_embeddings_size * sizeof(float));
|
||||
std::memset(words_length_ptr.get(), 0, SafeInt<size_t>(seq_len) * sizeof(int));
|
||||
|
||||
CalculateLengthOfEachWordInSequence(seq_ptr, words_length_ptr.get(), onnxruntime::narrow<size_t>(seq_len) , onnxruntime::narrow<size_t>(word_len));
|
||||
CalculateLengthOfEachWordInSequence(seq_ptr, words_length_ptr.get(), onnxruntime::narrow<size_t>(seq_len), onnxruntime::narrow<size_t>(word_len));
|
||||
|
||||
CharEmbeddingLookup(seq_ptr,
|
||||
w_char_embedding.Data<float>(),
|
||||
|
@ -224,7 +224,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
/* Range operator */
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
WordConvEmbedding, //name
|
||||
WordConvEmbedding, // name
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
|
|
|
@ -22,19 +22,19 @@ namespace cuda {
|
|||
.MayInplace(0, 0), \
|
||||
x<T>);
|
||||
|
||||
#define UNARY_ACTIVATION_COMPUTE(x, T) \
|
||||
template <> \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
UnaryElementwisePreparation p; \
|
||||
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
|
||||
Ctx##x func_ctx = MakeFuncCtx(); \
|
||||
Impl_##x<typename ToCudaType<T>::MappedType>( \
|
||||
Stream(context), \
|
||||
#define UNARY_ACTIVATION_COMPUTE(x, T) \
|
||||
template <> \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
UnaryElementwisePreparation p; \
|
||||
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
|
||||
Ctx##x func_ctx = MakeFuncCtx(); \
|
||||
Impl_##x<typename ToCudaType<T>::MappedType>( \
|
||||
Stream(context), \
|
||||
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(p.input_tensor->Data<T>()), \
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(p.output_tensor->MutableData<T>()), \
|
||||
&func_ctx, p.output_tensor->Shape().Size()); \
|
||||
\
|
||||
return Status::OK(); \
|
||||
&func_ctx, p.output_tensor->Shape().Size()); \
|
||||
\
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \
|
||||
|
@ -56,6 +56,6 @@ REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
|
|||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, double)
|
||||
|
||||
} //namespace cuda
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -62,10 +62,10 @@ __device__ inline void Softmax(const int all_sequence_length,
|
|||
if (i >= valid_start) {
|
||||
const int index = offset + i;
|
||||
float input_at_idx = no_rpb
|
||||
? float(input[index])
|
||||
: float(input[index] + (broadcast_rel_pos_bias
|
||||
? rel_pos_bias[index % size_per_batch]
|
||||
: rel_pos_bias[index]));
|
||||
? float(input[index])
|
||||
: float(input[index] + (broadcast_rel_pos_bias
|
||||
? rel_pos_bias[index % size_per_batch]
|
||||
: rel_pos_bias[index]));
|
||||
if (thread_data_max < input_at_idx) {
|
||||
thread_data_max = input_at_idx;
|
||||
}
|
||||
|
@ -148,10 +148,10 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length,
|
|||
const bool no_rpb = (rel_pos_bias == nullptr);
|
||||
const int size_per_batch = gridDim.x * all_sequence_length;
|
||||
float input_data = no_rpb
|
||||
? float(input[index])
|
||||
: float(input[index] + (broadcast_rel_pos_bias
|
||||
? rel_pos_bias[index % size_per_batch]
|
||||
: rel_pos_bias[index]));
|
||||
? float(input[index])
|
||||
: float(input[index] + (broadcast_rel_pos_bias
|
||||
? rel_pos_bias[index % size_per_batch]
|
||||
: rel_pos_bias[index]));
|
||||
float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F);
|
||||
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
|
||||
|
||||
|
@ -229,12 +229,12 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length,
|
|||
// a math transform as below is leveraged to get a stable softmax:
|
||||
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
|
||||
float input_data = is_valid
|
||||
? (rel_pos_bias
|
||||
? float(input[index] + (broadcast_rel_pos_bias
|
||||
? rel_pos_bias[index % size_per_batch]
|
||||
: rel_pos_bias[index]))
|
||||
: float(input[index]))
|
||||
: float(-CUDART_INF_F);
|
||||
? (rel_pos_bias
|
||||
? float(input[index] + (broadcast_rel_pos_bias
|
||||
? rel_pos_bias[index % size_per_batch]
|
||||
: rel_pos_bias[index]))
|
||||
: float(input[index]))
|
||||
: float(-CUDART_INF_F);
|
||||
cached_data[seq_idx] = input_data;
|
||||
thread_data_max = max(thread_data_max, input_data);
|
||||
}
|
||||
|
@ -300,8 +300,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
|
|||
if (rel_pos_bias == nullptr) {
|
||||
thread_data = float(input[index]) * rsqrt_head_size;
|
||||
} else {
|
||||
T rel_pos_bias_value = broadcast_rel_pos_bias ?
|
||||
rel_pos_bias[index % size_per_batch] : rel_pos_bias[index];
|
||||
T rel_pos_bias_value = broadcast_rel_pos_bias ? rel_pos_bias[index % size_per_batch] : rel_pos_bias[index];
|
||||
thread_data = float(input[index] + rel_pos_bias_value) * rsqrt_head_size;
|
||||
}
|
||||
|
||||
|
@ -433,8 +432,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
|||
}
|
||||
|
||||
if (rel_pos_bias != nullptr) {
|
||||
float rel_pos_bias_value = broadcast_rel_pos_bias ?
|
||||
float(rel_pos_bias[index % size_per_batch]) : float(rel_pos_bias[index]);
|
||||
float rel_pos_bias_value = broadcast_rel_pos_bias ? float(rel_pos_bias[index % size_per_batch]) : float(rel_pos_bias[index]);
|
||||
thread_data += rel_pos_bias_value;
|
||||
}
|
||||
}
|
||||
|
@ -590,10 +588,10 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length,
|
|||
const bool no_rpb = (rel_pos_bias == nullptr);
|
||||
const int size_per_batch = gridDim.x * sequence_length;
|
||||
float input_data = no_rpb
|
||||
? float(input[index])
|
||||
: float(input[index] + (broadcast_rel_pos_bias
|
||||
? rel_pos_bias[index % size_per_batch]
|
||||
: rel_pos_bias[index]));
|
||||
? float(input[index])
|
||||
: float(input[index] + (broadcast_rel_pos_bias
|
||||
? rel_pos_bias[index % size_per_batch]
|
||||
: rel_pos_bias[index]));
|
||||
|
||||
float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F);
|
||||
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
|
||||
|
@ -761,17 +759,17 @@ Status ComputeSoftmaxWithCumSeqLength(
|
|||
|
||||
template <typename T>
|
||||
Status ComputeSoftmaxWithMask1D(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int* mask_index,
|
||||
const int* mask_start,
|
||||
const T* rel_pos_bias,
|
||||
const bool broadcast_rel_pos_bias,
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional) {
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int num_heads,
|
||||
const int* mask_index,
|
||||
const int* mask_start,
|
||||
const T* rel_pos_bias,
|
||||
const bool broadcast_rel_pos_bias,
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
if (all_sequence_length <= 32) {
|
||||
|
|
|
@ -30,12 +30,12 @@ struct MemoryEfficientAttentionParams {
|
|||
int32_t* seqstart_k_ptr;
|
||||
int32_t* seqlen_k_ptr;
|
||||
|
||||
const void* query; // [B, S, N, H]
|
||||
const void* key; // [B, L, N, H], where L is kv_sequence_length
|
||||
const void* value; // [B, L, N, H_v]
|
||||
const void* attn_bias; // [N, S, S*] or null
|
||||
void* output; // [B, S, N, H_v]
|
||||
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
|
||||
const void* query; // [B, S, N, H]
|
||||
const void* key; // [B, L, N, H], where L is kv_sequence_length
|
||||
const void* value; // [B, L, N, H_v]
|
||||
const void* attn_bias; // [N, S, S*] or null
|
||||
void* output; // [B, S, N, H_v]
|
||||
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
|
||||
cudaStream_t stream;
|
||||
|
||||
static bool need_workspace(size_t v_head_size, bool is_float) {
|
||||
|
|
|
@ -65,24 +65,24 @@ Status EmbedLayerNorm<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
return LaunchEmbedLayerNormKernel(
|
||||
Stream(context),
|
||||
output->MutableData<T>(),
|
||||
mask_index->MutableData<int32_t>(),
|
||||
input_ids->Data<int32_t>(),
|
||||
nullptr == segment_ids ? nullptr : segment_ids->Data<int32_t>(),
|
||||
nullptr == mask ? nullptr : mask->Data<int32_t>(),
|
||||
gamma->Data<T>(),
|
||||
beta->Data<T>(),
|
||||
word_embedding->Data<T>(),
|
||||
position_embedding->Data<T>(),
|
||||
nullptr == segment_embedding ? nullptr : segment_embedding->Data<T>(),
|
||||
epsilon_,
|
||||
static_cast<int>(hidden_size),
|
||||
batch_size,
|
||||
sequence_length,
|
||||
element_size,
|
||||
embedding_sum == nullptr ? nullptr : embedding_sum->MutableData<T>(),
|
||||
position_ids == nullptr ? nullptr : position_ids->Data<int32_t>(),
|
||||
broadcast_position_ids);
|
||||
output->MutableData<T>(),
|
||||
mask_index->MutableData<int32_t>(),
|
||||
input_ids->Data<int32_t>(),
|
||||
nullptr == segment_ids ? nullptr : segment_ids->Data<int32_t>(),
|
||||
nullptr == mask ? nullptr : mask->Data<int32_t>(),
|
||||
gamma->Data<T>(),
|
||||
beta->Data<T>(),
|
||||
word_embedding->Data<T>(),
|
||||
position_embedding->Data<T>(),
|
||||
nullptr == segment_embedding ? nullptr : segment_embedding->Data<T>(),
|
||||
epsilon_,
|
||||
static_cast<int>(hidden_size),
|
||||
batch_size,
|
||||
sequence_length,
|
||||
element_size,
|
||||
embedding_sum == nullptr ? nullptr : embedding_sum->MutableData<T>(),
|
||||
position_ids == nullptr ? nullptr : position_ids->Data<int32_t>(),
|
||||
broadcast_position_ids);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
@ -8,24 +8,24 @@ namespace contrib {
|
|||
namespace cuda {
|
||||
|
||||
Status LaunchEmbedLayerNormKernel(cudaStream_t stream,
|
||||
void* output, // output tensor
|
||||
void* mask_index, // output mask index
|
||||
const int* input_ids, // input word IDs
|
||||
const int* segment_ids, // input segment IDs
|
||||
const int* input_mask, // input mask
|
||||
const void* gamma, // weight for layer normalization
|
||||
const void* beta, // bias for layer normalization
|
||||
const void* word_embedding, // weights for word embeddings
|
||||
const void* position_embedding, // weights for position embeddings
|
||||
const void* segment_embedding, // weights for segment (like sentence) embeddings
|
||||
float epsilon, // epsilon for layer normalization
|
||||
const int hidden_size, // hidden size (that is head_size * num_heads)
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence length
|
||||
const size_t element_size, // size of output element: 2 for half, 4 for float.
|
||||
void* embedding_sum, // Optional output of sum of embeddings
|
||||
const int* position_ids, // Optional input of position ids
|
||||
const bool broadcast_position_ids); // Whether to broadcast position ids
|
||||
void* output, // output tensor
|
||||
void* mask_index, // output mask index
|
||||
const int* input_ids, // input word IDs
|
||||
const int* segment_ids, // input segment IDs
|
||||
const int* input_mask, // input mask
|
||||
const void* gamma, // weight for layer normalization
|
||||
const void* beta, // bias for layer normalization
|
||||
const void* word_embedding, // weights for word embeddings
|
||||
const void* position_embedding, // weights for position embeddings
|
||||
const void* segment_embedding, // weights for segment (like sentence) embeddings
|
||||
float epsilon, // epsilon for layer normalization
|
||||
const int hidden_size, // hidden size (that is head_size * num_heads)
|
||||
int batch_size, // batch size
|
||||
int sequence_length, // sequence length
|
||||
const size_t element_size, // size of output element: 2 for half, 4 for float.
|
||||
void* embedding_sum, // Optional output of sum of embeddings
|
||||
const int* position_ids, // Optional input of position ids
|
||||
const bool broadcast_position_ids); // Whether to broadcast position ids
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace cuda {
|
|||
|
||||
template <typename T>
|
||||
Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length,
|
||||
const T* input, const T* bias, T* output, bool use_half2);
|
||||
const T* input, const T* bias, T* output, bool use_half2);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
@ -88,12 +88,12 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
relative_position_bias,
|
||||
past_key,
|
||||
past_value,
|
||||
nullptr, // past_seq_len
|
||||
nullptr, // past_seq_len
|
||||
¶meters,
|
||||
num_heads_,
|
||||
mask_filter_value_,
|
||||
scale_,
|
||||
false, // past_present_share_buffer
|
||||
false, // past_present_share_buffer
|
||||
device_prop.maxThreadsPerBlock));
|
||||
|
||||
int sequence_length = parameters.sequence_length;
|
||||
|
|
|
@ -219,7 +219,7 @@ MHARunner* PackedAttention<T>::TryGettingFusedRunner(const PackedAttentionParame
|
|||
!parameters.has_relative_position_bias &&
|
||||
parameters.hidden_size == parameters.v_hidden_size;
|
||||
|
||||
if(!use_fused_runner) {
|
||||
if (!use_fused_runner) {
|
||||
return fused_runner;
|
||||
}
|
||||
|
||||
|
@ -232,7 +232,7 @@ MHARunner* PackedAttention<T>::TryGettingFusedRunner(const PackedAttentionParame
|
|||
enable_trt_flash_attention_,
|
||||
false);
|
||||
|
||||
if(!is_fMHA_supported) {
|
||||
if (!is_fMHA_supported) {
|
||||
return fused_runner;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,14 +12,13 @@ using namespace onnxruntime::cuda;
|
|||
using namespace ::onnxruntime::common;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
RelativePositionBias , \
|
||||
RelativePositionBias, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
|
@ -63,8 +62,8 @@ Status RelPosAttnBias<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const int64_t num_buckets = bias_table_dims[0];
|
||||
const int64_t num_heads = bias_table_dims[1];
|
||||
|
||||
const int64_t query_len = *query_length->Data<int64_t>();
|
||||
const int64_t key_len = *key_length->Data<int64_t>();
|
||||
const int64_t query_len = *query_length->Data<int64_t>();
|
||||
const int64_t key_len = *key_length->Data<int64_t>();
|
||||
|
||||
if (query_len != key_len) {
|
||||
ORT_THROW("Relatvie position bias currently only support query length equal to key length in Self Attention.");
|
||||
|
@ -145,7 +144,7 @@ Status GatedRelativePositionBias<T>::ComputeInternal(OpKernelContext* context) c
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const auto BNS = batch_size * num_heads_ * seq_len;
|
||||
const size_t elements_in_query = (size_t)BNS * (size_t)head_size;
|
||||
const size_t elements_after_gemm = (size_t)BNS *(size_t)D;
|
||||
const size_t elements_after_gemm = (size_t)BNS * (size_t)D;
|
||||
bool reuse_output = (seq_len >= D);
|
||||
size_t workspace_size = sizeof(T) * (elements_in_query + (reuse_output ? (size_t)0 : elements_after_gemm));
|
||||
auto workspace = GetScratchBuffer<void>(workspace_size, context->GetComputeStream());
|
||||
|
|
|
@ -33,7 +33,6 @@ class GatedRelativePositionBias final : public CudaKernel {
|
|||
int num_heads_;
|
||||
};
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
@ -19,8 +19,7 @@ Status LaunchRelPosAttnBiasKernel(
|
|||
const int num_bucket,
|
||||
const int max_distance,
|
||||
const bool is_bidirectional,
|
||||
const int max_threads_per_block
|
||||
);
|
||||
const int max_threads_per_block);
|
||||
|
||||
template <typename T>
|
||||
Status LaunchGatedRelativePositionBiasKernel(
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -136,7 +136,7 @@ AllToAll::AllToAll(const OpKernelInfo& info) : NcclKernel(info) {
|
|||
Status AllToAll::ComputeInternal(OpKernelContext* context) const {
|
||||
const ncclComm_t comm = nccl_->Comm();
|
||||
auto input_tensor = context->Input<Tensor>(0);
|
||||
const char* input_data = static_cast<const char *>(input_tensor->DataRaw());
|
||||
const char* input_data = static_cast<const char*>(input_tensor->DataRaw());
|
||||
const auto in_shape = input_tensor->Shape();
|
||||
const int64_t input_count = in_shape.Size();
|
||||
auto out_shape = in_shape;
|
||||
|
@ -144,7 +144,7 @@ Status AllToAll::ComputeInternal(OpKernelContext* context) const {
|
|||
const int64_t rank_stride = input_count / group_size_;
|
||||
const ncclDataType_t dtype = GetNcclDataType(input_tensor->DataType());
|
||||
|
||||
char* output_data = static_cast<char *>(context->Output(0, out_shape)->MutableDataRaw());
|
||||
char* output_data = static_cast<char*>(context->Output(0, out_shape)->MutableDataRaw());
|
||||
|
||||
#ifdef ORT_USE_NCCL
|
||||
NCCL_RETURN_IF_ERROR(ncclGroupStart());
|
||||
|
|
|
@ -69,7 +69,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs<Tensor>(query,
|
||||
key,
|
||||
value,
|
||||
nullptr, //bias
|
||||
nullptr, // bias
|
||||
mask_index,
|
||||
relative_position_bias,
|
||||
past_key,
|
||||
|
@ -108,9 +108,9 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
std::vector<int64_t> present_dims{
|
||||
parameters.batch_size, parameters.num_heads,
|
||||
past_present_share_buffer_ ? parameters.max_sequence_length : parameters.total_sequence_length,
|
||||
parameters.head_size};
|
||||
parameters.batch_size, parameters.num_heads,
|
||||
past_present_share_buffer_ ? parameters.max_sequence_length : parameters.total_sequence_length,
|
||||
parameters.head_size};
|
||||
TensorShape present_shape(present_dims);
|
||||
Tensor* present_key = context->Output(kPresentOutputIndex, present_shape);
|
||||
Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape);
|
||||
|
|
|
@ -23,7 +23,7 @@ static constexpr int kPresentOutputIndex = 1;
|
|||
|
||||
#define REGISTER_KERNEL_TYPED(T1, T2) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
DecoderMaskedSelfAttention, \
|
||||
DecoderMaskedSelfAttention, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T1, \
|
||||
|
|
|
@ -34,11 +34,10 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
|
|||
void* out = nullptr;
|
||||
|
||||
const int32_t* cache_indir = nullptr;
|
||||
const int32_t* mask = nullptr; // [B, total_sequence_length]
|
||||
const int32_t* mask = nullptr; // [B, total_sequence_length]
|
||||
};
|
||||
|
||||
|
||||
template<
|
||||
template <
|
||||
// The type of the inputs. Supported types: float and half.
|
||||
typename T,
|
||||
// The hidden dimension per head.
|
||||
|
@ -51,11 +50,9 @@ template<
|
|||
int THREADS_PER_BLOCK>
|
||||
__global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
|
||||
|
||||
template<typename T, int head_size>
|
||||
template <typename T, int head_size>
|
||||
void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
|
||||
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
@ -71,10 +71,10 @@ Status BiasAdd<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
const int32_t grid_size = static_cast<int32_t>(input_dims[0] * input_dims[1]);
|
||||
LaunchBiasAddKernel<CudaT>(Stream(context), grid_size, static_cast<int32_t>(input_dims[2]),
|
||||
reinterpret_cast<const CudaT*>(input->Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(bias->Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(skip->Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output->MutableData<T>()));
|
||||
reinterpret_cast<const CudaT*>(input->Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(bias->Data<T>()),
|
||||
reinterpret_cast<const CudaT*>(skip->Data<T>()),
|
||||
reinterpret_cast<CudaT*>(output->MutableData<T>()));
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaPeekAtLastError());
|
||||
return Status::OK();
|
||||
|
|
|
@ -19,7 +19,7 @@ class GridSample final : public CudaKernel {
|
|||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
int64_t mode_i_; // 0: bilinear (default), 1: nearest 2: bicubic
|
||||
int64_t mode_i_; // 0: bilinear (default), 1: nearest 2: bicubic
|
||||
int64_t padding_mode_i_; // 0:'zeros', 1: 'border', 2:'reflection'
|
||||
int64_t align_corners_;
|
||||
};
|
||||
|
|
|
@ -19,24 +19,24 @@ namespace cuda {
|
|||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
x<T>);
|
||||
|
||||
#define CONTRIB_BINARY_ELEMENTWISE_COMPUTE(x, T) \
|
||||
template <> \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
BinaryElementwisePreparation prepare; \
|
||||
ORT_RETURN_IF_ERROR(Prepare(context, &prepare)); \
|
||||
Impl_##x<typename ToCudaType<T>::MappedType>( \
|
||||
Stream(context), \
|
||||
prepare.output_rank_or_simple_broadcast, \
|
||||
&prepare.lhs_padded_strides, \
|
||||
#define CONTRIB_BINARY_ELEMENTWISE_COMPUTE(x, T) \
|
||||
template <> \
|
||||
Status x<T>::ComputeInternal(OpKernelContext* context) const { \
|
||||
BinaryElementwisePreparation prepare; \
|
||||
ORT_RETURN_IF_ERROR(Prepare(context, &prepare)); \
|
||||
Impl_##x<typename ToCudaType<T>::MappedType>( \
|
||||
Stream(context), \
|
||||
prepare.output_rank_or_simple_broadcast, \
|
||||
&prepare.lhs_padded_strides, \
|
||||
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(prepare.lhs_tensor->Data<T>()), \
|
||||
&prepare.rhs_padded_strides, \
|
||||
&prepare.rhs_padded_strides, \
|
||||
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(prepare.rhs_tensor->Data<T>()), \
|
||||
&prepare.fdm_output_strides, \
|
||||
prepare.fdm_H, \
|
||||
prepare.fdm_C, \
|
||||
&prepare.fdm_output_strides, \
|
||||
prepare.fdm_H, \
|
||||
prepare.fdm_C, \
|
||||
reinterpret_cast<typename ToCudaType<T>::MappedType*>(prepare.output_tensor->MutableData<T>()), \
|
||||
prepare.output_tensor->Shape().Size()); \
|
||||
return Status::OK(); \
|
||||
prepare.output_tensor->Shape().Size()); \
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
#define CONTRIB_BINARY_OP_TYPED(name, ver, T) \
|
||||
|
|
|
@ -20,7 +20,7 @@ namespace cuda {
|
|||
#define CONTRIB_BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
void Impl_##name( \
|
||||
cudaStream_t stream, \
|
||||
cudaStream_t stream, \
|
||||
int32_t output_rank_or_simple_broadcast, \
|
||||
const TArray<int64_t>* lhs_padded_strides, \
|
||||
const T* lhs_data, \
|
||||
|
|
|
@ -12,7 +12,7 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
//key
|
||||
// key
|
||||
struct FFTState {
|
||||
int64_t signal_ndim;
|
||||
int64_t signal_dims[5];
|
||||
|
@ -22,7 +22,7 @@ struct FFTState {
|
|||
cudaDataType exec_type;
|
||||
};
|
||||
|
||||
//value
|
||||
// value
|
||||
struct CufftPlanInfo {
|
||||
cufftHandle plan;
|
||||
size_t ws_size_t;
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче