This commit is contained in:
Peyman Manikashani 2018-08-23 11:05:56 -07:00
Родитель d2ff41272d
Коммит 4a6238d979
1 изменённых файлов: 12 добавлений и 4 удалений

Просмотреть файл

@ -3464,11 +3464,19 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, LotusIR::Node* nod
Axis axis = (Axis)(src->Attributes()[L"axis"].Value<Axis>());
// Reduction on batch axis in CNTK removes the batch axis, even if keepdims is true.
// For ONNX export we need to make sure we export keepdims as 0 (false).
if (axis == Axis::DefaultBatchAxis())
// The same applies for All axes
if (axis == Axis::DefaultBatchAxis() || axis == Axis::AllAxes() || axis == Axis::AllStaticAxes())
keepReducedDimensions = 0;
int64_t ax = ConvertAxisToOnnx(axis, src->Inputs()[0]);
node->AddAttribute("axis", ax);
if (node->OpType() != "ArgMax" && node->OpType() != "ArgMin")
{
std::vector<int64_t> axes = ConvertAxesToOnnx(std::vector<Axis>({ axis }), src->Inputs()[0]);
node->AddAttribute("axes", axes);
}
else
{
int64_t ax = ConvertAxisToOnnx(axis, src->Inputs()[0]);
node->AddAttribute("axis", ax);
}
}
node->AddAttribute("keepdims", keepReducedDimensions);