This commit is contained in:
jeanfad 2016-06-10 14:44:16 +02:00
Родитель 2c4c17c56f
Коммит 15aa19019e
3 изменённых файлов: 70 добавлений и 68 удалений

Просмотреть файл

@ -26,15 +26,38 @@ IntDiv(x, y) = new NumericFunction [ what = 'IntDiv' ; args = (x:y) ]
##############################################################################
# comparison functions
# aliases
##############################################################################
Less = CNTK2.Less
Equal = CNTK2.Equal
Greater = CNTK2.Greater
GreaterEqual = CNTK2.GreaterEqual
NotEqual = CNTK2.NotEqual
LessEqual = CNTK2.LessEqual
Less = CNTK2.Less
Equal = CNTK2.Equal
Greater = CNTK2.Greater
GreaterEqual = CNTK2.GreaterEqual
NotEqual = CNTK2.NotEqual
LessEqual = CNTK2.LessEqual
Splice = CNTK2.Splice
NewReshape = CNTK2.Reshape
Slice = CNTK2.Slice
TransposeDimensions = CNTK2.TransposeDimensions
Times = CNTK2.Times
Abs = CNTK2.Abs
Ceil = CNTK2.Ceil
CrossEntropyWithSoftmax = CNTK2.CrossEntropyWithSoftmax
Dropout = CNTK2.Dropout
ElementTimes = CNTK2.ElementTimes
ElementDivide = CNTK2.ElementDivide
ErrorPrediction = CNTK2.ErrorPrediction
Exp = CNTK2.Exp
Floor = CNTK2.Floor
Log = CNTK2.Log
Minus = CNTK2.Minus
Pass = CNTK2.Identity
Plus = CNTK2.Plus
RectifiedLinear = CNTK2.Relu
ReduceSum = CNTK2.ReduceSum
ReduceLogSum = CNTK2.ReduceLogSum
Round = CNTK2.Round
Sigmoid = CNTK2.Sigmoid
##############################################################################
# ComputationNodes
@ -120,12 +143,11 @@ CNTK2 = [
Square(_, tag='') = ElementTimes(_, _, tag=tag)
Tanh(_, tag='') = new ComputationNode [ operation = 'Tanh' ; inputs = _ /*plus the function args*/ ]
// 6. Reductions
ReduceSum (_, axis=0, tag='') = new ComputationNode [ operation = 'ReduceElements' ; inputs = _ ; reductionOp = "Sum" /*plus the function args*/ ]
// 6. Reductions
# the following is a temporary workaround until we have the C++ version
ReduceLogSum (_, axis=0, tag='') = if axis != 0 then Fail("ReduceLogSum for now only supports axis=0.")
else [ tag1=tag ; axis1=axis ; out = RowSlice (0, 1, _ - LogSoftmax (_), tag=tag1) ].out
ReduceSum (_, axis=0, tag='') = new ComputationNode [ operation = 'ReduceElements' ; inputs = _ ; reductionOp = "Sum" /*plus the function args*/ ]
// 7. Control flow (if, composite etc.)
// None so far
@ -150,7 +172,7 @@ CNTK2 = [
CrossEntropyWithSoftmax(_, outProbVectorSequence, tag='') = new ComputationNode [ operation = 'CrossEntropyWithSoftmax' ; inputs = (_ : outProbVectorSequence) /*plus the function args*/ ]
ErrorPrediction(_, outVectorSequence, topN=1, tag='') = new ComputationNode [ operation = 'ErrorPrediction' ; inputs = if topN == 1 then (_ : outVectorSequence) else (_ : outVectorSequence : Constant (topN)) /*plus the function args*/ ]
// 13. Comparison nodes
// 12. Comparison nodes
Less(_, y, tag='') = new ComputationNode [ operation = 'Less' ; inputs = (_ : y) /*plus the function args*/ ]
Equal(_, y, tag='') = new ComputationNode [ operation = 'Equal' ; inputs = (_ : y) /*plus the function args*/ ]
Greater(_, y, tag='') = new ComputationNode [ operation = 'Greater' ; inputs = (_ : y) /*plus the function args*/ ]
@ -158,8 +180,7 @@ CNTK2 = [
NotEqual(_, y, tag='') = new ComputationNode [ operation = 'NotEqual' ; inputs = (_ : y) /*plus the function args*/ ]
LessEqual(_, y, tag='') = new ComputationNode [ operation = 'LessEqual' ; inputs = (_ : y) /*plus the function args*/ ]
// 13. Others
// 12. Others
// 13. Others
Identity(_, tag='') = new ComputationNode [ operation = 'Pass' ; inputs = _ /*plus the function args*/ ]
]
@ -184,36 +205,12 @@ Shift(input, fromOffset, boundaryValue, boundaryMode=-1/*context*/, dim=-1, tag=
RowSlice(beginIndex, numRows, input, tag='') = Slice(beginIndex, beginIndex + numRows, input, axis = 1)
RowRepeat(input, numRepeats, tag='') = new ComputationNode [ operation = 'RowRepeat' ; inputs = input /*plus the function args*/ ]
RowStack(inputs, tag='') = new ComputationNode [ operation = 'RowStack' /*plus the function args*/ ]
Splice (inputs, axis=1, tag='') = # TODO: This is a workaround. RowStack itself shall interpret 'axis' and be renamed to Splice().
if axis < 1 then Fail('Splice does not yet implement splicing the time axis.')
else if axis == 1 then [tag1=tag; out = RowStack (inputs, tag=tag1)].out
else [ # workaround: swap 'axis' to first position, RowStack, swap back
ArrayTransposeDimensions (inputs, axis1, axis2) = [ # transpose each element of a BS array
inputsT[i:0..Length(inputs)-1] = TransposeDimensions (inputs[i], axis1, axis2)
].inputsT
out = [tag1=tag; out=TransposeDimensions (RowStack (ArrayTransposeDimensions (inputs, 1, axis)), 1, axis, tag=tag)].out
].out
Reshape(input, numRows, imageWidth = 0, imageHeight = 0, imageChannels = 0, tag='') = new ComputationNode [ operation = 'LegacyReshape' ; inputs = input /*plus the function args*/ ]
NewReshape(input, dims, beginAxis=0, endAxis=0, tag='') = new ComputationNode [ operation = 'Reshape' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]
ReshapeDimension(x, axis, tensorShape) = NewReshape(x, tensorShape, beginAxis=axis, endAxis=axis + 1)
FlattenDimensions(x, axis, num) = NewReshape(x, 0, beginAxis=axis, endAxis=axis + num)
Slice(beginIndex, endIndex, input, axis=1, tag='') =
if axis < 0 then [ # time axis: specify -1
beginFlags = if beginIndex > 0 then BS.Boolean.Not (BS.Loop.IsFirstN (beginIndex, input)) else BS.Loop.IsLastN (-beginIndex, input)
endFlags = if endIndex > 0 then BS.Loop.IsFirstN (endIndex, input) else BS.Boolean.Not (BS.Loop.IsLastN (-endIndex, input))
flags = if beginIndex == 0 then endFlags
else if endIndex == 0 then beginFlags
else BS.Boolean.And (beginFlags, endFlags)
out = if beginIndex == 0 && endIndex == 0
then input
else BS.Sequences.Gather (flags, input)
].out
else new ComputationNode [ operation = 'Slice' ; inputs = input /*plus the function args*/ ] # non-time axis
SplitDimension(x, axis, N) = ReshapeDimension(x, axis, 0:N)
TransposeDimensions(input, axis1, axis2, tag='') = new ComputationNode [ operation = 'TransposeDimensions' ; inputs = input /*plus the function args*/ ]
# TODO: make input the last arg!
Transpose(x) = TransposeDimensions(x, 1, 2)
Times(A, B, outputRank=1, tag='') = new ComputationNode [ operation = 'Times' ; inputs = ( A : B ) /*plus the function args*/ ]
Logistic(label, probability, tag='') = new ComputationNode [ operation = 'Logistic' ; inputs = (label : probability) /*plus the function args*/ ]
WeightedLogistic(label, probability, instanceWeight, tag='') = new ComputationNode [ operation = 'Logistic' ; inputs = (label : probability : instanceWeight) /*plus the function args*/ ]
ReconcileDynamicAxis(dataInput, layoutInput, tag='') = new ComputationNode [ operation = 'ReconcileDynamicAxis' ; inputs = (dataInput : layoutInput) /*plus the function args*/ ]
@ -231,8 +228,6 @@ ClassificationError = ErrorPrediction
Delay = PastValue
BatchNormalization(input, scale, bias, runMean, runInvStdDev, spatial, normalizationTimeConstant = 0, blendTimeConstant = 0, epsilon = 0.00001, useCntkEngine = true, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'BatchNormalization' ; inputs = (input : scale : bias : runMean : runInvStdDev) /*plus the function args*/ ]
Abs(x, tag='') = new ComputationNode [ operation = 'Abs' ; inputs = x /*plus the function args*/ ]
Ceil(x, tag='') = Negate(Floor(Negate(x)), tag=tag)
ClassBasedCrossEntropyWithSoftmax(labelClassDescriptorVectorSequence, mainInputInfo, mainWeight, classLogProbsBeforeSoftmax, tag='') = new ComputationNode [ operation = 'ClassBasedCrossEntropyWithSoftmax' ; inputs = (labelClassDescriptorVectorSequence : mainInputInfo : mainWeight : classLogProbsBeforeSoftmax) /*plus the function args*/ ]
Clip(minValue, maxValue, x, tag='') = new ComputationNode [ operation = 'Clip' ; inputs = (minValue : maxValue : x) /* plus the function args*/ ]
ColumnElementTimes(aVectorSequence, anotherVectorSequence, tag='') = new ComputationNode [ operation = 'ColumnElementTimes' ; inputs = (aVectorSequence : anotherVectorSequence) /*plus the function args*/ ]
@ -241,50 +236,33 @@ CosDistance(aVectorSequence, anotherVectorSequence, tag='') = new ComputationNod
CosDistanceWithNegativeSamples(aVectorSequence, anotherVectorSequence, numShifts, numNegSamples, tag='') = new ComputationNode [ operation = 'CosDistanceWithNegativeSamples' ; inputs = (aVectorSequence : anotherVectorSequence : numShifts : numNegSamples) /*plus the function args*/ ]
Cosine(x, tag='') = new ComputationNode [ operation = 'Cosine' ; inputs = x /*plus the function args*/ ]
CrossEntropy(refProbVectorSequence, outProbVectorSequence, tag='') = new ComputationNode [ operation = 'CrossEntropy' ; inputs = (refProbVectorSequence : outProbVectorSequence) /*plus the function args*/ ]
CrossEntropyWithSoftmax(labelVectorSequence, outProbVectorSequence, tag='') = new ComputationNode [ operation = 'CrossEntropyWithSoftmax' ; inputs = (labelVectorSequence : outProbVectorSequence) /*plus the function args*/ ]
# once ReduceLogSum becomes proper C++, CrossEntropyWithSoftmax() will become this:
NewCrossEntropyWithSoftmax (labelSequence, z, tag='') = [ tag1 = tag; out = Minus (ReduceLogSum (z), ReduceSum (labelSequence .* z), tag=tag1) ].out
DiagTimes(diagonalMatrixAsColumnVector, matrix, tag='') = new ComputationNode [ operation = 'DiagTimes' ; inputs = (diagonalMatrixAsColumnVector : matrix) /*plus the function args*/ ]
// TODO: DiagTimes = ElementTimes
Dropout(activationVectorSequence, tag='') = new ComputationNode [ operation = 'Dropout' ; inputs = activationVectorSequence /*plus the function args*/ ]
ElementTimes(aMatrix, anotherMatrix, tag='') = new ComputationNode [ operation = 'ElementTimes' ; inputs = (aMatrix : anotherMatrix) /*plus the function args*/ ]
ElementDivide(aMatrix, anotherMatrix, tag='') = ElementTimes(aMatrix, Reciprocal(anotherMatrix), tag=tag)
ErrorPrediction = CNTK2.ErrorPrediction
Exp(x, tag='') = new ComputationNode [ operation = 'Exp' ; inputs = x /*plus the function args*/ ]
Floor(x, tag='') = new ComputationNode [ operation = 'Floor' ; inputs = x /*plus the function args*/ ]
GatherPacked(indexSequence, sourceData, tag='') = new ComputationNode [ operation = 'GatherPacked' ; inputs = (indexSequence : sourceData) /*plus the function args*/ ]
GMMLogLikelihood(unnormalizedPriorVector, meansAsRows, logStdDevAsRows, dataVectorSequence, tag='') = new ComputationNode [ operation = 'GMMLogLikelihood' ; inputs = (unnormalizedPriorVector : meansAsRows : logStdDevAsRows : dataVectorSequence) /*plus the function args*/ ]
InvStdDev(dataVectorSequence, tag='') = new ComputationNode [ operation = 'InvStdDev' ; inputs = dataVectorSequence /*plus the function args*/ ]
KhatriRaoProduct(leftMatrix, rightMatrix, tag='') = new ComputationNode [ operation = 'KhatriRaoProduct' ; inputs = (leftMatrix : rightMatrix) /*plus the function args*/ ]
Log(x, tag='') = new ComputationNode [ operation = 'Log' ; inputs = x /*plus the function args*/ ]
LogPlus(leftMatrix, rightMatrix, tag='') = new ComputationNode [ operation = 'LogPlus' ; inputs = (leftMatrix : rightMatrix) /*plus the function args*/ ]
LogSoftmax(z, tag='') = new ComputationNode [ operation = 'LogSoftmax' ; inputs = z /*plus the function args*/ ]
# TODO: ^^ along axis, like Softmax
MatrixL1Reg(matrix, tag='') = new ComputationNode [ operation = 'MatrixL1Reg' ; inputs = matrix /*plus the function args*/ ]
MatrixL2Reg(matrix, tag='') = new ComputationNode [ operation = 'MatrixL2Reg' ; inputs = matrix /*plus the function args*/ ]
Mean(dataVectorSequence, tag='') = new ComputationNode [ operation = 'Mean' ; inputs = dataVectorSequence /*plus the function args*/ ]
Minus(leftMatrix, rightMatrix, tag='') = new ComputationNode [ operation = 'Minus' ; inputs = (leftMatrix : rightMatrix) /*plus the function args*/ ]
Negate(input, tag='') = new ComputationNode [ operation = 'Negate' ; inputs = input /*plus the function args*/ ]
PackedIndex(targetObject, indexSequence, tag='') = new ComputationNode [ operation = 'PackedIndex' ; inputs = (targetObject : indexSequence) /*plus the function args*/ ]
Pass(x, tag='') = new ComputationNode [ operation = 'Pass' ; inputs = x /*plus the function args*/ ]
PerDimMeanVarDeNormalization(dataVectorSequence, meanVector, invStdDevVector, tag='') = new ComputationNode [ operation = 'PerDimMeanVarDeNormalization' ; inputs = (dataVectorSequence : meanVector : invStdDevVector) /*plus the function args*/ ]
PerDimMeanVarNormalization(dataVectorSequence, meanVector, invStdDevVector, tag='') = new ComputationNode [ operation = 'PerDimMeanVarNormalization' ; inputs = (dataVectorSequence : meanVector : invStdDevVector) /*plus the function args*/ ]
Plus(leftMatrix, rightMatrix, tag='') = new ComputationNode [ operation = 'Plus' ; inputs = (leftMatrix : rightMatrix) /*plus the function args*/ ]
Reciprocal(z, tag='') = new ComputationNode [ operation = 'Reciprocal' ; inputs = z /*plus the function args*/ ]
RectifiedLinear(z, tag='') = new ComputationNode [ operation = 'RectifiedLinear' ; inputs = z /*plus the function args*/ ]
ReduceSum (z, axis=0, tag='') = new ComputationNode [ operation = 'ReduceElements' ; inputs = z ; reductionOp = "Sum" /*plus the function args*/ ]
# the following is a temporary workaround until we have the C++ version
ReduceLogSum (z, axis=0, tag='') = if axis != 0 then Fail("ReduceLogSum for now only supports axis=0.")
else [ tag1=tag ; axis1=axis ; out = RowSlice (0, 1, z - LogSoftmax (z), tag=tag1) ].out
//# the following is a temporary workaround until we have the C++ version
#ReduceLogSum (z, axis=0, tag='') = new ComputationNode [ operation = 'ReduceElements' ; inputs = z ; reductionOp = "LogSum" /*plus the function args*/ ]
#ReduceMean (z, axis=0, tag='') = new ComputationNode [ operation = 'ReduceElements' ; inputs = z ; reductionOp = "Mean" /*plus the function args*/ ]
#ReduceMax (z, axis=0, tag='') = new ComputationNode [ operation = 'ReduceElements' ; inputs = z ; reductionOp = "Max" /*plus the function args*/ ]
#ReduceMin (z, axis=0, tag='') = new ComputationNode [ operation = 'ReduceElements' ; inputs = z ; reductionOp = "Min" /*plus the function args*/ ]
Round(x, tag='') = Floor(Plus(x, ConstantTensor(0.5, (1))), tag=tag)
Scale(scalarScalingFactor, matrix, tag='') = new ComputationNode [ operation = 'Scale' ; inputs = (scalarScalingFactor : matrix) /*plus the function args*/ ]
# TODO: Scale = ElementTimes
ScatterPacked(cond, indexSequence, sourceData, tag='') = new ComputationNode [ operation = 'ScatterPacked' ; inputs = (cond : indexSequence : sourceData) /*plus the function args*/ ]
Sigmoid(z, tag='') = new ComputationNode [ operation = 'Sigmoid' ; inputs = z /*plus the function args*/ ]
Sin(z, tag='') = new ComputationNode [ operation = 'Sin' ; inputs = z /*plus the function args*/ ]
Softmax (z, axis=0, tag='') = # TODO: replace this with more efficient version below once we have ReduceLogSum
if axis == 0 then new ComputationNode [ operation = 'Softmax' ; inputs = z /*plus the function args*/ ]

Просмотреть файл

@ -416,6 +416,7 @@ def identity(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -547,8 +548,8 @@ def clip(x, min_value, max_value, name=None):
Args:
x: tensor to be clipped
min_value: the minimum value to clip element values to
max_value: the maximum value to clip element values to
min_value: a scalar or a tensor which represents the minimum value to clip element values to
max_value: a scalar or a tensor which represents the maximum value to clip element values to
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
@ -572,6 +573,7 @@ def relu(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -595,6 +597,7 @@ def sigmoid(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -617,6 +620,7 @@ def tanh(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -644,6 +648,7 @@ def softmax(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -665,6 +670,7 @@ def exp(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -684,6 +690,7 @@ def log(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
@ -711,6 +718,7 @@ def sqrt(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
@ -734,6 +742,7 @@ def square(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -755,6 +764,7 @@ def abs(x, name=None):
Args:
x: numpy array or any :class:`cntk.graph.ComputationNode` that outputs a tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -815,8 +825,9 @@ def future_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None
Args:
shape (tuple): dimensions of the input `x`, the shape will be inferred if zero is passed.
x: the tensor (or its name) from which the future value is obtained.
time_step: the number of time steps to look into the future (default 1)
default_hidden_activation: the default value to use when no future value is available (default 0.1)
time_step (int): the number of time steps to look into the future (default 1)
default_hidden_activation (number): the default value to use when no future value is available (default 0.1)
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -849,8 +860,9 @@ def past_value(shape, x, time_step=1, default_hidden_activation=0.1, name=None):
Args:
shape (tuple): dimensions of the input `x`, the shape will be inferred if zero is passed.
x: the tensor (or its name) from which the past value is obtained
time_step: the number of time steps to look into the past (default 1)
default_hidden_activation: the default value to use when no past value is available (default 0.1)
time_step (int): the number of time steps to look into the past (default 1)
default_hidden_activation (number): the default value to use when no past value is available (default 0.1)
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -910,6 +922,7 @@ def transpose_dimensions(x, axis1, axis2, name=None):
x: tensor to be reshaped
axis1 (int): the axis to swap with axis2
axis2 (int): the axis to swap with axis1
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -954,7 +967,8 @@ def slice(x, begin_index, end_index, axis=0, name=None):
begin_index (int): the index along axis where the slicing starts
end_index (int): the index along axis where the slicing ends
axis (int or str): axis along which `begin_index` and `end_index` will be used. If axis is of type `str` then the time axis will be used.
name (str): the name of the node in the network
See also:
Indexing in NumPy: http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
@ -1000,6 +1014,7 @@ def splice(inputs, axis=0, name=None):
Args:
inputs (list): tuple of input tensors
axis (int): axis along which the concatenation will be performed
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
@ -1049,6 +1064,7 @@ def reduce_sum(x, axis=0, name=None):
Args:
x: input tensor
axis (int): axis along which the reduction will be performed
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
@ -1077,6 +1093,7 @@ def reduce_log_sum(inputs, name=None):
Args:
x: input tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
@ -1108,6 +1125,8 @@ def dropout(x, name=None):
Args:
x: source tensor
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -1135,6 +1154,8 @@ def input_numpy(value, alias=None, dynamic_axis='', name=None):
alias (str): alias to be used in the data file
dynamic_axis (str): whether the tensor has already the data
alias (str): optional the alias to be used when serializing the data into an intermediate file
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
'''
@ -1172,6 +1193,7 @@ def input(shape, dynamic_axis='', name=None):
shape (tuple): the shape of the input tensor
dynamic_axis (str or output of :func:`cntk.ops.dynamic_axis`): the dynamic axis
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
"""
@ -1222,6 +1244,8 @@ def sparse_input_numpy(indices, values, shape, alias=None, dynamic_axis='', name
alias (str): alias to be used in the data file
dynamic_axis (str): whether the tensor has already the data
alias (str): optional the alias to be used when serializing the data into an intermediate file
name (str): the name of the node in the network
Returns:
:class:`cntk.graph.ComputationNode`
'''

Просмотреть файл

@ -30,13 +30,13 @@ def test_op_reduce_sum(input_data, axis, device_id, precision):
# The first for sequences (length=1, since we have dynamic_axis='').
# The second for batch of one sample.
# keepdims = True as CNTK keep them as well
# keepdims = True as CNTK keeps them as well
def reduce_sum(x, axis, keepdims=True):
x_aa = AA(x)
if axis == len(x_aa.shape):
return [AA(np.reshape(np.add.reduce(np.ravel(x_aa)), (1,1)))]
return [[AA(np.add.reduce(x, axis, dtype=PRECISION_TO_TYPE[precision],
keepdims=keepdims))]]
return [np.reshape(np.add.reduce(np.ravel(x_aa)), (1,1))]
return [[np.add.reduce(x_aa, axis, dtype=PRECISION_TO_TYPE[precision],
keepdims=keepdims)]]
expected_result = reduce_sum(input_data, axis)