зеркало из https://github.com/microsoft/caffe.git
Use macro to simplify element wise cpu math functions
This commit is contained in:
Родитель
ccae3fa587
Коммит
b458b41d68
|
@ -126,22 +126,33 @@ inline char caffe_sign(Dtype val) {
|
|||
return (Dtype(0) < val) - (val < Dtype(0));
|
||||
}
|
||||
|
||||
template<typename Dtype>
|
||||
void caffe_cpu_sign(const int n, const Dtype* x, Dtype* y) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = caffe_sign<Dtype>(x[i]);
|
||||
// The following two macros are modifications of DEFINE_VSL_UNARY_FUNC
|
||||
// in include/caffe/util/mkl_alternate.hpp authored by @Rowland Depp.
|
||||
// Please refer to commit 7e8ef25c7 of the boost-eigen branch.
|
||||
// Git cherry picking that commit caused a conflict hard to resolve and
|
||||
// copying that file in convenient for code reviewing.
|
||||
// So they have to be pasted here temporarily.
|
||||
#define DEFINE_CAFFE_CPU_UNARY_FUNC(name, operation) \
|
||||
template<typename Dtype> \
|
||||
void caffe_cpu_##name(const int n, const Dtype* x, Dtype* y) { \
|
||||
CHECK_GT(n, 0); CHECK(x); CHECK(y); \
|
||||
for (int i = 0; i < n; ++i) { \
|
||||
operation; \
|
||||
} \
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_CAFFE_CPU_UNARY_FUNC(name) \
|
||||
template <> \
|
||||
void caffe_cpu_##name<float>(const int n, const float* x, float* y); \
|
||||
template <> \
|
||||
void caffe_cpu_##name<double>(const int n, const double* x, double* y)
|
||||
|
||||
DEFINE_CAFFE_CPU_UNARY_FUNC(sign, y[i] = caffe_sign<Dtype>(x[i]));
|
||||
|
||||
template<typename Dtype>
|
||||
void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y);
|
||||
|
||||
template <typename Dtype>
|
||||
void caffe_cpu_fabs(const int n, const Dtype* x, Dtype* y) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = std::fabs(x[i]);
|
||||
}
|
||||
}
|
||||
DEFINE_CAFFE_CPU_UNARY_FUNC(fabs, y[i] = std::fabs(x[i]));
|
||||
|
||||
template <typename Dtype>
|
||||
void caffe_gpu_fabs(const int n, const Dtype* x, Dtype* y);
|
||||
|
|
|
@ -410,16 +410,7 @@ void caffe_gpu_asum<double>(const int n, const double* x, double* y) {
|
|||
CUBLAS_CHECK(cublasDasum(Caffe::cublas_handle(), n, x, 1, y));
|
||||
}
|
||||
|
||||
template <>
|
||||
void caffe_cpu_sign<float>(const int n, const float* x, float* y);
|
||||
|
||||
template <>
|
||||
void caffe_cpu_sign<double>(const int n, const double* x, double* y);
|
||||
|
||||
template <>
|
||||
void caffe_cpu_fabs<float>(const int n, const float* x, float* y);
|
||||
|
||||
template <>
|
||||
void caffe_cpu_fabs<double>(const int n, const double* x, double* y);
|
||||
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(sign);
|
||||
INSTANTIATE_CAFFE_CPU_UNARY_FUNC(fabs);
|
||||
|
||||
} // namespace caffe
|
||||
|
|
Загрузка…
Ссылка в новой задаче