Integrate chazhang/rowstack_clone into master
This commit is contained in:
Коммит
4625934aec
|
@ -581,6 +581,7 @@ public:
|
|||
{
|
||||
auto node = dynamic_pointer_cast<RowStackNode<ElemType>>(nodeP);
|
||||
node->m_firstIndices = m_firstIndices;
|
||||
node->m_spliceDim = m_spliceDim;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import pytest
|
|||
from ..functions import *
|
||||
from ...trainer import *
|
||||
from ...initializer import glorot_uniform
|
||||
from .. import constant, parameter, input_variable, placeholder_variable, times, plus, past_value, sequence, as_composite, combine
|
||||
from .. import constant, parameter, input_variable, placeholder_variable, times, plus, past_value, sequence, as_composite, combine, convolution, splice
|
||||
from ... import InferredDimension
|
||||
from .ops_test_utils import compare_lists_of_np_arrays, AA
|
||||
|
||||
|
@ -21,7 +21,6 @@ def test_variable_forwarding():
|
|||
op = constant(value=2, shape=(3,4)) + 1
|
||||
assert op.shape == (3,4)
|
||||
|
||||
|
||||
def test_replace_placeholders():
|
||||
p = placeholder_variable(shape=(1,))
|
||||
i = input_variable(shape=(1,),
|
||||
|
@ -214,6 +213,23 @@ def test_clone_with_function_in_substitution_map():
|
|||
just_b = t_plus_b.clone('clone', {t : p})
|
||||
t_plus_b_clone = just_b.clone('share', {p : t})
|
||||
|
||||
def test_clone_with_slice():
|
||||
i1 = input_variable((2,2), name='i1')
|
||||
i2 = input_variable((2,2), name='i2')
|
||||
x = splice((i1,i2), 0)
|
||||
W = constant(1, (4,1), name='W')
|
||||
y = convolution(W, x)
|
||||
assert(y.shape == (4,2))
|
||||
|
||||
from ..functions import CloneMethod
|
||||
x1 = input_variable((2,1), name='x1')
|
||||
x2 = input_variable((2,1), name='x2')
|
||||
p1 = placeholder_variable()
|
||||
p2 = placeholder_variable()
|
||||
y_cloned = y.clone('clone', {i1:p1, i2:p2})
|
||||
y2 = y_cloned(x1, x2)
|
||||
assert(y2.shape == (4,1))
|
||||
|
||||
def test_as_composite():
|
||||
input_dim = 1
|
||||
proj_dim = 2
|
||||
|
|
Загрузка…
Ссылка в новой задаче