fast tanh (#3255)
This commit is contained in:
Родитель
29b0b4c11d
Коммит
165aa0dbbd
|
@ -31,6 +31,7 @@
|
|||
#include "tvm/tvm.h"
|
||||
#include "tvm/ir.h"
|
||||
#include "tvm/ir_pass.h"
|
||||
#include "broadcast.h"
|
||||
|
||||
namespace topi {
|
||||
using namespace tvm;
|
||||
|
@ -46,7 +47,6 @@ using namespace tvm;
|
|||
}
|
||||
|
||||
TOPI_DECLARE_UNARY_OP(exp);
|
||||
TOPI_DECLARE_UNARY_OP(tanh);
|
||||
TOPI_DECLARE_UNARY_OP(sigmoid);
|
||||
TOPI_DECLARE_UNARY_OP(sqrt);
|
||||
TOPI_DECLARE_UNARY_OP(log);
|
||||
|
@ -56,6 +56,74 @@ TOPI_DECLARE_UNARY_OP(round);
|
|||
TOPI_DECLARE_UNARY_OP(trunc);
|
||||
TOPI_DECLARE_UNARY_OP(abs);
|
||||
|
||||
/*
|
||||
* \brief Fast_tanh_float implementation from Eigen
|
||||
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
|
||||
*/
|
||||
inline Tensor fast_tanh_float(const Tensor& in,
|
||||
std::string name,
|
||||
std::string tag) {
|
||||
// Clamp the inputs to the range [-9, 9] since anything outside
|
||||
// this range is +/-1.0f in single-precision.
|
||||
auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0));
|
||||
|
||||
// The monomial coefficients of the numerator polynomial (odd).
|
||||
auto alpha_1 = make_const(in->dtype, 4.89352455891786e-03);
|
||||
auto alpha_3 = make_const(in->dtype, 6.37261928875436e-04);
|
||||
auto alpha_5 = make_const(in->dtype, 1.48572235717979e-05);
|
||||
auto alpha_7 = make_const(in->dtype, 5.12229709037114e-08);
|
||||
auto alpha_9 = make_const(in->dtype, -8.60467152213735e-11);
|
||||
auto alpha_11 = make_const(in->dtype, 2.00018790482477e-13);
|
||||
auto alpha_13 = make_const(in->dtype, -2.76076847742355e-16);
|
||||
|
||||
// The monomial coefficients of the denominator polynomial (even).
|
||||
auto beta_0 = make_const(in->dtype, 4.89352518554385e-03);
|
||||
auto beta_2 = make_const(in->dtype, 2.26843463243900e-03);
|
||||
auto beta_4 = make_const(in->dtype, 1.18534705686654e-04);
|
||||
auto beta_6 = make_const(in->dtype, 1.19825839466702e-06);
|
||||
|
||||
return compute(x->shape,
|
||||
[&](const Array<Var>& i) {
|
||||
auto x2 = x(i) * x(i);
|
||||
auto p = x2 * alpha_13 + alpha_11;
|
||||
p = x2 * p + alpha_9;
|
||||
p = x2 * p + alpha_7;
|
||||
p = x2 * p + alpha_5;
|
||||
p = x2 * p + alpha_3;
|
||||
p = x2 * p + alpha_1;
|
||||
p = x(i) * p;
|
||||
|
||||
auto q = x2 * beta_6 + beta_4;
|
||||
q = x2 * q + beta_2;
|
||||
q = x2 * q + beta_0;
|
||||
return p / q;
|
||||
},
|
||||
name, tag);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Creates an operation that returns hyperbolic tanh of a given tensor
|
||||
*
|
||||
* \param x The input tensor
|
||||
* \param name The name of the operation
|
||||
* \param tag The tag to mark the operation
|
||||
*
|
||||
* \return A Tensor whose op member is tanh
|
||||
*/
|
||||
inline Tensor tanh(const Tensor& x,
|
||||
std::string name = "T_tanh",
|
||||
std::string tag = kElementWise) {
|
||||
if (x->dtype == Float(32)) {
|
||||
// invoke fast_tanh_float implementation
|
||||
return fast_tanh_float(x, name, tag);
|
||||
} else {
|
||||
// fallback to default implementation
|
||||
return compute(x->shape, [&](const Array<Var>& i) {
|
||||
return ::tvm::tanh(x(i));
|
||||
}, name, tag);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Creates an operation that returns identity of a given tensor
|
||||
*
|
||||
|
|
|
@ -29,13 +29,21 @@ def test_util():
|
|||
|
||||
|
||||
def test_ewise():
|
||||
m = tvm.var('m')
|
||||
l = tvm.var('l')
|
||||
A = tvm.placeholder((m, l), name='A')
|
||||
def test_apply(
|
||||
func,
|
||||
name,
|
||||
f_numpy,
|
||||
low,
|
||||
high,
|
||||
shape=(20, 3),
|
||||
dtype=tvm.float32,
|
||||
check_round=False,
|
||||
skip_name_check=False,
|
||||
):
|
||||
m = tvm.var("m")
|
||||
l = tvm.var("l")
|
||||
A = tvm.placeholder((m, l), dtype=dtype, name="A")
|
||||
|
||||
shape = (20, 3)
|
||||
|
||||
def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False):
|
||||
B = func(A)
|
||||
assert tuple(B.shape) == tuple(A.shape)
|
||||
if not skip_name_check:
|
||||
|
@ -63,7 +71,6 @@ def test_ewise():
|
|||
for device in get_all_backend():
|
||||
check_device(device)
|
||||
|
||||
|
||||
test_apply(topi.floor, "floor", np.floor, -100, 100)
|
||||
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
|
||||
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
|
||||
|
@ -71,11 +78,12 @@ def test_ewise():
|
|||
test_apply(topi.abs, "fabs", np.abs, -100, 100)
|
||||
test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
|
||||
test_apply(topi.exp, "exp", np.exp, -1, 1)
|
||||
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
|
||||
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
|
||||
test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128))
|
||||
test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128), dtype="float64")
|
||||
test_apply(topi.sigmoid, "sigmoid", lambda x: 1 / (1 + np.exp(-x)), -1, 1)
|
||||
test_apply(topi.log, "log", np.log, 0, 100)
|
||||
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
|
||||
test_apply(topi.rsqrt, "rsqrt", lambda x:np.ones_like(x)/np.sqrt(x), 0, 100, skip_name_check=True)
|
||||
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
|
||||
|
||||
|
||||
def test_cast():
|
||||
|
|
Загрузка…
Ссылка в новой задаче