Merge pull request #3705 from Veason-silverbullet/lewlu/msra-face-dist

Lewlu/msra face dist
This commit is contained in:
Lewei Lu 2019-06-28 21:23:08 +08:00 коммит произвёл GitHub
Родитель 2a203c6a03 3807c2a14e
Коммит 5836d74b6c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
24 изменённых файлов: 875 добавлений и 428 удалений

Просмотреть файл

@ -81,6 +81,15 @@ DistributedAdditiveFullConnectionLayer {inDim, outDim, weightNormalize=true, bia
apply (labelSequence, outProbVectorSequence) = DistributedAdditiveFullConnection(labelSequence, W, outProbVectorSequence, weightNormalize=weightNormalize, bias=bias, scale=scale)
}.apply
# DistributedArcMarginProductLayer -- create a distributed arc margin product layer
# The output is decomposed. The next layer must deal with the decomposed output tensors
DistributedArcMarginProductLayer {inDim, outDim, bias=0, scale=1, init='glorotUniform', initValueScale=1} =
{
W = ParameterTensor {_ConcatArrays (inDim, outDim), init=init, initValueScale=initValueScale, distribute=true}
apply (labelSequence, outProbVectorSequence) = DistributedArcMarginProduct(labelSequence, W, outProbVectorSequence, bias=bias, scale=scale)
}.apply
# MarginInnerProductLayer -- create a marginInnerProduct projection layer
# Note: outputDimension may describe a tensor as well.
MarginInnerProductLayer {inputDimension, outputDimension, base, gamma, power, lambdaMin, coefficient, init='glorotUniform', initValueScale=1} =
@ -106,6 +115,15 @@ AdditiveFullConnectionLayer {inputDimension, outputDimension, weightNormalize=tr
apply (labelSequence, outProbVectorSequence) = AdditiveFullConnection(labelSequence, outProbVectorSequence, W, outputDimension=outputDimension, weightNormalize=weightNormalize, bias=bias, annealBias=annealBias, biasBase=biasBase, biasGamma=biasGamma, biasPower=biasPower, biasMin=biasMin, biasMax=biasMax)
}.apply
# ArcMarginProductLayer -- create a arc margin product layer
# Note: outputDimension may describe a tensor as well.
ArcMarginProductLayer {inputDimension, outputDimension, bias=0, init='glorotUniform', initValueScale=1} =
{
W = ParameterTensor {_ConcatArrays (outputDimension, inputDimension), init=init, initValueScale=initValueScale}
apply (labelSequence, outProbVectorSequence) = ArcMarginProduct(labelSequence, outProbVectorSequence, W, bias=bias)
}.apply
# GlobalConcatLayer -- create a concat layer, which uses temporary global memory to save the output
# Note: outputDimension may describe a tensor as well.
GlobalConcatLayer {blockIndex, growthRate, segmentIndex, segmentNum} =
@ -460,11 +478,12 @@ DistributedFullyConnected_v2 = CNTK2.DistributedFullyConnected_v2
DistributedCrossEntropyWithSoftmax = CNTK2.DistributedCrossEntropyWithSoftmax
DistributedClassificationError = CNTK2.DistributedClassificationError
DistributedAdditiveFullConnection = CNTK2.DistributedAdditiveFullConnection
DistributedArcMarginProduct = CNTK2.DistributedArcMarginProduct
MarginInnerProduct = CNTK2.MarginInnerProduct
FeatureNormalize = CNTK2.FeatureNormalize
AdditiveFullConnection = CNTK2.AdditiveFullConnection
ArcMarginProduct = CNTK2.ArcMarginProduct
CenterLoss = CNTK2.CenterLoss
ChannelMultiply = CNTK2.ChannelMultiply
GlobalConcat = CNTK2.GlobalConcat
Dropout = CNTK2.Dropout
ElementTimes = CNTK2.ElementTimes
@ -592,6 +611,12 @@ CNTK2 = [
DistributedAdditiveFullConnection(labelSequence, W, outProbVectorSequence, weightNormalize=true, bias=0, scale=1, tag='') =
new ComputationNode [ operation = 'DistributedAdditiveFullConnection' ; inputs = _AsNodes (labelSequence : W : outProbVectorSequence) /*plus the function args*/ ]
#
# Distributed arcMarginProduct node, the input labels and probability must be decomposed
#
DistributedArcMarginProduct(labelSequence, W, outProbVectorSequence, bias=0, scale=1, tag='') =
new ComputationNode [ operation = 'DistributedArcMarginProduct' ; inputs = _AsNodes (labelSequence : W : outProbVectorSequence) /*plus the function args*/ ]
#
# MarginInnerProduct node
#
@ -610,6 +635,12 @@ CNTK2 = [
AdditiveFullConnection(labelSequence, outProbVectorSequence, W, outputDimension=0, weightNormalize=true, bias=0, annealBias=false, biasBase=0, biasGamma=0, biasPower=0, biasMin=0, biasMax=0,tag='') =
new ComputationNode [ operation = 'AdditiveFullConnection' ; inputs = _AsNodes (labelSequence : outProbVectorSequence : W) /*plus the function args*/ ]
#
# ArcMarginProduct node
#
ArcMarginProduct(labelSequence, outProbVectorSequence, W, outputDimension=0, bias=0, tag='') =
new ComputationNode [ operation = 'ArcMarginProduct' ; inputs = _AsNodes (labelSequence : outProbVectorSequence : W) /*plus the function args*/ ]
#
#CenterLoss node
#
@ -617,12 +648,6 @@ CNTK2 = [
if axis==0 then new ComputationNode [ operation = 'CenterLoss' ; inputs = _AsNodes (labelSequence : outProbVectorSequence) /*plus the function args*/ ]
else [ tag1 = tag; out = Minus (ReduceLogSum (outProbVectorSequence, axis=axis), ReduceSum (labelSequence .* outProbVectorSequence, axis=axis), tag=tag1) ].out
#
# ChannelMultiply node
#
ChannelMultiply(feature, weight, tag='') =
new ComputationNode [ operation = 'ChannelMultiply' ; inputs = _AsNodes (feature : weight) /*plus the function args*/ ]
#
# GlobalConcat node
#

Просмотреть файл

@ -4351,15 +4351,17 @@ namespace CNTK
CNTK_API FunctionPtr DistributedAdditiveFullConnection(const Variable& targets, const Variable& weight, const Variable& prediction, bool weightNormalize, double bias, double scale, const std::wstring& name = L"");
CNTK_API FunctionPtr DistributedArcMarginProduct(const Variable& targets, const Variable& weight, const Variable& prediction, double bias, double scale, const std::wstring& name = L"");
CNTK_API FunctionPtr MarginInnerProduct(const Variable& prediction, const Variable& targets, const Variable& weight, size_t outputDimension, double base, double gamma, double power, double lambdaMin, size_t marginCoefficient, const std::wstring& name = L"");
CNTK_API FunctionPtr FeatureNormalize(const Variable& feature, size_t normalizeType, const std::wstring& name = L"");
CNTK_API FunctionPtr AdditiveFullConnection(const Variable& prediction, const Variable& targets, const Variable& weight, size_t outputDimension, bool weightNormalize, double bias, bool annealBias, double biasBase, double biasGamma, double biasPower, double biasMin, double biasMax, const std::wstring& name = L"");
CNTK_API FunctionPtr CenterLoss(const Variable& prediction, const Variable& targets, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring& name = L"");
CNTK_API FunctionPtr ArcMarginProduct(const Variable& prediction, const Variable& targets, const Variable& weight, double bias, const std::wstring& name = L"");
CNTK_API FunctionPtr ChannelMultiply(const Variable& prediction, const Variable& targets, const std::wstring& name = L"");
CNTK_API FunctionPtr CenterLoss(const Variable& prediction, const Variable& targets, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring& name = L"");
CNTK_API FunctionPtr GlobalConcat(const Variable& feature, size_t blockIndex, size_t growthRate, size_t segmentIndex, size_t segmentNum, const std::wstring& name = L"");

Просмотреть файл

@ -125,14 +125,15 @@ namespace CNTK
{PrimitiveOpType::MarginInnerProduct, L"MarginInnerProduct"},
{PrimitiveOpType::FeatureNormalize, L"FeatureNormalize"},
{PrimitiveOpType::AdditiveFullConnection, L"AdditiveFullConnection"},
{PrimitiveOpType::ArcMarginProduct, L"ArcMarginProduct" },
{PrimitiveOpType::CenterLoss, L"CenterLoss" },
{PrimitiveOpType::ChannelMultiply, L"ChannelMultiply" },
{PrimitiveOpType::GlobalConcat, L"GlobalConcat"},
{PrimitiveOpType::DistributedFullyConnected, L"DistributedFullyConnected" },
{PrimitiveOpType::DistributedFullyConnected_v2, L"DistributedFullyConnected_v2" },
{PrimitiveOpType::DistributedCrossEntropyWithSoftmax, L"DistributedCrossEntropyWithSoftmax" },
{PrimitiveOpType::DistributedClassificationError, L"DistributedClassificationError" },
{PrimitiveOpType::DistributedAdditiveFullConnection, L"DistributedAdditiveFullConnection" }
{PrimitiveOpType::DistributedAdditiveFullConnection, L"DistributedAdditiveFullConnection" },
{PrimitiveOpType::DistributedArcMarginProduct, L"DistributedArcMarginProduct" }
};
inline const std::wstring& PrimitiveOpTypeName(PrimitiveOpType opType)

Просмотреть файл

@ -154,6 +154,7 @@ namespace CNTK
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasPower;
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasMin;
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasMax;
CNTK_API static const std::wstring AttributeArcMarginProductBias;
CNTK_API static const std::wstring AttributeCenterLossLambda;
CNTK_API static const std::wstring AttributeCenterLossAlpha;
CNTK_API static const std::wstring AttributeCenterLossLabelDim;
@ -165,6 +166,8 @@ namespace CNTK
CNTK_API static const std::wstring AttributeDistributedAdditiveFullConnectionWeightNormalize;
CNTK_API static const std::wstring AttributeDistributedAdditiveFullConnectionBias;
CNTK_API static const std::wstring AttributeDistributedAdditiveFullConnectionScale;
CNTK_API static const std::wstring AttributeDistributedArcMarginProductBias;
CNTK_API static const std::wstring AttributeDistributedArcMarginProductScale;
CNTK_API static const std::vector<std::wstring> s_rngStateAttributes;
};

Просмотреть файл

@ -110,14 +110,15 @@ namespace CNTK
MarginInnerProduct = 197,
FeatureNormalize = 198,
AdditiveFullConnection = 199,
CenterLoss = 200,
ChannelMultiply = 201,
ArcMarginProduct = 200,
CenterLoss = 201,
GlobalConcat = 202,
DistributedFullyConnected = 203,
DistributedFullyConnected_v2 = 204,
DistributedCrossEntropyWithSoftmax = 205,
DistributedClassificationError = 206,
DistributedAdditiveFullConnection = 207,
DistributedArcMarginProduct = 208,
// New op types should only be appended to the end of this list
UnknownOP
// and UnknownOP should always be last.

Просмотреть файл

@ -347,6 +347,14 @@ namespace CNTK
opType = PrimitiveOpType::DistributedAdditiveFullConnection;
}
else if (node->OperationName() == OperationNameOf(DistributedArcMarginProductNode))
{
auto distributedArcMarginProductNode = node->As<DistributedArcMarginProductNode<ElementType>>();
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductBias] = distributedArcMarginProductNode->m_bias;
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductScale] = distributedArcMarginProductNode->m_scale;
opType = PrimitiveOpType::DistributedArcMarginProduct;
}
else if (node->OperationName() == OperationNameOf(MarginInnerProductNode))
{
auto marginInnerProductNode = node->As<MarginInnerProductNode<ElementType>>();
@ -381,6 +389,13 @@ namespace CNTK
opType = PrimitiveOpType::AdditiveFullConnection;
}
else if (node->OperationName() == OperationNameOf(ArcMarginProductNode))
{
auto arcMarginProductNode = node->As<ArcMarginProductNode<ElementType>>();
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeArcMarginProductBias] = arcMarginProductNode->m_bias;
opType = PrimitiveOpType::ArcMarginProduct;
}
else if (node->OperationName() == OperationNameOf(CenterLossNode))
{
auto centerLossNode = node->As<CenterLossNode<ElementType>>();
@ -391,8 +406,6 @@ namespace CNTK
opType = PrimitiveOpType::CenterLoss;
}
else if (node->OperationName() == OperationNameOf(ChannelMultiplyNode))
opType = PrimitiveOpType::ChannelMultiply;
else if (node->OperationName() == OperationNameOf(GlobalConcatNode))
{
auto globalConcatNode = node->As<GlobalConcatNode<ElementType>>();

Просмотреть файл

@ -1182,6 +1182,13 @@ namespace CNTK
ASSIGN_NEW_NODE(DistributedAdditiveFullConnectionNode, network->GetDeviceId(), internalNodeName, weightNormalize, bias, scale);
break;
}
case PrimitiveOpType::DistributedArcMarginProduct:
{
auto bias = functionConfig[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductBias].Value<double>();
auto scale = functionConfig[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductScale].Value<double>();
ASSIGN_NEW_NODE(DistributedArcMarginProductNode, network->GetDeviceId(), internalNodeName, bias, scale);
break;
}
case PrimitiveOpType::MarginInnerProduct:
{
auto outputDimension = functionConfig[PrimitiveFunctionAttribute::AttributeMarginInnerProductOutputDimension].Value<size_t>();
@ -1213,6 +1220,12 @@ namespace CNTK
ASSIGN_NEW_NODE(AdditiveFullConnectionNode, network->GetDeviceId(), internalNodeName, outputDimension, weightNormalize, bias, annealBias, biasBase, biasGamma, biasPower, biasMin, biasMax);
break;
}
case PrimitiveOpType::ArcMarginProduct:
{
auto bias = functionConfig[PrimitiveFunctionAttribute::AttributeArcMarginProductBias].Value<double>();
ASSIGN_NEW_NODE(ArcMarginProductNode, network->GetDeviceId(), internalNodeName, bias);
break;
}
case PrimitiveOpType::CenterLoss:
{
auto lambda = functionConfig[PrimitiveFunctionAttribute::AttributeCenterLossLambda].Value<double>();
@ -1222,11 +1235,6 @@ namespace CNTK
ASSIGN_NEW_NODE(CenterLossNode, network->GetDeviceId(), internalNodeName, lambda, alpha, labelDim, normalize);
break;
}
case PrimitiveOpType::ChannelMultiply:
{
ASSIGN_NEW_NODE(ChannelMultiplyNode, network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::GlobalConcat:
{
auto blockIndex = functionConfig[PrimitiveFunctionAttribute::AttributeGlobalConcatBlockIndex].Value<size_t>();

Просмотреть файл

@ -2170,6 +2170,15 @@ namespace CNTK
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::DistributedAdditiveFullConnection, operands, std::move(additionalProperties), name), name);
}
FunctionPtr DistributedArcMarginProduct(const Variable& labels, const Variable& weight, const Variable& features, double bias, double scale, const std::wstring& name)
{
std::vector<Variable> operands = { labels, weight, features };
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductBias] = bias;
additionalProperties[PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductScale] = scale;
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::DistributedArcMarginProduct, operands, std::move(additionalProperties), name), name);
}
FunctionPtr MarginInnerProduct(const Variable& features, const Variable& labels, const Variable& weight, size_t outputDimension, double base, double gamma, double power, double lambdaMin, size_t marginCoefficient, const std::wstring& name)
{
std::vector<Variable> operands = {features, labels, weight};
@ -2206,6 +2215,14 @@ namespace CNTK
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::AdditiveFullConnection, operands, std::move(additionalProperties), name), name);
}
FunctionPtr ArcMarginProduct(const Variable& features, const Variable& labels, const Variable& weight, double bias, const std::wstring& name)
{
std::vector<Variable> operands = { features, labels, weight };
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunctionAttribute::AttributeArcMarginProductBias] = bias;
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ArcMarginProduct, operands, std::move(additionalProperties), name), name);
}
FunctionPtr CenterLoss(const Variable& features, const Variable& labels, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring& name)
{
auto additionalProperties = Dictionary();
@ -2216,14 +2233,6 @@ namespace CNTK
return BinaryOp(PrimitiveOpType::CenterLoss, labels, features, std::move(additionalProperties), name);
}
FunctionPtr ChannelMultiply(const Variable& feature, const Variable& weight, const std::wstring& name)
{
Variable featureCopy = feature;
Variable weightCopy = weight;
auto additionalProperties = Dictionary();
return BinaryOp(PrimitiveOpType::ChannelMultiply, featureCopy, weightCopy, std::move(additionalProperties), name);
}
FunctionPtr GlobalConcat(const Variable& feature, size_t blockIndex, size_t growthRate, size_t segmentIndex, size_t segmentNum, const std::wstring& name)
{
Variable featureCopy = feature;

Просмотреть файл

@ -135,11 +135,12 @@ namespace CNTK
(op == PrimitiveOpType::DistributedCrossEntropyWithSoftmax) ||
(op == PrimitiveOpType::DistributedClassificationError) ||
(op == PrimitiveOpType::DistributedAdditiveFullConnection) ||
(op == PrimitiveOpType::DistributedArcMarginProduct) ||
(op == PrimitiveOpType::MarginInnerProduct) ||
(op == PrimitiveOpType::FeatureNormalize) ||
(op == PrimitiveOpType::AdditiveFullConnection) ||
(op == PrimitiveOpType::ArcMarginProduct) ||
(op == PrimitiveOpType::CenterLoss) ||
(op == PrimitiveOpType::ChannelMultiply) ||
(op == PrimitiveOpType::GlobalConcat) ||
(op == PrimitiveOpType::CrossEntropyWithSoftmax) ||
(op == PrimitiveOpType::LatticeSequenceWithSoftmax) ||
@ -970,6 +971,12 @@ namespace CNTK
outputShape = NDShape{};
break;
}
case PrimitiveOpType::DistributedArcMarginProduct:
{
assert(m_inputs.size() == 3);
outputShape = NDShape{};
break;
}
case PrimitiveOpType::MarginInnerProduct:
{
assert(m_inputs.size() == 3);
@ -988,10 +995,10 @@ namespace CNTK
outputShape = NDShape{};
break;
}
case PrimitiveOpType::ChannelMultiply:
case PrimitiveOpType::ArcMarginProduct:
{
assert(m_inputs.size() == 2);
outputShape = m_inputs[0].Shape();
assert(m_inputs.size() == 3);
outputShape = NDShape{};
break;
}
case PrimitiveOpType::GlobalConcat:

Просмотреть файл

@ -130,26 +130,29 @@ namespace CNTK
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeMarginInnerProductLambdaMin = L"marginInnerProductOutputLambdaMin";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeMarginInnerProductMarginCoefficient = L"marginCoefficient";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeFeatureNormalizeNormalizeType = L"featureNormalizeType";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionOutputDimension = L"outputDimension";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionWeightNormalize = L"weightNormalize";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBias = L"bias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionAnnealBias = L"annealBias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasBase = L"biasBase";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasGamma = L"biasGamma";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasPower = L"biasPower";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMin = L"biasMin";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMax = L"biasMax";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionOutputDimension = L"additiveFullConnectionOutputDimension";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionWeightNormalize = L"additiveFullConnectionWeightNormalize";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBias = L"additiveFullConnectionBias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionAnnealBias = L"additiveFullConnectionAnnealBias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasBase = L"additiveFullConnectionBiasBase";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasGamma = L"additiveFullConnectionBiasGamma";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasPower = L"additiveFullConnectionBiasPower";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMin = L"additiveFullConnectionBiasMin";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMax = L"additiveFullConnectionBiasMax";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeArcMarginProductBias = L"arcMarginProductBias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossLambda = L"centerLossLambda";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossAlpha = L"centerLossAlpha";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossLabelDim = L"centerLossLabelDim";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossNormalize = L"centerLossNormalize";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatBlockIndex = L"blockIndex";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatGrowthRate = L"growthRate";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatSegmentIndex = L"segmentIndex";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatSegmentNum = L"segmentNum";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionWeightNormalize = L"weightNormalize";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionBias = L"bias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionScale = L"scale";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatBlockIndex = L"globalConcatBlockIndex";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatGrowthRate = L"globalConcatGrowthRate";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatSegmentIndex = L"globalConcatSegmentIndex";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeGlobalConcatSegmentNum = L"globalConcatSegmentNum";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionWeightNormalize = L"distributedAdditiveFullConnectionWeightNormalize";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionBias = L"distributedAdditiveFullConnectionBias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedAdditiveFullConnectionScale = L"distributedAdditiveFullConnectionScale";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductBias = L"distributedArcMarginProductBias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeDistributedArcMarginProductScale = L"distributedArcMarginProductScale";
/*static*/ const std::vector<std::wstring> PrimitiveFunctionAttribute::s_rngStateAttributes =
{ PrimitiveFunctionAttribute::AttributeNameRngSeed,

Просмотреть файл

@ -131,11 +131,12 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
else if (nodeType == OperationNameOf(DistributedCrossEntropyWithSoftmaxNode)) return New<DistributedCrossEntropyWithSoftmaxNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(DistributedClassificationErrorNode)) return New<DistributedClassificationErrorNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(DistributedAdditiveFullConnectionNode)) return New<DistributedAdditiveFullConnectionNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(DistributedArcMarginProductNode)) return New<DistributedArcMarginProductNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(MarginInnerProductNode)) return New<MarginInnerProductNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(FeatureNormalizeNode)) return New<FeatureNormalizeNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(AdditiveFullConnectionNode)) return New<AdditiveFullConnectionNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ArcMarginProductNode)) return New<ArcMarginProductNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(CenterLossNode)) return New<CenterLossNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(ChannelMultiplyNode)) return New<ChannelMultiplyNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(GlobalConcatNode)) return New<GlobalConcatNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(LogisticNode)) return New<LogisticNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(SumColumnElementsNode)) return New<SumColumnElementsNode<ElemType>>(forward<_Types>(_Args)...);
@ -518,6 +519,12 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Distr
return net.AddNodeToNetAndAttachInputs(New<DistributedAdditiveFullConnectionNode<ElemType>>(net.GetDeviceId(), nodeName, weightNormalize, bias, scale), { a, b, c });
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::DistributedArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, ElemType scale, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<DistributedArcMarginProductNode<ElemType>>(net.GetDeviceId(), nodeName, bias, scale), { a, b, c });
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::MarginInnerProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType base, ElemType gamma, ElemType power, ElemType lambdaMin, size_t marginCoefficient, const std::wstring nodeName)
{
@ -537,15 +544,15 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Addit
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::CenterLoss(const ComputationNodePtr a, const ComputationNodePtr b, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring nodeName)
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<CenterLossNode<ElemType>>(net.GetDeviceId(), nodeName, lambda, alpha, labelDim, normalize), { a, b });
return net.AddNodeToNetAndAttachInputs(New<ArcMarginProductNode<ElemType>>(net.GetDeviceId(), nodeName, bias), { a, b, c });
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ChannelMultiply(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName)
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::CenterLoss(const ComputationNodePtr a, const ComputationNodePtr b, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<CenterLossNode<ElemType>>(net.GetDeviceId(), nodeName), { a, b });
return net.AddNodeToNetAndAttachInputs(New<CenterLossNode<ElemType>>(net.GetDeviceId(), nodeName, lambda, alpha, labelDim, normalize), { a, b });
}
template <class ElemType>

Просмотреть файл

@ -205,11 +205,12 @@ public:
ComputationNodePtr DistributedCrossEntropyWithSoftmax(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr DistributedClassificationError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr DistributedAdditiveFullConnection(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, bool weightNormalize, ElemType bias, ElemType scale, const std::wstring nodeName = L"");
ComputationNodePtr DistributedArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, ElemType scale, const std::wstring nodeName = L"");
ComputationNodePtr MarginInnerProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType base, ElemType gamma, ElemType power, ElemType lambdaMin, size_t marginCoefficient, const std::wstring nodeName = L"");
ComputationNodePtr FeatureNormalize(const ComputationNodePtr a, size_t normalizeType, const std::wstring nodeName = L"");
ComputationNodePtr AdditiveFullConnection(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, bool weightNormalize, ElemType bias, bool annealBias, ElemType biasBase, ElemType biasGamma, ElemType biasPower, ElemType biasMin, ElemType biasMax, const std::wstring nodeName = L"");
ComputationNodePtr ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, const std::wstring nodeName = L"");
ComputationNodePtr CenterLoss(const ComputationNodePtr a, const ComputationNodePtr b, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring nodeName = L"");
ComputationNodePtr ChannelMultiply(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr GlobalConcat(const ComputationNodePtr a, size_t blockIndex, size_t growthRate, size_t segmentIndex, size_t segmentNum, const std::wstring nodeName = L"");
ComputationNodePtr Sum(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr Tan(const ComputationNodePtr a, const std::wstring nodeName = L"");

Просмотреть файл

@ -22,24 +22,9 @@
using namespace std;
//#define __DISTRIBUTED_PROFILE__
namespace Microsoft { namespace MSR { namespace CNTK {
#ifdef __DISTRIBUTED_PROFILE__
double distributedAdditiveFullConnectionTime = 0.0;
std::chrono::time_point<std::chrono::system_clock> distributedAdditiveFullConnectionStartTime;
std::chrono::time_point<std::chrono::system_clock> distributedAdditiveFullConnectionEndTime;
int distributedAdditiveFullConnectionCnt = 0;
double distributedCrossEntropyWithSoftmaxTime = 0.0;
std::chrono::time_point<std::chrono::system_clock> distributedCrossEntropyWithSoftmaxStartTime;
std::chrono::time_point<std::chrono::system_clock> distributedCrossEntropyWithSoftmaxEndTime;
int distributedCrossEntropyWithSoftmaxCnt = 0;
#endif
// This source file contains methods related to evaluation (forward prop, backprop), network validation, and matrix memory allocation (memory sharing).
// -----------------------------------------------------------------------
@ -163,14 +148,6 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
{
if (node->IsOutOfDateWrtInputs())
{
#ifdef __DISTRIBUTED_PROFILE__
if (node->OperationName() == L"DistributedAdditiveFullConnection")
distributedAdditiveFullConnectionStartTime = std::chrono::system_clock::now();
else if (node->OperationName() == L"DistributedCrossEntropyWithSoftmax")
distributedCrossEntropyWithSoftmaxStartTime = std::chrono::system_clock::now();
#endif
node->BeginForwardProp();
node->BeginTiming(false /*backward*/);
node->ForwardProp(fr.WithLayout(node->GetMBLayout()));
@ -182,30 +159,6 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
// Extreme Tracing, part 1/4
if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode())
DumpNode(node, /*dumpGradient=*/false);
#ifdef __DISTRIBUTED_PROFILE__
if (node->OperationName() == L"DistributedAdditiveFullConnection")
{
distributedAdditiveFullConnectionEndTime = std::chrono::system_clock::now();
distributedAdditiveFullConnectionTime += (std::chrono::duration<double>(distributedAdditiveFullConnectionEndTime - distributedAdditiveFullConnectionStartTime)).count();
if (++distributedAdditiveFullConnectionCnt % 100 == 0)
{
fprintf(stderr, "Iteration [%d-%d]: distributedAdditiveFullConnection forward time = %.8gs\n", distributedAdditiveFullConnectionCnt - 99, distributedAdditiveFullConnectionCnt, distributedAdditiveFullConnectionTime);
distributedAdditiveFullConnectionTime = 0.0;
}
}
else if (node->OperationName() == L"DistributedCrossEntropyWithSoftmax")
{
distributedCrossEntropyWithSoftmaxEndTime = std::chrono::system_clock::now();
distributedCrossEntropyWithSoftmaxTime += (std::chrono::duration<double>(distributedCrossEntropyWithSoftmaxEndTime - distributedCrossEntropyWithSoftmaxStartTime)).count();
if (++distributedCrossEntropyWithSoftmaxCnt % 100 == 0)
{
fprintf(stderr, "Iteration [%d-%d]: distributedCrossEntropyWithSoftmax forward time = %.8gs\n", distributedCrossEntropyWithSoftmaxCnt - 99, distributedCrossEntropyWithSoftmaxCnt, distributedCrossEntropyWithSoftmaxTime);
distributedCrossEntropyWithSoftmaxTime = 0.0;
}
}
#endif
}
}

Просмотреть файл

@ -483,7 +483,8 @@ void ComputationNodeBase::ValidateBinaryReduce(bool isFinalValidationPass)
{
// It is for DistributedCrossEntropyWithSoftmaxNode
if (Input(0)->OperationName() != L"DistributedFullyConnected_v2" && Input(1)->OperationName() != L"DistributedFullyConnected_v2" &&
Input(0)->OperationName() != L"DistributedAdditiveFullConnection" && Input(1)->OperationName() != L"DistributedAdditiveFullConnection")
Input(0)->OperationName() != L"DistributedAdditiveFullConnection" && Input(1)->OperationName() != L"DistributedAdditiveFullConnection" &&
Input(0)->OperationName() != L"DistributedArcMarginProduct" && Input(1)->OperationName() != L"DistributedArcMarginProduct")
{
string s1 = Input(0)->GetSampleLayout();
string s2 = Input(1)->GetSampleLayout();

Просмотреть файл

@ -1729,7 +1729,7 @@ protected:
{
rows = GetSampleMatrixNumRows();
cols = GetSampleMatrixNumCols();
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection")
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection" || OperationName() == L"DistributedArcMarginProduct")
cols *= Globals::GetProcessNum();
}
else
@ -1856,7 +1856,7 @@ public:
size_t matrixSize = m_sampleLayout.GetNumElements();
if (IsValueSharable() && !m_isValueSparse)
{
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection")
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection" || OperationName() == L"DistributedArcMarginProduct")
matrixSize *= Globals::GetProcessNum();
RequestMatrixFromPool(m_value, matrixPool, matrixSize, HasMBLayout());
}
@ -1897,7 +1897,7 @@ public:
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) override
{
size_t matrixSize = m_sampleLayout.GetNumElements();
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection")
if (OperationName() == L"DistributedFullyConnected_v2" || OperationName() == L"DistributedAdditiveFullConnection" || OperationName() == L"DistributedArcMarginProduct")
matrixSize *= Globals::GetProcessNum();
RequestMatrixFromPool(m_gradient, matrixPool, matrixSize, HasMBLayout(), /*isWorkSpace*/false, ParentGradientReused() || IsGradientReused());

Просмотреть файл

@ -361,6 +361,10 @@ template class DistributedAdditiveFullConnectionNode<float>;
template class DistributedAdditiveFullConnectionNode<double>;
template class DistributedAdditiveFullConnectionNode<half>;
template class DistributedArcMarginProductNode<float>;
template class DistributedArcMarginProductNode<double>;
template class DistributedArcMarginProductNode<half>;
template class MarginInnerProductNode<float>;
template class MarginInnerProductNode<double>;
template class MarginInnerProductNode<half>;
@ -373,14 +377,14 @@ template class AdditiveFullConnectionNode<float>;
template class AdditiveFullConnectionNode<double>;
template class AdditiveFullConnectionNode<half>;
template class ArcMarginProductNode<float>;
template class ArcMarginProductNode<double>;
template class ArcMarginProductNode<half>;
template class CenterLossNode<float>;
template class CenterLossNode<double>;
template class CenterLossNode<half>;
template class ChannelMultiplyNode<float>;
template class ChannelMultiplyNode<double>;
template class ChannelMultiplyNode<half>;
template class GlobalMemoryBlock<float>;
template class GlobalMemoryBlock<double>;
template class GlobalMemoryBlock<half>;

Просмотреть файл

@ -31,6 +31,7 @@ static const wstring RandomDistributionTypeGumbel = L"gumbel";
static const wstring RandomDistributionTypeBernoulli = L"bernoulli";
static const double m_PI = acos(-1.0);
// Distributed fully connected layer (Y = W'X + b)
// Input(0): W, Input(1): X, Input(2): b
// Output is not decomposed
@ -736,6 +737,224 @@ public:
shared_ptr<Matrix<ElemType>> m_WNorm;
};
// Distributed additive angular margin product layer
// Input(0): labels, Input(1): W, Input(2): X
// Input and output are both decomposed
template <class ElemType>
class DistributedArcMarginProductNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<3>
{
typedef ComputationNodeNonLooping<ElemType> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
{
return L"DistributedArcMarginProduct";
}
public:
DistributedArcMarginProductNode(const ScriptableObjects::IConfigRecordPtr configp)
: DistributedArcMarginProductNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"bias"), configp->Get(L"scale"))
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
DistributedArcMarginProductNode(DEVICEID_TYPE deviceId, const wstring& name, double bias = 0.0, double scale = 1.0)
: Base(deviceId, name), m_bias(bias), m_scale(scale), m_rank(Globals::GetRank()), m_processNum(Globals::GetProcessNum()), m_minibatchSize(0), m_distGradAggPtr(NULL)
{
if (m_bias < 0 || m_bias >= m_PI)
LogicError("DistributedArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
m_threshold = cos(m_PI - m_bias);
m_cosBias = cos(m_bias);
m_sinBias = sin(m_bias);
#ifdef CPUONLY
LogicError("CPUONLY is not supported in DistributedArcMarginProductNode.");
#endif
}
~DistributedArcMarginProductNode()
{
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
DistributedGatheredLabels<ElemType>::initializeNodePtr = NULL;
}
virtual void UpdateFunctionMBSize() override
{
size_t minibatchSize = InputRef(0).Value().GetNumCols();
if (m_minibatchSize != minibatchSize)
{
if (1 == m_processNum)
LogicError("Multi Gpus and mpi is needed in distributed FC.");
m_distGradAggPtr = (IDistGradAggregator<ElemType>*) Globals::GetDistGradAggPtr();
bool minibatchSizeEqual = m_distGradAggPtr->DistributedCheck(m_minibatchSize, m_processNum);
if (!minibatchSizeEqual)
LogicError("With AllGather op, minibatch size in each Gpu must be the same.");
m_minibatchSize = minibatchSize;
m_batchSize = m_minibatchSize * m_processNum;
}
m_temp1->Resize(m_inputDim, m_batchSize); // Aggregated X
m_tempValue->Resize(1, m_batchSize);
m_flag->Resize(1, m_batchSize);
m_WNorm->Resize(1, m_outputDim);
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
DistributedGatheredLabels<ElemType>::setMinibatchSize(m_minibatchSize);
}
virtual void BackpropToNonLooping(size_t inputIndex) override
{
if (1 == inputIndex) // for W
{
Matrix<ElemType>::Scale((ElemType)m_scale, Gradient());
Matrix<ElemType>::DistributedArcLabelAddBackprop(*DistributedGatheredLabels<ElemType>::m_gatheredLabels, (ElemType)m_cosBias, (ElemType)m_sinBias, *m_flag, *m_tempValue, Gradient(), m_outputDim * m_rank, m_outputDim * (m_rank + 1) - 1);
auto& W_gradient = InputRef(1).Gradient();
Matrix<ElemType>::Multiply(*m_temp1, false, Gradient(), true, W_gradient);
}
else if (2 == inputIndex) // for X
{
auto& W = InputRef(1).Value();
Matrix<ElemType>::Multiply(W, false, Gradient(), false, *m_temp1);
m_distGradAggPtr->DistributedAllReduce(*m_temp1, MPI_SUM);
auto& X_gradient = InputRef(2).Gradient();
X_gradient.SetValue(m_temp1->ColumnSlice(m_minibatchSize * m_rank, m_minibatchSize));
}
}
virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override
{
auto& W = InputRef(1).Value();
auto& X = InputRef(2).Value();
W.VectorNorm2(*m_WNorm, true);
W.RowElementDivideBy(*m_WNorm);
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
DistributedGatheredLabels<ElemType>::gatherDistributedLabels(InputRef(0).Value());
m_distGradAggPtr->DistributedAllGather(X, *m_temp1, m_inputDim * m_minibatchSize);
Matrix<ElemType>::Multiply(W, true, *m_temp1, false, Value());
if (Environment().IsTraining())
{
m_flag->SetValue(0);
Matrix<ElemType>::DistributedArcLabelAdd(*DistributedGatheredLabels<ElemType>::m_gatheredLabels, (ElemType)m_threshold, (ElemType)m_bias, (ElemType)m_sinBias, *m_flag, *m_tempValue, Value(), m_outputDim * m_rank, m_outputDim * (m_rank + 1) - 1);
}
Matrix<ElemType>::Scale((ElemType)m_scale, Value());
}
virtual bool OutputUsedInComputingInputNodesGradients() const override
{
return false;
}
virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override
{
if (0 == childIndex) // for labels
return false;
else if (1 == childIndex) // for W
return true;
else // for X
return false;
}
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
m_inputDim = InputRef(1).Value().GetNumRows();
m_outputDim = InputRef(1).Value().GetNumCols();
SetDims(TensorShape(m_outputDim), HasMBLayout());
DistributedGatheredLabels<ElemType>::setInitializeNode(this);
}
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
{
Base::CopyTo(nodeP, newName, flags);
if (flags & CopyNodeFlags::copyNodeValue)
{
auto node = dynamic_pointer_cast<DistributedArcMarginProductNode<ElemType>>(nodeP);
node->m_rank = m_rank;
node->m_processNum = m_processNum;
node->m_inputDim = m_inputDim;
node->m_outputDim = m_outputDim;
node->m_minibatchSize = m_minibatchSize;
node->m_batchSize = m_batchSize;
node->m_bias = m_bias;
node->m_threshold = m_threshold;
node->m_cosBias = m_cosBias;
node->m_sinBias = m_sinBias;
node->m_scale = m_scale;
node->m_distGradAggPtr = m_distGradAggPtr;
node->m_temp1->SetValue(*m_temp1);
node->m_tempValue->SetValue(*m_tempValue);
node->m_flag->SetValue(*m_flag);
}
}
virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool)
{
Base::RequestMatricesBeforeForwardProp(matrixPool);
RequestMatrixFromPool(m_WNorm, matrixPool);
RequestMatrixFromPool(m_temp1, matrixPool, m_inputDim * m_processNum, true);
RequestMatrixFromPool(m_flag, matrixPool, m_inputDim * m_processNum, true);
RequestMatrixFromPool(m_tempValue, matrixPool, m_inputDim * m_processNum, true);
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
{
RequestMatrixFromPool(DistributedGatheredLabels<ElemType>::m_gatheredLabels, matrixPool, m_processNum, true);
RequestMatrixFromPool(DistributedGatheredLabels<ElemType>::m_labels, matrixPool, 1, true);
}
}
virtual void ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool)
{
Base::ReleaseMatricesAfterForwardProp(matrixPool);
ReleaseMatrixToPool(m_WNorm, matrixPool);
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
ReleaseMatrixToPool(DistributedGatheredLabels<ElemType>::m_labels, matrixPool);
}
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool)
{
Base::ReleaseMatricesAfterBackprop(matrixPool);
ReleaseMatrixToPool(m_temp1, matrixPool);
ReleaseMatrixToPool(m_flag, matrixPool);
ReleaseMatrixToPool(m_tempValue, matrixPool);
if (DistributedGatheredLabels<ElemType>::isInitializeNode(this))
ReleaseMatrixToPool(DistributedGatheredLabels<ElemType>::m_gatheredLabels, matrixPool);
}
void Save(File& fstream) const override
{
Base::Save(fstream);
fstream << m_bias << m_scale;
}
void Load(File& fstream, size_t modelVersion) override
{
Base::Load(fstream, modelVersion);
fstream >> m_bias >> m_scale;
if (m_bias < 0 || m_bias >= m_PI)
LogicError("DistributedArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
m_threshold = cos(m_PI - m_bias);
m_cosBias = cos(m_bias);
m_sinBias = sin(m_bias);
}
size_t m_rank;
size_t m_processNum;
size_t m_inputDim;
size_t m_outputDim;
size_t m_minibatchSize;
size_t m_batchSize;
double m_bias;
double m_threshold;
double m_cosBias;
double m_sinBias;
double m_scale;
IDistGradAggregator<ElemType>* m_distGradAggPtr;
shared_ptr<Matrix<ElemType>> m_temp1;
shared_ptr<Matrix<ElemType>> m_tempValue; // Matrix(1, m_batchSize)
shared_ptr<Matrix<ElemType>> m_flag; // Matrix(1, m_batchSize)
shared_ptr<Matrix<ElemType>> m_WNorm;
};
// Implements A-Softmax as described in:
// SphereFace: DeepHypersphereEmbeddingforFaceRecognition [Weiyang Liu, Yandong Wen, Zhiding Yu, Ming Li, Bhiksha Raj, Le Song]
// https://arxiv.org/abs/1704.08063
@ -1146,9 +1365,6 @@ public:
shared_ptr<Matrix<ElemType>> m_tempMatrix; // Matrix(k,m)
};
// Implements AM-Softmax as described in:
// Additive Margin Softmax for Face Verification [Feng Wang, Weiyang Liu, Haijun Liu, Jian Cheng]
// https://arxiv.org/abs/1801.05599
template <class ElemType>
class FeatureNormalizeNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<1>
{
@ -1269,6 +1485,9 @@ public:
shared_ptr<Matrix<ElemType>> m_temp1; // Matrix(1, minibatchSize)
};
// Implements AM-Softmax as described in:
// Additive Margin Softmax for Face Verification [Feng Wang, Weiyang Liu, Haijun Liu, Jian Cheng]
// https://arxiv.org/abs/1801.05599
template <class ElemType>
class AdditiveFullConnectionNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<3>
{
@ -1461,6 +1680,175 @@ public:
shared_ptr<Matrix<ElemType>> m_weightMagnitude; // Matrix(k,1)
};
// Implements additive angular margin product as described in:
// ArcFace: Additive Angular Margin Loss for Deep Face Recognition [Jiankang Deng, Jia Guo, Niannan Xue, Stefanos Zafeiriou]
// https://arxiv.org/abs/1801.07698
template <class ElemType>
class ArcMarginProductNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<3>
{
typedef ComputationNodeNonLooping<ElemType> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
{
return L"ArcMarginProduct";
}
public:
ArcMarginProductNode(const ScriptableObjects::IConfigRecordPtr configp)
: ArcMarginProductNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"bias"))
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
ArcMarginProductNode(DEVICEID_TYPE deviceId, const wstring& name, double bias = 0.0)
: Base(deviceId, name), m_bias(bias)
{
if (m_bias < 0 || m_bias >= m_PI)
LogicError("ArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
m_threshold = cos(m_PI - m_bias);
m_cosBias = cos(m_bias);
m_sinBias = sin(m_bias);
}
virtual void UpdateFunctionMBSize() override
{
m_minibatchSize = InputRef(0).Value().GetNumCols();
m_label->Resize(1, m_minibatchSize);
m_tempValue->Resize(1, m_minibatchSize);
m_flag->Resize(1, m_minibatchSize);
m_weightMagnitude->Resize(InputRef(2).Value().GetNumRows(), 1);
}
virtual void BackpropToNonLooping(size_t inputIndex) override
{
if (1 == inputIndex)
{
Matrix<ElemType>::ArcLabelAddBackprop(*m_label, (ElemType)m_cosBias, (ElemType)m_sinBias, *m_flag, *m_tempValue, Gradient());
FrameRange fr(InputRef(1).GetMBLayout());
auto X_gradient = InputRef(1).GradientFor(fr);
auto& weight = InputRef(2).Value();
Matrix<ElemType>::Multiply(weight, true, Gradient(), false, X_gradient);
}
else if (2 == inputIndex)
{
FrameRange fr(InputRef(1).GetMBLayout());
auto X = InputRef(1).ValueFor(fr);
auto& weightGradient = InputRef(2).Gradient();
Matrix<ElemType>::Multiply(Gradient(), false, X, true, weightGradient);
}
}
virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override
{
FrameRange fr(InputRef(0).GetMBLayout());
InputRef(0).MaskedValueFor(fr).VectorMax(*m_label, *m_tempValue, true /*isColWise*/);
auto X = InputRef(1).ValueFor(fr);
auto& weight = InputRef(2).Value();
weight.VectorNorm2(*m_weightMagnitude, false);
weight.ColumnElementDivideBy(*m_weightMagnitude);
Matrix<ElemType>::Multiply(weight, false, X, false, Value());
if (Environment().IsTraining())
{
m_flag->SetValue(0);
Matrix<ElemType>::ArcLabelAdd(*m_label, (ElemType)m_threshold, (ElemType)m_bias, (ElemType)m_sinBias, *m_flag, *m_tempValue, Value());
}
}
virtual bool OutputUsedInComputingInputNodesGradients() const override
{
return false;
}
virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override
{
if (0 == childIndex)
return false;
return true;
}
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
SetDims(TensorShape(InputRef(2).Value().GetNumRows()), HasMBLayout());
}
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
{
Base::CopyTo(nodeP, newName, flags);
if (flags & CopyNodeFlags::copyNodeValue)
{
auto node = dynamic_pointer_cast<ArcMarginProductNode<ElemType>>(nodeP);
node->m_minibatchSize = m_minibatchSize;
node->m_bias = m_bias;
node->m_threshold = m_threshold;
node->m_cosBias = m_cosBias;
node->m_sinBias = m_sinBias;
node->m_label->SetValue(*m_label);
node->m_tempValue->SetValue(*m_tempValue);
node->m_flag->SetValue(*m_flag);
node->m_weightMagnitude->SetValue(*m_weightMagnitude);
}
}
// request matrices needed to do node function value evaluation
virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool)
{
Base::RequestMatricesBeforeForwardProp(matrixPool);
RequestMatrixFromPool(m_label, matrixPool, 1, true, true, false);
RequestMatrixFromPool(m_tempValue, matrixPool, 1, true, true, false);
RequestMatrixFromPool(m_flag, matrixPool, 1, true, true, false);
RequestMatrixFromPool(m_weightMagnitude, matrixPool);
}
virtual void ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool)
{
Base::ReleaseMatricesAfterForwardProp(matrixPool);
ReleaseMatrixToPool(m_weightMagnitude, matrixPool);
}
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool)
{
Base::ReleaseMatricesAfterBackprop(matrixPool);
ReleaseMatrixToPool(m_label, matrixPool);
ReleaseMatrixToPool(m_tempValue, matrixPool);
ReleaseMatrixToPool(m_flag, matrixPool);
}
void Save(File& fstream) const override
{
Base::Save(fstream);
fstream << m_bias;
}
void Load(File& fstream, size_t modelVersion) override
{
Base::Load(fstream, modelVersion);
fstream >> m_bias;
if (m_bias < 0 || m_bias >= m_PI)
LogicError("ArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
m_threshold = cos(m_PI - m_bias);
m_cosBias = cos(m_bias);
m_sinBias = sin(m_bias);
}
size_t m_minibatchSize;
double m_bias;
double m_threshold;
double m_cosBias;
double m_sinBias;
shared_ptr<Matrix<ElemType>> m_label;
shared_ptr<Matrix<ElemType>> m_tempValue;
shared_ptr<Matrix<ElemType>> m_flag;
shared_ptr<Matrix<ElemType>> m_weightMagnitude;
};
// Implements Center-Loss as described in:
// A Discriminative Feature Learning Approach for Deep Face Recognition [Yandong Wen, Kaipeng Zhang, Zhifeng Li, Yu Qiao]
// https://ydwen.github.io/papers/WenECCV16.pdf
@ -1625,145 +2013,6 @@ public:
bool initFlag = true;
};
// Implements Squeeze-and-Excitation operation as described in:
// Squeeze-and-Excitation Networks [Jie Hu, Li Shen, Gang Sun]
// https://arxiv.org/pdf/1709.01507.pdf
template <class ElemType>
class ChannelMultiplyNode : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public NumInputs<2>
{
typedef ComputationNodeNonLooping<ElemType> Base;
UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName()
{
return L"ChannelMultiply";
}
public:
ChannelMultiplyNode(const ScriptableObjects::IConfigRecordPtr configp)
: ChannelMultiplyNode(configp->Get(L"deviceId"), L"<placeholder>")
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
ChannelMultiplyNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name)
{
}
virtual void BackpropToNonLooping(size_t inputIndex) override
{
if (0 == inputIndex)
{
FrameRange fr(InputRef(0).GetMBLayout());
auto X_gradient = InputRef(0).GradientFor(fr);
auto weight = InputRef(1).ValueFor(fr);
Matrix<ElemType>::ChannelMultiply(Gradient(), weight, X_gradient, m_featureSize);
}
else
{
FrameRange fr(InputRef(0).GetMBLayout());
auto X = InputRef(0).ValueFor(fr);
auto weight_gradient = InputRef(1).GradientFor(fr);
weight_gradient.SetValue((ElemType)0);
m_buffer->Resize(m_featureSize * m_channels, X.GetNumCols());
Matrix<ElemType>::ChannelMultiplyScaleBackprop(Gradient(), X, weight_gradient, *m_buffer, m_featureSize, m_N);
}
}
virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override
{
FrameRange fr(InputRef(0).GetMBLayout());
auto X = InputRef(0).ValueFor(fr);
auto weight = InputRef(1).ValueFor(fr);
Matrix<ElemType>::ChannelMultiply(X, weight, Value(), m_featureSize);
}
virtual bool OutputUsedInComputingInputNodesGradients() const override
{
return false;
}
virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const override
{
return true;
}
virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
Base::Validate(isFinalValidationPass);
InferMBLayoutFromInputsForStandardCase(isFinalValidationPass);
SetDims(Input(0));
auto dims0 = Input(0)->GetSampleLayout().GetDims();
auto dims1 = Input(1)->GetSampleLayout().GetDims();
if (dims0.size() != 3)
LogicError("ChannelMultiplyNode : input[0] dimension not equals to 3 \n");
size_t temp = 1;
for (size_t i(0); i < dims1.size(); ++i)
temp *= dims1[i];
if (dims0[2] != temp)
LogicError("ChannelMultiplyNode : input channel not match %d v.s. %d\n", (int)dims0[2], (int)temp);
m_featureSize = dims0[0] * dims0[1];
m_channels = dims0[2];
size_t featureSize = m_featureSize;
m_N = 1;
assert(featureSize != 0);
while (featureSize)
{
m_N *= 2;
featureSize /= 2;
}
m_N /= 2;
}
virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override
{
Base::CopyTo(nodeP, newName, flags);
if (flags & CopyNodeFlags::copyNodeValue)
{
auto node = dynamic_pointer_cast<ChannelMultiplyNode<ElemType>>(nodeP);
node->m_buffer->SetValue(*m_buffer);
node->m_featureSize = m_featureSize;
node->m_channels = m_channels;
node->m_N = m_N;
}
}
virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
{
Base::RequestMatricesBeforeBackprop(matrixPool);
RequestMatrixFromPool(m_buffer, matrixPool, m_featureSize * m_channels, true, false, false);
}
virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool)
{
Base::ReleaseMatricesAfterBackprop(matrixPool);
ReleaseMatrixToPool(m_buffer, matrixPool);
}
void Save(File& fstream) const override
{
Base::Save(fstream);
fstream << m_featureSize << m_channels;
}
void Load(File& fstream, size_t modelVersion) override
{
Base::Load(fstream, modelVersion);
fstream >> m_featureSize >> m_channels;
}
shared_ptr<Matrix<ElemType>> m_buffer;
size_t m_featureSize;
size_t m_channels;
size_t m_N;
};
// -----------------------------------------------------------------------
// SquareErrorNode (left, right)

Просмотреть файл

@ -472,17 +472,17 @@ public:
#pragma endregion
#pragma region CenterLoss
#pragma region ArcMarginProduct
static void ClassCount(const CPUMatrix<ElemType>& label, const CPUMatrix<ElemType>& counter);
static void ArcLabelAdd(const CPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& value);
static void ArcLabelAddBackprop(const CPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& gradient);
#pragma endregion
#pragma region SqueezeAndExcitation
#pragma region CenterLoss
static void ChannelMultiply(const CPUMatrix<ElemType>& X, const CPUMatrix<ElemType>& weight, CPUMatrix<ElemType>& value, size_t featureSize);
static void ChannelMultiplyScaleBackprop(const CPUMatrix<ElemType>& gradient, const CPUMatrix<ElemType>& X, CPUMatrix<ElemType>& weight_gradient, size_t featureSize);
static void ClassCount(const CPUMatrix<ElemType>& label, const CPUMatrix<ElemType>& counter);
#pragma endregion
@ -518,6 +518,10 @@ public:
static void DistributedLabelAdd(const CPUMatrix<ElemType>& labels, ElemType bias, const CPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex);
static void DistributedArcLabelAdd(const CPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex);
static void DistributedArcLabelAddBackprop(const CPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex);
#pragma endregion

Просмотреть файл

@ -5430,6 +5430,61 @@ void CPUMatrix<ElemType>::LabelAdd(const CPUMatrix<ElemType>& label, ElemType bi
#pragma endregion
#pragma region ArcMarginProduct
template <class ElemType>
void CPUMatrix<ElemType>::ArcLabelAdd(const CPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& value)
{
size_t minibatchSize = value.GetNumCols();
size_t outputDimension = value.GetNumRows();
size_t labelValue;
ElemType* labelPtr = label.Data();
ElemType* flagPtr = flag.Data();
ElemType* xPtr = x.Data();
ElemType* valuePtr = value.Data();
for (size_t i(0); i < minibatchSize; ++i)
{
labelValue = static_cast<size_t>(labelPtr[i]);
size_t index = i * outputDimension + labelValue;
if (valuePtr[index] > threshold)
{
xPtr[i] = valuePtr[index];
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
}
else
{
valuePtr[index] -= bias * sinBias;
flagPtr[i] = 1.0f;
}
}
}
template <class ElemType>
void CPUMatrix<ElemType>::ArcLabelAddBackprop(const CPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& gradient)
{
size_t minibatchSize = gradient.GetNumCols();
size_t outputDimension = gradient.GetNumRows();
size_t labelValue;
ElemType* labelPtr = label.Data();
ElemType* flagPtr = flag.Data();
ElemType* xPtr = x.Data();
ElemType* gradientPtr = gradient.Data();
for (size_t i(0); i < minibatchSize; ++i)
{
if (flagPtr[i] < 0.5f)
{
labelValue = static_cast<size_t>(labelPtr[i]);
size_t index = i * outputDimension + labelValue;
gradientPtr[index] *= cosBias + sinBias * xPtr[i] / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
}
}
}
#pragma endregion
#pragma region CenterLoss
template <class ElemType>
@ -5457,57 +5512,6 @@ void CPUMatrix<ElemType>::ClassCount(const CPUMatrix<ElemType>& label, const CPU
#pragma endregion
#pragma region SqueezeAndExcitation
template <class ElemType>
void CPUMatrix<ElemType>::ChannelMultiply(const CPUMatrix<ElemType>& X, const CPUMatrix<ElemType>& weight, CPUMatrix<ElemType>& value, size_t featureSize)
{
long n = (long)X.GetNumCols();
long m = (long)X.GetNumRows();
#pragma omp parallel for
for (long j = 0; j < n; j++)
{
// four-way unrolling
for (long i = 0; i < (m & ~3); i += 4)
{
value(i, j) = X(i, j) * weight(i / featureSize, j);
value(i + 1, j) = X(i + 1, j) * weight((i + 1) / featureSize, j);
value(i + 2, j) = X(i + 2, j) * weight((i + 2) / featureSize, j);
value(i + 3, j) = X(i + 3, j) * weight((i + 3) / featureSize, j);
}
// handle remaining stuffs
for (long i = m & ~3; i < m; i++)
{
value(i, j) = X(i, j) * weight(i / featureSize, j);
}
}
}
template <class ElemType>
void CPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(const CPUMatrix<ElemType>& gradient, const CPUMatrix<ElemType>& X, CPUMatrix<ElemType>& weight_gradient, size_t featureSize)
{
long n = (long)X.GetNumCols();
long m = (long)X.GetNumRows();
#pragma omp parallel for
for (long j = 0; j < n; j++)
{
// four-way unrolling
for (long i = 0; i < (m & ~3); i += 4)
{
weight_gradient(i / featureSize, j) += gradient(i, j) * X(i, j);
}
// handle remaining stuffs
for (long i = m & ~3; i < m; i++)
{
weight_gradient(i / featureSize, j) += gradient(i, j) * X(i, j);
}
}
}
#pragma endregion
#pragma region LabelSmoothing
template <class ElemType>
@ -5829,6 +5833,54 @@ void CPUMatrix<ElemType>::DistributedLabelAdd(const CPUMatrix<ElemType>& labels,
}
}
template <class ElemType>
void CPUMatrix<ElemType>::DistributedArcLabelAdd(const CPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
{
long cols = (long)value.GetNumCols();
long rows = (long)value.GetNumRows();
ElemType* labelsPtr = labels.Data();
ElemType* flagPtr = flag.Data();
ElemType* xPtr = x.Data();
ElemType* valuePtr = value.Data();
for (long i = 0; i < cols; i += 4)
{
long index = i * rows + ((long)labelsPtr[i]) - (long)startIndex;
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex)
{
if (valuePtr[index] > threshold)
{
xPtr[i] = valuePtr[index];
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
}
else
{
valuePtr[index] -= bias * sinBias;
flagPtr[i] = 1.0f;
}
}
}
}
template <class ElemType>
void CPUMatrix<ElemType>::DistributedArcLabelAddBackprop(const CPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const CPUMatrix<ElemType>& flag, const CPUMatrix<ElemType>& x, const CPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex)
{
long cols = (long)gradient.GetNumCols();
long rows = (long)gradient.GetNumRows();
ElemType* labelsPtr = labels.Data();
ElemType* flagPtr = flag.Data();
ElemType* xPtr = x.Data();
ElemType* gradientPtr = gradient.Data();
for (long i = 0; i < cols; i += 4)
{
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && flagPtr[i] < 0.5f)
{
gradientPtr[i * rows + ((long)labelsPtr[i]) - startIndex] *= cosBias + sinBias * xPtr[i] / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
}
}
}
#pragma endregion

Просмотреть файл

@ -3765,14 +3765,14 @@ template <class ElemType>
__global__ void _labelAdd(CUDA_LONG outputDimension, const ElemType* label, ElemType bias, ElemType* value, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id < numElements)
{
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
CUDA_LONG index = id * outputDimension + labelValue;
if (value[index] <= -bias)
return;
value[index] += bias;
}
if (id >= numElements)
return;
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
CUDA_LONG index = id * outputDimension + labelValue;
if (value[index] <= -bias)
return;
value[index] += bias;
}
template <class ElemType>
@ -3789,6 +3789,70 @@ void GPUMatrix<ElemType>::LabelAdd(const GPUMatrix<ElemType>& label, ElemType bi
#pragma endregion
#pragma region CenterLoss
template <class ElemType>
__global__ void _arcLabelAdd(CUDA_LONG outputDimension, const ElemType* label, ElemType threshold, ElemType bias, ElemType sinBias, ElemType* flag, ElemType* x, ElemType* value, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id >= numElements)
return;
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
CUDA_LONG index = id * outputDimension + labelValue;
if (value[index] > threshold)
{
x[id] = value[index];
value[index] = cosf(acosf(value[index]) + bias);
}
else
{
value[index] -= bias * sinBias;
flag[id] = 1.0f;
}
}
template <class ElemType>
void GPUMatrix<ElemType>::ArcLabelAdd(const GPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value)
{
CUDA_LONG minibatchSize = (CUDA_LONG)value.GetNumCols();
CUDA_LONG outputDimension = (CUDA_LONG)value.GetNumRows();
int blocksPerGrid = (int)ceil(1.0 * minibatchSize / GridDim::maxThreadsPerBlock);
label.PrepareDevice();
SyncGuard syncGuard;
_arcLabelAdd<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (outputDimension, label.Data(), threshold, bias, sinBias, flag.Data(), x.Data(), value.Data(), minibatchSize);
}
template <class ElemType>
__global__ void _arcLabelAddBackprop(CUDA_LONG outputDimension, const ElemType* label, ElemType cosBias, ElemType sinBias, ElemType* flag, ElemType* x, ElemType* gradient, CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id >= numElements)
return;
if (flag[id] > 0.5f)
return;
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
CUDA_LONG index = id * outputDimension + labelValue;
gradient[index] *= cosBias + sinBias * x[id] / (sqrtf(1 - x[id] * x[id]) + 1e-12);
}
template <class ElemType>
void GPUMatrix<ElemType>::ArcLabelAddBackprop(const GPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient)
{
CUDA_LONG minibatchSize = (CUDA_LONG)gradient.GetNumCols();
CUDA_LONG outputDimension = (CUDA_LONG)gradient.GetNumRows();
int blocksPerGrid = (int)ceil(1.0 * minibatchSize / GridDim::maxThreadsPerBlock);
label.PrepareDevice();
SyncGuard syncGuard;
_arcLabelAddBackprop<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (outputDimension, label.Data(), cosBias, sinBias, flag.Data(), x.Data(), gradient.Data(), minibatchSize);
}
#pragma endregion
#pragma region CenterLoss
@ -3818,68 +3882,6 @@ void GPUMatrix<ElemType>::ClassCount(const GPUMatrix<ElemType>& label, const GPU
#pragma endregion
#pragma region SqueezeAndExcitation
template <class ElemType>
__global__ void _channelMultiply(const ElemType* X, const ElemType* weight, ElemType* value, CUDA_LONG featureSize, const CUDA_LONG numElements)
{
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= numElements)
return;
value[id] = X[id] * weight[id / featureSize];
}
template <class ElemType>
void GPUMatrix<ElemType>::ChannelMultiply(const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight, const GPUMatrix<ElemType>& value, size_t featureSize)
{
CUDA_LONG numElements = X.GetNumElements();
SyncGuard syncGuard;
GridDim grid(numElements);
_channelMultiply<ElemType> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(X.Data(), weight.Data(), value.Data(), (CUDA_LONG)featureSize, numElements);
}
template <class ElemType>
__global__ void _channelMultiplyScaleBackprop(const ElemType* gradient, const ElemType* X, ElemType* weight_gradient, ElemType* buffer, CUDA_LONG featureSize, CUDA_LONG N, const CUDA_LONG numElements)
{
CUDA_LONG id = GridDim::GetLinearThreadId();
if (id < numElements)
{
CUDA_LONG i = id / featureSize; // Channel i
CUDA_LONG j = id % featureSize; // Element j
if (j < N)
buffer[id] = gradient[id] * X[id];
__syncthreads();
if (j >= N)
buffer[id - N] += gradient[id] * X[id];
__syncthreads();
for (CUDA_LONG k = N >> 1; k > 0; k >>= 1)
{
if (j < k)
buffer[id] += buffer[id + k];
__syncthreads();
}
if (0 == j)
weight_gradient[i] = buffer[id];
}
}
template <class ElemType>
void GPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(const GPUMatrix<ElemType>& gradient, const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight_gradient, const GPUMatrix<ElemType>& buffer, size_t featureSize, size_t N)
{
CUDA_LONG numElements = gradient.GetNumElements();
SyncGuard syncGuard;
GridDim grid(numElements);
_channelMultiplyScaleBackprop<ElemType> << <grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream >> >(gradient.Data(), X.Data(), weight_gradient.Data(), buffer.Data(), (CUDA_LONG)featureSize, (CUDA_LONG)N, numElements);
}
#pragma endregion
#pragma region LabelSmoothing
template <class ElemType>
@ -4318,9 +4320,7 @@ __global__ void _distributedLabelAdd(const ElemType* labels, ElemType bias, Elem
return;
CUDA_LONG label = (CUDA_LONG)labels[id];
if (label < startIndex)
return;
if (label > endIndex)
if (label < startIndex || label > endIndex)
return;
if (value[id * rows + label - startIndex] <= -bias)
return;
@ -4340,6 +4340,69 @@ void GPUMatrix<ElemType>::DistributedLabelAdd(const GPUMatrix<ElemType>& labels,
_distributedLabelAdd<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (labels.Data(), bias, value.Data(), rows, (CUDA_LONG)startIndex, (CUDA_LONG)endIndex, cols);
}
template <class ElemType>
__global__ void _distributedArcLabelAdd(const ElemType* labels, const ElemType threshold, const ElemType bias, const ElemType sinBias, ElemType* flag, ElemType* x, ElemType* value, CUDA_LONG rows, CUDA_LONG startIndex, CUDA_LONG endIndex, CUDA_LONG numElements)
{
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= numElements)
return;
CUDA_LONG label = (CUDA_LONG)labels[id];
if (label < startIndex || label > endIndex)
return;
CUDA_LONG index = id * rows + label - startIndex;
if (value[index] > threshold)
{
x[id] = value[index];
value[index] = cosf(acosf(value[index]) + bias);
}
else
{
value[index] -= bias * sinBias;
flag[id] = 1.0f;
}
}
template <class ElemType>
void GPUMatrix<ElemType>::DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
{
CUDA_LONG cols = value.GetNumCols();
CUDA_LONG rows = value.GetNumRows();
int blocksPerGrid = (int)ceil(1.0 * cols / GridDim::maxThreadsPerBlock);
labels.PrepareDevice();
SyncGuard syncGuard;
_distributedArcLabelAdd<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (labels.Data(), threshold, bias, sinBias, flag.Data(), x.Data(), value.Data(), rows, (CUDA_LONG)startIndex, (CUDA_LONG)endIndex, cols);
}
template <class ElemType>
__global__ void _distributedArcLabelAddBackprop(const ElemType* labels, const ElemType cosBias, const ElemType sinBias, ElemType* flag, ElemType* x, ElemType* gradient, CUDA_LONG rows, CUDA_LONG startIndex, CUDA_LONG endIndex, CUDA_LONG numElements)
{
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= numElements)
return;
CUDA_LONG label = (CUDA_LONG)labels[id];
if (label < startIndex || label > endIndex || flag[id] > 0.5f)
return;
CUDA_LONG index = id * rows + label - startIndex;
gradient[index] *= cosBias + sinBias * x[id] / (sqrtf(1 - x[id] * x[id]) + 1e-12);
}
template <class ElemType>
void GPUMatrix<ElemType>::DistributedArcLabelAddBackprop(const GPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex)
{
CUDA_LONG cols = gradient.GetNumCols();
CUDA_LONG rows = gradient.GetNumRows();
int blocksPerGrid = (int)ceil(1.0 * cols / GridDim::maxThreadsPerBlock);
labels.PrepareDevice();
SyncGuard syncGuard;
_distributedArcLabelAddBackprop<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream >> > (labels.Data(), cosBias, sinBias, flag.Data(), x.Data(), gradient.Data(), rows, (CUDA_LONG)startIndex, (CUDA_LONG)endIndex, cols);
}
#pragma endregion

Просмотреть файл

@ -588,17 +588,17 @@ public:
#pragma endregion
#pragma region CenterLoss
#pragma region ArcMarginProduct
static void ClassCount(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& counter);
static void ArcLabelAdd(const GPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value);
static void ArcLabelAddBackprop(const GPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient);
#pragma endregion
#pragma region SqueezeAndExcitation
#pragma region CenterLoss
static void ChannelMultiply(const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight, const GPUMatrix<ElemType>& value, size_t featureSize);
static void ChannelMultiplyScaleBackprop(const GPUMatrix<ElemType>& gradient, const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight_gradient, const GPUMatrix<ElemType>& buffer, size_t featureSize, size_t N);
static void ClassCount(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& counter);
#pragma endregion
@ -634,6 +634,10 @@ public:
static void DistributedLabelAdd(const GPUMatrix<ElemType>& labels, ElemType bias, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex);
static void DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex);
static void DistributedArcLabelAddBackprop(const GPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex);
#pragma endregion

Просмотреть файл

@ -5203,6 +5203,32 @@ template <class ElemType>
#pragma endregion
#pragma region ArcMarginProduct
template <class ElemType>
/*static*/ void Matrix<ElemType>::ArcLabelAdd(const Matrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& value)
{
DISPATCH_MATRIX_ON_FLAG(&value,
&value,
CPUMatrix<ElemType>::ArcLabelAdd(*(label.m_CPUMatrix), threshold, bias, sinBias, *(flag.m_CPUMatrix), *(x.m_CPUMatrix), *(value.m_CPUMatrix)),
GPUMatrix<ElemType>::ArcLabelAdd(*(label.m_GPUMatrix), threshold, bias, sinBias, *(flag.m_GPUMatrix), *(x.m_GPUMatrix), *(value.m_GPUMatrix)),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
template <class ElemType>
/*static*/ void Matrix<ElemType>::ArcLabelAddBackprop(const Matrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& gradient)
{
DISPATCH_MATRIX_ON_FLAG(&gradient,
&gradient,
CPUMatrix<ElemType>::ArcLabelAddBackprop(*(label.m_CPUMatrix), cosBias, sinBias, *(flag.m_CPUMatrix), *(x.m_CPUMatrix), *(gradient.m_CPUMatrix)),
GPUMatrix<ElemType>::ArcLabelAddBackprop(*(label.m_GPUMatrix), cosBias, sinBias, *(flag.m_GPUMatrix), *(x.m_GPUMatrix), *(gradient.m_GPUMatrix)),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
#pragma endregion
#pragma region CenterLoss
template <class ElemType>
@ -5218,32 +5244,6 @@ template <class ElemType>
#pragma endregion
#pragma region SqueezeAndExcitation
template <class ElemType>
/*static*/ void Matrix<ElemType>::ChannelMultiply(const Matrix<ElemType>& X, const Matrix<ElemType>& weight, const Matrix<ElemType>& value, size_t featureSize)
{
DISPATCH_MATRIX_ON_FLAG(&value,
&value,
CPUMatrix<ElemType>::ChannelMultiply(*(X.m_CPUMatrix), *(weight.m_CPUMatrix), *(value.m_CPUMatrix), featureSize),
GPUMatrix<ElemType>::ChannelMultiply(*(X.m_GPUMatrix), *(weight.m_GPUMatrix), *(value.m_GPUMatrix), featureSize),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
template <class ElemType>
/*static*/ void Matrix<ElemType>::ChannelMultiplyScaleBackprop(const Matrix<ElemType>& gradient, const Matrix<ElemType>& X, const Matrix<ElemType>& weight_gradient, const Matrix<ElemType>& buffer, size_t featureSize, size_t N)
{
DISPATCH_MATRIX_ON_FLAG(&weight_gradient,
&weight_gradient,
CPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(*(gradient.m_CPUMatrix), *(X.m_CPUMatrix), *(weight_gradient.m_CPUMatrix), featureSize),
GPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(*(gradient.m_GPUMatrix), *(X.m_GPUMatrix), *(weight_gradient.m_GPUMatrix), *(buffer.m_GPUMatrix), featureSize, N),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
#pragma endregion
#pragma region LabelSmoothing
template <class ElemType>
@ -5425,6 +5425,29 @@ template <class ElemType>
NOT_IMPLEMENTED);
}
template <class ElemType>
/*static*/ void Matrix<ElemType>::DistributedArcLabelAdd(const Matrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& value, size_t startIndex, size_t endIndex)
{
DISPATCH_MATRIX_ON_FLAG(&value,
&value,
CPUMatrix<ElemType>::DistributedArcLabelAdd(*(labels.m_CPUMatrix), threshold, bias, sinBias, *(flag.m_CPUMatrix), *(x.m_CPUMatrix), *(value.m_CPUMatrix), startIndex, endIndex),
GPUMatrix<ElemType>::DistributedArcLabelAdd(*(labels.m_GPUMatrix), threshold, bias, sinBias, *(flag.m_GPUMatrix), *(x.m_GPUMatrix), *(value.m_GPUMatrix), startIndex, endIndex),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
template <class ElemType>
/*static*/ void Matrix<ElemType>::DistributedArcLabelAddBackprop(const Matrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& gradient, size_t startIndex, size_t endIndex)
{
DISPATCH_MATRIX_ON_FLAG(&gradient,
&gradient,
CPUMatrix<ElemType>::DistributedArcLabelAddBackprop(*(labels.m_CPUMatrix), cosBias, sinBias, *(flag.m_CPUMatrix), *(x.m_CPUMatrix), *(gradient.m_CPUMatrix), startIndex, endIndex),
GPUMatrix<ElemType>::DistributedArcLabelAddBackprop(*(labels.m_GPUMatrix), cosBias, sinBias, *(flag.m_GPUMatrix), *(x.m_GPUMatrix), *(gradient.m_GPUMatrix), startIndex, endIndex),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}
#pragma endregion

Просмотреть файл

@ -628,17 +628,17 @@ public:
#pragma endregion
#pragma region CenterLoss
#pragma region ArcMarginProduct
static void ClassCount(const Matrix<ElemType>& label, const Matrix<ElemType>& counter);
static void ArcLabelAdd(const Matrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& value);
static void ArcLabelAddBackprop(const Matrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& gradient);
#pragma endregion
#pragma region SqueezeAndExcitation
#pragma region CenterLoss
static void ChannelMultiply(const Matrix<ElemType>& X, const Matrix<ElemType>& weight, const Matrix<ElemType>& value, size_t featureSize);
static void ChannelMultiplyScaleBackprop(const Matrix<ElemType>& gradient, const Matrix<ElemType>& X, const Matrix<ElemType>& weight_gradient, const Matrix<ElemType>& buffer, size_t featureSize, size_t N);
static void ClassCount(const Matrix<ElemType>& label, const Matrix<ElemType>& counter);
#pragma endregion
@ -674,6 +674,10 @@ public:
static void DistributedLabelAdd(const Matrix<ElemType>& labels, ElemType bias, const Matrix<ElemType>& value, size_t startIndex, size_t endIndex);
static void DistributedArcLabelAdd(const Matrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& value, size_t startIndex, size_t endIndex);
static void DistributedArcLabelAddBackprop(const Matrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const Matrix<ElemType>& flag, const Matrix<ElemType>& x, const Matrix<ElemType>& gradient, size_t startIndex, size_t endIndex);
#pragma endregion

Просмотреть файл

@ -2166,24 +2166,24 @@ void GPUMatrix<ElemType>::LabelAdd(const GPUMatrix<ElemType>& label, ElemType bi
#pragma endregion
#pragma region CenterLoss
#pragma region ArcMarginProduct
template <class ElemType>
void GPUMatrix<ElemType>::ClassCount(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& counter)
void GPUMatrix<ElemType>::ArcLabelAdd(const GPUMatrix<ElemType>& label, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value)
{
}
template <class ElemType>
void GPUMatrix<ElemType>::ArcLabelAddBackprop(const GPUMatrix<ElemType>& label, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient)
{
}
#pragma endregion
#pragma region SqueezeAndExcitation
#pragma region CenterLoss
template <class ElemType>
void GPUMatrix<ElemType>::ChannelMultiply(const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight, const GPUMatrix<ElemType>& value, size_t featureSize)
{
}
template <class ElemType>
void GPUMatrix<ElemType>::ChannelMultiplyScaleBackprop(const GPUMatrix<ElemType>& gradient, const GPUMatrix<ElemType>& X, const GPUMatrix<ElemType>& weight_gradient, const GPUMatrix<ElemType>& buffer, size_t featureSize, size_t N)
void GPUMatrix<ElemType>::ClassCount(const GPUMatrix<ElemType>& label, const GPUMatrix<ElemType>& counter)
{
}
@ -2260,6 +2260,16 @@ void GPUMatrix<ElemType>::DistributedLabelAdd(const GPUMatrix<ElemType>& labels,
{
}
template <class ElemType>
void GPUMatrix<ElemType>::DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
{
}
template <class ElemType>
void GPUMatrix<ElemType>::DistributedArcLabelAddBackprop(const GPUMatrix<ElemType>& labels, ElemType cosBias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& gradient, size_t startIndex, size_t endIndex)
{
}
#pragma endregion