Integrate wilrich/logregacc into master

This commit is contained in:
Project Philly 2016-05-10 05:18:02 -07:00
Родитель 3f4a000efc 6130ebacf5
Коммит 9541f38ae3
3 изменённых файлов: 79 добавлений и 37 удалений

Просмотреть файл

@ -66,7 +66,7 @@ def train_eval_logistic_regression_from_file(criterion_name=None,
def test_logistic_regression_from_file(device_id):
result = train_eval_logistic_regression_from_file('crit_node', 'eval_node', device_id)
TOLERANCE_ABSOLUTE = 1E-02
TOLERANCE_ABSOLUTE = 1E-06
assert np.allclose(result['perplexity'], 1.55153792, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['crit_node'], 0.43924664, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['eval_node'], 3.26340137, atol=TOLERANCE_ABSOLUTE)

Просмотреть файл

@ -16,28 +16,47 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
import numpy as np
import cntk as C
train_N = 1000
test_N = 500
# Mapping 2 numbers to 3 classes
feature_dim = 2
num_classes = 3
def synthetic_data(N, feature_dim, num_classes):
# Create synthetic data using NumPy.
Y = np.random.randint(size=(N, 1), low=0, high=num_classes)
# Make sure that the data is separable
X = (np.random.randn(N, feature_dim)+3) * (Y+1)
# converting class 0 into the vector "1 0 0",
# class 1 into vector "0 1 0", ...
class_ind = [Y==class_number for class_number in range(num_classes)]
Y = np.asarray(np.hstack(class_ind), dtype=int)
return X, Y
def train_eval_logistic_regression_with_numpy(criterion_name=None,
eval_name=None, device_id=-1):
# for repro and tests :-)
np.random.seed(1)
N = 500
d = 250
train_X, train_y = synthetic_data(train_N, feature_dim, num_classes)
test_X, test_y = synthetic_data(test_N, feature_dim, num_classes)
# create synthetic data using numpy
X = np.random.randn(N, d)
Y = np.random.randint(size=(N, 1), low=0, high=2)
Y = np.hstack((Y, 1-Y))
# set up the training data for CNTK
x = C.input_numpy(X)
y = C.input_numpy(Y)
# Set up the training data for CNTK. Before writing the CNTK configuration,
# the data will be attached to X.reader.batch and y.reader.batch and then
# serialized.
X = C.input_numpy(train_X)
y = C.input_numpy(train_y)
# define our network -- one weight tensor and a bias
W = C.parameter(value=np.zeros(shape=(2, d)))
b = C.parameter(value=np.zeros(shape=(2, 1)))
out = C.times(W, x) + b
W = C.parameter(value=np.zeros(shape=(num_classes, feature_dim)))
b = C.parameter(value=np.zeros(shape=(num_classes, 1)))
out = C.times(W, X) + b
ce = C.cross_entropy_with_softmax(y, out)
ce.tag = 'criterion'
@ -47,14 +66,18 @@ def train_eval_logistic_regression_with_numpy(criterion_name=None,
eval.tag = 'eval'
eval.name = eval_name
my_sgd = C.SGDParams(epoch_size=0, minibatch_size=25, learning_rates_per_mb=0.1, max_epochs=3)
with C.LocalExecutionContext('logreg') as ctx:
my_sgd = C.SGDParams(epoch_size=0, minibatch_size=25,
learning_rates_per_mb=0.1, max_epochs=3)
with C.LocalExecutionContext('logreg', clean_up=False) as ctx:
ctx.device_id = device_id
ctx.train(
root_nodes=[ce,eval],
training_params=my_sgd)
# For testing, we attach the test data to the input nodes.
X.reader.batch, y.reader.batch = test_X, test_y
result = ctx.test(root_nodes=[ce,eval])
return result
@ -63,10 +86,11 @@ def test_logistic_regression_with_numpy(device_id):
result = train_eval_logistic_regression_with_numpy('crit_node',
'eval_node', device_id)
TOLERANCE_ABSOLUTE = 1E-02
assert np.allclose(result['perplexity'], 1.55057073, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['crit_node'], 0.43862308, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['eval_node'], 1.16664551, atol=TOLERANCE_ABSOLUTE)
TOLERANCE_ABSOLUTE = 1E-06
print(result)
assert np.allclose(result['perplexity'], 2.33378225, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['crit_node'], 0.84749023, atol=TOLERANCE_ABSOLUTE)
assert np.allclose(result['eval_node'], 2.69121655, atol=TOLERANCE_ABSOLUTE)
if __name__ == "__main__":
print(train_eval_logistic_regression_with_numpy())

Просмотреть файл

@ -24,9 +24,6 @@ class AbstractReader(with_metaclass(ABCMeta)):
def __ne__(self, x): return x is not self
def _to_aggregate_form():
pass
class UCIFastReader(AbstractReader):
@ -133,16 +130,29 @@ class CNTKTextFormatReader(AbstractReader):
if randomize is None:
randomize = 'none'
self.randomize = randomize.lower()
assert self.randomize in ['auto', 'none']
if not self.randomize in ['auto', 'none']:
raise ValueError('parameter "randomize" can be only "auto", ' +
'"none", or None. You gave: %s'%randomize)
self.skip_sequence_ids = bool(skip_sequence_ids)
self.max_errors = int(max_errors)
assert self.max_errors >= 0
if not self.max_errors >= 0:
raise ValueError('parameter "max_errors" has to be an integer ' +
'greater than or equal to 0. You gave: %s'%max_errors)
self.trace_level = int(trace_level)
assert self.trace_level in [0,1,2]
if not self.trace_level in [0,1,2]:
raise ValueError('parameter "trace_level" has to be an integer ' +
'from [0, 1, 2]. You gave: %s'%str(trace_level))
self.chunk_size_in_bytes = int(chunk_size_in_bytes)
assert self.chunk_size_in_bytes > 0
if not self.chunk_size_in_bytes > 0:
raise ValueError('parameter "chunk_size_in_bytes" has to be an integer ' +
'greater than zero. You gave: %s'%str(chunk_size_in_bytes))
self.num_chunks_to_cache = int(num_chunks_to_cache)
assert self.chunk_size_in_bytes >= 0
if self.chunk_size_in_bytes < 0:
raise ValueError('parameter "chunk_size_in_bytes" has to be an integer ' +
'greater than or equal to zero. You gave: %s'\
%str(self.chunk_size_in_bytes))
def map(self, node_or_name, **kw):
'''
@ -211,10 +221,14 @@ class CNTKTextFormatReader(AbstractReader):
'''
if input_map.has_unmapped():
if len(input_map.node_map) > 0:
raise ValueError('you cannot have inputs initialized with '+
'NumPy arrays together with inputs that are ' +
' initialized with a custom reader')
input_map._serialize_unmapped_nodes(
input_map.unmapped_nodes, self.filename)
for node_or_name, param_dict in input_map.node_map.items():
if (isinstance(node_or_name, ComputationNode)):
name = node_or_name.name
@ -477,10 +491,16 @@ class InputMap(object):
from .utils import get_temp_filename
filename = get_temp_filename(get_context().directory)
assert not self.node_map
if len(self.node_map) > 0:
raise ValueError('you cannot have inputs initialized with '+
'NumPy arrays together with inputs that are ' +
' initialized with a custom reader')
self._serialize_unmapped_nodes(filename)
r = CNTKTextFormatReader(filename)
# All the data we got, was through NumPy. In this case, we assume
# that all the required randomization has happened already.
r = CNTKTextFormatReader(filename, randomize=None)
return r._to_config_description(self)
@ -509,8 +529,9 @@ class InputMap(object):
sample_sizes = collections.defaultdict(list)
used_aliases = set()
for node in self.unmapped_nodes:
assert node._is_input()
assert isinstance(node.reader, LazyInputReader)
is_lazy_input = isinstance(node.reader, LazyInputReader)
if not (node._is_input() and is_lazy_input):
raise ValueError('expected NumPy input, but got "%s"'%str(node))
l = node.reader
@ -553,10 +574,7 @@ class InputMap(object):
value_shape = shapes_in_tensor.pop()
l.shape = value_shape if value_shape else (1,)
assert node not in self.node_map
self.node_map[node] = {
'alias': l.input_alias,
}
self.node_map[node] = { 'alias': l.input_alias }
# make sure all inputs have same sample size
if len(sample_sizes) != 1: