Add SoftMax and GELU activation functions to DMLX (#573)
This commit is contained in:
Родитель
0bd9f4f0c7
Коммит
d3918ea66a
|
@ -959,6 +959,13 @@ namespace dml
|
|||
{
|
||||
return FusedActivation(DML_OPERATOR_ACTIVATION_CELU, alpha);
|
||||
}
|
||||
|
||||
#if DML_TARGET_VERSION >= 0x5100
|
||||
static FusedActivation Gelu()
|
||||
{
|
||||
return FusedActivation(DML_OPERATOR_ACTIVATION_GELU);
|
||||
}
|
||||
#endif // DML_TARGET_VERSION >= 0x5100
|
||||
};
|
||||
|
||||
// Implementation detail helper for determining if a list of expressions share the same GraphBuilder.
|
||||
|
@ -1855,6 +1862,18 @@ namespace dml
|
|||
DMLX_ACTIVATION_IMPL(ACTIVATION_SOFTMAX);
|
||||
}
|
||||
|
||||
#if DML_TARGET_VERSION >= 0x5100
|
||||
inline Expression ActivationSoftmax(Expression input, Span<const uint32_t> axes)
|
||||
{
|
||||
DMLX_ACTIVATION_IMPL_2(ACTIVATION_SOFTMAX1, AxisCount, static_cast<uint32_t>(axes.size()), Axes, axes.data());
|
||||
}
|
||||
|
||||
inline Expression ActivationGelu(Expression input)
|
||||
{
|
||||
DMLX_ACTIVATION_IMPL(ACTIVATION_GELU);
|
||||
}
|
||||
#endif
|
||||
|
||||
inline Expression ActivationSoftplus(Expression input, float steepness = 1.0f)
|
||||
{
|
||||
DMLX_ACTIVATION_IMPL_1(ACTIVATION_SOFTPLUS, Steepness, steepness);
|
||||
|
|
Загрузка…
Ссылка в новой задаче