add get_rank
This commit is contained in:
Родитель
3cbc1676ed
Коммит
5eedf54171
|
@ -4,7 +4,7 @@
|
|||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
from ..utils import wrap_numpy_arrays
|
||||
from ..utils import wrap_numpy_arrays, get_rank
|
||||
|
||||
################################################################################
|
||||
# convolution ops
|
||||
|
@ -833,7 +833,7 @@ def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None
|
|||
from cntk.ops.cntk1 import FutureValue
|
||||
op = FutureValue(shape, x, time_step, default_hidden_activation, name = name)
|
||||
wrap_numpy_arrays(op)
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
op.rank = get_rank(shape)
|
||||
return op
|
||||
|
||||
def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
|
||||
|
@ -868,7 +868,7 @@ def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
|
|||
from cntk.ops.cntk1 import PastValue
|
||||
op = PastValue(shape, x, time_step, default_hidden_activation, name = name)
|
||||
wrap_numpy_arrays(op)
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
op.rank = get_rank(shape)
|
||||
return op
|
||||
|
||||
################################################################################
|
||||
|
@ -903,7 +903,7 @@ def reshape(x, shape, name=None):
|
|||
shape = tuple(reversed(shape))
|
||||
op = NewReshape(x, shape, 0, 0, name = name)
|
||||
wrap_numpy_arrays(op)
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
op.rank = get_rank(shape)
|
||||
return op
|
||||
|
||||
def transpose_dimensions(x, axis1, axis2, name=None):
|
||||
|
@ -1079,7 +1079,7 @@ def input(shape, dynamic_axis='', name=None):
|
|||
shape = tuple(reversed(shape))
|
||||
op = Input(shape, dynamicAxis=dynamic_axis, name=name)
|
||||
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
op.rank = get_rank(shape)
|
||||
return op
|
||||
|
||||
def sparse_input_numpy(indices, values, shape, alias=None, dynamic_axis='', name=None):
|
||||
|
@ -1153,7 +1153,7 @@ def sparse_input(shape, dynamic_axis='', name=None):
|
|||
# cntk uses column major, thus we reverse the shape
|
||||
shape = tuple(reversed(shape))
|
||||
op = SparseInput(shape, dynamicAxis=dynamic_axis, name=name)
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
op.rank = get_rank(shape)
|
||||
return op
|
||||
|
||||
def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
|
||||
|
@ -1190,7 +1190,7 @@ def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
|
|||
learningRateMultiplier=learning_rate_multiplier,
|
||||
name=name)
|
||||
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
op.rank = get_rank(shape)
|
||||
return op
|
||||
|
||||
"""
|
||||
|
@ -1232,7 +1232,7 @@ def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
|
|||
init='fromLiteral',
|
||||
initFromLiteral=s.getvalue().decode())
|
||||
|
||||
op.rank = 0 if np.isscalar(param_shape) else len(param_shape)
|
||||
op.rank = get_rank(param_shape)
|
||||
return op
|
||||
|
||||
def constant(value, name=None):
|
||||
|
|
|
@ -225,6 +225,24 @@ def get_temp_filename(directory=None):
|
|||
|
||||
return tf.name
|
||||
|
||||
def get_rank(shape):
|
||||
'''
|
||||
computes the rank of a tensor.
|
||||
|
||||
Args:
|
||||
shape: it is either a tuple or an integer.
|
||||
|
||||
Returns: the rank of the tensor.
|
||||
|
||||
'''
|
||||
if np.isscalar(shape):
|
||||
if shape == 1:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
else:
|
||||
return len(shape)
|
||||
|
||||
def wrap_numpy_arrays(node):
|
||||
'''
|
||||
for a given computation node, wrapes its tensor inputs that are numpy arrays
|
||||
|
|
Загрузка…
Ссылка в новой задаче