This commit is contained in:
jeanfad 2016-05-24 15:47:29 +02:00
Родитель 9367d6acbd
Коммит 3b2f8ed56d
3 изменённых файлов: 5 добавлений и 9 удалений

Просмотреть файл

@ -940,7 +940,7 @@ def input_numpy(value, alias=None, dynamic_axis='', name=None):
:class:`cntk.graph.ComputationNode`
'''
from .. import utils
if utils.is_tensor_list(value) or utils.is_tensor(value):
if utils.is_tensor(value) or utils.is_tensor(value):
value = np.asarray(value)
if dynamic_axis:
cntk_shape = value[0].shape[1:]

Просмотреть файл

@ -189,7 +189,7 @@ def is_tensor(data):
if isinstance(data, np.ndarray):
return True
if not isinstance(data[0], list):
if not (isinstance(data[0], list) or isinstance(data[0], np.ndarray)):
return False
data = data[0]

Просмотреть файл

@ -43,12 +43,8 @@ def eval(node):
# One param needs to be an Input() node. This will be fixed in
# CNTK soon, so that we can remove this workaround and evaluate a
# network with no inputs.
if first:
if not isinstance(val, list):
# inputs have the outmost dimension for sequence dimension
val = [val]
ir = input_numpy(val, alias=p, name=p)
if first:
ir = input_numpy([val], alias=p, name=p)
setattr(node, p, ir)
first = False
else:
@ -56,5 +52,5 @@ def eval(node):
else:
if isinstance(val, _InputComputationNodeBase) and first:
first = False
ctx.clean_up=False
return ctx.eval(node)