CNTK v2 library: Add binary CrossEntropy operator
This commit is contained in:
Родитель
2ca454bd84
Коммит
e2cf02a609
|
@ -2699,6 +2699,16 @@ namespace CNTK
|
||||||
return TransposeTimes(leftOperand, rightOperand, /*outputRank =*/ 1, name);
|
return TransposeTimes(leftOperand, rightOperand, /*outputRank =*/ 1, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Create an instance of the CNTK built-in operation to compute binary cross-entropy for specified input operands.
|
||||||
|
///
|
||||||
|
CNTK_API FunctionPtr BinaryCrossEntropy(const Variable& prediction, const Variable& targets, const std::wstring& name = L"");
|
||||||
|
|
||||||
|
///
|
||||||
|
/// Create an instance of the CNTK built-in operation to compute weighted binary cross-entropy for specified input operands.
|
||||||
|
///
|
||||||
|
CNTK_API FunctionPtr WeightedBinaryCrossEntropy(const Variable& prediction, const Variable& targets, const Variable& weights, const std::wstring& name = L"");
|
||||||
|
|
||||||
///
|
///
|
||||||
/// Create an instance of the CNTK built-in operation to compute squared-error for specified input operands.
|
/// Create an instance of the CNTK built-in operation to compute squared-error for specified input operands.
|
||||||
///
|
///
|
||||||
|
|
|
@ -232,6 +232,8 @@ namespace CNTK
|
||||||
primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameOffset] = (size_t)node->As<FutureValueNode<ElementType>>()->TimeStep();
|
primitiveFunctionConfigParameters[PrimitiveFunction::AttributeNameOffset] = (size_t)node->As<FutureValueNode<ElementType>>()->TimeStep();
|
||||||
opType = PrimitiveOpType::FutureValue;
|
opType = PrimitiveOpType::FutureValue;
|
||||||
}
|
}
|
||||||
|
else if (node->OperationName() == OperationNameOf(LogisticNode))
|
||||||
|
opType = PrimitiveOpType::Logistic;
|
||||||
else if (node->OperationName() == OperationNameOf(SquareErrorNode))
|
else if (node->OperationName() == OperationNameOf(SquareErrorNode))
|
||||||
opType = PrimitiveOpType::SquaredError;
|
opType = PrimitiveOpType::SquaredError;
|
||||||
else if (node->OperationName() == OperationNameOf(CrossEntropyWithSoftmaxNode))
|
else if (node->OperationName() == OperationNameOf(CrossEntropyWithSoftmaxNode))
|
||||||
|
|
|
@ -634,10 +634,16 @@ namespace CNTK
|
||||||
if (outputDataType == DataType::Unknown)
|
if (outputDataType == DataType::Unknown)
|
||||||
outputDataType = firstKnownInputDataType;
|
outputDataType = firstKnownInputDataType;
|
||||||
|
|
||||||
// We currently require that the inputs' dynamic axes if any match
|
// We currently require that the inputs' dynamic axes, if any, match
|
||||||
std::vector<Axis> outputDynamicAxes;
|
std::vector<Axis> outputDynamicAxes;
|
||||||
if ((op == PrimitiveOpType::SumAll) || (op == PrimitiveOpType::SquaredError) || (op == PrimitiveOpType::CrossEntropyWithSoftmax) || (op == PrimitiveOpType::ClassificationError))
|
if ((op == PrimitiveOpType::SumAll) ||
|
||||||
|
(op == PrimitiveOpType::SquaredError) ||
|
||||||
|
(op == PrimitiveOpType::CrossEntropyWithSoftmax) ||
|
||||||
|
(op == PrimitiveOpType::ClassificationError) ||
|
||||||
|
(op == PrimitiveOpType::Logistic))
|
||||||
|
{
|
||||||
outputDynamicAxes = std::vector<Axis>({});
|
outputDynamicAxes = std::vector<Axis>({});
|
||||||
|
}
|
||||||
else if (op == PrimitiveOpType::Where)
|
else if (op == PrimitiveOpType::Where)
|
||||||
{
|
{
|
||||||
if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewDynamicAxes))
|
if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewDynamicAxes))
|
||||||
|
@ -889,9 +895,9 @@ namespace CNTK
|
||||||
case PrimitiveOpType::Convolution:
|
case PrimitiveOpType::Convolution:
|
||||||
{
|
{
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
auto& strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
|
auto& strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
|
||||||
auto& lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
|
auto& lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
|
||||||
auto& upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
|
auto& upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
|
||||||
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
|
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
|
||||||
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
|
||||||
bool transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
|
bool transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
|
||||||
|
@ -900,23 +906,24 @@ namespace CNTK
|
||||||
|
|
||||||
NDShape outputMapCount, kernelShape;
|
NDShape outputMapCount, kernelShape;
|
||||||
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(inputs[0].Shape(), inputs[1].Shape());
|
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(inputs[0].Shape(), inputs[1].Shape());
|
||||||
auto originalKernelShape = kernelShape;
|
auto originalKernelShape = kernelShape;
|
||||||
outputShape = ConvolutionOpOutputShape(op, inputs[1].Shape(), kernelShape, outputMapCount, strides, sharing, autoPadding, lowerPad, upperPad, transpose, inferDimensions);
|
outputShape = ConvolutionOpOutputShape(op, inputs[1].Shape(), kernelShape, outputMapCount, strides, sharing, autoPadding, lowerPad, upperPad, transpose, inferDimensions);
|
||||||
if (originalKernelShape != kernelShape)
|
if (originalKernelShape != kernelShape)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < kernelShape.Rank(); ++i)
|
for (size_t i = 0; i < kernelShape.Rank(); ++i)
|
||||||
inputs[0].m_dataFields->m_shape[i] = kernelShape[i];
|
inputs[0].m_dataFields->m_shape[i] = kernelShape[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
functionConfig[PrimitiveFunction::AttributeNameSharing] = AsDictionaryValueVector(sharing);
|
functionConfig[PrimitiveFunction::AttributeNameSharing] = AsDictionaryValueVector(sharing);
|
||||||
functionConfig[PrimitiveFunction::AttributeNameAutoPadding] = AsDictionaryValueVector(autoPadding);
|
functionConfig[PrimitiveFunction::AttributeNameAutoPadding] = AsDictionaryValueVector(autoPadding);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case PrimitiveOpType::Logistic:
|
||||||
case PrimitiveOpType::SquaredError:
|
case PrimitiveOpType::SquaredError:
|
||||||
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
case PrimitiveOpType::CrossEntropyWithSoftmax:
|
||||||
case PrimitiveOpType::ClassificationError:
|
case PrimitiveOpType::ClassificationError:
|
||||||
{
|
{
|
||||||
if (op == PrimitiveOpType::ClassificationError)
|
if ((op == PrimitiveOpType::ClassificationError) || (op == PrimitiveOpType::Logistic))
|
||||||
assert(inputs.size() >= 2);
|
assert(inputs.size() >= 2);
|
||||||
else
|
else
|
||||||
assert(inputs.size() == 2);
|
assert(inputs.size() == 2);
|
||||||
|
@ -929,9 +936,9 @@ namespace CNTK
|
||||||
if (predictionShape != labelsShape)
|
if (predictionShape != labelsShape)
|
||||||
RuntimeError("Prediction output operand's shape %S is incompatible with label operand's shape %S for the %S operation", AsStringForErrorReporting(predictionShape).c_str(), AsStringForErrorReporting(labelsShape).c_str(), PrimitiveOpTypeName(op).c_str());
|
RuntimeError("Prediction output operand's shape %S is incompatible with label operand's shape %S for the %S operation", AsStringForErrorReporting(predictionShape).c_str(), AsStringForErrorReporting(labelsShape).c_str(), PrimitiveOpTypeName(op).c_str());
|
||||||
|
|
||||||
std::vector<int> reductionAxes;
|
std::vector<int> reductionAxes;
|
||||||
for (int i = 0; i < (int)inputs[0].Shape().Rank(); ++i)
|
for (int i = 0; i < (int)inputs[0].Shape().Rank(); ++i)
|
||||||
reductionAxes.push_back(i);
|
reductionAxes.push_back(i);
|
||||||
|
|
||||||
outputShape = ReductionOpOutputShape(op, predictionShape, reductionAxes, /*preserveReductionAxes =*/ false);
|
outputShape = ReductionOpOutputShape(op, predictionShape, reductionAxes, /*preserveReductionAxes =*/ false);
|
||||||
break;
|
break;
|
||||||
|
@ -1613,6 +1620,9 @@ namespace CNTK
|
||||||
computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples);
|
computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case PrimitiveOpType::Logistic:
|
||||||
|
computationNodePtr = New<LogisticNode<ElementType>>(network->GetDeviceId(), internalNodeName);
|
||||||
|
break;
|
||||||
case PrimitiveOpType::SquaredError:
|
case PrimitiveOpType::SquaredError:
|
||||||
computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
|
computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
|
||||||
break;
|
break;
|
||||||
|
@ -2792,6 +2802,18 @@ namespace CNTK
|
||||||
return BinaryOp(PrimitiveOpType::TransposeTimes, leftOperand, rightOperand, std::move(additionalProperties), name);
|
return BinaryOp(PrimitiveOpType::TransposeTimes, leftOperand, rightOperand, std::move(additionalProperties), name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FunctionPtr BinaryCrossEntropy(const Variable& prediction, const Variable& targets, const std::wstring& name)
|
||||||
|
{
|
||||||
|
std::vector<Variable> operands = { prediction, targets };
|
||||||
|
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Logistic, operands, Dictionary(), name), name);
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionPtr WeightedBinaryCrossEntropy(const Variable& prediction, const Variable& targets, const Variable& weights, const std::wstring& name)
|
||||||
|
{
|
||||||
|
std::vector<Variable> operands = { prediction, targets, weights };
|
||||||
|
return CompositeFunction::Create(MakeSharedObject<PrimitiveFunction>(PrimitiveOpType::Logistic, operands, Dictionary(), name), name);
|
||||||
|
}
|
||||||
|
|
||||||
FunctionPtr SquaredError(const Variable& prediction, const Variable& targets, const std::wstring& name)
|
FunctionPtr SquaredError(const Variable& prediction, const Variable& targets, const std::wstring& name)
|
||||||
{
|
{
|
||||||
auto difference = Minus(prediction, targets);
|
auto difference = Minus(prediction, targets);
|
||||||
|
|
|
@ -65,7 +65,7 @@ namespace CNTK
|
||||||
{PrimitiveOpType::Times, L"Times"},
|
{PrimitiveOpType::Times, L"Times"},
|
||||||
{PrimitiveOpType::TransposeTimes, L"TransposeTimes"},
|
{PrimitiveOpType::TransposeTimes, L"TransposeTimes"},
|
||||||
{PrimitiveOpType::Convolution, L"Convolution"},
|
{PrimitiveOpType::Convolution, L"Convolution"},
|
||||||
{PrimitiveOpType::SquaredError, L"SquaredError"},
|
{ PrimitiveOpType::SquaredError, L"SquaredError" },
|
||||||
{PrimitiveOpType::CrossEntropyWithSoftmax, L"CrossEntropyWithSoftmax"},
|
{PrimitiveOpType::CrossEntropyWithSoftmax, L"CrossEntropyWithSoftmax"},
|
||||||
{PrimitiveOpType::ClassificationError, L"ClassificationError"},
|
{PrimitiveOpType::ClassificationError, L"ClassificationError"},
|
||||||
{PrimitiveOpType::PastValue, L"PastValue"},
|
{PrimitiveOpType::PastValue, L"PastValue"},
|
||||||
|
@ -79,6 +79,7 @@ namespace CNTK
|
||||||
{PrimitiveOpType::RandomSample, L"RandomSample"},
|
{PrimitiveOpType::RandomSample, L"RandomSample"},
|
||||||
{PrimitiveOpType::RandomSampleInclusionFrequency, L"RandomSampleInclusionFrequency"},
|
{PrimitiveOpType::RandomSampleInclusionFrequency, L"RandomSampleInclusionFrequency"},
|
||||||
{PrimitiveOpType::ROIPooling, L"ROIPooling"},
|
{PrimitiveOpType::ROIPooling, L"ROIPooling"},
|
||||||
|
{ PrimitiveOpType::Logistic, L"Logistic" },
|
||||||
};
|
};
|
||||||
|
|
||||||
inline const std::wstring& PrimitiveOpTypeName(PrimitiveOpType opType)
|
inline const std::wstring& PrimitiveOpTypeName(PrimitiveOpType opType)
|
||||||
|
@ -103,7 +104,15 @@ namespace CNTK
|
||||||
if (numFunctionInputs > 2)
|
if (numFunctionInputs > 2)
|
||||||
indexMap.insert({2, 2});
|
indexMap.insert({2, 2});
|
||||||
}
|
}
|
||||||
else if ((op == PrimitiveOpType::CrossEntropyWithSoftmax) || (op == PrimitiveOpType::GatherPacked))
|
else if (op == PrimitiveOpType::Logistic)
|
||||||
|
{
|
||||||
|
indexMap = std::unordered_map<size_t, size_t>({ { 0, 1 }, { 1, 0 } });
|
||||||
|
if (numFunctionInputs > 2)
|
||||||
|
indexMap.insert({ 2, 2 });
|
||||||
|
}
|
||||||
|
else if (op == PrimitiveOpType::CrossEntropyWithSoftmax)
|
||||||
|
indexMap = std::unordered_map<size_t, size_t>({ { 0, 1 }, { 1, 0 } });
|
||||||
|
else if (op == PrimitiveOpType::GatherPacked)
|
||||||
indexMap = std::unordered_map<size_t, size_t>({ { 0, 1 }, { 1, 0 } });
|
indexMap = std::unordered_map<size_t, size_t>({ { 0, 1 }, { 1, 0 } });
|
||||||
else if (op == PrimitiveOpType::ScatterPacked)
|
else if (op == PrimitiveOpType::ScatterPacked)
|
||||||
indexMap = std::unordered_map<size_t, size_t>({ { 0, 2 }, { 1, 1 }, { 2, 0 } });
|
indexMap = std::unordered_map<size_t, size_t>({ { 0, 2 }, { 1, 1 }, { 2, 0 } });
|
||||||
|
|
|
@ -57,6 +57,7 @@ namespace CNTK
|
||||||
RandomSample = 45,
|
RandomSample = 45,
|
||||||
RandomSampleInclusionFrequency = 46,
|
RandomSampleInclusionFrequency = 46,
|
||||||
ROIPooling = 47,
|
ROIPooling = 47,
|
||||||
|
Logistic = 48,
|
||||||
// New op types should only be appended to the end of this list.
|
// New op types should only be appended to the end of this list.
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -395,7 +395,7 @@ public:
|
||||||
// If input data is sparse, then gradient is block sparse.
|
// If input data is sparse, then gradient is block sparse.
|
||||||
if (InputRef(1).Value().GetMatrixType() == SPARSE && InputRef(0).Gradient().GetMatrixType() == DENSE && Gradient().GetMatrixType() == DENSE)
|
if (InputRef(1).Value().GetMatrixType() == SPARSE && InputRef(0).Gradient().GetMatrixType() == DENSE && Gradient().GetMatrixType() == DENSE)
|
||||||
{
|
{
|
||||||
// We need a sparse matrix for the gradient. However, we should allocate a new one instead of switching the type in place
|
// We need a sparse matrix for the gradient. We allocate a new one instead of switching the type in place
|
||||||
// since switching in place may affect other nodes who share this matrix due to memory sharing
|
// since switching in place may affect other nodes who share this matrix due to memory sharing
|
||||||
auto& currentInput0GradientMatrixRef = InputRef(0).Gradient();
|
auto& currentInput0GradientMatrixRef = InputRef(0).Gradient();
|
||||||
auto newInput0SparseGradientMatrix = std::make_shared<Matrix<ElemType>>(currentInput0GradientMatrixRef.GetNumRows(),
|
auto newInput0SparseGradientMatrix = std::make_shared<Matrix<ElemType>>(currentInput0GradientMatrixRef.GetNumRows(),
|
||||||
|
@ -556,7 +556,7 @@ public:
|
||||||
{
|
{
|
||||||
Input(0)->CreateGradientMatrixIfNull();
|
Input(0)->CreateGradientMatrixIfNull();
|
||||||
|
|
||||||
// We need a sparse matrix for the gradient. However, we should allocate a new one instead of switching the type in place
|
// We need a sparse matrix for the gradient. We allocate a new one instead of switching the type in place
|
||||||
// since switching in place may affect other nodes who share this matrix due to memory sharing
|
// since switching in place may affect other nodes who share this matrix due to memory sharing
|
||||||
auto& currentInput0GradientMatrixRef = InputRef(0).Gradient();
|
auto& currentInput0GradientMatrixRef = InputRef(0).Gradient();
|
||||||
if (currentInput0GradientMatrixRef.GetMatrixType() != SPARSE)
|
if (currentInput0GradientMatrixRef.GetMatrixType() != SPARSE)
|
||||||
|
|
|
@ -126,7 +126,7 @@ void RandomSampleNode<ElemType>::ForwardPropNonLooping()
|
||||||
if (ValueAsMatrix().GetMatrixType() != SPARSE)
|
if (ValueAsMatrix().GetMatrixType() != SPARSE)
|
||||||
{
|
{
|
||||||
// BUGBUG: matrix type should be configured during validation
|
// BUGBUG: matrix type should be configured during validation
|
||||||
// We should allocate a new one instead of switching the type in place since switching in place may
|
// Note: We allocate a new one instead of switching the type in place since switching in place may
|
||||||
// affect other nodes who share this matrix due to memory sharing
|
// affect other nodes who share this matrix due to memory sharing
|
||||||
auto newSparseValueMatrix = std::make_shared<Matrix<ElemType>>(ValueAsMatrix().GetNumRows(), ValueAsMatrix().GetNumCols(), CPUDEVICE, SPARSE, matrixFormatSparseCSC);
|
auto newSparseValueMatrix = std::make_shared<Matrix<ElemType>>(ValueAsMatrix().GetNumRows(), ValueAsMatrix().GetNumCols(), CPUDEVICE, SPARSE, matrixFormatSparseCSC);
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
|
@ -140,10 +140,7 @@ void RandomSampleNode<ElemType>::ForwardPropNonLooping()
|
||||||
|
|
||||||
// TODO: Should we prepare the CSC data directly on the CPU and move it in one go?
|
// TODO: Should we prepare the CSC data directly on the CPU and move it in one go?
|
||||||
// Currently the reader will place the data onto the GPU. It will then be pulled on-demand to the CPU once (and cached there).
|
// Currently the reader will place the data onto the GPU. It will then be pulled on-demand to the CPU once (and cached there).
|
||||||
valueMatrix.TransferToDeviceIfNotThere(CPUDEVICE, /*ismoved =*/ true/*means: BOTH state not ok */, /*emptyTransfer =*/ true, /*updatePreferredDevice =*/ false);
|
valueMatrix.TransferToDeviceIfNotThere(CPUDEVICE, /*ismoved =*/ true/*means: BOTH state not ok */, /*emptyTransfer =*/ true, /*updatePreferredDevice =*/ true);
|
||||||
|
|
||||||
// BUGUBUG: This is a no-op; was the intent to change the preferred device to CPU?
|
|
||||||
valueMatrix.SetDevice(CPUDEVICE);
|
|
||||||
valueMatrix.Reset();
|
valueMatrix.Reset();
|
||||||
|
|
||||||
// Get vector with indices of randomly sampled classes
|
// Get vector with indices of randomly sampled classes
|
||||||
|
|
|
@ -59,9 +59,52 @@ def alias(x, name=''):
|
||||||
return alias(x, name)
|
return alias(x, name)
|
||||||
|
|
||||||
##########################################################################
|
##########################################################################
|
||||||
# evaluation ops
|
# loss and evaluation ops
|
||||||
##########################################################################
|
##########################################################################
|
||||||
|
|
||||||
|
@typemap
|
||||||
|
def binary_cross_entropy(output, target, name=''):
|
||||||
|
r'''
|
||||||
|
This operation computes the binary cross entropy between the ``output`` and ``target``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
TBA
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: the computed posterior probability from the network
|
||||||
|
target: ground-truth label, 0 or 1
|
||||||
|
name (`str`, optional): the name of the Function instance in the network
|
||||||
|
Returns:
|
||||||
|
:class:`cntk.ops.functions.Function`
|
||||||
|
'''
|
||||||
|
from cntk.cntk_py import binary_cross_entropy
|
||||||
|
dtype = get_data_type(output, target)
|
||||||
|
output = sanitize_input(output, dtype)
|
||||||
|
target = sanitize_input(target, dtype)
|
||||||
|
return binary_cross_entropy(output, target, name)
|
||||||
|
|
||||||
|
@typemap
|
||||||
|
def weighted_binary_cross_entropy(output, target, weight, name=''):
|
||||||
|
r'''
|
||||||
|
This operation computes the weighted binary cross entropy between the ``output`` and ``target``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
TBA
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: the computed posterior probability from the network
|
||||||
|
target: ground-truth label, 0 or 1
|
||||||
|
weight: weight of each example
|
||||||
|
name (`str`, optional): the name of the Function instance in the network
|
||||||
|
Returns:
|
||||||
|
:class:`cntk.ops.functions.Function`
|
||||||
|
'''
|
||||||
|
from cntk.cntk_py import weighted_binary_cross_entropy
|
||||||
|
dtype = get_data_type(output, target, weight)
|
||||||
|
output = sanitize_input(output, dtype)
|
||||||
|
target = sanitize_input(target, dtype)
|
||||||
|
weight = sanitize_input(weight, dtype)
|
||||||
|
return weighted_binary_cross_entropy(output, target, weight, name)
|
||||||
|
|
||||||
@typemap
|
@typemap
|
||||||
def cross_entropy_with_softmax(output_vector, target_vector, axis=-1, name=''):
|
def cross_entropy_with_softmax(output_vector, target_vector, axis=-1, name=''):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче