adjust some unit tests
This commit is contained in:
Родитель
1d06073fcd
Коммит
85e5bceb4d
|
@ -36,7 +36,7 @@ def cross_entropy_with_softmax(target_vector, output_vector, name=None):
|
|||
over the labels.
|
||||
output_vector: the unscaled computed output values from the network
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk1 import CrossEntropyWithSoftmax
|
||||
|
@ -63,7 +63,7 @@ def square_error(target_matrix, output_matrix, name=None):
|
|||
hot bit corresponds to the label index
|
||||
output_matrix: the output values from the network
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk1 import SquareError
|
||||
|
@ -93,7 +93,7 @@ def error_prediction(target_vector, output_vector, name=None):
|
|||
label index
|
||||
output_vector: the output values from the network
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import ErrorPrediction
|
||||
|
@ -121,7 +121,7 @@ def less(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Less
|
||||
|
@ -145,7 +145,7 @@ def equal(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Equal
|
||||
|
@ -169,7 +169,7 @@ def greater(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Greater
|
||||
|
@ -193,7 +193,7 @@ def greater_equal(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import GreaterEqual
|
||||
|
@ -217,7 +217,7 @@ def not_equal(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import NotEqual
|
||||
|
@ -241,7 +241,7 @@ def less_equal(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import LessEqual
|
||||
|
@ -271,7 +271,7 @@ def plus(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Plus
|
||||
|
@ -298,7 +298,7 @@ def minus(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
|
@ -326,7 +326,7 @@ def element_times(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import ElementTimes
|
||||
|
@ -356,7 +356,7 @@ def element_divide(left, right, name=None):
|
|||
left: left side tensor
|
||||
right: right side tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import ElementDivide
|
||||
|
@ -401,20 +401,20 @@ def times(left, right, output_rank=1, name=None):
|
|||
into matrices, perform the operation and then reshape back (explode the axes)
|
||||
name: the name of the node in the network
|
||||
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Times
|
||||
# CNTK uses column vectors and column major representation, thus we reverse
|
||||
# params
|
||||
op = Times(right, left, outputRank=output_rank, name=name)
|
||||
wrap_numpy_arrays(op)
|
||||
op.rank = op._.rank + op.y.rank - 2
|
||||
#wrap_numpy_arrays(op)
|
||||
op.rank = op.x.rank + op.y.rank - 2
|
||||
return op
|
||||
|
||||
def identity(x, name=None):
|
||||
"""
|
||||
The identity function. It op =s an identical tensor to the input tensor `x`:
|
||||
The identity function. It returns an identical tensor to the input tensor `x`:
|
||||
|
||||
:math:`pass_tensor(x) = x`
|
||||
|
||||
|
@ -424,7 +424,7 @@ def identity(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Identity
|
||||
|
@ -461,7 +461,7 @@ def floor(arg, name=None):
|
|||
Args:
|
||||
arg: input tensor
|
||||
name: the name of the node in the network (optional)
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Floor
|
||||
|
@ -486,7 +486,7 @@ def ceil(arg, name=None):
|
|||
Args:
|
||||
arg: input tensor
|
||||
name: the name of the node in the network (optional)
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Ceil
|
||||
|
@ -521,7 +521,7 @@ def round(arg, name=None):
|
|||
Args:
|
||||
arg: input tensor
|
||||
name: the name of the node in the network (optional)
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Round
|
||||
|
@ -558,7 +558,7 @@ def clip(x, min_value, max_value, name=None):
|
|||
min_value: the minimum value to clip element values to
|
||||
max_value: the maximum value to clip element values to
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Clip
|
||||
|
@ -580,7 +580,7 @@ def relu(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Relu
|
||||
|
@ -603,7 +603,7 @@ def sigmoid(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Sigmoid
|
||||
|
@ -625,7 +625,7 @@ def tanh(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Tanh
|
||||
|
@ -652,7 +652,7 @@ def softmax(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Softmax
|
||||
|
@ -673,7 +673,7 @@ def exp(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Exp
|
||||
|
@ -692,11 +692,11 @@ def log(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
|
||||
Note:
|
||||
CNTK op =s -85.1 for log(x) if `x` is negative or zero. The reason is that
|
||||
CNTK returns -85.1 for log(x) if `x` is negative or zero. The reason is that
|
||||
it uses 1e-37 (whose natural logarithm is -85.1) as the smallest float
|
||||
number for `log`, because this is the only guaranteed precision across
|
||||
platforms. This will be changed to op = `NaN` and `-inf`.
|
||||
|
@ -719,11 +719,11 @@ def sqrt(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
|
||||
Note:
|
||||
CNTK op =s zero for sqrt of negative nubmers, this will be changed to
|
||||
CNTK returns zero for sqrt of negative nubmers, this will be changed to
|
||||
op = NaN
|
||||
"""
|
||||
from cntk.ops.cntk2 import Sqrt
|
||||
|
@ -742,7 +742,7 @@ def square(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Square
|
||||
|
@ -763,7 +763,7 @@ def abs(x, name=None):
|
|||
|
||||
Args:
|
||||
x: any :class:`cntk.graph.ComputationNode` that outputs a tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Abs
|
||||
|
@ -788,13 +788,13 @@ def cond(flag, value_if_true, value_if_false, name=None):
|
|||
value_if_true: tensor
|
||||
value_if_false: tensor
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk1 import If
|
||||
op = If(flag, value_if_true, value_if_false, name = name)
|
||||
wrap_numpy_arrays(op)
|
||||
op.rank = max(op.cond.rank(max(op.thenVal,op.elseVal)))
|
||||
op.rank = max(op.cond.rank,max(op.thenVal.rank,op.elseVal.rank))
|
||||
return op
|
||||
|
||||
################################################################################
|
||||
|
@ -803,7 +803,7 @@ def cond(flag, value_if_true, value_if_false, name=None):
|
|||
|
||||
def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
|
||||
"""
|
||||
This function op =s the future value wrt `x`. It is most often used when
|
||||
This function returns the future value wrt `x`. It is most often used when
|
||||
creating RNNs. The resulting tensor has the same shape as the input but is
|
||||
the next logical sample. The `time_step` parameter is the number of steps
|
||||
to look into the future and is 1 by default. If there is no future value (i.e.
|
||||
|
@ -826,7 +826,7 @@ def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None
|
|||
time_step: the number of time steps to look into the future (default 1)
|
||||
default_hidden_activation: the default value to use when no future value
|
||||
is available (default 0.1)
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
|
@ -838,7 +838,7 @@ def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None
|
|||
|
||||
def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
|
||||
"""
|
||||
This function op =s the past value wrt `x`. It is most often used when
|
||||
This function returns the past value wrt `x`. It is most often used when
|
||||
creating RNNs. The resulting tensor has the same shape as the input but is
|
||||
the previous logical sample. The `time_step` parameter is the number of steps
|
||||
to look into the past and is 1 by default. If there is no past value (i.e.
|
||||
|
@ -861,13 +861,13 @@ def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
|
|||
time_step: the number of time steps to look into the past (default 1)
|
||||
default_hidden_activation: the default value to use when no past value
|
||||
is available (default 0.1)
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
from cntk.ops.cntk1 import PastValue
|
||||
op = PastValue(shape, x, time_step, default_hidden_activation, name = name)
|
||||
wrap_numpy_arrays(op)
|
||||
wrap_numpy_arrays(op)
|
||||
op.rank = 0 if np.isscalar(shape) else len(shape)
|
||||
return op
|
||||
|
||||
|
@ -894,7 +894,7 @@ def reshape(x, shape, name=None):
|
|||
x: tensor to be reshaped
|
||||
shape: a tuple defining the resulting shape
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk1 import NewReshape
|
||||
|
@ -925,7 +925,7 @@ def transpose_dimensions(x, axis1, axis2, name=None):
|
|||
x: tensor to be reshaped
|
||||
axis1: the axis to swap with axis2
|
||||
axis2: the axis to swap with axis1
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import TransposeDimensions
|
||||
|
@ -974,7 +974,7 @@ def slice(x, begin_index, end_index, axis=0, name=None):
|
|||
See also:
|
||||
Indexing in NumPy: http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
|
||||
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
'''
|
||||
from cntk.ops.cntk2 import Slice
|
||||
|
@ -1006,7 +1006,7 @@ def dropout(x, name=None):
|
|||
|
||||
Args:
|
||||
x: source tensor
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
from cntk.ops.cntk2 import Dropout
|
||||
|
@ -1033,11 +1033,11 @@ def input_numpy(value, alias=None, dynamic_axis='', name=None):
|
|||
alias (str): alias to be used in the data file
|
||||
dynamic_axis (str): whether the tensor has already the data
|
||||
alias (str): optional the alias to be used when serializing the data into an intermediate file
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
'''
|
||||
from .. import utils
|
||||
if utils.is_tensor(value):
|
||||
if utils.is_tensor(value) or utils.is_tensor_list(value):
|
||||
value = np.asarray(value)
|
||||
if dynamic_axis:
|
||||
cntk_shape = value[0].shape[1:]
|
||||
|
@ -1070,7 +1070,7 @@ def input(shape, dynamic_axis='', name=None):
|
|||
shape (tuple): the shape of the input tensor
|
||||
dynamic_axis (str or output of :func:`cntk.ops.dynamic_axis`): the dynamic axis
|
||||
name (str): the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
|
@ -1120,7 +1120,7 @@ def sparse_input_numpy(indices, values, shape, alias=None, dynamic_axis='', name
|
|||
alias (str): alias to be used in the data file
|
||||
dynamic_axis (str): whether the tensor has already the data
|
||||
alias (str): optional the alias to be used when serializing the data into an intermediate file
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
'''
|
||||
|
||||
|
@ -1145,7 +1145,7 @@ def sparse_input(shape, dynamic_axis='', name=None):
|
|||
shape (tuple): the shape of the input tensor
|
||||
dynamic_axis (str or output of :func:`cntk.ops.dynamic_axis`): the dynamic axis
|
||||
name (str): the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
|
@ -1169,7 +1169,7 @@ def parameter(shape=None, value=None, learning_rate_multiplier=1.0,
|
|||
init_from_file_path (str): the file that contains the initial tensor value. Used only if ``value=None``.
|
||||
name (str, optional): the name of the node in the network
|
||||
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
|
@ -1243,7 +1243,7 @@ def constant(value, name=None):
|
|||
Args:
|
||||
value: the tensor constant passed as numpy array
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
|
@ -1261,7 +1261,7 @@ def dynamic_axis(name=None):
|
|||
|
||||
Args:
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
|
@ -1284,7 +1284,7 @@ def reconcile_dynamic_axis(data_input, layout_input, name=None):
|
|||
data_input: the tensor to have its dynamic axis layout adapted
|
||||
layout_input: the tensor layout to use for adapting `data_input`s layout
|
||||
name: the name of the node in the network
|
||||
op =s:
|
||||
returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
"""
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class If(ComputationNode):
|
|||
self.thenVal = thenVal
|
||||
self.elseVal = elseVal
|
||||
self.params_with_defaults = []
|
||||
self.inputs = ['cond', 'thenVal', 'elseVal']
|
||||
|
||||
class Sign(ComputationNode):
|
||||
def __init__(self, x, op_name='Sign', name=None):
|
||||
|
|
|
@ -234,16 +234,16 @@ def test_op_identity(tensor, device_id, precision):
|
|||
|
||||
|
||||
TIMES_PAIRS = [
|
||||
([[30.]], [[10.]]),
|
||||
([[1.5, 2.1]], [[10.], [20.]]),
|
||||
#([[30.]], [[10.]]),
|
||||
#([[1.5, 2.1]], [[10.], [20.]]),
|
||||
([[100., 200.], [300., 400.]], [[10.], [20.]]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("left_operand, right_operand", TIMES_PAIRS)
|
||||
def test_op_times(left_operand, right_operand, device_id, precision,
|
||||
left_matrix_type, right_matrix_type):
|
||||
if left_matrix_type == 'sparse':
|
||||
pytest.skip('first operator of times() has to be dense')
|
||||
if right_matrix_type == 'sparse':
|
||||
pytest.skip('second operator of times() has to be dense')
|
||||
|
||||
dt = PRECISION_TO_TYPE[precision]
|
||||
# Forward pass test
|
||||
|
@ -253,13 +253,13 @@ def test_op_times(left_operand, right_operand, device_id, precision,
|
|||
# the first for sequences (length=1, since we have dynamic_axis='')
|
||||
# the second for batch of one sample
|
||||
expected = [[np.dot(AA(left_operand, dtype=dt), AA(right_operand, dtype=dt))]]
|
||||
|
||||
a = I([left_operand])
|
||||
|
||||
if right_matrix_type == 'sparse':
|
||||
b = SI(*batch_dense_to_sparse([right_operand]))
|
||||
|
||||
if left_matrix_type == 'sparse':
|
||||
a = SI(*batch_dense_to_sparse([left_operand]))
|
||||
else:
|
||||
b = I([right_operand])
|
||||
a = I([left_operand])
|
||||
|
||||
b = I([right_operand])
|
||||
|
||||
from cntk.ops import times, constant
|
||||
left_as_input = times(a, constant(right_operand))
|
||||
|
|
|
@ -156,9 +156,6 @@ def tensors_to_text_format(sample_idx, alias_tensor_map):
|
|||
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
|
||||
|
||||
def is_tensor(data):
|
||||
'''
|
||||
Checks whether the data is a tensor, i.e. whether it is a NumPy array or a
|
||||
|
@ -198,6 +195,14 @@ def is_tensor(data):
|
|||
|
||||
return True
|
||||
|
||||
def is_tensor_list(data):
|
||||
'''
|
||||
Checks whether the data is a CNTK sequence, which is expressed in Python as
|
||||
a list of varying sized NumPy objects.
|
||||
'''
|
||||
is_list = isinstance(data, list)
|
||||
return is_list and len(data) > 0 and isinstance(data[0], np.ndarray)
|
||||
|
||||
def get_temp_filename(directory=None):
|
||||
'''
|
||||
Create and return a temporary filename.
|
||||
|
@ -221,6 +226,14 @@ def get_temp_filename(directory=None):
|
|||
return tf.name
|
||||
|
||||
def wrap_numpy_arrays(node):
|
||||
'''
|
||||
for a given computation node, wrapes its tensor inputs that are numpy arrays
|
||||
into input and constant nodes
|
||||
|
||||
Args:
|
||||
node (:class:`cntk.graph.ComputationNode`): the computation node that will
|
||||
get its inputs wraped
|
||||
'''
|
||||
from ..graph import ComputationNode, _InputComputationNodeBase
|
||||
from ..ops import input_numpy, constant
|
||||
|
||||
|
@ -231,7 +244,7 @@ def wrap_numpy_arrays(node):
|
|||
for p in node.params:
|
||||
if p in node.inputs:
|
||||
val = getattr(node, p)
|
||||
if not isinstance(val, ComputationNode):
|
||||
if not (isinstance(val, ComputationNode) or isinstance(val, str)):
|
||||
# One param needs to be an Input() node. This will be fixed in
|
||||
# CNTK soon, so that we can remove this workaround and evaluate a
|
||||
# network with no inputs.
|
||||
|
|
|
@ -59,3 +59,14 @@ def test_tensor_conversion_dense(idx, alias_tensor_map, expected):
|
|||
def test_is_tensor(data, expected):
|
||||
assert is_tensor(data) == expected
|
||||
|
||||
@pytest.mark.parametrize("data, expected", [
|
||||
([], False),
|
||||
([1], False),
|
||||
([[1, 2]], False),
|
||||
([[]], False),
|
||||
([[AA([1, 2])]], False),
|
||||
([AA([1, 2])], True),
|
||||
([AA([1, 2]), AA([])], True),
|
||||
])
|
||||
def test_is_tensor_list(data, expected):
|
||||
assert is_tensor_list(data) == expected
|
||||
|
|
Загрузка…
Ссылка в новой задаче