addressed CR feedback (new overload for PastValueNode constructor)

This commit is contained in:
Frank Seide 2016-09-15 08:18:00 -07:00
Родитель c6685896d6
Коммит b0ade9c1c1
2 изменённых файлов: 10 добавлений и 3 удалений

Просмотреть файл

@ -972,11 +972,10 @@ namespace CNTK
Variable initialStateVar = functionInputs[1];
size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
float dummyInitialStateValue = 0.0f; // Not really used but then why are we forced to pass this?
if (op == PrimitiveOpType::PastValue)
computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), functionName, dummyInitialStateValue, AsTensorShape(inputOperandVar.Shape()), offset);
computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), functionName, AsTensorShape(inputOperandVar.Shape()), offset);
else
computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), functionName, dummyInitialStateValue, AsTensorShape(inputOperandVar.Shape()), offset);
computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), functionName, AsTensorShape(inputOperandVar.Shape()), offset);
break;
}

Просмотреть файл

@ -117,6 +117,10 @@ public:
: Base(deviceId, name, fixedInitialStateScalarValue, sampleLayout, timeStep)
{
}
PastValueNode(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& sampleLayout, size_t timeStep)
: Base(deviceId, name, (ElemType)0, sampleLayout, timeStep)
{
}
PastValueNode(DEVICEID_TYPE deviceId, const wstring& name, ElemType fixedInitialStateScalarValue, size_t numRows, size_t timeStep)
: PastValueNode(deviceId, name, fixedInitialStateScalarValue, TensorShape(numRows), timeStep)
{
@ -147,6 +151,10 @@ public:
: Base(deviceId, name, fixedInitialStateScalarValue, sampleLayout, timeStep)
{
}
FutureValueNode(DEVICEID_TYPE deviceId, const wstring& name, const TensorShape& sampleLayout, size_t timeStep)
: Base(deviceId, name, (ElemType)0, sampleLayout, timeStep)
{
}
FutureValueNode(DEVICEID_TYPE deviceId, const wstring& name, ElemType fixedInitialStateScalarValue, size_t numRows, size_t timeStep)
: FutureValueNode(deviceId, name, fixedInitialStateScalarValue, TensorShape(numRows), timeStep)
{