Integrate wilrich/slice into master
This commit is contained in:
Коммит
6cef7b091e
|
@ -61,6 +61,18 @@ CNTK2 = [
|
|||
// 3. Shape operations
|
||||
// Changes: NewReshape -> Reshape, input -> _, dims -> shape
|
||||
Reshape(_, shape, beginAxis=0, endAxis=0, tag='') = new ComputationNode [ operation = 'Reshape' ; inputs = _ ; shape = new TensorShape [ /*shape*/ ] /*plus the function args*/ ]
|
||||
Slice(_, beginIndex, endIndex, axis=1, tag='') =
|
||||
if axis < 0 then [ # time axis: specify -1
|
||||
beginFlags = if beginIndex > 0 then BS.Boolean.Not (BS.Loop.IsFirstN (beginIndex, _)) else BS.Loop.IsLastN (-beginIndex, _)
|
||||
endFlags = if endIndex > 0 then BS.Loop.IsFirstN (endIndex, _) else BS.Boolean.Not (BS.Loop.IsLastN (-endIndex, _))
|
||||
flags = if beginIndex == 0 then endFlags
|
||||
else if endIndex == 0 then beginFlags
|
||||
else BS.Boolean.And (beginFlags, endFlags)
|
||||
out = if beginIndex == 0 && endIndex == 0
|
||||
then _
|
||||
else BS.Sequences.Gather (flags, _)
|
||||
].out
|
||||
else new ComputationNode [ operation = 'Slice' ; inputs = _ /*plus the function args*/ ] # non-time axis
|
||||
|
||||
// 4. Tensor operations
|
||||
// Changes: Matrix -> Tensor. A -> x, B -> y. Data must come on y ("default parameter") hence not using _
|
||||
|
@ -249,7 +261,7 @@ SumElements(matrix, tag='') = new ComputationNode [ operation = 'SumElements' ;
|
|||
# ^^ TODO: Rename to ReduceSumMB?
|
||||
Tanh(z, tag='') = new ComputationNode [ operation = 'Tanh' ; inputs = z /*plus the function args*/ ]
|
||||
TimeReverse(vectorSequence, tag='') = new ComputationNode [ operation = 'TimeReverse' ; inputs = vectorSequence /*plus the function args*/ ]
|
||||
Trace (node, say='', logFrequency=traceFrequency, logFirst=10, logGradientToo=false, onlyUpToRow=100000000, onlyUpToT=100000000, format=[], tag='') = new ComputationNode [ operation = 'Trace' ; inputs = node ]
|
||||
Trace (node, say='', logFrequency=100, logFirst=10, logGradientToo=false, onlyUpToRow=100000000, onlyUpToT=100000000, format=[], tag='') = new ComputationNode [ operation = 'Trace' ; inputs = node ]
|
||||
TransposeTimes(leftMatrix, rightMatrix, tag='') = new ComputationNode [ operation = 'TransposeTimes' ; inputs = (leftMatrix : rightMatrix) /*plus the function args*/ ]
|
||||
Where(cond, tag='') = new ComputationNode [ operation = 'Where' ; inputs = cond /*plus the function args*/ ]
|
||||
|
||||
|
|
|
@ -368,7 +368,7 @@ class LocalExecutionContext(AbstractContext):
|
|||
name (str): context name
|
||||
device_id (int): whether to use CPU (-1) or GPU if `device_id>=0`, in which case it denotes the GPU index
|
||||
precision (str): either float or double
|
||||
clean_up: whether the temporary directory should be removed when the context is left
|
||||
clean_up (bool): whether the temporary directory should be removed when the context is left
|
||||
'''
|
||||
|
||||
def __init__(self, name,
|
||||
|
@ -389,7 +389,6 @@ class LocalExecutionContext(AbstractContext):
|
|||
del _CONTEXT[self.name]
|
||||
if self.clean_up:
|
||||
sh.rmtree(self.directory)
|
||||
|
||||
|
||||
def _call_cntk(self, config_file_name, config_content, action_name):
|
||||
'''
|
||||
|
|
|
@ -49,9 +49,7 @@ def train_eval_logistic_regression_from_file(criterion_name=None,
|
|||
my_sgd = C.SGDParams(
|
||||
epoch_size=0, minibatch_size=25, learning_rates_per_mb=0.1, max_epochs=3)
|
||||
|
||||
with C.LocalExecutionContext('logreg') as ctx:
|
||||
ctx.device_id = device_id
|
||||
|
||||
with C.LocalExecutionContext('logreg', device_id=device_id, clean_up=True) as ctx:
|
||||
ctx.train(
|
||||
root_nodes=[ce, eval],
|
||||
training_params=my_sgd,
|
||||
|
|
|
@ -69,9 +69,7 @@ def train_eval_logistic_regression_with_numpy(criterion_name=None,
|
|||
my_sgd = C.SGDParams(epoch_size=0, minibatch_size=25,
|
||||
learning_rates_per_mb=0.1, max_epochs=3)
|
||||
|
||||
with C.LocalExecutionContext('logreg', clean_up=True) as ctx:
|
||||
ctx.device_id = device_id
|
||||
|
||||
with C.LocalExecutionContext('logreg_numpy', device_id=device_id, clean_up=True) as ctx:
|
||||
ctx.train(
|
||||
root_nodes=[ce,eval],
|
||||
training_params=my_sgd)
|
||||
|
|
|
@ -123,20 +123,61 @@ class ComputationNode(object):
|
|||
def __abs__(self):
|
||||
return ops.abs(self)
|
||||
|
||||
def __getitem__(self, so):
|
||||
if so.stop == None:
|
||||
raise ValueError('The stop index has to be provided')
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, int):
|
||||
# Case 1: e.g. data[3] -> key=3
|
||||
return ops.slice(self, key, key+1, axis=0)
|
||||
|
||||
if isinstance(so, int):
|
||||
return RowSlice(self, so, 1)
|
||||
elif isinstance(key, slice):
|
||||
# Case 2: e.g. data[2:4] -> key will be a slice object
|
||||
if key.step is not None:
|
||||
raise TypeError('step argument is not supported')
|
||||
if not isinstance(key.stop, int):
|
||||
raise TypeError('end index has to be of type int, not "%s"'%type(key.stop))
|
||||
|
||||
elif isinstance(so, slice):
|
||||
if so.step not in {1, None}:
|
||||
raise ValueError("RowSlice does not support strides")
|
||||
if isinstance(key.start, int):
|
||||
if key.stop<=key.start:
|
||||
raise ValueError('end index has to be greater than start index')
|
||||
return ops.slice(self, key.start or 0, key.stop or 0, axis=0)
|
||||
|
||||
start = so.start or 0
|
||||
elif isinstance(key, (tuple, list)):
|
||||
# Case 3: e.g. data[2:4,1:,1:7] -> key will be an iterable of ints
|
||||
# (case 1) or slices (case 2)
|
||||
# objects.
|
||||
# FIXME: we need to check that len(key) equals the node's rank
|
||||
node = self
|
||||
for ax_counter, so in enumerate(key):
|
||||
if isinstance(so, int):
|
||||
# Proceed as case 1
|
||||
node = ops.slice(node, so, so+1, axis=ax_counter)
|
||||
|
||||
elif isinstance(so, slice):
|
||||
# Proceed as case 2
|
||||
if so.step is not None:
|
||||
raise TypeError('step argument is not supported')
|
||||
if isinstance(so.start, int) and isinstance(so.stop, int):
|
||||
if so.stop<=so.start:
|
||||
raise ValueError('end index has to be greater than start index')
|
||||
if so.start is None and so.stop is None:
|
||||
continue
|
||||
node = ops.slice(node, so.start or 0, so.stop or 0, axis=ax_counter)
|
||||
elif isinstance(so, list):
|
||||
# Case 3b: e.g. data[[0],[2,3]] aka "advanced indexing" ->
|
||||
# so = ([0], [2,3])
|
||||
# In NumPy we would have another dimension, but since
|
||||
# data[0].shape != data[[0]].shape == data[[[0]]].shape ==
|
||||
# we decided to have all shapes like data[0] in this case
|
||||
for idx in so:
|
||||
if not isinstance(idx, int):
|
||||
raise IndexError('indices have to be of type int and not "%s"'%type(idx))
|
||||
node = ops.slice(node, idx, idx+1, axis=ax_counter)
|
||||
else:
|
||||
raise IndexError('type "%s" is not supported as index'%type(so))
|
||||
|
||||
return node
|
||||
else:
|
||||
raise TypeError('index must be int or slice, not {}'.format(type(key).__name__))
|
||||
|
||||
return RowSlice(self, start, so.stop - start)
|
||||
|
||||
# TODO more __operators__
|
||||
|
||||
|
|
|
@ -602,6 +602,49 @@ def reshape(x, shape, name=None):
|
|||
from cntk.ops.cntk1 import NewReshape
|
||||
return NewReshape(x, shape, 0, 0, name = name)
|
||||
|
||||
def slice(x, begin_index, end_index, axis=0, name=None):
|
||||
'''
|
||||
Slice the input along an axis.
|
||||
|
||||
Note:
|
||||
`axis` is zero-based as in Numpy, in contrast to CNTK, where 1 is the first axis.
|
||||
|
||||
Examples:
|
||||
>>> # create 2x3 matrix in a sequence of length 1 in a batch of one sample
|
||||
>>> data = np.asarray([[[1, 2, -3],
|
||||
... [4, 5, 6]]])
|
||||
>>> x = C.input_numpy(data)
|
||||
>>> # slice index 1 (second) at first axis
|
||||
>>> C.eval(C.slice(x, 1, 2, 0))
|
||||
[array([[[ 4., 5., 6.]]])]
|
||||
>>> # slice index 0 (first) at second axis
|
||||
>>> C.eval(C.slice(x, 0, 1, 1))
|
||||
[array([[[ 1.],
|
||||
[ 4.]]])]
|
||||
|
||||
NumPy's way of slicing works, too:
|
||||
|
||||
Examples:
|
||||
>>> C.eval(x[1])
|
||||
[array([[[ 4., 5., 6.]]])]
|
||||
>>> C.eval(x[:,:2,:])
|
||||
[array([[[ 1., 2.],
|
||||
[ 4., 5.]]])]
|
||||
|
||||
Args:
|
||||
arg: input tensor
|
||||
axis (int): axis along which `begin_index` and `end_index` will be used to slice the data.
|
||||
|
||||
See also:
|
||||
Indexing in NumPy: http://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
|
||||
|
||||
Returns:
|
||||
:class:`cntk.graph.ComputationNode`
|
||||
'''
|
||||
from cntk.ops.cntk2 import Slice
|
||||
cntk_axis = axis+1
|
||||
return Slice(x, begin_index, end_index, cntk_axis, name=name)
|
||||
|
||||
################################################################################
|
||||
# training ops
|
||||
################################################################################
|
||||
|
@ -785,3 +828,4 @@ def reconcile_dynamic_axis(data_input, layout_input, name=None):
|
|||
|
||||
from cntk.ops.cntk1 import ReconcileDynamicAxis
|
||||
return ReconcileDynamicAxis(data_input, layout_input, name=name)
|
||||
|
||||
|
|
|
@ -60,7 +60,6 @@ class If(ComputationNode):
|
|||
self.thenVal = thenVal
|
||||
self.elseVal = elseVal
|
||||
self.params_with_defaults = []
|
||||
self.inputs = ['cond', 'thenVal', 'elseVal']
|
||||
|
||||
class Sign(ComputationNode):
|
||||
def __init__(self, x, op_name='Sign', name=None):
|
||||
|
@ -635,9 +634,9 @@ class RectifiedLinear(ComputationNode):
|
|||
self.params_with_defaults = []
|
||||
self.inputs = ['z']
|
||||
|
||||
class ReducePlus(ComputationNode):
|
||||
def __init__(self, z, axis=0, op_name='ReducePlus', name=None):
|
||||
super(ReducePlus, self).__init__(params=['z', 'axis'], op_name=op_name, name=name)
|
||||
class ReduceSum(ComputationNode):
|
||||
def __init__(self, z, axis=0, op_name='ReduceSum', name=None):
|
||||
super(ReduceSum, self).__init__(params=['z', 'axis'], op_name=op_name, name=name)
|
||||
self.z = z
|
||||
self.axis = axis
|
||||
self.params_with_defaults = ['axis']
|
||||
|
@ -674,13 +673,6 @@ class Sin(ComputationNode):
|
|||
self.params_with_defaults = []
|
||||
self.inputs = ['z']
|
||||
|
||||
class Softmax(ComputationNode):
|
||||
def __init__(self, z, op_name='Softmax', name=None):
|
||||
super(Softmax, self).__init__(params=['z'], op_name=op_name, name=name)
|
||||
self.z = z
|
||||
self.params_with_defaults = []
|
||||
self.inputs = ['z']
|
||||
|
||||
class Hardmax(ComputationNode):
|
||||
def __init__(self, z, op_name='Hardmax', name=None):
|
||||
super(Hardmax, self).__init__(params=['z'], op_name=op_name, name=name)
|
||||
|
|
|
@ -8,6 +8,17 @@
|
|||
|
||||
from cntk.graph import ComputationNode, _InputComputationNodeBase, _ImageInputComputationNodeBase
|
||||
|
||||
class Slice(ComputationNode):
|
||||
def __init__(self, _, beginIndex, endIndex, axis=1, op_name='CNTK2.Slice',
|
||||
name=None):
|
||||
super(Slice, self).__init__(params=['_', 'beginIndex', 'endIndex', 'axis'], op_name=op_name, name=name)
|
||||
self._ = _
|
||||
self.beginIndex = beginIndex
|
||||
self.endIndex = endIndex
|
||||
self.axis = axis
|
||||
self.inputs = ['_']
|
||||
self.params_with_defaults = ['axis']
|
||||
|
||||
class Ceil(ComputationNode):
|
||||
def __init__(self, _, op_name='CNTK2.Ceil', name=None):
|
||||
super(Ceil, self).__init__(params=['_'], op_name=op_name, name=name)
|
||||
|
|
|
@ -9,8 +9,7 @@ import pytest
|
|||
from .ops_test_utils import unittest_helper, AA, I, precision, PRECISION_TO_TYPE
|
||||
from ...graph import *
|
||||
from ...reader import *
|
||||
from .. import reshape
|
||||
|
||||
import cntk as C
|
||||
|
||||
RESHAPE_TEST_CASES = [
|
||||
#(inputShape, outputShape, expectedOutputShape)
|
||||
|
@ -41,7 +40,7 @@ def test_op_reshape(inputShape, outputShape, expectedOutputShape, device_id, pre
|
|||
a = I([input_tensor])
|
||||
|
||||
# reshape into output shape
|
||||
reshaped_input = reshape(a, outputShape)
|
||||
reshaped_input = C.reshape(a, outputShape)
|
||||
|
||||
unittest_helper(reshaped_input, None, [[expected_tensor]], device_id=device_id,
|
||||
precision=precision, clean_up=True, backward_pass=False)
|
||||
|
@ -58,14 +57,135 @@ def test_op_reshape(inputShape, outputShape, expectedOutputShape, device_id, pre
|
|||
a = I([input_tensor])
|
||||
|
||||
# reshape into output shape
|
||||
reshaped_input = reshape(a, outputShape)
|
||||
reshaped_input = C.reshape(a, outputShape)
|
||||
|
||||
some_factor = 100
|
||||
weight = some_factor * expected_tensor
|
||||
output = reshaped_input * weight
|
||||
weight = expected_tensor * some_factor
|
||||
|
||||
output = reshaped_input * weight
|
||||
expected_gradient = input_tensor * some_factor
|
||||
|
||||
unittest_helper(output, None, [[expected_gradient]], device_id = device_id,
|
||||
precision=precision, clean_up=True, backward_pass=True, input_node=a)
|
||||
|
||||
|
||||
SLICE_TEST_CASES = [
|
||||
#(input_data, slice_params(beg_index, end_index,axis), expected_result)
|
||||
([[1,2],[-3,4]], (1,2,0), [[-3,4]]),
|
||||
([[1,2],[-3,4]], (1,2,1), [[2],[4]]),
|
||||
]
|
||||
@pytest.mark.parametrize("input_data, slice_params, expected_result", SLICE_TEST_CASES)
|
||||
def test_op_slice(input_data, slice_params, expected_result, device_id, precision):
|
||||
# Forward pass test
|
||||
#==================
|
||||
# We compute the expected output for the forward pass.
|
||||
# We need two surrounding brackets:
|
||||
# The first for sequences (length=1, since we have dynamic_axis='').
|
||||
# The second for batch of one sample.
|
||||
|
||||
a = I([input_data])
|
||||
def op_slice(x, beg_index, end_index, axis):
|
||||
return x[beg_index:end_index]
|
||||
|
||||
def _ax_slices(x, beg_index, end_index, axis):
|
||||
'''
|
||||
Creates a NumPy slicing array from slice operator's arguments
|
||||
'''
|
||||
ax_slices = []
|
||||
for i in range(0, len(x.shape)):
|
||||
if i==axis:
|
||||
if end_index >= x.shape[i]:
|
||||
ax_slices.append([beg_index,])
|
||||
else:
|
||||
ax_slices.append([beg_index,end_index])
|
||||
else:
|
||||
ax_slices.append(slice(None)) # corresponds to ':'
|
||||
return ax_slices
|
||||
|
||||
|
||||
# slice using the operator
|
||||
result = C.slice(a, *slice_params)
|
||||
|
||||
unittest_helper(result, None, [[expected_result]], device_id=device_id,
|
||||
precision=precision, clean_up=True, backward_pass=False)
|
||||
|
||||
# slice using the overload
|
||||
ax_slices = _ax_slices(a, *slice_params)
|
||||
result = a[ax_slices]
|
||||
|
||||
unittest_helper(result, None, [[expected_result]], device_id=device_id,
|
||||
precision=precision, clean_up=False, backward_pass=False)
|
||||
# Backward pass test
|
||||
# ==================
|
||||
# The gradient of the slice operator is a tensor of the same shape as the
|
||||
# input tensor, having 1 for elements that were taken and 0 for elements
|
||||
# that were dropped.
|
||||
|
||||
def grad_slice(x, beg_index, end_index, axis):
|
||||
res = np.zeros_like(x)
|
||||
ax_slices = _ax_slices(x, beg_index, end_index, axis)
|
||||
res[ax_slices] = x[ax_slices]
|
||||
res[res!=0] = 1
|
||||
return res
|
||||
|
||||
expected_gradient = grad_slice(np.asarray(input_data), *slice_params)
|
||||
|
||||
unittest_helper(result, None, [[expected_gradient]], device_id = device_id,
|
||||
precision=precision, clean_up=True, backward_pass=True, input_node=a)
|
||||
|
||||
def test_op_slice_overload(device_id, precision):
|
||||
# Testing ComputationNode's __getitem__ more thoroughly
|
||||
|
||||
input_data = np.arange(12).reshape(2,3,2)
|
||||
# array([[[ 0, 1],
|
||||
# [ 2, 3],
|
||||
# [ 4, 5]],
|
||||
# [[ 6, 7],
|
||||
# [ 8, 9],
|
||||
# [10, 11]]])
|
||||
a = I([input_data])
|
||||
|
||||
# simple index slicing
|
||||
result = a[1]
|
||||
|
||||
expected_result = \
|
||||
np.asarray([[
|
||||
[ 6, 7],
|
||||
[ 8, 9],
|
||||
[10, 11]]])
|
||||
unittest_helper(result, None, [[expected_result]], device_id=device_id,
|
||||
precision=precision, clean_up=True, backward_pass=False)
|
||||
|
||||
# slice a range along the middle axis
|
||||
result = a[:,1:,:]
|
||||
|
||||
expected_result = \
|
||||
np.asarray([[
|
||||
[ 2, 3],
|
||||
[ 4, 5]],
|
||||
[
|
||||
[ 8, 9],
|
||||
[10, 11]]])
|
||||
unittest_helper(result, None, [[expected_result]], device_id=device_id,
|
||||
precision=precision, clean_up=True, backward_pass=False)
|
||||
|
||||
# slice at the end
|
||||
result = a[:,:,1]
|
||||
|
||||
expected_result = \
|
||||
np.asarray([[
|
||||
[ 1],
|
||||
[ 3],
|
||||
[ 5]],
|
||||
[[ 7],
|
||||
[ 9],
|
||||
[11]]])
|
||||
unittest_helper(result, None, [[expected_result]], device_id=device_id,
|
||||
precision=precision, clean_up=True, backward_pass=False)
|
||||
|
||||
# do we properly handle bad user input?
|
||||
with pytest.raises(ValueError):
|
||||
result = a[:,:,2:1]
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
result = a[1,object(),2]
|
||||
|
|
|
@ -204,13 +204,13 @@ class CNTKTextFormatReader(AbstractReader):
|
|||
configuration = {
|
||||
'readerType': self.reader_type,
|
||||
'file': self.filename,
|
||||
'randomize': self.randomize,
|
||||
'randomize': str(self.randomize).lower(),
|
||||
'skipSequenceIds': str(self.skip_sequence_ids).lower(),
|
||||
'maxErrors': self.max_errors,
|
||||
'traceLevel': self.trace_level,
|
||||
'chunkSizeInBytes': self.chunk_size_in_bytes,
|
||||
'keepDataInMemory': self.keepDataInMemory,
|
||||
'frameMode': self.frameMode
|
||||
'keepDataInMemory': str(self.keepDataInMemory).lower(),
|
||||
'frameMode': str(self.frameMode).lower()
|
||||
}
|
||||
|
||||
template = '''
|
||||
|
|
|
@ -45,12 +45,13 @@ def test_overload_types(root_node, expected):
|
|||
|
||||
|
||||
def test_overload_exception():
|
||||
with pytest.raises(ValueError):
|
||||
C(range(0, 10))[:]
|
||||
c = C(list(range(0, 10)))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
C(range(0, 10))[0:3:2]
|
||||
with pytest.raises(TypeError):
|
||||
c[:]
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
c[0:3:2]
|
||||
|
||||
def _to_list(desc):
|
||||
return [line.strip() for line in desc.split('\n')]
|
||||
|
|
|
@ -15,11 +15,11 @@ import sys
|
|||
REGEX_STANDARD = re.compile(r'(?P<operator>\w+)\((?P<operands>.*?)\) = .*')
|
||||
REGEX_COMPNODE = re.compile(
|
||||
r'(?P<operator>\w+) ?\((?P<operands>.*?)\)\s*=\s*new\s*ComputationNode\s*\[\s*(?P<inputs>.*?inputs\s*=.*?[;|\/])?')
|
||||
REGEX_ALIAS = re.compile(r'(?P<operator>[A-Z]\w*)\s*=\s*(?P<alias>\w+)\s*(//.*|)')
|
||||
REGEX_ALIAS = re.compile(r'(?P<operator>[A-Z]\w*)\s*=\s*(?P<alias>\w+)\s*(//.*|#.*|)$')
|
||||
# ElementDivide(aMatrix, anotherMatrix, tag='') = ElementTimes(aMatrix,
|
||||
# Reciprocal(anotherMatrix))
|
||||
REGEX_INSTANTIATION = re.compile(
|
||||
r'(?P<operator>\w+)\((?P<operands>.*?)\)\s*=\s*(?P<inst_operator>\w+)\s*\((?P<inst_operands>.*?)\)\s*(//.*|)')
|
||||
r'(?P<operator>\w+)\((?P<operands>.*?)\)\s*=\s*(?P<inst_operator>\w+)\s*\((?P<inst_operands>.*?)\)\s*(//.*|#.*|)')
|
||||
|
||||
REGEX_COMMENT = re.compile(r'/\*.*\*/')
|
||||
|
||||
|
@ -285,6 +285,17 @@ CNTK2_MANUAL_PREFIX = """\
|
|||
|
||||
from cntk.graph import ComputationNode, _InputComputationNodeBase, _ImageInputComputationNodeBase
|
||||
|
||||
class Slice(ComputationNode):
|
||||
def __init__(self, _, beginIndex, endIndex, axis=1, op_name='CNTK2.Slice',
|
||||
name=None):
|
||||
super(Slice, self).__init__(params=['_', 'beginIndex', 'endIndex', 'axis'], op_name=op_name, name=name)
|
||||
self._ = _
|
||||
self.beginIndex = beginIndex
|
||||
self.endIndex = endIndex
|
||||
self.axis = axis
|
||||
self.inputs = ['_']
|
||||
self.params_with_defaults = ['axis']
|
||||
|
||||
class Ceil(ComputationNode):
|
||||
def __init__(self, _, op_name='CNTK2.Ceil', name=None):
|
||||
super(Ceil, self).__init__(params=['_'], op_name=op_name, name=name)
|
||||
|
@ -361,7 +372,11 @@ def convert_bs_to_python(bs_fn, out_dir):
|
|||
comp_match = REGEX_COMPNODE.match(line)
|
||||
if comp_match:
|
||||
ns = 'CNTK2.' if part_of_file==CNTK2_SECT else ''
|
||||
op = CompNodeOperator(comp_match, ns)
|
||||
try:
|
||||
op = CompNodeOperator(comp_match, ns)
|
||||
except ValueError:
|
||||
print('ERROR while parsing: %s'%line)
|
||||
continue
|
||||
if op.name in OPERATORS_TO_IGNORE and part_of_file==COMP_NODE_SECT:
|
||||
continue
|
||||
pyf.write(str(op) + '\n')
|
||||
|
|
Загрузка…
Ссылка в новой задаче