This commit is contained in:
REDMOND\sayanpa 2017-02-06 17:15:05 -08:00
Родитель 777c9cca07 4625934aec
Коммит a335e0a604
2 изменённых файлов: 19 добавлений и 2 удалений

Просмотреть файл

@ -581,6 +581,7 @@ public:
{
auto node = dynamic_pointer_cast<RowStackNode<ElemType>>(nodeP);
node->m_firstIndices = m_firstIndices;
node->m_spliceDim = m_spliceDim;
}
}

Просмотреть файл

@ -13,7 +13,7 @@ import pytest
from ..functions import *
from ...trainer import *
from ...initializer import glorot_uniform
from .. import constant, parameter, input_variable, placeholder_variable, times, plus, past_value, sequence, as_composite, combine
from .. import constant, parameter, input_variable, placeholder_variable, times, plus, past_value, sequence, as_composite, combine, convolution, splice
from ... import InferredDimension
from .ops_test_utils import compare_lists_of_np_arrays, AA
@ -21,7 +21,6 @@ def test_variable_forwarding():
op = constant(value=2, shape=(3,4)) + 1
assert op.shape == (3,4)
def test_replace_placeholders():
p = placeholder_variable(shape=(1,))
i = input_variable(shape=(1,),
@ -214,6 +213,23 @@ def test_clone_with_function_in_substitution_map():
just_b = t_plus_b.clone('clone', {t : p})
t_plus_b_clone = just_b.clone('share', {p : t})
def test_clone_with_slice():
i1 = input_variable((2,2), name='i1')
i2 = input_variable((2,2), name='i2')
x = splice((i1,i2), 0)
W = constant(1, (4,1), name='W')
y = convolution(W, x)
assert(y.shape == (4,2))
from ..functions import CloneMethod
x1 = input_variable((2,1), name='x1')
x2 = input_variable((2,1), name='x2')
p1 = placeholder_variable()
p2 = placeholder_variable()
y_cloned = y.clone('clone', {i1:p1, i2:p2})
y2 = y_cloned(x1, x2)
assert(y2.shape == (4,1))
def test_as_composite():
input_dim = 1
proj_dim = 2