[WebNN] Improve data type check of slice op (#22988)

A follow-up of [[WebNN] Support negative steps for
slice](https://github.com/microsoft/onnxruntime/pull/22871#discussion_r1847929774).
Slice op is emulated by reverse+slice when steps < 0 so
`SliceOpBuilder::HasSupportedInputsImpl()` should also check the
supported data types of reverse.

---------

Co-authored-by: Wanming Lin <wanming.lin@intel.com>
This commit is contained in:
shiyi 2024-12-11 07:48:16 +08:00 коммит произвёл GitHub
Родитель fa6ad202aa
Коммит 02f0af0d08
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
23 изменённых файлов: 137 добавлений и 83 удалений

Просмотреть файл

@ -178,14 +178,31 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
return false;
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type
<< "] " << onnx_input_output_name
<< " type: [" << onnx_data_type
<< "] is not supported for now";
return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
webnn_input_output_name, onnx_input_output_name, logger);
}
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger) {
if (wnn_limits[webnn_op_type].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
return false;
}
if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
<< webnn_input_output_name << "]";
return false;
}
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: ["
<< onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now";
return false;
}
return true;
}

Просмотреть файл

@ -340,6 +340,13 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger);
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger);
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,

Просмотреть файл

@ -29,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
if (!HasSupportedInputs(node, wnn_limits, logger))
if (!HasSupportedInputs(initializers, node, wnn_limits, logger))
return false;
if (!HasSupportedOutputs(node, wnn_limits, logger))
@ -41,7 +41,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
return IsOpSupportedImpl(initializers, node, device_type, logger);
}
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
@ -50,10 +50,10 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
}
}
return HasSupportedInputsImpl(node, wnn_limits, logger);
return HasSupportedInputsImpl(initializers, node, wnn_limits, logger);
}
bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.

Просмотреть файл

@ -40,7 +40,7 @@ class BaseOpBuilder : public IOpBuilder {
return true;
}
virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
@ -56,7 +56,7 @@ class BaseOpBuilder : public IOpBuilder {
private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.

Просмотреть файл

@ -22,8 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -86,8 +86,8 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return true;
}
bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;

Просмотреть файл

@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder {
// Operator support related.
private:
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -86,8 +86,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
}
// Operator support related.
bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type;

Просмотреть файл

@ -21,8 +21,8 @@ class ConcatOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -55,8 +55,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}
bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;

Просмотреть файл

@ -29,8 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
@ -397,8 +397,8 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}
bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // input data type

Просмотреть файл

@ -27,8 +27,8 @@ class EinsumOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Helper functions, thanks for DML EP's OperatorHelper.
@ -735,8 +735,8 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}
bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
@ -776,11 +776,11 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten:
return false;
} else if (recognized_operator_type == RecognizedOperatorType::Pairwise) {
// Map to WebNN's gemm or matmul
return IsDataTypeSupportedByOp("MatMul", input0_type, wnn_limits, "a", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger);
} else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) {
return IsDataTypeSupportedByOp("ReduceSum", input0_type, wnn_limits, "input", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger);
} else {
return IsDataTypeSupportedByOp("Identity", input0_type, wnn_limits, "input", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger);
}
}

Просмотреть файл

@ -20,8 +20,8 @@ class GatherElementsOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -49,7 +49,8 @@ Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde
// Operator support related.
bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];

Просмотреть файл

@ -22,8 +22,8 @@ class GatherNDOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -55,8 +55,8 @@ bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initial
return true;
}
bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();

Просмотреть файл

@ -22,8 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -69,8 +69,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}
bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();

Просмотреть файл

@ -25,8 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -215,8 +215,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer
return true;
}
bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // A data type

Просмотреть файл

@ -26,8 +26,8 @@ class GruOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -187,8 +187,8 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
return true;
}
bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_X_type = 0; // input data type

Просмотреть файл

@ -21,8 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -71,8 +71,8 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
return true;
}
bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool LogicalOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;

Просмотреть файл

@ -25,8 +25,8 @@ class LstmOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
@ -198,8 +198,8 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}
bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool LstmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type = 0; // input data type

Просмотреть файл

@ -22,8 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -87,8 +87,8 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}
bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool MaxMinOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;

Просмотреть файл

@ -25,8 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
@ -228,7 +228,8 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
return true;
}
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();

Просмотреть файл

@ -22,8 +22,8 @@ class QDQOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
@ -118,8 +118,8 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}
bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool QDQOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type = 0; // input data type

Просмотреть файл

@ -22,8 +22,8 @@ class ScatterElementsOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -65,7 +65,8 @@ bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /*
return true;
}
bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];

Просмотреть файл

@ -22,8 +22,8 @@ class ScatterNDOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -57,7 +57,8 @@ bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
return true;
}
bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool ScatterNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];

Просмотреть файл

@ -27,6 +27,8 @@ class SliceOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
// TODO: Support Slice opset < 10, which uses attributes for starts and ends.
int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; }
};
@ -161,6 +163,30 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}
bool SliceOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& input = *input_defs[0];
const auto& op_type = node.OpType();
int32_t input_type;
if (!GetType(input, input_type, logger))
return false;
// If there is step < 0, check data type support of reverse.
if (input_defs.size() > 4 && input_defs[4]->Exists()) {
std::vector<int64_t> steps;
if (!ReadIntArrayFrom1DTensor(*initializers.at(input_defs[4]->Name()), steps, logger))
return false;
if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) {
if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) {
return false;
}
}
}
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger);
}
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<SliceOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

Просмотреть файл

@ -18,8 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};
// Add operator related.
@ -46,8 +46,8 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
return Status::OK();
}
bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool TernaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // condition data type