[WebNN EP] Remove workaround for CPU op supported list (#21962)
We assume all WebNN ops are supported across all backends.
This commit is contained in:
Родитель
f3725b9f06
Коммит
ad9afbb042
|
@ -31,11 +31,6 @@ enum class WebnnDeviceType {
|
|||
NPU,
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
std::string opName;
|
||||
bool isCpuSupported; // The WebNN CPU backend XNNPack supports it (not about the CPU EP).
|
||||
} WebnnOpInfo;
|
||||
|
||||
// Collects all the initializer tensors in the subGraph and its ancestor graphs.
|
||||
InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer);
|
||||
|
||||
|
@ -154,109 +149,100 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
|
|||
const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type,
|
||||
const logging::Logger& logger);
|
||||
static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
|
||||
{"Abs", {"abs", true}},
|
||||
{"Add", {"add", true}},
|
||||
{"ArgMax", {"argMax", true}},
|
||||
{"ArgMin", {"argMin", true}},
|
||||
{"AveragePool", {"averagePool2d", true}},
|
||||
{"BatchNormalization", {"batchNormalization", true}},
|
||||
{"Cast", {"cast", true}},
|
||||
{"Ceil", {"ceil", true}},
|
||||
{"Clip", {"clamp", true}},
|
||||
{"Concat", {"concat", true}},
|
||||
{"Conv", {"conv2d", true}},
|
||||
{"ConvInteger", {"conv2dInteger", false}},
|
||||
{"ConvTranspose", {"convTranspose2d", true}},
|
||||
{"Cos", {"cos", true}},
|
||||
{"Div", {"div", true}},
|
||||
{"DequantizeLinear", {"dequantizeLinear", false}},
|
||||
{"Dropout", {"identity", true}},
|
||||
{"DynamicQuantizeLinear", {"dynamicQuantizeLinear", false}},
|
||||
{"Elu", {"elu", true}},
|
||||
{"Equal", {"equal", true}},
|
||||
{"Erf", {"erf", true}},
|
||||
{"Exp", {"exp", true}},
|
||||
{"Expand", {"expand", true}},
|
||||
{"Flatten", {"reshape", true}},
|
||||
{"Floor", {"floor", true}},
|
||||
{"Gather", {"gather", true}},
|
||||
{"Gelu", {"gelu", true}},
|
||||
{"Gemm", {"gemm", true}},
|
||||
{"GlobalAveragePool", {"averagePool2d", true}},
|
||||
{"GlobalMaxPool", {"maxPool2d", true}},
|
||||
{"GlobalLpPool", {"l2Pool2d", false}},
|
||||
{"Greater", {"greater", true}},
|
||||
{"GreaterOrEqual", {"greaterOrEqual", true}},
|
||||
{"HardSigmoid", {"hardSigmoid", true}},
|
||||
{"HardSwish", {"hardSwish", true}},
|
||||
{"Identity", {"identity", true}},
|
||||
{"InstanceNormalization", {"instanceNormalization", true}},
|
||||
{"LayerNormalization", {"layerNormalization", true}},
|
||||
{"LeakyRelu", {"leakyRelu", true}},
|
||||
{"Less", {"lesser", true}},
|
||||
{"LessOrEqual", {"lesserOrEqual", true}},
|
||||
{"Log", {"log", true}},
|
||||
{"LpPool", {"l2Pool2d", false}},
|
||||
{"MatMul", {"matmul", true}},
|
||||
{"MatMulInteger", {"matmulInteger", false}},
|
||||
{"Max", {"max", true}},
|
||||
{"MaxPool", {"maxPool2d", true}},
|
||||
{"Min", {"min", true}},
|
||||
{"Mul", {"mul", true}},
|
||||
{"Neg", {"neg", true}},
|
||||
{"Not", {"logicalNot", true}},
|
||||
{"Pad", {"pad", true}},
|
||||
{"Pow", {"pow", true}},
|
||||
{"PRelu", {"prelu", true}},
|
||||
{"Reciprocal", {"reciprocal", true}},
|
||||
{"ReduceL1", {"reduceL1", true}},
|
||||
{"ReduceL2", {"reduceL2", true}},
|
||||
{"ReduceLogSum", {"reduceLogSum", true}},
|
||||
{"ReduceLogSumExp", {"reduceLogSumExp", true}},
|
||||
{"ReduceMax", {"reduceMax", true}},
|
||||
{"ReduceMean", {"reduceMean", true}},
|
||||
{"ReduceMin", {"reduceMin", true}},
|
||||
{"ReduceProd", {"reduceProduct", true}},
|
||||
{"ReduceSum", {"reduceSum", true}},
|
||||
{"ReduceSumSquare", {"reduceSumSquare", true}},
|
||||
{"Relu", {"relu", true}},
|
||||
{"Reshape", {"reshape", true}},
|
||||
{"Resize", {"resample2d", true}},
|
||||
{"Shape", {"slice", true}},
|
||||
{"Sigmoid", {"sigmoid", true}},
|
||||
{"Softplus", {"softplus", true}},
|
||||
{"Softsign", {"softsign", true}},
|
||||
{"Sin", {"sin", true}},
|
||||
{"Slice", {"slice", true}},
|
||||
{"Softmax", {"softmax", true}},
|
||||
{"Split", {"split", true}},
|
||||
{"Sqrt", {"sqrt", true}},
|
||||
{"Squeeze", {"reshape", true}},
|
||||
{"Sub", {"sub", true}},
|
||||
{"Tan", {"tan", true}},
|
||||
{"Tanh", {"tanh", true}},
|
||||
{"Transpose", {"transpose", true}},
|
||||
{"Trilu", {"triangular", true}},
|
||||
{"Unsqueeze", {"reshape", true}},
|
||||
{"Where", {"where", true}},
|
||||
static const InlinedHashMap<std::string, std::string> op_map = {
|
||||
{"Abs", "abs"},
|
||||
{"Add", "add"},
|
||||
{"ArgMax", "argMax"},
|
||||
{"ArgMin", "argMin"},
|
||||
{"AveragePool", "averagePool2d"},
|
||||
{"BatchNormalization", "batchNormalization"},
|
||||
{"Cast", "cast"},
|
||||
{"Ceil", "ceil"},
|
||||
{"Clip", "clamp"},
|
||||
{"Concat", "concat"},
|
||||
{"Conv", "conv2d"},
|
||||
{"ConvInteger", "conv2dInteger"},
|
||||
{"ConvTranspose", "convTranspose2d"},
|
||||
{"Cos", "cos"},
|
||||
{"Div", "div"},
|
||||
{"DequantizeLinear", "dequantizeLinear"},
|
||||
{"Dropout", "identity"},
|
||||
{"DynamicQuantizeLinear", "dynamicQuantizeLinear"},
|
||||
{"Elu", "elu"},
|
||||
{"Equal", "equal"},
|
||||
{"Erf", "erf"},
|
||||
{"Exp", "exp"},
|
||||
{"Expand", "expand"},
|
||||
{"Flatten", "reshape"},
|
||||
{"Floor", "floor"},
|
||||
{"Gather", "gather"},
|
||||
{"Gelu", "gelu"},
|
||||
{"Gemm", "gemm"},
|
||||
{"GlobalAveragePool", "averagePool2d"},
|
||||
{"GlobalMaxPool", "maxPool2d"},
|
||||
{"GlobalLpPool", "l2Pool2d"},
|
||||
{"Greater", "greater"},
|
||||
{"GreaterOrEqual", "greaterOrEqual"},
|
||||
{"HardSigmoid", "hardSigmoid"},
|
||||
{"HardSwish", "hardSwish"},
|
||||
{"Identity", "identity"},
|
||||
{"InstanceNormalization", "instanceNormalization"},
|
||||
{"LayerNormalization", "layerNormalization"},
|
||||
{"LeakyRelu", "leakyRelu"},
|
||||
{"Less", "lesser"},
|
||||
{"LessOrEqual", "lesserOrEqual"},
|
||||
{"Log", "log"},
|
||||
{"LpPool", "l2Pool2d"},
|
||||
{"MatMul", "matmul"},
|
||||
{"MatMulInteger", "matmulInteger"},
|
||||
{"Max", "max"},
|
||||
{"MaxPool", "maxPool2d"},
|
||||
{"Min", "min"},
|
||||
{"Mul", "mul"},
|
||||
{"Neg", "neg"},
|
||||
{"Not", "logicalNot"},
|
||||
{"Pad", "pad"},
|
||||
{"Pow", "pow"},
|
||||
{"PRelu", "prelu"},
|
||||
{"Reciprocal", "reciprocal"},
|
||||
{"ReduceL1", "reduceL1"},
|
||||
{"ReduceL2", "reduceL2"},
|
||||
{"ReduceLogSum", "reduceLogSum"},
|
||||
{"ReduceLogSumExp", "reduceLogSumExp"},
|
||||
{"ReduceMax", "reduceMax"},
|
||||
{"ReduceMean", "reduceMean"},
|
||||
{"ReduceMin", "reduceMin"},
|
||||
{"ReduceProd", "reduceProduct"},
|
||||
{"ReduceSum", "reduceSum"},
|
||||
{"ReduceSumSquare", "reduceSumSquare"},
|
||||
{"Relu", "relu"},
|
||||
{"Reshape", "reshape"},
|
||||
{"Resize", "resample2d"},
|
||||
{"Shape", "slice"},
|
||||
{"Sigmoid", "sigmoid"},
|
||||
{"Softplus", "softplus"},
|
||||
{"Softsign", "softsign"},
|
||||
{"Sin", "sin"},
|
||||
{"Slice", "slice"},
|
||||
{"Softmax", "softmax"},
|
||||
{"Split", "split"},
|
||||
{"Sqrt", "sqrt"},
|
||||
{"Squeeze", "reshape"},
|
||||
{"Sub", "sub"},
|
||||
{"Tan", "tan"},
|
||||
{"Tanh", "tanh"},
|
||||
{"Transpose", "transpose"},
|
||||
{"Trilu", "triangular"},
|
||||
{"Unsqueeze", "reshape"},
|
||||
{"Where", "where"},
|
||||
};
|
||||
|
||||
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
|
||||
const WebnnDeviceType device_type) {
|
||||
// Returns false if the op_type is not listed in the op_map.
|
||||
if (op_map.find(op_type) == op_map.end()) {
|
||||
return false;
|
||||
}
|
||||
// Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser.
|
||||
if (!wnn_builder[op_map.find(op_type)->second.opName].as<bool>()) {
|
||||
return false;
|
||||
}
|
||||
// The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather
|
||||
// fall back early to the ORT CPU EP rather than fail in the WebNN "cpu" deviceType.
|
||||
// This is a workaround because the op may be included in MLGraphBuilder for DirectML
|
||||
// backend but without TFLite implementation in Chromium.
|
||||
if (!op_map.find(op_type)->second.isCpuSupported && device_type == WebnnDeviceType::CPU) {
|
||||
auto op_map_entry = op_map.find(op_type);
|
||||
// Returns false if the op_type is not listed in the op_map or
|
||||
// if the WebNN op has not been implemented in MLGraphBuilder in current browser.
|
||||
if (op_map_entry == op_map.end() || !wnn_builder[op_map_entry->second].as<bool>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче