This commit is contained in:
jeanfad 2016-06-02 11:11:38 +02:00
Родитель 3cbc1676ed
Коммит 5eedf54171
2 изменённых файлов: 26 добавлений и 8 удалений

Просмотреть файл

@ -4,7 +4,7 @@
# ==============================================================================
import numpy as np
from ..utils import wrap_numpy_arrays
from ..utils import wrap_numpy_arrays, get_rank
################################################################################
# convolution ops
@ -833,7 +833,7 @@ def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None
from cntk.ops.cntk1 import FutureValue
op = FutureValue(shape, x, time_step, default_hidden_activation, name = name)
wrap_numpy_arrays(op)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = get_rank(shape)
return op
def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
@ -868,7 +868,7 @@ def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
from cntk.ops.cntk1 import PastValue
op = PastValue(shape, x, time_step, default_hidden_activation, name = name)
wrap_numpy_arrays(op)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = get_rank(shape)
return op
################################################################################
@ -903,7 +903,7 @@ def reshape(x, shape, name=None):
shape = tuple(reversed(shape))
op = NewReshape(x, shape, 0, 0, name = name)
wrap_numpy_arrays(op)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = get_rank(shape)
return op
def transpose_dimensions(x, axis1, axis2, name=None):
@ -1079,7 +1079,7 @@ def input(shape, dynamic_axis='', name=None):
shape = tuple(reversed(shape))
op = Input(shape, dynamicAxis=dynamic_axis, name=name)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = get_rank(shape)
return op
def sparse_input_numpy(indices, values, shape, alias=None, dynamic_axis='', name=None):
@ -1153,7 +1153,7 @@ def sparse_input(shape, dynamic_axis='', name=None):
# cntk uses column major, thus we reverse the shape
shape = tuple(reversed(shape))
op = SparseInput(shape, dynamicAxis=dynamic_axis, name=name)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = get_rank(shape)
return op
def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
@ -1190,7 +1190,7 @@ def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
learningRateMultiplier=learning_rate_multiplier,
name=name)
op.rank = 0 if np.isscalar(shape) else len(shape)
op.rank = get_rank(shape)
return op
"""
@ -1232,7 +1232,7 @@ def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
init='fromLiteral',
initFromLiteral=s.getvalue().decode())
op.rank = 0 if np.isscalar(param_shape) else len(param_shape)
op.rank = get_rank(param_shape)
return op
def constant(value, name=None):

Просмотреть файл

@ -225,6 +225,24 @@ def get_temp_filename(directory=None):
return tf.name
def get_rank(shape):
'''
computes the rank of a tensor.
Args:
shape: it is either a tuple or an integer.
Returns: the rank of the tensor.
'''
if np.isscalar(shape):
if shape == 1:
return 0
else:
return 1
else:
return len(shape)
def wrap_numpy_arrays(node):
'''
for a given computation node, wrapes its tensor inputs that are numpy arrays