diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp index 7bf78977..e5405727 100644 --- a/include/caffe/util/io.hpp +++ b/include/caffe/util/io.hpp @@ -15,6 +15,8 @@ using std::string; using ::google::protobuf::Message; +#define HDF5_NUM_DIMS 4 + namespace caffe { void ReadProtoFromTextFile(const char* filename, @@ -60,6 +62,10 @@ void hdf5_load_nd_dataset( hid_t file_id, const char* dataset_name_, int min_dim, int max_dim, Blob* blob); +template +void hdf5_save_nd_dataset( + const hid_t file_id, const string dataset_name, const Blob& blob); + } // namespace caffe #endif // CAFFE_UTIL_IO_H_ diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index 3ac69f97..053d7a40 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -142,4 +142,30 @@ void hdf5_load_nd_dataset(hid_t file_id, const char* dataset_name_, file_id, dataset_name_, blob->mutable_cpu_data()); } +template <> +void hdf5_save_nd_dataset( + const hid_t file_id, const string dataset_name, const Blob& blob) { + hsize_t dims[HDF5_NUM_DIMS]; + dims[0] = blob.num(); + dims[1] = blob.channels(); + dims[2] = blob.height(); + dims[3] = blob.width(); + herr_t status = H5LTmake_dataset_float( + file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data()); + CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name; +} + +template <> +void hdf5_save_nd_dataset( + const hid_t file_id, const string dataset_name, const Blob& blob) { + hsize_t dims[HDF5_NUM_DIMS]; + dims[0] = blob.num(); + dims[1] = blob.channels(); + dims[2] = blob.height(); + dims[3] = blob.width(); + herr_t status = H5LTmake_dataset_double( + file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data()); + CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name; +} + } // namespace caffe