implemented TransposeNode for arbitrary dimensions (not serializing yet);

added plumbing for tensor Times
This commit is contained in:
Frank Seide 2016-02-22 17:45:46 -08:00
Родитель 1fb343866b
Коммит 407c5a5546
10 изменённых файлов: 220 добавлений и 209 удалений

Просмотреть файл

@ -324,20 +324,20 @@ static ConfigValuePtr NodeOp(const ExpressionPtr &e, ConfigValuePtr leftVal, Con
// When they fetch their parameters, they should only look in this record, not in any parent scope (if they don't find what they are looking for, it's a bug in this routine here).
// The values themselves are already in ConfigValuePtr form, so we won't need any scope lookups there either.
config->Add(L"operation", MakeFailFn(e->location), ConfigValuePtr(make_shared<String>(operationName), MakeFailFn(e->location), exprPath));
let leftFailFn = leftVal.GetFailFn(); // report any error for this Constant object as belonging to the scalar factor's expression
vector<ConfigValuePtr> inputs;
if (operationName == L"Scale")
{
// if we scale, the first operand is a Double, and we must convert that into a 1x1 Constant
// TODO: apply this more generally to all operators
auto constantConfig = make_shared<ConfigRecord>(config, MakeFailFn(e->location));
let leftFailFn = leftVal.GetFailFn(); // report any error for this Constant object as belonging to the scalar factor's expression
constantConfig->Add(L"operation", leftFailFn, ConfigValuePtr(make_shared<String>(L"LearnableParameter"), leftFailFn, exprPath));
let one = MakePrimitiveConfigValuePtr(1.0, leftVal.GetFailFn(), exprPath);
let one = MakePrimitiveConfigValuePtr(1.0, leftFailFn, exprPath);
constantConfig->Add(L"rows", leftFailFn, one);
constantConfig->Add(L"cols", leftFailFn, one);
//constantConfig->Add(L"shape", leftFailFn, one); // BUGBUG: rows,cols is no longer right, we need a TensorShape here
constantConfig->Add(L"value", leftFailFn, leftVal);
constantConfig->Add(L"needGradient", leftFailFn, MakePrimitiveConfigValuePtr(false, leftVal.GetFailFn(), exprPath));
constantConfig->Add(L"needGradient", leftFailFn, MakePrimitiveConfigValuePtr(false, leftFailFn, exprPath));
let value = ConfigValuePtr(rtInfo->construct(constantConfig), leftFailFn, exprPath);
let valueWithName = dynamic_cast<HasName *>(value.get());
if (valueWithName)
@ -347,8 +347,13 @@ static ConfigValuePtr NodeOp(const ExpressionPtr &e, ConfigValuePtr leftVal, Con
inputs.push_back(leftVal);
if (operationName != L"Negate") // Negate only has one input (rightVal is a nullptr)
inputs.push_back(rightVal);
config->Add(L"inputs", leftVal.GetFailFn(), ConfigValuePtr(make_shared<ConfigArray>(0, move(inputs)), leftVal.GetFailFn(), exprPath));
config->Add(L"tag", leftVal.GetFailFn(), ConfigValuePtr(make_shared<String>(), leftVal.GetFailFn(), exprPath)); // infix nodes have no tag
config->Add(L"inputs", leftFailFn, ConfigValuePtr(make_shared<ConfigArray>(0, move(inputs)), leftFailFn, exprPath));
config->Add(L"tag", leftFailFn, ConfigValuePtr(make_shared<String>(), leftFailFn, exprPath)); // infix nodes have no tag
if (operationName == L"Times")
{
let one = MakePrimitiveConfigValuePtr(1.0, leftFailFn, exprPath);
config->Add(L"outputRank", leftFailFn, one);
}
// instantiate the ComputationNode
let value = ConfigValuePtr(rtInfo->construct(config), MakeFailFn(e->location), exprPath);
let valueWithName = dynamic_cast<HasName *>(value.get());

Просмотреть файл

@ -33,40 +33,43 @@ wstring commonMacros =
L"LogPrior(labels) = Log(Mean(labels)) \n";
wstring computationNodes = // TODO: use actual TypeName() here? would first need to make it a wide string; we should also extract those two methods into the base macro
L"LearnableParameter(rows, cols, needGradient = true, init = 'uniform'/*|fixedValue|gaussian|fromFile*/, initValueScale = 1, value = 0, initFromFilePath = '', initOnCPUOnly=true, randomSeed=-1, tag='') = new ComputationNode [ operation = 'LearnableParameter' ; shape = new TensorShape [ dims = (rows : cols) ] /*plus the function args*/ ]\n"
L"Parameter = LearnableParameter // deprecated \n"
L"ParameterTensor(dims, needGradient = true, init = 'uniform'/*|fixedValue|gaussian|fromFile*/, initValueScale = 1, value = 0, initFromFilePath = '', initOnCPUOnly=true, randomSeed=-1, tag='') = new ComputationNode [ operation = 'LearnableParameter' ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]\n"
// TODO: ImageParameter?
// ^^ already works; vv untested
L"Input(dims, tag='feature') = new ComputationNode [ operation = 'InputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]\n" // note: naming a little inconsistent // TODO: re-test after flag change
L"SparseInput(dims, tag='feature') = new ComputationNode [ operation = 'SparseInputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]\n"
L"ImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', tag='feature') = new ComputationNode [ operation = 'InputValue' ; isImage = true /*plus the function args*/ ]\n"
L"SparseImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', tag='feature') = new ComputationNode [ operation = 'SparseInputValue' ; isImage = true /*plus the function args*/ ]\n"
L"Constant(val, rows = 1, cols = 1, tag='') = Parameter(rows, cols, needGradient = false, init = 'fixedValue', value = val) \n"
L"PastValue(dims, input, timeStep = 1, defaultHiddenActivation = 0.1, tag='') = new ComputationNode [ operation = 'PastValue' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]\n"
L"FutureValue(dims, input, timeStep = 1, defaultHiddenActivation = 0.1, tag='') = new ComputationNode [ operation = 'FutureValue' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]\n"
// TODO: ^^ DelayedValues no longer need to know their dimension. That is inferred in Validation.
L"Shift(input, fromOffset, boundaryValue, boundaryMode=-1/*context*/, dim=-1, tag='') = new ComputationNode [ operation = 'Shift' ; inputs = (input : boundaryValue) /*plus the function args*/ ]\n"
L"RowSlice(startIndex, numRows, input, needGradient = false, tag='') = new ComputationNode [ operation = 'RowSlice' ; inputs = input /*plus the function args*/ ]\n"
L"RowRepeat(input, numRepeats, needGradient = false, tag='') = new ComputationNode [ operation = 'RowRepeat' ; inputs = input /*plus the function args*/ ]\n"
L"RowStack(inputs, tag='') = new ComputationNode [ operation = 'RowStack' /*plus the function args*/ ]\n"
L"Reshape(input, numRows, imageWidth = 0, imageHeight = 0, imageChannels = 0, tag='') = new ComputationNode [ operation = 'LegacyReshape' ; inputs = input /*plus the function args*/ ]\n"
L"NewReshape(input, dims, beginDim=0, endDim=0, tag='') = new ComputationNode [ operation = 'Reshape' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]\n"
L"ReshapeDimension(x, dim, tensorShape) = NewReshape(x, tensorShape, beginDim=dim, endDim=dim + 1) \n"
L"FlattenDimensions(x, dim, num) = NewReshape(x, 0, beginDim=dim, endDim=dim + num) \n"
L"SplitDimension(x, dim, N) = ReshapeDimension(x, dim, 0:N) \n"
L"Logistic(label, probability, tag='') = new ComputationNode [ operation = 'Logistic' ; inputs = (label : probability) /*plus the function args*/ ]\n"
L"WeightedLogistic(label, probability, instanceWeight, tag='') = new ComputationNode [ operation = 'Logistic' ; inputs = (label : probability : instanceWeight) /*plus the function args*/ ]\n"
L"ReconcileMBLayout(dataInput, layoutInput, tag='') = new ComputationNode [ operation = 'ReconcileMBLayout' ; inputs = (dataInput : layoutInput) /*plus the function args*/ ]\n"
L"Convolution(weightNode, inputValueNode, kernelWidth, kernelHeight, outputChannels, horizontalSubsample, verticalSubsample, zeroPadding = false, maxTempMemSizeInSamples = 0, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'Convolution' ; inputs = (weightNode : inputValueNode) /*plus the function args*/ ]\n"
L"MaxPooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'MaxPooling' ; inputs = input /*plus the function args*/ ]\n"
L"AveragePooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'AveragePooling' ; inputs = input /*plus the function args*/ ]\n"
// TODO: define DelayedValue, with negative delay for future; cannot do this yet, need to be able to say something like delay = -(^.delay)
// aliases
L"ColumnwiseCrossProduct = KhatriRaoProduct // deprecated \n" // TODO: should it be deprecated? It is described as easier to understand in the CNTKBook.
L"ClassificationError = ErrorPrediction \n"
L"Delay = PastValue \n" // TODO: should it allow negative offsets and an if test here?
L"BatchNormalization(input, scale, bias, runMean, runInvStdDev, eval, spatial, expAvgFactor = 1.0, epsilon = 0.00001, useCntkEngine = true, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'BatchNormalization' ; inputs = (input : scale : bias : runMean : runInvStdDev) /*plus the function args*/ ]\n"
L"LearnableParameter(rows, cols, needGradient = true, init = 'uniform'/*|fixedValue|gaussian|fromFile*/, initValueScale = 1, value = 0, initFromFilePath = '', initOnCPUOnly=true, randomSeed=-1, tag='') = new ComputationNode [ operation = 'LearnableParameter' ; shape = new TensorShape [ dims = (rows : cols) ] /*plus the function args*/ ]\n"
L"Parameter = LearnableParameter // deprecated \n"
L"ParameterTensor(dims, needGradient = true, init = 'uniform'/*|fixedValue|gaussian|fromFile*/, initValueScale = 1, value = 0, initFromFilePath = '', initOnCPUOnly=true, randomSeed=-1, tag='') = new ComputationNode [ operation = 'LearnableParameter' ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]\n"
// TODO: ImageParameter?
// ^^ already works; vv untested
L"Input(dims, tag='feature') = new ComputationNode [ operation = 'InputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]\n" // note: naming a little inconsistent // TODO: re-test after flag change
L"SparseInput(dims, tag='feature') = new ComputationNode [ operation = 'SparseInputValue' ; shape = new TensorShape [ /*dims*/ ] ; isImage = false /*plus the function args*/ ]\n"
L"ImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', tag='feature') = new ComputationNode [ operation = 'InputValue' ; isImage = true /*plus the function args*/ ]\n"
L"SparseImageInput(imageWidth, imageHeight, imageChannels, imageLayout='CHW', tag='feature') = new ComputationNode [ operation = 'SparseInputValue' ; isImage = true /*plus the function args*/ ]\n"
L"Constant(val, rows = 1, cols = 1, tag='') = Parameter(rows, cols, needGradient = false, init = 'fixedValue', value = val) \n"
L"PastValue(dims, input, timeStep = 1, defaultHiddenActivation = 0.1, tag='') = new ComputationNode [ operation = 'PastValue' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]\n"
L"FutureValue(dims, input, timeStep = 1, defaultHiddenActivation = 0.1, tag='') = new ComputationNode [ operation = 'FutureValue' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]\n"
// TODO: ^^ DelayedValues no longer need to know their dimension. That is inferred in Validation.
L"Shift(input, fromOffset, boundaryValue, boundaryMode=-1/*context*/, dim=-1, tag='') = new ComputationNode [ operation = 'Shift' ; inputs = (input : boundaryValue) /*plus the function args*/ ]\n"
L"RowSlice(startIndex, numRows, input, needGradient = false, tag='') = new ComputationNode [ operation = 'RowSlice' ; inputs = input /*plus the function args*/ ]\n"
L"RowRepeat(input, numRepeats, needGradient = false, tag='') = new ComputationNode [ operation = 'RowRepeat' ; inputs = input /*plus the function args*/ ]\n"
L"RowStack(inputs, tag='') = new ComputationNode [ operation = 'RowStack' /*plus the function args*/ ]\n"
L"Reshape(input, numRows, imageWidth = 0, imageHeight = 0, imageChannels = 0, tag='') = new ComputationNode [ operation = 'LegacyReshape' ; inputs = input /*plus the function args*/ ]\n"
L"NewReshape(input, dims, beginDim=0, endDim=0, tag='') = new ComputationNode [ operation = 'Reshape' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]\n"
L"ReshapeDimension(x, dim, tensorShape) = NewReshape(x, tensorShape, beginDim=dim, endDim=dim + 1) \n"
L"FlattenDimensions(x, dim, num) = NewReshape(x, 0, beginDim=dim, endDim=dim + num) \n"
L"SplitDimension(x, dim, N) = ReshapeDimension(x, dim, 0:N) \n"
L"TransposeDimensions(input, dim1, dim2, tag='') = new ComputationNode [ operation = 'TransposeDimensions' ; inputs = input /*plus the function args*/ ]\n"
L"Transpose(x) = TransposeDimensions(x, 1, 2)\n"
L"Times(A, B, outputRank=1) = new ComputationNode [ operation = 'Times' ; inputs = ( A : B ) /*plus the function args*/ ]\n"
L"Logistic(label, probability, tag='') = new ComputationNode [ operation = 'Logistic' ; inputs = (label : probability) /*plus the function args*/ ]\n"
L"WeightedLogistic(label, probability, instanceWeight, tag='') = new ComputationNode [ operation = 'Logistic' ; inputs = (label : probability : instanceWeight) /*plus the function args*/ ]\n"
L"ReconcileMBLayout(dataInput, layoutInput, tag='') = new ComputationNode [ operation = 'ReconcileMBLayout' ; inputs = (dataInput : layoutInput) /*plus the function args*/ ]\n"
L"Convolution(weightNode, inputValueNode, kernelWidth, kernelHeight, outputChannels, horizontalSubsample, verticalSubsample, zeroPadding = false, maxTempMemSizeInSamples = 0, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'Convolution' ; inputs = (weightNode : inputValueNode) /*plus the function args*/ ]\n"
L"MaxPooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'MaxPooling' ; inputs = input /*plus the function args*/ ]\n"
L"AveragePooling(input, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'AveragePooling' ; inputs = input /*plus the function args*/ ]\n"
// TODO: define DelayedValue, with negative delay for future; cannot do this yet, need to be able to say something like delay = -(^.delay)
// aliases
L"ColumnwiseCrossProduct = KhatriRaoProduct // deprecated \n" // TODO: should it be deprecated? It is described as easier to understand in the CNTKBook.
L"ClassificationError = ErrorPrediction \n"
L"Delay = PastValue \n" // TODO: should it allow negative offsets and an if test here?
L"BatchNormalization(input, scale, bias, runMean, runInvStdDev, eval, spatial, expAvgFactor = 1.0, epsilon = 0.00001, useCntkEngine = true, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'BatchNormalization' ; inputs = (input : scale : bias : runMean : runInvStdDev) /*plus the function args*/ ]\n"
// standard nodes. We use macros to define these strings.
#define UnaryStandardNode(Op, a) L## #Op L"(" L## #a L", tag='') = new ComputationNode [ operation = '" L## #Op L"' ; inputs = " L## #a L" /*plus the function args*/ ]\n"
#define BinaryStandardNode(Op, a, b) L## #Op L"(" L## #a L", " L## #b L", tag='') = new ComputationNode [ operation = '" L## #Op L"' ; inputs = (" L## #a L" : " L## #b L") /*plus the function args*/ ]\n"
@ -120,10 +123,6 @@ wstring computationNodes = // TODO: use actual TypeName() here? would first need
UnaryStandardNode(SumElements, matrix)
UnaryStandardNode(Tanh, z)
UnaryStandardNode(TimeReverse, vectorSequence)
BinaryStandardNode(Times, leftMatrix, rightMatrix)
#ifdef COMING_SOON
UnaryStandardNode(Transpose, matrix)
#endif
BinaryStandardNode(TransposeTimes, leftMatrix, rightMatrix)
// those nodes are deprecated, we won't implement them in BS:
//BinaryStandardNode(NoiseContrastiveEstimationNode)

Просмотреть файл

@ -209,10 +209,9 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable)
else if (EqualInsensitive(nodeType, OperationNameOf(SumElementsNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(TanhNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(TimesNode))) ret = true;
#ifdef COMING_SOON
else if (EqualInsensitive(nodeType, OperationNameOf(TransposeNode))) ret = true;
#endif
//else if (EqualInsensitive(nodeType, OperationNameOf(TransposeDimensionsNode))) ret = true; // not supported from NDL, use Transpose()
else if (EqualInsensitive(nodeType, OperationNameOf(TransposeTimesNode))) ret = true;
// legacy names:
else if (EqualInsensitive(nodeType, L"ColumnElementTimes")) ret = true;
else if (EqualInsensitive(nodeType, L"Constant", L"Const")) ret = true;
else if (EqualInsensitive(nodeType, L"ImageInput", L"Image")) ret = true;
@ -220,6 +219,7 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable)
else if (EqualInsensitive(nodeType, L"RowElementTimes")) ret = true;
else if (EqualInsensitive(nodeType, L"Scale")) ret = true;
else if (EqualInsensitive(nodeType, L"SparseImageInput", L"SparseImage")) ret = true;
else if (EqualInsensitive(nodeType, L"Transpose")) ret = true;
// return the actual node name in the parameter if we found something
if (ret)

Просмотреть файл

@ -97,7 +97,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildFFDNNFromDescription(
w = builder.CreateLearnableParameter(L"W0", m_layerSizes[1], m_layerSizes[0]);
m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale);
b = builder.CreateLearnableParameter(L"B0", m_layerSizes[1], 1);
output = ApplyNonlinearFunction(builder.Plus(builder.Times(w, input, L"W0*features"), b, L"W0*features+B0"), 0, L"H1");
output = ApplyNonlinearFunction(builder.Plus(builder.Times(w, input, 1, L"W0*features"), b, L"W0*features+B0"), 0, L"H1");
if (m_addDropoutNodes)
input = builder.Dropout(output, L"DropH1");
@ -116,7 +116,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildFFDNNFromDescription(
w = builder.CreateLearnableParameter(nameOfW, m_layerSizes[i + 1], m_layerSizes[i]);
m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale);
b = builder.CreateLearnableParameter(nameOfB, m_layerSizes[i + 1], 1);
output = ApplyNonlinearFunction(builder.Plus(builder.Times(w, input, nameOfTimes), b, nameOfPlus), i, nameOfH);
output = ApplyNonlinearFunction(builder.Plus(builder.Times(w, input, 1, nameOfTimes), b, nameOfPlus), i, nameOfH);
if (m_addDropoutNodes)
input = builder.Dropout(output, L"Drop" + nameOfH);
@ -134,7 +134,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildFFDNNFromDescription(
w = builder.CreateLearnableParameter(nameOfW, m_layerSizes[numHiddenLayers + 1], m_layerSizes[numHiddenLayers]);
m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale);
b = builder.CreateLearnableParameter(nameOfB, m_layerSizes[numHiddenLayers + 1], 1);
output = builder.Plus(builder.Times(w, input, nameOfTimes), b, nameOfPlus);
output = builder.Plus(builder.Times(w, input, 1, nameOfTimes), b, nameOfPlus);
m_net->RenameNode(output, L"HLast");
label = builder.CreateInputNode(L"labels", m_layerSizes[numHiddenLayers + 1]);
@ -264,7 +264,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildRNNFromDescription()
label = builder.CreateInputNode(L"labels", m_layerSizes[numHiddenLayers + 1]);
AddTrainAndEvalCriterionNodes(input, label, w, L"criterion", L"eval");
output = builder.Times(w, input, L"outputs");
output = builder.Times(w, input, 1, L"outputs");
m_net->OutputNodes().push_back(output);
@ -379,7 +379,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildClassEntropyRNNFromDe
clsweight = builder.CreateLearnableParameter(L"WeightForClassPostProb", m_nbrCls, m_layerSizes[numHiddenLayers]);
m_net->InitLearnableParameters(clsweight, m_uniformInit, randomSeed++, m_initValueScale);
clslogpostprob = builder.Times(clsweight, input, L"ClassPostProb");
clslogpostprob = builder.Times(clsweight, input, 1, L"ClassPostProb");
output = AddTrainAndEvalCriterionNodes(input, label, w, L"TrainNodeClassBasedCrossEntropy", L"EvalNodeClassBasedCrossEntrpy",
clslogpostprob);
@ -481,7 +481,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildConditionalLSTMNetwor
clsweight = builder.CreateLearnableParameter(L"WeightForClassPostProb", m_nbrCls, m_layerSizes[numHiddenLayers]);
m_net->InitLearnableParameters(clsweight, m_uniformInit, randomSeed++, m_initValueScale);
clslogpostprob = builder.Times(clsweight, input, L"ClassPostProb");
clslogpostprob = builder.Times(clsweight, input, 1, L"ClassPostProb");
output = AddTrainAndEvalCriterionNodes(input, label, w, L"TrainNodeClassBasedCrossEntropy", L"EvalNodeClassBasedCrossEntrpy",
clslogpostprob);
@ -601,7 +601,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildLogBilinearNetworkFro
label = builder.CreateInputNode(L"labels", m_layerSizes[numHiddenLayers + 1]);
AddTrainAndEvalCriterionNodes(input, label, w);
output = builder.Times(w, input, L"outputs");
output = builder.Times(w, input, 1, L"outputs");
m_net->OutputNodes().push_back(output);
@ -1128,7 +1128,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildClassLSTMNetworkFromD
clsweight = builder.CreateLearnableParameter(L"WeightForClassPostProb", m_nbrCls, m_layerSizes[numHiddenLayers]);
m_net->InitLearnableParameters(clsweight, m_uniformInit, randomSeed++, m_initValueScale);
clslogpostprob = builder.Times(clsweight, input, L"ClassPostProb");
clslogpostprob = builder.Times(clsweight, input, 1, L"ClassPostProb");
output = AddTrainAndEvalCriterionNodes(input, label, w, L"TrainNodeClassBasedCrossEntropy", L"EvalNodeClassBasedCrossEntrpy",
clslogpostprob);
@ -1296,7 +1296,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildLSTMNetworkFromDescri
label = builder.CreateInputNode(L"labels", m_layerSizes[numHiddenLayers + 1]);
AddTrainAndEvalCriterionNodes(input, label, w);
output = builder.Times(w, input, L"outputs");
output = builder.Times(w, input, 1, L"outputs");
if (m_needPrior)
{
@ -1414,7 +1414,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildNCELSTMNetworkFromDes
bias = builder.CreateLearnableParameter(L"BiasVector", 1, m_layerSizes[m_layerSizes.size() - 1]);
bias->Value().SetValue((ElemType) -std::log(m_layerSizes[m_layerSizes.size() - 1]));
// m_net->InitLearnableParameters(bias, m_uniformInit, randomSeed++, std::log(m_layerSizes[m_layerSizes.size() - 1])* m_initValueScale);
// clslogpostprob = builder.Times(clsweight, input, L"ClassPostProb");
// clslogpostprob = builder.Times(clsweight, input, 1, L"ClassPostProb");
output = AddTrainAndEvalCriterionNodes(input, label, w, L"TrainNodeNCEBasedCrossEntropy", L"EvalNodeNCEBasedCrossEntrpy", bias);
@ -1532,17 +1532,17 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildNetworkFromDbnFile(co
if (layerType == "perceptron")
{
fprintf(stderr, "DBN: Reading (%lu x %lu) perceptron\n", (unsigned long) wts.GetNumRows(), (unsigned long) wts.GetNumCols());
output = builder.Plus(builder.Times(w, input, nameOfTimes), b, nameOfPlus);
output = builder.Plus(builder.Times(w, input, 1, nameOfTimes), b, nameOfPlus);
}
else if (layerType == "rbmisalinearbernoulli")
{
fprintf(stderr, "DBN: Reading (%lu x %lu) linear layer\n", (unsigned long) wts.GetNumRows(), (unsigned long) wts.GetNumCols());
output = builder.Plus(builder.Times(w, input, nameOfTimes), b, nameOfPlus);
output = builder.Plus(builder.Times(w, input, 1, nameOfTimes), b, nameOfPlus);
}
else // assume rbmbernoullibernoulli
{
fprintf(stderr, "DBN: Reading (%lu x %lu) non-linear layer\n", (unsigned long) wts.GetNumRows(), (unsigned long) wts.GetNumCols());
output = ApplyNonlinearFunction(builder.Plus(builder.Times(w, input, nameOfTimes), b, nameOfPlus), i, nameOfH);
output = ApplyNonlinearFunction(builder.Plus(builder.Times(w, input, 1, nameOfTimes), b, nameOfPlus), i, nameOfH);
if (m_addDropoutNodes)
input = builder.Dropout(output, L"Drop" + nameOfH);
}
@ -1589,7 +1589,7 @@ ComputationNetworkPtr SimpleNetworkBuilder<ElemType>::BuildNetworkFromDbnFile(co
w = builder.CreateLearnableParameter(nameOfW, outputLayerSize, penultimateSize);
m_net->InitLearnableParameters(w, m_uniformInit, randomSeed++, m_initValueScale);
b = builder.CreateLearnableParameter(nameOfB, outputLayerSize, 1);
output = builder.Plus(builder.Times(w, input, nameOfTimes), b, nameOfPlus);
output = builder.Plus(builder.Times(w, input, 1, nameOfTimes), b, nameOfPlus);
m_net->RenameNode(output, L"HLast");
if (m_needPrior)

Просмотреть файл

@ -74,7 +74,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// - allows output and input tensors (TimesNode will get optional parameter how many leading dims to not contract), e.g.
// A[U,V,I,J] * B[I,J,S,T] -> C[U,V,S,T], c_uvst = sum_ij a_uvij * b_ijst
// - for now this operation must be flattenable as to be implementable as SGEMM (may extend in the future)
// - tensor transpose -> TransposeNode
// - tensor transpose -> TransposeDimensionsNode
// - swaps any two dimensions. This does not change the column-major definition, i.e. requires a memory copy.
// - special case: swapping between sample and MBLayout, e.g. turn a sample dimension to a time dimension
// - Validate() stage will automatically infer tensor dimensions from inputs, and also infer downwards into LearnableParameters where requested
@ -643,6 +643,14 @@ public:
NarrowTo(k, (size_t)bounds.first[k], (size_t)bounds.second[k]);
return *this;
}
// swap two existing dimensions (implements transposition)
void SwapDimsInPlace(size_t i, size_t j)
{
if (i == j) // this is OK
return;
std::swap(m_dims[i], m_dims[j]);
std::swap(m_strides[i], m_strides[j]);
}
// compare two TensorShapes, whether they are compatible, considering padding and broadcasting
bool IsElementwiseCompatibleWith(const TensorShape& other) const

Просмотреть файл

@ -69,6 +69,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
else if (nodeType == OperationNameOf(PastValueNode)) return New<PastValueNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(PerDimMeanVarNormalizationNode)) return New<PerDimMeanVarNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(PerDimMeanVarDeNormalizationNode)) return New<PerDimMeanVarDeNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(TransposeDimensionsNode)) return New<TransposeDimensionsNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(PlusNode)) return New<PlusNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ReconcileMBLayoutNode)) return New<ReconcileMBLayoutNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(RectifiedLinearNode)) return New<RectifiedLinearNode<ElemType>>(forward<_Types>(_Args)...);
@ -91,17 +92,16 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
else if (nodeType == OperationNameOf(SumElementsNode)) return New<SumElementsNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(TanhNode)) return New<TanhNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(TimesNode)) return New<TimesNode<ElemType>>(forward<_Types>(_Args)...);
#ifdef COMING_SOON
else if (nodeType == OperationNameOf(TransposeNode)) return New<TransposeNode<ElemType>>(forward<_Types>(_Args)...);
#endif
else if (nodeType == OperationNameOf(TransposeDimensionsNode)) return New<TransposeDimensionsNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(TransposeTimesNode)) return New<TransposeTimesNode<ElemType>>(forward<_Types>(_Args)...);
// old names we also support
// legacy names we also support for back compat of model-files
else if (nodeType == L"ColumnElementTimes") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"Delay") return New<PastValueNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"PerDimMeanVarNormalizationNode") return New<PerDimMeanVarNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"PerDimMeanVarDeNormalizationNode") return New<PerDimMeanVarDeNormalizationNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"RowElementTimes") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"Scale") return New<ElementTimesNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == L"Transpose") return New<TransposeDimensionsNode<ElemType>>(forward<_Types>(_Args)...);
#if 1
else if (nodeType == OperationNameOf(LegacyReshapeNode)) return New<LegacyReshapeNode<ElemType>>(forward<_Types>(_Args)...);
#endif
@ -463,18 +463,16 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Sum(c
return net.AddNodeToNetAndAttachInputs(New<SumElementsNode<ElemType>>(net.GetDeviceId(), nodeName), a);
}
#ifdef COMING_SOON
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Transpose(const ComputationNodePtr matrix, const std::wstring nodeName)
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::TransposeDimensions(const ComputationNodePtr matrix, int dim1, int dim2, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<TransposeNode<ElemType>>(net.GetDeviceId(), nodeName), matrix);
return net.AddNodeToNetAndAttachInputs(New<TransposeDimensionsNode<ElemType>>(net.GetDeviceId(), nodeName, dim1, dim2), matrix);
}
#endif
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Times(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName)
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Times(const ComputationNodePtr a, const ComputationNodePtr b, size_t outputRank, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<TimesNode<ElemType>>(net.GetDeviceId(), nodeName), a, b);
return net.AddNodeToNetAndAttachInputs(New<TimesNode<ElemType>>(net.GetDeviceId(), nodeName, outputRank), a, b);
}
template <class ElemType>

