This commit is contained in:
Clemens Marschner 2016-09-09 17:37:00 +02:00
Родитель c0c4c5d9eb
Коммит 1a199c5241
3 изменённых файлов: 10 добавлений и 10 удалений

Просмотреть файл

@ -143,7 +143,7 @@ public:
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
auto result = ValueTensorFor(rank, fr);
auto result = ValueTensorFor(rank, fr);
auto input0 = InputRef(0).ValueTensorFor(rank, fr.AllowBroadcast());
auto input1 = InputRef(1).ValueTensorFor(rank, fr.AllowBroadcast());
result.AssignDifferenceOf(input0, input1);
@ -193,7 +193,7 @@ public:
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
auto result = ValueTensorFor(rank, fr);
auto result = ValueTensorFor(rank, fr);
auto input0 = InputRef(0).ValueTensorFor(rank, fr.AllowBroadcast());
auto input1 = InputRef(1).ValueTensorFor(rank, fr.AllowBroadcast());
result.AssignElementwiseProductOf(input0, input1);
@ -203,7 +203,7 @@ public:
{
size_t rank = DetermineElementwiseTensorRank();
auto gradient = GradientTensorFor(rank, fr);
auto inputGradient = Input(inputIndex)->GradientTensorFor(rank, fr.AllowBroadcast());
auto inputGradient = Input(inputIndex)->GradientTensorFor(rank, fr.AllowBroadcast());
auto otherInputValue = Input(1 - inputIndex)->ValueTensorFor(rank, fr.AllowBroadcast());
// if reduction then mask the respective input(s) (zero out the gaps)
@ -689,7 +689,7 @@ public:
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
auto output = ValueTensorFor( rank, fr);
auto output = ValueTensorFor( rank, fr);
auto input = TensorView<ElemType>(InputRef(0).ValuePtr(), GetTransposedTensorSliceFor(rank, fr));
output.AssignCopyOf(input);
}
@ -697,7 +697,7 @@ public:
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
auto outputGradient = GradientTensorFor( rank, fr);
auto outputGradient = GradientTensorFor( rank, fr);
auto inputGradient = TensorView<ElemType>(InputRef(0).GradientPtr(), GetTransposedTensorSliceFor(rank, fr));
inputGradient.AddCopyOf(outputGradient);
}

Просмотреть файл

@ -50,7 +50,7 @@ public:
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
auto result = ValueTensorFor(rank, fr);
auto result = ValueTensorFor(rank, fr);
auto input = InputRef(0).ValueTensorFor(rank, fr);
result.DoUnaryOpOf(0, input, 1, opForward, opSum);
}
@ -61,7 +61,7 @@ public:
// get the args
size_t rank = DetermineElementwiseTensorRank();
auto sliceOutputGrad = GradientTensorFor(rank, fr); // propagate from this one...
auto sliceOutputGrad = GradientTensorFor(rank, fr); // propagate from this one...
auto sliceInputGrad = InputRef(0).GradientTensorFor(rank, fr); // ...to this one
GradientOperationType opTypeHolder = opType; // preventing pragma warning C4127
@ -544,10 +544,10 @@ public:
if (inputIndex == 2)
{
size_t rank = DetermineElementwiseTensorRank();
auto gradient = GradientTensorFor(rank, fr);
auto gradient = GradientTensorFor(rank, fr);
auto inputGradient = InputRef(inputIndex).GradientTensorFor(rank, fr.AllowBroadcast());
auto input = InputRef(inputIndex).ValueTensorFor(rank, fr.AllowBroadcast());
auto output = ValueTensorFor(rank, fr.AllowBroadcast());
auto output = ValueTensorFor(rank, fr.AllowBroadcast());
inputGradient.AddCopyIfEqualOf(input, output, gradient);
}

Просмотреть файл

@ -126,7 +126,7 @@ template <class ElemType>
fprintf(stderr, "] %ls %s--> %s\n", m_message.c_str(), logGradientInstead ? "(gradient) " : "", InputRef(0).FormatOperationPrototype("").c_str());
InputRef(0).WriteMinibatchWithFormatting(stderr, fr, m_onlyUpToRow, m_onlyUpToT, m_formattingOptions.transpose, m_formattingOptions.isCategoryLabel, m_formattingOptions.isSparse, m_labelMapping,
sequenceSeparator, sequencePrologue, sequenceEpilogue, elementSeparator, sampleSeparator,
valueFormatString, logGradientInstead);
valueFormatString, logGradientInstead);
}
}