This commit is contained in:
Zhiming Mao 2019-06-20 16:43:35 +08:00
Родитель 7055de854b
Коммит 0d8f549c7c
13 изменённых файлов: 42 добавлений и 82 удалений

Просмотреть файл

@ -121,7 +121,7 @@ ArcMarginProductLayer {inputDimension, outputDimension, bias=0, init='glorotUnif
{
W = ParameterTensor {_ConcatArrays (outputDimension, inputDimension), init=init, initValueScale=initValueScale}
apply (labelSequence, outProbVectorSequence) = ArcMarginProduct(labelSequence, outProbVectorSequence, W, outputDimension=outputDimension, bias=bias)
apply (labelSequence, outProbVectorSequence) = ArcMarginProduct(labelSequence, outProbVectorSequence, W, bias=bias)
}.apply
# GlobalConcatLayer -- create a concat layer, which uses temporary global memory to save the output

Просмотреть файл

@ -4359,7 +4359,7 @@ namespace CNTK
CNTK_API FunctionPtr AdditiveFullConnection(const Variable& prediction, const Variable& targets, const Variable& weight, size_t outputDimension, bool weightNormalize, double bias, bool annealBias, double biasBase, double biasGamma, double biasPower, double biasMin, double biasMax, const std::wstring& name = L"");
CNTK_API FunctionPtr ArcMarginProduct(const Variable& prediction, const Variable& targets, const Variable& weight, size_t outputDimension, double bias, const std::wstring& name = L"");
CNTK_API FunctionPtr ArcMarginProduct(const Variable& prediction, const Variable& targets, const Variable& weight, double bias, const std::wstring& name = L"");
CNTK_API FunctionPtr CenterLoss(const Variable& prediction, const Variable& targets, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring& name = L"");

Просмотреть файл

@ -154,7 +154,6 @@ namespace CNTK
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasPower;
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasMin;
CNTK_API static const std::wstring AttributeAdditiveFullConnectionBiasMax;
CNTK_API static const std::wstring AttributeArcMarginProductOutputDimension;
CNTK_API static const std::wstring AttributeArcMarginProductBias;
CNTK_API static const std::wstring AttributeCenterLossLambda;
CNTK_API static const std::wstring AttributeCenterLossAlpha;

Просмотреть файл

@ -392,7 +392,6 @@ namespace CNTK
else if (node->OperationName() == OperationNameOf(ArcMarginProductNode))
{
auto arcMarginProductNode = node->As<ArcMarginProductNode<ElementType>>();
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeArcMarginProductOutputDimension] = arcMarginProductNode->m_outputDimension;
primitiveFunctionConfigParameters[PrimitiveFunctionAttribute::AttributeArcMarginProductBias] = arcMarginProductNode->m_bias;
opType = PrimitiveOpType::ArcMarginProduct;

Просмотреть файл

@ -1222,9 +1222,8 @@ namespace CNTK
}
case PrimitiveOpType::ArcMarginProduct:
{
auto outputDimension = functionConfig[PrimitiveFunctionAttribute::AttributeArcMarginProductOutputDimension].Value<size_t>();
auto bias = functionConfig[PrimitiveFunctionAttribute::AttributeArcMarginProductBias].Value<double>();
ASSIGN_NEW_NODE(ArcMarginProductNode, network->GetDeviceId(), internalNodeName, outputDimension, bias);
ASSIGN_NEW_NODE(ArcMarginProductNode, network->GetDeviceId(), internalNodeName, bias);
break;
}
case PrimitiveOpType::CenterLoss:

Просмотреть файл

@ -2215,11 +2215,10 @@ namespace CNTK
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::AdditiveFullConnection, operands, std::move(additionalProperties), name), name);
}
FunctionPtr ArcMarginProduct(const Variable& features, const Variable& labels, const Variable& weight, size_t outputDimension, double bias, const std::wstring& name)
FunctionPtr ArcMarginProduct(const Variable& features, const Variable& labels, const Variable& weight, double bias, const std::wstring& name)
{
std::vector<Variable> operands = { features, labels, weight };
auto additionalProperties = Dictionary();
additionalProperties[PrimitiveFunctionAttribute::AttributeArcMarginProductOutputDimension] = outputDimension;
additionalProperties[PrimitiveFunctionAttribute::AttributeArcMarginProductBias] = bias;
return AsComposite(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::ArcMarginProduct, operands, std::move(additionalProperties), name), name);
}

Просмотреть файл

@ -139,7 +139,6 @@ namespace CNTK
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasPower = L"additiveFullConnectionBiasPower";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMin = L"additiveFullConnectionBiasMin";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeAdditiveFullConnectionBiasMax = L"additiveFullConnectionBiasMax";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeArcMarginProductOutputDimension = L"arcMarginProductOutputDimension";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeArcMarginProductBias = L"arcMarginProductBias";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossLambda = L"centerLossLambda";
/*static*/ const std::wstring PrimitiveFunctionAttribute::AttributeCenterLossAlpha = L"centerLossAlpha";

Просмотреть файл

@ -544,9 +544,9 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Addit
}
template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType bias, const std::wstring nodeName)
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<ArcMarginProductNode<ElemType>>(net.GetDeviceId(), nodeName, outputDimension, bias), { a, b, c });
return net.AddNodeToNetAndAttachInputs(New<ArcMarginProductNode<ElemType>>(net.GetDeviceId(), nodeName, bias), { a, b, c });
}
template <class ElemType>

Просмотреть файл

@ -209,7 +209,7 @@ public:
ComputationNodePtr MarginInnerProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType base, ElemType gamma, ElemType power, ElemType lambdaMin, size_t marginCoefficient, const std::wstring nodeName = L"");
ComputationNodePtr FeatureNormalize(const ComputationNodePtr a, size_t normalizeType, const std::wstring nodeName = L"");
ComputationNodePtr AdditiveFullConnection(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, bool weightNormalize, ElemType bias, bool annealBias, ElemType biasBase, ElemType biasGamma, ElemType biasPower, ElemType biasMin, ElemType biasMax, const std::wstring nodeName = L"");
ComputationNodePtr ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, size_t outputDimension, ElemType bias, const std::wstring nodeName = L"");
ComputationNodePtr ArcMarginProduct(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, ElemType bias, const std::wstring nodeName = L"");
ComputationNodePtr CenterLoss(const ComputationNodePtr a, const ComputationNodePtr b, double lambda, double alpha, size_t labelDim, bool normalize, const std::wstring nodeName = L"");
ComputationNodePtr GlobalConcat(const ComputationNodePtr a, size_t blockIndex, size_t growthRate, size_t segmentIndex, size_t segmentNum, const std::wstring nodeName = L"");
ComputationNodePtr Sum(const ComputationNodePtr a, const std::wstring nodeName = L"");

Просмотреть файл

@ -918,15 +918,13 @@ public:
void Save(File& fstream) const override
{
Base::Save(fstream);
fstream << m_bias;
fstream << m_scale;
fstream << m_bias << m_scale;
}
void Load(File& fstream, size_t modelVersion) override
{
Base::Load(fstream, modelVersion);
fstream >> m_bias;
fstream >> m_scale;
fstream >> m_bias >> m_scale;
if (m_bias < 0 || m_bias >= m_PI)
LogicError("DistributedArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
@ -1693,13 +1691,13 @@ class ArcMarginProductNode : public ComputationNodeNonLooping /*ComputationNode*
public:
ArcMarginProductNode(const ScriptableObjects::IConfigRecordPtr configp)
: ArcMarginProductNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"outputDimension"), configp->Get(L"bias"))
: ArcMarginProductNode(configp->Get(L"deviceId"), L"<placeholder>", configp->Get(L"bias"))
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
}
ArcMarginProductNode(DEVICEID_TYPE deviceId, const wstring& name, size_t outputDimension = 0, double bias = 0.0)
: Base(deviceId, name), m_outputDimension(outputDimension), m_bias(bias)
ArcMarginProductNode(DEVICEID_TYPE deviceId, const wstring& name, double bias = 0.0)
: Base(deviceId, name), m_bias(bias)
{
if (m_bias < 0 || m_bias >= m_PI)
LogicError("ArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
@ -1711,11 +1709,10 @@ public:
virtual void UpdateFunctionMBSize() override
{
m_minibatchSize = InputRef(0).Value().GetNumCols();
m_label->Resize(1, m_minibatchSize);
m_tempValue->Resize(1, m_minibatchSize);
m_flag->Resize(1, m_minibatchSize);
m_weightMagnitude->Resize(m_outputDimension, 1); // Matrix(k,1)
m_weightMagnitude->Resize(InputRef(2).Value().GetNumRows(), 1);
}
virtual void BackpropToNonLooping(size_t inputIndex) override
@ -1783,7 +1780,6 @@ public:
{
auto node = dynamic_pointer_cast<ArcMarginProductNode<ElemType>>(nodeP);
node->m_minibatchSize = m_minibatchSize;
node->m_outputDimension = m_outputDimension;
node->m_bias = m_bias;
node->m_threshold = m_threshold;
node->m_cosBias = m_cosBias;
@ -1822,13 +1818,13 @@ public:
void Save(File& fstream) const override
{
Base::Save(fstream);
fstream << m_outputDimension << m_bias;
fstream << m_bias;
}
void Load(File& fstream, size_t modelVersion) override
{
Base::Load(fstream, modelVersion);
fstream >> m_outputDimension >> m_bias;
fstream >> m_bias;
if (m_bias < 0 || m_bias >= m_PI)
LogicError("ArcMarginProductNode: bias(%.8g) not in range [0, PI)", m_bias);
@ -1837,17 +1833,16 @@ public:
m_sinBias = sin(m_bias);
}
size_t m_minibatchSize; // m
size_t m_outputDimension; // k
size_t m_minibatchSize;
double m_bias;
double m_threshold;
double m_cosBias;
double m_sinBias;
shared_ptr<Matrix<ElemType>> m_label; // Matrix(1,m)
shared_ptr<Matrix<ElemType>> m_tempValue; // Matrix(1,m)
shared_ptr<Matrix<ElemType>> m_flag; // Matrix(1,m)
shared_ptr<Matrix<ElemType>> m_weightMagnitude; // Matrix(k,1)
shared_ptr<Matrix<ElemType>> m_label;
shared_ptr<Matrix<ElemType>> m_tempValue;
shared_ptr<Matrix<ElemType>> m_flag;
shared_ptr<Matrix<ElemType>> m_weightMagnitude;
};
// Implements Center-Loss as described in:

Просмотреть файл

@ -5450,8 +5450,8 @@ void CPUMatrix<ElemType>::ArcLabelAdd(const CPUMatrix<ElemType>& label, ElemType
if (valuePtr[index] > threshold)
{
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
xPtr[i] = valuePtr[index];
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
}
else
{
@ -5478,7 +5478,7 @@ void CPUMatrix<ElemType>::ArcLabelAddBackprop(const CPUMatrix<ElemType>& label,
{
labelValue = static_cast<size_t>(labelPtr[i]);
size_t index = i * outputDimension + labelValue;
gradientPtr[index] *= cosBias + xPtr[i] * sinBias / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
gradientPtr[index] *= cosBias + sinBias * xPtr[i] / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
}
}
}
@ -5843,34 +5843,21 @@ void CPUMatrix<ElemType>::DistributedArcLabelAdd(const CPUMatrix<ElemType>& labe
ElemType* xPtr = x.Data();
ElemType* valuePtr = value.Data();
// four-way unrolling
for (long i = 0; i < (cols & ~3); i += 4)
for (long i = 0; i < cols; i += 4)
{
long index = i * rows + ((long)labelsPtr[i]) - (long)startIndex;
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && valuePtr[index] > threshold)
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex)
{
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
xPtr[i] = valuePtr[index];
}
else
{
valuePtr[index] -= bias * sinBias;
flagPtr[i] = 1.0f;
}
}
// handle remaining stuffs
for (long i = cols & ~3; i < cols; i++)
{
long index = i * rows + ((long)labelsPtr[i]) - (long)startIndex;
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && valuePtr[index] > threshold)
{
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
xPtr[i] = valuePtr[index];
}
else
{
valuePtr[index] -= bias * sinBias;
flagPtr[i] = 1.0f;
if (valuePtr[index] > threshold)
{
xPtr[i] = valuePtr[index];
valuePtr[index] = cos(acos(valuePtr[index]) + bias);
}
else
{
valuePtr[index] -= bias * sinBias;
flagPtr[i] = 1.0f;
}
}
}
}
@ -5885,20 +5872,11 @@ void CPUMatrix<ElemType>::DistributedArcLabelAddBackprop(const CPUMatrix<ElemTyp
ElemType* xPtr = x.Data();
ElemType* gradientPtr = gradient.Data();
// four-way unrolling
for (long i = 0; i < (cols & ~3); i += 4)
for (long i = 0; i < cols; i += 4)
{
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && flagPtr[i * rows + ((long)labelsPtr[i]) - startIndex] < 0.5f)
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && flagPtr[i] < 0.5f)
{
gradientPtr[i * rows + ((long)labelsPtr[i]) - startIndex] *= cosBias + xPtr[i] * sinBias / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
}
}
// handle remaining stuffs
for (long i = cols & ~3; i < cols; i++)
{
if (labelsPtr[i] >= startIndex && labelsPtr[i] <= endIndex && flagPtr[i * rows + ((long)labelsPtr[i]) - startIndex] < 0.5f)
{
gradientPtr[i * rows + ((long)labelsPtr[i]) - startIndex] *= cosBias + xPtr[i] * sinBias / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
gradientPtr[i * rows + ((long)labelsPtr[i]) - startIndex] *= cosBias + sinBias * xPtr[i] / (sqrt(1 - xPtr[i] * xPtr[i]) + 1e-12);
}
}
}

Просмотреть файл

@ -3803,8 +3803,8 @@ __global__ void _arcLabelAdd(CUDA_LONG outputDimension, const ElemType* label, E
if (value[index] > threshold)
{
value[index] = cosf(acosf(value[index]) + bias);
x[id] = value[index];
value[index] = cosf(acosf(value[index]) + bias);
}
else
{
@ -3837,7 +3837,7 @@ __global__ void _arcLabelAddBackprop(CUDA_LONG outputDimension, const ElemType*
CUDA_LONG labelValue = static_cast<CUDA_LONG>(label[id]);
CUDA_LONG index = id * outputDimension + labelValue;
gradient[index] *= cosBias + x[id] * sinBias / (sqrtf(1 - x[id] * x[id]) + 1e-12);
gradient[index] *= cosBias + sinBias * x[id] / (sqrtf(1 - x[id] * x[id]) + 1e-12);
}
template <class ElemType>
@ -4354,8 +4354,8 @@ __global__ void _distributedArcLabelAdd(const ElemType* labels, const ElemType t
CUDA_LONG index = id * rows + label - startIndex;
if (value[index] > threshold)
{
value[index] = cosf(acosf(value[index]) + bias);
x[id] = value[index];
value[index] = cosf(acosf(value[index]) + bias);
}
else
{
@ -4384,14 +4384,11 @@ __global__ void _distributedArcLabelAddBackprop(const ElemType* labels, const El
return;
CUDA_LONG label = (CUDA_LONG)labels[id];
if (label < startIndex || label > endIndex)
return;
if (flag[id] > 0.5)
if (label < startIndex || label > endIndex || flag[id] > 0.5f)
return;
CUDA_LONG index = id * rows + label - startIndex;
gradient[index] *= cosBias + x[id] * sinBias / (sqrtf(1 - x[id] * x[id]) + 1e-12);
gradient[index] *= cosBias + sinBias * x[id] / (sqrtf(1 - x[id] * x[id]) + 1e-12);
}
template <class ElemType>

Просмотреть файл

@ -2260,11 +2260,6 @@ void GPUMatrix<ElemType>::DistributedLabelAdd(const GPUMatrix<ElemType>& labels,
{
}
template <class ElemType>
void GPUMatrix<ElemType>::DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType bias, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
{
}
template <class ElemType>
void GPUMatrix<ElemType>::DistributedArcLabelAdd(const GPUMatrix<ElemType>& labels, ElemType threshold, ElemType bias, ElemType sinBias, const GPUMatrix<ElemType>& flag, const GPUMatrix<ElemType>& x, const GPUMatrix<ElemType>& value, size_t startIndex, size_t endIndex)
{