Fix other unit test.
This commit is contained in:
Родитель
4ff71f8067
Коммит
de511fca9f
|
@ -42,7 +42,7 @@ def run_distributed_trainer(tmpdir, quantized):
|
|||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
|
||||
trainer = Trainer(z, ce, errs, \
|
||||
sgd(z.parameters, 0.007, momentum_time_constant, 0.5, True),
|
||||
momentum_sgd(z.parameters, 0.007, momentum_time_constant),
|
||||
distributed_trainer=dist_trainer)
|
||||
in1_value = [[1],[2]]
|
||||
label_value = [[0], [1]]
|
||||
|
|
|
@ -23,7 +23,7 @@ def test_trainer(tmpdir):
|
|||
m_schedule = momentum_schedule(1100)
|
||||
|
||||
trainer = Trainer(z, ce, errs, \
|
||||
[sgd(z.parameters, 0.007, m_schedule, 0.5, True)])
|
||||
[momentum_sgd(z.parameters, 0.007, m_schedule)])
|
||||
in1_value = [[1],[2]]
|
||||
label_value = [[0], [1]]
|
||||
arguments = {in1: in1_value, labels: label_value}
|
||||
|
@ -52,7 +52,7 @@ def test_output_to_retain():
|
|||
m_schedule = momentum_schedule(1100)
|
||||
|
||||
trainer = Trainer(z, ce, errs, \
|
||||
[sgd(z.parameters, 0.007, m_schedule, 0.5, True)])
|
||||
[momentum_sgd(z.parameters, 0.007, m_schedule)])
|
||||
in1_value = [[1],[2]]
|
||||
label_value = [[0], [1]]
|
||||
arguments = {in1: in1_value, labels: label_value}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
|
@ -157,7 +157,10 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
momentum_time_constant = momentum_as_time_constant_schedule(1100)
|
||||
clipping_threshold_per_sample = 2.3
|
||||
gradient_clipping_with_truncation = True
|
||||
learner = momentum_sgd(z.parameters, lr_per_sample, momentum_time_constant, clipping_threshold_per_sample, gradient_clipping_with_truncation)
|
||||
learner = momentum_sgd(z.parameters,
|
||||
lr_per_sample, momentum_time_constant,
|
||||
gradient_clipping_threshold_per_sample=clipping_threshold_per_sample,
|
||||
gradient_clipping_with_truncation=gradient_clipping_with_truncation)
|
||||
trainer = Trainer(z, ce, errs, learner)
|
||||
|
||||
# setup data
|
||||
|
|
Загрузка…
Ссылка в новой задаче