renamed Function.save_model() to save(), and restore_model() likewise;
made load_model a static method Function.load()
This commit is contained in:
Родитель
adf0a9238a
Коммит
3b02db423d
|
@ -129,7 +129,7 @@ def convnetlrn_cifar10_dataaug(reader_train, reader_test, epoch_size=50000, max_
|
|||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
z.save_model(os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
|
||||
z.save(os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
|
||||
|
||||
### Evaluation action
|
||||
epoch_size = 10000
|
||||
|
|
|
@ -97,7 +97,7 @@ def convnet_cifar10(debug_output=False):
|
|||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
z.save_model(os.path.join(model_path, "ConvNet_CIFAR10_{}.dnn".format(epoch)))
|
||||
z.save(os.path.join(model_path, "ConvNet_CIFAR10_{}.dnn".format(epoch)))
|
||||
|
||||
# Load test data
|
||||
reader_test = create_reader(os.path.join(data_path, 'Test_cntk_text.txt'), False, input_dim, num_output_classes)
|
||||
|
|
|
@ -107,7 +107,7 @@ def convnet_cifar10_dataaug(reader_train, reader_test, epoch_size = 50000, max_e
|
|||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
z.save_model(os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
|
||||
z.save(os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
|
||||
|
||||
### Evaluation action
|
||||
epoch_size = 10000
|
||||
|
|
|
@ -87,7 +87,7 @@ def convnet_mnist(debug_output=False):
|
|||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
z.save_model(os.path.join(model_path, "ConvNet_MNIST_{}.dnn".format(epoch)))
|
||||
z.save(os.path.join(model_path, "ConvNet_MNIST_{}.dnn".format(epoch)))
|
||||
|
||||
# Load test data
|
||||
reader_test = create_reader(os.path.join(data_path, 'Test-28x28_cntk_text.txt'), False, input_dim, num_output_classes)
|
||||
|
|
|
@ -112,7 +112,7 @@ def train_and_evaluate(reader_train, reader_test, network_name, epoch_size, max_
|
|||
sample_count += trainer.previous_minibatch_sample_count # count samples processed so far
|
||||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
z.save_model(os.path.join(model_path, network_name + "_{}.dnn".format(epoch)))
|
||||
z.save(os.path.join(model_path, network_name + "_{}.dnn".format(epoch)))
|
||||
enable_profiler() # begin to collect profiler data after first epoch
|
||||
|
||||
if profiler_dir:
|
||||
|
|
|
@ -176,7 +176,7 @@ def train_fast_rcnn(debug_output=False):
|
|||
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
if debug_output:
|
||||
frcn_output.save_model(os.path.join(abs_path, "Output", "frcn_py_%s.model" % (epoch+1)))
|
||||
frcn_output.save(os.path.join(abs_path, "Output", "frcn_py_%s.model" % (epoch+1)))
|
||||
|
||||
return frcn_output
|
||||
|
||||
|
@ -217,7 +217,7 @@ if __name__ == '__main__':
|
|||
trained_model = load_model(model_path)
|
||||
else:
|
||||
trained_model = train_fast_rcnn()
|
||||
trained_model.save_model(model_path)
|
||||
trained_model.save(model_path)
|
||||
print("Stored trained model at %s" % model_path)
|
||||
|
||||
# Evaluate the test set
|
||||
|
|
|
@ -82,7 +82,7 @@ def deconv_mnist(max_epochs=3):
|
|||
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
|
||||
|
||||
progress_printer.epoch_summary(with_metric=True)
|
||||
z.save_model(os.path.join(model_path, "07_Deconvolution_PY_{}.model".format(epoch)))
|
||||
z.save(os.path.join(model_path, "07_Deconvolution_PY_{}.model".format(epoch)))
|
||||
|
||||
# rename final model
|
||||
last_model_name = os.path.join(model_path, "07_Deconvolution_PY_{}.model".format(max_epochs - 1))
|
||||
|
|
|
@ -203,7 +203,7 @@ if __name__ == '__main__':
|
|||
trained_model = train_model(_base_model_file, _feature_node_name, _last_hidden_node_name,
|
||||
_image_width, _image_height, _num_channels, _num_classes, _train_map_file,
|
||||
max_epochs, freeze=freeze_weights)
|
||||
trained_model.save_model(tl_model_file)
|
||||
trained_model.save(tl_model_file)
|
||||
print("Stored trained model at %s" % tl_model_file)
|
||||
|
||||
# Evaluate the test set
|
||||
|
|
|
@ -82,7 +82,7 @@ if __name__ == '__main__':
|
|||
trained_model = train_model(base_model_file, feature_node_name, last_hidden_node_name,
|
||||
image_width, image_height, num_channels,
|
||||
len(class_mapping), train_map_file, num_epochs=30, freeze=True)
|
||||
trained_model.save_model(new_model_file)
|
||||
trained_model.save(new_model_file)
|
||||
print("Stored trained model at %s" % tl_model_file)
|
||||
|
||||
# evaluate test images
|
||||
|
|
|
@ -233,8 +233,8 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
|
|||
|
||||
error1 = translator_test_error(z, trainer, input_vocab_dim, label_vocab_dim)
|
||||
|
||||
z.save_model("seq2seq.dnn")
|
||||
z.restore_model("seq2seq.dnn")
|
||||
z.save("seq2seq.dnn")
|
||||
z.restore("seq2seq.dnn")
|
||||
|
||||
label_seq_axis = Axis('labelAxis')
|
||||
label_sequence = sequence.slice(find_arg_by_name('raw_labels',z), 1, 0)
|
||||
|
|
|
@ -184,7 +184,7 @@ def train_lm(training_file, max_num_minibatches):
|
|||
p = 0
|
||||
e += 1
|
||||
model_filename = "models/shakespeare_epoch%d.dnn" % e
|
||||
z.save_model(model_filename)
|
||||
z.save(model_filename)
|
||||
print("Saved model to '%s'" % model_filename)
|
||||
|
||||
# get the data
|
||||
|
@ -207,7 +207,7 @@ def train_lm(training_file, max_num_minibatches):
|
|||
|
||||
# Do a final save of the model
|
||||
model_filename = "models/shakespeare_epoch%d.dnn" % e
|
||||
z.save_model(model_filename)
|
||||
z.save(model_filename)
|
||||
|
||||
|
||||
def load_and_sample(model_filename, vocab_filename, prime_text='', use_hardmax=False, length=1000, temperature=1.0):
|
||||
|
|
|
@ -751,7 +751,7 @@
|
|||
" plot_weights([(agent.brain.params['W1'], 'Episode %i $W_1$'%episode_number)], figsize=(14,5))\n",
|
||||
" break\n",
|
||||
" reward_sum = 0\n",
|
||||
"agent.brain.model.save_model('dqn.mod')"
|
||||
"agent.brain.model.save('dqn.mod')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1184,7 +1184,7 @@
|
|||
"\n",
|
||||
" observation = env.reset() # reset env\n",
|
||||
" episode_number += 1\n",
|
||||
"probability.save_model('pg.mod')"
|
||||
"probability.save('pg.mod')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -9,7 +9,7 @@ import numpy as np
|
|||
import numbers
|
||||
from numbers import Number
|
||||
from . import sequence
|
||||
from .functions import CloneMethod, Function, load_model
|
||||
from .functions import CloneMethod, Function
|
||||
from .variables import Variable, Parameter, Constant
|
||||
from ..utils import sanitize_input, sanitize_shape, get_data_type, sanitize_axis, sanitize_dynamic_axes, typemap
|
||||
from ..axis import Axis
|
||||
|
|
|
@ -621,7 +621,7 @@ class Function(cntk_py.Function):
|
|||
return graph.find_by_name(self, name)
|
||||
|
||||
@typemap
|
||||
def save_model(self, filename):
|
||||
def save(self, filename):
|
||||
'''
|
||||
Save this function graph into a model file using protobuf-based
|
||||
serialization.
|
||||
|
@ -631,8 +631,11 @@ class Function(cntk_py.Function):
|
|||
'''
|
||||
return super(Function, self).save_model(filename)
|
||||
|
||||
def save_model(self, filename): # legacy name
|
||||
return self.save(filename)
|
||||
|
||||
@typemap
|
||||
def restore_model(self, filename):
|
||||
def restore(self, filename):
|
||||
'''
|
||||
Restore the models parameters (in-place) from a saved model file
|
||||
|
||||
|
@ -644,6 +647,39 @@ class Function(cntk_py.Function):
|
|||
'''
|
||||
return super(Function, self).restore_model(filename)
|
||||
|
||||
def restore_model(self, filename): # legacy name
|
||||
return self.restore(filename)
|
||||
|
||||
@staticmethod
|
||||
@typemap
|
||||
def load(filename, device=None):
|
||||
'''
|
||||
Load the model in ``filename``, that has been saved using
|
||||
:func:`~cntk.ops.functions.Function.save`.
|
||||
|
||||
Args:
|
||||
filename (str): filename to load the model from
|
||||
device (:class:`~cntk.device.DeviceDescriptor`, default is the default device):
|
||||
instance of DeviceDescriptor
|
||||
|
||||
Returns:
|
||||
root node
|
||||
'''
|
||||
if not device:
|
||||
device = DeviceDescriptor.use_default_device()
|
||||
return cntk_py.Function.load_model(filename, device)
|
||||
|
||||
@typemap
|
||||
def load_model(filename, device=None):
|
||||
'''
|
||||
Alias for :func:`~cntk.ops.functions.Function.load`.
|
||||
'''
|
||||
return Function.load(filename, device)
|
||||
|
||||
@typemap
|
||||
def save_model(model, filename): # legacy name
|
||||
return model.save(filename)
|
||||
|
||||
|
||||
class UserFunction(Function):
|
||||
'''
|
||||
|
@ -770,22 +806,3 @@ class UserFunction(Function):
|
|||
Returns the operator name.
|
||||
'''
|
||||
return 'UserFunction'
|
||||
|
||||
|
||||
@typemap
|
||||
def load_model(filename, device=None):
|
||||
'''
|
||||
Load the model in ``filename``, that has been saved using
|
||||
:func:`~cntk.ops.functions.Function.save_model`.
|
||||
|
||||
Args:
|
||||
filename (str): filename to load the model from
|
||||
device (:class:`~cntk.device.DeviceDescriptor`, default is the default device):
|
||||
instance of DeviceDescriptor
|
||||
|
||||
Returns:
|
||||
root node
|
||||
'''
|
||||
if not device:
|
||||
device = DeviceDescriptor.use_default_device()
|
||||
return cntk_py.Function.load_model(filename, device)
|
||||
|
|
|
@ -10,6 +10,7 @@ from cntk.ops import *
|
|||
from cntk.debug import save_as_legacy_model
|
||||
from cntk.ops.functions import load_model
|
||||
|
||||
# TODO: a test for restore_model?
|
||||
|
||||
def test_load_save_constant(tmpdir):
|
||||
c = constant(value=[1,3])
|
||||
|
@ -20,19 +21,19 @@ def test_load_save_constant(tmpdir):
|
|||
assert np.allclose(result, expected)
|
||||
|
||||
filename = str(tmpdir / 'c_plus_c.mod')
|
||||
root_node.save_model(filename)
|
||||
root_node.save(filename)
|
||||
|
||||
loaded_node = load_model(filename)
|
||||
loaded_node = Function.load(filename)
|
||||
loaded_result = loaded_node.eval()
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
||||
filename = filename + '.legacy'
|
||||
save_as_legacy_model(root_node, filename)
|
||||
loaded_node = load_model(filename)
|
||||
loaded_node = Function.load(filename)
|
||||
loaded_result = loaded_node.eval()
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
||||
def test_load_save_input(tmpdir):
|
||||
def test_load_save_input_legacy_names(tmpdir):
|
||||
i1 = input_variable((1,2), name='i1')
|
||||
root_node = abs(i1)
|
||||
input1 = [[[-1,2]]]
|
||||
|
@ -68,9 +69,9 @@ def test_load_save_inputs(tmpdir):
|
|||
assert np.allclose(result, expected)
|
||||
|
||||
filename = str(tmpdir / 'i_plus_i_0.mod')
|
||||
root_node.save_model(filename)
|
||||
root_node.save(filename)
|
||||
|
||||
loaded_node = load_model(filename)
|
||||
loaded_node = Function.load(filename)
|
||||
|
||||
# Test specifying the input nodes by name
|
||||
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
|
||||
|
@ -78,7 +79,7 @@ def test_load_save_inputs(tmpdir):
|
|||
|
||||
filename = filename + '.legacy'
|
||||
save_as_legacy_model(root_node, filename)
|
||||
loaded_node = load_model(filename)
|
||||
loaded_node = Function.load(filename)
|
||||
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
||||
|
@ -92,9 +93,9 @@ def test_load_save_unique_input(tmpdir):
|
|||
assert np.allclose(result, expected)
|
||||
|
||||
filename = str(tmpdir / 'i_plus_0.mod')
|
||||
root_node.save_model(filename)
|
||||
root_node.save(filename)
|
||||
|
||||
loaded_node = load_model(filename)
|
||||
loaded_node = Function.load(filename)
|
||||
|
||||
# Test specifying the only value for an unique input
|
||||
loaded_result = loaded_node.eval(input1)
|
||||
|
@ -102,6 +103,6 @@ def test_load_save_unique_input(tmpdir):
|
|||
|
||||
filename = filename + '.legacy'
|
||||
save_as_legacy_model(root_node, filename)
|
||||
loaded_node = load_model(filename)
|
||||
loaded_node = Function.load(filename)
|
||||
loaded_result = loaded_node.eval(input1)
|
||||
assert np.allclose(loaded_result, expected)
|
||||
|
|
Загрузка…
Ссылка в новой задаче