onnx 1.2.1 for BatchNorm and Dropout
This commit is contained in:
Родитель
172cf67741
Коммит
327c030594
|
@ -2832,7 +2832,6 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, ONNXIR::Node* node
|
|||
momentum = 1.0f + expm1(-48.0f / normalizationTimeConstant);
|
||||
|
||||
node->AddAttribute(attributesMap[L"spatial"], spatial);
|
||||
node->AddAttribute("is_test", (int64_t)1);
|
||||
node->AddAttribute(attributesMap[L"epsilon"], epsilon);
|
||||
node->AddAttribute("momentum", momentum);
|
||||
}
|
||||
|
@ -2877,7 +2876,6 @@ void CNTKToONNXHelper::CopyAttributes(const FunctionPtr& src, ONNXIR::Node* node
|
|||
{
|
||||
auto dropoutRate = (float)src->Attributes()[L"dropoutRate"].Value<double>();
|
||||
node->AddAttribute(attributesMap[L"dropoutRate"], dropoutRate);
|
||||
node->AddAttribute("is_test", (int64_t)1);
|
||||
}
|
||||
else if ((src->OpName() == L"RandomDistribution") ||
|
||||
(src->OpName() == L"UniformRandom") || (src->OpName() == L"NormalRandom") ||
|
||||
|
|
|
@ -1962,10 +1962,6 @@ FunctionPtr ONNXToCNTKHelper::CreateFunction(const Node *node, const std::vector
|
|||
}
|
||||
else if (onnxOpName == "BatchNormalization" || onnxOpName == "SpatialBN")
|
||||
{
|
||||
auto is_test = GetNamedAttributeAsInt64(node, "is_test", 0);
|
||||
if (is_test == 0)
|
||||
NOT_IMPLEMENTED;
|
||||
// TODO: implement this right once ready.
|
||||
const Variable &operand = inputs[0];
|
||||
const Variable &scale = inputs[1];
|
||||
const Variable &bias = inputs[2];
|
||||
|
|
|
@ -181,8 +181,6 @@ def test_AveragePool(tmpdir):
|
|||
|
||||
#BatchNormalization
|
||||
def test_BatchNormalization(tmpdir):
|
||||
pytest.skip('Need to support new ONNX spec.')
|
||||
|
||||
dtype = np.float32
|
||||
|
||||
sample = [ # 5 samples having 4 classes
|
||||
|
@ -326,7 +324,6 @@ def test_Div(tmpdir):
|
|||
|
||||
#Dropout
|
||||
def test_Dropout(tmpdir):
|
||||
pytest.skip('Need to support new ONNX spec.')
|
||||
data = np.asarray([[10, 20],[30, 40],[50, 60]], dtype=np.float32)
|
||||
model = C.dropout(data, 0.5)
|
||||
verify_no_input(model, tmpdir, 'Dropout_0')
|
||||
|
|
Загрузка…
Ссылка в новой задаче