diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index ab1cee17..fd9de876 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -119,6 +119,23 @@ Dtype caffe_cpu_asum(const int n, const Dtype* x); template void caffe_gpu_asum(const int n, const Dtype* x, Dtype* y); +// the branchless, type-safe version from +// http://stackoverflow.com/questions/1903954/is-there-a-standard-sign-function-signum-sgn-in-c-c +template +inline char caffe_sign(Dtype val) { + return (Dtype(0) < val) - (val < Dtype(0)); +} + +template +void caffe_cpu_sign(const int n, const Dtype* x, Dtype* y) { + for (int i = 0; i < n; ++i) { + y[i] = caffe_sign(x[i]); + } +} + +template +void caffe_gpu_sign(const int n, const Dtype* x, Dtype* y); + } // namespace caffe diff --git a/src/caffe/test/test_math_functions.cpp b/src/caffe/test/test_math_functions.cpp index ba8bfe72..09b4aa67 100644 --- a/src/caffe/test/test_math_functions.cpp +++ b/src/caffe/test/test_math_functions.cpp @@ -1,7 +1,7 @@ // Copyright 2014 kloudkl@github -#include // for uint32_t & uint64_t -#include // for std::fabs +#include // for uint32_t & uint64_t +#include // for std::fabs #include "gtest/gtest.h" #include "caffe/blob.hpp" @@ -67,7 +67,7 @@ REF_HAMMING_DIST(double, uint64_t); typedef ::testing::Types Dtypes; TYPED_TEST_CASE(MathFunctionsTest, Dtypes); -TYPED_TEST(MathFunctionsTest, TestHammingDistance) { +TYPED_TEST(MathFunctionsTest, TestHammingDistance){ int n = this->blob_bottom_->count(); const TypeParam* x = this->blob_bottom_->cpu_data(); const TypeParam* y = this->blob_top_->cpu_data(); @@ -98,4 +98,25 @@ TYPED_TEST(MathFunctionsTest, TestAsumGPU){ CHECK_LT((gpu_asum - std_asum) / std_asum, 1e-2); } -} // namespace caffe +TYPED_TEST(MathFunctionsTest, TestSignCPU){ + int n = this->blob_bottom_->count(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + caffe_cpu_sign(n, x, this->blob_bottom_->mutable_cpu_diff()); + const TypeParam* signs = this->blob_bottom_->cpu_diff(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0)); + } +} + +TYPED_TEST(MathFunctionsTest, TestSignGPU){ + int n = this->blob_bottom_->count(); + caffe_gpu_sign(n, this->blob_bottom_->gpu_data(), + this->blob_bottom_->mutable_gpu_diff()); + const TypeParam* signs = this->blob_bottom_->cpu_diff(); + const TypeParam* x = this->blob_bottom_->cpu_data(); + for (int i = 0; i < n; ++i) { + CHECK_EQ(signs[i], x[i] > 0 ? 1 : (x[i] < 0 ? -1 : 0)); + } +} + +} diff --git a/src/caffe/util/math_functions.cu b/src/caffe/util/math_functions.cu index 5491e246..5aff39fd 100644 --- a/src/caffe/util/math_functions.cu +++ b/src/caffe/util/math_functions.cu @@ -1,4 +1,5 @@ // Copyright 2013 Yangqing Jia +// Copyright 2014 kloudkl@github #include #include @@ -33,5 +34,24 @@ void caffe_gpu_mul(const int N, const double* a, N, a, b, y); } +template +__global__ void sign_kernel(const int n, const Dtype* x, Dtype* y) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + y[index] = (Dtype(0) < x[index]) - (x[index] < Dtype(0)); + } +} + +template <> +void caffe_gpu_sign(const int n, const float* x, float* y) { + sign_kernel<<>>( + n, x, y); +} + +template <> +void caffe_gpu_sign(const int n, const double* x, double* y) { + sign_kernel<<>>( + n, x, y); +} } // namespace caffe