renamed Function.save_model() to save(), and restore_model() likewise;

made load_model a static method Function.load()
This commit is contained in:
Frank Seide 2017-02-09 17:24:39 -08:00
Родитель adf0a9238a
Коммит 3b02db423d
15 изменённых файлов: 66 добавлений и 48 удалений

Просмотреть файл

@ -129,7 +129,7 @@ def convnetlrn_cifar10_dataaug(reader_train, reader_test, epoch_size=50000, max_
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
progress_printer.epoch_summary(with_metric=True)
z.save_model(os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
z.save(os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
### Evaluation action
epoch_size = 10000

Просмотреть файл

@ -97,7 +97,7 @@ def convnet_cifar10(debug_output=False):
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
progress_printer.epoch_summary(with_metric=True)
z.save_model(os.path.join(model_path, "ConvNet_CIFAR10_{}.dnn".format(epoch)))
z.save(os.path.join(model_path, "ConvNet_CIFAR10_{}.dnn".format(epoch)))
# Load test data
reader_test = create_reader(os.path.join(data_path, 'Test_cntk_text.txt'), False, input_dim, num_output_classes)

Просмотреть файл

@ -107,7 +107,7 @@ def convnet_cifar10_dataaug(reader_train, reader_test, epoch_size = 50000, max_e
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
progress_printer.epoch_summary(with_metric=True)
z.save_model(os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
z.save(os.path.join(model_path, "ConvNet_CIFAR10_DataAug_{}.dnn".format(epoch)))
### Evaluation action
epoch_size = 10000

Просмотреть файл

@ -87,7 +87,7 @@ def convnet_mnist(debug_output=False):
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
progress_printer.epoch_summary(with_metric=True)
z.save_model(os.path.join(model_path, "ConvNet_MNIST_{}.dnn".format(epoch)))
z.save(os.path.join(model_path, "ConvNet_MNIST_{}.dnn".format(epoch)))
# Load test data
reader_test = create_reader(os.path.join(data_path, 'Test-28x28_cntk_text.txt'), False, input_dim, num_output_classes)

Просмотреть файл

@ -112,7 +112,7 @@ def train_and_evaluate(reader_train, reader_test, network_name, epoch_size, max_
sample_count += trainer.previous_minibatch_sample_count # count samples processed so far
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
progress_printer.epoch_summary(with_metric=True)
z.save_model(os.path.join(model_path, network_name + "_{}.dnn".format(epoch)))
z.save(os.path.join(model_path, network_name + "_{}.dnn".format(epoch)))
enable_profiler() # begin to collect profiler data after first epoch
if profiler_dir:

Просмотреть файл

@ -176,7 +176,7 @@ def train_fast_rcnn(debug_output=False):
progress_printer.epoch_summary(with_metric=True)
if debug_output:
frcn_output.save_model(os.path.join(abs_path, "Output", "frcn_py_%s.model" % (epoch+1)))
frcn_output.save(os.path.join(abs_path, "Output", "frcn_py_%s.model" % (epoch+1)))
return frcn_output
@ -217,7 +217,7 @@ if __name__ == '__main__':
trained_model = load_model(model_path)
else:
trained_model = train_fast_rcnn()
trained_model.save_model(model_path)
trained_model.save(model_path)
print("Stored trained model at %s" % model_path)
# Evaluate the test set

Просмотреть файл

@ -82,7 +82,7 @@ def deconv_mnist(max_epochs=3):
progress_printer.update_with_trainer(trainer, with_metric=True) # log progress
progress_printer.epoch_summary(with_metric=True)
z.save_model(os.path.join(model_path, "07_Deconvolution_PY_{}.model".format(epoch)))
z.save(os.path.join(model_path, "07_Deconvolution_PY_{}.model".format(epoch)))
# rename final model
last_model_name = os.path.join(model_path, "07_Deconvolution_PY_{}.model".format(max_epochs - 1))

Просмотреть файл

@ -203,7 +203,7 @@ if __name__ == '__main__':
trained_model = train_model(_base_model_file, _feature_node_name, _last_hidden_node_name,
_image_width, _image_height, _num_channels, _num_classes, _train_map_file,
max_epochs, freeze=freeze_weights)
trained_model.save_model(tl_model_file)
trained_model.save(tl_model_file)
print("Stored trained model at %s" % tl_model_file)
# Evaluate the test set

Просмотреть файл

@ -82,7 +82,7 @@ if __name__ == '__main__':
trained_model = train_model(base_model_file, feature_node_name, last_hidden_node_name,
image_width, image_height, num_channels,
len(class_mapping), train_map_file, num_epochs=30, freeze=True)
trained_model.save_model(new_model_file)
trained_model.save(new_model_file)
print("Stored trained model at %s" % tl_model_file)
# evaluate test images

Просмотреть файл

@ -233,8 +233,8 @@ def sequence_to_sequence_translator(debug_output=False, run_test=False):
error1 = translator_test_error(z, trainer, input_vocab_dim, label_vocab_dim)
z.save_model("seq2seq.dnn")
z.restore_model("seq2seq.dnn")
z.save("seq2seq.dnn")
z.restore("seq2seq.dnn")
label_seq_axis = Axis('labelAxis')
label_sequence = sequence.slice(find_arg_by_name('raw_labels',z), 1, 0)

Просмотреть файл

@ -184,7 +184,7 @@ def train_lm(training_file, max_num_minibatches):
p = 0
e += 1
model_filename = "models/shakespeare_epoch%d.dnn" % e
z.save_model(model_filename)
z.save(model_filename)
print("Saved model to '%s'" % model_filename)
# get the data
@ -207,7 +207,7 @@ def train_lm(training_file, max_num_minibatches):
# Do a final save of the model
model_filename = "models/shakespeare_epoch%d.dnn" % e
z.save_model(model_filename)
z.save(model_filename)
def load_and_sample(model_filename, vocab_filename, prime_text='', use_hardmax=False, length=1000, temperature=1.0):

Просмотреть файл

@ -751,7 +751,7 @@
" plot_weights([(agent.brain.params['W1'], 'Episode %i $W_1$'%episode_number)], figsize=(14,5))\n",
" break\n",
" reward_sum = 0\n",
"agent.brain.model.save_model('dqn.mod')"
"agent.brain.model.save('dqn.mod')"
]
},
{
@ -1184,7 +1184,7 @@
"\n",
" observation = env.reset() # reset env\n",
" episode_number += 1\n",
"probability.save_model('pg.mod')"
"probability.save('pg.mod')"
]
},
{

Просмотреть файл

@ -9,7 +9,7 @@ import numpy as np
import numbers
from numbers import Number
from . import sequence
from .functions import CloneMethod, Function, load_model
from .functions import CloneMethod, Function
from .variables import Variable, Parameter, Constant
from ..utils import sanitize_input, sanitize_shape, get_data_type, sanitize_axis, sanitize_dynamic_axes, typemap
from ..axis import Axis

Просмотреть файл

@ -621,7 +621,7 @@ class Function(cntk_py.Function):
return graph.find_by_name(self, name)
@typemap
def save_model(self, filename):
def save(self, filename):
'''
Save this function graph into a model file using protobuf-based
serialization.
@ -631,8 +631,11 @@ class Function(cntk_py.Function):
'''
return super(Function, self).save_model(filename)
def save_model(self, filename): # legacy name
return self.save(filename)
@typemap
def restore_model(self, filename):
def restore(self, filename):
'''
Restore the models parameters (in-place) from a saved model file
@ -644,6 +647,39 @@ class Function(cntk_py.Function):
'''
return super(Function, self).restore_model(filename)
def restore_model(self, filename): # legacy name
return self.restore(filename)
@staticmethod
@typemap
def load(filename, device=None):
'''
Load the model in ``filename``, that has been saved using
:func:`~cntk.ops.functions.Function.save`.
Args:
filename (str): filename to load the model from
device (:class:`~cntk.device.DeviceDescriptor`, default is the default device):
instance of DeviceDescriptor
Returns:
root node
'''
if not device:
device = DeviceDescriptor.use_default_device()
return cntk_py.Function.load_model(filename, device)
@typemap
def load_model(filename, device=None):
'''
Alias for :func:`~cntk.ops.functions.Function.load`.
'''
return Function.load(filename, device)
@typemap
def save_model(model, filename): # legacy name
return model.save(filename)
class UserFunction(Function):
'''
@ -770,22 +806,3 @@ class UserFunction(Function):
Returns the operator name.
'''
return 'UserFunction'
@typemap
def load_model(filename, device=None):
'''
Load the model in ``filename``, that has been saved using
:func:`~cntk.ops.functions.Function.save_model`.
Args:
filename (str): filename to load the model from
device (:class:`~cntk.device.DeviceDescriptor`, default is the default device):
instance of DeviceDescriptor
Returns:
root node
'''
if not device:
device = DeviceDescriptor.use_default_device()
return cntk_py.Function.load_model(filename, device)

Просмотреть файл

@ -10,6 +10,7 @@ from cntk.ops import *
from cntk.debug import save_as_legacy_model
from cntk.ops.functions import load_model
# TODO: a test for restore_model?
def test_load_save_constant(tmpdir):
c = constant(value=[1,3])
@ -20,19 +21,19 @@ def test_load_save_constant(tmpdir):
assert np.allclose(result, expected)
filename = str(tmpdir / 'c_plus_c.mod')
root_node.save_model(filename)
root_node.save(filename)
loaded_node = load_model(filename)
loaded_node = Function.load(filename)
loaded_result = loaded_node.eval()
assert np.allclose(loaded_result, expected)
filename = filename + '.legacy'
save_as_legacy_model(root_node, filename)
loaded_node = load_model(filename)
loaded_node = Function.load(filename)
loaded_result = loaded_node.eval()
assert np.allclose(loaded_result, expected)
def test_load_save_input(tmpdir):
def test_load_save_input_legacy_names(tmpdir):
i1 = input_variable((1,2), name='i1')
root_node = abs(i1)
input1 = [[[-1,2]]]
@ -68,9 +69,9 @@ def test_load_save_inputs(tmpdir):
assert np.allclose(result, expected)
filename = str(tmpdir / 'i_plus_i_0.mod')
root_node.save_model(filename)
root_node.save(filename)
loaded_node = load_model(filename)
loaded_node = Function.load(filename)
# Test specifying the input nodes by name
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
@ -78,7 +79,7 @@ def test_load_save_inputs(tmpdir):
filename = filename + '.legacy'
save_as_legacy_model(root_node, filename)
loaded_node = load_model(filename)
loaded_node = Function.load(filename)
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
assert np.allclose(loaded_result, expected)
@ -92,9 +93,9 @@ def test_load_save_unique_input(tmpdir):
assert np.allclose(result, expected)
filename = str(tmpdir / 'i_plus_0.mod')
root_node.save_model(filename)
root_node.save(filename)
loaded_node = load_model(filename)
loaded_node = Function.load(filename)
# Test specifying the only value for an unique input
loaded_result = loaded_node.eval(input1)
@ -102,6 +103,6 @@ def test_load_save_unique_input(tmpdir):
filename = filename + '.legacy'
save_as_legacy_model(root_node, filename)
loaded_node = load_model(filename)
loaded_node = Function.load(filename)
loaded_result = loaded_node.eval(input1)
assert np.allclose(loaded_result, expected)