Merge pull request #863 from jeffdonahue/lint-check-caffe-fns

add lint check for functions with Caffe alternatives (memcpy, memset)
This commit is contained in:
Jeff Donahue 2014-08-12 12:54:29 -07:00
Родитель 2d2055a4b2 b2b1ee6317
Коммит 570e6b8ffc
8 изменённых файлов: 67 добавлений и 20 удалений

Просмотреть файл

@ -62,6 +62,7 @@ NONGEN_CXX_SRCS := $(shell find \
examples \
tools \
-name "*.cpp" -or -name "*.hpp" -or -name "*.cu" -or -name "*.cuh")
LINT_SCRIPT := scripts/cpp_lint.py
LINT_OUTPUT_DIR := $(BUILD_DIR)/.lint
LINT_EXT := lint.txt
LINT_OUTPUTS := $(addsuffix .$(LINT_EXT), $(addprefix $(LINT_OUTPUT_DIR)/, $(NONGEN_CXX_SRCS)))
@ -322,9 +323,9 @@ $(EMPTY_LINT_REPORT): $(LINT_OUTPUTS) | $(BUILD_DIR)
$(RM) $(NONEMPTY_LINT_REPORT); \
echo "No lint errors!";
$(LINT_OUTPUTS): $(LINT_OUTPUT_DIR)/%.lint.txt : % | $(LINT_OUTPUT_DIR)
$(LINT_OUTPUTS): $(LINT_OUTPUT_DIR)/%.lint.txt : % $(LINT_SCRIPT) | $(LINT_OUTPUT_DIR)
@ mkdir -p $(dir $@)
@ python ./scripts/cpp_lint.py $< 2>&1 \
@ python $(LINT_SCRIPT) $< 2>&1 \
| grep -v "^Done processing " \
| grep -v "^Total errors found: 0" \
> $@ \

Просмотреть файл

@ -6,6 +6,7 @@
#include "glog/logging.h"
#include "caffe/common.hpp"
#include "caffe/util/device_alternate.hpp"
#include "caffe/util/mkl_alternate.hpp"
@ -38,6 +39,10 @@ void caffe_copy(const int N, const Dtype *X, Dtype *Y);
template <typename Dtype>
void caffe_set(const int N, const Dtype alpha, Dtype *X);
inline void caffe_memset(const size_t N, const int alpha, void* X) {
memset(X, alpha, N); // NOLINT(caffe/alt_fn)
}
template <typename Dtype>
void caffe_add_scalar(const int N, const Dtype alpha, Dtype *X);
@ -165,6 +170,14 @@ void caffe_gpu_memcpy(const size_t N, const void *X, void *Y);
template <typename Dtype>
void caffe_gpu_set(const int N, const Dtype alpha, Dtype *X);
inline void caffe_gpu_memset(const size_t N, const int alpha, void* X) {
#ifndef CPU_ONLY
CUDA_CHECK(cudaMemset(X, alpha, N)); // NOLINT(caffe/alt_fn)
#else
NO_GPU;
#endif
}
template <typename Dtype>
void caffe_gpu_add_scalar(const int N, const Dtype alpha, Dtype *X);

Просмотреть файл

@ -154,6 +154,7 @@ _ERROR_CATEGORIES = [
'build/namespaces',
'build/printf_format',
'build/storage_class',
'caffe/alt_fn',
'caffe/random_fn',
'legal/copyright',
'readability/alt_tokens',
@ -1559,6 +1560,37 @@ def CheckForMultilineCommentsAndStrings(filename, clean_lines, linenum, error):
'Use C++11 raw strings or concatenation instead.')
caffe_alt_function_list = (
('memset', ['caffe_set', 'caffe_memset']),
('cudaMemset', ['caffe_gpu_set', 'caffe_gpu_memset']),
('memcpy', ['caffe_copy', 'caffe_memcpy']),
('cudaMemcpy', ['caffe_copy', 'caffe_gpu_memcpy']),
)
def CheckCaffeAlternatives(filename, clean_lines, linenum, error):
"""Checks for C(++) functions for which a Caffe substitute should be used.
For certain native C functions (memset, memcpy), there is a Caffe alternative
which should be used instead.
Args:
filename: The name of the current file.
clean_lines: A CleansedLines instance containing the file.
linenum: The number of the line to check.
error: The function to call with any errors found.
"""
line = clean_lines.elided[linenum]
for function, alts in caffe_alt_function_list:
ix = line.find(function + '(')
if ix >= 0 and (ix == 0 or (not line[ix - 1].isalnum() and
line[ix - 1] not in ('_', '.', '>'))):
disp_alts = ['%s(...)' % alt for alt in alts]
error(filename, linenum, 'caffe/alt_fn', 2,
'Use Caffe function %s instead of %s(...).' %
(' or '.join(disp_alts), function))
c_random_function_list = (
'rand(',
'rand_r(',
@ -4560,6 +4592,7 @@ def ProcessLine(filename, file_extension, clean_lines, line,
CheckForNonStandardConstructs(filename, clean_lines, line,
nesting_state, error)
CheckVlogArguments(filename, clean_lines, line, error)
CheckCaffeAlternatives(filename, clean_lines, line, error)
CheckCaffeRandom(filename, clean_lines, line, error)
CheckPosixThreading(filename, clean_lines, line, error)
CheckInvalidIncrement(filename, clean_lines, line, error)

Просмотреть файл

@ -103,12 +103,11 @@ Dtype HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
}
current_row_ = 0;
}
memcpy(&(*top)[0]->mutable_cpu_data()[i * data_count],
&data_blob_.cpu_data()[current_row_ * data_count],
sizeof(Dtype) * data_count);
memcpy(&(*top)[1]->mutable_cpu_data()[i * label_data_count],
&label_blob_.cpu_data()[current_row_ * label_data_count],
sizeof(Dtype) * label_data_count);
caffe_copy(data_count, &data_blob_.cpu_data()[current_row_ * data_count],
&(*top)[0]->mutable_cpu_data()[i * data_count]);
caffe_copy(label_data_count,
&label_blob_.cpu_data()[current_row_ * label_data_count],
&(*top)[1]->mutable_cpu_data()[i * label_data_count]);
}
return Dtype(0.);
}

Просмотреть файл

@ -22,7 +22,7 @@ inline void SyncedMemory::to_cpu() {
switch (head_) {
case UNINITIALIZED:
CaffeMallocHost(&cpu_ptr_, size_);
memset(cpu_ptr_, 0, size_);
caffe_memset(size_, 0, cpu_ptr_);
head_ = HEAD_AT_CPU;
own_cpu_data_ = true;
break;
@ -49,7 +49,7 @@ inline void SyncedMemory::to_gpu() {
switch (head_) {
case UNINITIALIZED:
CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
CUDA_CHECK(cudaMemset(gpu_ptr_, 0, size_));
caffe_gpu_memset(size_, 0, gpu_ptr_);
head_ = HEAD_AT_GPU;
break;
case HEAD_AT_CPU:

Просмотреть файл

@ -55,14 +55,14 @@ TEST_F(SyncedMemoryTest, TestCPUWrite) {
SyncedMemory mem(10);
void* cpu_data = mem.mutable_cpu_data();
EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU);
memset(cpu_data, 1, mem.size());
caffe_memset(mem.size(), 1, cpu_data);
for (int i = 0; i < mem.size(); ++i) {
EXPECT_EQ((static_cast<char*>(cpu_data))[i], 1);
}
// do another round
cpu_data = mem.mutable_cpu_data();
EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU);
memset(cpu_data, 2, mem.size());
caffe_memset(mem.size(), 2, cpu_data);
for (int i = 0; i < mem.size(); ++i) {
EXPECT_EQ((static_cast<char*>(cpu_data))[i], 2);
}
@ -74,7 +74,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
SyncedMemory mem(10);
void* cpu_data = mem.mutable_cpu_data();
EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU);
memset(cpu_data, 1, mem.size());
caffe_memset(mem.size(), 1, cpu_data);
const void* gpu_data = mem.gpu_data();
EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
// check if values are the same
@ -86,7 +86,7 @@ TEST_F(SyncedMemoryTest, TestGPURead) {
// do another round
cpu_data = mem.mutable_cpu_data();
EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_CPU);
memset(cpu_data, 2, mem.size());
caffe_memset(mem.size(), 2, cpu_data);
for (int i = 0; i < mem.size(); ++i) {
EXPECT_EQ((static_cast<char*>(cpu_data))[i], 2);
}
@ -104,7 +104,7 @@ TEST_F(SyncedMemoryTest, TestGPUWrite) {
SyncedMemory mem(10);
void* gpu_data = mem.mutable_gpu_data();
EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_GPU);
CUDA_CHECK(cudaMemset(gpu_data, 1, mem.size()));
caffe_gpu_memset(mem.size(), 1, gpu_data);
const void* cpu_data = mem.cpu_data();
for (int i = 0; i < mem.size(); ++i) {
EXPECT_EQ((static_cast<const char*>(cpu_data))[i], 1);
@ -113,7 +113,7 @@ TEST_F(SyncedMemoryTest, TestGPUWrite) {
gpu_data = mem.mutable_gpu_data();
EXPECT_EQ(mem.head(), SyncedMemory::HEAD_AT_GPU);
CUDA_CHECK(cudaMemset(gpu_data, 2, mem.size()));
caffe_gpu_memset(mem.size(), 2, gpu_data);
cpu_data = mem.cpu_data();
for (int i = 0; i < mem.size(); ++i) {
EXPECT_EQ((static_cast<const char*>(cpu_data))[i], 2);

Просмотреть файл

@ -56,7 +56,7 @@ void caffe_axpy<double>(const int N, const double alpha, const double* X,
template <typename Dtype>
void caffe_set(const int N, const Dtype alpha, Dtype* Y) {
if (alpha == 0) {
memset(Y, 0, sizeof(Dtype) * N);
memset(Y, 0, sizeof(Dtype) * N); // NOLINT(caffe/alt_fn)
return;
}
for (int i = 0; i < N; ++i) {
@ -87,12 +87,13 @@ void caffe_copy(const int N, const Dtype* X, Dtype* Y) {
if (X != Y) {
if (Caffe::mode() == Caffe::GPU) {
#ifndef CPU_ONLY
// NOLINT_NEXT_LINE(caffe/alt_fn)
CUDA_CHECK(cudaMemcpy(Y, X, sizeof(Dtype) * N, cudaMemcpyDefault));
#else
NO_GPU;
#endif
} else {
memcpy(Y, X, sizeof(Dtype) * N);
memcpy(Y, X, sizeof(Dtype) * N); // NOLINT(caffe/alt_fn)
}
}
}

Просмотреть файл

@ -78,7 +78,7 @@ void caffe_gpu_axpy<double>(const int N, const double alpha, const double* X,
void caffe_gpu_memcpy(const size_t N, const void* X, void* Y) {
if (X != Y) {
CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault));
CUDA_CHECK(cudaMemcpy(Y, X, N, cudaMemcpyDefault)); // NOLINT(caffe/alt_fn)
}
}
@ -152,7 +152,7 @@ __global__ void set_kernel(const int n, const Dtype alpha, Dtype* y) {
template <typename Dtype>
void caffe_gpu_set(const int N, const Dtype alpha, Dtype* Y) {
if (alpha == 0) {
CUDA_CHECK(cudaMemset(Y, 0, sizeof(Dtype) * N));
CUDA_CHECK(cudaMemset(Y, 0, sizeof(Dtype) * N)); // NOLINT(caffe/alt_fn)
return;
}
// NOLINT_NEXT_LINE(whitespace/operators)