Add support for softmaxcrossentropy loss. This is already enabled on our ROCm Fork of the MIGraphX EP ### Motivation and Context Adds support for the SoftmaxCrossEntropyLoss operator and removes the filtering of inputs here.
This commit is contained in:
Родитель
05fbb43b34
Коммит
7ad78733e6
|
@ -915,6 +915,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
|
|||
"SkipSimplifiedLayerNormalization",
|
||||
"Slice",
|
||||
"Softmax",
|
||||
"SoftmaxCrossEntropyLoss",
|
||||
"Softplus",
|
||||
"Softsign",
|
||||
"SpaceToDepth",
|
||||
|
@ -1026,15 +1027,6 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v
|
|||
return result;
|
||||
}
|
||||
|
||||
// migraphx cannot handle Loop, If, and SoftmaxCrossEntropyLoss for now,
|
||||
// so if a model contain any of these operators, fall back to CPU
|
||||
std::unordered_set<std::string> vec_ops = {"SoftmaxCrossEntropyLoss"};
|
||||
if (std::any_of(unsupported_nodes.begin(), unsupported_nodes.end(), [&](auto i) {
|
||||
return (vec_ops.count(graph_viewer.GetNode(i)->OpType()) > 0);
|
||||
})) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes);
|
||||
|
||||
// check whether a subgrap should fallback to CPU
|
||||
|
|
Загрузка…
Ссылка в новой задаче