renamed Times inputRank to inferInputRank
This commit is contained in:
Родитель
abe66249e7
Коммит
de864bffb4
|
@ -335,7 +335,7 @@ static ConfigValuePtr NodeOp(const ExpressionPtr &e, ConfigValuePtr leftVal, Con
|
|||
{
|
||||
let one = MakePrimitiveConfigValuePtr(1.0, leftFailFn, exprPath);
|
||||
config->Add(L"outputRank", leftFailFn, one);
|
||||
config->Add(L"inputRank", leftFailFn, one);
|
||||
config->Add(L"inferInputRank", leftFailFn, one);
|
||||
}
|
||||
// instantiate the ComputationNode
|
||||
let value = ConfigValuePtr(rtInfo->construct(config), MakeFailFn(e->location), exprPath);
|
||||
|
|
|
@ -37,12 +37,12 @@ LinearLayer {outDim, bias = true, init='heNormal', initValueScale=1, inputRank=0
|
|||
outputRank = Length (_AsArray (outDim)) # support outputs with tensor layouts
|
||||
apply (x) =
|
||||
if bias
|
||||
then Times (W, x, outputRank=outputRank, inputRank=inputRank) + b
|
||||
else Times (W, x, outputRank=outputRank, inputRank=inputRank)
|
||||
then Times (W, x, outputRank=outputRank, inferInputRank=inputRank) + b
|
||||
else Times (W, x, outputRank=outputRank, inferInputRank=inputRank)
|
||||
}.apply
|
||||
|
||||
# DenseLayer -- create a fully-connected layer with optional non-linearity
|
||||
DenseLayer{outDim, bias = true, activation=(x=>x), init='heNormal', initValueScale=1, inputRank=0} = Sequential ( LinearLayer{outDim, bias=bias, init=init, initValueScale=initValueScale, inputRank=inputRank} : activation )
|
||||
DenseLayer{outDim, bias = true, activation=(x=>x), init='heNormal', initValueScale=1, inputRank=0} = Sequential ( LinearLayer{outDim, bias=bias, init=init, initValueScale=initValueScale, inferInputRank=inputRank} : activation )
|
||||
|
||||
# EmbeddingLayer -- create a linear embedding layer
|
||||
EmbeddingLayer {outDim, # dimension of embedding
|
||||
|
@ -326,7 +326,7 @@ CNTK2 = [
|
|||
|
||||
// 4. Tensor operations
|
||||
// Changes: Matrix -> Tensor. A -> x, B -> y. Data must come on y ("default parameter") hence not using _
|
||||
Times(x, y, outputRank=1, inputRank=1, tag='') = new ComputationNode [ operation = 'Times' ; inputs = ( x : y ) /*plus the function args*/ ]
|
||||
Times(x, y, outputRank=1, inferInputRank=1, tag='') = new ComputationNode [ operation = 'Times' ; inputs = ( x : y ) /*plus the function args*/ ]
|
||||
|
||||
// 5. Elementwise operations.
|
||||
// Changes: "Matrix" -> "Tensor"; left input -> _; Clip: move input to front. ElementDivide/Times: anotherTensor -> y
|
||||
|
|
|
@ -238,8 +238,8 @@ class TimesNodeBase : public ComputationNode<ElemType>, public NumInputs<2>
|
|||
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembers; using Base::OperationName; \
|
||||
|
||||
public:
|
||||
TimesNodeBase(DEVICEID_TYPE deviceId, const wstring& name, size_t outputRank = 1, int inputRank = 1)
|
||||
: Base(deviceId, name), m_outputRank(outputRank), m_inputRank(inputRank)
|
||||
TimesNodeBase(DEVICEID_TYPE deviceId, const wstring& name, size_t outputRank = 1, int inferInputRank = 1)
|
||||
: Base(deviceId, name), m_outputRank(outputRank), m_inferInputRank(inferInputRank)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -250,7 +250,7 @@ public:
|
|||
{
|
||||
auto node = dynamic_pointer_cast<TimesNodeBase<ElemType, m_transpose>>(nodeP);
|
||||
node->m_outputRank = m_outputRank;
|
||||
node->m_inputRank = m_inputRank;
|
||||
node->m_inferInputRank = m_inferInputRank;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -258,7 +258,7 @@ public:
|
|||
{
|
||||
Base::Save(fstream);
|
||||
fstream << m_outputRank;
|
||||
fstream << m_inputRank;
|
||||
fstream << m_inferInputRank;
|
||||
}
|
||||
|
||||
virtual void Load(File& fstream, size_t modelVersion) override
|
||||
|
@ -269,9 +269,9 @@ public:
|
|||
else
|
||||
m_outputRank = 1;
|
||||
if (modelVersion >= CNTK_MODEL_VERSION_11)
|
||||
fstream >> m_inputRank;
|
||||
fstream >> m_inferInputRank;
|
||||
else
|
||||
m_inputRank = 1;
|
||||
m_inferInputRank = 1;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -427,28 +427,28 @@ public:
|
|||
InvalidArgument("%ls %ls operation: The outputRank (%d) dimensions in left argument's shape [%s] must not be 0.", NodeName().c_str(), OperationName().c_str(), (int)m_outputRank, dimsAstring.c_str());
|
||||
|
||||
// infer rank of dimsA
|
||||
// For purpose of dimension inference, Times() accepts an optional parameter inputRank (default 1).
|
||||
// The first 'inputRank' axes are considered those that the matrix product should reduce over,
|
||||
// For purpose of dimension inference, Times() accepts an optional parameter inferInputRank (default 1).
|
||||
// The first 'inferInputRank' axes are considered those that the matrix product should reduce over,
|
||||
// while the remaining axes are kept (Times() is applied one by one, like a "map" operation).
|
||||
// Importantly, inputRank <= 0 will be interpreted from the end. Hence, inputRank=-1 denotes
|
||||
// Importantly, inferInputRank <= 0 will be interpreted from the end. Hence, inferInputRank=-1 denotes
|
||||
// that the one last axis will not be reduced over.
|
||||
// And inputRank=0 means to reduce over all input axes, e.g. for an image input that
|
||||
// And inferInputRank=0 means to reduce over all input axes, e.g. for an image input that
|
||||
// should be flattened.
|
||||
// Examples:
|
||||
// [I x Inferred] * [J x K], inputRank=1 --> Inferred := J, result is [I x K]
|
||||
// [I x Inferred] * [W x H x C], inputRank=1 --> Inferred := W, result is [I x H x C] (not desired)
|
||||
// [I x Inferred] * [W x H x C], inputRank=0 --> Inferred := W x H x C, result is [I] (desired)
|
||||
// [I x Inferred] * [W x H x C x R], inputRank=-1 --> Inferred := W x H x C, result is [I x R] (desired)
|
||||
// [I x Inferred] * [J x K], inferInputRank=1 --> Inferred := J, result is [I x K]
|
||||
// [I x Inferred] * [W x H x C], inferInputRank=1 --> Inferred := W, result is [I x H x C] (not desired)
|
||||
// [I x Inferred] * [W x H x C], inferInputRank=0 --> Inferred := W x H x C, result is [I] (desired)
|
||||
// [I x Inferred] * [W x H x C x R], inferInputRank=-1 --> Inferred := W x H x C, result is [I x R] (desired)
|
||||
// In each case,
|
||||
// * if the output tensor is too short *and* the last dimension is 0, it will be extended
|
||||
// * output tensor dimensions that are not 0 are not touched
|
||||
if (dimsA.back() == 0) // if last entry is 0, we infer the tensor rank as well
|
||||
{
|
||||
if (abs(m_inputRank) > dimsB.size())
|
||||
InvalidArgument("%ls %ls operation: 'inputDims' argument %d exceeds rank of second operand [%s].", NodeName().c_str(), OperationName().c_str(), m_inputRank, dimsBstring.c_str());
|
||||
size_t inputRank = (size_t)(m_inputRank > 0 ? m_inputRank : (int)dimsB.size() + m_inputRank);
|
||||
if (abs(m_inferInputRank) > dimsB.size())
|
||||
InvalidArgument("%ls %ls operation: 'inputDims' argument %d exceeds rank of second operand [%s].", NodeName().c_str(), OperationName().c_str(), m_inferInputRank, dimsBstring.c_str());
|
||||
size_t inferInputRank = (size_t)(m_inferInputRank > 0 ? m_inferInputRank : (int)dimsB.size() + m_inferInputRank);
|
||||
assert(dimsA.size() == m_outputRank + numReductionDims);
|
||||
while (numReductionDims < inputRank)
|
||||
while (numReductionDims < inferInputRank)
|
||||
{
|
||||
dimsA.push_back(0);
|
||||
numReductionDims++;
|
||||
|
@ -502,7 +502,7 @@ public:
|
|||
|
||||
private:
|
||||
size_t m_outputRank;
|
||||
int m_inputRank; // can be negative to indicate counting from end
|
||||
int m_inferInputRank; // can be negative to indicate counting from end
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
@ -529,12 +529,12 @@ class TimesNode : public TimesNodeBase<ElemType, false>
|
|||
static const std::wstring TypeName() { return L"Times"; }
|
||||
|
||||
public:
|
||||
TimesNode(DEVICEID_TYPE deviceId, const wstring& name, size_t outputRank = 1, int inputRank = 1)
|
||||
: Base(deviceId, name, outputRank, inputRank)
|
||||
TimesNode(DEVICEID_TYPE deviceId, const wstring& name, size_t outputRank = 1, int inferInputRank = 1)
|
||||
: Base(deviceId, name, outputRank, inferInputRank)
|
||||
{
|
||||
}
|
||||
TimesNode(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
: TimesNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"outputRank"), configp->Get(L"inputRank"))
|
||||
: TimesNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"outputRank"), configp->Get(L"inferInputRank"))
|
||||
{
|
||||
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
@ -562,7 +562,7 @@ class TransposeTimesNode : public TimesNodeBase<ElemType, true>
|
|||
public:
|
||||
DeclareConstructorFromConfigWithNumInputs(TransposeTimesNode);
|
||||
TransposeTimesNode(DEVICEID_TYPE deviceId, const wstring& name, size_t outputRank = 1)
|
||||
: Base(deviceId, name, outputRank, /*inputRank=*/1)
|
||||
: Base(deviceId, name, outputRank, /*inferInputRank=*/1)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
|
Загрузка…
Ссылка в новой задаче