fix comments in cr
This commit is contained in:
Родитель
8d1bc2465b
Коммит
f4e9df1242
|
@ -1281,8 +1281,7 @@ namespace CNTK
|
|||
if (operand.Shape().Dimensions().size() == 0)
|
||||
LogicError("ToBatch: the input can not be scalar.");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
return UnaryOp(PrimitiveOpType::ToBatch, operand, std::move(additionalProperties), name);
|
||||
return UnaryOp(PrimitiveOpType::ToBatch, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr UnpackBatch(const Variable& operand, const std::wstring& name)
|
||||
|
@ -1290,8 +1289,7 @@ namespace CNTK
|
|||
if (operand.DynamicAxes().size() > 1)
|
||||
LogicError("UnpackBatch: only support input with batch axis itself.");
|
||||
|
||||
auto additionalProperties = Dictionary();
|
||||
return UnaryOp(PrimitiveOpType::UnpackBatch, operand, std::move(additionalProperties), name);
|
||||
return UnaryOp(PrimitiveOpType::UnpackBatch, operand, Dictionary(), name);
|
||||
}
|
||||
|
||||
FunctionPtr GumbelRandom(const NDShape& shape, DataType dataType, double loc, double scale, unsigned long seed, const std::wstring& name)
|
||||
|
|
|
@ -433,7 +433,11 @@ namespace CNTK
|
|||
if (!(m_inputs[0].IsConstant() || m_inputs[0].IsParameter()))
|
||||
InvalidArgument("AssignNode: Ref operand must be constant or parameter only.");
|
||||
//delay the check for free dimension
|
||||
if (m_inputs[0].Shape() != m_inputs[1].Shape() && !m_inputs[0].Shape().HasFreeDimension() && !m_inputs[1].Shape().HasFreeDimension())
|
||||
if (m_inputs[0].Shape() != m_inputs[1].Shape() &&
|
||||
!m_inputs[0].Shape().HasFreeDimension() &&
|
||||
!m_inputs[1].Shape().HasFreeDimension() &&
|
||||
!m_inputs[0].Shape().HasInferredDimension() &&
|
||||
!m_inputs[1].Shape().HasInferredDimension())
|
||||
{
|
||||
InvalidArgument("AssignNode: All inputs should have same sample layout.");
|
||||
}
|
||||
|
|
|
@ -500,7 +500,7 @@ class ToBatchAxisNode : public ComputationNodeNonLooping<ElemType>, public NumIn
|
|||
{
|
||||
typedef ComputationNodeNonLooping<ElemType> Base; UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName() {
|
||||
return L"AttachDynamicAxis";
|
||||
return L"ToBatchAxisNode";
|
||||
}
|
||||
public:
|
||||
ToBatchAxisNode(DEVICEID_TYPE deviceId, const wstring& name)
|
||||
|
@ -567,7 +567,8 @@ public:
|
|||
|
||||
if (!m_pMBLayout)
|
||||
{
|
||||
m_pMBLayout = make_shared<MBLayout>(1, 0, ComputationNodeBase::DefaultNoSequenceAxisName); // this generates a new layout
|
||||
m_pMBLayout = make_shared<MBLayout>(); // this generates a new layout
|
||||
m_pMBLayout->SetUniqueAxisName(ComputationNodeBase::DefaultNoSequenceAxisName);
|
||||
}
|
||||
|
||||
auto sampleLayout = Input(0)->GetSampleLayout();
|
||||
|
@ -600,7 +601,7 @@ class UnpackBatchAixsNode : public ComputationNodeNonLooping<ElemType>, public N
|
|||
{
|
||||
typedef ComputationNodeNonLooping<ElemType> Base; UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName() {
|
||||
return L"DetachDynamicAxis";
|
||||
return L"UnpackBatchAixs";
|
||||
}
|
||||
public:
|
||||
UnpackBatchAixsNode(DEVICEID_TYPE deviceId, const wstring& name)
|
||||
|
|
|
@ -492,7 +492,7 @@ def test_convert_dynamic_axis():
|
|||
|
||||
const_a = C.unpack_batch(y)
|
||||
assert len(const_a.dynamic_axes) == 0
|
||||
assert const_a.shape == (-3, 2, 3)
|
||||
assert const_a.shape == (C.FreeDimension, 2, 3)
|
||||
|
||||
f = C.assign(a, const_a)
|
||||
z = x + 1
|
||||
|
@ -505,10 +505,10 @@ def test_convert_dynamic_axis():
|
|||
x = C.input_variable((2,3))
|
||||
const_x = C.unpack_batch(x)
|
||||
assert len(const_x.dynamic_axes) == 0
|
||||
assert const_x.shape == (-3, 2, 3)
|
||||
assert const_x.shape == (C.FreeDimension, 2, 3)
|
||||
|
||||
const_y = C.reshape(const_x, (-1, 3))
|
||||
assert const_y.shape == (-3, 3)
|
||||
assert const_y.shape == (C.FreeDimension, 3)
|
||||
y = C.to_batch(const_y)
|
||||
assert len(y.dynamic_axes) == 1
|
||||
assert y.shape == (3,)
|
||||
|
|
Загрузка…
Ссылка в новой задаче