add splice node to the python api
This commit is contained in:
Родитель
0110d4646b
Коммит
319f3c70ce
|
@ -981,15 +981,52 @@ def slice(x, begin_index, end_index, axis=0, name=None):
|
|||
else:
|
||||
cntk_axis = abs(axis) if axis<0 else x.rank - axis
|
||||
op = Slice(x, begin_index, end_index, cntk_axis, name=name)
|
||||
wrap_numpy_arrays(op)
|
||||
op.rank = op._.rank
|
||||
return op
|
||||
|
||||
def splice(inputs, begin_index, end_index, axis=0, name=None):
|
||||
'''
|
||||
Concatenate the input tensors along an axis.
|
||||
|
||||
Examples:
|
||||
>>> # create 2x2 matrix in a sequence of length 1 in a batch of one sample
|
||||
>>> data1 = np.asarray([[[1, 2],
|
||||
... [4, 5]]])
|
||||
>>> x = C.input_numpy(data1)
|
||||
>>> # create 3x2 matrix in a sequence of length 1 in a batch of one sample
|
||||
>>> data2 = np.asarray([[[10, 20],
|
||||
... [30, 40],
|
||||
... [50, 60]]])
|
||||
>>> y = C.input_numpy(data2)
|
||||
>>> # splice both inputs on axis=0 returns a 5x2 matrix
|
||||
>>> C.eval(C.splice([x,y], 0))
|
||||
[array([[[1, 2],
|
||||
[4, 5],
|
||||
[10, 20],
|
||||
[30, 40],
|
||||
[50, 60]]])]
|
||||
|
||||
Args:
|
||||
inputs (list): list of input tensors
|
||||
axis (int): axis along which the concatenation will be performed
|
||||
|
||||
Returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
'''
|
||||
from cntk.ops.cntk2 import Splice
|
||||
#cntk uses column major, thus it will read the indices of data passed from
|
||||
# python in reverse
|
||||
cntk_axis = abs(axis) if axis<0 else inputs[0].rank - axis
|
||||
op = Splice(inputs, cntk_axis, name=name)
|
||||
wrap_numpy_arrays(op)
|
||||
op.rank = op._[0].rank
|
||||
return op
|
||||
|
||||
################################################################################
|
||||
# training ops
|
||||
################################################################################
|
||||
|
||||
# unittests might require training and testing at the same time ? which
|
||||
# sounds more like end2end test ?
|
||||
|
||||
def dropout(x, name=None):
|
||||
"""
|
||||
Compute a new tensor with `dropoutRate` perecent set to zero. The values
|
||||
|
@ -998,10 +1035,6 @@ def dropout(x, name=None):
|
|||
|
||||
The output tensor has the same shape as `x`, but with `dropoutRate` of the
|
||||
elements set to zero (droped out).
|
||||
|
||||
|
||||
Examples:
|
||||
TBA
|
||||
|
||||
Args:
|
||||
x: source tensor
|
||||
|
|
|
@ -19,6 +19,15 @@ class Slice(ComputationNode):
|
|||
self.inputs = ['_']
|
||||
self.params_with_defaults = ['axis']
|
||||
|
||||
class Splice(ComputationNode):
|
||||
def __init__(self, _, axis=1, op_name='CNTK2.Splice',
|
||||
name=None):
|
||||
super(Splice, self).__init__(params=['_', 'axis'], op_name=op_name, name=name)
|
||||
self._ = _
|
||||
self.axis = axis
|
||||
self.inputs = ['_']
|
||||
self.params_with_defaults = ['axis']
|
||||
|
||||
class Ceil(ComputationNode):
|
||||
def __init__(self, _, op_name='CNTK2.Ceil', name=None):
|
||||
super(Ceil, self).__init__(params=['_'], op_name=op_name, name=name)
|
||||
|
|
|
@ -261,6 +261,7 @@ def wrap_numpy_arrays(node):
|
|||
for p in node.params:
|
||||
if p in node.inputs:
|
||||
val = getattr(node, p)
|
||||
#TODO: add support to list of numpy arrays, e.g. Splice()
|
||||
if not (isinstance(val, ComputationNode) or isinstance(val, str)):
|
||||
# One param needs to be an Input() node. This will be fixed in
|
||||
# CNTK soon, so that we can remove this workaround and evaluate a
|
||||
|
|
|
@ -297,6 +297,15 @@ class Slice(ComputationNode):
|
|||
self.inputs = ['_']
|
||||
self.params_with_defaults = ['axis']
|
||||
|
||||
class Splice(ComputationNode):
|
||||
def __init__(self, _, axis=1, op_name='CNTK2.Splice',
|
||||
name=None):
|
||||
super(Splice, self).__init__(params=['_', 'axis'], op_name=op_name, name=name)
|
||||
self._ = _
|
||||
self.axis = axis
|
||||
self.inputs = ['_']
|
||||
self.params_with_defaults = ['axis']
|
||||
|
||||
class Ceil(ComputationNode):
|
||||
def __init__(self, _, op_name='CNTK2.Ceil', name=None):
|
||||
super(Ceil, self).__init__(params=['_'], op_name=op_name, name=name)
|
||||
|
|
Загрузка…
Ссылка в новой задаче