[WebNN] Improve data type check of slice op (#22988)
A follow-up of [[WebNN] Support negative steps for slice](https://github.com/microsoft/onnxruntime/pull/22871#discussion_r1847929774). Slice op is emulated by reverse+slice when steps < 0 so `SliceOpBuilder::HasSupportedInputsImpl()` should also check the supported data types of reverse. --------- Co-authored-by: Wanming Lin <wanming.lin@intel.com>
This commit is contained in:
Родитель
fa6ad202aa
Коммит
02f0af0d08
|
@ -178,14 +178,31 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
|
|||
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
|
||||
return false;
|
||||
|
||||
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
|
||||
LOGS(logger, VERBOSE) << "[" << onnx_op_type
|
||||
<< "] " << onnx_input_output_name
|
||||
<< " type: [" << onnx_data_type
|
||||
<< "] is not supported for now";
|
||||
return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
|
||||
webnn_input_output_name, onnx_input_output_name, logger);
|
||||
}
|
||||
|
||||
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
|
||||
const std::string& webnn_op_type,
|
||||
const int32_t onnx_data_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const std::string& webnn_input_output_name,
|
||||
const std::string& onnx_input_output_name,
|
||||
const logging::Logger& logger) {
|
||||
if (wnn_limits[webnn_op_type].isUndefined()) {
|
||||
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
|
||||
return false;
|
||||
}
|
||||
if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {
|
||||
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
|
||||
<< webnn_input_output_name << "]";
|
||||
return false;
|
||||
}
|
||||
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
|
||||
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: ["
|
||||
<< onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -340,6 +340,13 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
|
|||
const std::string& webnn_input_output_name,
|
||||
const std::string& onnx_input_output_name,
|
||||
const logging::Logger& logger);
|
||||
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
|
||||
const std::string& webnn_op_type,
|
||||
const int32_t onnx_data_type,
|
||||
const emscripten::val& wnn_limits,
|
||||
const std::string& webnn_input_output_name,
|
||||
const std::string& onnx_input_output_name,
|
||||
const logging::Logger& logger);
|
||||
|
||||
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
|
||||
std::vector<int64_t>& shape_b,
|
||||
|
|
|
@ -29,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
|
|||
bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
if (!HasSupportedInputs(node, wnn_limits, logger))
|
||||
if (!HasSupportedInputs(initializers, node, wnn_limits, logger))
|
||||
return false;
|
||||
|
||||
if (!HasSupportedOutputs(node, wnn_limits, logger))
|
||||
|
@ -41,7 +41,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
|
|||
return IsOpSupportedImpl(initializers, node, device_type, logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
|
||||
bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
|
||||
for (const auto* input : node.InputDefs()) {
|
||||
|
@ -50,10 +50,10 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
|
|||
}
|
||||
}
|
||||
|
||||
return HasSupportedInputsImpl(node, wnn_limits, logger);
|
||||
return HasSupportedInputsImpl(initializers, node, wnn_limits, logger);
|
||||
}
|
||||
|
||||
bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
|
||||
bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
// We only check the type of input 0 by default, specific op builder can override this.
|
||||
|
|
|
@ -40,7 +40,7 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
return true;
|
||||
}
|
||||
|
||||
virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const;
|
||||
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const;
|
||||
|
@ -56,7 +56,7 @@ class BaseOpBuilder : public IOpBuilder {
|
|||
|
||||
private:
|
||||
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
|
||||
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
|
||||
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
|
||||
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
|
||||
|
||||
const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.
|
||||
|
|
|
@ -22,8 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -86,8 +86,8 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
|
|||
return true;
|
||||
}
|
||||
|
||||
bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
|
|
|
@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder {
|
|||
|
||||
// Operator support related.
|
||||
private:
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -86,8 +86,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
}
|
||||
|
||||
// Operator support related.
|
||||
bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
|
|
|
@ -21,8 +21,8 @@ class ConcatOpBuilder : public BaseOpBuilder {
|
|||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -55,8 +55,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
|
|
|
@ -29,8 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
|
||||
|
@ -397,8 +397,8 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type; // input data type
|
||||
|
|
|
@ -27,8 +27,8 @@ class EinsumOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Helper functions, thanks for DML EP's OperatorHelper.
|
||||
|
@ -735,8 +735,8 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
const auto& op_type = node.OpType();
|
||||
|
@ -776,11 +776,11 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten:
|
|||
return false;
|
||||
} else if (recognized_operator_type == RecognizedOperatorType::Pairwise) {
|
||||
// Map to WebNN's gemm or matmul
|
||||
return IsDataTypeSupportedByOp("MatMul", input0_type, wnn_limits, "a", "inputs", logger);
|
||||
return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger);
|
||||
} else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) {
|
||||
return IsDataTypeSupportedByOp("ReduceSum", input0_type, wnn_limits, "input", "inputs", logger);
|
||||
return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger);
|
||||
} else {
|
||||
return IsDataTypeSupportedByOp("Identity", input0_type, wnn_limits, "input", "inputs", logger);
|
||||
return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,8 @@ class GatherElementsOpBuilder : public BaseOpBuilder {
|
|||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -49,7 +49,8 @@ Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde
|
|||
|
||||
// Operator support related.
|
||||
|
||||
bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& data = *node.InputDefs()[0];
|
||||
const auto& indices = *node.InputDefs()[1];
|
||||
|
|
|
@ -22,8 +22,8 @@ class GatherNDOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -55,8 +55,8 @@ bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initial
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& data = *node.InputDefs()[0];
|
||||
const auto& indices = *node.InputDefs()[1];
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
|
@ -22,8 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -69,8 +69,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input = *node.InputDefs()[0];
|
||||
const auto& indices = *node.InputDefs()[1];
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
|
@ -25,8 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -215,8 +215,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type; // A data type
|
||||
|
|
|
@ -26,8 +26,8 @@ class GruOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
@ -187,8 +187,8 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
|
|||
return true;
|
||||
}
|
||||
|
||||
bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_X_type = 0; // input data type
|
||||
|
|
|
@ -21,8 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -71,8 +71,8 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
|
|||
return true;
|
||||
}
|
||||
|
||||
bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool LogicalOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
|
|
|
@ -25,8 +25,8 @@ class LstmOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
};
|
||||
|
@ -198,8 +198,8 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool LstmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type = 0; // input data type
|
||||
|
|
|
@ -22,8 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -87,8 +87,8 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
|
|||
return true;
|
||||
}
|
||||
|
||||
bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool MaxMinOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type;
|
||||
|
|
|
@ -25,8 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
|
@ -228,7 +228,8 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
|
|||
return true;
|
||||
}
|
||||
|
||||
bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
|
|
|
@ -22,8 +22,8 @@ class QDQOpBuilder : public BaseOpBuilder {
|
|||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
|
||||
// Operator support related.
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
||||
|
@ -118,8 +118,8 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool QDQOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type = 0; // input data type
|
||||
|
|
|
@ -22,8 +22,8 @@ class ScatterElementsOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -65,7 +65,8 @@ bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /*
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& data = *node.InputDefs()[0];
|
||||
const auto& indices = *node.InputDefs()[1];
|
||||
|
|
|
@ -22,8 +22,8 @@ class ScatterNDOpBuilder : public BaseOpBuilder {
|
|||
// Operator support related.
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -57,7 +57,8 @@ bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
bool ScatterNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
const auto& data = *node.InputDefs()[0];
|
||||
const auto& indices = *node.InputDefs()[1];
|
||||
|
|
|
@ -27,6 +27,8 @@ class SliceOpBuilder : public BaseOpBuilder {
|
|||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
// TODO: Support Slice opset < 10, which uses attributes for starts and ends.
|
||||
int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; }
|
||||
};
|
||||
|
@ -161,6 +163,30 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool SliceOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& input = *input_defs[0];
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input_type;
|
||||
if (!GetType(input, input_type, logger))
|
||||
return false;
|
||||
|
||||
// If there is step < 0, check data type support of reverse.
|
||||
if (input_defs.size() > 4 && input_defs[4]->Exists()) {
|
||||
std::vector<int64_t> steps;
|
||||
if (!ReadIntArrayFrom1DTensor(*initializers.at(input_defs[4]->Name()), steps, logger))
|
||||
return false;
|
||||
if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) {
|
||||
if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger);
|
||||
}
|
||||
|
||||
void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
|
||||
op_registrations.builders.push_back(std::make_unique<SliceOpBuilder>());
|
||||
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
|
||||
|
|
|
@ -18,8 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder {
|
|||
private:
|
||||
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
|
||||
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
|
||||
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const override;
|
||||
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
// Add operator related.
|
||||
|
@ -46,8 +46,8 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
|
||||
const logging::Logger& logger) const {
|
||||
bool TernaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
|
||||
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const auto& op_type = node.OpType();
|
||||
int32_t input0_type; // condition data type
|
||||
|
|
Загрузка…
Ссылка в новой задаче