deferred context test
This commit is contained in:
Родитель
a00526732e
Коммит
be9a0efdaf
|
@ -6,7 +6,9 @@
|
|||
|
||||
import numpy as np
|
||||
from ..context import *
|
||||
|
||||
from ..ops.cntk2 import Input
|
||||
from ..sgd import *
|
||||
from ..reader import *
|
||||
|
||||
def test_parse_shapes_1():
|
||||
output = '''\
|
||||
|
@ -107,3 +109,25 @@ Final Results: Minibatch[1-1]: eval_node = 2.77790430 * 500; crit_node = 0.44370
|
|||
assert result['eval_node'] == 2.77790430
|
||||
assert result['crit_node'] == 0.44370050
|
||||
assert len(result) == 3
|
||||
|
||||
def test_export_deferred_context():
|
||||
X = Input(2)
|
||||
reader = CNTKTextFormatReader("Data.txt")
|
||||
my_sgd = SGDParams()
|
||||
|
||||
with DeferredExecutionContext() as ctx:
|
||||
input_map=reader.map(X, alias='I', dim=2)
|
||||
ctx.train(
|
||||
root_nodes=[X],
|
||||
training_params=my_sgd,
|
||||
input_map=input_map)
|
||||
|
||||
ctx.test(
|
||||
root_nodes=[X],
|
||||
input_map=input_map)
|
||||
|
||||
ctx.write(input_map=input_map)
|
||||
ctx.eval(X, input_map)
|
||||
with open(ctx.export("name")) as config_file:
|
||||
assert config_file.readlines()[-1] == "command=Train:Test:Write:Eval"
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче