arcface bug fix
This commit is contained in:
Родитель
7055de854b
Коммит
0d8f549c7c
|
@ -121,7 +121,7 @@ ArcMarginProductLayer {inputDimension, outputDimension, bias=0, init='glorotUnif
|
|||
{
|
||||
W = ParameterTensor {_ConcatArrays (outputDimension, inputDimension), init=init, initValueScale=initValueScale}
|
||||
|
||||
apply (labelSequence, outProbVectorSequence) = ArcMarginProduct(labelSequence, outProbVectorSequence, W, outputDimension=outputDimension, bias=bias)
|
||||
apply (labelSequence, outProbVectorSequence) = ArcMarginProduct(labelSequence, outProbVectorSequence, W, bias=bias)
|
||||
}.apply
|
||||
|
||||
# GlobalConcatLayer -- create a concat layer, which uses temporary global memory to save the output
|
||||
|
|
|
@ -4359,7 +4359,7 @@ namespace CNTK
|
|||
|
||||
CNTK_API FunctionPtr AdditiveFullConnection(const Variable& prediction, const Variable& targets, const Variable& weight, size_t outputDimension, bool weightNormalize, double bias, bool annealBias, double biasBase, double biasGamma, double biasPower, double biasMin, double biasMax, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr ArcMarginProduct(const Variable& prediction, const Variable& targets, const Variable& weight, size_t outputDimension, double bias, const std::wstring& name = L"");
|
||||
CNTK_API FunctionPtr ArcMarginProduct(const Variable& prediction, const Variable& targets, const Variable& weight, double bias, const std::wstring& name = L"");
|
||||
|
||||
CNTK_API FunctionPtr CenterLoss(const Variable& prediction, const Variable& targets, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring& name = L"");
|
||||
|
||||
|
|
|
@ -154,7 +154,6 @@ namespace CNTK
|
|||
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasPower;
|
||||
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasMin;
|
||||
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasMax;
|
||||
CNTK_API static const std::wstring AttributeArcMarginProductOutputDimension;
|
||||
CNTK_API static const std::wstring AttributeArcMarginProductBias;
|
||||
CNTK_API static const std::wstring AttributeCenterLossLambda;
|
||||
CNTK_API static const std::wstring AttributeCenterLossAlpha;
|
||||
|
|
|
@ -392,7 +392,6 @@ namespace CNTK
|
|||
else if (node->OperationName() == OperationNameOf(ArcMarginProductNode))
|
||||
{
|
||||
auto arcMarginProductNode = node->As<ArcMarginProductNode<ElementType>>();
|
||||
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeArcMarginProductOutputDimension] = arcMarginProductNode->m_outputDimension;
|
||||
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeArcMarginProductBias] = arcMarginProductNode->m_bias;
|
||||
|
||||
opType = PrimitiveOpType::ArcMarginProduct;
|
||||
|
|
|
@ -1222,9 +1222,8 @@ namespace CNTK
|
|||
}
|
||||
case PrimitiveOpType::ArcMarginProduct:
|
||||
{
|
||||
auto outputDimension = functionConfig[PrimitiveFunctionAttribute::AttributeArcMarginProductOutputDimension].Value<size_t>();
|
||||
auto bias = functionConfig[PrimitiveFunctionAttribute::AttributeArcMarginProductBias].Value<double>();
|
||||
ASSIGN_NEW_NODE(ArcMarginProductNode, network->GetDeviceId(), internalNodeName, outputDimension, bias);
|
||||
ASSIGN_NEW_NODE(ArcMarginProductNode, network->GetDeviceId(), internalNodeName, bias);
|
||||
break;
|
||||
}
|
||||
case PrimitiveOpType::CenterLoss:
|
||||
|
|
|
@ -2215,11 +2215,10 @@ namespace CNTK
|
|||
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::AdditiveFullConnection, operands, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
||||
FunctionPtr ArcMarginProduct(const Variable& features, const Variable& labels, const Variable& weight, size_t outputDimension, double bias, const std::wstring& name)
|
||||
FunctionPtr ArcMarginProduct(const Variable& features, const Variable& labels, const Variable& weight, double bias, const std::wstring& name)
|
||||
{
|
||||
std::vector<Variable> operands = { features, labels, weight };
|
||||
auto additionalProperties = Dictionary();
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeArcMarginProductOutputDimension] = outputDimension;
|
||||
additionalProperties[PrimitiveFunctionAttribute::AttributeArcMarginProductBias] = bias;
|
||||
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ArcMarginProduct, operands, std::move(additionalProperties), name), name);
|
||||
}
|
||||
|
|
|
@ -139,7 +139,6 @@ namespace CNTK
|
|||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasPower = L"additiveFullConnectionBiasPower";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMin = L"additiveFullConnectionBiasMin";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMax = L"additiveFullConnectionBiasMax";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeArcMarginProductOutputDimension = L"arcMarginProductOutputDimension";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeArcMarginProductBias = L"arcMarginProductBias";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossLambda = L"centerLossLambda";
|
||||
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossAlpha = L"centerLossAlpha";
|
||||
|
|
|
@ -544,9 +544,9 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Addit
|
|||
}
|
||||
|
||||
template <class ElemType>
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType bias, const std::wstring nodeName)
|
||||
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, const std::wstring nodeName)
|
||||
{
|
||||
return net.AddNodeToNetAndAttachInputs(New<ArcMarginProductNode<ElemType>>(net.GetDeviceId(), nodeName, outputDimension, bias), { a, b, c });
|
||||
return net.AddNodeToNetAndAttachInputs(New<ArcMarginProductNode<ElemType>>(net.GetDeviceId(), nodeName, bias), { a, b, c });
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -209,7 +209,7 @@ public:
|
|||
ComputationNodePtr MarginInnerProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType base, ElemType gamma, ElemType power, ElemType lambdaMin, size_t marginCoefficient, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr FeatureNormalize(const ComputationNodePtr a, size_t normalizeType, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr AdditiveFullConnection(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, bool weightNormalize, ElemType bias, bool annealBias, ElemType biasBase, ElemType biasGamma, ElemType biasPower, ElemType biasMin, ElemType biasMax, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType bias, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr CenterLoss(const ComputationNodePtr a, const ComputationNodePtr b, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr GlobalConcat(const ComputationNodePtr a, size_t blockIndex, size_t growthRate, size_t segmentIndex, size_t segmentNum, const std::wstring nodeName = L"");
|
||||
ComputationNodePtr Sum(const ComputationNodePtr a, const std::wstring nodeName = L"");
|
||||
|
|
|
@ -918,15 +918,13 @@ public:
|
|||
void Save(File& fstream) const override
|
||||
{
|
||||
Base::Save(fstream);
|
||||
fstream << m_bias;
|
||||
fstream << m_scale;
|
||||
fstream << m_bias << m_scale;
|
||||
}
|
||||
|
||||
void Load(File& fstream, size_t modelVersion) override
|
||||
{
|
||||
Base::Load(fstream, modelVersion);
|
||||
fstream >> m_bias;
|
||||
fstream >> m_scale;
|
||||
fstream >> m_bias >> m_scale;
|
||||
|
||||
if (m_bias < 0 || m_bias >= m_PI)
|
||||
LogicError("DistributedArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
|
||||
|
@ -1693,13 +1691,13 @@ class ArcMarginProductNode : public ComputationNodeNonLooping /*ComputationNode*
|
|||
|
||||
public:
|
||||
ArcMarginProductNode(const ScriptableObjects::IConfigRecordPtr configp)
|
||||
: ArcMarginProductNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"outputDimension"), configp->Get(L"bias"))
|
||||
: ArcMarginProductNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"bias"))
|
||||
{
|
||||
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
|
||||
}
|
||||
|
||||
ArcMarginProductNode(DEVICEID_TYPE deviceId, const wstring& name, size_t outputDimension = 0, double bias = 0.0)
|
||||
: Base(deviceId, name), m_outputDimension(outputDimension), m_bias(bias)
|
||||
ArcMarginProductNode(DEVICEID_TYPE deviceId, const wstring& name, double bias = 0.0)
|
||||
: Base(deviceId, name), m_bias(bias)
|
||||
{
|
||||
if (m_bias < 0 || m_bias >= m_PI)
|
||||
LogicError("ArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
|
||||
|
@ -1711,11 +1709,10 @@ public:
|
|||
virtual void UpdateFunctionMBSize() override
|
||||
{
|
||||
m_minibatchSize = InputRef(0).Value().GetNumCols();
|
||||
|
||||
m_label->Resize(1, m_minibatchSize);
|
||||
m_tempValue->Resize(1, m_minibatchSize);
|
||||
m_flag->Resize(1, m_minibatchSize);
|
||||
m_weightMagnitude->Resize(m_outputDimension, 1); // Matrix(k,1)
|
||||
m_weightMagnitude->Resize(InputRef(2).Value().GetNumRows(), 1);
|
||||
}
|
||||
|
||||
virtual void BackpropToNonLooping(size_t inputIndex) override
|
||||
|
@ -1783,7 +1780,6 @@ public:
|
|||
{
|
||||
auto node = dynamic_pointer_cast<ArcMarginProductNode<ElemType>>(nodeP);
|
||||
node->m_minibatchSize = m_minibatchSize;
|
||||
node->m_outputDimension = m_outputDimension;
|
||||
node->m_bias = m_bias;
|
||||
node->m_threshold = m_threshold;
|
||||
node->m_cosBias = m_cosBias;
|
||||
|
@ -1822,13 +1818,13 @@ public:
|
|||
void Save(File& fstream) const override
|
||||
{
|
||||
Base::Save(fstream);
|
||||
fstream << m_outputDimension << m_bias;
|
||||
fstream << m_bias;
|
||||
}
|
||||
|
||||
void Load(File& fstream, size_t modelVersion) override
|
||||
{
|
||||
Base::Load(fstream, modelVersion);
|
||||
fstream >> m_outputDimension >> m_bias;
|
||||
fstream >> m_bias;
|
||||
|
||||
if (m_bias < 0 || m_bias >= m_PI)
|
||||
LogicError("ArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
|
||||
|
@ -1837,17 +1833,16 @@ public:
|
|||
m_sinBias = sin(m_bias);
|
||||
}
|
||||
|
||||
size_t m_minibatchSize; // m
|
||||
size_t m_outputDimension; // k
|
||||
size_t m_minibatchSize;
|
||||
double m_bias;
|
||||
double m_threshold;
|
||||
double m_cosBias;
|
||||
double m_sinBias;
|
||||
|
||||
shared_ptr<Matrix<ElemType>> m_label; // Matrix(1,m)
|
||||
shared_ptr<Matrix<ElemType>> m_tempValue; // Matrix(1,m)
|
||||
shared_ptr<Matrix<ElemType>> m_flag; // Matrix(1,m)
|
||||
shared_ptr<Matrix<ElemType>> m_weightMagnitude; // Matrix(k,1)
|
||||
shared_ptr<Matrix<ElemType>> m_label;
|
||||
shared_ptr<Matrix<ElemType>> m_tempValue;
|
||||
shared_ptr<Matrix<ElemType>> m_flag;
|
||||
shared_ptr<Matrix<ElemType>> m_weightMagnitude;
|
||||
};
|
||||
|
||||
// Implements Center-Loss as described in:
|
||||
|
|
|
@ -5450,8 +5450,8 @@ void CPUMatrix<ElemType>::ArcLabelAdd(const CPUMatrix<ElemType>& label, ElemType
|
|||
|
||||
if (valuePtr[index] > threshold)
|
||||
{
|
||||
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
|
||||
xPtr[i] = valuePtr[index];
|
||||
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -5478,7 +5478,7 @@ void CPUMatrix<ElemType>::ArcLabelAddBackprop(const CPUMatrix<ElemType>& label,
|
|||
{
|
||||
labelValue = static_cast<size_t>(labelPtr[i]);
|
||||
size_t index = i * outputDimension + labelValue;
|
||||
gradientPtr[index] *= cosBias + xPtr[i] * sinBias / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
|
||||
gradientPtr[index] *= cosBias + sinBias * xPtr[i] / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5843,34 +5843,21 @@ void CPUMatrix<ElemType>::DistributedArcLabelAdd(const CPUMatrix<ElemType>& labe
|
|||
ElemType* xPtr = x.Data();
|
||||
ElemType* valuePtr = value.Data();
|
||||
|
||||
// four-way unrolling
|
||||
for (long i = 0; i < (cols & ~3); i += 4)
|
||||
for (long i = 0; i < cols; i += 4)
|
||||
{
|
||||
long index = i * rows + ((long)labelsPtr[i]) - (long)startIndex;
|
||||
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && valuePtr[index] > threshold)
|
||||
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex)
|
||||
{
|
||||
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
|
||||
xPtr[i] = valuePtr[index];
|
||||
}
|
||||
else
|
||||
{
|
||||
valuePtr[index] -= bias * sinBias;
|
||||
flagPtr[i] = 1.0f;
|
||||
}
|
||||
}
|
||||
// handle remaining stuffs
|
||||
for (long i = cols & ~3; i < cols; i++)
|
||||
{
|
||||
long index = i * rows + ((long)labelsPtr[i]) - (long)startIndex;
|
||||
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && valuePtr[index] > threshold)
|
||||
{
|
||||
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
|
||||
xPtr[i] = valuePtr[index];
|
||||
}
|
||||
else
|
||||
{
|
||||
valuePtr[index] -= bias * sinBias;
|
||||
flagPtr[i] = 1.0f;
|
||||
if (valuePtr[index] > threshold)
|
||||
{
|
||||
xPtr[i] = valuePtr[index];
|
||||
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
valuePtr[index] -= bias * sinBias;
|
||||
flagPtr[i] = 1.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5885,20 +5872,11 @@ void CPUMatrix<ElemType>::DistributedArcLabelAddBackprop(const CPUMatrix<ElemTyp
|
|||
ElemType* xPtr = x.Data();
|
||||
ElemType* gradientPtr = gradient.Data();
|
||||
|
||||
// four-way unrolling
|
||||
for (long i = 0; i < (cols & ~3); i += 4)
|
||||
for (long i = 0; i < cols; i += 4)
|
||||
{
|
||||
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && flagPtr[i * rows + ((long)labelsPtr[i]) - startIndex] < 0.5f)
|
||||
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && flagPtr[i] < 0.5f)
|
||||
{
|
||||
gradientPtr[i * rows + ((long)labelsPtr[i]) - startIndex] *= cosBias + xPtr[i] * sinBias / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
|
||||
}
|
||||
}
|
||||
// handle remaining stuffs
|
||||
for (long i = cols & ~3; i < cols; i++)
|
||||
{
|
||||
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && flagPtr[i * rows + ((long)labelsPtr[i]) - startIndex] < 0.5f)
|
||||
{
|
||||
gradientPtr[i * rows + ((long)labelsPtr[i]) - startIndex] *= cosBias + xPtr[i] * sinBias / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
|
||||
gradientPtr[i * rows + ((long)labelsPtr[i]) - startIndex] *= cosBias + sinBias * xPtr[i] / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3803,8 +3803,8 @@ __global__ void _arcLabelAdd(CUDA_LONG outputDimension, const ElemType* label, E
|
|||
|
||||
if (value[index] > threshold)
|
||||
{
|
||||
value[index] = cosf(acosf(value[index]) + bias);
|
||||
x[id] = value[index];
|
||||
value[index] = cosf(acosf(value[index]) + bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -3837,7 +3837,7 @@ __global__ void _arcLabelAddBackprop(CUDA_LONG outputDimension, const ElemType*
|
|||
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
|
||||
CUDA_LONG index = id * outputDimension + labelValue;
|
||||
|
||||
gradient[index] *= cosBias + x[id] * sinBias / (sqrtf(1 - x[id] * x[id]) + 1e-12);
|
||||
gradient[index] *= cosBias + sinBias * x[id] / (sqrtf(1 - x[id] * x[id]) + 1e-12);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
@ -4354,8 +4354,8 @@ __global__ void _distributedArcLabelAdd(const ElemType* labels, const ElemType t
|
|||
CUDA_LONG index = id * rows + label - startIndex;
|
||||
if (value[index] > threshold)
|
||||
{
|
||||
value[index] = cosf(acosf(value[index]) + bias);
|
||||
x[id] = value[index];
|
||||
value[index] = cosf(acosf(value[index]) + bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -4384,14 +4384,11 @@ __global__ void _distributedArcLabelAddBackprop(const ElemType* labels, const El
|
|||
return;
|
||||
|
||||
CUDA_LONG label = (CUDA_LONG)labels[id];
|
||||
if (label < startIndex || label > endIndex)
|
||||
return;
|
||||
|
||||
if (flag[id] > 0.5)
|
||||
if (label < startIndex || label > endIndex || flag[id] > 0.5f)
|
||||
return;
|
||||
|
||||
CUDA_LONG index = id * rows + label - startIndex;
|
||||
gradient[index] *= cosBias + x[id] * sinBias / (sqrtf(1 - x[id] * x[id]) + 1e-12);
|
||||
gradient[index] *= cosBias + sinBias * x[id] / (sqrtf(1 - x[id] * x[id]) + 1e-12);
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
|
|
|
@ -2260,11 +2260,6 @@ void GPUMatrix<ElemType>::DistributedLabelAdd(const GPUMatrix<ElemType>& labels,
|
|||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType bias, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
}
|
||||
|
||||
template <class ElemType>
|
||||
void GPUMatrix<ElemType>::DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче