Merge pull request #3705 from Veason-silverbullet/lewlu/msra-face-dist
Lewlu/msra face dist
This commit is contained in:
Коммит
5836d74b6c
|
@ -81,6 +81,15 @@ DistributedAdditiveFullConnectionLayer {inDim, outDim, weightNormalize=true, bia
|
|||
apply (labelSequence, outProbVectorSequence) = DistributedAdditiveFullConnection(labelSequence, W, outProbVectorSequence, weightNormalize=weightNormalize, bias=bias, scale=scale)
|
||||
}.apply
|
||||
|
||||
# DistributedArcMarginProductLayer -- create a distributed arc margin product layer
|
||||
# The output is decomposed. The next layer must deal with the decomposed output tensors
|
||||
DistributedArcMarginProductLayer {inDim, outDim, bias=0, scale=1, init='glorotUniform', initValueScale=1} =
|
||||
{
|
||||
W = ParameterTensor {_ConcatArrays (inDim, outDim), init=init, initValueScale=initValueScale, distribute=true}
|
||||
|
||||
apply (labelSequence, outProbVectorSequence) = DistributedArcMarginProduct(labelSequence, W, outProbVectorSequence, bias=bias, scale=scale)
|
||||
}.apply
|
||||
|
||||
# MarginInnerProductLayer -- create a marginInnerProduct projection layer
|
||||
# Note: outputDimension may describe a tensor as well.
|
||||
MarginInnerProductLayer {inputDimension, outputDimension, base, gamma, power, lambdaMin, coefficient, init='glorotUniform', initValueScale=1} =
|
||||
|
@ -106,6 +115,15 @@ AdditiveFullConnectionLayer {inputDimension, outputDimension, weightNormalize=tr
|
|||
apply (labelSequence, outProbVectorSequence) = AdditiveFullConnection(labelSequence, outProbVectorSequence, W, outputDimension=outputDimension, weightNormalize=weightNormalize, bias=bias, annealBias=annealBias, biasBase=biasBase, biasGamma=biasGamma, biasPower=biasPower, biasMin=biasMin, biasMax=biasMax)
|
||||
}.apply
|
||||
|
||||
# ArcMarginProductLayer -- create a arc margin product layer
|
||||
# Note: outputDimension may describe a tensor as well.
|
||||
ArcMarginProductLayer {inputDimension, outputDimension, bias=0, init='glorotUniform', initValueScale=1} =
|
||||
{
|
||||
W = ParameterTensor {_ConcatArrays (outputDimension, inputDimension), init=init, initValueScale=initValueScale}
|
||||
|
||||
apply (labelSequence, outProbVectorSequence) = ArcMarginProduct(labelSequence, outProbVectorSequence, W, bias=bias)
|
||||
}.apply
|
||||
|
||||
# GlobalConcatLayer -- create a concat layer, which uses temporary global memory to save the output
|
||||
# Note: outputDimension may describe a tensor as well.
|
||||
GlobalConcatLayer {blockIndex, growthRate, segmentIndex, segmentNum} =
|
||||
|
@ -460,11 +478,12 @@ DistributedFullyConnected_v2 = CNTK2.DistributedFullyConnected_v2
|
|||
DistributedCrossEntropyWithSoftmax = CNTK2.DistributedCrossEntropyWithSoftmax
|
||||
DistributedClassificationError = CNTK2.DistributedClassificationError
|
||||
DistributedAdditiveFullConnection = CNTK2.DistributedAdditiveFullConnection
|
||||
DistributedArcMarginProduct = CNTK2.DistributedArcMarginProduct
|
||||
MarginInnerProduct = CNTK2.MarginInnerProduct
|
||||
FeatureNormalize = CNTK2.FeatureNormalize
|
||||
AdditiveFullConnection = CNTK2.AdditiveFullConnection
|
||||
ArcMarginProduct = CNTK2.ArcMarginProduct
|
||||
CenterLoss = CNTK2.CenterLoss
|
||||
ChannelMultiply = CNTK2.ChannelMultiply
|
||||
GlobalConcat = CNTK2.GlobalConcat
|
||||
Dropout = CNTK2.Dropout
|
||||
ElementTimes = CNTK2.ElementTimes
|
||||
|
@ -592,6 +611,12 @@ CNTK2 = [
|
|||
DistributedAdditiveFullConnection(labelSequence, W, outProbVectorSequence, weightNormalize=true, bias=0, scale=1, tag='') =
|
||||
new ComputationNode [ operation = 'DistributedAdditiveFullConnection' ; inputs = _AsNodes (labelSequence : W : outProbVectorSequence) /*plus the function args*/ ]
|
||||
|
||||
#
|
||||
# Distributed arcMarginProduct node, the input labels and probability must be decomposed
|
||||
#
|
||||
DistributedArcMarginProduct(labelSequence, W, outProbVectorSequence, bias=0, scale=1, tag='') =
|
||||
new ComputationNode [ operation = 'DistributedArcMarginProduct' ; inputs = _AsNodes (labelSequence : W : outProbVectorSequence) /*plus the function args*/ ]
|
||||
|
||||
#
|
||||
# MarginInnerProduct node
|
||||
#
|
||||
|
@ -610,6 +635,12 @@ CNTK2 = [
|
|||
AdditiveFullConnection(labelSequence, outProbVectorSequence, W, outputDimension=0, weightNormalize=true, bias=0, annealBias=false, biasBase=0, biasGamma=0, biasPower=0, biasMin=0, biasMax=0,tag='') =
|
||||
new ComputationNode [ operation = 'AdditiveFullConnection' ; inputs = _AsNodes (labelSequence : outProbVectorSequence : W) /*plus the function args*/ ]
|
||||
|
||||
#
|
||||
# ArcMarginProduct node
|
||||
#
|
||||
ArcMarginProduct(labelSequence, outProbVectorSequence, W, outputDimension=0, bias=0, tag='') =
|
||||
new ComputationNode [ operation = 'ArcMarginProduct' ; inputs = _AsNodes (labelSequence : outProbVectorSequence : W) /*plus the function args*/ ]
|
||||
|
||||
#
|
||||
#CenterLoss node
|
||||
#
|
||||
|
@ -617,12 +648,6 @@ CNTK2 = [
|
|||
if axis==0 then new ComputationNode [ operation = 'CenterLoss' ; inputs = _AsNodes (labelSequence : outProbVectorSequence) /*plus the function args*/ ]
|
||||
else [ tag1 = tag; out = Minus (ReduceLogSum (outProbVectorSequence, axis=axis), ReduceSum (labelSequence .* outProbVectorSequence, axis=axis), tag=tag1) ].out
|
||||
|
||||
#
|
||||
# ChannelMultiply node
|
||||
#
|
||||
ChannelMultiply(feature, weight, tag='') =
|
||||
new ComputationNode [ operation = 'ChannelMultiply' ; inputs = _AsNodes (feature : weight) /*plus the function args*/ ]
|
||||
|
||||
#
|
||||
# GlobalConcat node
|
||||
#
|
||||
|
|
|
@ -4351,15 +4351,17 @@ namespace CNTK
|
|||
|
||||
CNTK_API FunctionPtr DistributedAdditiveFullConnection(const Variable& targets, const Variable& weight, const Variable& prediction, bool weightNormalize, double bias, double scale, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr DistributedArcMarginProduct(const Variable& targets, const Variable& weight, const Variable& prediction, double bias, double scale, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr MarginInnerProduct(const Variable& prediction, const Variable& targets, const Variable& weight, size_t outputDimension, double base, double gamma, double power, double lambdaMin, size_t marginCoefficient, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr FeatureNormalize(const Variable& feature, size_t normalizeType, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr AdditiveFullConnection(const Variable& prediction, const Variable& targets, const Variable& weight, size_t outputDimension, bool weightNormalize, double bias, bool annealBias, double biasBase, double biasGamma, double biasPower, double biasMin, double biasMax, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr CenterLoss(const Variable& prediction, const Variable& targets, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ArcMarginProduct(const Variable& prediction, const Variable& targets, const Variable& weight, double bias, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr ChannelMultiply(const Variable& prediction, const Variable& targets, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr CenterLoss(const Variable& prediction, const Variable& targets, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr GlobalConcat(const Variable& feature, size_t blockIndex, size_t growthRate, size_t segmentIndex, size_t segmentNum, const std::wstring& name = L"");
|
||||
|
||||
|
|
|
@ -125,14 +125,15 @@ namespace CNTK
|
|||
{PrimitiveOpType::MarginInnerProduct, L"MarginInnerProduct"},
|
||||
{PrimitiveOpType::FeatureNormalize, L"FeatureNormalize"},
|
||||
{PrimitiveOpType::AdditiveFullConnection, L"AdditiveFullConnection"},
|
||||
{PrimitiveOpType::ArcMarginProduct, L"ArcMarginProduct" },
|
||||
{PrimitiveOpType::CenterLoss, L"CenterLoss" },
|
||||
{PrimitiveOpType::ChannelMultiply, L"ChannelMultiply" },
|
||||
{PrimitiveOpType::GlobalConcat, L"GlobalConcat"},
|
||||
{PrimitiveOpType::DistributedFullyConnected, L"DistributedFullyConnected" },
|
||||
{PrimitiveOpType::DistributedFullyConnected_v2, L"DistributedFullyConnected_v2" },
|
||||
{PrimitiveOpType::DistributedCrossEntropyWithSoftmax, L"DistributedCrossEntropyWithSoftmax" },
|
||||
{PrimitiveOpType::DistributedClassificationError, L"DistributedClassificationError" },
|
||||
{PrimitiveOpType::DistributedAdditiveFullConnection, L"DistributedAdditiveFullConnection" }
|
||||
{PrimitiveOpType::DistributedAdditiveFullConnection, L"DistributedAdditiveFullConnection" },
|
||||
{PrimitiveOpType::DistributedArcMarginProduct, L"DistributedArcMarginProduct" }
|
||||
};
|
||||
|
||||
inline const std::wstring& PrimitiveOpTypeName(PrimitiveOpType opType)
|
||||
|
|
|
@ -154,6 +154,7 @@ namespace CNTK
|
|||
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasPower;
|
||||
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasMin;
|
||||
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasMax;
|
||||
CNTK_API static const std::wstring AttributeArcMarginProductBias;
|
||||
CNTK_API static const std::wstring AttributeCenterLossLambda;
|
||||
CNTK_API static const std::wstring AttributeCenterLossAlpha;
|
||||
CNTK_API static const std::wstring AttributeCenterLossLabelDim;
|
||||
|
@ -165,6 +166,8 @@ namespace CNTK
|
|||
CNTK_API static const std::wstring AttributeDistributedAdditiveFullConnectionWeightNormalize;
|
||||
CNTK_API static const std::wstring AttributeDistributedAdditiveFullConnectionBias;
|
||||
CNTK_API static const std::wstring AttributeDistributedAdditiveFullConnectionScale;
|
||||
CNTK_API static const std::wstring AttributeDistributedArcMarginProductBias;
|
||||
CNTK_API static const std::wstring AttributeDistributedArcMarginProductScale;
|
||||
|
||||
CNTK_API static const std::vector<std::wstring> s_rngStateAttributes;
|
||||
};
|
||||
|
|
|
@ -110,14 +110,15 @@ namespace CNTK
|
|||
MarginInnerProduct = 197,
|
||||
FeatureNormalize = 198,
|
||||
AdditiveFullConnection = 199,
|
||||
CenterLoss = 200,
|
||||
ChannelMultiply = 201,
|
||||
ArcMarginProduct = 200,
|
||||
CenterLoss = 201,
|
||||
GlobalConcat = 202,
|
||||
DistributedFullyConnected = 203,
|
||||
DistributedFullyConnected_v2 = 204,
|
||||
DistributedCrossEntropyWithSoftmax = 205,
|
||||
DistributedClassificationError = 206,
|
||||
DistributedAdditiveFullConnection = 207,
|
||||
DistributedArcMarginProduct = 208,
|
||||
// New op types should only be appended to the end of this list
|
||||
UnknownOP
|
||||
// and UnknownOP should always be last.
|
||||
|
|
|
@ -347,6 +347,14 @@ namespace CNTK
|
|||
|
||||
opType = PrimitiveOpType::DistributedAdditiveFullConnection;
|
||||
}
|
||||
else if (node->OperationName() == OperationNameOf(DistributedArcMarginProductNode))
|
||||
{
|
||||
auto distributedArcMarginProductNode = node->As<DistributedArcMarginProductNode<ElementType>>();
|
||||
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductBias] = distributedArcMarginProductNode->m_bias;
|
||||
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductScale] = distributedArcMarginProductNode->m_scale;
|
||||
|
||||
opType = PrimitiveOpType::DistributedArcMarginProduct;
|
||||
}
|
||||
else if (node->OperationName() == OperationNameOf(MarginInnerProductNode))
|
||||
{
|
||||
auto marginInnerProductNode = node->As<MarginInnerProductNode<ElementType>>();
|
||||
|
@ -381,6 +389,13 @@ namespace CNTK
|
|||
|
||||
opType = PrimitiveOpType::AdditiveFullConnection;
|
||||
}
|
||||
else if (node->OperationName() == OperationNameOf(ArcMarginProductNode))
|
||||
{
|
||||
auto arcMarginProductNode = node->As<ArcMarginProductNode<ElementType>>();
|
||||
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeArcMarginProductBias] = arcMarginProductNode->m_bias;
|
||||
|
||||
opType = PrimitiveOpType::ArcMarginProduct;
|
||||
}
|
||||
else if (node->OperationName() == OperationNameOf(CenterLossNode))
|
||||
{
|
||||
auto centerLossNode = node->As<CenterLossNode<ElementType>>();
|
||||
|
@ -391,8 +406,6 @@ namespace CNTK
|
|||
|
||||
opType = PrimitiveOpType::CenterLoss;
|
||||
}
|
||||
else if (node->OperationName() == OperationNameOf(ChannelMultiplyNode))
|
||||
opType = PrimitiveOpType::ChannelMultiply;
|
||||
else if (node->OperationName() == OperationNameOf(GlobalConcatNode))
|
||||
{
|
||||
auto globalConcatNode = node->As<GlobalConcatNode<ElementType>>();
|
||||
|
|
|
@ -1182,6 +1182,13 @@ namespace CNTK
|
|||
ASSIGN_NEW_NODE(DistributedAdditiveFullConnectionNode, network->GetDeviceId(), internalNodeName, weightNormalize, bias, scale);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::DistributedArcMarginProduct:
|
||||
{
|
||||
auto bias = functionConfig[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductBias].Value<double>();
|
||||
auto scale = functionConfig[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductScale].Value<double>();
|
||||
ASSIGN_NEW_NODE(DistributedArcMarginProductNode, network->GetDeviceId(), internalNodeName, bias, scale);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::MarginInnerProduct:
|
||||
{
|
||||
auto outputDimension = functionConfig[PrimitiveFunctionAttribute::AttributeMarginInnerProductOutputDimension].Value<size_t>();
|
||||
|
@ -1213,6 +1220,12 @@ namespace CNTK
|
|||
ASSIGN_NEW_NODE(AdditiveFullConnectionNode, network->GetDeviceId(), internalNodeName, outputDimension, weightNormalize, bias, annealBias, biasBase, biasGamma, biasPower, biasMin, biasMax);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ArcMarginProduct:
|
||||
{
|
||||
auto bias = functionConfig[PrimitiveFunctionAttribute::AttributeArcMarginProductBias].Value<double>();
|
||||
ASSIGN_NEW_NODE(ArcMarginProductNode, network->GetDeviceId(), internalNodeName, bias);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::CenterLoss:
|
||||
{
|
||||
auto lambda = functionConfig[PrimitiveFunctionAttribute::AttributeCenterLossLambda].Value<double>();
|
||||
|
@ -1222,11 +1235,6 @@ namespace CNTK
|
|||
ASSIGN_NEW_NODE(CenterLossNode, network->GetDeviceId(), internalNodeName, lambda, alpha, labelDim, normalize);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ChannelMultiply:
|
||||
{
|
||||
ASSIGN_NEW_NODE(ChannelMultiplyNode, network->GetDeviceId(), internalNodeName);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::GlobalConcat:
|
||||
{
|
||||
auto blockIndex = functionConfig[PrimitiveFunctionAttribute::AttributeGlobalConcatBlockIndex].Value<size_t>();
|
||||
|
|
|
@ -2170,6 +2170,15 @@ namespace CNTK
|
|||
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::DistributedAdditiveFullConnection, operands, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr DistributedArcMarginProduct(const Variable& labels, const Variable& weight, const Variable& features, double bias, double scale, const std::wstring& name)
|
||||
{
|
||||
std::vector<Variable> operands = { labels, weight, features };
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductBias] = bias;
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductScale] = scale;
|
||||
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::DistributedArcMarginProduct, operands, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr MarginInnerProduct(const Variable& features, const Variable& labels, const Variable& weight, size_t outputDimension, double base, double gamma, double power, double lambdaMin, size_t marginCoefficient, const std::wstring& name)
|
||||
{
|
||||
std::vector<Variable> operands = {features, labels, weight};
|
||||
|
@ -2206,6 +2215,14 @@ namespace CNTK
|
|||
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::AdditiveFullConnection, operands, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ArcMarginProduct(const Variable& features, const Variable& labels, const Variable& weight, double bias, const std::wstring& name)
|
||||
{
|
||||
std::vector<Variable> operands = { features, labels, weight };
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeArcMarginProductBias] = bias;
|
||||
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ArcMarginProduct, operands, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr CenterLoss(const Variable& features, const Variable& labels, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring& name)
|
||||
{
|
||||
auto additionalProperties = Dictionary();
|
||||
|
@ -2216,14 +2233,6 @@ namespace CNTK
|
|||
return BinaryOp(PrimitiveOpType::CenterLoss, labels, features, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr ChannelMultiply(const Variable& feature, const Variable& weight, const std::wstring& name)
|
||||
{
|
||||
Variable featureCopy = feature;
|
||||
Variable weightCopy = weight;
|
||||
auto additionalProperties = Dictionary();
|
||||
return BinaryOp(PrimitiveOpType::ChannelMultiply, featureCopy, weightCopy, std::move(additionalProperties), name);
|
||||
}
|
||||
|
||||
FunctionPtr GlobalConcat(const Variable& feature, size_t blockIndex, size_t growthRate, size_t segmentIndex, size_t segmentNum, const std::wstring& name)
|
||||
{
|
||||
Variable featureCopy = feature;
|
||||
|
|
|
@ -135,11 +135,12 @@ namespace CNTK
|
|||
(op == PrimitiveOpType::DistributedCrossEntropyWithSoftmax) ||
|
||||
(op == PrimitiveOpType::DistributedClassificationError) ||
|
||||
(op == PrimitiveOpType::DistributedAdditiveFullConnection) ||
|
||||
(op == PrimitiveOpType::DistributedArcMarginProduct) ||
|
||||
(op == PrimitiveOpType::MarginInnerProduct) ||
|
||||
(op == PrimitiveOpType::FeatureNormalize) ||
|
||||
(op == PrimitiveOpType::AdditiveFullConnection) ||
|
||||
(op == PrimitiveOpType::ArcMarginProduct) ||
|
||||
(op == PrimitiveOpType::CenterLoss) ||
|
||||
(op == PrimitiveOpType::ChannelMultiply) ||
|
||||
(op == PrimitiveOpType::GlobalConcat) ||
|
||||
(op == PrimitiveOpType::CrossEntropyWithSoftmax) ||
|
||||
(op == PrimitiveOpType::LatticeSequenceWithSoftmax) ||
|
||||
|
@ -970,6 +971,12 @@ namespace CNTK
|
|||
outputShape = NDShape{};
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::DistributedArcMarginProduct:
|
||||
{
|
||||
assert(m_inputs.size() == 3);
|
||||
outputShape = NDShape{};
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::MarginInnerProduct:
|
||||
{
|
||||
assert(m_inputs.size() == 3);
|
||||
|
@ -988,10 +995,10 @@ namespace CNTK
|
|||
outputShape = NDShape{};
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::ChannelMultiply:
|
||||
case PrimitiveOpType::ArcMarginProduct:
|
||||
{
|
||||
assert(m_inputs.size() == 2);
|
||||
outputShape = m_inputs[0].Shape();
|
||||
assert(m_inputs.size() == 3);
|
||||
outputShape = NDShape{};
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::GlobalConcat:
|
||||
|
|
|
@ -130,26 +130,29 @@ namespace CNTK
|
|||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeMarginInnerProductLambdaMin = L"marginInnerProductOutputLambdaMin";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeMarginInnerProductMarginCoefficient = L"marginCoefficient";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeFeatureNormalizeNormalizeType = L"featureNormalizeType";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionOutputDimension = L"outputDimension";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionWeightNormalize = L"weightNormalize";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBias = L"bias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionAnnealBias = L"annealBias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasBase = L"biasBase";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasGamma = L"biasGamma";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasPower = L"biasPower";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMin = L"biasMin";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMax = L"biasMax";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionOutputDimension = L"additiveFullConnectionOutputDimension";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionWeightNormalize = L"additiveFullConnectionWeightNormalize";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBias = L"additiveFullConnectionBias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionAnnealBias = L"additiveFullConnectionAnnealBias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasBase = L"additiveFullConnectionBiasBase";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasGamma = L"additiveFullConnectionBiasGamma";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasPower = L"additiveFullConnectionBiasPower";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMin = L"additiveFullConnectionBiasMin";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMax = L"additiveFullConnectionBiasMax";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeArcMarginProductBias = L"arcMarginProductBias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossLambda = L"centerLossLambda";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossAlpha = L"centerLossAlpha";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossLabelDim = L"centerLossLabelDim";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossNormalize = L"centerLossNormalize";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatBlockIndex = L"blockIndex";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatGrowthRate = L"growthRate";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatSegmentIndex = L"segmentIndex";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatSegmentNum = L"segmentNum";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionWeightNormalize = L"weightNormalize";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionBias = L"bias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionScale = L"scale";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatBlockIndex = L"globalConcatBlockIndex";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatGrowthRate = L"globalConcatGrowthRate";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatSegmentIndex = L"globalConcatSegmentIndex";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatSegmentNum = L"globalConcatSegmentNum";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionWeightNormalize = L"distributedAdditiveFullConnectionWeightNormalize";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionBias = L"distributedAdditiveFullConnectionBias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionScale = L"distributedAdditiveFullConnectionScale";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductBias = L"distributedArcMarginProductBias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductScale = L"distributedArcMarginProductScale";
|
||||
|
||||
/*static*/ const std::vector<std::wstring> PrimitiveFunctionAttribute::s_rngStateAttributes =
|
||||
{ PrimitiveFunctionAttribute::AttributeNameRngSeed,
|
||||
|
|
|
@ -131,11 +131,12 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
|
|||
else if (nodeType == OperationNameOf(DistributedCrossEntropyWithSoftmaxNode)) return New<DistributedCrossEntropyWithSoftmaxNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(DistributedClassificationErrorNode)) return New<DistributedClassificationErrorNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(DistributedAdditiveFullConnectionNode)) return New<DistributedAdditiveFullConnectionNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(DistributedArcMarginProductNode)) return New<DistributedArcMarginProductNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(MarginInnerProductNode)) return New<MarginInnerProductNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(FeatureNormalizeNode)) return New<FeatureNormalizeNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(AdditiveFullConnectionNode)) return New<AdditiveFullConnectionNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(ArcMarginProductNode)) return New<ArcMarginProductNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(CenterLossNode)) return New<CenterLossNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(ChannelMultiplyNode)) return New<ChannelMultiplyNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(GlobalConcatNode)) return New<GlobalConcatNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(LogisticNode)) return New<LogisticNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
else if (nodeType == OperationNameOf(SumColumnElementsNode)) return New<SumColumnElementsNode<ElemType>>(forward<_Types>(_Args)...);
|
||||
|
@ -518,6 +519,12 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Distr
|
|||
return net.AddNodeToNetAndAttachInputs(New<DistributedAdditiveFullConnectionNode<ElemType>>(net.GetDeviceId(), nodeName, weightNormalize, bias, scale), { a, b, c });
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::DistributedArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, ElemType scale, const std::wstring nodeName)
|
||||
{
|
||||
return net.AddNodeToNetAndAttachInputs(New<DistributedArcMarginProductNode<ElemType>>(net.GetDeviceId(), nodeName, bias, scale), { a, b, c });
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::MarginInnerProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType base, ElemType gamma, ElemType power, ElemType lambdaMin, size_t marginCoefficient, const std::wstring nodeName)
|
||||
{
|
||||
|
@ -537,15 +544,15 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Addit
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::CenterLoss(const ComputationNodePtr a, const ComputationNodePtr b, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring nodeName)
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, const std::wstring nodeName)
|
||||
{
|
||||
return net.AddNodeToNetAndAttachInputs(New<CenterLossNode<ElemType>>(net.GetDeviceId(), nodeName, lambda, alpha, labelDim, normalize), { a, b });
|
||||
return net.AddNodeToNetAndAttachInputs(New<ArcMarginProductNode<ElemType>>(net.GetDeviceId(), nodeName, bias), { a, b, c });
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ChannelMultiply(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName)
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::CenterLoss(const ComputationNodePtr a, const ComputationNodePtr b, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring nodeName)
|
||||
{
|
||||
return net.AddNodeToNetAndAttachInputs(New<CenterLossNode<ElemType>>(net.GetDeviceId(), nodeName), { a, b });
|
||||
return net.AddNodeToNetAndAttachInputs(New<CenterLossNode<ElemType>>(net.GetDeviceId(), nodeName, lambda, alpha, labelDim, normalize), { a, b });
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -205,11 +205,12 @@ public:
|
|||
ComputationNodePtr DistributedCrossEntropyWithSoftmax(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr DistributedClassificationError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr DistributedAdditiveFullConnection(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, bool weightNormalize, ElemType bias, ElemType scale, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr DistributedArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, ElemType scale, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr MarginInnerProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType base, ElemType gamma, ElemType power, ElemType lambdaMin, size_t marginCoefficient, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr FeatureNormalize(const ComputationNodePtr a, size_t normalizeType, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr AdditiveFullConnection(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, bool weightNormalize, ElemType bias, bool annealBias, ElemType biasBase, ElemType biasGamma, ElemType biasPower, ElemType biasMin, ElemType biasMax, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr CenterLoss(const ComputationNodePtr a, const ComputationNodePtr b, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr ChannelMultiply(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr GlobalConcat(const ComputationNodePtr a, size_t blockIndex, size_t growthRate, size_t segmentIndex, size_t segmentNum, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr Sum(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr Tan(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
|
|
|
@ -22,24 +22,9 @@
|
|||
using namespace std;
|
||||
|
||||
|
||||
//#define __DISTRIBUTED_PROFILE__
|
||||
|
||||
|
||||
namespace Microsoft { namespace MSR { namespace CNTK {
|
||||
|
||||
|
||||
#ifdef __DISTRIBUTED_PROFILE__
|
||||
double distributedAdditiveFullConnectionTime = 0.0;
|
||||
std::chrono::time_point<std::chrono::system_clock> distributedAdditiveFullConnectionStartTime;
|
||||
std::chrono::time_point<std::chrono::system_clock> distributedAdditiveFullConnectionEndTime;
|
||||
int distributedAdditiveFullConnectionCnt = 0;
|
||||
double distributedCrossEntropyWithSoftmaxTime = 0.0;
|
||||
std::chrono::time_point<std::chrono::system_clock> distributedCrossEntropyWithSoftmaxStartTime;
|
||||
std::chrono::time_point<std::chrono::system_clock> distributedCrossEntropyWithSoftmaxEndTime;
|
||||
int distributedCrossEntropyWithSoftmaxCnt = 0;
|
||||
#endif
|
||||
|
||||
|
||||
// This source file contains methods related to evaluation (forward prop, backprop), network validation, and matrix memory allocation (memory sharing).
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
|
@ -163,14 +148,6 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
|||
{
|
||||
if (node->IsOutOfDateWrtInputs())
|
||||
{
|
||||
#ifdef __DISTRIBUTED_PROFILE__
|
||||
if (node->OperationName() == L"DistributedAdditiveFullConnection")
|
||||
distributedAdditiveFullConnectionStartTime = std::chrono::system_clock::now();
|
||||
else if (node->OperationName() == L"DistributedCrossEntropyWithSoftmax")
|
||||
distributedCrossEntropyWithSoftmaxStartTime = std::chrono::system_clock::now();
|
||||
#endif
|
||||
|
||||
|
||||
node->BeginForwardProp();
|
||||
node->BeginTiming(false /*backward*/);
|
||||
node->ForwardProp(fr.WithLayout(node->GetMBLayout()));
|
||||
|
@ -182,30 +159,6 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
|
|||
// Extreme Tracing, part 1/4
|
||||
if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode())
|
||||
DumpNode(node, /*dumpGradient=*/false);
|
||||
|
||||
|
||||
#ifdef __DISTRIBUTED_PROFILE__
|
||||
if (node->OperationName() == L"DistributedAdditiveFullConnection")
|
||||
{
|
||||
distributedAdditiveFullConnectionEndTime = std::chrono::system_clock::now();
|
||||
distributedAdditiveFullConnectionTime += (std::chrono::duration<double>(distributedAdditiveFullConnectionEndTime - distributedAdditiveFullConnectionStartTime)).count();
|
||||
if (++distributedAdditiveFullConnectionCnt % 100 == 0)
|
||||
{
|
||||
fprintf(stderr, "Iteration [%d-%d]: distributedAdditiveFullConnection forward time = %.8gs\n", distributedAdditiveFullConnectionCnt - 99, distributedAdditiveFullConnectionCnt, distributedAdditiveFullConnectionTime);
|
||||
distributedAdditiveFullConnectionTime = 0.0;
|
||||
}
|
||||
}
|
||||
else if (node->OperationName() == L"DistributedCrossEntropyWithSoftmax")
|
||||
{
|
||||
distributedCrossEntropyWithSoftmaxEndTime = std::chrono::system_clock::now();
|
||||
distributedCrossEntropyWithSoftmaxTime += (std::chrono::duration<double>(distributedCrossEntropyWithSoftmaxEndTime - distributedCrossEntropyWithSoftmaxStartTime)).count();
|
||||
if (++distributedCrossEntropyWithSoftmaxCnt % 100 == 0)
|
||||
{
|
||||
fprintf(stderr, "Iteration [%d-%d]: distributedCrossEntropyWithSoftmax forward time = %.8gs\n", distributedCrossEntropyWithSoftmaxCnt - 99, distributedCrossEntropyWithSoftmaxCnt, distributedCrossEntropyWithSoftmaxTime);
|
||||
distributedCrossEntropyWithSoftmaxTime = 0.0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -483,7 +483,8 @@ void ComputationNodeBase::ValidateBinaryReduce(bool isFinalValidationPass)
|
|||
{
|
||||
// It is for DistributedCrossEntropyWithSoftmaxNode
|
||||
if (Input(0)->OperationName() != L"DistributedFullyConnected_v2" && Input(1)->OperationName() != L"DistributedFullyConnected_v2" &&
|
||||
Input(0)->OperationName() != L"DistributedAdditiveFullConnection" && Input(1)->OperationName() != L"DistributedAdditiveFullConnection")
|
||||
Input(0)->OperationName() != L"DistributedAdditiveFullConnection" && Input(1)->OperationName() != L"DistributedAdditiveFullConnection" &&
|
||||
Input(0)->OperationName() != L"DistributedArcMarginProduct" && Input(1)->OperationName() != L"DistributedArcMarginProduct")
|
||||
{
|
||||
string s1 = Input(0)->GetSampleLayout();
|
||||
string s2 = Input(1)->GetSampleLayout();
|
||||
|
|
|
@ -1729,7 +1729,7 @@ protected:
|
|||
{
|
||||
rows = GetSampleMatrixNumRows();
|
||||
cols = GetSampleMatrixNumCols();
|
||||
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection")
|
||||
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection" || OperationName() == L"DistributedArcMarginProduct")
|
||||
cols *= Globals::GetProcessNum();
|
||||
}
|
||||
else
|
||||
|
@ -1856,7 +1856,7 @@ public:
|
|||
size_t matrixSize = m_sampleLayout.GetNumElements();
|
||||
if (IsValueSharable() && !m_isValueSparse)
|
||||
{
|
||||
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection")
|
||||
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection" || OperationName() == L"DistributedArcMarginProduct")
|
||||
matrixSize *= Globals::GetProcessNum();
|
||||
RequestMatrixFromPool(m_value, matrixPool, matrixSize, HasMBLayout());
|
||||
}
|
||||
|
@ -1897,7 +1897,7 @@ public:
|
|||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
|
||||
{
|
||||
size_t matrixSize = m_sampleLayout.GetNumElements();
|
||||
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection")
|
||||
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection" || OperationName() == L"DistributedArcMarginProduct")
|
||||
matrixSize *= Globals::GetProcessNum();
|
||||
RequestMatrixFromPool(m_gradient, matrixPool, matrixSize, HasMBLayout(), /*isWorkSpace*/false, ParentGradientReused() || IsGradientReused());
|
||||
|
||||
|
|
|
@ -361,6 +361,10 @@ template class DistributedAdditiveFullConnectionNode<float>;
|
|||
template class DistributedAdditiveFullConnectionNode<double>;
|
||||
template class DistributedAdditiveFullConnectionNode<half>;
|
||||
|
||||
template class DistributedArcMarginProductNode<float>;
|
||||
template class DistributedArcMarginProductNode<double>;
|
||||
template class DistributedArcMarginProductNode<half>;
|
||||
|
||||
template class MarginInnerProductNode<float>;
|
||||
template class MarginInnerProductNode<double>;
|
||||
template class MarginInnerProductNode<half>;
|
||||
|
@ -373,14 +377,14 @@ template class AdditiveFullConnectionNode<float>;
|
|||
template class AdditiveFullConnectionNode<double>;
|
||||
template class AdditiveFullConnectionNode<half>;
|
||||
|
||||
template class ArcMarginProductNode<float>;
|
||||
template class ArcMarginProductNode<double>;
|
||||
template class ArcMarginProductNode<half>;
|
||||
|
||||
template class CenterLossNode<float>;
|
||||
template class CenterLossNode<double>;
|
||||
template class CenterLossNode<half>;
|
||||
|
||||
template class ChannelMultiplyNode<float>;
|
||||
template class ChannelMultiplyNode<double>;
|
||||
template class ChannelMultiplyNode<half>;
|
||||
|
||||
template class GlobalMemoryBlock<float>;
|
||||
template class GlobalMemoryBlock<double>;
|
||||
template class GlobalMemoryBlock<half>;
|
||||
|
|
|
@ -31,6 +31,7 @@ static const wstring RandomDistributionTypeGumbel = L"gumbel";
|
|||
static const wstring RandomDistributionTypeBernoulli = L"bernoulli";
|
||||
|
||||
|
||||
static const double m_PI = acos(-1.0);
|
||||
// Distributed fully connected layer (Y = W'X + b)
|
||||
// Input(0): W, Input(1): X, Input(2): b
|
||||
// Output is not decomposed
|
||||
|
@ -736,6 +737,224 @@ public:
|
|||
shared_ptr<Matrix<ElemType>> m_WNorm;
|
||||
};
|
||||
|
||||
// Distributed additive angular margin product layer
|
||||
// Input(0): labels, Input(1): W, Input(2): X
|
||||
// Input and output are both decomposed
|
||||
template <class ElemType>
|
||||
class DistributedArcMarginProductNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<3>
|
||||
{
|
||||
typedef ComputationNodeNonLooping<ElemType> Base;
|
||||
UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName()
|
||||
{
|
||||
return L"DistributedArcMarginProduct";
|
||||
}
|
||||
|
||||
public:
|
||||
DistributedArcMarginProductNode(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
: DistributedArcMarginProductNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"bias"), configp->Get(L"scale"))
|
||||
{
|
||||
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
DistributedArcMarginProductNode(DEVICEID_TYPE deviceId, const wstring& name, double bias = 0.0, double scale = 1.0)
|
||||
: Base(deviceId, name), m_bias(bias), m_scale(scale), m_rank(Globals::GetRank()), m_processNum(Globals::GetProcessNum()), m_minibatchSize(0), m_distGradAggPtr(NULL)
|
||||
{
|
||||
if (m_bias < 0 || m_bias >= m_PI)
|
||||
LogicError("DistributedArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
|
||||
m_threshold = cos(m_PI - m_bias);
|
||||
m_cosBias = cos(m_bias);
|
||||
m_sinBias = sin(m_bias);
|
||||
#ifdef CPUONLY
|
||||
LogicError("CPUONLY is not supported in DistributedArcMarginProductNode.");
|
||||
#endif
|
||||
}
|
||||
|
||||
~DistributedArcMarginProductNode()
|
||||
{
|
||||
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
|
||||
DistributedGatheredLabels<ElemType>::initializeNodePtr = NULL;
|
||||
}
|
||||
|
||||
virtual void UpdateFunctionMBSize() override
|
||||
{
|
||||
size_t minibatchSize = InputRef(0).Value().GetNumCols();
|
||||
if (m_minibatchSize != minibatchSize)
|
||||
{
|
||||
if (1 == m_processNum)
|
||||
LogicError("Multi Gpus and mpi is needed in distributed FC.");
|
||||
m_distGradAggPtr = (IDistGradAggregator<ElemType>*) Globals::GetDistGradAggPtr();
|
||||
bool minibatchSizeEqual = m_distGradAggPtr->DistributedCheck(m_minibatchSize, m_processNum);
|
||||
if (!minibatchSizeEqual)
|
||||
LogicError("With AllGather op, minibatch size in each Gpu must be the same.");
|
||||
m_minibatchSize = minibatchSize;
|
||||
m_batchSize = m_minibatchSize * m_processNum;
|
||||
}
|
||||
m_temp1->Resize(m_inputDim, m_batchSize); // Aggregated X
|
||||
m_tempValue->Resize(1, m_batchSize);
|
||||
m_flag->Resize(1, m_batchSize);
|
||||
m_WNorm->Resize(1, m_outputDim);
|
||||
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
|
||||
DistributedGatheredLabels<ElemType>::setMinibatchSize(m_minibatchSize);
|
||||
}
|
||||
|
||||
virtual void BackpropToNonLooping(size_t inputIndex) override
|
||||
{
|
||||
if (1 == inputIndex) // for W
|
||||
{
|
||||
Matrix<ElemType>::Scale((ElemType)m_scale, Gradient());
|
||||
Matrix<ElemType>::DistributedArcLabelAddBackprop(*DistributedGatheredLabels<ElemType>::m_gatheredLabels, (ElemType)m_cosBias, (ElemType)m_sinBias, *m_flag, *m_tempValue, Gradient(), m_outputDim * m_rank, m_outputDim * (m_rank + 1) - 1);
|
||||
|
||||
auto& W_gradient = InputRef(1).Gradient();
|
||||
Matrix<ElemType>::Multiply(*m_temp1, false, Gradient(), true, W_gradient);
|
||||
}
|
||||
else if (2 == inputIndex) // for X
|
||||
{
|
||||
auto& W = InputRef(1).Value();
|
||||
Matrix<ElemType>::Multiply(W, false, Gradient(), false, *m_temp1);
|
||||
m_distGradAggPtr->DistributedAllReduce(*m_temp1, MPI_SUM);
|
||||
auto& X_gradient = InputRef(2).Gradient();
|
||||
X_gradient.SetValue(m_temp1->ColumnSlice(m_minibatchSize * m_rank, m_minibatchSize));
|
||||
}
|
||||
}
|
||||
|
||||
virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override
|
||||
{
|
||||
auto& W = InputRef(1).Value();
|
||||
auto& X = InputRef(2).Value();
|
||||
W.VectorNorm2(*m_WNorm, true);
|
||||
W.RowElementDivideBy(*m_WNorm);
|
||||
|
||||
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
|
||||
DistributedGatheredLabels<ElemType>::gatherDistributedLabels(InputRef(0).Value());
|
||||
m_distGradAggPtr->DistributedAllGather(X, *m_temp1, m_inputDim * m_minibatchSize);
|
||||
Matrix<ElemType>::Multiply(W, true, *m_temp1, false, Value());
|
||||
if (Environment().IsTraining())
|
||||
{
|
||||
m_flag->SetValue(0);
|
||||
Matrix<ElemType>::DistributedArcLabelAdd(*DistributedGatheredLabels<ElemType>::m_gatheredLabels, (ElemType)m_threshold, (ElemType)m_bias, (ElemType)m_sinBias, *m_flag, *m_tempValue, Value(), m_outputDim * m_rank, m_outputDim * (m_rank + 1) - 1);
|
||||
}
|
||||
Matrix<ElemType>::Scale((ElemType)m_scale, Value());
|
||||
}
|
||||
|
||||
virtual bool OutputUsedInComputingInputNodesGradients() const override
|
||||
{
|
||||
return false;
|
||||
}
|
||||
virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override
|
||||
{
|
||||
if (0 == childIndex) // for labels
|
||||
return false;
|
||||
else if (1 == childIndex) // for W
|
||||
return true;
|
||||
else // for X
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
|
||||
{
|
||||
Base::Validate(isFinalValidationPass);
|
||||
|
||||
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
|
||||
m_inputDim = InputRef(1).Value().GetNumRows();
|
||||
m_outputDim = InputRef(1).Value().GetNumCols();
|
||||
SetDims(TensorShape(m_outputDim), HasMBLayout());
|
||||
|
||||
DistributedGatheredLabels<ElemType>::setInitializeNode(this);
|
||||
}
|
||||
|
||||
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
|
||||
{
|
||||
Base::CopyTo(nodeP, newName, flags);
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
auto node = dynamic_pointer_cast<DistributedArcMarginProductNode<ElemType>>(nodeP);
|
||||
node->m_rank = m_rank;
|
||||
node->m_processNum = m_processNum;
|
||||
node->m_inputDim = m_inputDim;
|
||||
node->m_outputDim = m_outputDim;
|
||||
node->m_minibatchSize = m_minibatchSize;
|
||||
node->m_batchSize = m_batchSize;
|
||||
node->m_bias = m_bias;
|
||||
node->m_threshold = m_threshold;
|
||||
node->m_cosBias = m_cosBias;
|
||||
node->m_sinBias = m_sinBias;
|
||||
node->m_scale = m_scale;
|
||||
node->m_distGradAggPtr = m_distGradAggPtr;
|
||||
node->m_temp1->SetValue(*m_temp1);
|
||||
node->m_tempValue->SetValue(*m_tempValue);
|
||||
node->m_flag->SetValue(*m_flag);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool)
|
||||
{
|
||||
Base::RequestMatricesBeforeForwardProp(matrixPool);
|
||||
RequestMatrixFromPool(m_WNorm, matrixPool);
|
||||
RequestMatrixFromPool(m_temp1, matrixPool, m_inputDim * m_processNum, true);
|
||||
RequestMatrixFromPool(m_flag, matrixPool, m_inputDim * m_processNum, true);
|
||||
RequestMatrixFromPool(m_tempValue, matrixPool, m_inputDim * m_processNum, true);
|
||||
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
|
||||
{
|
||||
RequestMatrixFromPool(DistributedGatheredLabels<ElemType>::m_gatheredLabels, matrixPool, m_processNum, true);
|
||||
RequestMatrixFromPool(DistributedGatheredLabels<ElemType>::m_labels, matrixPool, 1, true);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool)
|
||||
{
|
||||
Base::ReleaseMatricesAfterForwardProp(matrixPool);
|
||||
ReleaseMatrixToPool(m_WNorm, matrixPool);
|
||||
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
|
||||
ReleaseMatrixToPool(DistributedGatheredLabels<ElemType>::m_labels, matrixPool);
|
||||
}
|
||||
|
||||
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
Base::ReleaseMatricesAfterBackprop(matrixPool);
|
||||
ReleaseMatrixToPool(m_temp1, matrixPool);
|
||||
ReleaseMatrixToPool(m_flag, matrixPool);
|
||||
ReleaseMatrixToPool(m_tempValue, matrixPool);
|
||||
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
|
||||
ReleaseMatrixToPool(DistributedGatheredLabels<ElemType>::m_gatheredLabels, matrixPool);
|
||||
}
|
||||
|
||||
void Save(File& fstream) const override
|
||||
{
|
||||
Base::Save(fstream);
|
||||
fstream << m_bias << m_scale;
|
||||
}
|
||||
|
||||
void Load(File& fstream, size_t modelVersion) override
|
||||
{
|
||||
Base::Load(fstream, modelVersion);
|
||||
fstream >> m_bias >> m_scale;
|
||||
|
||||
if (m_bias < 0 || m_bias >= m_PI)
|
||||
LogicError("DistributedArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
|
||||
m_threshold = cos(m_PI - m_bias);
|
||||
m_cosBias = cos(m_bias);
|
||||
m_sinBias = sin(m_bias);
|
||||
}
|
||||
|
||||
size_t m_rank;
|
||||
size_t m_processNum;
|
||||
size_t m_inputDim;
|
||||
size_t m_outputDim;
|
||||
size_t m_minibatchSize;
|
||||
size_t m_batchSize;
|
||||
double m_bias;
|
||||
double m_threshold;
|
||||
double m_cosBias;
|
||||
double m_sinBias;
|
||||
double m_scale;
|
||||
IDistGradAggregator<ElemType>* m_distGradAggPtr;
|
||||
shared_ptr<Matrix<ElemType>> m_temp1;
|
||||
shared_ptr<Matrix<ElemType>> m_tempValue; // Matrix(1, m_batchSize)
|
||||
shared_ptr<Matrix<ElemType>> m_flag; // Matrix(1, m_batchSize)
|
||||
shared_ptr<Matrix<ElemType>> m_WNorm;
|
||||
};
|
||||
|
||||
// Implements A-Softmax as described in:
|
||||
// SphereFace: DeepHypersphereEmbeddingforFaceRecognition [Weiyang Liu, Yandong Wen, Zhiding Yu, Ming Li, Bhiksha Raj, Le Song]
|
||||
// https://arxiv.org/abs/1704.08063
|
||||
|
@ -1146,9 +1365,6 @@ public:
|
|||
shared_ptr<Matrix<ElemType>> m_tempMatrix; // Matrix(k,m)
|
||||
};
|
||||
|
||||
// Implements AM-Softmax as described in:
|
||||
// Additive Margin Softmax for Face Verification [Feng Wang, Weiyang Liu, Haijun Liu, Jian Cheng]
|
||||
// https://arxiv.org/abs/1801.05599
|
||||
template <class ElemType>
|
||||
class FeatureNormalizeNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<1>
|
||||
{
|
||||
|
@ -1269,6 +1485,9 @@ public:
|
|||
shared_ptr<Matrix<ElemType>> m_temp1; // Matrix(1, minibatchSize)
|
||||
};
|
||||
|
||||
// Implements AM-Softmax as described in:
|
||||
// Additive Margin Softmax for Face Verification [Feng Wang, Weiyang Liu, Haijun Liu, Jian Cheng]
|
||||
// https://arxiv.org/abs/1801.05599
|
||||
template <class ElemType>
|
||||
class AdditiveFullConnectionNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<3>
|
||||
{
|
||||
|
@ -1461,6 +1680,175 @@ public:
|
|||
shared_ptr<Matrix<ElemType>> m_weightMagnitude; // Matrix(k,1)
|
||||
};
|
||||
|
||||
// Implements additive angular margin product as described in:
|
||||
// ArcFace: Additive Angular Margin Loss for Deep Face Recognition [Jiankang Deng, Jia Guo, Niannan Xue, Stefanos Zafeiriou]
|
||||
// https://arxiv.org/abs/1801.07698
|
||||
template <class ElemType>
|
||||
class ArcMarginProductNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<3>
|
||||
{
|
||||
typedef ComputationNodeNonLooping<ElemType> Base;
|
||||
UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName()
|
||||
{
|
||||
return L"ArcMarginProduct";
|
||||
}
|
||||
|
||||
public:
|
||||
ArcMarginProductNode(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
: ArcMarginProductNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"bias"))
|
||||
{
|
||||
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
ArcMarginProductNode(DEVICEID_TYPE deviceId, const wstring& name, double bias = 0.0)
|
||||
: Base(deviceId, name), m_bias(bias)
|
||||
{
|
||||
if (m_bias < 0 || m_bias >= m_PI)
|
||||
LogicError("ArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
|
||||
m_threshold = cos(m_PI - m_bias);
|
||||
m_cosBias = cos(m_bias);
|
||||
m_sinBias = sin(m_bias);
|
||||
}
|
||||
|
||||
virtual void UpdateFunctionMBSize() override
|
||||
{
|
||||
m_minibatchSize = InputRef(0).Value().GetNumCols();
|
||||
m_label->Resize(1, m_minibatchSize);
|
||||
m_tempValue->Resize(1, m_minibatchSize);
|
||||
m_flag->Resize(1, m_minibatchSize);
|
||||
m_weightMagnitude->Resize(InputRef(2).Value().GetNumRows(), 1);
|
||||
}
|
||||
|
||||
virtual void BackpropToNonLooping(size_t inputIndex) override
|
||||
{
|
||||
if (1 == inputIndex)
|
||||
{
|
||||
Matrix<ElemType>::ArcLabelAddBackprop(*m_label, (ElemType)m_cosBias, (ElemType)m_sinBias, *m_flag, *m_tempValue, Gradient());
|
||||
|
||||
FrameRange fr(InputRef(1).GetMBLayout());
|
||||
auto X_gradient = InputRef(1).GradientFor(fr);
|
||||
auto& weight = InputRef(2).Value();
|
||||
Matrix<ElemType>::Multiply(weight, true, Gradient(), false, X_gradient);
|
||||
}
|
||||
else if (2 == inputIndex)
|
||||
{
|
||||
FrameRange fr(InputRef(1).GetMBLayout());
|
||||
auto X = InputRef(1).ValueFor(fr);
|
||||
auto& weightGradient = InputRef(2).Gradient();
|
||||
Matrix<ElemType>::Multiply(Gradient(), false, X, true, weightGradient);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override
|
||||
{
|
||||
FrameRange fr(InputRef(0).GetMBLayout());
|
||||
InputRef(0).MaskedValueFor(fr).VectorMax(*m_label, *m_tempValue, true /*isColWise*/);
|
||||
auto X = InputRef(1).ValueFor(fr);
|
||||
auto& weight = InputRef(2).Value();
|
||||
|
||||
weight.VectorNorm2(*m_weightMagnitude, false);
|
||||
weight.ColumnElementDivideBy(*m_weightMagnitude);
|
||||
|
||||
Matrix<ElemType>::Multiply(weight, false, X, false, Value());
|
||||
|
||||
if (Environment().IsTraining())
|
||||
{
|
||||
m_flag->SetValue(0);
|
||||
Matrix<ElemType>::ArcLabelAdd(*m_label, (ElemType)m_threshold, (ElemType)m_bias, (ElemType)m_sinBias, *m_flag, *m_tempValue, Value());
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool OutputUsedInComputingInputNodesGradients() const override
|
||||
{
|
||||
return false;
|
||||
}
|
||||
virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override
|
||||
{
|
||||
if (0 == childIndex)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
|
||||
{
|
||||
Base::Validate(isFinalValidationPass);
|
||||
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
|
||||
|
||||
SetDims(TensorShape(InputRef(2).Value().GetNumRows()), HasMBLayout());
|
||||
}
|
||||
|
||||
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
|
||||
{
|
||||
Base::CopyTo(nodeP, newName, flags);
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
auto node = dynamic_pointer_cast<ArcMarginProductNode<ElemType>>(nodeP);
|
||||
node->m_minibatchSize = m_minibatchSize;
|
||||
node->m_bias = m_bias;
|
||||
node->m_threshold = m_threshold;
|
||||
node->m_cosBias = m_cosBias;
|
||||
node->m_sinBias = m_sinBias;
|
||||
node->m_label->SetValue(*m_label);
|
||||
node->m_tempValue->SetValue(*m_tempValue);
|
||||
node->m_flag->SetValue(*m_flag);
|
||||
node->m_weightMagnitude->SetValue(*m_weightMagnitude);
|
||||
}
|
||||
}
|
||||
|
||||
// request matrices needed to do node function value evaluation
|
||||
virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool)
|
||||
{
|
||||
Base::RequestMatricesBeforeForwardProp(matrixPool);
|
||||
RequestMatrixFromPool(m_label, matrixPool, 1, true, true, false);
|
||||
RequestMatrixFromPool(m_tempValue, matrixPool, 1, true, true, false);
|
||||
RequestMatrixFromPool(m_flag, matrixPool, 1, true, true, false);
|
||||
RequestMatrixFromPool(m_weightMagnitude, matrixPool);
|
||||
}
|
||||
|
||||
virtual void ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool)
|
||||
{
|
||||
Base::ReleaseMatricesAfterForwardProp(matrixPool);
|
||||
ReleaseMatrixToPool(m_weightMagnitude, matrixPool);
|
||||
}
|
||||
|
||||
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
Base::ReleaseMatricesAfterBackprop(matrixPool);
|
||||
ReleaseMatrixToPool(m_label, matrixPool);
|
||||
ReleaseMatrixToPool(m_tempValue, matrixPool);
|
||||
ReleaseMatrixToPool(m_flag, matrixPool);
|
||||
}
|
||||
|
||||
void Save(File& fstream) const override
|
||||
{
|
||||
Base::Save(fstream);
|
||||
fstream << m_bias;
|
||||
}
|
||||
|
||||
void Load(File& fstream, size_t modelVersion) override
|
||||
{
|
||||
Base::Load(fstream, modelVersion);
|
||||
fstream >> m_bias;
|
||||
|
||||
if (m_bias < 0 || m_bias >= m_PI)
|
||||
LogicError("ArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
|
||||
m_threshold = cos(m_PI - m_bias);
|
||||
m_cosBias = cos(m_bias);
|
||||
m_sinBias = sin(m_bias);
|
||||
}
|
||||
|
||||
size_t m_minibatchSize;
|
||||
double m_bias;
|
||||
double m_threshold;
|
||||
double m_cosBias;
|
||||
double m_sinBias;
|
||||
|
||||
shared_ptr<Matrix<ElemType>> m_label;
|
||||
shared_ptr<Matrix<ElemType>> m_tempValue;
|
||||
shared_ptr<Matrix<ElemType>> m_flag;
|
||||
shared_ptr<Matrix<ElemType>> m_weightMagnitude;
|
||||
};
|
||||
|
||||
// Implements Center-Loss as described in:
|
||||
// A Discriminative Feature Learning Approach for Deep Face Recognition [Yandong Wen, Kaipeng Zhang, Zhifeng Li, Yu Qiao]
|
||||
// https://ydwen.github.io/papers/WenECCV16.pdf
|
||||
|
@ -1625,145 +2013,6 @@ public:
|
|||
bool initFlag = true;
|
||||
};
|
||||
|
||||
// Implements Squeeze-and-Excitation operation as described in:
|
||||
// Squeeze-and-Excitation Networks [Jie Hu, Li Shen, Gang Sun]
|
||||
// https://arxiv.org/pdf/1709.01507.pdf
|
||||
template <class ElemType>
|
||||
class ChannelMultiplyNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<2>
|
||||
{
|
||||
typedef ComputationNodeNonLooping<ElemType> Base;
|
||||
UsingComputationNodeMembersBoilerplate;
|
||||
static const std::wstring TypeName()
|
||||
{
|
||||
return L"ChannelMultiply";
|
||||
}
|
||||
|
||||
public:
|
||||
ChannelMultiplyNode(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
: ChannelMultiplyNode(configp->Get(L"deviceId"), L"<placeholder>")
|
||||
{
|
||||
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
ChannelMultiplyNode(DEVICEID_TYPE deviceId, const wstring& name)
|
||||
: Base(deviceId, name)
|
||||
{
|
||||
}
|
||||
|
||||
virtual void BackpropToNonLooping(size_t inputIndex) override
|
||||
{
|
||||
if (0 == inputIndex)
|
||||
{
|
||||
FrameRange fr(InputRef(0).GetMBLayout());
|
||||
auto X_gradient = InputRef(0).GradientFor(fr);
|
||||
auto weight = InputRef(1).ValueFor(fr);
|
||||
|
||||
Matrix<ElemType>::ChannelMultiply(Gradient(), weight, X_gradient, m_featureSize);
|
||||
}
|
||||
else
|
||||
{
|
||||
FrameRange fr(InputRef(0).GetMBLayout());
|
||||
auto X = InputRef(0).ValueFor(fr);
|
||||
auto weight_gradient = InputRef(1).GradientFor(fr);
|
||||
weight_gradient.SetValue((ElemType)0);
|
||||
m_buffer->Resize(m_featureSize * m_channels, X.GetNumCols());
|
||||
|
||||
Matrix<ElemType>::ChannelMultiplyScaleBackprop(Gradient(), X, weight_gradient, *m_buffer, m_featureSize, m_N);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override
|
||||
{
|
||||
FrameRange fr(InputRef(0).GetMBLayout());
|
||||
auto X = InputRef(0).ValueFor(fr);
|
||||
auto weight = InputRef(1).ValueFor(fr);
|
||||
|
||||
Matrix<ElemType>::ChannelMultiply(X, weight, Value(), m_featureSize);
|
||||
}
|
||||
|
||||
virtual bool OutputUsedInComputingInputNodesGradients() const override
|
||||
{
|
||||
return false;
|
||||
}
|
||||
virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
|
||||
{
|
||||
Base::Validate(isFinalValidationPass);
|
||||
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
|
||||
SetDims(Input(0));
|
||||
|
||||
auto dims0 = Input(0)->GetSampleLayout().GetDims();
|
||||
auto dims1 = Input(1)->GetSampleLayout().GetDims();
|
||||
if (dims0.size() != 3)
|
||||
LogicError("ChannelMultiplyNode : input[0] dimension not equals to 3 \n");
|
||||
size_t temp = 1;
|
||||
for (size_t i(0); i < dims1.size(); ++i)
|
||||
temp *= dims1[i];
|
||||
if (dims0[2] != temp)
|
||||
LogicError("ChannelMultiplyNode : input channel not match %d v.s. %d\n", (int)dims0[2], (int)temp);
|
||||
|
||||
m_featureSize = dims0[0] * dims0[1];
|
||||
m_channels = dims0[2];
|
||||
|
||||
size_t featureSize = m_featureSize;
|
||||
m_N = 1;
|
||||
assert(featureSize != 0);
|
||||
while (featureSize)
|
||||
{
|
||||
m_N *= 2;
|
||||
featureSize /= 2;
|
||||
}
|
||||
m_N /= 2;
|
||||
}
|
||||
|
||||
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
|
||||
{
|
||||
Base::CopyTo(nodeP, newName, flags);
|
||||
if (flags & CopyNodeFlags::copyNodeValue)
|
||||
{
|
||||
auto node = dynamic_pointer_cast<ChannelMultiplyNode<ElemType>>(nodeP);
|
||||
node->m_buffer->SetValue(*m_buffer);
|
||||
node->m_featureSize = m_featureSize;
|
||||
node->m_channels = m_channels;
|
||||
node->m_N = m_N;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
Base::RequestMatricesBeforeBackprop(matrixPool);
|
||||
RequestMatrixFromPool(m_buffer, matrixPool, m_featureSize * m_channels, true, false, false);
|
||||
}
|
||||
|
||||
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool)
|
||||
{
|
||||
Base::ReleaseMatricesAfterBackprop(matrixPool);
|
||||
ReleaseMatrixToPool(m_buffer, matrixPool);
|
||||
}
|
||||
|
||||
void Save(File& fstream) const override
|
||||
{
|
||||
Base::Save(fstream);
|
||||
fstream << m_featureSize << m_channels;
|
||||
}
|
||||
|
||||
void Load(File& fstream, size_t modelVersion) override
|
||||
{
|
||||
Base::Load(fstream, modelVersion);
|
||||
fstream >> m_featureSize >> m_channels;
|
||||
}
|
||||
|
||||
|
||||
shared_ptr<Matrix<ElemType>> m_buffer;
|
||||
size_t m_featureSize;
|
||||
size_t m_channels;
|
||||
size_t m_N;
|
||||
};
|
||||
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// SquareErrorNode (left, right)
|
||||
|
|
|
@ -472,17 +472,17 @@ public:
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region CenterLoss
|
||||
#pragma region ArcMarginProduct
|
||||
|
||||
static void ClassCount(const CPUMatrix<ElemType>& label, const CPUMatrix<ElemType>& counter);
|
||||
static void ArcLabelAdd(const CPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& value);
|
||||
|
||||
static void ArcLabelAddBackprop(const CPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& gradient);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region SqueezeAndExcitation
|
||||
#pragma region CenterLoss
|
||||
|
||||
static void ChannelMultiply(const CPUMatrix<ElemType>& X, const CPUMatrix<ElemType>& weight, CPUMatrix<ElemType>& value, size_t featureSize);
|
||||
|
||||
static void ChannelMultiplyScaleBackprop(const CPUMatrix<ElemType>& gradient, const CPUMatrix<ElemType>& X, CPUMatrix<ElemType>& weight_gradient, size_t featureSize);
|
||||
static void ClassCount(const CPUMatrix<ElemType>& label, const CPUMatrix<ElemType>& counter);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
@ -518,6 +518,10 @@ public:
|
|||
|
||||
static void DistributedLabelAdd(const CPUMatrix<ElemType>& labels, ElemType bias, const CPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex);
|
||||
|
||||
static void DistributedArcLabelAdd(const CPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex);
|
||||
|
||||
static void DistributedArcLabelAddBackprop(const CPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
||||
|
|
|
@ -5430,6 +5430,61 @@ void CPUMatrix<ElemType>::LabelAdd(const CPUMatrix<ElemType>& label, ElemType bi
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region ArcMarginProduct
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::ArcLabelAdd(const CPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& value)
|
||||
{
|
||||
size_t minibatchSize = value.GetNumCols();
|
||||
size_t outputDimension = value.GetNumRows();
|
||||
size_t labelValue;
|
||||
ElemType* labelPtr = label.Data();
|
||||
ElemType* flagPtr = flag.Data();
|
||||
ElemType* xPtr = x.Data();
|
||||
ElemType* valuePtr = value.Data();
|
||||
|
||||
for (size_t i(0); i < minibatchSize; ++i)
|
||||
{
|
||||
labelValue = static_cast<size_t>(labelPtr[i]);
|
||||
size_t index = i * outputDimension + labelValue;
|
||||
|
||||
if (valuePtr[index] > threshold)
|
||||
{
|
||||
xPtr[i] = valuePtr[index];
|
||||
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
valuePtr[index] -= bias * sinBias;
|
||||
flagPtr[i] = 1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::ArcLabelAddBackprop(const CPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& gradient)
|
||||
{
|
||||
size_t minibatchSize = gradient.GetNumCols();
|
||||
size_t outputDimension = gradient.GetNumRows();
|
||||
size_t labelValue;
|
||||
ElemType* labelPtr = label.Data();
|
||||
ElemType* flagPtr = flag.Data();
|
||||
ElemType* xPtr = x.Data();
|
||||
ElemType* gradientPtr = gradient.Data();
|
||||
|
||||
for (size_t i(0); i < minibatchSize; ++i)
|
||||
{
|
||||
if (flagPtr[i] < 0.5f)
|
||||
{
|
||||
labelValue = static_cast<size_t>(labelPtr[i]);
|
||||
size_t index = i * outputDimension + labelValue;
|
||||
gradientPtr[index] *= cosBias + sinBias * xPtr[i] / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region CenterLoss
|
||||
|
||||
template <class ElemType>
|
||||
|
@ -5457,57 +5512,6 @@ void CPUMatrix<ElemType>::ClassCount(const CPUMatrix<ElemType>& label, const CPU
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region SqueezeAndExcitation
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::ChannelMultiply(const CPUMatrix<ElemType>& X, const CPUMatrix<ElemType>& weight, CPUMatrix<ElemType>& value, size_t featureSize)
|
||||
{
|
||||
long n = (long)X.GetNumCols();
|
||||
long m = (long)X.GetNumRows();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (long j = 0; j < n; j++)
|
||||
{
|
||||
// four-way unrolling
|
||||
for (long i = 0; i < (m & ~3); i += 4)
|
||||
{
|
||||
value(i, j) = X(i, j) * weight(i / featureSize, j);
|
||||
value(i + 1, j) = X(i + 1, j) * weight((i + 1) / featureSize, j);
|
||||
value(i + 2, j) = X(i + 2, j) * weight((i + 2) / featureSize, j);
|
||||
value(i + 3, j) = X(i + 3, j) * weight((i + 3) / featureSize, j);
|
||||
}
|
||||
// handle remaining stuffs
|
||||
for (long i = m & ~3; i < m; i++)
|
||||
{
|
||||
value(i, j) = X(i, j) * weight(i / featureSize, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(const CPUMatrix<ElemType>& gradient, const CPUMatrix<ElemType>& X, CPUMatrix<ElemType>& weight_gradient, size_t featureSize)
|
||||
{
|
||||
long n = (long)X.GetNumCols();
|
||||
long m = (long)X.GetNumRows();
|
||||
|
||||
#pragma omp parallel for
|
||||
for (long j = 0; j < n; j++)
|
||||
{
|
||||
// four-way unrolling
|
||||
for (long i = 0; i < (m & ~3); i += 4)
|
||||
{
|
||||
weight_gradient(i / featureSize, j) += gradient(i, j) * X(i, j);
|
||||
}
|
||||
// handle remaining stuffs
|
||||
for (long i = m & ~3; i < m; i++)
|
||||
{
|
||||
weight_gradient(i / featureSize, j) += gradient(i, j) * X(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region LabelSmoothing
|
||||
|
||||
template <class ElemType>
|
||||
|
@ -5829,6 +5833,54 @@ void CPUMatrix<ElemType>::DistributedLabelAdd(const CPUMatrix<ElemType>& labels,
|
|||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::DistributedArcLabelAdd(const CPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
long cols = (long)value.GetNumCols();
|
||||
long rows = (long)value.GetNumRows();
|
||||
ElemType* labelsPtr = labels.Data();
|
||||
ElemType* flagPtr = flag.Data();
|
||||
ElemType* xPtr = x.Data();
|
||||
ElemType* valuePtr = value.Data();
|
||||
|
||||
for (long i = 0; i < cols; i += 4)
|
||||
{
|
||||
long index = i * rows + ((long)labelsPtr[i]) - (long)startIndex;
|
||||
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex)
|
||||
{
|
||||
if (valuePtr[index] > threshold)
|
||||
{
|
||||
xPtr[i] = valuePtr[index];
|
||||
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
valuePtr[index] -= bias * sinBias;
|
||||
flagPtr[i] = 1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void CPUMatrix<ElemType>::DistributedArcLabelAddBackprop(const CPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
long cols = (long)gradient.GetNumCols();
|
||||
long rows = (long)gradient.GetNumRows();
|
||||
ElemType* labelsPtr = labels.Data();
|
||||
ElemType* flagPtr = flag.Data();
|
||||
ElemType* xPtr = x.Data();
|
||||
ElemType* gradientPtr = gradient.Data();
|
||||
|
||||
for (long i = 0; i < cols; i += 4)
|
||||
{
|
||||
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && flagPtr[i] < 0.5f)
|
||||
{
|
||||
gradientPtr[i * rows + ((long)labelsPtr[i]) - startIndex] *= cosBias + sinBias * xPtr[i] / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
||||
|
|
|
@ -3765,14 +3765,14 @@ template <class ElemType>
|
|||
__global__ void _labelAdd(CUDA_LONG outputDimension, const ElemType* label, ElemType bias, ElemType* value, CUDA_LONG numElements)
|
||||
{
|
||||
CUDA_LONG id = GridDim::GetLinearThreadId();
|
||||
if (id < numElements)
|
||||
{
|
||||
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
|
||||
CUDA_LONG index = id * outputDimension + labelValue;
|
||||
if (value[index] <= -bias)
|
||||
return;
|
||||
value[index] += bias;
|
||||
}
|
||||
if (id >= numElements)
|
||||
return;
|
||||
|
||||
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
|
||||
CUDA_LONG index = id * outputDimension + labelValue;
|
||||
if (value[index] <= -bias)
|
||||
return;
|
||||
value[index] += bias;
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
@ -3789,6 +3789,70 @@ void GPUMatrix<ElemType>::LabelAdd(const GPUMatrix<ElemType>& label, ElemType bi
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region CenterLoss
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _arcLabelAdd(CUDA_LONG outputDimension, const ElemType* label, ElemType threshold, ElemType bias, ElemType sinBias, ElemType* flag, ElemType* x, ElemType* value, CUDA_LONG numElements)
|
||||
{
|
||||
CUDA_LONG id = GridDim::GetLinearThreadId();
|
||||
if (id >= numElements)
|
||||
return;
|
||||
|
||||
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
|
||||
CUDA_LONG index = id * outputDimension + labelValue;
|
||||
|
||||
if (value[index] > threshold)
|
||||
{
|
||||
x[id] = value[index];
|
||||
value[index] = cosf(acosf(value[index]) + bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
value[index] -= bias * sinBias;
|
||||
flag[id] = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ArcLabelAdd(const GPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value)
|
||||
{
|
||||
CUDA_LONG minibatchSize = (CUDA_LONG)value.GetNumCols();
|
||||
CUDA_LONG outputDimension = (CUDA_LONG)value.GetNumRows();
|
||||
|
||||
int blocksPerGrid = (int)ceil(1.0 * minibatchSize / GridDim::maxThreadsPerBlock);
|
||||
label.PrepareDevice();
|
||||
SyncGuard syncGuard;
|
||||
_arcLabelAdd<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (outputDimension, label.Data(), threshold, bias, sinBias, flag.Data(), x.Data(), value.Data(), minibatchSize);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _arcLabelAddBackprop(CUDA_LONG outputDimension, const ElemType* label, ElemType cosBias, ElemType sinBias, ElemType* flag, ElemType* x, ElemType* gradient, CUDA_LONG numElements)
|
||||
{
|
||||
CUDA_LONG id = GridDim::GetLinearThreadId();
|
||||
if (id >= numElements)
|
||||
return;
|
||||
if (flag[id] > 0.5f)
|
||||
return;
|
||||
|
||||
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
|
||||
CUDA_LONG index = id * outputDimension + labelValue;
|
||||
|
||||
gradient[index] *= cosBias + sinBias * x[id] / (sqrtf(1 - x[id] * x[id]) + 1e-12);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ArcLabelAddBackprop(const GPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient)
|
||||
{
|
||||
CUDA_LONG minibatchSize = (CUDA_LONG)gradient.GetNumCols();
|
||||
CUDA_LONG outputDimension = (CUDA_LONG)gradient.GetNumRows();
|
||||
|
||||
int blocksPerGrid = (int)ceil(1.0 * minibatchSize / GridDim::maxThreadsPerBlock);
|
||||
label.PrepareDevice();
|
||||
SyncGuard syncGuard;
|
||||
_arcLabelAddBackprop<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (outputDimension, label.Data(), cosBias, sinBias, flag.Data(), x.Data(), gradient.Data(), minibatchSize);
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region CenterLoss
|
||||
|
||||
|
@ -3818,68 +3882,6 @@ void GPUMatrix<ElemType>::ClassCount(const GPUMatrix<ElemType>& label, const GPU
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region SqueezeAndExcitation
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _channelMultiply(const ElemType* X, const ElemType* weight, ElemType* value, CUDA_LONG featureSize, const CUDA_LONG numElements)
|
||||
{
|
||||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (id >= numElements)
|
||||
return;
|
||||
|
||||
value[id] = X[id] * weight[id / featureSize];
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ChannelMultiply(const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight, const GPUMatrix<ElemType>& value, size_t featureSize)
|
||||
{
|
||||
CUDA_LONG numElements = X.GetNumElements();
|
||||
|
||||
SyncGuard syncGuard;
|
||||
GridDim grid(numElements);
|
||||
_channelMultiply<ElemType> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(X.Data(), weight.Data(), value.Data(), (CUDA_LONG)featureSize, numElements);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _channelMultiplyScaleBackprop(const ElemType* gradient, const ElemType* X, ElemType* weight_gradient, ElemType* buffer, CUDA_LONG featureSize, CUDA_LONG N, const CUDA_LONG numElements)
|
||||
{
|
||||
CUDA_LONG id = GridDim::GetLinearThreadId();
|
||||
if (id < numElements)
|
||||
{
|
||||
CUDA_LONG i = id / featureSize; // Channel i
|
||||
CUDA_LONG j = id % featureSize; // Element j
|
||||
|
||||
if (j < N)
|
||||
buffer[id] = gradient[id] * X[id];
|
||||
__syncthreads();
|
||||
if (j >= N)
|
||||
buffer[id - N] += gradient[id] * X[id];
|
||||
__syncthreads();
|
||||
|
||||
for (CUDA_LONG k = N >> 1; k > 0; k >>= 1)
|
||||
{
|
||||
if (j < k)
|
||||
buffer[id] += buffer[id + k];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (0 == j)
|
||||
weight_gradient[i] = buffer[id];
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(const GPUMatrix<ElemType>& gradient, const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight_gradient, const GPUMatrix<ElemType>& buffer, size_t featureSize, size_t N)
|
||||
{
|
||||
CUDA_LONG numElements = gradient.GetNumElements();
|
||||
|
||||
SyncGuard syncGuard;
|
||||
GridDim grid(numElements);
|
||||
_channelMultiplyScaleBackprop<ElemType> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(gradient.Data(), X.Data(), weight_gradient.Data(), buffer.Data(), (CUDA_LONG)featureSize, (CUDA_LONG)N, numElements);
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region LabelSmoothing
|
||||
|
||||
template <class ElemType>
|
||||
|
@ -4318,9 +4320,7 @@ __global__ void _distributedLabelAdd(const ElemType* labels, ElemType bias, Elem
|
|||
return;
|
||||
|
||||
CUDA_LONG label = (CUDA_LONG)labels[id];
|
||||
if (label < startIndex)
|
||||
return;
|
||||
if (label > endIndex)
|
||||
if (label < startIndex || label > endIndex)
|
||||
return;
|
||||
if (value[id * rows + label - startIndex] <= -bias)
|
||||
return;
|
||||
|
@ -4340,6 +4340,69 @@ void GPUMatrix<ElemType>::DistributedLabelAdd(const GPUMatrix<ElemType>& labels,
|
|||
_distributedLabelAdd<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (labels.Data(), bias, value.Data(), rows, (CUDA_LONG)startIndex, (CUDA_LONG)endIndex, cols);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _distributedArcLabelAdd(const ElemType* labels, const ElemType threshold, const ElemType bias, const ElemType sinBias, ElemType* flag, ElemType* x, ElemType* value, CUDA_LONG rows, CUDA_LONG startIndex, CUDA_LONG endIndex, CUDA_LONG numElements)
|
||||
{
|
||||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (id >= numElements)
|
||||
return;
|
||||
|
||||
CUDA_LONG label = (CUDA_LONG)labels[id];
|
||||
if (label < startIndex || label > endIndex)
|
||||
return;
|
||||
|
||||
CUDA_LONG index = id * rows + label - startIndex;
|
||||
if (value[index] > threshold)
|
||||
{
|
||||
x[id] = value[index];
|
||||
value[index] = cosf(acosf(value[index]) + bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
value[index] -= bias * sinBias;
|
||||
flag[id] = 1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
CUDA_LONG cols = value.GetNumCols();
|
||||
CUDA_LONG rows = value.GetNumRows();
|
||||
|
||||
int blocksPerGrid = (int)ceil(1.0 * cols / GridDim::maxThreadsPerBlock);
|
||||
labels.PrepareDevice();
|
||||
SyncGuard syncGuard;
|
||||
_distributedArcLabelAdd<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (labels.Data(), threshold, bias, sinBias, flag.Data(), x.Data(), value.Data(), rows, (CUDA_LONG)startIndex, (CUDA_LONG)endIndex, cols);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
__global__ void _distributedArcLabelAddBackprop(const ElemType* labels, const ElemType cosBias, const ElemType sinBias, ElemType* flag, ElemType* x, ElemType* gradient, CUDA_LONG rows, CUDA_LONG startIndex, CUDA_LONG endIndex, CUDA_LONG numElements)
|
||||
{
|
||||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (id >= numElements)
|
||||
return;
|
||||
|
||||
CUDA_LONG label = (CUDA_LONG)labels[id];
|
||||
if (label < startIndex || label > endIndex || flag[id] > 0.5f)
|
||||
return;
|
||||
|
||||
CUDA_LONG index = id * rows + label - startIndex;
|
||||
gradient[index] *= cosBias + sinBias * x[id] / (sqrtf(1 - x[id] * x[id]) + 1e-12);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::DistributedArcLabelAddBackprop(const GPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
CUDA_LONG cols = gradient.GetNumCols();
|
||||
CUDA_LONG rows = gradient.GetNumRows();
|
||||
|
||||
int blocksPerGrid = (int)ceil(1.0 * cols / GridDim::maxThreadsPerBlock);
|
||||
labels.PrepareDevice();
|
||||
SyncGuard syncGuard;
|
||||
_distributedArcLabelAddBackprop<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (labels.Data(), cosBias, sinBias, flag.Data(), x.Data(), gradient.Data(), rows, (CUDA_LONG)startIndex, (CUDA_LONG)endIndex, cols);
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
||||
|
|
|
@ -588,17 +588,17 @@ public:
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region CenterLoss
|
||||
#pragma region ArcMarginProduct
|
||||
|
||||
static void ClassCount(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& counter);
|
||||
static void ArcLabelAdd(const GPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value);
|
||||
|
||||
static void ArcLabelAddBackprop(const GPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region SqueezeAndExcitation
|
||||
#pragma region CenterLoss
|
||||
|
||||
static void ChannelMultiply(const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight, const GPUMatrix<ElemType>& value, size_t featureSize);
|
||||
|
||||
static void ChannelMultiplyScaleBackprop(const GPUMatrix<ElemType>& gradient, const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight_gradient, const GPUMatrix<ElemType>& buffer, size_t featureSize, size_t N);
|
||||
static void ClassCount(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& counter);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
@ -634,6 +634,10 @@ public:
|
|||
|
||||
static void DistributedLabelAdd(const GPUMatrix<ElemType>& labels, ElemType bias, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex);
|
||||
|
||||
static void DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex);
|
||||
|
||||
static void DistributedArcLabelAddBackprop(const GPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
||||
|
|
|
@ -5203,6 +5203,32 @@ template <class ElemType>
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region ArcMarginProduct
|
||||
|
||||
template <class ElemType>
|
||||
/*static*/ void Matrix<ElemType>::ArcLabelAdd(const Matrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& value)
|
||||
{
|
||||
DISPATCH_MATRIX_ON_FLAG(&value,
|
||||
&value,
|
||||
CPUMatrix<ElemType>::ArcLabelAdd(*(label.m_CPUMatrix), threshold, bias, sinBias, *(flag.m_CPUMatrix), *(x.m_CPUMatrix), *(value.m_CPUMatrix)),
|
||||
GPUMatrix<ElemType>::ArcLabelAdd(*(label.m_GPUMatrix), threshold, bias, sinBias, *(flag.m_GPUMatrix), *(x.m_GPUMatrix), *(value.m_GPUMatrix)),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
/*static*/ void Matrix<ElemType>::ArcLabelAddBackprop(const Matrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& gradient)
|
||||
{
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradient,
|
||||
&gradient,
|
||||
CPUMatrix<ElemType>::ArcLabelAddBackprop(*(label.m_CPUMatrix), cosBias, sinBias, *(flag.m_CPUMatrix), *(x.m_CPUMatrix), *(gradient.m_CPUMatrix)),
|
||||
GPUMatrix<ElemType>::ArcLabelAddBackprop(*(label.m_GPUMatrix), cosBias, sinBias, *(flag.m_GPUMatrix), *(x.m_GPUMatrix), *(gradient.m_GPUMatrix)),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region CenterLoss
|
||||
|
||||
template <class ElemType>
|
||||
|
@ -5218,32 +5244,6 @@ template <class ElemType>
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region SqueezeAndExcitation
|
||||
|
||||
template <class ElemType>
|
||||
/*static*/ void Matrix<ElemType>::ChannelMultiply(const Matrix<ElemType>& X, const Matrix<ElemType>& weight, const Matrix<ElemType>& value, size_t featureSize)
|
||||
{
|
||||
DISPATCH_MATRIX_ON_FLAG(&value,
|
||||
&value,
|
||||
CPUMatrix<ElemType>::ChannelMultiply(*(X.m_CPUMatrix), *(weight.m_CPUMatrix), *(value.m_CPUMatrix), featureSize),
|
||||
GPUMatrix<ElemType>::ChannelMultiply(*(X.m_GPUMatrix), *(weight.m_GPUMatrix), *(value.m_GPUMatrix), featureSize),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
/*static*/ void Matrix<ElemType>::ChannelMultiplyScaleBackprop(const Matrix<ElemType>& gradient, const Matrix<ElemType>& X, const Matrix<ElemType>& weight_gradient, const Matrix<ElemType>& buffer, size_t featureSize, size_t N)
|
||||
{
|
||||
DISPATCH_MATRIX_ON_FLAG(&weight_gradient,
|
||||
&weight_gradient,
|
||||
CPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(*(gradient.m_CPUMatrix), *(X.m_CPUMatrix), *(weight_gradient.m_CPUMatrix), featureSize),
|
||||
GPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(*(gradient.m_GPUMatrix), *(X.m_GPUMatrix), *(weight_gradient.m_GPUMatrix), *(buffer.m_GPUMatrix), featureSize, N),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region LabelSmoothing
|
||||
|
||||
template <class ElemType>
|
||||
|
@ -5425,6 +5425,29 @@ template <class ElemType>
|
|||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
/*static*/ void Matrix<ElemType>::DistributedArcLabelAdd(const Matrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& value, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
DISPATCH_MATRIX_ON_FLAG(&value,
|
||||
&value,
|
||||
CPUMatrix<ElemType>::DistributedArcLabelAdd(*(labels.m_CPUMatrix), threshold, bias, sinBias, *(flag.m_CPUMatrix), *(x.m_CPUMatrix), *(value.m_CPUMatrix), startIndex, endIndex),
|
||||
GPUMatrix<ElemType>::DistributedArcLabelAdd(*(labels.m_GPUMatrix), threshold, bias, sinBias, *(flag.m_GPUMatrix), *(x.m_GPUMatrix), *(value.m_GPUMatrix), startIndex, endIndex),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
/*static*/ void Matrix<ElemType>::DistributedArcLabelAddBackprop(const Matrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& gradient, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
DISPATCH_MATRIX_ON_FLAG(&gradient,
|
||||
&gradient,
|
||||
CPUMatrix<ElemType>::DistributedArcLabelAddBackprop(*(labels.m_CPUMatrix), cosBias, sinBias, *(flag.m_CPUMatrix), *(x.m_CPUMatrix), *(gradient.m_CPUMatrix), startIndex, endIndex),
|
||||
GPUMatrix<ElemType>::DistributedArcLabelAddBackprop(*(labels.m_GPUMatrix), cosBias, sinBias, *(flag.m_GPUMatrix), *(x.m_GPUMatrix), *(gradient.m_GPUMatrix), startIndex, endIndex),
|
||||
NOT_IMPLEMENTED,
|
||||
NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
||||
|
|
|
@ -628,17 +628,17 @@ public:
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region CenterLoss
|
||||
#pragma region ArcMarginProduct
|
||||
|
||||
static void ClassCount(const Matrix<ElemType>& label, const Matrix<ElemType>& counter);
|
||||
static void ArcLabelAdd(const Matrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& value);
|
||||
|
||||
static void ArcLabelAddBackprop(const Matrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& gradient);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region SqueezeAndExcitation
|
||||
#pragma region CenterLoss
|
||||
|
||||
static void ChannelMultiply(const Matrix<ElemType>& X, const Matrix<ElemType>& weight, const Matrix<ElemType>& value, size_t featureSize);
|
||||
|
||||
static void ChannelMultiplyScaleBackprop(const Matrix<ElemType>& gradient, const Matrix<ElemType>& X, const Matrix<ElemType>& weight_gradient, const Matrix<ElemType>& buffer, size_t featureSize, size_t N);
|
||||
static void ClassCount(const Matrix<ElemType>& label, const Matrix<ElemType>& counter);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
@ -674,6 +674,10 @@ public:
|
|||
|
||||
static void DistributedLabelAdd(const Matrix<ElemType>& labels, ElemType bias, const Matrix<ElemType>& value, size_t startIndex, size_t endIndex);
|
||||
|
||||
static void DistributedArcLabelAdd(const Matrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& value, size_t startIndex, size_t endIndex);
|
||||
|
||||
static void DistributedArcLabelAddBackprop(const Matrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& gradient, size_t startIndex, size_t endIndex);
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
||||
|
|
|
@ -2166,24 +2166,24 @@ void GPUMatrix<ElemType>::LabelAdd(const GPUMatrix<ElemType>& label, ElemType bi
|
|||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region CenterLoss
|
||||
#pragma region ArcMarginProduct
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ClassCount(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& counter)
|
||||
void GPUMatrix<ElemType>::ArcLabelAdd(const GPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value)
|
||||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ArcLabelAddBackprop(const GPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient)
|
||||
{
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
#pragma region SqueezeAndExcitation
|
||||
#pragma region CenterLoss
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ChannelMultiply(const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight, const GPUMatrix<ElemType>& value, size_t featureSize)
|
||||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(const GPUMatrix<ElemType>& gradient, const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight_gradient, const GPUMatrix<ElemType>& buffer, size_t featureSize, size_t N)
|
||||
void GPUMatrix<ElemType>::ClassCount(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& counter)
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -2260,6 +2260,16 @@ void GPUMatrix<ElemType>::DistributedLabelAdd(const GPUMatrix<ElemType>& labels,
|
|||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::DistributedArcLabelAddBackprop(const GPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
}
|
||||
|
||||
#pragma endregion
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче