Fix Convnet test
This commit is contained in:
Родитель
bc2b046cdd
Коммит
f30aba05fb
|
@ -126,7 +126,7 @@ def train_and_evaluate(reader_train, reader_test, max_epochs):
|
|||
# Set learning parameters
|
||||
lr_per_minibatch = learning_rate_schedule([0.01]*10 + [0.003]*10 + [0.001], epoch_size, UnitType.minibatch)
|
||||
momentum_time_constant = momentum_as_time_constant_schedule(-minibatch_size/np.log(0.9))
|
||||
l2_reg_weight = 0.03
|
||||
l2_reg_weight = 0.0001
|
||||
|
||||
# trainer object
|
||||
learner = momentum_sgd(z.parameters,
|
||||
|
|
|
@ -44,7 +44,7 @@ def test_cifar_resnet_error(device_id):
|
|||
reader_test = create_reader(os.path.join(base_path, 'test_map.txt'), os.path.join(base_path, 'CIFAR-10_mean.xml'), False)
|
||||
|
||||
test_error = train_and_evaluate(reader_train, reader_test, max_epochs=5)
|
||||
expected_test_error = 0.462
|
||||
expected_test_error = 0.463
|
||||
|
||||
assert np.allclose(test_error, expected_test_error,
|
||||
atol=TOLERANCE_ABSOLUTE)
|
||||
|
|
|
@ -44,7 +44,7 @@ def test_cifar_resnet_error(device_id):
|
|||
reader_test = create_reader(os.path.join(base_path, 'test_map.txt'), os.path.join(base_path, 'CIFAR-10_mean.xml'), False)
|
||||
|
||||
test_error = train_and_evaluate(reader_train, reader_test, max_epochs=5)
|
||||
expected_test_error = 0.36
|
||||
expected_test_error = 0.282
|
||||
|
||||
assert np.allclose(test_error, expected_test_error,
|
||||
atol=TOLERANCE_ABSOLUTE)
|
||||
|
|
Загрузка…
Ссылка в новой задаче