Merge pull request #252 from kloudkl/hdf5_output_layer

Hdf5 output layer
This commit is contained in:
Sergey Karayev 2014-03-24 00:47:58 -07:00
Родитель 699b557c75 ebf90c31c4
Коммит d3e4c21d91
8 изменённых файлов: 338 добавлений и 0 удалений

Просмотреть файл

@ -15,6 +15,8 @@
using std::string;
using ::google::protobuf::Message;
#define HDF5_NUM_DIMS 4
namespace caffe {
void ReadProtoFromTextFile(const char* filename,
@ -60,6 +62,10 @@ void hdf5_load_nd_dataset(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
Blob<Dtype>* blob);
template <typename Dtype>
void hdf5_save_nd_dataset(
const hid_t file_id, const string dataset_name, const Blob<Dtype>& blob);
} // namespace caffe
#endif // CAFFE_UTIL_IO_H_

Просмотреть файл

@ -15,6 +15,9 @@
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"
#define HDF5_DATA_DATASET_NAME "data"
#define HDF5_DATA_LABEL_NAME "label"
namespace caffe {
@ -477,6 +480,33 @@ class HDF5DataLayer : public Layer<Dtype> {
};
template <typename Dtype>
class HDF5OutputLayer : public Layer<Dtype> {
public:
explicit HDF5OutputLayer(const LayerParameter& param);
virtual ~HDF5OutputLayer();
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
inline std::string file_name() const { return file_name_; }
protected:
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual void SaveBlobs();
std::string file_name_;
hid_t file_id_;
Blob<Dtype> data_blob_;
Blob<Dtype> label_blob_;
};
template <typename Dtype>
class SoftmaxLayer : public Layer<Dtype> {
public:

Просмотреть файл

@ -37,6 +37,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
return new FlattenLayer<Dtype>(param);
} else if (type == "hdf5_data") {
return new HDF5DataLayer<Dtype>(param);
} else if (type == "hdf5_output") {
return new HDF5OutputLayer<Dtype>(param);
} else if (type == "images") {
return new ImagesLayer<Dtype>(param);
} else if (type == "im2col") {

Просмотреть файл

@ -0,0 +1,88 @@
// Copyright 2014 BVLC and contributors.
/*
Contributors:
- kloudkl@github, 2014.
*/
#include <vector>
#include "hdf5.h"
#include "hdf5_hl.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
using std::vector;
template <typename Dtype>
HDF5OutputLayer<Dtype>::HDF5OutputLayer(const LayerParameter& param)
: Layer<Dtype>(param),
file_name_(param.hdf5_output_param().file_name()) {
/* create a HDF5 file */
file_id_ = H5Fcreate(file_name_.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
H5P_DEFAULT);
CHECK_GE(file_id_, 0) << "Failed to open HDF5 file" << file_name_;
}
template <typename Dtype>
HDF5OutputLayer<Dtype>::~HDF5OutputLayer<Dtype>() {
herr_t status = H5Fclose(file_id_);
CHECK_GE(status, 0) << "Failed to close HDF5 file " << file_name_;
}
template <typename Dtype>
void HDF5OutputLayer<Dtype>::SaveBlobs() {
// TODO: no limit on the number of blobs
LOG(INFO) << "Saving HDF5 file" << file_name_;
CHECK_EQ(data_blob_.num(), label_blob_.num()) <<
"data blob and label blob must have the same batch size";
hdf5_save_nd_dataset(file_id_, HDF5_DATA_DATASET_NAME, data_blob_);
hdf5_save_nd_dataset(file_id_, HDF5_DATA_LABEL_NAME, label_blob_);
LOG(INFO) << "Successfully saved " << data_blob_.num() << " rows";
}
template <typename Dtype>
void HDF5OutputLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// TODO: no limit on the number of blobs
CHECK_EQ(bottom.size(), 2) << "HDF5OutputLayer takes two blobs as input.";
CHECK_EQ(top->size(), 0) << "HDF5OutputLayer takes no output blobs.";
}
template <typename Dtype>
Dtype HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_GE(bottom.size(), 2);
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),
bottom[1]->height(), bottom[1]->width());
const int data_datum_dim = bottom[0]->count() / bottom[0]->num();
const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
for (int i = 0; i < bottom[0]->num(); ++i) {
memcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
&bottom[0]->cpu_data()[i * data_datum_dim],
sizeof(Dtype) * data_datum_dim);
memcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
&bottom[1]->cpu_data()[i * label_datum_dim],
sizeof(Dtype) * label_datum_dim);
}
SaveBlobs();
return Dtype(0.);
}
template <typename Dtype>
void HDF5OutputLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return;
}
INSTANTIATE_CLASS(HDF5OutputLayer);
} // namespace caffe

Просмотреть файл

@ -0,0 +1,53 @@
// Copyright 2014 BVLC and contributors.
/*
Contributors:
- kloudkl@github, 2014.
*/
#include <vector>
#include "hdf5.h"
#include "hdf5_hl.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
using std::vector;
template <typename Dtype>
Dtype HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_GE(bottom.size(), 2);
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),
bottom[1]->height(), bottom[1]->width());
const int data_datum_dim = bottom[0]->count() / bottom[0]->num();
const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
for (int i = 0; i < bottom[0]->num(); ++i) {
CUDA_CHECK(cudaMemcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
&bottom[0]->gpu_data()[i * data_datum_dim],
sizeof(Dtype) * data_datum_dim, cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
&bottom[1]->gpu_data()[i * label_datum_dim],
sizeof(Dtype) * label_datum_dim, cudaMemcpyDeviceToHost));
}
SaveBlobs();
return Dtype(0.);
}
template <typename Dtype>
void HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
return;
}
INSTANTIATE_CLASS(HDF5OutputLayer);
} // namespace caffe

Просмотреть файл

@ -125,6 +125,12 @@ message LayerParameter {
// the other dimensions must be the same for all the bottom blobs.
// By default it will concatenate blobs along the channels dimension.
optional uint32 concat_dim = 65 [default = 1];
optional HDF5OutputParameter hdf5_output_param = 1001;
}
message HDF5OutputParameter {
optional string file_name = 1;
}
message LayerConnection {

Просмотреть файл

@ -0,0 +1,127 @@
// Copyright 2014 kloudkl@github
#include <cuda_runtime.h>
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/test/test_caffe_main.hpp"
namespace caffe {
using std::string;
using std::vector;
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
template <typename Dtype>
class HDF5OutputLayerTest : public ::testing::Test {
protected:
HDF5OutputLayerTest()
: output_file_name_("/tmp/test_hdf5_output_layer-sample_data.hdf5"),
input_file_name_("src/caffe/test/test_data/sample_data.h5"),
blob_data_(new Blob<Dtype>()),
blob_label_(new Blob<Dtype>()),
num_(5),
channels_(8),
height_(5),
width_(5) {
}
virtual void SetUp() {
}
virtual ~HDF5OutputLayerTest() {
delete blob_data_;
delete blob_label_;
}
void CheckBlobEqual(const Blob<Dtype>& b1, const Blob<Dtype>& b2);
string output_file_name_;
string input_file_name_;
Blob<Dtype>* const blob_data_;
Blob<Dtype>* const blob_label_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
int num_;
int channels_;
int height_;
int width_;
};
template <typename Dtype>
void HDF5OutputLayerTest<Dtype>::CheckBlobEqual(
const Blob<Dtype>& b1, const Blob<Dtype>& b2) {
EXPECT_EQ(b1.num(), b2.num());
EXPECT_EQ(b1.channels(), b2.channels());
EXPECT_EQ(b1.height(), b2.height());
EXPECT_EQ(b1.width(), b2.width());
for (int n = 0; n < b1.num(); ++n) {
for (int c = 0; c < b1.channels(); ++c) {
for (int h = 0; h < b1.height(); ++h) {
for (int w = 0; w < b1.width(); ++w) {
EXPECT_EQ(b1.data_at(n, c, h, w), b1.data_at(n, c, h, w));
}
}
}
}
}
typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(HDF5OutputLayerTest, Dtypes);
TYPED_TEST(HDF5OutputLayerTest, TestForward) {
LOG(INFO) << "Loading HDF5 file " << this->input_file_name_;
hid_t file_id = H5Fopen(this->input_file_name_.c_str(), H5F_ACC_RDONLY,
H5P_DEFAULT);
ASSERT_GE(file_id, 0) << "Failed to open HDF5 file" <<
this->input_file_name_;
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
this->blob_data_);
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
this->blob_label_);
herr_t status = H5Fclose(file_id);
EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
this->input_file_name_;
this->blob_bottom_vec_.push_back(this->blob_data_);
this->blob_bottom_vec_.push_back(this->blob_label_);
Caffe::Brew modes[] = { Caffe::CPU, Caffe::GPU };
for (int m = 0; m < 2; ++m) {
Caffe::set_mode(modes[m]);
LayerParameter param;
param.mutable_hdf5_output_param()->set_file_name(this->output_file_name_);
// This code block ensures that the layer is deconstructed and
// the output hdf5 file is closed.
{
HDF5OutputLayer<TypeParam> layer(param);
EXPECT_EQ(layer.file_name(), this->output_file_name_);
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
}
hid_t file_id = H5Fopen(this->output_file_name_.c_str(), H5F_ACC_RDONLY,
H5P_DEFAULT);
ASSERT_GE(file_id, 0) << "Failed to open HDF5 file" <<
this->input_file_name_;
Blob<TypeParam>* blob_data = new Blob<TypeParam>();
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
blob_data);
this->CheckBlobEqual(*(this->blob_data_), *blob_data);
Blob<TypeParam>* blob_label = new Blob<TypeParam>();
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
blob_label);
this->CheckBlobEqual(*(this->blob_label_), *blob_label);
herr_t status = H5Fclose(file_id);
EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
this->output_file_name_;
}
}
} // namespace caffe

Просмотреть файл

@ -142,4 +142,30 @@ void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
file_id, dataset_name_, blob->mutable_cpu_data());
}
template <>
void hdf5_save_nd_dataset<float>(
const hid_t file_id, const string dataset_name, const Blob<float>& blob) {
hsize_t dims[HDF5_NUM_DIMS];
dims[0] = blob.num();
dims[1] = blob.channels();
dims[2] = blob.height();
dims[3] = blob.width();
herr_t status = H5LTmake_dataset_float(
file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data());
CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
}
template <>
void hdf5_save_nd_dataset<double>(
const hid_t file_id, const string dataset_name, const Blob<double>& blob) {
hsize_t dims[HDF5_NUM_DIMS];
dims[0] = blob.num();
dims[1] = blob.channels();
dims[2] = blob.height();
dims[3] = blob.width();
herr_t status = H5LTmake_dataset_double(
file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data());
CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
}
} // namespace caffe