Просмотреть файл

@ -125,10 +125,8 @@ public:
ComputationNodePtr SquareError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr Sum(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr Tanh(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr Times(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
#ifdef COMING_SOON
ComputationNodePtr Transpose(const ComputationNodePtr matrix, const std::wstring nodeName = L"");
#endif
ComputationNodePtr Times(const ComputationNodePtr a, const ComputationNodePtr b, size_t outputRank = 1, const std::wstring nodeName = L"");
ComputationNodePtr TransposeDimensions(const ComputationNodePtr matrix, int dim1, int dim2, const std::wstring nodeName = L"");
ComputationNodePtr TransposeTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
#if 1 // legacy
ComputationNodePtr LegacyReshape(const ComputationNodePtr a, const size_t num_rows, const TensorShape& imageLayout, const std::wstring nodeName = L"");

Просмотреть файл

@ -366,7 +366,7 @@ public:
// - as a Matrix reference
// - actual object is a 2D tensor without MB Layout
// - ValueAsMatrix(), GradientAsMatrix() returns tensor as a 2D Matrix object
// - nodes that do this are: TimesNode, DiagTimesNode, ConvolutionNode, NoiseContrastiveEstimationNode, ClassBasedCrossEntropyWithSoftmaxNode, TransposeNode, DiagonalNode
// - nodes that do this are: TimesNode, DiagTimesNode, ConvolutionNode, NoiseContrastiveEstimationNode, ClassBasedCrossEntropyWithSoftmaxNode, TransposeDimensionsNode, DiagonalNode
//
// How values are stored:
//
@ -919,18 +919,18 @@ public:
// creation from configuration
// Nodes with NumInputs<> should say DeclareConstructorFromConfigWithNumInputs(ClassName), and nodes without DeclareConstructorFromConfig(ClassName).
// The macro will forward to the regular constructor of the node (which may do more than just calling the base constructor), and then attach the inputs from config.
#define DeclareConstructorFromConfig(C) \
C(const ScriptableObjects::IConfigRecordPtr configp) \
: C(configp->Get(L"deviceId"), L"<placeholder>") \
{ \
AttachInputs(configp); \
}
#define DeclareConstructorFromConfigWithNumInputs(C) \
C(const ScriptableObjects::IConfigRecordPtr configp) \
: C(configp->Get(L"deviceId"), L"<placeholder>") \
{ \
AttachInputs(configp, this->GetExpectedNumInputs()); \
}
#define DeclareConstructorFromConfig(C) \
C(const ScriptableObjects::IConfigRecordPtr configp) \
: C(configp->Get(L"deviceId"), L"<placeholder>") \
{ \
AttachInputs(configp); \
}
// helper to load m_value from a stream
// This function updates the dimensions to a 2D matrix.

Просмотреть файл

@ -44,6 +44,15 @@ public:
{
}
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
auto result = ValueTensorFor(rank, fr);
auto input0 = Input(0)->ValueTensorFor(rank, fr.AllowBroadcast());
auto input1 = Input(1)->ValueTensorFor(rank, fr.AllowBroadcast());
result.AssignSumOf(input0, input1);
}
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
@ -56,15 +65,6 @@ public:
inputGradient.AddCopyOf(gradient);
}
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
auto result = ValueTensorFor(rank, fr);
auto input0 = Input(0)->ValueTensorFor(rank, fr.AllowBroadcast());
auto input1 = Input(1)->ValueTensorFor(rank, fr.AllowBroadcast());
result.AssignSumOf(input0, input1);
}
};
template class PlusNode<float>;
@ -175,9 +175,9 @@ template class NegateNode<float>;
template class NegateNode<double>;
// -----------------------------------------------------------------------
// TimesNodeBase (A, B)
// TimesNodeBase (A, B, outputRank=1)
// shared code of TimesNode and TransposeTimesNode (which transposes A)
// right operand and output can have MB layout, while left operand cannot
// Right operand and output can have MB layout, while left operand cannot.
// -----------------------------------------------------------------------
template <class ElemType, bool m_transpose>
@ -187,7 +187,7 @@ class TimesNodeBase : public ComputationNode<ElemType>, public NumInputs<2>
UsingComputationNodeMembers;
public:
TimesNodeBase(DEVICEID_TYPE deviceId, const wstring& name)
TimesNodeBase(DEVICEID_TYPE deviceId, const wstring& name, size_t outputRank = 1/*TODO: complete this*/)
: Base(deviceId, name)
{
}
@ -220,12 +220,8 @@ public:
}
}
virtual bool OutputUsedInComputingInputNodesGradients() const override
{
// The TimesNode does not require its output value for computing
// the gradients of its input nodes
return false;
}
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
// but both inputs are
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
@ -302,8 +298,19 @@ public:
};
// -----------------------------------------------------------------------
// TimesNode (A, B)
// right operand and output can have MB layout, while left operand cannot
// TimesNode (A, B, outputRank=1) -- matrix product
// Right operand and output can have MB layout, while left operand cannot.
// This is generalized to tensors, in that B's leading dimension(s) must match
// the trailing dimension(s) of A, and those dimensions get flattened into a single
// dimension, after which the regular matrix product is applied.
// Leading dimension(s) of A and trailing ones of B remain unchanged.
// For example:
// [I x J x K] * [J x K x L x *] = [I x L x *]
// How many dimensions must match is controlled by an optional parameter
// 'outputRank', where all but the first 'outputRank' dimensions of A must match.
// 'outputRank' defaults to 1, which is used in the above example.
// Example for outputRank = 2:
// [I x J x K] * [K x L x M x *] = [I x L x M x *]
// -----------------------------------------------------------------------
template <class ElemType>
@ -311,17 +318,18 @@ class TimesNode : public TimesNodeBase<ElemType, false>
{
typedef TimesNodeBase<ElemType, false> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
{
return L"Times";
}
static const std::wstring TypeName() { return L"Times"; }
public:
DeclareConstructorFromConfigWithNumInputs(TimesNode);
TimesNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name)
TimesNode(DEVICEID_TYPE deviceId, const wstring& name, size_t outputRank = 1)
: Base(deviceId, name, outputRank)
{
}
TimesNode(const ScriptableObjects::IConfigRecordPtr configp)
: TimesNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"outputRank"))
{
AttachInputs(configp, this->GetExpectedNumInputs());
}
};
template class TimesNode<float>;
@ -329,7 +337,9 @@ template class TimesNode<double>;
// -----------------------------------------------------------------------
// TransposeTimesNode (A', B)
// right operand and output can have MB layout, while left operand cannot
// Right operand and output can have MB layout, while left operand cannot.
// This differs from TimesNode in that A is transposed, where A must be a
// rank-1 or rank-2 tensor.
// -----------------------------------------------------------------------
template <class ElemType>
@ -337,15 +347,12 @@ class TransposeTimesNode : public TimesNodeBase<ElemType, true>
{
typedef TimesNodeBase<ElemType, true> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
{
return L"TransposeTimes";
}
static const std::wstring TypeName() { return L"TransposeTimes"; }
public:
DeclareConstructorFromConfigWithNumInputs(TransposeTimesNode);
TransposeTimesNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name)
: Base(deviceId, name, /*outputRank=*/1)
{
}
};
@ -650,97 +657,107 @@ public:
template class SumColumnElementsNode<float>;
template class SumColumnElementsNode<double>;
#ifdef COMING_SOON // known bug in backprop; generalize to tensor
// -----------------------------------------------------------------------
// TransposeNode (input matrix)
// TODO: extend towards tensor transpose (swap 2 dimensions, incl. time)
// TransposeDimensionsNode (input, dim1, dim2)
// - swaps index dimensions dim1 and dim2. The values are 1-based; 1 stands for the leading dimension.
// - new dimensions can be created; e.g. a column vector can be transposed into a row vector, which is a [1 x N] tensor
// - transposing into the time dimension is currently not supported
// - internally implemented with tensor lib by shuffling dimensions with their strides
// - input may be minibatch data or not
// Transpose (input) = TransposeDimensions (input, 1, 2)
// -----------------------------------------------------------------------
template <class ElemType>
class TransposeNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<1>
class TransposeDimensionsNode : public ComputationNode /*ComputationNode*/<ElemType>, public NumInputs<1>
{
typedef ComputationNodeNonLooping<ElemType> Base;
typedef ComputationNode<ElemType> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
static const std::wstring TypeName() { return L"TransposeDimensions"; }
public:
TransposeDimensionsNode(DEVICEID_TYPE deviceId, const wstring& name, int dim1 = 1, int dim2 = 2)
: Base(deviceId, name), m_dim1(dim1), m_dim2(dim2)
{
return L"Transpose";
}
TransposeDimensionsNode(const ScriptableObjects::IConfigRecordPtr configp)
: TransposeDimensionsNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"dim1"), configp->Get(L"dim2"))
{
AttachInputs(configp, this->GetExpectedNumInputs());
}
// TODO: Save and Load for m_dims; if old model version then default to 1:2
private:
// compute the transposed tensor shape (in-place)
void TransposeShape(TensorShape& shape) const
{
assert(m_dim1 > 0 && m_dim2 > 0);
size_t i = m_dim1 - 1;
size_t j = m_dim2 - 1;
shape.SwapDimsInPlace(i, j);
}
// get the transposed output shape for the current input
TensorShape GetTransposedTensorSliceFor(size_t rank, const FrameRange& fr)
{
auto shape = Input(0)->GetTensorSliceFor(rank, fr);
TransposeShape(shape);
return shape;
}
public:
DeclareConstructorFromConfigWithNumInputs(TransposeNode);
TransposeNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name)
virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
size_t rank = DetermineElementwiseTensorRank();
auto input = Input(0)->ValueTensorFor(rank, fr);
auto output = TensorView<ElemType>(Value(), GetTransposedTensorSliceFor(rank, fr));
output.AssignCopyOf(input);
}
virtual void /*ComputationNodeNonLooping::*/ BackpropToNonLooping(size_t /*inputIndex*/) override
virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override
{
auto& inputGradientValues = Input(0)->GradientAsMatrix();
auto& gradientValues = GradientAsMatrix();
#if DUMPOUTPUT
gradientValues.Print("Gradient-in");
inputGradientValues.Print("child Gradient-in/out");
inputFunctionValues.Print("child Function values");
#endif
const Matrix<ElemType>& ones = ConstOnes(inputGradientValues.GetNumRows(), inputGradientValues.GetNumRows(), inputGradientValues.GetDeviceId());
// BUGBUG: This should be ^^ Identity(). This will be fixed once we switch to the more generic tensor Transpose operation, which can handle this easily.
Matrix<ElemType>::MultiplyAndAdd(ones, false, gradientValues, true, inputGradientValues);
#if DUMPOUTPUT
inputGradientValues.Print("child Gradient-out");
#endif
InvalidArgument("TransposeNode::BackpropTo() has a known bug. it is not functional.");
size_t rank = DetermineElementwiseTensorRank();
auto inputGradient = Input(0)->GradientTensorFor(rank, fr);
auto outputGradient = TensorView<ElemType>(Gradient(), GetTransposedTensorSliceFor(rank, fr));
inputGradient.AddCopyOf(outputGradient);
}
virtual bool OutputUsedInComputingInputNodesGradients() const override
{
#if DUMPOUTPUT
return true;
#else
// The TransposeNode does not require its output value for computing
// the gradients of its input nodes
return false;
#endif
}
virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override
{
// The TransposeNode does not require any of it's input's values for computing
// the gradients of its input nodes
UNREFERENCED_PARAMETER(childIndex);
return false;
}
virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override
{
#if DUMPOUTPUT
Input(0)->ValueAsMatrix().Print("TransposeNode- Input0");
#endif
ValueAsMatrix().AssignTransposeOf(Input(0)->ValueAsMatrix());
#if NANCHECK
Value().HasNan("Transpose");
#endif
#if DUMPOUTPUT
ValueAsMatrix().Print("TransposeNode");
#endif
}
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const override { return false; }
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
if (Input(0)->HasMBLayout())
InvalidArgument("%ls %ls operation cannot operate on minibatch data (which have a layout)", NodeName().c_str(), OperationName().c_str());
m_pMBLayout = nullptr; // this node does not hold mini-batch data
assert(m_inputs.size() == 1);
ComputationNodeBase::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase();
size_t rows0 = Input(0)->GetAsMatrixNumRows(), cols0 = Input(0)->GetAsMatrixNumCols();
SetDims(TensorShape(cols0, rows0), false);
// input shape
auto shape = Input(0)->GetSampleLayout();
// validate indices
if (m_dim1 < 1 || m_dim2 < 1)
InvalidArgument("%ls %ls operation: Indices to transpose must be >= 1.", NodeName().c_str(), OperationName().c_str());
size_t i = m_dim1 - 1;
size_t j = m_dim2 - 1;
if (i >= shape.GetRank() && j >= shape.GetRank())
InvalidArgument("%ls %ls operation: At least one index must refer to an existing index.", NodeName().c_str(), OperationName().c_str());
// pad
// Permutation is allowed to create new dimensions, specifically to be able to transpose a [N] column vector into a [1 x N] row vector.
// One can also use SplitDimensions() for this, but this seems a natural thing to do.
size_t maxij = std::max(i, j);
if (maxij >= shape.GetRank())
shape.PadRankInPlace(maxij + 1);
// apply the permutation
TransposeShape(shape);
// init with dimensions only (dropping strides), since we want to allow that getting the actual minibatch tensor later may mess with strides
SetDims(TensorShape(shape.GetDims()), HasMBLayout());
}
private:
int m_dim1, m_dim2; // the two dimensions (axes, 1-based) to swap
};
template class TransposeNode<float>;
template class TransposeNode<double>;
#endif
template class TransposeDimensionsNode<float>;
template class TransposeDimensionsNode<double>;
// -----------------------------------------------------------------------
// CosDistanceNode (left, right)

