[WebNN EP] Remove workaround for CPU op supported list (#21962)

We assume all WebNN ops are supported across all backends.
This commit is contained in:
Wanming Lin 2024-09-07 13:14:52 +08:00 коммит произвёл GitHub
Родитель f3725b9f06
Коммит ad9afbb042
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 90 добавлений и 104 удалений

Просмотреть файл

@ -31,11 +31,6 @@ enum class WebnnDeviceType {
NPU,
};
typedef struct {
std::string opName;
bool isCpuSupported; // The WebNN CPU backend XNNPack supports it (not about the CPU EP).
} WebnnOpInfo;
// Collects all the initializer tensors in the subGraph and its ancestor graphs.
InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer);
@ -154,109 +149,100 @@ std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_v
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const logging::Logger& logger);
static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Abs", {"abs", true}},
{"Add", {"add", true}},
{"ArgMax", {"argMax", true}},
{"ArgMin", {"argMin", true}},
{"AveragePool", {"averagePool2d", true}},
{"BatchNormalization", {"batchNormalization", true}},
{"Cast", {"cast", true}},
{"Ceil", {"ceil", true}},
{"Clip", {"clamp", true}},
{"Concat", {"concat", true}},
{"Conv", {"conv2d", true}},
{"ConvInteger", {"conv2dInteger", false}},
{"ConvTranspose", {"convTranspose2d", true}},
{"Cos", {"cos", true}},
{"Div", {"div", true}},
{"DequantizeLinear", {"dequantizeLinear", false}},
{"Dropout", {"identity", true}},
{"DynamicQuantizeLinear", {"dynamicQuantizeLinear", false}},
{"Elu", {"elu", true}},
{"Equal", {"equal", true}},
{"Erf", {"erf", true}},
{"Exp", {"exp", true}},
{"Expand", {"expand", true}},
{"Flatten", {"reshape", true}},
{"Floor", {"floor", true}},
{"Gather", {"gather", true}},
{"Gelu", {"gelu", true}},
{"Gemm", {"gemm", true}},
{"GlobalAveragePool", {"averagePool2d", true}},
{"GlobalMaxPool", {"maxPool2d", true}},
{"GlobalLpPool", {"l2Pool2d", false}},
{"Greater", {"greater", true}},
{"GreaterOrEqual", {"greaterOrEqual", true}},
{"HardSigmoid", {"hardSigmoid", true}},
{"HardSwish", {"hardSwish", true}},
{"Identity", {"identity", true}},
{"InstanceNormalization", {"instanceNormalization", true}},
{"LayerNormalization", {"layerNormalization", true}},
{"LeakyRelu", {"leakyRelu", true}},
{"Less", {"lesser", true}},
{"LessOrEqual", {"lesserOrEqual", true}},
{"Log", {"log", true}},
{"LpPool", {"l2Pool2d", false}},
{"MatMul", {"matmul", true}},
{"MatMulInteger", {"matmulInteger", false}},
{"Max", {"max", true}},
{"MaxPool", {"maxPool2d", true}},
{"Min", {"min", true}},
{"Mul", {"mul", true}},
{"Neg", {"neg", true}},
{"Not", {"logicalNot", true}},
{"Pad", {"pad", true}},
{"Pow", {"pow", true}},
{"PRelu", {"prelu", true}},
{"Reciprocal", {"reciprocal", true}},
{"ReduceL1", {"reduceL1", true}},
{"ReduceL2", {"reduceL2", true}},
{"ReduceLogSum", {"reduceLogSum", true}},
{"ReduceLogSumExp", {"reduceLogSumExp", true}},
{"ReduceMax", {"reduceMax", true}},
{"ReduceMean", {"reduceMean", true}},
{"ReduceMin", {"reduceMin", true}},
{"ReduceProd", {"reduceProduct", true}},
{"ReduceSum", {"reduceSum", true}},
{"ReduceSumSquare", {"reduceSumSquare", true}},
{"Relu", {"relu", true}},
{"Reshape", {"reshape", true}},
{"Resize", {"resample2d", true}},
{"Shape", {"slice", true}},
{"Sigmoid", {"sigmoid", true}},
{"Softplus", {"softplus", true}},
{"Softsign", {"softsign", true}},
{"Sin", {"sin", true}},
{"Slice", {"slice", true}},
{"Softmax", {"softmax", true}},
{"Split", {"split", true}},
{"Sqrt", {"sqrt", true}},
{"Squeeze", {"reshape", true}},
{"Sub", {"sub", true}},
{"Tan", {"tan", true}},
{"Tanh", {"tanh", true}},
{"Transpose", {"transpose", true}},
{"Trilu", {"triangular", true}},
{"Unsqueeze", {"reshape", true}},
{"Where", {"where", true}},
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
{"Add", "add"},
{"ArgMax", "argMax"},
{"ArgMin", "argMin"},
{"AveragePool", "averagePool2d"},
{"BatchNormalization", "batchNormalization"},
{"Cast", "cast"},
{"Ceil", "ceil"},
{"Clip", "clamp"},
{"Concat", "concat"},
{"Conv", "conv2d"},
{"ConvInteger", "conv2dInteger"},
{"ConvTranspose", "convTranspose2d"},
{"Cos", "cos"},
{"Div", "div"},
{"DequantizeLinear", "dequantizeLinear"},
{"Dropout", "identity"},
{"DynamicQuantizeLinear", "dynamicQuantizeLinear"},
{"Elu", "elu"},
{"Equal", "equal"},
{"Erf", "erf"},
{"Exp", "exp"},
{"Expand", "expand"},
{"Flatten", "reshape"},
{"Floor", "floor"},
{"Gather", "gather"},
{"Gelu", "gelu"},
{"Gemm", "gemm"},
{"GlobalAveragePool", "averagePool2d"},
{"GlobalMaxPool", "maxPool2d"},
{"GlobalLpPool", "l2Pool2d"},
{"Greater", "greater"},
{"GreaterOrEqual", "greaterOrEqual"},
{"HardSigmoid", "hardSigmoid"},
{"HardSwish", "hardSwish"},
{"Identity", "identity"},
{"InstanceNormalization", "instanceNormalization"},
{"LayerNormalization", "layerNormalization"},
{"LeakyRelu", "leakyRelu"},
{"Less", "lesser"},
{"LessOrEqual", "lesserOrEqual"},
{"Log", "log"},
{"LpPool", "l2Pool2d"},
{"MatMul", "matmul"},
{"MatMulInteger", "matmulInteger"},
{"Max", "max"},
{"MaxPool", "maxPool2d"},
{"Min", "min"},
{"Mul", "mul"},
{"Neg", "neg"},
{"Not", "logicalNot"},
{"Pad", "pad"},
{"Pow", "pow"},
{"PRelu", "prelu"},
{"Reciprocal", "reciprocal"},
{"ReduceL1", "reduceL1"},
{"ReduceL2", "reduceL2"},
{"ReduceLogSum", "reduceLogSum"},
{"ReduceLogSumExp", "reduceLogSumExp"},
{"ReduceMax", "reduceMax"},
{"ReduceMean", "reduceMean"},
{"ReduceMin", "reduceMin"},
{"ReduceProd", "reduceProduct"},
{"ReduceSum", "reduceSum"},
{"ReduceSumSquare", "reduceSumSquare"},
{"Relu", "relu"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"Shape", "slice"},
{"Sigmoid", "sigmoid"},
{"Softplus", "softplus"},
{"Softsign", "softsign"},
{"Sin", "sin"},
{"Slice", "slice"},
{"Softmax", "softmax"},
{"Split", "split"},
{"Sqrt", "sqrt"},
{"Squeeze", "reshape"},
{"Sub", "sub"},
{"Tan", "tan"},
{"Tanh", "tanh"},
{"Transpose", "transpose"},
{"Trilu", "triangular"},
{"Unsqueeze", "reshape"},
{"Where", "where"},
};
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type) {
// Returns false if the op_type is not listed in the op_map.
if (op_map.find(op_type) == op_map.end()) {
return false;
}
// Returns false if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (!wnn_builder[op_map.find(op_type)->second.opName].as<bool>()) {
return false;
}
// The current WebNN CPU (TFLite) backend supports a limited op list, and we'd rather
// fall back early to the ORT CPU EP rather than fail in the WebNN "cpu" deviceType.
// This is a workaround because the op may be included in MLGraphBuilder for DirectML
// backend but without TFLite implementation in Chromium.
if (!op_map.find(op_type)->second.isCpuSupported && device_type == WebnnDeviceType::CPU) {
auto op_map_entry = op_map.find(op_type);
// Returns false if the op_type is not listed in the op_map or
// if the WebNN op has not been implemented in MLGraphBuilder in current browser.
if (op_map_entry == op_map.end() || !wnn_builder[op_map_entry->second].as<bool>()) {
return false;
}