From 7ad78733e6ab95378aee9d0255de6c466f203801 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Sun, 27 Oct 2024 16:59:35 -0400 Subject: [PATCH] Add support for softmaxcrossentropy loss to MIGraphX EP (#64) (#22603) Add support for softmaxcrossentropy loss. This is already enabled on our ROCm Fork of the MIGraphX EP ### Motivation and Context Adds support for the SoftmaxCrossEntropyLoss operator and removes the filtering of inputs here. --- .../providers/migraphx/migraphx_execution_provider.cc | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index dca3848043..9d651129fb 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -915,6 +915,7 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, "SkipSimplifiedLayerNormalization", "Slice", "Softmax", + "SoftmaxCrossEntropyLoss", "Softplus", "Softsign", "SpaceToDepth", @@ -1026,15 +1027,6 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v return result; } - // migraphx cannot handle Loop, If, and SoftmaxCrossEntropyLoss for now, - // so if a model contain any of these operators, fall back to CPU - std::unordered_set vec_ops = {"SoftmaxCrossEntropyLoss"}; - if (std::any_of(unsupported_nodes.begin(), unsupported_nodes.end(), [&](auto i) { - return (vec_ops.count(graph_viewer.GetNode(i)->OpType()) > 0); - })) { - return result; - } - auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); // check whether a subgrap should fallback to CPU