example with numpy data
This commit is contained in:
Родитель
ea012e808a
Коммит
5f9778b841
|
@ -17,11 +17,11 @@ import numpy as np
|
|||
DATATYPE = np.float32
|
||||
|
||||
class Trainer(Trainer):
|
||||
"""
|
||||
'''
|
||||
Trainer to train the specified 'model' with the specified `training_loss` as the training criterion,
|
||||
the specified `evaluation_function` as the criterion for evaluating the trained model's quality, and using the specified set
|
||||
of `parameters` for updating the model's parameters using computed gradients.
|
||||
"""
|
||||
'''
|
||||
def __init__(self, model, loss_function, eval_function, parameters):
|
||||
if isinstance(model, Variable):
|
||||
model = model.owner
|
||||
|
@ -30,3 +30,35 @@ class Trainer(Trainer):
|
|||
if isinstance(eval_function, Variable):
|
||||
eval_function = eval_function.owner
|
||||
super(Trainer, self).__init__(model, loss_function, eval_function, parameters)
|
||||
|
||||
def train_minibatch(self, arguments, device=None):
|
||||
'''
|
||||
Optimize model parameters using the specified 'arguments' minibatch of training samples.
|
||||
Returns false if all parameter learners indicate end of learning (through their Update method's return value).
|
||||
Args:
|
||||
arguments (dict): map from input variables to the data, data should be either numpy
|
||||
arrays or cntk.Value instances returned by a minibatch source
|
||||
device (:class:`cntk.DeviceDescriptor`): the device descriptor that contains the type and id of the device
|
||||
Returns:
|
||||
bool
|
||||
'''
|
||||
if not device:
|
||||
device=DeviceDescriptor.use_default_device()
|
||||
arguments = sanitize_var_map(arguments, add_batch_axis=True)
|
||||
super(Trainer, self).train_minibatch(arguments, device)
|
||||
|
||||
def test_minibatch(self, arguments, device=None):
|
||||
'''
|
||||
Test the model on the specified batch of samples using the evaluation Function specified during construction of the Trainer
|
||||
Returns the average evaluation criterion value per sample for the tested minibatch of samples
|
||||
Args:
|
||||
arguments (dict): map from input variables to the data, data should be either numpy
|
||||
arrays or cntk.Value instances returned by a minibatch source
|
||||
device (:class:`cntk.DeviceDescriptor`): the device descriptor that contains the type and id of the device
|
||||
Returns:
|
||||
float
|
||||
'''
|
||||
if not device:
|
||||
device=DeviceDescriptor.use_default_device()
|
||||
arguments = sanitize_var_map(arguments, add_batch_axis=True)
|
||||
super(Trainer, self).train_minibatch(arguments, device)
|
||||
|
|
|
@ -307,7 +307,7 @@ def pad_to_dense(batch):
|
|||
Z[idx, :len(seq)] += seq
|
||||
return Z
|
||||
|
||||
def sanitize_batch(batch, data_type, dev):
|
||||
def sanitize_batch(batch, data_type=None, dev=None):
|
||||
"""
|
||||
Convert to Value with `data_type`. If the samples in `batch` have different
|
||||
sequence lengths, pad them to max sequence length and create a mask.
|
||||
|
@ -347,7 +347,7 @@ def sanitize_batch(batch, data_type, dev):
|
|||
batch = pad_to_dense(batch)
|
||||
|
||||
# If it still is not an NumPy array, try brute force...
|
||||
if not isinstance(batch, np.ndarray) or batch.dtype != data_type:
|
||||
if not isinstance(batch, np.ndarray):
|
||||
batch = np.asarray(batch, dtype=data_type)
|
||||
|
||||
'''
|
||||
|
@ -371,7 +371,7 @@ def sanitize_batch(batch, data_type, dev):
|
|||
|
||||
return value
|
||||
|
||||
def sanitize_var_map(input_map, precision_numpy, device):
|
||||
def sanitize_var_map(input_map, precision_numpy=None, device=None, add_batch_axis=False):
|
||||
'''
|
||||
Sanitizes a dictionary of `Variable`s to input data such that it can be
|
||||
handed off to the `Forward` method.
|
||||
|
@ -380,6 +380,7 @@ def sanitize_var_map(input_map, precision_numpy, device):
|
|||
input_map (`dict`): `Variable` to input (NumPy array or simple list of lists)
|
||||
precision_numpy : np.float32 or np.float64
|
||||
device: CNTK DeviceDescriptor
|
||||
add_batch_axis (bool): if the data does not have the batch axis, add it before creating NDArrayView
|
||||
|
||||
Returns:
|
||||
`dict` that maps variables to sanitized batches
|
||||
|
@ -387,16 +388,22 @@ def sanitize_var_map(input_map, precision_numpy, device):
|
|||
var_map = {}
|
||||
if input_map:
|
||||
for var, batch in input_map.items():
|
||||
if isinstance(batch, np.ndarray):
|
||||
if batch.dtype not in (np.float32, np.float64):
|
||||
raise ValueError('only float32 and float64 are supported')
|
||||
batch = sanitize_batch(batch, precision_numpy, device)
|
||||
else:
|
||||
if is_tensor(batch):
|
||||
batch = np.asarray(batch, dtype=precision_numpy)
|
||||
batch = create_Value_from_NumPy(batch, device)
|
||||
else:
|
||||
from ..cntk_py import Value
|
||||
if not isinstance(batch, Value):
|
||||
if add_batch_axis:
|
||||
batch = [batch]
|
||||
if isinstance(batch, np.ndarray):
|
||||
if batch.dtype == np.int:
|
||||
batch = batch.astype(np.float32)
|
||||
if batch.dtype not in (np.float32, np.float64):
|
||||
raise ValueError('only float32 and float64 are supported')
|
||||
batch = sanitize_batch(batch, precision_numpy, device)
|
||||
else:
|
||||
if is_tensor(batch):
|
||||
batch = np.asarray(batch, dtype=precision_numpy)
|
||||
batch = create_Value_from_NumPy(batch, device)
|
||||
else:
|
||||
batch = sanitize_batch(batch, precision_numpy, device)
|
||||
|
||||
var_map[var] = batch
|
||||
|
||||
|
|
|
@ -11,6 +11,19 @@ from cntk import learning_rates_per_sample, DeviceDescriptor, Trainer, sgd_learn
|
|||
from cntk.ops import input_variable, cross_entropy_with_softmax, combine, classification_error, sigmoid
|
||||
from examples.common.nn import fully_connected_classifier_net, print_training_progress
|
||||
|
||||
def generate_random_data(sample_dize, feature_dim, num_classes):
|
||||
# Create synthetic data using NumPy.
|
||||
Y = np.random.randint(size=(sample_dize, 1), low=0, high=num_classes)
|
||||
|
||||
# Make sure that the data is separable
|
||||
X = (np.random.randn(sample_dize, feature_dim)+3) * (Y+1)
|
||||
X = X.astype(np.float32)
|
||||
# converting class 0 into the vector "1 0 0",
|
||||
# class 1 into vector "0 1 0", ...
|
||||
class_ind = [Y==class_number for class_number in range(num_classes)]
|
||||
Y = np.asarray(np.hstack(class_ind), dtype=np.float32)
|
||||
return X, Y
|
||||
|
||||
# Creates and trains a feedforward classification model
|
||||
def ffnet():
|
||||
input_dim = 2
|
||||
|
@ -28,17 +41,6 @@ def ffnet():
|
|||
ce = cross_entropy_with_softmax(netout, label)
|
||||
pe = classification_error(netout, label)
|
||||
|
||||
rel_path = r"../../../../Examples/Other/Simple2d/Data/SimpleDataTrain_cntk_text.txt"
|
||||
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), rel_path)
|
||||
feature_stream_name = 'features'
|
||||
labels_stream_name = 'labels'
|
||||
|
||||
mb_source = text_format_minibatch_source(path, [
|
||||
StreamConfiguration( feature_stream_name, input_dim ),
|
||||
StreamConfiguration( labels_stream_name, num_output_classes)])
|
||||
features_si = mb_source.stream_info(feature_stream_name)
|
||||
labels_si = mb_source.stream_info(labels_stream_name)
|
||||
|
||||
# Instantiate the trainer object to drive the model training
|
||||
lr = learning_rates_per_sample(0.02)
|
||||
trainer = Trainer(netout, ce, pe, [sgd_learner(netout.owner.parameters(), lr)])
|
||||
|
@ -50,11 +52,9 @@ def ffnet():
|
|||
num_minibatches_to_train = (num_samples_per_sweep * num_sweeps_to_train_with) / minibatch_size
|
||||
training_progress_output_freq = 20
|
||||
for i in range(0, int(num_minibatches_to_train)):
|
||||
mb = mb_source.get_next_minibatch(minibatch_size)
|
||||
|
||||
features, labels = generate_random_data(minibatch_size, input_dim, num_output_classes)
|
||||
# Specify the mapping of input variables in the model to actual minibatch data to be trained with
|
||||
arguments = {input : mb[features_si].m_data, label : mb[labels_si].m_data}
|
||||
trainer.train_minibatch(arguments)
|
||||
trainer.train_minibatch({input : features, label : labels})
|
||||
print_training_progress(trainer, i, training_progress_output_freq)
|
||||
|
||||
if __name__=='__main__':
|
Загрузка…
Ссылка в новой задаче