This commit is contained in:
Yangqing Jia 2013-09-13 09:48:34 -07:00
Родитель af8f6ad65c
Коммит 690332bfac
6 изменённых файлов: 132 добавлений и 86 удалений

Просмотреть файл

@ -41,7 +41,7 @@ all: $(NAME)
test: $(TEST_NAME)
$(TEST_NAME): $(OBJS) $(TEST_OBJS)
$(CXX) $(TEST_OBJS) $(OBJS) -o $(TEST_NAME) $(LDFLAGS) $(WARNINGS)
$(CXX) $(OBJS) $(TEST_OBJS) -o $(TEST_NAME) $(LDFLAGS) $(WARNINGS)
./$(TEST_NAME)
$(NAME): $(PROTO_GEN_CC) $(OBJS)
@ -51,8 +51,8 @@ $(PROTO_GEN_CC): $(PROTO_SRCS)
protoc $(PROTO_SRCS) --cpp_out=.
clean:
$(RM) $(NAME)
$(RM) $(OBJS)
$(RM) $(PROTO_GEN_HEADER) $(PROTO_GEN_CC)
@- $(RM) $(NAME) $(TEST_NAME)
@- $(RM) $(OBJS) $(TEST_OBJS)
@- $(RM) $(PROTO_GEN_HEADER) $(PROTO_GEN_CC)
distclean: clean

Просмотреть файл

@ -1,72 +0,0 @@
#include "caffeine/blob.hpp"
#include "caffeine/common.hpp"
#include "caffeine/syncedmem.hpp"
namespace caffeine {
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
num_ = num;
channels_ = channels;
height_ = height;
width_ = width;
count_ = num_ * channels_ * height_ * width_;
data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
}
template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() {
check_data();
return data_->cpu_data();
}
template <typename Dtype>
const Dtype* Blob<Dtype>::gpu_data() {
check_data();
return data_->gpu_data();
}
template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_diff() {
check_diff();
return diff_->cpu_data();
}
template <typename Dtype>
const Dtype* Blob<Dtype>::gpu_diff() {
check_diff();
return diff_->gpu_data();
}
template <typename Dtype>
Dtype* Blob<Dtype>::mutable_cpu_data() {
check_data();
return data_->mutable_cpu_data();
}
template <typename Dtype>
Dtype* Blob<Dtype>::mutable_gpu_data() {
check_data();
return data_->mutable_gpu_data();
}
template <typename Dtype>
Dtype* Blob<Dtype>::mutable_cpu_diff() {
check_diff();
return diff_->mutable_cpu_data();
}
template <typename Dtype>
Dtype* Blob<Dtype>::mutable_gpu_diff() {
check_diff();
return diff_->mutable_gpu_data();
}
template <typename Dtype>
void Blob<Dtype>::update() {
}
} // namespace caffeine

Просмотреть файл

@ -15,16 +15,19 @@ class Blob {
: num_(0), channels_(0), height_(0), width_(0), count_(0), data_(),
diff_() {};
explicit Blob(const int num, const int channels, const int height,
const int width) {
Reshape(num, channels, height, width);
};
~Blob() {};
const int width)
: num_(num), channels_(channels), height_(height), width_(width),
count_(num * channels * height * width),
data_(new SyncedMemory(count_ * sizeof(Dtype))),
diff_(new SyncedMemory(count_ * sizeof(Dtype))) {};
virtual ~Blob() {};
void Reshape(const int num, const int channels, const int height,
const int width);
const int width);
inline int num() { return num_; }
inline int channels() { return channels_; }
inline int height() { return height_; }
inline int width() { return width_; }
inline int count() {return count_; }
const Dtype* cpu_data();
const Dtype* gpu_data();
@ -47,6 +50,71 @@ class Blob {
int count_;
}; // class Blob
template <typename Dtype>
void Blob<Dtype>::Reshape(const int num, const int channels, const int height,
const int width) {
num_ = num;
channels_ = channels;
height_ = height;
width_ = width;
count_ = num_ * channels_ * height_ * width_;
data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
}
template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() {
check_data();
return data_->cpu_data();
}
template <typename Dtype>
const Dtype* Blob<Dtype>::gpu_data() {
check_data();
return data_->gpu_data();
}
template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_diff() {
check_diff();
return diff_->cpu_data();
}
template <typename Dtype>
const Dtype* Blob<Dtype>::gpu_diff() {
check_diff();
return diff_->gpu_data();
}
template <typename Dtype>
Dtype* Blob<Dtype>::mutable_cpu_data() {
check_data();
return data_->mutable_cpu_data();
}
template <typename Dtype>
Dtype* Blob<Dtype>::mutable_gpu_data() {
check_data();
return data_->mutable_gpu_data();
}
template <typename Dtype>
Dtype* Blob<Dtype>::mutable_cpu_diff() {
check_diff();
return diff_->mutable_cpu_data();
}
template <typename Dtype>
Dtype* Blob<Dtype>::mutable_gpu_diff() {
check_diff();
return diff_->mutable_gpu_data();
}
template <typename Dtype>
void Blob<Dtype>::update() {
}
} // namespace caffeine
#endif // CAFFEINE_BLOB_HPP_

Просмотреть файл

@ -0,0 +1,48 @@
#include <cstring>
#include <cuda_runtime.h>
#include "gtest/gtest.h"
#include "caffeine/common.hpp"
#include "caffeine/blob.hpp"
namespace caffeine {
template <typename Dtype>
class BlobSimpleTest : public ::testing::Test {
protected:
BlobSimpleTest()
: blob_(new Blob<Dtype>()),
blob_preshaped_(new Blob<Dtype>(2, 3, 4, 5)) {};
virtual ~BlobSimpleTest() { delete blob_; delete blob_preshaped_; }
Blob<Dtype>* const blob_;
Blob<Dtype>* const blob_preshaped_;
};
typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(BlobSimpleTest, Dtypes);
TYPED_TEST(BlobSimpleTest, TestInitialization) {
EXPECT_TRUE(this->blob_);
EXPECT_TRUE(this->blob_preshaped_);
EXPECT_EQ(this->blob_preshaped_->num(), 2);
EXPECT_EQ(this->blob_preshaped_->channels(), 3);
EXPECT_EQ(this->blob_preshaped_->height(), 4);
EXPECT_EQ(this->blob_preshaped_->width(), 5);
EXPECT_EQ(this->blob_preshaped_->count(), 120);
EXPECT_EQ(this->blob_->num(), 0);
EXPECT_EQ(this->blob_->channels(), 0);
EXPECT_EQ(this->blob_->height(), 0);
EXPECT_EQ(this->blob_->width(), 0);
EXPECT_EQ(this->blob_->count(), 0);
}
TYPED_TEST(BlobSimpleTest, TestReshape) {
this->blob_->Reshape(2, 3, 4, 5);
EXPECT_EQ(this->blob_->num(), 2);
EXPECT_EQ(this->blob_->channels(), 3);
EXPECT_EQ(this->blob_->height(), 4);
EXPECT_EQ(this->blob_->width(), 5);
EXPECT_EQ(this->blob_->count(), 120);
}
}

Просмотреть файл

@ -0,0 +1,6 @@
#include "gtest/gtest.h"
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

Просмотреть файл

@ -42,11 +42,7 @@ TEST_F(SyncedMemoryTest, TestGPUWrite) {
for (int i = 0; i < mem.size(); ++i) {
EXPECT_EQ(((char*)cpu_data)[i], 1);
}
EXPECT_EQ(mem.head(), SyncedMemory::SYNCED);
}
}
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}