Added Log Computational Node
This commit is contained in:
Родитель
f5e41bb3bb
Коммит
04f519b1d0
|
@ -1921,174 +1921,341 @@ namespace Microsoft { namespace MSR { namespace CNTK {
|
|||
template class TanhNode<float>;
|
||||
template class TanhNode<double>;
|
||||
|
||||
template<class ElemType>
|
||||
class LogNode : public ComputationNode<ElemType>
|
||||
{
|
||||
typedef ComputationNode<ElemType>* ComputationNodePtr;
|
||||
template<class ElemType>
|
||||
class LogNode : public ComputationNode<ElemType>
|
||||
{
|
||||
typedef ComputationNode<ElemType>* ComputationNodePtr;
|
||||
|
||||
|
||||
public:
|
||||
LogNode(const short deviceId=AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode(deviceId), m_gradientOfLog(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L""? CreateUniqNodeName() : name);
|
||||
m_deviceId = deviceId;
|
||||
MoveMatricesToDevice(deviceId);
|
||||
InitRecurrentNode();
|
||||
}
|
||||
public:
|
||||
LogNode(const short deviceId = AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode(deviceId), m_gradientOfLog(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
m_deviceId = deviceId;
|
||||
MoveMatricesToDevice(deviceId);
|
||||
InitRecurrentNode();
|
||||
}
|
||||
|
||||
LogNode(File& fstream, const size_t modelVersion, const short deviceId=AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode(deviceId), m_gradientOfLog(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L""? CreateUniqNodeName() : name);
|
||||
LoadFromFile(fstream, modelVersion, deviceId);
|
||||
}
|
||||
LogNode(File& fstream, const size_t modelVersion, const short deviceId = AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode(deviceId), m_gradientOfLog(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
LoadFromFile(fstream, modelVersion, deviceId);
|
||||
}
|
||||
|
||||
virtual const std::wstring OperationName() const {return TypeName();}
|
||||
static const std::wstring TypeName() {return L"Log";}
|
||||
virtual const std::wstring OperationName() const { return TypeName(); }
|
||||
static const std::wstring TypeName() { return L"Log"; }
|
||||
|
||||
|
||||
virtual void ComputeInputPartial(const size_t inputIndex)
|
||||
{
|
||||
if (inputIndex != 0)
|
||||
throw std::invalid_argument("Log only has one input.");
|
||||
ComputeInputPartialS(m_gradientOfLog, Inputs(0)->GradientValues(), Inputs(0)->FunctionValues(), GradientValues());
|
||||
}
|
||||
virtual void ComputeInputPartial(const size_t inputIndex)
|
||||
{
|
||||
if (inputIndex != 0)
|
||||
throw std::invalid_argument("Log only has one input.");
|
||||
ComputeInputPartialS(m_gradientOfLog, Inputs(0)->GradientValues(), Inputs(0)->FunctionValues(), GradientValues());
|
||||
}
|
||||
|
||||
virtual void ComputeInputPartial(const size_t inputIndex, const size_t timeIdxInSeq)
|
||||
{
|
||||
if (inputIndex != 0)
|
||||
throw std::invalid_argument("Log only has one input.");
|
||||
virtual void ComputeInputPartial(const size_t inputIndex, const size_t timeIdxInSeq)
|
||||
{
|
||||
if (inputIndex != 0)
|
||||
throw std::invalid_argument("Log only has one input.");
|
||||
|
||||
Matrix<ElemType> sliceInputGrad = Inputs(0)->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputGrad = GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceInputGrad = Inputs(0)->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputGrad = GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
|
||||
Matrix<ElemType> sliceInputValue = Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceInputValue = Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
|
||||
ComputeInputPartialS(m_gradientOfLog, sliceInputGrad, sliceInputValue, sliceOutputGrad);
|
||||
}
|
||||
ComputeInputPartialS(m_gradientOfLog, sliceInputGrad, sliceInputValue, sliceOutputGrad);
|
||||
}
|
||||
|
||||
static void WINAPI ComputeInputPartialS(Matrix<ElemType>& gradientOfLog, Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& inputFunctionValues, const Matrix<ElemType>& gradientValues)
|
||||
{
|
||||
gradientOfLog.AssignElementInverseOf(inputFunctionValues); // 1/x (x is input to log(x))
|
||||
static void WINAPI ComputeInputPartialS(Matrix<ElemType>& gradientOfLog, Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& inputFunctionValues, const Matrix<ElemType>& gradientValues)
|
||||
{
|
||||
gradientOfLog.AssignElementInverseOf(inputFunctionValues); // 1/x (x is input to log(x))
|
||||
|
||||
inputGradientValues.AddElementProductOf(gradientValues, gradientOfLog);
|
||||
}
|
||||
inputGradientValues.AddElementProductOf(gradientValues, gradientOfLog);
|
||||
}
|
||||
|
||||
// GetTaskDescriptor - Get a task descriptor for this node
|
||||
// taskType - task type we are generating a task for
|
||||
virtual TaskDescriptor<ElemType>* GetPTaskDescriptor(TaskType taskType, size_t inputIndex=0) const
|
||||
{
|
||||
TaskDescriptor<ElemType>* descriptor = new TaskDescriptor<ElemType>(this, taskType, inputIndex);
|
||||
switch(taskType)
|
||||
{
|
||||
case taskComputeInputPartial:
|
||||
descriptor->MatrixParam(m_gradientOfLog, "GradientOfLog", paramOptionsInput | paramOptionsTemporary);
|
||||
descriptor->GradientParam(0, paramOptionsInput | paramOptionsOutput | paramOptionsInitialize);
|
||||
descriptor->FunctionParam(0, paramOptionsInput);
|
||||
descriptor->GradientParam();
|
||||
descriptor->SetFunction((FARPROC)ComputeInputPartialS);
|
||||
break;
|
||||
case taskEvaluate:
|
||||
descriptor->FunctionParam();
|
||||
descriptor->FunctionParam(0, paramOptionsInput);
|
||||
descriptor->SetFunction((FARPROC)EvaluateThisNodeS);
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
throw std::logic_error("Unsupported task requested");
|
||||
}
|
||||
return descriptor;
|
||||
}
|
||||
// GetTaskDescriptor - Get a task descriptor for this node
|
||||
// taskType - task type we are generating a task for
|
||||
virtual TaskDescriptor<ElemType>* GetPTaskDescriptor(TaskType taskType, size_t inputIndex = 0) const
|
||||
{
|
||||
TaskDescriptor<ElemType>* descriptor = new TaskDescriptor<ElemType>(this, taskType, inputIndex);
|
||||
switch (taskType)
|
||||
{
|
||||
case taskComputeInputPartial:
|
||||
descriptor->MatrixParam(m_gradientOfLog, "GradientOfLog", paramOptionsInput | paramOptionsTemporary);
|
||||
descriptor->GradientParam(0, paramOptionsInput | paramOptionsOutput | paramOptionsInitialize);
|
||||
descriptor->FunctionParam(0, paramOptionsInput);
|
||||
descriptor->GradientParam();
|
||||
descriptor->SetFunction((FARPROC)ComputeInputPartialS);
|
||||
break;
|
||||
case taskEvaluate:
|
||||
descriptor->FunctionParam();
|
||||
descriptor->FunctionParam(0, paramOptionsInput);
|
||||
descriptor->SetFunction((FARPROC)EvaluateThisNodeS);
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
throw std::logic_error("Unsupported task requested");
|
||||
}
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
virtual void EvaluateThisNode()
|
||||
{
|
||||
EvaluateThisNodeS(m_functionValues, Inputs(0)->FunctionValues());
|
||||
}
|
||||
virtual void EvaluateThisNode()
|
||||
{
|
||||
EvaluateThisNodeS(m_functionValues, Inputs(0)->FunctionValues());
|
||||
}
|
||||
|
||||
virtual void EvaluateThisNode(const size_t timeIdxInSeq)
|
||||
{
|
||||
Matrix<ElemType> sliceInputValue = Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputValue = m_functionValues.ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
virtual void EvaluateThisNode(const size_t timeIdxInSeq)
|
||||
{
|
||||
Matrix<ElemType> sliceInputValue = Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputValue = m_functionValues.ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
|
||||
EvaluateThisNodeS(sliceOutputValue, sliceInputValue);
|
||||
}
|
||||
EvaluateThisNodeS(sliceOutputValue, sliceInputValue);
|
||||
}
|
||||
|
||||
static void WINAPI EvaluateThisNodeS(Matrix<ElemType>& functionValues, const Matrix<ElemType>& inputFunctionValues)
|
||||
{
|
||||
functionValues.AssignLogOf(inputFunctionValues);
|
||||
static void WINAPI EvaluateThisNodeS(Matrix<ElemType>& functionValues, const Matrix<ElemType>& inputFunctionValues)
|
||||
{
|
||||
functionValues.AssignLogOf(inputFunctionValues);
|
||||
#if NANCHECK
|
||||
functionValues.HasNan("Log");
|
||||
functionValues.HasNan("Log");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
virtual void Validate()
|
||||
{
|
||||
PrintSelfBeforeValidation();
|
||||
virtual void Validate()
|
||||
{
|
||||
PrintSelfBeforeValidation();
|
||||
|
||||
if (m_children.size() != 1)
|
||||
throw std::logic_error("Log operation should have one input.");
|
||||
if (m_children.size() != 1)
|
||||
throw std::logic_error("Log operation should have one input.");
|
||||
|
||||
if (Inputs(0)->FunctionValues().GetNumElements() == 0)
|
||||
throw std::logic_error("Log operation: the input node has 0 element.");
|
||||
if (Inputs(0)->FunctionValues().GetNumElements() == 0)
|
||||
throw std::logic_error("Log operation: the input node has 0 element.");
|
||||
|
||||
FunctionValues().Resize(Inputs(0)->FunctionValues().GetNumRows(), Inputs(0)->FunctionValues().GetNumCols());
|
||||
m_gradientOfLog.Resize(Inputs(0)->FunctionValues().GetNumRows(), Inputs(0)->FunctionValues().GetNumCols());
|
||||
CopyImageSizeFromInputs();
|
||||
}
|
||||
FunctionValues().Resize(Inputs(0)->FunctionValues().GetNumRows(), Inputs(0)->FunctionValues().GetNumCols());
|
||||
m_gradientOfLog.Resize(Inputs(0)->FunctionValues().GetNumRows(), Inputs(0)->FunctionValues().GetNumCols());
|
||||
CopyImageSizeFromInputs();
|
||||
}
|
||||
|
||||
virtual void AttachInputs(const ComputationNodePtr singleInput)
|
||||
{
|
||||
m_children.resize(1);
|
||||
m_children[0] = singleInput;
|
||||
}
|
||||
virtual void AttachInputs(const ComputationNodePtr singleInput)
|
||||
{
|
||||
m_children.resize(1);
|
||||
m_children[0] = singleInput;
|
||||
}
|
||||
|
||||
virtual void MoveMatricesToDevice(const short deviceId)
|
||||
{
|
||||
ComputationNode<ElemType>::MoveMatricesToDevice(deviceId);
|
||||
virtual void MoveMatricesToDevice(const short deviceId)
|
||||
{
|
||||
ComputationNode<ElemType>::MoveMatricesToDevice(deviceId);
|
||||
|
||||
if (deviceId != AUTOPLACEMATRIX)
|
||||
{
|
||||
if (m_gradientOfLog.GetDeviceId() != deviceId)
|
||||
m_gradientOfLog.TransferFromDeviceToDevice(m_gradientOfLog.GetDeviceId(), deviceId);
|
||||
}
|
||||
}
|
||||
if (deviceId != AUTOPLACEMATRIX)
|
||||
{
|
||||
if (m_gradientOfLog.GetDeviceId() != deviceId)
|
||||
m_gradientOfLog.TransferFromDeviceToDevice(m_gradientOfLog.GetDeviceId(), deviceId);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void CopyTo(const ComputationNodePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
ComputationNode<ElemType>::CopyTo(nodeP, newName, flags);
|
||||
LogNode<ElemType>* node = (LogNode<ElemType>*) nodeP;
|
||||
virtual void CopyTo(const ComputationNodePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
ComputationNode<ElemType>::CopyTo(nodeP, newName, flags);
|
||||
LogNode<ElemType>* node = (LogNode<ElemType>*) nodeP;
|
||||
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
node->m_gradientOfLog = m_gradientOfLog;
|
||||
}
|
||||
}
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
node->m_gradientOfLog = m_gradientOfLog;
|
||||
}
|
||||
}
|
||||
|
||||
// copy constructor
|
||||
LogNode(const LogNode<ElemType>* node, const std::wstring& newName, const CopyNodeFlags flags)
|
||||
: ComputationNode(node->m_deviceId), m_gradientOfLog(node->m_deviceId)
|
||||
{
|
||||
node->CopyTo(this, newName, flags);
|
||||
}
|
||||
// copy constructor
|
||||
LogNode(const LogNode<ElemType>* node, const std::wstring& newName, const CopyNodeFlags flags)
|
||||
: ComputationNode(node->m_deviceId), m_gradientOfLog(node->m_deviceId)
|
||||
{
|
||||
node->CopyTo(this, newName, flags);
|
||||
}
|
||||
|
||||
virtual ComputationNodePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
const std::wstring& name = (newName == L"")?NodeName():newName;
|
||||
|
||||
ComputationNodePtr node = new LogNode<ElemType>(this, name, flags);
|
||||
return node;
|
||||
}
|
||||
virtual ComputationNodePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
const std::wstring& name = (newName == L"") ? NodeName() : newName;
|
||||
|
||||
private:
|
||||
Matrix<ElemType> m_gradientOfLog;
|
||||
};
|
||||
ComputationNodePtr node = new LogNode<ElemType>(this, name, flags);
|
||||
return node;
|
||||
}
|
||||
|
||||
template class LogNode<float>;
|
||||
template class LogNode<double>;
|
||||
private:
|
||||
Matrix<ElemType> m_gradientOfLog;
|
||||
};
|
||||
|
||||
template class LogNode<float>;
|
||||
template class LogNode<double>;
|
||||
|
||||
|
||||
template<class ElemType>
|
||||
template<class ElemType>
|
||||
class ExpNode : public ComputationNode<ElemType>
|
||||
{
|
||||
typedef ComputationNode<ElemType>* ComputationNodePtr;
|
||||
|
||||
|
||||
public:
|
||||
ExpNode(const short deviceId = AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode(deviceId), m_gradientOfExp(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
m_deviceId = deviceId;
|
||||
MoveMatricesToDevice(deviceId);
|
||||
InitRecurrentNode();
|
||||
}
|
||||
|
||||
ExpNode(File& fstream, const size_t modelVersion, const short deviceId = AUTOPLACEMATRIX, const std::wstring name = L"")
|
||||
: ComputationNode(deviceId), m_gradientOfExp(deviceId)
|
||||
{
|
||||
m_nodeName = (name == L"" ? CreateUniqNodeName() : name);
|
||||
LoadFromFile(fstream, modelVersion, deviceId);
|
||||
}
|
||||
|
||||
virtual const std::wstring OperationName() const { return TypeName(); }
|
||||
static const std::wstring TypeName() { return L"Exp"; }
|
||||
|
||||
|
||||
virtual void ComputeInputPartial(const size_t inputIndex)
|
||||
{
|
||||
if (inputIndex != 0)
|
||||
throw std::invalid_argument("Exp only has one input.");
|
||||
ComputeInputPartialS(m_gradientOfExp, Inputs(0)->GradientValues(), Inputs(0)->FunctionValues(), GradientValues());
|
||||
}
|
||||
|
||||
virtual void ComputeInputPartial(const size_t inputIndex, const size_t timeIdxInSeq)
|
||||
{
|
||||
if (inputIndex != 0)
|
||||
throw std::invalid_argument("Exp only has one input.");
|
||||
|
||||
Matrix<ElemType> sliceInputGrad = Inputs(0)->GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputGrad = GradientValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
|
||||
Matrix<ElemType> sliceInputValue = Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
|
||||
ComputeInputPartialS(m_gradientOfExp, sliceInputGrad, sliceInputValue, sliceOutputGrad);
|
||||
}
|
||||
|
||||
static void WINAPI ComputeInputPartialS(Matrix<ElemType>& gradientOfExp, Matrix<ElemType>& inputGradientValues, const Matrix<ElemType>& inputFunctionValues, const Matrix<ElemType>& gradientValues)
|
||||
{
|
||||
gradientOfExp.AssignExpOf(inputFunctionValues); // Exp(x) is its own partial
|
||||
|
||||
inputGradientValues.AddElementProductOf(gradientValues, gradientOfExp);
|
||||
}
|
||||
|
||||
// GetTaskDescriptor - Get a task descriptor for this node
|
||||
// taskType - task type we are generating a task for
|
||||
virtual TaskDescriptor<ElemType>* GetPTaskDescriptor(TaskType taskType, size_t inputIndex = 0) const
|
||||
{
|
||||
TaskDescriptor<ElemType>* descriptor = new TaskDescriptor<ElemType>(this, taskType, inputIndex);
|
||||
switch (taskType)
|
||||
{
|
||||
case taskComputeInputPartial:
|
||||
descriptor->MatrixParam(m_gradientOfExp, "GradientOfExp", paramOptionsInput | paramOptionsTemporary);
|
||||
descriptor->GradientParam(0, paramOptionsInput | paramOptionsOutput | paramOptionsInitialize);
|
||||
descriptor->FunctionParam(0, paramOptionsInput);
|
||||
descriptor->GradientParam();
|
||||
descriptor->SetFunction((FARPROC)ComputeInputPartialS);
|
||||
break;
|
||||
case taskEvaluate:
|
||||
descriptor->FunctionParam();
|
||||
descriptor->FunctionParam(0, paramOptionsInput);
|
||||
descriptor->SetFunction((FARPROC)EvaluateThisNodeS);
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
throw std::logic_error("Unsupported task requested");
|
||||
}
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
virtual void EvaluateThisNode()
|
||||
{
|
||||
EvaluateThisNodeS(m_functionValues, Inputs(0)->FunctionValues());
|
||||
}
|
||||
|
||||
virtual void EvaluateThisNode(const size_t timeIdxInSeq)
|
||||
{
|
||||
Matrix<ElemType> sliceInputValue = Inputs(0)->FunctionValues().ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
Matrix<ElemType> sliceOutputValue = m_functionValues.ColumnSlice(timeIdxInSeq * m_samplesInRecurrentStep, m_samplesInRecurrentStep);
|
||||
|
||||
EvaluateThisNodeS(sliceOutputValue, sliceInputValue);
|
||||
}
|
||||
|
||||
static void WINAPI EvaluateThisNodeS(Matrix<ElemType>& functionValues, const Matrix<ElemType>& inputFunctionValues)
|
||||
{
|
||||
functionValues.AssignExpOf(inputFunctionValues);
|
||||
#if NANCHECK
|
||||
functionValues.HasNan("Exp");
|
||||
#endif
|
||||
}
|
||||
|
||||
virtual void Validate()
|
||||
{
|
||||
PrintSelfBeforeValidation();
|
||||
|
||||
if (m_children.size() != 1)
|
||||
throw std::logic_error("Exp operation should have one input.");
|
||||
|
||||
if (Inputs(0)->FunctionValues().GetNumElements() == 0)
|
||||
throw std::logic_error("Exp operation: the input node has 0 element.");
|
||||
|
||||
FunctionValues().Resize(Inputs(0)->FunctionValues().GetNumRows(), Inputs(0)->FunctionValues().GetNumCols());
|
||||
m_gradientOfExp.Resize(Inputs(0)->FunctionValues().GetNumRows(), Inputs(0)->FunctionValues().GetNumCols());
|
||||
CopyImageSizeFromInputs();
|
||||
}
|
||||
|
||||
virtual void AttachInputs(const ComputationNodePtr singleInput)
|
||||
{
|
||||
m_children.resize(1);
|
||||
m_children[0] = singleInput;
|
||||
}
|
||||
|
||||
virtual void MoveMatricesToDevice(const short deviceId)
|
||||
{
|
||||
ComputationNode<ElemType>::MoveMatricesToDevice(deviceId);
|
||||
|
||||
if (deviceId != AUTOPLACEMATRIX)
|
||||
{
|
||||
if (m_gradientOfExp.GetDeviceId() != deviceId)
|
||||
m_gradientOfExp.TransferFromDeviceToDevice(m_gradientOfExp.GetDeviceId(), deviceId);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void CopyTo(const ComputationNodePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
ComputationNode<ElemType>::CopyTo(nodeP, newName, flags);
|
||||
ExpNode<ElemType>* node = (ExpNode<ElemType>*) nodeP;
|
||||
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
node->m_gradientOfExp = m_gradientOfExp;
|
||||
}
|
||||
}
|
||||
|
||||
// copy constructor
|
||||
ExpNode(const ExpNode<ElemType>* node, const std::wstring& newName, const CopyNodeFlags flags)
|
||||
: ComputationNode(node->m_deviceId), m_gradientOfExp(node->m_deviceId)
|
||||
{
|
||||
node->CopyTo(this, newName, flags);
|
||||
}
|
||||
|
||||
virtual ComputationNodePtr Duplicate(const std::wstring& newName, const CopyNodeFlags flags) const
|
||||
{
|
||||
const std::wstring& name = (newName == L"") ? NodeName() : newName;
|
||||
|
||||
ComputationNodePtr node = new ExpNode<ElemType>(this, name, flags);
|
||||
return node;
|
||||
}
|
||||
|
||||
private:
|
||||
Matrix<ElemType> m_gradientOfExp;
|
||||
};
|
||||
|
||||
template class ExpNode<float>;
|
||||
template class ExpNode<double>;
|
||||
|
||||
|
||||
template<class ElemType>
|
||||
class CosineNode : public ComputationNode<ElemType>
|
||||
{
|
||||
typedef ComputationNode<ElemType>* ComputationNodePtr;
|
||||
|
|
Загрузка…
Ссылка в новой задаче