Просмотреть файл

@ -28,7 +28,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// - just replaces metadata m_sampleLayout, does not change data values
// - one dimension may be specified as 0 and will be inferred
// - optional beginDim/endDim denote to only replace a sub-range of dims, for implementing ReshapeDimension() and FlattenRank()
// - may not be applied to time; use Permute() or Transpose()
//
// Derived operations:
//
@ -41,9 +40,6 @@ namespace Microsoft { namespace MSR { namespace CNTK {
//
// SplitDimension(x, dim, N) = ReshapeDimension(x, dim, 0:N)
// - splits a dimension into a new tensor dimension, injecting them into a new dimension
// - to split stacked frames into a new time dimension:
// insert new time dim with ReshapeDimension(., -1, 0:1), SplitDimension(., dim, N), Transpose(., dim+1, -1), then Select(., dim+1, 0) away the new time dim
// This would make 4 copies presently. We may need a compound C++ node for now.
// - note: to split into multiple outputs (like tf.split()), use a BrainScript loop with Slice().
// -----------------------------------------------------------------------
@ -166,14 +162,8 @@ public:
Input(inputIndex)->GradientFor(fr).SetValue(GradientFor(fr));
}
virtual bool OutputUsedInComputingInputNodesGradients() const override
{
return false;
}
virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const override
{
return false;
}
virtual bool OutputUsedInComputingInputNodesGradients() const override { return false; }
virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const override { return false; }
private:
TensorShape m_replacementSampleLayout; // user-specified dimensions to replace dimensions [beginDim, endDim]
@ -754,7 +744,7 @@ public:
#define UsingReinterpretNodeBaseMembers UsingComputationNodeMembersBoilerplate
// TODO: This ReshapeNode is currently not used. Its function will be taken over by Transpose and the Reshape that follows this one below.
// TODO: This ReshapeNode should no longer be used. Its function will be taken over by Transpose and the Reshape that follows this one below.
// -----------------------------------------------------------------------
// LegacyReshapeNode (input) -- reinterpret input matrix as having different dimensions
@ -1090,7 +1080,6 @@ reshaping
- just replaces metadata m_sampleLayout
- one dimension may be specified as 0 and will be inferred
- optional beginDim/endDim denote to only replace a sub-range of dims, for implementing ReshapeDimension() and FlattenRank()
- may not be applied to time; use Permute() or Transpose()
- ReshapeDimension(x, dim, tensorShape) = Reshape(x, tensorShape, beginDim=dim, endDim=dim+1)
- reinterprets one dimension as multiple, where the number of elements remains the same
- one of the new dimensions may be specified as 0 and will be inferred
@ -1098,9 +1087,6 @@ reshaping
- replace two or more consecutive dims by a single dim with the same number of elements
- SplitDimension(x, dim, N) = ReshapeDimension(x, dim, 0:N)
- splits a dimension into a new tensor dimension, injecting them into a new dimension
- to split stacked frames into a new time dimension:
insert new time dim with ReshapeDimension(., -1, 0:1), SplitDimension(., dim, N), Transpose(., dim+1, -1), then Select(., dim+1, 0) away the new time dim
This would make 4 copies presently. We may need a compound C++ node for now.
- note: to split into multiple outputs (like tf.split()), use a BrainScript loop with Slice().
- Slicing --all implemented in C++ by SliceNode
- Slice(x, dim, begin, end, stride=1, phase=0)
@ -1144,16 +1130,15 @@ reshaping
- generalizes CNTK RowRepeat(x, numRepeats) = Repeat(x, 1, numRepeats)
- to repeat multiple, specify vectors, e.g. Repeat(x, dim1:dim2, numRepeats1:numRepeats2)
- like tf.tile() and Matlab's repmat()
- Transposition (permuting dims): --implemented in C++ by PermuteDimensionsNode
- PermuteDimensionsOf(x, dim1:dim2:...:dimN)
- dims are rotated to dim2:dim3:...:dimN:dim1; other dims remain untouched
To rotate the other way round, specify them in opposite order.
We specify it this way to be able to reference the time dimension without having to know the rank of the m_sampleLayout.
- time dims must have a constant duration for all items in the minibatch
- internally implemented with tensor lib by shuffling dimensions with their strides --TODO: check if TensorShape optimization is still correct
- Transpose(x, dim1, dim2) = PermuteDimensions(x, dim1:dim2)
- any two dimensions; including time (must have constant duration)
- Transposition
- TransposeDimensions (input, dim1, dim2)
- swaps index dimensions dim1 and dim2. The values are 1-based; 1 stands for the leading dimension.
- new dimensions can be created; e.g. a column vector can be transposed into a row vector, which is a [1 x N] tensor
- transposing into the time dimension is currently not supported
- internally implemented with tensor lib by shuffling dimensions with their strides
- input may be minibatch data or not
- like torch.transpose()
- Transpose (input) = TransposeDimensions (input, 1, 2)
- Re-indexing: --implemented by ReindexRankNode and SliceNode
- ReindexDimension(x, dim, indexVector)
- splice x[..., indexVector[0], ...], x[..., indexVector[1], ...], etc. with indexVector[.] at given dim
@ -1161,6 +1146,7 @@ reshaping
- DownsampleDimension(x, dim, n, phase=0) = Slice(x, dim, 0, 0, stride=n)
- select every n-th element, starting with index 'phase'
- time dims allowed. Phase is then a modulus w.r.t. where a sequence is inside the minibatch (may require a ReconcileLayout() before to match layouts)
- TODO: use a bool vector for the time dimensions
- ReverseDimension(x, dim) = Slice(x, dim, -1, 0, stride=-1)
- reverses the direction of a dim
- when applied to time dims, this creates a new layout (which is also flipped)