зеркало из https://github.com/microsoft/caffe.git
Merge pull request #252 from kloudkl/hdf5_output_layer
Hdf5 output layer
This commit is contained in:
Коммит
d3e4c21d91
|
@ -15,6 +15,8 @@
|
|||
using std::string;
|
||||
using ::google::protobuf::Message;
|
||||
|
||||
#define HDF5_NUM_DIMS 4
|
||||
|
||||
namespace caffe {
|
||||
|
||||
void ReadProtoFromTextFile(const char* filename,
|
||||
|
@ -60,6 +62,10 @@ void hdf5_load_nd_dataset(
|
|||
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
|
||||
Blob<Dtype>* blob);
|
||||
|
||||
template <typename Dtype>
|
||||
void hdf5_save_nd_dataset(
|
||||
const hid_t file_id, const string dataset_name, const Blob<Dtype>& blob);
|
||||
|
||||
} // namespace caffe
|
||||
|
||||
#endif // CAFFE_UTIL_IO_H_
|
||||
|
|
|
@ -15,6 +15,9 @@
|
|||
#include "caffe/layer.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
|
||||
#define HDF5_DATA_DATASET_NAME "data"
|
||||
#define HDF5_DATA_LABEL_NAME "label"
|
||||
|
||||
namespace caffe {
|
||||
|
||||
|
||||
|
@ -477,6 +480,33 @@ class HDF5DataLayer : public Layer<Dtype> {
|
|||
};
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
class HDF5OutputLayer : public Layer<Dtype> {
|
||||
public:
|
||||
explicit HDF5OutputLayer(const LayerParameter& param);
|
||||
virtual ~HDF5OutputLayer();
|
||||
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top);
|
||||
inline std::string file_name() const { return file_name_; }
|
||||
|
||||
protected:
|
||||
virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top);
|
||||
virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top);
|
||||
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
|
||||
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
|
||||
virtual void SaveBlobs();
|
||||
|
||||
std::string file_name_;
|
||||
hid_t file_id_;
|
||||
Blob<Dtype> data_blob_;
|
||||
Blob<Dtype> label_blob_;
|
||||
};
|
||||
|
||||
|
||||
template <typename Dtype>
|
||||
class SoftmaxLayer : public Layer<Dtype> {
|
||||
public:
|
||||
|
|
|
@ -37,6 +37,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
|
|||
return new FlattenLayer<Dtype>(param);
|
||||
} else if (type == "hdf5_data") {
|
||||
return new HDF5DataLayer<Dtype>(param);
|
||||
} else if (type == "hdf5_output") {
|
||||
return new HDF5OutputLayer<Dtype>(param);
|
||||
} else if (type == "images") {
|
||||
return new ImagesLayer<Dtype>(param);
|
||||
} else if (type == "im2col") {
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright 2014 BVLC and contributors.
|
||||
/*
|
||||
Contributors:
|
||||
- kloudkl@github, 2014.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "hdf5.h"
|
||||
#include "hdf5_hl.h"
|
||||
|
||||
#include "caffe/blob.hpp"
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/layer.hpp"
|
||||
#include "caffe/util/io.hpp"
|
||||
#include "caffe/vision_layers.hpp"
|
||||
|
||||
namespace caffe {
|
||||
using std::vector;
|
||||
|
||||
template <typename Dtype>
|
||||
HDF5OutputLayer<Dtype>::HDF5OutputLayer(const LayerParameter& param)
|
||||
: Layer<Dtype>(param),
|
||||
file_name_(param.hdf5_output_param().file_name()) {
|
||||
/* create a HDF5 file */
|
||||
file_id_ = H5Fcreate(file_name_.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
|
||||
H5P_DEFAULT);
|
||||
CHECK_GE(file_id_, 0) << "Failed to open HDF5 file" << file_name_;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
HDF5OutputLayer<Dtype>::~HDF5OutputLayer<Dtype>() {
|
||||
herr_t status = H5Fclose(file_id_);
|
||||
CHECK_GE(status, 0) << "Failed to close HDF5 file " << file_name_;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void HDF5OutputLayer<Dtype>::SaveBlobs() {
|
||||
// TODO: no limit on the number of blobs
|
||||
LOG(INFO) << "Saving HDF5 file" << file_name_;
|
||||
CHECK_EQ(data_blob_.num(), label_blob_.num()) <<
|
||||
"data blob and label blob must have the same batch size";
|
||||
hdf5_save_nd_dataset(file_id_, HDF5_DATA_DATASET_NAME, data_blob_);
|
||||
hdf5_save_nd_dataset(file_id_, HDF5_DATA_LABEL_NAME, label_blob_);
|
||||
LOG(INFO) << "Successfully saved " << data_blob_.num() << " rows";
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void HDF5OutputLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top) {
|
||||
// TODO: no limit on the number of blobs
|
||||
CHECK_EQ(bottom.size(), 2) << "HDF5OutputLayer takes two blobs as input.";
|
||||
CHECK_EQ(top->size(), 0) << "HDF5OutputLayer takes no output blobs.";
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
Dtype HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top) {
|
||||
CHECK_GE(bottom.size(), 2);
|
||||
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
|
||||
data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),
|
||||
bottom[0]->height(), bottom[0]->width());
|
||||
label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),
|
||||
bottom[1]->height(), bottom[1]->width());
|
||||
const int data_datum_dim = bottom[0]->count() / bottom[0]->num();
|
||||
const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
|
||||
|
||||
for (int i = 0; i < bottom[0]->num(); ++i) {
|
||||
memcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
|
||||
&bottom[0]->cpu_data()[i * data_datum_dim],
|
||||
sizeof(Dtype) * data_datum_dim);
|
||||
memcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
|
||||
&bottom[1]->cpu_data()[i * label_datum_dim],
|
||||
sizeof(Dtype) * label_datum_dim);
|
||||
}
|
||||
SaveBlobs();
|
||||
return Dtype(0.);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void HDF5OutputLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
|
||||
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
|
||||
return;
|
||||
}
|
||||
|
||||
INSTANTIATE_CLASS(HDF5OutputLayer);
|
||||
|
||||
} // namespace caffe
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright 2014 BVLC and contributors.
|
||||
/*
|
||||
Contributors:
|
||||
- kloudkl@github, 2014.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "hdf5.h"
|
||||
#include "hdf5_hl.h"
|
||||
|
||||
#include "caffe/blob.hpp"
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/layer.hpp"
|
||||
#include "caffe/util/io.hpp"
|
||||
#include "caffe/vision_layers.hpp"
|
||||
|
||||
namespace caffe {
|
||||
using std::vector;
|
||||
|
||||
template <typename Dtype>
|
||||
Dtype HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
|
||||
vector<Blob<Dtype>*>* top) {
|
||||
CHECK_GE(bottom.size(), 2);
|
||||
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
|
||||
data_blob_.Reshape(bottom[0]->num(), bottom[0]->channels(),
|
||||
bottom[0]->height(), bottom[0]->width());
|
||||
label_blob_.Reshape(bottom[1]->num(), bottom[1]->channels(),
|
||||
bottom[1]->height(), bottom[1]->width());
|
||||
const int data_datum_dim = bottom[0]->count() / bottom[0]->num();
|
||||
const int label_datum_dim = bottom[1]->count() / bottom[1]->num();
|
||||
|
||||
for (int i = 0; i < bottom[0]->num(); ++i) {
|
||||
CUDA_CHECK(cudaMemcpy(&data_blob_.mutable_cpu_data()[i * data_datum_dim],
|
||||
&bottom[0]->gpu_data()[i * data_datum_dim],
|
||||
sizeof(Dtype) * data_datum_dim, cudaMemcpyDeviceToHost));
|
||||
CUDA_CHECK(cudaMemcpy(&label_blob_.mutable_cpu_data()[i * label_datum_dim],
|
||||
&bottom[1]->gpu_data()[i * label_datum_dim],
|
||||
sizeof(Dtype) * label_datum_dim, cudaMemcpyDeviceToHost));
|
||||
}
|
||||
SaveBlobs();
|
||||
return Dtype(0.);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
void HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
|
||||
const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
|
||||
return;
|
||||
}
|
||||
|
||||
INSTANTIATE_CLASS(HDF5OutputLayer);
|
||||
|
||||
} // namespace caffe
|
|
@ -125,6 +125,12 @@ message LayerParameter {
|
|||
// the other dimensions must be the same for all the bottom blobs.
|
||||
// By default it will concatenate blobs along the channels dimension.
|
||||
optional uint32 concat_dim = 65 [default = 1];
|
||||
|
||||
optional HDF5OutputParameter hdf5_output_param = 1001;
|
||||
}
|
||||
|
||||
message HDF5OutputParameter {
|
||||
optional string file_name = 1;
|
||||
}
|
||||
|
||||
message LayerConnection {
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
// Copyright 2014 kloudkl@github
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "caffe/blob.hpp"
|
||||
#include "caffe/common.hpp"
|
||||
#include "caffe/util/io.hpp"
|
||||
#include "caffe/vision_layers.hpp"
|
||||
#include "caffe/proto/caffe.pb.h"
|
||||
#include "caffe/test/test_caffe_main.hpp"
|
||||
|
||||
namespace caffe {
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
|
||||
|
||||
template <typename Dtype>
|
||||
class HDF5OutputLayerTest : public ::testing::Test {
|
||||
protected:
|
||||
HDF5OutputLayerTest()
|
||||
: output_file_name_("/tmp/test_hdf5_output_layer-sample_data.hdf5"),
|
||||
input_file_name_("src/caffe/test/test_data/sample_data.h5"),
|
||||
blob_data_(new Blob<Dtype>()),
|
||||
blob_label_(new Blob<Dtype>()),
|
||||
num_(5),
|
||||
channels_(8),
|
||||
height_(5),
|
||||
width_(5) {
|
||||
}
|
||||
virtual void SetUp() {
|
||||
}
|
||||
|
||||
virtual ~HDF5OutputLayerTest() {
|
||||
delete blob_data_;
|
||||
delete blob_label_;
|
||||
}
|
||||
|
||||
void CheckBlobEqual(const Blob<Dtype>& b1, const Blob<Dtype>& b2);
|
||||
|
||||
string output_file_name_;
|
||||
string input_file_name_;
|
||||
Blob<Dtype>* const blob_data_;
|
||||
Blob<Dtype>* const blob_label_;
|
||||
vector<Blob<Dtype>*> blob_bottom_vec_;
|
||||
vector<Blob<Dtype>*> blob_top_vec_;
|
||||
int num_;
|
||||
int channels_;
|
||||
int height_;
|
||||
int width_;
|
||||
};
|
||||
|
||||
template <typename Dtype>
|
||||
void HDF5OutputLayerTest<Dtype>::CheckBlobEqual(
|
||||
const Blob<Dtype>& b1, const Blob<Dtype>& b2) {
|
||||
EXPECT_EQ(b1.num(), b2.num());
|
||||
EXPECT_EQ(b1.channels(), b2.channels());
|
||||
EXPECT_EQ(b1.height(), b2.height());
|
||||
EXPECT_EQ(b1.width(), b2.width());
|
||||
for (int n = 0; n < b1.num(); ++n) {
|
||||
for (int c = 0; c < b1.channels(); ++c) {
|
||||
for (int h = 0; h < b1.height(); ++h) {
|
||||
for (int w = 0; w < b1.width(); ++w) {
|
||||
EXPECT_EQ(b1.data_at(n, c, h, w), b1.data_at(n, c, h, w));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef ::testing::Types<float, double> Dtypes;
|
||||
TYPED_TEST_CASE(HDF5OutputLayerTest, Dtypes);
|
||||
|
||||
TYPED_TEST(HDF5OutputLayerTest, TestForward) {
|
||||
LOG(INFO) << "Loading HDF5 file " << this->input_file_name_;
|
||||
hid_t file_id = H5Fopen(this->input_file_name_.c_str(), H5F_ACC_RDONLY,
|
||||
H5P_DEFAULT);
|
||||
ASSERT_GE(file_id, 0) << "Failed to open HDF5 file" <<
|
||||
this->input_file_name_;
|
||||
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
|
||||
this->blob_data_);
|
||||
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
|
||||
this->blob_label_);
|
||||
herr_t status = H5Fclose(file_id);
|
||||
EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
|
||||
this->input_file_name_;
|
||||
this->blob_bottom_vec_.push_back(this->blob_data_);
|
||||
this->blob_bottom_vec_.push_back(this->blob_label_);
|
||||
|
||||
Caffe::Brew modes[] = { Caffe::CPU, Caffe::GPU };
|
||||
for (int m = 0; m < 2; ++m) {
|
||||
Caffe::set_mode(modes[m]);
|
||||
LayerParameter param;
|
||||
param.mutable_hdf5_output_param()->set_file_name(this->output_file_name_);
|
||||
// This code block ensures that the layer is deconstructed and
|
||||
// the output hdf5 file is closed.
|
||||
{
|
||||
HDF5OutputLayer<TypeParam> layer(param);
|
||||
EXPECT_EQ(layer.file_name(), this->output_file_name_);
|
||||
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
|
||||
layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
|
||||
}
|
||||
hid_t file_id = H5Fopen(this->output_file_name_.c_str(), H5F_ACC_RDONLY,
|
||||
H5P_DEFAULT);
|
||||
ASSERT_GE(file_id, 0) << "Failed to open HDF5 file" <<
|
||||
this->input_file_name_;
|
||||
|
||||
Blob<TypeParam>* blob_data = new Blob<TypeParam>();
|
||||
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
|
||||
blob_data);
|
||||
this->CheckBlobEqual(*(this->blob_data_), *blob_data);
|
||||
|
||||
Blob<TypeParam>* blob_label = new Blob<TypeParam>();
|
||||
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
|
||||
blob_label);
|
||||
this->CheckBlobEqual(*(this->blob_label_), *blob_label);
|
||||
|
||||
herr_t status = H5Fclose(file_id);
|
||||
EXPECT_GE(status, 0) << "Failed to close HDF5 file " <<
|
||||
this->output_file_name_;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace caffe
|
|
@ -142,4 +142,30 @@ void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
|
|||
file_id, dataset_name_, blob->mutable_cpu_data());
|
||||
}
|
||||
|
||||
template <>
|
||||
void hdf5_save_nd_dataset<float>(
|
||||
const hid_t file_id, const string dataset_name, const Blob<float>& blob) {
|
||||
hsize_t dims[HDF5_NUM_DIMS];
|
||||
dims[0] = blob.num();
|
||||
dims[1] = blob.channels();
|
||||
dims[2] = blob.height();
|
||||
dims[3] = blob.width();
|
||||
herr_t status = H5LTmake_dataset_float(
|
||||
file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data());
|
||||
CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
|
||||
}
|
||||
|
||||
template <>
|
||||
void hdf5_save_nd_dataset<double>(
|
||||
const hid_t file_id, const string dataset_name, const Blob<double>& blob) {
|
||||
hsize_t dims[HDF5_NUM_DIMS];
|
||||
dims[0] = blob.num();
|
||||
dims[1] = blob.channels();
|
||||
dims[2] = blob.height();
|
||||
dims[3] = blob.width();
|
||||
herr_t status = H5LTmake_dataset_double(
|
||||
file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data());
|
||||
CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
|
||||
}
|
||||
|
||||
} // namespace caffe
|
||||
|
|
Загрузка…
Ссылка в новой задаче