fix is tensor test
This commit is contained in:
Родитель
9367d6acbd
Коммит
3b2f8ed56d
|
@ -940,7 +940,7 @@ def input_numpy(value, alias=None, dynamic_axis='', name=None):
|
|||
:class:`cntk.graph.ComputationNode`
|
||||
'''
|
||||
from .. import utils
|
||||
if utils.is_tensor_list(value) or utils.is_tensor(value):
|
||||
if utils.is_tensor(value) or utils.is_tensor(value):
|
||||
value = np.asarray(value)
|
||||
if dynamic_axis:
|
||||
cntk_shape = value[0].shape[1:]
|
||||
|
|
|
@ -189,7 +189,7 @@ def is_tensor(data):
|
|||
if isinstance(data, np.ndarray):
|
||||
return True
|
||||
|
||||
if not isinstance(data[0], list):
|
||||
if not (isinstance(data[0], list) or isinstance(data[0], np.ndarray)):
|
||||
return False
|
||||
|
||||
data = data[0]
|
||||
|
|
|
@ -43,12 +43,8 @@ def eval(node):
|
|||
# One param needs to be an Input() node. This will be fixed in
|
||||
# CNTK soon, so that we can remove this workaround and evaluate a
|
||||
# network with no inputs.
|
||||
if first:
|
||||
if not isinstance(val, list):
|
||||
# inputs have the outmost dimension for sequence dimension
|
||||
val = [val]
|
||||
|
||||
ir = input_numpy(val, alias=p, name=p)
|
||||
if first:
|
||||
ir = input_numpy([val], alias=p, name=p)
|
||||
setattr(node, p, ir)
|
||||
first = False
|
||||
else:
|
||||
|
@ -56,5 +52,5 @@ def eval(node):
|
|||
else:
|
||||
if isinstance(val, _InputComputationNodeBase) and first:
|
||||
first = False
|
||||
ctx.clean_up=False
|
||||
|
||||
return ctx.eval(node)
|
||||
|
|
Загрузка…
Ссылка в новой задаче