This commit is contained in:
Cheng Tang 2017-07-21 16:41:35 -07:00
Родитель 8d1bc2465b
Коммит f4e9df1242
4 изменённых файлов: 14 добавлений и 11 удалений

Просмотреть файл

@ -1281,8 +1281,7 @@ namespace CNTK
if (operand.Shape().Dimensions().size() == 0)
LogicError("ToBatch: the input can not be scalar.");
auto additionalProperties = Dictionary();
return UnaryOp(PrimitiveOpType::ToBatch, operand, std::move(additionalProperties), name);
return UnaryOp(PrimitiveOpType::ToBatch, operand, Dictionary(), name);
}
FunctionPtr UnpackBatch(const Variable& operand, const std::wstring& name)
@ -1290,8 +1289,7 @@ namespace CNTK
if (operand.DynamicAxes().size() > 1)
LogicError("UnpackBatch: only support input with batch axis itself.");
auto additionalProperties = Dictionary();
return UnaryOp(PrimitiveOpType::UnpackBatch, operand, std::move(additionalProperties), name);
return UnaryOp(PrimitiveOpType::UnpackBatch, operand, Dictionary(), name);
}
FunctionPtr GumbelRandom(const NDShape& shape, DataType dataType, double loc, double scale, unsigned long seed, const std::wstring& name)

Просмотреть файл

@ -433,7 +433,11 @@ namespace CNTK
if (!(m_inputs[0].IsConstant() || m_inputs[0].IsParameter()))
InvalidArgument("AssignNode: Ref operand must be constant or parameter only.");
//delay the check for free dimension
if (m_inputs[0].Shape() != m_inputs[1].Shape() && !m_inputs[0].Shape().HasFreeDimension() && !m_inputs[1].Shape().HasFreeDimension())
if (m_inputs[0].Shape() != m_inputs[1].Shape() &&
!m_inputs[0].Shape().HasFreeDimension() &&
!m_inputs[1].Shape().HasFreeDimension() &&
!m_inputs[0].Shape().HasInferredDimension() &&
!m_inputs[1].Shape().HasInferredDimension())
{
InvalidArgument("AssignNode: All inputs should have same sample layout.");
}

Просмотреть файл

@ -500,7 +500,7 @@ class ToBatchAxisNode : public ComputationNodeNonLooping<ElemType>, public NumIn
{
typedef ComputationNodeNonLooping<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() {
return L"AttachDynamicAxis";
return L"ToBatchAxisNode";
}
public:
ToBatchAxisNode(DEVICEID_TYPE deviceId, const wstring& name)
@ -567,7 +567,8 @@ public:
if (!m_pMBLayout)
{
m_pMBLayout = make_shared<MBLayout>(1, 0, ComputationNodeBase::DefaultNoSequenceAxisName); // this generates a new layout
m_pMBLayout = make_shared<MBLayout>(); // this generates a new layout
m_pMBLayout->SetUniqueAxisName(ComputationNodeBase::DefaultNoSequenceAxisName);
}
auto sampleLayout = Input(0)->GetSampleLayout();
@ -600,7 +601,7 @@ class UnpackBatchAixsNode : public ComputationNodeNonLooping<ElemType>, public N
{
typedef ComputationNodeNonLooping<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() {
return L"DetachDynamicAxis";
return L"UnpackBatchAixs";
}
public:
UnpackBatchAixsNode(DEVICEID_TYPE deviceId, const wstring& name)

Просмотреть файл

@ -492,7 +492,7 @@ def test_convert_dynamic_axis():
const_a = C.unpack_batch(y)
assert len(const_a.dynamic_axes) == 0
assert const_a.shape == (-3, 2, 3)
assert const_a.shape == (C.FreeDimension, 2, 3)
f = C.assign(a, const_a)
z = x + 1
@ -505,10 +505,10 @@ def test_convert_dynamic_axis():
x = C.input_variable((2,3))
const_x = C.unpack_batch(x)
assert len(const_x.dynamic_axes) == 0
assert const_x.shape == (-3, 2, 3)
assert const_x.shape == (C.FreeDimension, 2, 3)
const_y = C.reshape(const_x, (-1, 3))
assert const_y.shape == (-3, 3)
assert const_y.shape == (C.FreeDimension, 3)
y = C.to_batch(const_y)
assert len(y.dynamic_axes) == 1
assert y.shape == (3,)