addressed CR feedback (new overload for PastValueNode constructor)
This commit is contained in:
Родитель
c6685896d6
Коммит
b0ade9c1c1
|
@ -972,11 +972,10 @@ namespace CNTK
|
|||
Variable initialStateVar = functionInputs[1];
|
||||
|
||||
size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
|
||||
float dummyInitialStateValue = 0.0f; // Not really used but then why are we forced to pass this?
|
||||
if (op == PrimitiveOpType::PastValue)
|
||||
computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), functionName, dummyInitialStateValue, AsTensorShape(inputOperandVar.Shape()), offset);
|
||||
computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), functionName, AsTensorShape(inputOperandVar.Shape()), offset);
|
||||
else
|
||||
computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), functionName, dummyInitialStateValue, AsTensorShape(inputOperandVar.Shape()), offset);
|
||||
computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), functionName, AsTensorShape(inputOperandVar.Shape()), offset);
|
||||
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -117,6 +117,10 @@ public:
|
|||
: Base(deviceId, name, fixedInitialStateScalarValue, sampleLayout, timeStep)
|
||||
{
|
||||
}
|
||||
PastValueNode(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& sampleLayout, size_t timeStep)
|
||||
: Base(deviceId, name, (ElemType)0, sampleLayout, timeStep)
|
||||
{
|
||||
}
|
||||
PastValueNode(DEVICEID_TYPE deviceId, const wstring& name, ElemType fixedInitialStateScalarValue, size_t numRows, size_t timeStep)
|
||||
: PastValueNode(deviceId, name, fixedInitialStateScalarValue, TensorShape(numRows), timeStep)
|
||||
{
|
||||
|
@ -147,6 +151,10 @@ public:
|
|||
: Base(deviceId, name, fixedInitialStateScalarValue, sampleLayout, timeStep)
|
||||
{
|
||||
}
|
||||
FutureValueNode(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& sampleLayout, size_t timeStep)
|
||||
: Base(deviceId, name, (ElemType)0, sampleLayout, timeStep)
|
||||
{
|
||||
}
|
||||
FutureValueNode(DEVICEID_TYPE deviceId, const wstring& name, ElemType fixedInitialStateScalarValue, size_t numRows, size_t timeStep)
|
||||
: FutureValueNode(deviceId, name, fixedInitialStateScalarValue, TensorShape(numRows), timeStep)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче