[CoreML] ML Program more ops (2/N) (#22480)
- cast - argmax - gelu - cast - LayerNorm - GroupNorm - InstanceNorm ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Scott McKay <skottmckay@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Родитель
c7ecc081ca
Коммит
9daf7664fc
|
@ -31,10 +31,10 @@ enum COREMLFlags {
|
|||
// Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later.
|
||||
COREML_FLAG_CREATE_MLPROGRAM = 0x010,
|
||||
|
||||
// Exclude ANE as sometimes this decrease performance
|
||||
// https://developer.apple.com/documentation/coreml/mlcomputeunits?language=objc
|
||||
// there are four compute units:
|
||||
// MLComputeUnitsCPUAndNeuralEngine|MLComputeUnitsCPUAndGPU|MLComputeUnitsCPUOnly|MLComputeUnitsAll
|
||||
// different CU will have different performance and power consumption
|
||||
COREML_FLAG_USE_CPU_AND_GPU = 0x020,
|
||||
// Keep COREML_FLAG_LAST at the end of the enum definition
|
||||
// And assign the last COREMLFlag to it
|
||||
|
|
|
@ -40,6 +40,25 @@ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, con
|
|||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
void HandlePReluWeight(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger,
|
||||
std::vector<T>& alpha_values) {
|
||||
// add slope initializer as alpha weight
|
||||
const auto& slope_tensor = *model_builder.GetConstantInitializer(node.InputDefs()[1]->Name());
|
||||
Initializer unpacked_tensor(slope_tensor);
|
||||
const auto alpha_v = unpacked_tensor.DataAsSpan<T>();
|
||||
|
||||
if (alpha_v.size() == 1) {
|
||||
// expand to number of channels
|
||||
std::vector<int64_t> x_shape;
|
||||
GetShape(*node.InputDefs()[0], x_shape, logger);
|
||||
alpha_values.resize(x_shape[x_shape.size() - 3], alpha_v[0]);
|
||||
} else {
|
||||
alpha_values.assign(alpha_v.begin(), alpha_v.end());
|
||||
}
|
||||
}
|
||||
|
||||
Status AddPReluWeight(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger,
|
||||
COREML_SPEC::ActivationPReLU& prelu) {
|
||||
|
@ -84,6 +103,7 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.activation
|
||||
std::string_view coreml_op_type;
|
||||
bool add_alpha = false;
|
||||
bool add_gelu_mode = false;
|
||||
if (op_type == "Sigmoid") {
|
||||
coreml_op_type = "sigmoid";
|
||||
} else if (op_type == "Tanh") {
|
||||
|
@ -93,6 +113,12 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
} else if (op_type == "LeakyRelu") {
|
||||
coreml_op_type = "leaky_relu";
|
||||
add_alpha = true;
|
||||
} else if (op_type == "Gelu") {
|
||||
coreml_op_type = "gelu";
|
||||
add_gelu_mode = true;
|
||||
} else if (op_type == "PRelu") {
|
||||
coreml_op_type = "prelu";
|
||||
add_alpha = true;
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"ActivationOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
|
||||
|
@ -102,16 +128,39 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
AddOperationInput(*op, "x", node.InputDefs()[0]->Name());
|
||||
|
||||
if (add_alpha) {
|
||||
NodeAttrHelper helper(node);
|
||||
const auto alpha = helper.Get("alpha", 0.01f);
|
||||
|
||||
auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha));
|
||||
|
||||
if ("PRelu" == op_type) {
|
||||
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
std::vector<float> alpha_values;
|
||||
HandlePReluWeight(model_builder, node, logger, alpha_values);
|
||||
AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values));
|
||||
} else {
|
||||
std::vector<MLFloat16> alpha_values;
|
||||
HandlePReluWeight(model_builder, node, logger, alpha_values);
|
||||
AddOperationInput(*op, "alpha", model_builder.AddConstant(op->type(), "alpha", alpha_values));
|
||||
}
|
||||
} else {
|
||||
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha)));
|
||||
NodeAttrHelper helper(node);
|
||||
const auto alpha = helper.Get("alpha", 0.01f);
|
||||
|
||||
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", alpha));
|
||||
} else {
|
||||
AddOperationInput(*op, "alpha", model_builder.AddScalarConstant(op->type(), "alpha", MLFloat16(alpha)));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (add_gelu_mode) {
|
||||
NodeAttrHelper helper(node);
|
||||
std::string approximate = helper.Get("approximate", std::string("none"));
|
||||
if (approximate == "tanh") {
|
||||
approximate = "TANH_APPROXIMATION";
|
||||
} else if (approximate == "none") {
|
||||
approximate = "EXACT";
|
||||
}
|
||||
AddOperationInput(*op, "mode", model_builder.AddScalarConstant(op->type(), "mode", std::string(approximate)));
|
||||
}
|
||||
|
||||
AddOperationOutput(*op, *node.OutputDefs()[0]);
|
||||
|
||||
|
@ -213,17 +262,11 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInp
|
|||
const logging::Logger& logger) const {
|
||||
const auto& op_type = node.OpType();
|
||||
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
if (input_params.create_mlprogram) {
|
||||
if (op_type == "PRelu") { // TODO: ML Program supports this so should be easy to enable
|
||||
return false;
|
||||
}
|
||||
} else
|
||||
#endif // (COREML_ENABLE_MLPROGRAM)
|
||||
{
|
||||
if (op_type == "PRelu") {
|
||||
return IsPReluOpSupported(node, input_params, logger);
|
||||
}
|
||||
if (op_type == "Gelu" && !input_params.create_mlprogram) {
|
||||
return false;
|
||||
}
|
||||
if (op_type == "PRelu") {
|
||||
return IsPReluOpSupported(node, input_params, logger);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -245,6 +288,7 @@ void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistration
|
|||
"Relu",
|
||||
"PRelu",
|
||||
"LeakyRelu",
|
||||
"Gelu",
|
||||
};
|
||||
|
||||
op_registrations.builders.push_back(std::make_unique<ActivationOpBuilder>());
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "core/providers/coreml/builders/impl/base_op_builder.h"
|
||||
#include "core/providers/coreml/builders/model_builder.h"
|
||||
#include "core/providers/coreml/builders/impl/builder_utils.h"
|
||||
#include "core/providers/coreml/builders/op_builder_factory.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
|
||||
|
@ -15,6 +16,9 @@ class ArgMaxOpBuilder : public BaseOpBuilder {
|
|||
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
|
||||
public:
|
||||
bool SupportsMLProgram() const override { return true; }
|
||||
};
|
||||
|
||||
Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
|
@ -24,41 +28,60 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
const auto& graph_viewer = model_builder.GetGraphViewer();
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
const auto axis = helper.Get("axis", 0);
|
||||
const auto keepdims = helper.Get("keepdims", 1);
|
||||
const int64_t axis = helper.Get("axis", 0);
|
||||
const int64_t keepdims = helper.Get("keepdims", 1);
|
||||
const bool removedim = keepdims != 1;
|
||||
|
||||
auto* coreml_argmax = layer->mutable_argmax();
|
||||
coreml_argmax->set_axis(axis);
|
||||
coreml_argmax->set_removedim(removedim);
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
if (model_builder.CreateMLProgram()) {
|
||||
using namespace CoreML::Specification::MILSpec;
|
||||
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.reduction
|
||||
|
||||
// There are two cases here:
|
||||
// 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input
|
||||
// (We still have this special case here because CoreML model does not have Cast)
|
||||
// 2. Otherwise, we add Argmax layer normally
|
||||
if (node.GetOutputEdgesCount() == 1) {
|
||||
auto it = node.OutputEdgesBegin();
|
||||
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
|
||||
// If Argmax's successive node is a Cast from int64 to int32 output
|
||||
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
|
||||
// so we omit the check here
|
||||
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
|
||||
// Skip the cast's input/argmax's output
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
return Status::OK();
|
||||
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "reduce_argmax");
|
||||
AddOperationInput(*op, "x", node.InputDefs()[0]->Name());
|
||||
AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis));
|
||||
AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", bool(keepdims)));
|
||||
|
||||
int32_t output_datatype = ONNX_NAMESPACE::TensorProto_DataType_INT32;
|
||||
// the output of ArgMax must be int32
|
||||
AddOperationOutput(*op, *node.OutputDefs()[0], output_datatype);
|
||||
model_builder.AddOperation(std::move(op));
|
||||
} else
|
||||
#endif // (COREML_ENABLE_MLPROGRAM)
|
||||
{
|
||||
auto* coreml_argmax = layer->mutable_argmax();
|
||||
coreml_argmax->set_axis(axis);
|
||||
coreml_argmax->set_removedim(removedim);
|
||||
|
||||
// There are two cases here:
|
||||
// 1. Special Case (ArgMax-Cast(from int64 to int32)), we fuse the Argmax's output/Cast's input
|
||||
// (We still have this special case here because CoreML model does not have Cast)
|
||||
// 2. Otherwise, we add Argmax layer normally
|
||||
if (node.GetOutputEdgesCount() == 1) {
|
||||
auto it = node.OutputEdgesBegin();
|
||||
const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index());
|
||||
// If Argmax's successive node is a Cast from int64 to int32 output
|
||||
// The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl())
|
||||
// so we omit the check here
|
||||
if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") {
|
||||
// Skip the cast's input/argmax's output
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name();
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
|
||||
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
}
|
||||
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
|
||||
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
|
||||
bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node,
|
||||
[[maybe_unused]] const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const {
|
||||
// Attribute `select_last_index` of ArgMax op is not supported
|
||||
NodeAttrHelper helper(node);
|
||||
|
@ -68,6 +91,12 @@ bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa
|
|||
return false;
|
||||
}
|
||||
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
if (input_params.create_mlprogram) {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
// If there are multiple downstream nodes and cast (toint32) is one of them
|
||||
// not supported, exit here
|
||||
// Otherwise, for general multiple downstream nodes, supported
|
||||
|
|
|
@ -16,11 +16,10 @@ namespace coreml {
|
|||
// Once all ops are supportted FP16, we can remove it. Before that, we keep a set of ops to
|
||||
// filter suppported ones.
|
||||
static std::set<std::string> Float16Ops = {
|
||||
"Add", "Mul", "Sub", "Div", "Pow", "Sqrt", "Reciprocal",
|
||||
"Sigmoid", "Tanh", "Relu", "LeakyRelu", "Concat", "GridSample", "GlobalAveragePool",
|
||||
"Clip", "DepthToSpace", "Resize", "Slice", "Conv",
|
||||
"ConvTranspose", "GlobalMaxPool", "Gemm", "MatMul",
|
||||
"AveragePool", "MaxPool", "Reshape", "Split", "Transpose"};
|
||||
"Add", "ArgMax", "AveragePool", "BatchNormalization", "Cast", "Clip", "Concat", "Conv", "ConvTranspose",
|
||||
"DepthToSpace", "Div", "Gelu", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "GridSample", "GroupNormalization",
|
||||
"InstanceNormalization", "LayerNormalization", "LeakyRelu", "MatMul", "MaxPool", "Mul", "PRelu", "Pow",
|
||||
"Reciprocal", "Relu", "Reshape", "Resize", "Sigmoid", "Slice", "Split", "Sqrt", "Sub", "Tanh", "Transpose"};
|
||||
|
||||
namespace {
|
||||
// TODO, move this to shared_library
|
||||
|
|
|
@ -10,6 +10,10 @@
|
|||
#include "core/providers/coreml/shape_utils.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <TargetConditionals.h>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace coreml {
|
||||
|
||||
|
@ -24,6 +28,9 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder {
|
|||
|
||||
// BatchNormalization opset 6- has unsupported attributes
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; }
|
||||
|
||||
public:
|
||||
bool SupportsMLProgram() const override { return true; }
|
||||
};
|
||||
|
||||
void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
|
@ -50,21 +57,46 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
|
|||
const auto eps = helper.Get("epsilon", 1e-5f);
|
||||
const auto channels = scale_tensor.dims()[0];
|
||||
|
||||
auto* coreml_batch_norm = layer->mutable_batchnorm();
|
||||
coreml_batch_norm->set_channels(channels);
|
||||
coreml_batch_norm->set_epsilon(eps);
|
||||
coreml_batch_norm->set_computemeanvar(false);
|
||||
coreml_batch_norm->set_instancenormalization(false);
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
if (model_builder.CreateMLProgram()) {
|
||||
using namespace CoreML::Specification::MILSpec;
|
||||
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.batch_norm
|
||||
|
||||
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale
|
||||
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B
|
||||
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean
|
||||
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var
|
||||
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "batch_norm");
|
||||
AddOperationInput(*op, "x", input_defs[0]->Name());
|
||||
AddOperationInput(*op, "mean", model_builder.AddConstant(op->type(), input_defs[3]->Name() + "mean", mean_tensor));
|
||||
AddOperationInput(*op, "variance", model_builder.AddConstant(op->type(), input_defs[4]->Name() + "variance", var_tensor));
|
||||
AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name(), scale_tensor));
|
||||
AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name(), bias_tensor));
|
||||
auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
MLFloat16 epsilon_fp16(eps);
|
||||
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16));
|
||||
} else {
|
||||
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps));
|
||||
}
|
||||
|
||||
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name();
|
||||
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
|
||||
AddOperationOutput(*op, *node.OutputDefs()[0]);
|
||||
model_builder.AddOperation(std::move(op));
|
||||
} else
|
||||
#endif // (COREML_ENABLE_MLPROGRAM)
|
||||
{
|
||||
auto* coreml_batch_norm = layer->mutable_batchnorm();
|
||||
coreml_batch_norm->set_channels(channels);
|
||||
coreml_batch_norm->set_epsilon(eps);
|
||||
coreml_batch_norm->set_computemeanvar(false);
|
||||
coreml_batch_norm->set_instancenormalization(false);
|
||||
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_gamma(), scale_tensor)); // scale
|
||||
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_beta(), bias_tensor)); // B
|
||||
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_mean(), mean_tensor)); // mean
|
||||
ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_batch_norm->mutable_variance(), var_tensor)); // var
|
||||
|
||||
*layer->mutable_input()->Add() = input_defs[0]->Name();
|
||||
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name();
|
||||
|
||||
model_builder.AddLayer(std::move(layer));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -119,6 +151,15 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu
|
|||
return false;
|
||||
}
|
||||
|
||||
#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64)
|
||||
// To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary
|
||||
auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) {
|
||||
LOGS(logger, VERBOSE) << "float16 input is not supported on the iOS x86_64 simulator"
|
||||
<< " due to CoreML producing invalid output.";
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/providers/coreml/builders/helper.h"
|
||||
#include "core/providers/coreml/builders/impl/base_op_builder.h"
|
||||
#include "core/providers/coreml/builders/model_builder.h"
|
||||
#include "core/providers/coreml/builders/impl/builder_utils.h"
|
||||
#include "core/providers/coreml/builders/op_builder_factory.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
|
||||
|
@ -18,14 +19,62 @@ class CastOpBuilder : public BaseOpBuilder {
|
|||
|
||||
bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
|
||||
public:
|
||||
bool SupportsMLProgram() const override { return true; }
|
||||
};
|
||||
|
||||
Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */,
|
||||
const Node& /* node */,
|
||||
const logging::Logger& /* logger */) const {
|
||||
// This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type.
|
||||
// The ArgMax is fused with the Cast node and produces an int32 output.
|
||||
// Cast node is not provided in CoreML model, so we're skipping adding the Cast node here.
|
||||
Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder,
|
||||
[[maybe_unused]] const Node& node,
|
||||
[[maybe_unused]] const logging::Logger& logger) const {
|
||||
// This is a special handling case for ArgMax Op, where argmax is followed by a cast to int32 type.
|
||||
// The ArgMax is fused with the Cast node and produces an int32 output.
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
if (model_builder.CreateMLProgram()) {
|
||||
using namespace CoreML::Specification::MILSpec;
|
||||
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.cast
|
||||
|
||||
NodeAttrHelper helper(node);
|
||||
auto cast_to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto::UNDEFINED);
|
||||
std::string to_dtype = "";
|
||||
if (cast_to_type == ONNX_NAMESPACE::TensorProto::INT32 || cast_to_type == ONNX_NAMESPACE::TensorProto::INT64) {
|
||||
to_dtype = "int32";
|
||||
// CoreML doesn't support int64, while ONNX uses int64 for indices and as well as data values.
|
||||
// We convert the data inputs/outputs between int64 and int32 when calling onnxruntime::coreml::Model::Predict,
|
||||
// and when adding int64 initializers to the CoreML model.
|
||||
// CoreML operators can only produce int32 and not int64 values.
|
||||
// Due to that there should be no actual int64 values inside the CoreML model and we can infer any
|
||||
// ONNX_NAMESPACE::TensorProto::INT64 values to be int32.
|
||||
cast_to_type = ONNX_NAMESPACE::TensorProto::INT32;
|
||||
} else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT) {
|
||||
to_dtype = "fp32";
|
||||
} else if (cast_to_type == ONNX_NAMESPACE::TensorProto::FLOAT16) {
|
||||
to_dtype = "fp16";
|
||||
} else if (cast_to_type == ONNX_NAMESPACE::TensorProto::BOOL) {
|
||||
to_dtype = "bool";
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported cast type: ", cast_to_type);
|
||||
}
|
||||
|
||||
std::string_view op_type = "cast";
|
||||
auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
if (((input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
|
||||
input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT32) &&
|
||||
to_dtype == "int32") ||
|
||||
cast_to_type == input_dtype) {
|
||||
op_type = "identity";
|
||||
}
|
||||
|
||||
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, op_type);
|
||||
AddOperationInput(*op, "x", node.InputDefs()[0]->Name());
|
||||
if (op_type == "cast") {
|
||||
AddOperationInput(*op, "dtype", model_builder.AddScalarConstant(op->type(), "dtype", std::string(to_dtype)));
|
||||
}
|
||||
AddOperationOutput(*op, *node.OutputDefs()[0], cast_to_type);
|
||||
model_builder.AddOperation(std::move(op));
|
||||
}
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -36,6 +85,10 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
|
|||
return false;
|
||||
}
|
||||
|
||||
if (input_params.create_mlprogram) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto& prec_node = node.InputEdgesBegin()->GetNode();
|
||||
|
||||
/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
|
||||
|
@ -67,14 +120,39 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
|
||||
bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const {
|
||||
// We only check the type of input 0
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& output = *node.OutputDefs()[0];
|
||||
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
int32_t input_type, output_type;
|
||||
if (!GetType(input, input_type, logger)) {
|
||||
return false;
|
||||
}
|
||||
if (!GetType(output, output_type, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
if (input_params.create_mlprogram) {
|
||||
if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
|
||||
input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
|
||||
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
|
||||
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) &&
|
||||
(output_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
|
||||
output_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
|
||||
output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
|
||||
output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) {
|
||||
return true;
|
||||
} else {
|
||||
LOGS(logger, VERBOSE) << "[" << node.OpType()
|
||||
<< "] Input type: [" << input_type
|
||||
<< "] is not supported.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// only support int64 coming from ArgMax (check for ArgMax is done in IsOpSupportedImpl())
|
||||
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/coreml/builders/helper.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/providers/coreml/builders/impl/base_op_builder.h"
|
||||
#include "core/providers/coreml/builders/impl/builder_utils.h"
|
||||
#include "core/providers/coreml/builders/model_builder.h"
|
||||
#include "core/providers/coreml/builders/op_builder_factory.h"
|
||||
#include "core/providers/coreml/shape_utils.h"
|
||||
#include "core/providers/shared/utils/utils.h"
|
||||
#include <numeric>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace coreml {
|
||||
|
||||
class NormalizationOpBuilder : public BaseOpBuilder {
|
||||
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
|
||||
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override;
|
||||
Status AddGroupNormToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const;
|
||||
|
||||
bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const override;
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 1; }
|
||||
|
||||
public:
|
||||
bool SupportsMLProgram() const override { return true; }
|
||||
};
|
||||
|
||||
void NormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
// skip everything except input0 for Normalization
|
||||
const auto& input_defs = node.InputDefs();
|
||||
model_builder.AddInitializerToSkip(input_defs[1]->Name()); // scale
|
||||
if (input_defs.size() > 2) {
|
||||
model_builder.AddInitializerToSkip(input_defs[2]->Name()); // B
|
||||
}
|
||||
}
|
||||
|
||||
Status NormalizationOpBuilder::AddToModelBuilderImpl(
|
||||
[[maybe_unused]] ModelBuilder& model_builder,
|
||||
[[maybe_unused]] const Node& node,
|
||||
[[maybe_unused]] const logging::Logger& logger) const {
|
||||
if (node.OpType() == "GroupNormalization") {
|
||||
return AddGroupNormToModelBuilderImpl(model_builder, node, logger);
|
||||
}
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
const auto& input_defs = node.InputDefs();
|
||||
NodeAttrHelper helper(node);
|
||||
const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name());
|
||||
|
||||
const auto eps = helper.Get("epsilon", 1e-5f);
|
||||
|
||||
std::vector<int64_t> input_shape;
|
||||
// GetShape will never fail as we have already verified the input shape in IsOpSupportedImpl
|
||||
GetShape(*input_defs[0], input_shape, logger);
|
||||
|
||||
const auto rank = input_shape.size();
|
||||
auto axis = static_cast<size_t>(HandleNegativeAxis(helper.Get("axis", 1), rank));
|
||||
|
||||
std::vector<int64_t> axes(rank - axis);
|
||||
std::iota(axes.begin(), axes.end(), axis);
|
||||
auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
|
||||
if (model_builder.CreateMLProgram()) {
|
||||
using namespace CoreML::Specification::MILSpec;
|
||||
std::string_view layer_input_name_x = node.InputDefs()[0]->Name();
|
||||
std::string_view op_name = (node.OpType() == "InstanceNormalization") ? "instance_norm" : "layer_norm";
|
||||
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.layer_norm
|
||||
|
||||
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, op_name);
|
||||
AddOperationInput(*op, "x", layer_input_name_x);
|
||||
if (op_name == "layer_norm") {
|
||||
AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), input_defs[0]->Name() + "axes", axes));
|
||||
}
|
||||
AddOperationInput(*op, "gamma", model_builder.AddConstant(op->type(), input_defs[1]->Name() + "gamma", scale_tensor));
|
||||
if (input_defs.size() > 2) {
|
||||
const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name());
|
||||
AddOperationInput(*op, "beta", model_builder.AddConstant(op->type(), input_defs[2]->Name() + "beta", bias_tensor));
|
||||
}
|
||||
|
||||
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
MLFloat16 epsilon_fp16(eps);
|
||||
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", epsilon_fp16));
|
||||
} else {
|
||||
AddOperationInput(*op, "epsilon", model_builder.AddScalarConstant(op->type(), "epsilon", eps));
|
||||
}
|
||||
|
||||
AddOperationOutput(*op, *node.OutputDefs()[0]);
|
||||
model_builder.AddOperation(std::move(op));
|
||||
}
|
||||
#endif // (COREML_ENABLE_MLPROGRAM)
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status NormalizationOpBuilder::AddGroupNormToModelBuilderImpl(
|
||||
[[maybe_unused]] ModelBuilder& model_builder,
|
||||
[[maybe_unused]] const Node& node,
|
||||
[[maybe_unused]] const logging::Logger& logger) const {
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
const auto& input_defs = node.InputDefs();
|
||||
NodeAttrHelper helper(node);
|
||||
// Coreml hasn't supported GroupNorm yet.
|
||||
// we decompose GroupNorm to sub ops and levrage LayerNorm to implement GroupNorm.
|
||||
// groupnorm --> reshape [b, num_groups, c // (num_groups), h, w] --> layer_norm --> reshape [b, c, h, w]->mul(scale)->add(bias)
|
||||
|
||||
// scale and bias is required for group-norm by the onnx spec
|
||||
const auto& scale_tensor = *model_builder.GetConstantInitializer(input_defs[1]->Name());
|
||||
const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name());
|
||||
|
||||
const auto eps = helper.Get("epsilon", 1e-5f);
|
||||
int64_t num_groups = helper.Get("num_groups", 1); // GroupNorm
|
||||
|
||||
std::vector<int64_t> input_shape;
|
||||
GetShape(*input_defs[0], input_shape, logger);
|
||||
|
||||
const auto input_size = input_shape.size();
|
||||
int64_t axis = 2;
|
||||
std::vector<int64_t> axes(input_size + 1 - axis); // Group add one more dim
|
||||
std::iota(axes.begin(), axes.end(), axis);
|
||||
auto input_dtype = node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
int64_t channel_dims = input_shape[1];
|
||||
|
||||
if (model_builder.CreateMLProgram()) {
|
||||
using namespace CoreML::Specification::MILSpec;
|
||||
std::string_view layer_input_name_x = node.InputDefs()[0]->Name();
|
||||
const int32_t elem_type = static_cast<int32_t>(input_dtype);
|
||||
|
||||
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.normalization.layer_norm
|
||||
// https://github.com/apple/coremltools/blob/9827d424b3c5b5fbb6ddc8891a000d87a188c84f/coremltools/converters/mil/frontend/torch/ops.py#L1354
|
||||
// reshape to [b, num_groups, c // (num_groups), h, w]
|
||||
auto reshape1 = model_builder.CreateOperation(node, "reshape", "pre");
|
||||
std::vector<int64_t> shape1 = input_shape;
|
||||
shape1.insert(shape1.begin() + 1, num_groups);
|
||||
shape1[2] = input_shape[1] / num_groups;
|
||||
std::vector<int64_t> shape_scale_bias(input_shape.size(), 1);
|
||||
shape_scale_bias[1] = channel_dims;
|
||||
AddOperationInput(*reshape1, "x", node.InputDefs()[0]->Name());
|
||||
AddOperationInput(*reshape1, "shape", model_builder.AddConstant(reshape1->type(), "shape1", shape1));
|
||||
layer_input_name_x = model_builder.GetUniqueName(node, "ln_reshape1_");
|
||||
AddIntermediateOperationOutput(*reshape1, layer_input_name_x, elem_type, shape1);
|
||||
|
||||
std::unique_ptr<Operation> layer_norm = model_builder.CreateOperation(node, "layer_norm");
|
||||
AddOperationInput(*layer_norm, "x", layer_input_name_x);
|
||||
AddOperationInput(*layer_norm, "axes", model_builder.AddConstant(layer_norm->type(), "axes", axes));
|
||||
|
||||
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
MLFloat16 epsilon_fp16(eps);
|
||||
AddOperationInput(*layer_norm, "epsilon", model_builder.AddScalarConstant(layer_norm->type(), "epsilon", epsilon_fp16));
|
||||
} else {
|
||||
AddOperationInput(*layer_norm, "epsilon", model_builder.AddScalarConstant(layer_norm->type(), "epsilon", eps));
|
||||
}
|
||||
|
||||
const auto& ln_output_name = model_builder.GetUniqueName(node, "ln_output_");
|
||||
AddIntermediateOperationOutput(*layer_norm, ln_output_name, elem_type, shape1);
|
||||
|
||||
auto reshape2 = model_builder.CreateOperation(node, "reshape", "post");
|
||||
AddOperationInput(*reshape2, "x", ln_output_name);
|
||||
AddOperationInput(*reshape2, "shape", model_builder.AddConstant(reshape2->type(), "shape2", input_shape));
|
||||
|
||||
const auto& reshape2_output_name = model_builder.GetUniqueName(node, "gn_reshape_output_");
|
||||
AddIntermediateOperationOutput(*reshape2, reshape2_output_name, elem_type, input_shape);
|
||||
|
||||
auto mul = model_builder.CreateOperation(node, "mul", "post_mul");
|
||||
AddOperationInput(*mul, "x", reshape2_output_name);
|
||||
AddOperationInput(*mul, "y", model_builder.AddConstant(mul->type(), "mul1", scale_tensor, shape_scale_bias));
|
||||
const auto& mul_output_name = model_builder.GetUniqueName(node, "mul_output_");
|
||||
AddIntermediateOperationOutput(*mul, mul_output_name, elem_type, input_shape);
|
||||
|
||||
auto add = model_builder.CreateOperation(node, "add", "post_add");
|
||||
AddOperationInput(*add, "x", mul_output_name);
|
||||
AddOperationInput(*add, "y", model_builder.AddConstant(add->type(), "add1", bias_tensor, shape_scale_bias));
|
||||
AddOperationOutput(*add, *node.OutputDefs()[0]);
|
||||
|
||||
model_builder.AddOperation(std::move(reshape1));
|
||||
model_builder.AddOperation(std::move(layer_norm));
|
||||
model_builder.AddOperation(std::move(reshape2));
|
||||
model_builder.AddOperation(std::move(mul));
|
||||
model_builder.AddOperation(std::move(add));
|
||||
}
|
||||
#endif // (COREML_ENABLE_MLPROGRAM)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool NormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const {
|
||||
// LayerNormalization may have three output in the training mode, but we only support the inference mode
|
||||
// for InstanceNormalization and GroupNormalization, they only have one output, so this check will always return true
|
||||
if (node.OutputDefs().size() != 1) {
|
||||
LOGS(logger, VERBOSE) << "Your onnx model (with LayerNormalization) may be in training mode,"
|
||||
<< " please export it for inferencing.";
|
||||
return false;
|
||||
}
|
||||
const auto& input_defs = node.InputDefs();
|
||||
std::vector<int64_t> input_shape;
|
||||
if (!GetShape(*input_defs[0], input_shape, logger)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// groupnorm and layernorm has attribute "stash_type", while InstanceNormalization doesn't have this attribute
|
||||
// Type of Mean and InvStdDev. This also specifies stage one’s computation precision.
|
||||
// if stash_type is 1, this operator casts all input variables to 32-bit float,
|
||||
// perform the computation, and finally cast Normalized back to the original type of X
|
||||
// coreml didn't have a similiar attribute to stash_type, for now, we support the default value
|
||||
if (node.OpType() != "InstanceNormalization") {
|
||||
NodeAttrHelper helper(node);
|
||||
const auto stash_type = helper.Get("stash_type", 1);
|
||||
if (stash_type != 1) {
|
||||
LOGS(logger, VERBOSE) << "stash_type != 1 is not supported";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& scale_name = input_defs[1]->Name();
|
||||
const auto* scale_tensor = input_params.graph_viewer.GetConstantInitializer(scale_name);
|
||||
if (!scale_tensor) {
|
||||
LOGS(logger, VERBOSE) << "Scale must be a constant initializer";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input_defs.size() > 2) {
|
||||
const auto& b_name = input_defs[2]->Name();
|
||||
const auto& b_tensor = input_params.graph_viewer.GetConstantInitializer(b_name);
|
||||
if (!b_tensor) {
|
||||
LOGS(logger, VERBOSE) << "Bias must be a constant initializer";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params,
|
||||
const logging::Logger& logger) const {
|
||||
if (!input_params.create_mlprogram) {
|
||||
return false;
|
||||
}
|
||||
// We only check the type of input 0,1,2
|
||||
const auto& input_0 = *node.InputDefs()[0];
|
||||
const auto& input_1 = *node.InputDefs()[1];
|
||||
const auto& input_2 = node.InputDefs().size() > 2 ? *node.InputDefs()[2] : input_0;
|
||||
int32_t input_type_0, input_type_1, input_type_2;
|
||||
if (!GetType(input_0, input_type_0, logger)) {
|
||||
return false;
|
||||
}
|
||||
if (!GetType(input_1, input_type_1, logger)) {
|
||||
return false;
|
||||
}
|
||||
if (!GetType(input_2, input_type_2, logger)) {
|
||||
return false;
|
||||
}
|
||||
if (input_type_0 != input_type_1 || input_type_0 != input_type_2) {
|
||||
LOGS(logger, VERBOSE) << "Input types of LayerNorm must be the same";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input_type_0 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
|
||||
input_type_0 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
LOGS(logger, VERBOSE) << "Input types of LayerNorm must be float or float16";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<NormalizationOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
}
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace onnxruntime
|
|
@ -14,6 +14,7 @@
|
|||
#include "core/providers/coreml/coreml_provider_factory.h"
|
||||
#include "core/providers/coreml/model/host_utils.h"
|
||||
#include "core/providers/coreml/shape_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
// includes from coremltools-src in _deps
|
||||
|
@ -1003,6 +1004,48 @@ Status ModelBuilder::LoadModel(std::unique_ptr<Model>& model) {
|
|||
return model->LoadModel(); // load using CoreML API, including compilation
|
||||
}
|
||||
|
||||
#if defined(COREML_ENABLE_MLPROGRAM)
|
||||
std::string_view ModelBuilder::AddConstant(std::string_view op_type, std::string_view value_type,
|
||||
const ONNX_NAMESPACE::TensorProto& tensor,
|
||||
std::optional<gsl::span<const int64_t>> shape) {
|
||||
const auto data_type = tensor.data_type();
|
||||
Initializer unpacked_tensor(tensor);
|
||||
std::string_view ret;
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<float>(), shape ? shape : tensor.dims());
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<MLFloat16>(), shape ? shape : tensor.dims());
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
|
||||
ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<int64_t>(), shape ? shape : tensor.dims());
|
||||
break;
|
||||
// case ONNX_NAMESPACE::TensorProto_DataType_INT32:
|
||||
// ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<int32_t>(), shape?shape:tensor.dims());
|
||||
// break;
|
||||
// case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
// ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<double>(), shape?shape:tensor.dims());
|
||||
// break;
|
||||
// case ONNX_NAMESPACE::TensorProto_DataType_INT8:
|
||||
// ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<int8_t>(), shape?shape:tensor.dims());
|
||||
// break;
|
||||
// case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
|
||||
// ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<uint8_t>(), shape?shape:tensor.dims());
|
||||
// break;
|
||||
// case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
|
||||
// ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<uint32_t>(), shape?shape:tensor.dims());
|
||||
// break;
|
||||
// case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
|
||||
// ret = AddConstant(op_type, value_type, unpacked_tensor.DataAsSpan<bool>(), shape?shape:tensor.dims());
|
||||
// break;
|
||||
default:
|
||||
ORT_THROW("AddConstant: Unsupported data type: ", data_type);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
// static
|
||||
Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger,
|
||||
int32_t coreml_version, uint32_t coreml_flags,
|
||||
|
|
|
@ -129,6 +129,12 @@ class ModelBuilder {
|
|||
return AddConstant(op_type, value_type, gsl::span<const T>(value), shape);
|
||||
}
|
||||
|
||||
// helper to convert a initializer to a constant
|
||||
// by default, shape is inferred from the tensor.dims(), but can be provided to override if needed
|
||||
std::string_view AddConstant(std::string_view op_type, std::string_view value_type,
|
||||
const ONNX_NAMESPACE::TensorProto& tensor,
|
||||
std::optional<gsl::span<const int64_t>> shape = std::nullopt);
|
||||
|
||||
/// <summary>
|
||||
/// Add a scalar value as a 'const' operation. See AddConstant for details.
|
||||
/// </summary>
|
||||
|
|
|
@ -21,6 +21,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateActivationOpBuilder("Relu", op_registrations);
|
||||
CreateActivationOpBuilder("PRelu", op_registrations);
|
||||
CreateActivationOpBuilder("LeakyRelu", op_registrations);
|
||||
CreateActivationOpBuilder("Gelu", op_registrations);
|
||||
|
||||
// Unary ops
|
||||
CreateUnaryOpBuilder("Reciprocal", op_registrations);
|
||||
|
@ -43,8 +44,13 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
|
|||
CreateReductionOpBuilder("ReduceMean", op_registrations);
|
||||
CreateReductionOpBuilder("ReduceSum", op_registrations);
|
||||
|
||||
CreateArgMaxOpBuilder("ArgMax", op_registrations);
|
||||
// Normalization ops
|
||||
CreateBatchNormalizationOpBuilder("BatchNormalization", op_registrations);
|
||||
CreateNormalizationOpBuilder("GroupNormalization", op_registrations);
|
||||
CreateNormalizationOpBuilder("InstanceNormalization", op_registrations);
|
||||
CreateNormalizationOpBuilder("LayerNormalization", op_registrations);
|
||||
|
||||
CreateArgMaxOpBuilder("ArgMax", op_registrations);
|
||||
CreateCastOpBuilder("Cast", op_registrations);
|
||||
CreateClipOpBuilder("Clip", op_registrations);
|
||||
CreateConcatOpBuilder("Concat", op_registrations);
|
||||
|
|
|
@ -19,6 +19,7 @@ const std::unordered_map<std::string, const IOpBuilder*>& GetOpBuilders();
|
|||
void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateArgMaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateBatchNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
// - iOS 16 ops
|
||||
// 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7)
|
||||
// - iOS 17 ops
|
||||
// 9 : iOS 18, macOS 15, tvOS 18, watchOS 11 (Core ML 8)
|
||||
// - iOS 18 ops
|
||||
//
|
||||
// **NOTE** We use the Core ML version not the spec version.
|
||||
//
|
||||
|
@ -39,6 +41,7 @@
|
|||
#define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15))
|
||||
#define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16))
|
||||
#define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17))
|
||||
#define API_AVAILABLE_COREML8 API_AVAILABLE(macos(15), ios(18))
|
||||
|
||||
// @available is used in implementation code
|
||||
// Base required OS to run CoreML Specification Version 4 (Core ML 3)
|
||||
|
@ -47,6 +50,7 @@
|
|||
#define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *)
|
||||
#define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *)
|
||||
#define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *)
|
||||
#define HAS_COREML8_OR_LATER @available(macOS 15, iOS 18, *)
|
||||
|
||||
#endif
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@ bool HasRequiredBaseOS() {
|
|||
}
|
||||
|
||||
int32_t CoreMLVersion() {
|
||||
if (HAS_COREML8_OR_LATER)
|
||||
return 8;
|
||||
if (HAS_COREML7_OR_LATER)
|
||||
return 7;
|
||||
if (HAS_COREML6_OR_LATER)
|
||||
|
|
|
@ -194,6 +194,24 @@ inline void CheckTensor(const Tensor& expected_tensor, const Tensor& output_tens
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetTypedArray(std::vector<float> inputs) {
|
||||
static_assert(std::is_same<T, float>::value || std::is_same<T, double>::value ||
|
||||
std::is_same<T, MLFloat16>::value || std::is_integral_v<T>,
|
||||
"Only float, double, MLFloat16, and integral types are supported.");
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return inputs;
|
||||
} else if constexpr (std::is_integral_v<T> || std::is_same<T, double>::value) {
|
||||
std::vector<T> result(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
result[i] = static_cast<T>(inputs[i]);
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
return ToFloat16(inputs);
|
||||
}
|
||||
}
|
||||
|
||||
class ParallelRandomValueGenerator {
|
||||
public:
|
||||
using RandomEngine = std::default_random_engine;
|
||||
|
|
|
@ -162,7 +162,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput_Initializers) {
|
|||
// TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider,
|
||||
kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider});
|
||||
kNnapiExecutionProvider, kQnnExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(LayerNormTest, LayerNorm_Scale_Bias) {
|
||||
|
@ -211,20 +211,31 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) {
|
|||
}
|
||||
|
||||
TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) {
|
||||
OpTester test("LayerNormalization");
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
auto run_test = [](bool is_initializer) {
|
||||
OpTester test("LayerNormalization");
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
|
||||
std::vector<int64_t> dims{1, 3, 2};
|
||||
test.AddInput<MLFloat16>("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}));
|
||||
test.AddInput<MLFloat16>("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}));
|
||||
test.AddInput<MLFloat16>("bias", {2}, ToFloat16({0.6435f, -0.3964f}));
|
||||
test.AddOutput<MLFloat16>("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}));
|
||||
// TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider,
|
||||
kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider});
|
||||
std::vector<int64_t> dims{1, 3, 2};
|
||||
test.AddInput<MLFloat16>("x", dims, ToFloat16({1.2416f, 0.946123f, 13.1685f, 0.36423f, 21.145f, 0.03941f}));
|
||||
test.AddInput<MLFloat16>("gamma", {2}, ToFloat16({-0.6953f, 5.1824f}), is_initializer);
|
||||
test.AddInput<MLFloat16>("bias", {2}, ToFloat16({0.6435f, -0.3964f}), is_initializer);
|
||||
test.AddOutput<MLFloat16>("output", dims, ToFloat16({-0.0516f, -5.5776f, -0.0518f, -5.5788f, -0.0518f, -5.5788f}));
|
||||
// TRT, DNNL, OpenVINO and NNAPI don't support this combination of datatypes
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider,
|
||||
kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider});
|
||||
};
|
||||
run_test(false);
|
||||
run_test(true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class LayerNormTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
using LayerNormTestTypes = ::testing::Types<float, MLFloat16>;
|
||||
TYPED_TEST_SUITE(LayerNormTest, LayerNormTestTypes);
|
||||
|
||||
TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializers) {
|
||||
OpTester test("LayerNormalization");
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
|
@ -237,19 +248,41 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput_Initializer
|
|||
// TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider,
|
||||
kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider});
|
||||
kNnapiExecutionProvider, kQnnExecutionProvider});
|
||||
}
|
||||
|
||||
// LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check.
|
||||
TEST(LayerNormTest, LayerNorm17_float) {
|
||||
OpTester test("LayerNormalization", 17);
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
TYPED_TEST(LayerNormTest, LayerNorm17_opset) {
|
||||
auto run_test = [](bool is_initializer) {
|
||||
OpTester test("LayerNormalization", 17);
|
||||
test.AddAttribute<float>("epsilon", 1e-05f);
|
||||
|
||||
std::vector<int64_t> dims{1, 2, 3};
|
||||
test.AddInput<float>("x", dims, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
|
||||
test.AddInput<float>("gamma", {3}, {1.0f, 1.0f, 1.0f});
|
||||
test.AddOutput<float>("output", dims, {-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f});
|
||||
test.Run();
|
||||
std::vector<int64_t> dims{1, 2, 3};
|
||||
test.AddInput<TypeParam>("x", dims, GetTypedArray<TypeParam>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
|
||||
test.AddInput<TypeParam>("gamma", {3}, GetTypedArray<TypeParam>({1.0f, 1.0f, 1.0f}), is_initializer);
|
||||
test.AddOutput<TypeParam>("output", dims, GetTypedArray<TypeParam>({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f}));
|
||||
if (std::is_same<TypeParam, MLFloat16>::value) {
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCoreMLExecutionProvider(true));
|
||||
// coreml EP requires weight and bias to be initializers
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
|
||||
{kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider,
|
||||
kNnapiExecutionProvider, kQnnExecutionProvider},
|
||||
nullptr, &execution_providers);
|
||||
} else {
|
||||
test.Run();
|
||||
}
|
||||
};
|
||||
// Execution provider entry invalid.
|
||||
// when other EPs support layer-norm fp16, this test should be updated to include them.
|
||||
if (std::is_same<TypeParam, MLFloat16>::value) {
|
||||
#if !defined(COREML_ENABLE_MLPROGRAM)
|
||||
return;
|
||||
#endif
|
||||
}
|
||||
|
||||
run_test(false);
|
||||
run_test(true);
|
||||
}
|
||||
|
||||
TEST(LayerNormTest, LayerNorm17_double) {
|
||||
|
|
|
@ -127,6 +127,10 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) {
|
|||
MakeCoreMLExecutionProvider(),
|
||||
feeds,
|
||||
verification_params);
|
||||
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
|
||||
MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM),
|
||||
feeds,
|
||||
verification_params);
|
||||
#else
|
||||
TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All);
|
||||
#endif
|
||||
|
@ -164,6 +168,11 @@ TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) {
|
|||
MakeCoreMLExecutionProvider(),
|
||||
feeds,
|
||||
verification_params);
|
||||
|
||||
RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(),
|
||||
MakeCoreMLExecutionProvider(COREML_FLAG_CREATE_MLPROGRAM),
|
||||
feeds,
|
||||
verification_params);
|
||||
#else
|
||||
TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some);
|
||||
#endif
|
||||
|
|
|
@ -105,7 +105,12 @@ class ActivationOpTest : public ::testing::Test {
|
|||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_real_distribution<float> dist(low, high);
|
||||
#ifdef COREML_ENABLE_MLPROGRAM
|
||||
// please check onnxruntime/onnxruntime/core/providers/coreml/builders/helper.cc:81
|
||||
std::vector<std::size_t> batch_size_list = {1, 2, 4, 9, 100};
|
||||
#else
|
||||
std::vector<std::size_t> batch_size_list = {1, 2, 4, 9, 100000};
|
||||
#endif
|
||||
for (auto batch_size : batch_size_list) {
|
||||
std::vector<float> vec(batch_size);
|
||||
for (size_t i = 0; i != batch_size; ++i) {
|
||||
|
|
|
@ -704,7 +704,7 @@ TEST(BatchNormTest, NonSpatial_Complicated) {
|
|||
}
|
||||
|
||||
// Only CUDA and ROCm kernels have float 16 support
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM)
|
||||
TEST(BatchNormTest, BatchNorm2d_fp16) {
|
||||
vector<float> X{-0.91221f, -0.283559f, 0.937637f, 2.09818f, -0.100199f, -0.608113f, 0.444562f, -1.07505f, 0.940591f,
|
||||
-0.922262f, 0.0931303f, 0.69611f, 1.55187f, 0.159808f, 0.914874f, -1.24856f, -1.98928f, -0.331621f,
|
||||
|
@ -765,9 +765,6 @@ TEST(BatchNormTest, BatchNorm2d_fp16) {
|
|||
-0.0989828f, -0.160014f, 0.362077f, 0.0649763f, -0.371465f, 0.727401f, 0.0320011f};
|
||||
float epsilon = 1e-05f;
|
||||
|
||||
OpTester test("BatchNormalization");
|
||||
test.AddAttribute("epsilon", epsilon);
|
||||
|
||||
vector<int64_t> input_shape{2, 3, 6, 6};
|
||||
int input_size = 2 * 3 * 6 * 6;
|
||||
|
||||
|
@ -785,13 +782,20 @@ TEST(BatchNormTest, BatchNorm2d_fp16) {
|
|||
ConvertFloatToMLFloat16(var.data(), f_var.data(), 3);
|
||||
ConvertFloatToMLFloat16(expected_output.data(), f_output.data(), input_size);
|
||||
|
||||
test.AddInput<MLFloat16>("X", input_shape, f_X);
|
||||
test.AddInput<MLFloat16>("scale", {3}, f_scale);
|
||||
test.AddInput<MLFloat16>("B", {3}, f_B);
|
||||
test.AddInput<MLFloat16>("mean", {3}, f_mean);
|
||||
test.AddInput<MLFloat16>("var", {3}, f_var);
|
||||
test.AddOutput<MLFloat16>("output", input_shape, f_output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
auto run_test = [&](bool is_initializer) {
|
||||
OpTester test("BatchNormalization");
|
||||
test.AddAttribute("epsilon", epsilon);
|
||||
test.AddInput<MLFloat16>("X", input_shape, f_X);
|
||||
test.AddInput<MLFloat16>("scale", {3}, f_scale, is_initializer);
|
||||
test.AddInput<MLFloat16>("B", {3}, f_B, is_initializer);
|
||||
test.AddInput<MLFloat16>("mean", {3}, f_mean, is_initializer);
|
||||
test.AddInput<MLFloat16>("var", {3}, f_var, is_initializer);
|
||||
test.AddOutput<MLFloat16>("output", input_shape, f_output, is_initializer);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
};
|
||||
run_test(false);
|
||||
// coreml EP requires initializer
|
||||
run_test(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/providers/xnnpack/xnnpack_init.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "default_providers.h"
|
||||
|
||||
using namespace std;
|
||||
|
@ -130,17 +131,6 @@ TEST(ConvTransposeTest, ConvTranspose_1D) {
|
|||
TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::vector<T> GetTypedArray(std::vector<float> inputs, [[maybe_unused]] T v = T(0.f)) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return inputs;
|
||||
} else {
|
||||
std::vector<T> inputs_fp16(inputs.size());
|
||||
ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size());
|
||||
return inputs_fp16;
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(ConvTransposeTest, ConvTranspose_2D_outputpadding_strides2) {
|
||||
ConvTransposeOpAttributes attrs = {
|
||||
vector<int64_t>{3, 3}, // kernel_shape
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
|
||||
#ifdef COREML_ENABLE_MLPROGRAM
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <typename T>
|
||||
class GroupNormalizationOpTest : public ::testing::Test {
|
||||
};
|
||||
using GroupNormalizationOpTestTypes = ::testing::Types<float, MLFloat16>;
|
||||
TYPED_TEST_SUITE(GroupNormalizationOpTest, GroupNormalizationOpTestTypes);
|
||||
|
||||
// GroupSize = channel_dims to simulate InstanceNorm
|
||||
// Disable TensorRT on some of the tests because its parser doesn't support weight as input
|
||||
TYPED_TEST(GroupNormalizationOpTest, Equivalent_InstanceNorm_G_C) {
|
||||
OpTester test("GroupNormalization", 18);
|
||||
test.AddAttribute("epsilon", 0.3F);
|
||||
test.AddAttribute("num_groups", int64_t(3));
|
||||
|
||||
vector<float> input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F,
|
||||
8.519701F, 1.2382338F, 1.7930176F, 5.1099434F,
|
||||
7.9195533F, 7.638727F, 8.065445F, 3.8082376F,
|
||||
|
||||
2.3667817F, 2.8248506F, 3.7754705F, 5.861325F,
|
||||
5.058735F, 3.2787242F, 3.6843839F, 9.755121F,
|
||||
2.7902672F, 7.3974323F, 8.283609F, 8.488337F};
|
||||
vector<int64_t> input_dims = {2, 3, 4};
|
||||
test.AddInput<TypeParam>("X", input_dims, GetTypedArray<TypeParam>(input));
|
||||
|
||||
vector<float> scale = {1.F, 1.F, 1.F};
|
||||
vector<int64_t> scale_dims = {3};
|
||||
test.AddInput<TypeParam>("scale", scale_dims, GetTypedArray<TypeParam>(scale), true);
|
||||
|
||||
vector<float> B = {0.F, 0.F, 0.F};
|
||||
vector<int64_t> B_dims = {3};
|
||||
test.AddInput<TypeParam>("bias", B_dims, GetTypedArray<TypeParam>(B), true);
|
||||
|
||||
// expected output is calculated using torch.nn.GroupNorm(3, 3, eps=0.3)
|
||||
vector<float> expected_output = {-0.56495477f, 1.48930046f, -1.13334329f, 0.20899761f,
|
||||
1.46688162f, -0.98600774f, -0.79911913f, 0.31824524f,
|
||||
0.57370438f, 0.42193634f, 0.6525492f, -1.64818992f,
|
||||
|
||||
-0.92380346f, -0.60808484f, 0.04711878f, 1.48476953f,
|
||||
-0.14644464f, -0.82262872f, -0.66852817f, 1.63760153f,
|
||||
-1.65898662f, 0.27618144f, 0.64840618f, 0.734399f};
|
||||
|
||||
test.AddOutput<TypeParam>("Y", input_dims, GetTypedArray<TypeParam>(expected_output));
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCoreMLExecutionProvider(true));
|
||||
// coreml EP requires weight and bias to be initializers
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
// GroupSize = 1 to simulate LayerNorm, (LayerNorm)
|
||||
// expected output is calculated using torch.nn.GroupNorm(1, 3, eps=1e-5f)
|
||||
TYPED_TEST(GroupNormalizationOpTest, Equivalent_LayerNorm_G_1) {
|
||||
auto run_test = [](bool is_initializer) {
|
||||
OpTester test("GroupNormalization", 18);
|
||||
test.AddAttribute<float>("epsilon", 1e-5f);
|
||||
test.AddAttribute("num_groups", int64_t(1));
|
||||
|
||||
std::vector<int64_t> dims{1, 2, 3};
|
||||
test.AddInput<TypeParam>("x", dims, GetTypedArray<TypeParam>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}));
|
||||
test.AddInput<TypeParam>("scale", {2}, GetTypedArray<TypeParam>({1.0f, 1.0f}), is_initializer);
|
||||
test.AddInput<TypeParam>("bias", {2}, GetTypedArray<TypeParam>({2.0f, 1.0f}), is_initializer);
|
||||
test.AddOutput<TypeParam>("output", dims, GetTypedArray<TypeParam>({0.5361f, 1.1216f, 1.7072f, 1.2928f, 1.8783f, 2.4638f}));
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCoreMLExecutionProvider(true));
|
||||
// coreml EP requires weight and bias to be initializers
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
};
|
||||
|
||||
run_test(true);
|
||||
}
|
||||
|
||||
// expected output is calculated using torch.nn.GroupNorm(2, 6, eps=0.3)
|
||||
TYPED_TEST(GroupNormalizationOpTest, GroupSize_N) {
|
||||
OpTester test("GroupNormalization", 18);
|
||||
test.AddAttribute("epsilon", 0.3F);
|
||||
test.AddAttribute("num_groups", int64_t(2));
|
||||
|
||||
vector<float> input = {-1.1258f, -1.1524f, -0.2506f, -0.4339f,
|
||||
0.8487f, 0.6920f, -0.3160f, -2.1152f,
|
||||
0.3223f, -1.2633f, 0.3500f, 0.3081f,
|
||||
0.1198f, 1.2377f, 1.1168f, -0.2473f,
|
||||
-1.3527f, -1.6959f, 0.5667f, 0.7935f,
|
||||
0.5988f, -1.5551f, -0.3414f, 1.8530f,
|
||||
|
||||
0.7502f, -0.5855f, -0.1734f, 0.1835f,
|
||||
1.3894f, 1.5863f, 0.9463f, -0.8437f,
|
||||
-0.6136f, 0.0316f, -0.4927f, 0.2484f,
|
||||
0.4397f, 0.1124f, 0.6408f, 0.4412f,
|
||||
-0.1023f, 0.7924f, -0.2897f, 0.0525f,
|
||||
0.5229f, 2.3022f, -1.4689f, -1.5867f};
|
||||
vector<int64_t> input_dims = {2, 6, 4};
|
||||
test.AddInput<TypeParam>("X", input_dims, GetTypedArray<TypeParam>(input));
|
||||
|
||||
vector<float> scale = {1.F, 1.F, 1.F, 1.F, 1.F, 1.F};
|
||||
vector<int64_t> scale_dims = {6};
|
||||
test.AddInput<TypeParam>("scale", scale_dims, GetTypedArray<TypeParam>(scale), true);
|
||||
|
||||
vector<float> B = {.0F, .0F, .0F, .0F, .0F, .0F};
|
||||
vector<int64_t> B_dims = {6};
|
||||
test.AddInput<TypeParam>("bias", B_dims, GetTypedArray<TypeParam>(B), true);
|
||||
|
||||
vector<float> expected_output = {
|
||||
-0.7590f, -0.7848f, 0.0914f, -0.0867f,
|
||||
1.1595f, 1.0073f, 0.0278f, -1.7203f,
|
||||
0.6480f, -0.8926f, 0.6749f, 0.6343f,
|
||||
0.0232f, 0.9274f, 0.8296f, -0.2738f,
|
||||
-1.1679f, -1.4456f, 0.3846f, 0.5681f,
|
||||
0.4107f, -1.3317f, -0.3499f, 1.4252f,
|
||||
|
||||
0.5772f, -0.8298f, -0.3957f, -0.0198f,
|
||||
1.2505f, 1.4580f, 0.7838f, -1.1017f,
|
||||
-0.8594f, -0.1798f, -0.7320f, 0.0486f,
|
||||
0.2541f, -0.0377f, 0.4334f, 0.2554f,
|
||||
-0.2291f, 0.5686f, -0.3962f, -0.0911f,
|
||||
0.3282f, 1.9145f, -1.4475f, -1.5525f};
|
||||
test.AddOutput<TypeParam>("Y", input_dims, GetTypedArray<TypeParam>(expected_output));
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCoreMLExecutionProvider(true));
|
||||
// coreml EP requires weight and bias to be initializers
|
||||
if constexpr (std::is_same<TypeParam, float>::value) {
|
||||
test.SetOutputTolerance(1e-4f);
|
||||
} else {
|
||||
test.SetOutputTolerance(0.005f);
|
||||
}
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
#endif
|
|
@ -3,71 +3,87 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
|
||||
using namespace std;
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <typename T>
|
||||
class InstanceNormalizationOpTest : public ::testing::Test {
|
||||
};
|
||||
using InstanceNormalizationOpTestTypes = ::testing::Types<float, MLFloat16>;
|
||||
TYPED_TEST_SUITE(InstanceNormalizationOpTest, InstanceNormalizationOpTestTypes);
|
||||
|
||||
// Disable TensorRT on some of the tests because its parser doesn't support weight as input
|
||||
|
||||
TEST(InstanceNormalizationOpTest, InstanceNorm) {
|
||||
OpTester test("InstanceNormalization");
|
||||
test.AddAttribute("epsilon", 0.3F);
|
||||
TYPED_TEST(InstanceNormalizationOpTest, InstanceNorm) {
|
||||
auto run_test = [](bool is_initializer) {
|
||||
OpTester test("InstanceNormalization");
|
||||
test.AddAttribute("epsilon", 0.3F);
|
||||
|
||||
vector<float> input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F,
|
||||
8.519701F, 1.2382338F, 1.7930176F, 5.1099434F,
|
||||
7.9195533F, 7.638727F, 8.065445F, 3.8082376F,
|
||||
vector<float> input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F,
|
||||
8.519701F, 1.2382338F, 1.7930176F, 5.1099434F,
|
||||
7.9195533F, 7.638727F, 8.065445F, 3.8082376F,
|
||||
|
||||
2.3667817F, 2.8248506F, 3.7754705F, 5.861325F,
|
||||
5.058735F, 3.2787242F, 3.6843839F, 9.755121F,
|
||||
2.7902672F, 7.3974323F, 8.283609F, 8.488337F};
|
||||
vector<int64_t> input_dims = {2, 3, 4};
|
||||
test.AddInput<float>("input", input_dims, input);
|
||||
2.3667817F, 2.8248506F, 3.7754705F, 5.861325F,
|
||||
5.058735F, 3.2787242F, 3.6843839F, 9.755121F,
|
||||
2.7902672F, 7.3974323F, 8.283609F, 8.488337F};
|
||||
vector<int64_t> input_dims = {2, 3, 4};
|
||||
test.AddInput<TypeParam>("input", input_dims, GetTypedArray<TypeParam>(input));
|
||||
|
||||
// vector<float> scale = {2.1F, 0.1F, 1.F};
|
||||
vector<float> scale = {1.0F, 1.0F, 1.F};
|
||||
vector<int64_t> scale_dims = {3};
|
||||
test.AddInput<float>("scale", scale_dims, scale);
|
||||
// vector<float> scale = {2.1F, 0.1F, 1.F};
|
||||
vector<float> scale = {1.0F, 1.0F, 1.F};
|
||||
vector<int64_t> scale_dims = {3};
|
||||
test.AddInput<TypeParam>("scale", scale_dims, GetTypedArray<TypeParam>(scale), is_initializer);
|
||||
|
||||
// vector<float> B = {2.3F, 1.5F, 0.F};
|
||||
vector<float> B = {0.0F, 0.0F, 0.F};
|
||||
vector<int64_t> B_dims = {3};
|
||||
test.AddInput<float>("B", B_dims, B);
|
||||
// vector<float> B = {2.3F, 1.5F, 0.F};
|
||||
vector<float> B = {0.0F, 0.0F, 0.F};
|
||||
vector<int64_t> B_dims = {3};
|
||||
test.AddInput<TypeParam>("B", B_dims, GetTypedArray<TypeParam>(B), is_initializer);
|
||||
|
||||
vector<float> expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F,
|
||||
1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F,
|
||||
0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F,
|
||||
vector<float> expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F,
|
||||
1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F,
|
||||
0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F,
|
||||
|
||||
-0.92380346F, -0.60808484F, 0.04711878F, 1.48476953F,
|
||||
-0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F,
|
||||
-1.65898662F, 0.27618144F, 0.64840618F, 0.734399F};
|
||||
test.AddOutput<float>("Y", input_dims, expected_output);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
-0.92380346F, -0.60808484F, 0.04711878F, 1.48476953F,
|
||||
-0.14644464F, -0.82262872F, -0.66852817F, 1.63760153F,
|
||||
-1.65898662F, 0.27618144F, 0.64840618F, 0.734399F};
|
||||
test.AddOutput<TypeParam>("Y", input_dims, GetTypedArray<TypeParam>(expected_output));
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
};
|
||||
run_test(false);
|
||||
run_test(true);
|
||||
}
|
||||
|
||||
TEST(InstanceNormalizationOpTest, InstanceNormBatch1) {
|
||||
OpTester test("InstanceNormalization");
|
||||
test.AddAttribute("epsilon", 0.3F);
|
||||
TYPED_TEST(InstanceNormalizationOpTest, InstanceNormBatch1) {
|
||||
auto run_test = [](bool is_initializer) {
|
||||
OpTester test("InstanceNormalization");
|
||||
test.AddAttribute("epsilon", 0.3F);
|
||||
|
||||
vector<float> input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F,
|
||||
8.519701F, 1.2382338F, 1.7930176F, 5.1099434F,
|
||||
7.9195533F, 7.638727F, 8.065445F, 3.8082376F};
|
||||
vector<int64_t> input_dims = {1, 3, 4};
|
||||
test.AddInput<float>("input", input_dims, input);
|
||||
vector<float> input = {3.1513367F, 9.283596F, 1.4546119F, 5.4617004F,
|
||||
8.519701F, 1.2382338F, 1.7930176F, 5.1099434F,
|
||||
7.9195533F, 7.638727F, 8.065445F, 3.8082376F};
|
||||
vector<int64_t> input_dims = {1, 3, 4};
|
||||
test.AddInput<TypeParam>("input", input_dims, GetTypedArray<TypeParam>(input));
|
||||
|
||||
vector<float> scale = {1.0F, 1.0F, 1.F};
|
||||
vector<int64_t> scale_dims = {3};
|
||||
test.AddInput<float>("scale", scale_dims, scale);
|
||||
vector<float> scale = {1.0F, 1.0F, 1.F};
|
||||
vector<int64_t> scale_dims = {3};
|
||||
test.AddInput<TypeParam>("scale", scale_dims, GetTypedArray<TypeParam>(scale), is_initializer);
|
||||
|
||||
vector<float> B = {0.0F, 0.0F, 0.F};
|
||||
vector<int64_t> B_dims = {3};
|
||||
test.AddInput<float>("B", B_dims, B);
|
||||
vector<float> B = {0.0F, 0.0F, 0.F};
|
||||
vector<int64_t> B_dims = {3};
|
||||
test.AddInput<TypeParam>("B", B_dims, GetTypedArray<TypeParam>(B), is_initializer);
|
||||
|
||||
vector<float> expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F,
|
||||
1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F,
|
||||
0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F};
|
||||
test.AddOutput<float>("Y", input_dims, expected_output);
|
||||
vector<float> expected_output = {-0.56495477F, 1.48930046F, -1.13334329F, 0.20899761F,
|
||||
1.46688162F, -0.98600774F, -0.79911913F, 0.31824524F,
|
||||
0.57370438F, 0.42193634F, 0.6525492F, -1.64818992F};
|
||||
test.AddOutput<TypeParam>("Y", input_dims, GetTypedArray<TypeParam>(expected_output));
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
};
|
||||
run_test(false);
|
||||
run_test(true);
|
||||
}
|
||||
|
||||
TEST(InstanceNormalizationOpTest, InstanceNormBatch2) {
|
||||
|
@ -105,7 +121,7 @@ TEST(InstanceNormalizationOpTest, InstanceNormBatch2) {
|
|||
}
|
||||
|
||||
// Only CUDA and ROCm kernels have float 16 support
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(COREML_ENABLE_MLPROGRAM)
|
||||
|
||||
TEST(InstanceNormalizationOpTest, InstanceNormBatch1_fp16) {
|
||||
OpTester test("InstanceNormalization");
|
||||
|
|
|
@ -3175,19 +3175,26 @@ TEST(ReductionOpTest, ReduceProd0DTensor) {
|
|||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
|
||||
}
|
||||
|
||||
TEST(ReductionOpTest, ArgMax) {
|
||||
template <typename T>
|
||||
class ReductionOpTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
using ReductionOpTestTypes = ::testing::Types<float, MLFloat16>;
|
||||
TYPED_TEST_SUITE(ReductionOpTest, ReductionOpTestTypes);
|
||||
|
||||
TYPED_TEST(ReductionOpTest, ArgMax) {
|
||||
OpTester test("ArgMax");
|
||||
test.AddAttribute("axis", (int64_t)1);
|
||||
test.AddAttribute("keepdims", (int64_t)1);
|
||||
test.AddInput<float>("data", {3, 2, 2},
|
||||
{1.0f, 2.0f,
|
||||
3.0f, 4.0f,
|
||||
test.AddInput<TypeParam>("data", {3, 2, 2},
|
||||
GetTypedArray<TypeParam>({1.0f, 2.0f,
|
||||
3.0f, 4.0f,
|
||||
|
||||
5.0f, 6.0f,
|
||||
7.0f, 8.0f,
|
||||
5.0f, 6.0f,
|
||||
7.0f, 8.0f,
|
||||
|
||||
9.0f, 10.0f,
|
||||
11.0f, 12.0f});
|
||||
9.0f, 10.0f,
|
||||
11.0f, 12.0f}));
|
||||
test.AddOutput<int64_t>("reduced", {3, 1, 2},
|
||||
{1, 1,
|
||||
1, 1,
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
@ -75,17 +76,6 @@ TEST(ConcatOpTest, Concat1D_2) {
|
|||
kQnnExecutionProvider}); // QNN: not support dynamic shape tensor
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::vector<T> GetTypedArray(std::vector<float> inputs, [[maybe_unused]] T v = T(0.f)) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return inputs;
|
||||
} else {
|
||||
std::vector<T> inputs_fp16(inputs.size());
|
||||
ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size());
|
||||
return inputs_fp16;
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST(ConcatOpTest, Concat2D_1) {
|
||||
OpTester test("Concat");
|
||||
test.AddAttribute("axis", int64_t{0});
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
@ -263,22 +264,6 @@ TEST(SliceTest, Slice3D) {
|
|||
332.0f, 333.0f});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::vector<T> GetTypedArray(std::vector<float> inputs, [[maybe_unused]] T v = T(0.f)) {
|
||||
std::vector<T> inputs_T(inputs.size());
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return inputs;
|
||||
} else if constexpr (std::is_integral_v<T>) {
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
inputs_T[i] = static_cast<T>(inputs[i]);
|
||||
}
|
||||
return inputs_T;
|
||||
} else {
|
||||
ConvertFloatToMLFloat16(inputs.data(), inputs_T.data(), inputs.size());
|
||||
return inputs_T;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void TestSlice1DIntData() {
|
||||
// static_assert(std::is_integral_v<TInt>);
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
#include "gtest/gtest.h"
|
||||
#include "core/framework/to_tensor_proto_element_type.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
@ -178,17 +179,6 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) {
|
|||
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetTypedArray(std::vector<float> inputs, [[maybe_unused]] T v = T(0.f)) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return inputs;
|
||||
} else {
|
||||
std::vector<T> inputs_fp16(inputs.size());
|
||||
ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size());
|
||||
return inputs_fp16;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SplitOperatorTest, Axis0UnequalSplitString) {
|
||||
constexpr int64_t axis = 0;
|
||||
std::vector<ShapeAndStringData> outputs;
|
||||
|
|
|
@ -69,17 +69,6 @@ void TransposeTest(const std::vector<int64_t>& input_shape,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> GetTypedArray(std::vector<float> inputs, [[maybe_unused]] T v = T(0.f)) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return inputs;
|
||||
} else {
|
||||
std::vector<T> inputs_fp16(inputs.size());
|
||||
ConvertFloatToMLFloat16(inputs.data(), inputs_fp16.data(), inputs.size());
|
||||
return inputs_fp16;
|
||||
}
|
||||
}
|
||||
|
||||
// Test 2 dimensional transpose, with no permutation attribute specified
|
||||
TYPED_TEST(TransposeOpTest, TwoDimNoAttr) {
|
||||
std::vector<int64_t> input_shape({2, 3});
|
||||
|
|
|
@ -4,7 +4,9 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
|
|||
|Operator|Note|
|
||||
|--------|------|
|
||||
|ai.onnx:Add||
|
||||
|ai.onnx:Argmax||
|
||||
|ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|
||||
|ai.onnx:Cast||
|
||||
|ai.onnx:Clip||
|
||||
|ai.onnx:Concat||
|
||||
|ai.onnx:Conv|Only 1D/2D Conv is supported.<br/>Bias if provided must be constant.|
|
||||
|
@ -12,14 +14,19 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution
|
|||
|ai.onnx:DepthToSpace|If 'mode' is 'CRD' the input must have a fixed shape.|
|
||||
|ai.onnx:Div||
|
||||
|ai.onnx:Gemm|Input B must be constant.|
|
||||
|ai.onnx:Gelu||
|
||||
|ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|
||||
|ai.onnx:GlobalMaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|
||||
|ai.onnx:GridSample|4D input.<br/>'mode' of 'linear' or 'zeros'.<br/>(mode==linear && padding_mode==reflection && align_corners==0) is not supported.|
|
||||
|ai.onnx:GroupNormalization||
|
||||
|ai.onnx:InstanceNormalization||
|
||||
|ai.onnx:LayerNormalization||
|
||||
|ai.onnx:LeakyRelu||
|
||||
|ai.onnx:MatMul|Only support for transA == 0, alpha == 1.0 and beta == 1.0 is currently implemented.|
|
||||
|ai.onnx:MaxPool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.|
|
||||
|ai.onnx:Mul||
|
||||
|ai.onnx:Pow|Only supports cases when both inputs are fp32.|
|
||||
|ai.onnx:PRelu||
|
||||
|ai.onnx:Reciprocal|this ask for a `epislon` (default 1e-4) where onnx don't provide|
|
||||
|ai.onnx:Relu||
|
||||
|ai.onnx:Reshape||
|
||||
|
|
Загрузка…
Ссылка в новой задаче