Add SoftMax and GELU activation functions to DMLX (#573)

This commit is contained in:
Patrice Vignola 2024-04-11 14:48:39 -07:00 коммит произвёл GitHub
Родитель 0bd9f4f0c7
Коммит d3918ea66a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 19 добавлений и 0 удалений

Просмотреть файл

@ -959,6 +959,13 @@ namespace dml
{
return FusedActivation(DML_OPERATOR_ACTIVATION_CELU, alpha);
}
#if DML_TARGET_VERSION >= 0x5100
static FusedActivation Gelu()
{
return FusedActivation(DML_OPERATOR_ACTIVATION_GELU);
}
#endif // DML_TARGET_VERSION >= 0x5100
};
// Implementation detail helper for determining if a list of expressions share the same GraphBuilder.
@ -1855,6 +1862,18 @@ namespace dml
DMLX_ACTIVATION_IMPL(ACTIVATION_SOFTMAX);
}
#if DML_TARGET_VERSION >= 0x5100
inline Expression ActivationSoftmax(Expression input, Span<const uint32_t> axes)
{
DMLX_ACTIVATION_IMPL_2(ACTIVATION_SOFTMAX1, AxisCount, static_cast<uint32_t>(axes.size()), Axes, axes.data());
}
inline Expression ActivationGelu(Expression input)
{
DMLX_ACTIVATION_IMPL(ACTIVATION_GELU);
}
#endif
inline Expression ActivationSoftplus(Expression input, float steepness = 1.0f)
{
DMLX_ACTIVATION_IMPL_1(ACTIVATION_SOFTPLUS, Steepness, steepness);