reduction all axes export fix
This commit is contained in:
Родитель
d2ff41272d
Коммит
4a6238d979
|
@ -3464,11 +3464,19 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod
|
|||
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
|
||||
// Reduction on batch axis in CNTK removes the batch axis, even if keepdims is true.
|
||||
// For ONNX export we need to make sure we export keepdims as 0 (false).
|
||||
if (axis == Axis::DefaultBatchAxis())
|
||||
// The same applies for All axes
|
||||
if (axis == Axis::DefaultBatchAxis() || axis == Axis::AllAxes() || axis == Axis::AllStaticAxes())
|
||||
keepReducedDimensions = 0;
|
||||
int64_t ax = ConvertAxisToOnnx(axis, src->Inputs()[0]);
|
||||
|
||||
node->AddAttribute("axis", ax);
|
||||
if (node->OpType() != "ArgMax" && node->OpType() != "ArgMin")
|
||||
{
|
||||
std::vector<int64_t> axes = ConvertAxesToOnnx(std::vector<Axis>({ axis }), src->Inputs()[0]);
|
||||
node->AddAttribute("axes", axes);
|
||||
}
|
||||
else
|
||||
{
|
||||
int64_t ax = ConvertAxisToOnnx(axis, src->Inputs()[0]);
|
||||
node->AddAttribute("axis", ax);
|
||||
}
|
||||
}
|
||||
|
||||
node->AddAttribute("keepdims", keepReducedDimensions);
|
||||
|
|
Загрузка…
Ссылка в новой задаче