diff --git a/egs/swbd/s5b/local/run_nnet2.sh b/egs/swbd/s5b/local/run_nnet2.sh new file mode 100755 index 000000000..8436e9bac --- /dev/null +++ b/egs/swbd/s5b/local/run_nnet2.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# This runs on the 100 hour subset. + +. cmd.sh + +( # TODO: change 5a to 5a2. + if [ ! -f exp/nnet5a2/final.mdl ]; then + steps/nnet2/train_tanh.sh --stage 215 \ + --mix-up 8000 \ + --initial-learning-rate 0.01 --final-learning-rate 0.001 \ + --num-jobs-nnet 16 --num-hidden-layers 4 \ + --hidden-layer-dim 1024 \ + --cmd "$decode_cmd" \ + data/train_100k_nodup data/lang exp/tri4a exp/nnet5a2 || exit 1; + fi + + for lm_suffix in tg fsh_tgpr; do + steps/decode_nnet_cpu.sh --cmd "$decode_cmd" --nj 30 \ + --config conf/decode.config --transform-dir exp/tri4a/decode_eval2000_sw1_${lm_suffix} \ + exp/tri4a/graph_sw1_${lm_suffix} data/eval2000 exp/nnet5a2/decode_eval2000_sw1_${lm_suffix} & + done +) + diff --git a/egs/swbd/s5b/local/run_nnet2_gpu.sh b/egs/swbd/s5b/local/run_nnet2_gpu.sh index c78cf1218..4b838fb52 100755 --- a/egs/swbd/s5b/local/run_nnet2_gpu.sh +++ b/egs/swbd/s5b/local/run_nnet2_gpu.sh @@ -1,17 +1,19 @@ #!/bin/bash # This runs on the 100 hour subset. This version of the recipe runs on GPUs. -# We assume you have 8 GPU machines. You have to use --num-threads 1 so it will -# use the version of the code that can use GPUs. -# We assume the queue is set up as in JHU (or as in the "Kluster" project -# on Sourceforge) where "gpu" is a consumable resource that you can set to -# number of GPU cards a machine has. +# We assume you have 8 GPU cards. You have to use --num-threads 1 so it will +# use the version of the code that can use GPUs (the -parallel training code +# cannot use GPUs unless we make further modifications as the CUDA model assumes +# a single thread per GPU context, and we're not currently set up to create multiple +# GPU contexts. We assume the queue is set up as in JHU (or +# as in the "Kluster" project on Sourceforge) where "gpu" is a consumable +# resource that you can set to number of GPU cards a machine has. . cmd.sh ( if [ ! -f exp/nnet5b/final.mdl ]; then - steps/nnet2/train_tanh.sh --cmd "$decode_cmd -l gpu=1" --parallel-opts "" --stage 0 \ + steps/nnet2/train_tanh.sh --cmd "$decode_cmd -l gpu=1" --parallel-opts "" --stage 253 \ --num-threads 1 \ --mix-up 8000 \ --initial-learning-rate 0.01 --final-learning-rate 0.001 \ @@ -26,4 +28,3 @@ exp/tri4a/graph_sw1_${lm_suffix} data/eval2000 exp/nnet5b/decode_eval2000_sw1_${lm_suffix} & done ) - diff --git a/egs/swbd/s5b/local/run_nnet2_gpu2.sh b/egs/swbd/s5b/local/run_nnet2_gpu2.sh new file mode 100755 index 000000000..d65e24a87 --- /dev/null +++ b/egs/swbd/s5b/local/run_nnet2_gpu2.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# This runs on the 100 hour subset. This version of the recipe runs on GPUs. +# We assume you have 8 GPU cards. You have to use --num-threads 1 so it will +# use the version of the code that can use GPUs (the -parallel training code +# cannot use GPUs unless we make further modifications as the CUDA model assumes +# a single thread per GPU context, and we're not currently set up to create multiple +# GPU contexts. We assume the queue is set up as in JHU (or +# as in the "Kluster" project on Sourceforge) where "gpu" is a consumable +# resource that you can set to number of GPU cards a machine has. + +. cmd.sh + +( + if [ ! -f exp/nnet5b/final.mdl ]; then + steps/nnet2/train_tanh.sh --cmd "$decode_cmd -l gpu=1" --parallel-opts "" --io-opts "-tc 5 -l gpu=0" --stage -3 \ + --num-threads 1 --minibatch-size 512 --max-change 40.0 --mix-up 8000 \ + --initial-learning-rate 0.01 --final-learning-rate 0.001 \ + --num-jobs-nnet 8 --num-hidden-layers 4 \ + --hidden-layer-dim 1024 \ + data/train_100k_nodup data/lang exp/tri4a exp/nnet5b2 || exit 1; + fi + + for lm_suffix in tg fsh_tgpr; do + steps/decode_nnet_cpu.sh --cmd "$decode_cmd" --nj 30 \ + --config conf/decode.config --transform-dir exp/tri4a/decode_eval2000_sw1_${lm_suffix} \ + exp/tri4a/graph_sw1_${lm_suffix} data/eval2000 exp/nnet5b2/decode_eval2000_sw1_${lm_suffix} & + done +) diff --git a/egs/swbd/s5b/local/run_nnet_cpu.sh b/egs/swbd/s5b/local/run_nnet_cpu.sh index a326b9662..ff58101d4 100755 --- a/egs/swbd/s5b/local/run_nnet_cpu.sh +++ b/egs/swbd/s5b/local/run_nnet_cpu.sh @@ -6,7 +6,7 @@ ( if [ ! -f exp/nnet5a/final.mdl ]; then - steps/train_nnet_cpu.sh \ + steps/train_nnet_cpu.sh --stage 103 \ --mix-up 8000 \ --initial-learning-rate 0.01 --final-learning-rate 0.001 \ --num-jobs-nnet 16 --num-hidden-layers 4 \ diff --git a/egs/wsj/s5/steps/nnet2/train_block.sh b/egs/wsj/s5/steps/nnet2/train_block.sh index aa6e2e725..bbb89da30 100755 --- a/egs/wsj/s5/steps/nnet2/train_block.sh +++ b/egs/wsj/s5/steps/nnet2/train_block.sh @@ -62,7 +62,7 @@ max_change=10.0 mix_up=0 # Number of components to mix up to (should be > #tree leaves, if # specified.) num_threads=16 -parallel_opts="-pe smp $num_threads" # using a smallish #threads by default, out of stability concerns. +parallel_opts="-pe smp 16 -l ram_free=1G,mem_free=1G" # by default we use 16 threads; this lets the queue know. # note: parallel_opts doesn't automatically get adjusted if you adjust num-threads. cleanup=true egs_dir= @@ -105,8 +105,9 @@ if [ $# != 4 ]; then echo " --num-threads # Number of parallel threads per job (will affect results" echo " # as well as speed; may interact with batch size; if you increase" echo " # this, you may want to decrease the batch size." - echo " --parallel-opts # extra options to pass to e.g. queue.pl for processes that" - echo " # use multiple threads." + echo " --parallel-opts # extra options to pass to e.g. queue.pl for processes that" + echo " # use multiple threads... note, you might have to reduce mem_free,ram_free" + echo " # versus your defaults, because it gets multiplied by the -pe smp argument." echo " --io-opts # Options given to e.g. queue.pl for jobs that do a lot of I/O." echo " --minibatch-size # Size of minibatch to process (note: product with --num-threads" echo " # should not get too large, e.g. >2k)." diff --git a/egs/wsj/s5/steps/nnet2/train_tanh.sh b/egs/wsj/s5/steps/nnet2/train_tanh.sh index fcbdbaf2a..a33f4217e 100755 --- a/egs/wsj/s5/steps/nnet2/train_tanh.sh +++ b/egs/wsj/s5/steps/nnet2/train_tanh.sh @@ -50,7 +50,7 @@ num_hidden_layers=3 stage=-5 -io_opts="-tc 5" # for jobs with a lot of I/O, limits the number running at one time. +io_opts="-tc 5" # for jobs with a lot of I/O, limits the number running at one time. These don't splice_width=4 # meaning +- 4 frames on each side for second LDA randprune=4.0 # speeds up LDA. alpha=4.0 @@ -58,7 +58,7 @@ max_change=10.0 mix_up=0 # Number of components to mix up to (should be > #tree leaves, if # specified.) num_threads=16 -parallel_opts="-pe smp $num_threads" # using a smallish #threads by default, out of stability concerns. +parallel_opts="-pe smp 16 -l ram_free=1G,mem_free=1G" # by default we use 16 threads; this lets the queue know. # note: parallel_opts doesn't automatically get adjusted if you adjust num-threads. cleanup=true egs_dir= @@ -101,8 +101,9 @@ if [ $# != 4 ]; then echo " --num-threads # Number of parallel threads per job (will affect results" echo " # as well as speed; may interact with batch size; if you increase" echo " # this, you may want to decrease the batch size." - echo " --parallel-opts # extra options to pass to e.g. queue.pl for processes that" - echo " # use multiple threads." + echo " --parallel-opts # extra options to pass to e.g. queue.pl for processes that" + echo " # use multiple threads... note, you might have to reduce mem_free,ram_free" + echo " # versus your defaults, because it gets multiplied by the -pe smp argument." echo " --io-opts # Options given to e.g. queue.pl for jobs that do a lot of I/O." echo " --minibatch-size # Size of minibatch to process (note: product with --num-threads" echo " # should not get too large, e.g. >2k)." @@ -170,7 +171,7 @@ if [ $stage -le -3 ] && [ -z "$egs_dir" ]; then echo "$0: calling get_egs.sh" [ ! -z $spk_vecs_dir ] && spk_vecs_opt="--spk-vecs-dir $spk_vecs_dir"; steps/nnet2/get_egs.sh $spk_vecs_opt --samples-per-iter $samples_per_iter --num-jobs-nnet $num_jobs_nnet \ - --splice-width $splice_width --stage $get_egs_stage --cmd "$cmd" $egs_opts \ + --splice-width $splice_width --stage $get_egs_stage --cmd "$cmd" $egs_opts --io-opts "$io_opts" \ $data $lang $alidir $dir || exit 1; fi diff --git a/src/Makefile b/src/Makefile index 9075fcb9d..799dc70a7 100644 --- a/src/Makefile +++ b/src/Makefile @@ -9,6 +9,14 @@ SUBDIRS = base matrix util feat tree thread gmm tied transform sgmm \ fstext hmm lm decoder lat cudamatrix nnet \ bin fstbin gmmbin fgmmbin tiedbin sgmmbin featbin \ nnetbin latbin sgmm2 sgmm2bin nnet2 nnet2bin kwsbin + +MEMTESTDIRS = base matrix util feat tree thread gmm tied transform sgmm \ + fstext hmm lm decoder lat nnet \ + bin fstbin gmmbin fgmmbin tiedbin sgmmbin featbin \ + nnetbin latbin sgmm2 sgmm2bin nnet-cpu nnet-cpubin kwsbin + +CUDAMEMTESTDIR = cudamatrix + SUBDIRS_LIB = $(filter-out %bin, $(SUBDIRS)) @@ -97,8 +105,11 @@ ext_test: $(addsuffix /test, $(EXT_SUBDIRS)) %/test: % mklibdir $(MAKE) -C $< test +cudavalgrind: + -for x in $(CUDAMEMTESTDIR); do $(MAKE) -C $$x valgrind || { echo "valgrind on $$x failed"; exit 1; }; done + valgrind: - -for x in $(SUBDIRS); do $(MAKE) -C $$x valgrind || { echo "valgrind on $$x failed"; exit 1; }; done + -for x in $(MEMTESTDIRS); do $(MAKE) -C $$x valgrind || { echo "valgrind on $$x failed"; exit 1; }; done depend: $(addsuffix /depend, $(SUBDIRS)) diff --git a/src/cudamatrix/Makefile b/src/cudamatrix/Makefile index 8355c6146..1ea14a86a 100644 --- a/src/cudamatrix/Makefile +++ b/src/cudamatrix/Makefile @@ -12,7 +12,7 @@ LDFLAGS += $(CUDA_LDFLAGS) LDLIBS += $(CUDA_LDLIBS) TESTFILES = cu-vector-test cu-matrix-test cu-math-test cu-test cu-sp-matrix-test cu-packed-matrix-test cu-tp-matrix-test \ - cu-block-matrix-test cu-matrix-speed-test cu-vector-speed-test cu-sp-matrix-speed-test + cu-block-matrix-test cu-matrix-speed-test cu-vector-speed-test cu-sp-matrix-speed-test cu-array-test OBJFILES = cu-device.o cu-math.o cu-matrix.o cu-packed-matrix.o cu-sp-matrix.o \ diff --git a/src/cudamatrix/cu-array-inl.h b/src/cudamatrix/cu-array-inl.h index 7775e47f9..9e2e478ff 100644 --- a/src/cudamatrix/cu-array-inl.h +++ b/src/cudamatrix/cu-array-inl.h @@ -1,6 +1,7 @@ // cudamatrix/cu-array-inl.h // Copyright 2009-2012 Karel Vesely +// 2013 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // @@ -113,7 +114,7 @@ void CuArray::CopyToVec(std::vector *dst) const { if (CuDevice::Instantiate().Enabled()) { Timer tim; CU_SAFE_CALL(cudaMemcpy(&dst->front(), Data(), dim_*sizeof(T), cudaMemcpyDeviceToHost)); - CuDevice::Instantiate().AccuProfile("CuArray::CopyToVecD2H",tim.Elapsed()); + CuDevice::Instantiate().AccuProfile("CuArray::CopyToVecD2H", tim.Elapsed()); } else #endif { @@ -129,7 +130,7 @@ void CuArray::SetZero() { if (CuDevice::Instantiate().Enabled()) { Timer tim; CU_SAFE_CALL(cudaMemset(data_, 0, dim_ * sizeof(T))); - CuDevice::Instantiate().AccuProfile("CuArray::SetZero",tim.Elapsed()); + CuDevice::Instantiate().AccuProfile("CuArray::SetZero", tim.Elapsed()); } else #endif { @@ -184,6 +185,24 @@ inline void CuArray::Set(const int32 &value) { } } +template +void CuArray::CopyFromArray(const CuArray &src) { + this->Resize(src.Dim(), kUndefined); + if (dim_ == 0) return; +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + Timer tim; + CU_SAFE_CALL(cudaMemcpy(this->data_, src.data_, dim_ * sizeof(T), + cudaMemcpyDeviceToDevice)); + CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed()); + } else +#endif + { + memcpy(this->data_, src.data_, dim_ * sizeof(T)); + } +} + + } // namespace kaldi #endif diff --git a/src/cudamatrix/cu-array-test.cc b/src/cudamatrix/cu-array-test.cc new file mode 100644 index 000000000..f9c14978d --- /dev/null +++ b/src/cudamatrix/cu-array-test.cc @@ -0,0 +1,124 @@ +// cudamatrix/cu-array-test.cc + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "cudamatrix/cu-array.h" + +using namespace kaldi; + + +namespace kaldi { + + + + +template +static void UnitTestCuArray() { + for (int32 i = 0; i < 30; i++) { + int32 size = rand() % 5; + size = size * size * size; // Have a good distribution of sizes, including >256. + int32 size2 = rand() % 4; + std::vector vec(size); + std::vector garbage_vec(size2); // We just use garbage_vec to make sure + // we sometimes resize from empty, + // sometimes not. + + int32 byte_size = size * sizeof(T); + std::vector rand_c(byte_size); + for (size_t i = 0; i < byte_size; i++) + rand_c[i] = rand() % 256; + if (!vec.empty()) { + std::memcpy((void*)&(vec[0]), (void*)&(rand_c[0]), + byte_size); + } + + { // test constructor from vector and CopyToVec. + CuArray cu_vec(vec); + std::vector vec2; + cu_vec.CopyToVec(&vec2); + KALDI_ASSERT(vec2 == vec); + } + + { // test assignment operator from CuArray. + CuArray cu_vec(vec); + CuArray cu_vec2(garbage_vec); + cu_vec2 = cu_vec; + std::vector vec2; + cu_vec2.CopyToVec(&vec2); + KALDI_ASSERT(vec2 == vec); + KALDI_ASSERT(cu_vec2.Dim() == int32(vec2.size())); // test Dim() + } + + { // test resize with resize_type = kSetZero. + CuArray cu_vec(vec); + cu_vec.Resize(size, kSetZero); + std::vector vec2(vec); + + if (!vec2.empty()) + std::memset(&(vec2[0]), 0, vec2.size() * sizeof(T)); + std::vector vec3; + cu_vec.CopyToVec(&vec3); + KALDI_ASSERT(vec2 == vec3); // testing equality of zero arrays. + } + + if (sizeof(T) == sizeof(int32) && size > 0) { // test Set for type int32, or same size. + CuArray cu_vec(vec); + cu_vec.Set(vec[0]); + for (size_t i = 1; i < vec.size(); i++) vec[i] = vec[0]; + std::vector vec2; + cu_vec.CopyToVec(&vec2); + KALDI_ASSERT(vec2 == vec); + } + } +} + + +} // namespace kaldi + + +int main() { + for (int32 loop = 0; loop < 2; loop++) { +#if HAVE_CUDA == 1 + if (loop == 0) + CuDevice::Instantiate().SelectGpuId(-1); // -1 means no GPU + else + CuDevice::Instantiate().SelectGpuId(-2); // -2 .. automatic selection +#endif + + //kaldi::UnitTestCuArray(); + kaldi::UnitTestCuArray(); + kaldi::UnitTestCuArray(); + kaldi::UnitTestCuArray >(); + + if (loop == 0) + KALDI_LOG << "Tests without GPU use succeeded.\n"; + else + KALDI_LOG << "Tests with GPU use (if available) succeeded.\n"; + } +#if HAVE_CUDA == 1 + CuDevice::Instantiate().PrintProfile(); +#endif + return 0; +} diff --git a/src/cudamatrix/cu-array.h b/src/cudamatrix/cu-array.h index f27c6c746..6f759494e 100644 --- a/src/cudamatrix/cu-array.h +++ b/src/cudamatrix/cu-array.h @@ -49,7 +49,10 @@ class CuArray { /// Constructor from CPU-based int vector explicit CuArray(const std::vector &src): dim_(0), data_(NULL) { CopyFromVec(src); } - + + explicit CuArray(const CuArray &src): + dim_(0), data_(NULL) { CopyFromArray(src); } + /// Destructor ~CuArray() { Destroy(); } @@ -73,6 +76,9 @@ class CuArray { /// and any constructors or assignment operators are not called. void CopyFromVec(const std::vector &src); + /// This function resizes if needed. + void CopyFromArray(const CuArray &src); + /// This function resizes *dst if needed. On resize of "dst", the STL vector /// may call copy-constructors, initializers, and assignment operators for /// existing objects (which will be overwritten), but the copy from GPU to CPU @@ -88,6 +94,14 @@ class CuArray { /// assignment operators or destructors are not called. This is NOT IMPLEMENTED /// YET except for T == int32 (the current implementation will just crash). void Set(const T &value); + + CuArray &operator= (const CuArray &in) { + this->CopyFromArray(in); return *this; + } + + CuArray &operator= (const std::vector &in) { + this->CopyFromVec(in); return *this; + } private: MatrixIndexT dim_; ///< dimension of the vector diff --git a/src/cudamatrix/cu-block-matrix.cc b/src/cudamatrix/cu-block-matrix.cc index 3afcbb66b..1559cd681 100644 --- a/src/cudamatrix/cu-block-matrix.cc +++ b/src/cudamatrix/cu-block-matrix.cc @@ -165,6 +165,7 @@ void CuBlockMatrix::Write(std::ostream &os, bool binary) const { WriteToken(os, binary, ""); } + template void CuBlockMatrix::Read(std::istream &is, bool binary) { Destroy(); diff --git a/src/cudamatrix/cu-block-matrix.h b/src/cudamatrix/cu-block-matrix.h index b437f2ffb..d36d6afbd 100644 --- a/src/cudamatrix/cu-block-matrix.h +++ b/src/cudamatrix/cu-block-matrix.h @@ -44,6 +44,7 @@ namespace kaldi { 'primary' home remains on the CPU.. what we mean by this is that while the data remains on the GPU, the "primary" version of the Matrix object that holds the pointers will remain on the CPU. + We just copy it over to the GPU whenever it is changed. */ template @@ -94,10 +95,15 @@ class CuBlockMatrix { /// Copies elements within the block structure from matrix M, discarding others. - /// Note: this has not been impelemented in a very efficient way, it's used only + /// Note: this has not been implemented in a very efficient way, it's used only /// for testing. void CopyFromMat(const CuMatrix &M); + /// Normalizes the columns of *this so that each one sums to one. + /// On error (e.g. inf's), will set the column to a constant value that + /// sums to one. + void NormalizeColumns(); + void Swap(CuBlockMatrix *other); protected: diff --git a/src/cudamatrix/cu-device.cc b/src/cudamatrix/cu-device.cc index 80cbd04bb..3e888591d 100644 --- a/src/cudamatrix/cu-device.cc +++ b/src/cudamatrix/cu-device.cc @@ -79,7 +79,7 @@ void CuDevice::SelectGpuId(int32 gpu_id, bool abort_on_error) { // Check that we have a gpu available int32 n_gpu = 0; cudaGetDeviceCount(&n_gpu); - if(n_gpu == 0 && gpu_id == -2) { + if(n_gpu == 0) { // If we do automatic selection and no GPU is found, we run on a CPU if (abort_on_error) { KALDI_ERR << "No CUDA capable GPU was detected"; @@ -89,16 +89,6 @@ void CuDevice::SelectGpuId(int32 gpu_id, bool abort_on_error) { return; } } - if(n_gpu == 0) { - if (abort_on_error) { - KALDI_ERR << "No CUDA capable GPU was detected."; - } else { - KALDI_WARN << "No CUDA capable GPU detected, while explicitly asked for gpu-id '" - << gpu_id << "'.CUDA will NOT be used!!!"; - active_gpu_id_ = -2; - return; - } - } // Now we know that there is a GPU in the system, // and we don't want to have it disabled. @@ -390,7 +380,7 @@ void CuDevice::PrintProfile() { for(it = profile_map_.begin(); it != profile_map_.end(); ++it) pairs.push_back(std::make_pair(it->second, it->first)); std::sort(pairs.begin(), pairs.end()); - size_t max_print = 15, start_pos = (pairs.size() > max_print ? + size_t max_print = 15, start_pos = (pairs.size() <= max_print ? 0 : pairs.size() - max_print); for (size_t i = start_pos; i < pairs.size(); i++) os << pairs[i].second << "\t" << pairs[i].first << "s\n"; diff --git a/src/cudamatrix/cu-kernels-ansi.h b/src/cudamatrix/cu-kernels-ansi.h index f3b8339a4..80cfa4b2c 100644 --- a/src/cudamatrix/cu-kernels-ansi.h +++ b/src/cudamatrix/cu-kernels-ansi.h @@ -148,6 +148,9 @@ void cudaF_comp_obj_deriv(dim3 Gr,dim3 Bl, MatrixElement* x, int s, const void cudaF_transpose_matrix(dim3 Gr, dim3 Bl, float* mat, MatrixDim d); void cudaF_sy_add_tr2(dim3 Gr, dim3 Bl, float alpha, float beta, const float* T, MatrixDim tdim, float *S, MatrixDim sdim); +void cudaF_sum_column_ranges(dim3 Gr, dim3 Bl, float *data, MatrixDim dim, + const float *src_data, MatrixDim src_dim, + const Int32Pair *indices); /********************************************************* @@ -277,6 +280,10 @@ void cudaD_comp_obj_deriv(dim3 Gr,dim3 Bl, MatrixElement* x, int s, cons void cudaD_transpose_matrix(dim3 Gr, dim3 Bl, double* mat, MatrixDim d); void cudaD_sy_add_tr2(dim3 Gr, dim3 Bl, double alpha, double beta, const double* T, MatrixDim tdim, double *S, MatrixDim sdim); +void cudaD_sum_column_ranges(dim3 Gr, dim3 Bl, double *data, MatrixDim dim, + const double *src_data, MatrixDim src_dim, + const Int32Pair *indices); + } // extern "C" diff --git a/src/cudamatrix/cu-kernels.cu b/src/cudamatrix/cu-kernels.cu index 4b0ef2347..16f5b87a7 100644 --- a/src/cudamatrix/cu-kernels.cu +++ b/src/cudamatrix/cu-kernels.cu @@ -1297,6 +1297,64 @@ static void _block_add_mat_mat(CuBlockMatrixData *B_cu_data, int num_blocks, } +template +__global__ +static void _blockadd_mat_blockmat_trans(Real *data, MatrixDim dim, const Real *A_data, int A_num_rows, int A_num_cols, + int A_row_stride, int A_col_stride, const CuBlockMatrixData *B_cu_data, + int B_num_blocks, Real alpha, Real beta) { + int i = blockIdx.x * blockDim.x + threadIdx.x; // row-index into "data" + int j = blockIdx.y * blockDim.y + threadIdx.y; // block-index into B. + if (i >= A_num_rows || j >= B_num_blocks) return; + + const CuBlockMatrixData &cu_data = B_cu_data[j]; + + // BT means B transposed. + int BT_row_start = cu_data.col_offset, + BT_col_start = cu_data.row_offset, + BT_num_rows = cu_data.matrix_dim.cols, + BT_num_cols = cu_data.matrix_dim.rows, + BT_col_stride = cu_data.matrix_dim.stride; + const Real *B_data = static_cast(cu_data.matrix_data); // Cast from void; + // we avoided a bunch of hassle by doing this (relates to Ansi-C requirement). + + for (int k = 0; k < BT_num_cols; k++) { + const Real *this_BT_col = B_data + k * BT_col_stride; + const Real *this_A_row = A_data + i * A_row_stride + BT_row_start * A_col_stride; + // this_A_row points to the element A[i][BT_row_start], it's really just + // part of this row of A. + Real sum = 0.0; + for (int l = 0; l < BT_num_rows; l++) // l indexes rows of B. + sum += this_BT_col[l] * this_A_row[l * A_col_stride]; + + int index = i * dim.stride + (k + BT_col_start); + data[index] = alpha * sum + beta * data[index]; + } +} + + +// Since this is a newer kernel, x is the row-index and y is the +// column-index. +template +__global__ +static void _sum_column_ranges(Real *data, MatrixDim dim, + const Real *src_data, + MatrixDim src_dim, + const Int32Pair *indices) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + int col = blockIdx.y * blockDim.y + threadIdx.y; + if (row >= dim.rows || col >= dim.cols) + return; + int dest_index = row * dim.stride + col, + src_start_index = row * src_dim.stride + indices[col].first, + src_end_index = row * src_dim.stride + indices[col].second; + Real sum = 0.0; + for (int index = src_start_index; index < src_end_index; index++) + sum += src_data[index]; + data[dest_index] = sum; +} + + + template __global__ static void _soft_hinge(Real*y, const Real*x, MatrixDim d, int src_stride) { @@ -2047,6 +2105,11 @@ void cudaF_copy_col_from_mat_fd(int Gr, int Bl, float* v, int col, const float* _copy_col_from_mat_fd<<>>(v,col,mat,dmat,dim); } +void cudaF_sum_column_ranges(dim3 Gr, dim3 Bl, float *data, MatrixDim dim, + const float *src_data, MatrixDim src_dim, + const Int32Pair *indices) { + _sum_column_ranges<<>>(data, dim, src_data, src_dim, indices); +} @@ -2407,6 +2470,11 @@ void cudaD_copy_rows_from_vec(dim3 Gr, dim3 Bl, double *mat_out, MatrixDim d_out _copy_rows_from_vec<<>>(mat_out, d_out, v_in); } +void cudaD_sum_column_ranges(dim3 Gr, dim3 Bl, double *data, MatrixDim dim, + const double *src_data, MatrixDim src_dim, + const Int32Pair *indices) { + _sum_column_ranges<<>>(data, dim, src_data, src_dim, indices); +} /* Some conversion kernels for which it's more convenient to not name them F or D. */ diff --git a/src/cudamatrix/cu-kernels.h b/src/cudamatrix/cu-kernels.h index 50312cc79..0cb1a42b4 100644 --- a/src/cudamatrix/cu-kernels.h +++ b/src/cudamatrix/cu-kernels.h @@ -207,7 +207,12 @@ inline void cuda_take_lower(dim3 Gr, dim3 Bl, const float* x, float* y, MatrixDi inline void cuda_take_upper(dim3 Gr, dim3 Bl, const float* x, float* y, MatrixDim d_in) { cudaF_take_upper(Gr,Bl,x,y,d_in); } inline void cuda_take_mean(dim3 Gr, dim3 Bl, const float* x, float* y, MatrixDim d_in) { cudaF_take_mean(Gr,Bl,x,y,d_in); } inline void cuda_comp_obj_deriv(dim3 Gr, dim3 Bl, MatrixElement* x, int32 size, const float* z, MatrixDim d, float* z2, MatrixDim d2, float* t) {cudaF_comp_obj_deriv(Gr,Bl,x,size,z,d,z2,d2,t); } -inline void cuda_comp_obj_deriv(dim3 Gr, dim3 Bl, MatrixElement* x, int32 size, const double* z, MatrixDim d, double* z2, MatrixDim d2, double* t) {cudaD_comp_obj_deriv(Gr,Bl,x,size,z,d,z2,d2,t); } +inline void cuda_sum_column_ranges(dim3 Gr, dim3 Bl, float *data, MatrixDim dim, + const float *src_data, MatrixDim src_dim, + const Int32Pair *indices) { + cudaF_sum_column_ranges(Gr, Bl, data, dim, src_data, src_dim, indices); +} + // double versions @@ -347,6 +352,11 @@ inline void cuda_copy_from_sp(int Gr, int Bl, const double* x, double* y, int d_ inline void cuda_take_lower(dim3 Gr, dim3 Bl, const double* x, double* y, MatrixDim d_in) { cudaD_take_lower(Gr,Bl,x,y,d_in); } inline void cuda_take_upper(dim3 Gr, dim3 Bl, const double* x, double* y, MatrixDim d_in) { cudaD_take_upper(Gr,Bl,x,y,d_in); } inline void cuda_take_mean(dim3 Gr, dim3 Bl, const double* x, double* y, MatrixDim d_in) { cudaD_take_mean(Gr,Bl,x,y,d_in); } +inline void cuda_comp_obj_deriv(dim3 Gr, dim3 Bl, MatrixElement* x, int32 size, const double* z, MatrixDim d, double* z2, MatrixDim d2, double* t) {cudaD_comp_obj_deriv(Gr,Bl,x,size,z,d,z2,d2,t); } +inline void cuda_sum_column_ranges(dim3 Gr, dim3 Bl, double *data, MatrixDim dim, + const double *src_data, MatrixDim src_dim, const Int32Pair *indices) { + cudaD_sum_column_ranges(Gr, Bl, data, dim, src_data, src_dim, indices); +} // Also include some template-friendly wrappers of cublas functions: diff --git a/src/cudamatrix/cu-math.h b/src/cudamatrix/cu-math.h index 76d7bd791..33feb4967 100644 --- a/src/cudamatrix/cu-math.h +++ b/src/cudamatrix/cu-math.h @@ -67,6 +67,7 @@ void Splice(const CuMatrix &src, /// The matrices src and tgt must have the same dimensions and /// the dimension of copy_from_indices must equal the number of columns /// in the src matrix. As a result, tgt(i, j) == src(i, copy_from_indices[j]). +/// Also see CuMatrix::CopyCols(), which is more general. template void Copy(const CuMatrix &src, const CuArray ©_from_indices, diff --git a/src/cudamatrix/cu-matrix-test.cc b/src/cudamatrix/cu-matrix-test.cc index 5434f8f72..9c0cbacee 100644 --- a/src/cudamatrix/cu-matrix-test.cc +++ b/src/cudamatrix/cu-matrix-test.cc @@ -340,6 +340,47 @@ template void UnitTestCuMatrixCopyCross2() { } } +template +static void UnitTestCuMatrixSumColumnRanges() { + for (MatrixIndexT p = 0; p < 10; p++) { + MatrixIndexT num_cols1 = 10 + rand() % 10, + num_cols2 = 10 + rand() % 10, + num_rows = 10 + rand() % 10; + Matrix src(num_rows, num_cols1); + Matrix dst(num_rows, num_cols2); + std::vector indices(num_cols2); + for (MatrixIndexT i = 0; i < num_cols2; i++) { + indices[i].first = rand() % num_cols1; + int32 headroom = num_cols1 - indices[i].first, + size = (rand() % headroom) + 1; + indices[i].second = indices[i].first + size; + KALDI_ASSERT(indices[i].second >= indices[i].first && + indices[i].second <= num_cols1 && + indices[i].first >= 0); + // In the test we allow second == first. + } + src.SetRandn(); + // Simple computation: + for (MatrixIndexT i = 0; i < num_rows; i++) { + for (MatrixIndexT j = 0; j < num_cols2; j++) { + int32 start = indices[j].first, end = indices[j].second; + Real sum = 0.0; + for (MatrixIndexT j2 = start; j2 < end; j2++) + sum += src(i, j2); + dst(i, j) = sum; + } + } + CuMatrix cu_src(src); + CuMatrix cu_dst(num_rows, num_cols2, kUndefined); + CuArray indices_tmp(indices); + cu_dst.SumColumnRanges(cu_src, indices_tmp); + Matrix dst2(cu_dst); + AssertEqual(dst, dst2); + } +} + + + template static void UnitTestCuMatrixCopyCols() { for (MatrixIndexT p = 0; p < 10; p++) { @@ -353,8 +394,13 @@ static void UnitTestCuMatrixCopyCols() { std::vector reorder(num_cols2); for (int32 i = 0; i < num_cols2; i++) reorder[i] = -1 + (rand() % (num_cols1 + 1)); - - N.CopyCols(M, reorder); + + if (rand() % 2 == 0) { + N.CopyCols(M, reorder); + } else { + CuArray cuda_reorder(reorder); + N.CopyCols(M, cuda_reorder); + } for (int32 i = 0; i < num_rows; i++) for (int32 j = 0; j < num_cols2; j++) @@ -1640,6 +1686,7 @@ template void CudaMatrixUnitTest() { UnitTestCuMatrixCopyFromTp(); UnitTestCuMatrixAddMatTp(); UnitTestCuMatrixCopyCols(); + UnitTestCuMatrixSumColumnRanges(); UnitTestCuMatrixCopyRows(); UnitTestCuMatrixCopyRowsFromVec(); UnitTestCuMatrixAddTpMat(); diff --git a/src/cudamatrix/cu-matrix.cc b/src/cudamatrix/cu-matrix.cc index b7f934e51..906116bfa 100644 --- a/src/cudamatrix/cu-matrix.cc +++ b/src/cudamatrix/cu-matrix.cc @@ -1622,6 +1622,7 @@ void VectorBase::CopyRowsFromMat(const CuMatrixBase &mat); template void VectorBase::CopyRowsFromMat(const CuMatrixBase &mat); + template void CuMatrixBase::CopyCols(const CuMatrixBase &src, const std::vector &reorder) { @@ -1650,6 +1651,30 @@ void CuMatrixBase::CopyCols(const CuMatrixBase &src, } } +template +void CuMatrixBase::CopyCols(const CuMatrixBase &src, + const CuArray &reorder) { +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + KALDI_ASSERT(reorder.Dim() == NumCols()); + KALDI_ASSERT(NumRows() == src.NumRows()); + Timer tim; + dim3 dimBlock(CU2DBLOCK, CU2DBLOCK); + // This kernel, as it is newer has the (x,y) dims as (rows,cols). + dim3 dimGrid(n_blocks(NumRows(), CU2DBLOCK), n_blocks(NumCols(), CU2DBLOCK)); + cuda_copy_cols(dimGrid, dimBlock, data_, src.Data(), reorder.Data(), Dim(), src.Stride()); + CU_SAFE_CALL(cudaGetLastError()); + CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed()); + } else +#endif + { + std::vector reorder_cpu; + reorder.CopyToVec(&reorder_cpu); + Mat().CopyCols(src.Mat(), reorder_cpu); + } +} + + template void CuMatrixBase::CopyRows(const CuMatrixBase &src, const std::vector &reorder) { @@ -1678,6 +1703,46 @@ void CuMatrixBase::CopyRows(const CuMatrixBase &src, } } + +template +void CuMatrixBase::SumColumnRanges(const CuMatrixBase &src, + const CuArray &indices) { + KALDI_ASSERT(static_cast(indices.Dim()) == NumCols()); + KALDI_ASSERT(NumRows() == src.NumRows()); + if (NumRows() == 0) return; +#if HAVE_CUDA == 1 + if (CuDevice::Instantiate().Enabled()) { + + Timer tim; + dim3 dimBlock(CU2DBLOCK, CU2DBLOCK); + // This kernel, as it is newer has the (x,y) dims as (rows,cols). + dim3 dimGrid(n_blocks(NumRows(), CU2DBLOCK), n_blocks(NumCols(), CU2DBLOCK)); + cuda_sum_column_ranges(dimGrid, dimBlock, data_, Dim(), src.Data(), src.Dim(), indices.Data()); + CU_SAFE_CALL(cudaGetLastError()); + CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed()); + } else +#endif + { // Implement here for the CPU.. + int32 num_rows = this->num_rows_, num_cols = this->num_cols_, + this_stride = this->stride_, src_stride = src.stride_; + Real *data = this->data_; + const Real *src_data = src.data_; + const Int32Pair *indices_data = indices.Data(); + for (int32 row = 0; row < num_rows; row++) { + for (int32 col = 0; col < num_cols; col++) { + int32 start_col = indices_data[col].first, + end_col = indices_data[col].second; + Real sum = 0.0; + for (int32 src_col = start_col; src_col < end_col; src_col++) + sum += src_data[row * src_stride + src_col]; + data[row * this_stride + col] = sum; + } + } + } +} + + + template void CuMatrixBase::CopyLowerToUpper() { KALDI_ASSERT(num_cols_ == num_rows_); diff --git a/src/cudamatrix/cu-matrix.h b/src/cudamatrix/cu-matrix.h index 1c5ff91f0..9ea0367aa 100644 --- a/src/cudamatrix/cu-matrix.h +++ b/src/cudamatrix/cu-matrix.h @@ -91,6 +91,11 @@ class CuMatrixBase { void CopyCols(const CuMatrixBase &src, const std::vector &indices); + /// Version of CopyCols that takes CuArray argument. + void CopyCols(const CuMatrixBase &src, + const CuArray &indices); + + /// Copies row r from row indices[r] of src. /// As a special case, if indexes[i] <== -1, sets row i to zero /// "reorder".size() must equal this->NumRows(), @@ -100,6 +105,13 @@ class CuMatrixBase { const std::vector &indices); + /// For each row r of this and for each column c, sets (*this)(r, c) to the + /// sum \sum_j src(r, j), where j ranges from indices[c].first through + /// indices[c].second - 1. + void SumColumnRanges(const CuMatrixBase &src, + const CuArray &indices); + + friend Real TraceMatMat(const CuMatrixBase &A, const CuMatrixBase &B, MatrixTransposeType trans); diff --git a/src/cudamatrix/cu-matrixdim.h b/src/cudamatrix/cu-matrixdim.h index a8bef28dd..32df913b4 100644 --- a/src/cudamatrix/cu-matrixdim.h +++ b/src/cudamatrix/cu-matrixdim.h @@ -81,6 +81,10 @@ extern "C" { // decided to make this a void* pointer. } CuBlockMatrixData; + typedef struct Int32Pair { + int32_cuda first; + int32_cuda second; + } Int32Pair; } #endif diff --git a/src/cudamatrix/cu-vector.h b/src/cudamatrix/cu-vector.h index d270c01b3..16c3f0702 100644 --- a/src/cudamatrix/cu-vector.h +++ b/src/cudamatrix/cu-vector.h @@ -221,6 +221,7 @@ class CuVector: public CuVectorBase { CuVector(MatrixIndexT dim, MatrixResizeType t = kSetZero) { Resize(dim, t); } CuVector(const CuVectorBase &v); + CuVector(const VectorBase &v); explicit CuVector(const CuVector &v) : CuVectorBase() { Resize(v.Dim(), kUndefined); diff --git a/src/fstext/determinize-lattice-inl.h b/src/fstext/determinize-lattice-inl.h index 7dd2c407a..1ac954a3a 100644 --- a/src/fstext/determinize-lattice-inl.h +++ b/src/fstext/determinize-lattice-inl.h @@ -226,12 +226,14 @@ template class LatticeStringRepository { typedef unordered_set SetType; void RebuildHelper(const Entry *to_add, SetType *tmp_set) { - if (to_add == NULL) return; - else { + while(true) { + if (to_add == NULL) return; typename SetType::iterator iter = tmp_set->find(to_add); if (iter == tmp_set->end()) { // not in tmp_set. tmp_set->insert(to_add); - RebuildHelper(to_add->parent, tmp_set); // make sure parent there. + to_add = to_add->parent; // and loop. + } else { + return; } } } diff --git a/src/fstext/determinize-star-inl.h b/src/fstext/determinize-star-inl.h index 7810a0518..5bd9ddf85 100644 --- a/src/fstext/determinize-star-inl.h +++ b/src/fstext/determinize-star-inl.h @@ -88,7 +88,7 @@ template class StringRepository { else if (id>=single_symbol_start) { v->resize(1); (*v)[0] = id - single_symbol_start; } else { - assert(id>=string_start && id < static_cast(vec_.size())); + assert(id >= string_start && id < static_cast(vec_.size())); *v = *(vec_[id]); } } diff --git a/src/nnet/Makefile b/src/nnet/Makefile index 851edf044..bbd3e32c9 100644 --- a/src/nnet/Makefile +++ b/src/nnet/Makefile @@ -10,7 +10,7 @@ LDLIBS += $(CUDA_LDLIBS) TESTFILES = nnet-test OBJFILES = nnet-nnet.o nnet-component.o nnet-loss.o nnet-cache.o \ - nnet-cache-tgtmat.o nnet-loss-prior.o nnet-pdf-prior.o + nnet-cache-tgtmat.o nnet-cache-conf.o nnet-loss-prior.o nnet-pdf-prior.o LIBNAME = kaldi-nnet diff --git a/src/nnet/nnet-activation.h b/src/nnet/nnet-activation.h index dfe141b7c..480f6eb19 100644 --- a/src/nnet/nnet-activation.h +++ b/src/nnet/nnet-activation.h @@ -1,6 +1,6 @@ // nnet/nnet-activation.h -// Copyright 2011 Karel Vesely +// Copyright 2011-2013 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -30,15 +30,14 @@ namespace nnet1 { class Softmax : public Component { public: - Softmax(int32 dim_in, int32 dim_out, Nnet *nnet) - : Component(dim_in, dim_out, nnet) + Softmax(int32 dim_in, int32 dim_out) + : Component(dim_in, dim_out) { } ~Softmax() { } - ComponentType GetType() const { - return kSoftmax; - } + Component* Copy() const { return new Softmax(*this); } + ComponentType GetType() const { return kSoftmax; } void PropagateFnc(const CuMatrix &in, CuMatrix *out) { // y = e^x_j/sum_j(e^x_j) @@ -60,15 +59,14 @@ class Softmax : public Component { class Sigmoid : public Component { public: - Sigmoid(int32 dim_in, int32 dim_out, Nnet *nnet) - : Component(dim_in, dim_out, nnet) + Sigmoid(int32 dim_in, int32 dim_out) + : Component(dim_in, dim_out) { } ~Sigmoid() { } - ComponentType GetType() const { - return kSigmoid; - } + Component* Copy() const { return new Sigmoid(*this); } + ComponentType GetType() const { return kSigmoid; } void PropagateFnc(const CuMatrix &in, CuMatrix *out) { // y = 1/(1+e^-x) @@ -86,15 +84,14 @@ class Sigmoid : public Component { class Tanh : public Component { public: - Tanh(int32 dim_in, int32 dim_out, Nnet *nnet) - : Component(dim_in, dim_out, nnet) + Tanh(int32 dim_in, int32 dim_out) + : Component(dim_in, dim_out) { } ~Tanh() { } - ComponentType GetType() const { - return kTanh; - } + Component* Copy() const { return new Tanh(*this); } + ComponentType GetType() const { return kTanh; } void PropagateFnc(const CuMatrix &in, CuMatrix *out) { // y = (e^x - e^(-x)) / (e^x + e^(-x)) @@ -112,15 +109,14 @@ class Tanh : public Component { class Dropout : public Component { public: - Dropout(int32 dim_in, int32 dim_out, Nnet *nnet): - Component(dim_in, dim_out, nnet) + Dropout(int32 dim_in, int32 dim_out): + Component(dim_in, dim_out) { } ~Dropout() { } - ComponentType GetType() const { - return kDropout; - } + Component* Copy() const { return new Dropout(*this); } + ComponentType GetType() const { return kDropout; } void PropagateFnc(const CuMatrix &in, CuMatrix *out) { out->CopyFromMat(in); diff --git a/src/nnet/nnet-affine-transform.h b/src/nnet/nnet-affine-transform.h index 971c47880..490ece0ac 100644 --- a/src/nnet/nnet-affine-transform.h +++ b/src/nnet/nnet-affine-transform.h @@ -1,6 +1,6 @@ // nnet/nnet-affine-transform.h -// Copyright 2011 Karel Vesely +// Copyright 2011 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -31,17 +31,16 @@ namespace nnet1 { class AffineTransform : public UpdatableComponent { public: - AffineTransform(int32 dim_in, int32 dim_out, Nnet *nnet) - : UpdatableComponent(dim_in, dim_out, nnet), + AffineTransform(int32 dim_in, int32 dim_out) + : UpdatableComponent(dim_in, dim_out), linearity_(dim_out, dim_in), bias_(dim_out), linearity_corr_(dim_out, dim_in), bias_corr_(dim_out) { } ~AffineTransform() { } - ComponentType GetType() const { - return kAffineTransform; - } + Component* Copy() const { return new AffineTransform(*this); } + ComponentType GetType() const { return kAffineTransform; } void ReadData(std::istream &is, bool binary) { linearity_.Read(is, binary); diff --git a/src/nnet/nnet-cache-conf.cc b/src/nnet/nnet-cache-conf.cc new file mode 100644 index 000000000..567d8e855 --- /dev/null +++ b/src/nnet/nnet-cache-conf.cc @@ -0,0 +1,247 @@ +// nnet/nnet-cache-conf.cc + +// Copyright 2013 Brno University of Technology (author: Karel Vesely) + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "nnet/nnet-cache-conf.h" + +#include "cudamatrix/cu-math.h" + +#include + +namespace kaldi { +namespace nnet1 { + + + +void CacheConf::Init(int32 cachesize, int32 bunchsize) { + + KALDI_ASSERT(cachesize>0); + if(cachesize > 8388479) { + KALDI_ERR << "CacheConfsize " << cachesize << " too large, use cachesize smaller than 8388480."; + } + KALDI_ASSERT(bunchsize>0); + KALDI_ASSERT(cachesize>=bunchsize); + + if ((cachesize % bunchsize) != 0) { + KALDI_ERR << "Non divisible cachesize by bunchsize"; + } + + cachesize_ = cachesize; + bunchsize_ = bunchsize; + + state_ = EMPTY; + + filling_pos_ = 0; + emptying_pos_ = 0; + + randomized_ = false; +} + + + +void CacheConf::AddData(const CuMatrix &features, const std::vector &targets, const Vector &confidence) { + if (state_ == FULL) { + KALDI_ERR << "Cannot add data, cache already full"; + } + + KALDI_ASSERT(features.NumRows() == static_cast(targets.size())); + KALDI_ASSERT(features.NumRows() == static_cast(confidence.Dim())); + + int32 dim_fea = features.NumCols(); + + // lazy buffers allocation + if (features_.NumRows() != cachesize_) { + features_.Resize(cachesize_, dim_fea); + targets_.resize(cachesize_); + confidence_.Resize(cachesize_); + } + + // warn if segment longer than half-cache + // (frame level shuffling accross sentences will be poor) + if (features.NumRows() > cachesize_/4) { + KALDI_WARN << "Too long segment or small cachesize!" + << " (cache-size " << cachesize_ << ") < (4 x" + << " segment-size " << features.NumRows() << ")."; + } + + // change state + if (state_ == EMPTY) { + state_ = FILLING; filling_pos_ = 0; + + // check for leftover from previous segment + int leftover = features_leftover_.NumRows(); + // check if leftover is not bigger than half-cachesize + if (leftover > cachesize_/2) { + KALDI_WARN << "Truncating " + << leftover - cachesize_/2 + << " frames from leftover of previous segment " + << "(max leftover " << cachesize_/2 << ")."; + leftover = cachesize_/2; + } + // prefill cache with leftover + if (leftover > 0) { + features_.RowRange(0, leftover).CopyFromMat( + features_leftover_.RowRange(0, leftover) + ); + + std::copy(targets_leftover_.begin(), + targets_leftover_.begin() + leftover, + targets_.begin()); + + confidence_.Range(0, leftover).CopyFromVec( + confidence_leftover_.Range(0, leftover) + ); + + features_leftover_.Resize(0, 0); + targets_leftover_.resize(0); + confidence_leftover_.Resize(0); + filling_pos_ += leftover; + } + } + + KALDI_ASSERT(state_ == FILLING); + KALDI_ASSERT(features.NumRows() == static_cast(targets.size())); + + int cache_space = cachesize_ - filling_pos_; + int feature_length = features.NumRows(); + int fill_rows = (cache_space 0); + + // copy the data to cache + features_.RowRange(filling_pos_, fill_rows).CopyFromMat( + features.RowRange(0, fill_rows) + ); + + std::copy(targets.begin(), + targets.begin()+fill_rows, + targets_.begin()+filling_pos_); + + confidence_.Range(filling_pos_,fill_rows). + CopyFromVec(confidence.Range(0,fill_rows)); + + // copy leftovers + if (leftover > 0) { + features_leftover_.Resize(leftover, dim_fea); + features_leftover_.CopyFromMat( + features.RowRange(fill_rows, leftover) + ); + + KALDI_ASSERT(targets.end()-(targets.begin()+fill_rows)==leftover); + targets_leftover_.resize(leftover); + std::copy(targets.begin()+fill_rows, + targets.end(), + targets_leftover_.begin()); + + confidence_leftover_.Resize(leftover); + confidence_leftover_.CopyFromVec(confidence.Range(fill_rows,leftover)); + } + + // update cursor + filling_pos_ += fill_rows; + + // change state + if (filling_pos_ == cachesize_) { + state_ = FULL; + } +} + + + +void CacheConf::Randomize() { + KALDI_ASSERT(state_ == FULL || state_ == FILLING); + + // lazy initialization of the output buffers + features_random_.Resize(cachesize_, features_.NumCols()); + targets_random_.resize(cachesize_); + confidence_random_.Resize(cachesize_); + + // generate random series of integers + randmask_.resize(filling_pos_); + GenerateRandom randomizer; + for(int32 i=0; i *features, std::vector *targets, Vector *confidence) { + if (state_ == EMPTY) { + KALDI_ERR << "GetBunch on empty cache!!!"; + } + + // change state if full... + if (state_ == FULL) { + state_ = EMPTYING; emptying_pos_ = 0; + } + + // final cache is not completely filled + if (state_ == FILLING) { + state_ = EMPTYING; emptying_pos_ = 0; + } + + KALDI_ASSERT(state_ == EMPTYING); + + const CuMatrixBase &features_ref = (randomized_ ? + features_random_ : features_); + const std::vector &targets_ref = (randomized_ ? + targets_random_ : targets_); + const Vector &confidence_ref = (randomized_ ? + confidence_random_ : confidence_); + + // init the output + features->Resize(bunchsize_, features_.NumCols()); + targets->resize(bunchsize_); + confidence->Resize(bunchsize_); + + // copy the output + features->CopyFromMat(features_ref.RowRange(emptying_pos_, bunchsize_)); + + std::copy(targets_ref.begin() + emptying_pos_, + targets_ref.begin() + emptying_pos_ + bunchsize_, + targets->begin()); + + confidence->CopyFromVec(confidence_ref.Range(emptying_pos_, bunchsize_)); + + // update position + emptying_pos_ += bunchsize_; + + // If we're done, change state to EMPTY + if (emptying_pos_ > filling_pos_ - bunchsize_) { + // we don't have more complete bunches... + state_ = EMPTY; + } +} + + +} // namespace nnet1 +} // namespace kaldi diff --git a/src/nnet/nnet-cache-conf.h b/src/nnet/nnet-cache-conf.h new file mode 100644 index 000000000..594b4c4b3 --- /dev/null +++ b/src/nnet/nnet-cache-conf.h @@ -0,0 +1,107 @@ +// nnet/nnet-cache-conf.h + +// Copyright 2012 Brno University of Technology (author: Karel Vesely) + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_NNET_NNET_CACHE_CONF_H_ +#define KALDI_NNET_NNET_CACHE_CONF_H_ + +#include "base/kaldi-math.h" +#include "cudamatrix/cu-matrix.h" +#include "cudamatrix/cu-math.h" + +namespace kaldi { +namespace nnet1 { + +/** + * The feature-target pair cache + */ +class CacheConf { + typedef enum { EMPTY, FILLING, FULL, EMPTYING } State; + + public: + CacheConf() : state_(EMPTY), filling_pos_(0), emptying_pos_(0), + cachesize_(0), bunchsize_(0), randomized_(false) + { } + ~CacheConf() { } + + /// Initialize the cache + void Init(int32 cachesize, int32 bunchsize); + + /// Add data to cache + void AddData(const CuMatrix &features, const std::vector &targets, const Vector &confidence); + /// Randomizes the cache + void Randomize(); + /// Get the bunch of training data from cache + void GetBunch(CuMatrix *features, std::vector *targets, Vector *confidence); + + + /// Returns true if the cache was completely filled + bool Full() { + return (state_ == FULL); + } + + /// Returns true if the cache is empty + bool Empty() { + return (state_ == EMPTY || filling_pos_ < bunchsize_); + } + + /// Returns true if the cache is empty + bool Randomized() { + return randomized_; + } + + + private: + struct GenerateRandom { + int32 operator()(int32 max) const { + // return lrand48() % max; + return RandInt(0, max-1); + } + }; + + State state_; ///< Current state of the cache + + int32 filling_pos_; ///< Number of frames filled to cache by AddData + int32 emptying_pos_; ///< Number of frames given by cache by GetBunch + + int32 cachesize_; ///< Size of cache + int32 bunchsize_; ///< Size of bunch + + bool randomized_; + + CuMatrix features_; ///< Feature cache + CuMatrix features_random_; ///< Feature cache + CuMatrix features_leftover_; ///< Feature cache + + std::vector targets_; ///< Desired vector cache + std::vector targets_random_; ///< Desired vector cache + std::vector targets_leftover_; ///< Desired vector cache + + Vector confidence_; + Vector confidence_random_; + Vector confidence_leftover_; + + std::vector randmask_; + CuArray randmask_device_; + +}; + + +} // namespace nnet1 +} // namespace kaldi + +#endif diff --git a/src/nnet/nnet-cache-tgtmat.cc b/src/nnet/nnet-cache-tgtmat.cc index 00051ec6c..b514984cd 100644 --- a/src/nnet/nnet-cache-tgtmat.cc +++ b/src/nnet/nnet-cache-tgtmat.cc @@ -1,6 +1,6 @@ // nnet/nnet-cache-tgtmat.cc -// Copyright 2011 Karel Vesely +// Copyright 2011 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet/nnet-cache-tgtmat.h b/src/nnet/nnet-cache-tgtmat.h index 63292fd11..741ce3a29 100644 --- a/src/nnet/nnet-cache-tgtmat.h +++ b/src/nnet/nnet-cache-tgtmat.h @@ -1,6 +1,6 @@ // nnet/nnet-cache-tgtmat.h -// Copyright 2012 Karel Vesely +// Copyright 2012 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet/nnet-cache.cc b/src/nnet/nnet-cache.cc index 01c480e38..eff3751a2 100644 --- a/src/nnet/nnet-cache.cc +++ b/src/nnet/nnet-cache.cc @@ -1,6 +1,6 @@ // nnet/nnet-cache.cc -// Copyright 2011 Karel Vesely +// Copyright 2011 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet/nnet-cache.h b/src/nnet/nnet-cache.h index 1907d7925..8f80d5d7e 100644 --- a/src/nnet/nnet-cache.h +++ b/src/nnet/nnet-cache.h @@ -1,6 +1,6 @@ // nnet/nnet-cache.h -// Copyright 2012 Karel Vesely +// Copyright 2012 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet/nnet-component.cc b/src/nnet/nnet-component.cc index 9d88cf6c1..c8cd1ffb4 100644 --- a/src/nnet/nnet-component.cc +++ b/src/nnet/nnet-component.cc @@ -1,6 +1,6 @@ // nnet/nnet-component.cc -// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) +// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -64,14 +64,24 @@ Component::ComponentType Component::MarkerToType(const std::string &s) { } -Component* Component::Read(std::istream &is, bool binary, Nnet *nnet) { +Component* Component::Read(std::istream &is, bool binary) { int32 dim_out, dim_in; std::string token; int first_char = Peek(is, binary); if (first_char == EOF) return NULL; - ReadToken(is, binary, &token); + ReadToken(is, binary, &token); + + // Skip optional initial token + if(token == "") { + ReadToken(is, binary, &token); // Next token is a Component + } + // Finish reading when optional terminal token appears + if(token == "") { + return NULL; + } + Component::ComponentType comp_type = Component::MarkerToType(token); ReadBasicType(is, binary, &dim_out); @@ -80,34 +90,34 @@ Component* Component::Read(std::istream &is, bool binary, Nnet *nnet) { Component *p_comp=NULL; switch (comp_type) { case Component::kAffineTransform : - p_comp = new AffineTransform(dim_in, dim_out, nnet); + p_comp = new AffineTransform(dim_in, dim_out); break; case Component::kSoftmax : - p_comp = new Softmax(dim_in, dim_out, nnet); + p_comp = new Softmax(dim_in, dim_out); break; case Component::kSigmoid : - p_comp = new Sigmoid(dim_in, dim_out, nnet); + p_comp = new Sigmoid(dim_in, dim_out); break; case Component::kTanh : - p_comp = new Tanh(dim_in, dim_out, nnet); + p_comp = new Tanh(dim_in, dim_out); break; case Component::kDropout : - p_comp = new Dropout(dim_in, dim_out, nnet); + p_comp = new Dropout(dim_in, dim_out); break; case Component::kRbm : - p_comp = new Rbm(dim_in, dim_out, nnet); + p_comp = new Rbm(dim_in, dim_out); break; case Component::kSplice : - p_comp = new Splice(dim_in, dim_out, nnet); + p_comp = new Splice(dim_in, dim_out); break; case Component::kCopy : - p_comp = new Copy(dim_in, dim_out, nnet); + p_comp = new CopyComponent(dim_in, dim_out); break; case Component::kAddShift : - p_comp = new AddShift(dim_in, dim_out, nnet); + p_comp = new AddShift(dim_in, dim_out); break; case Component::kRescale : - p_comp = new Rescale(dim_in, dim_out, nnet); + p_comp = new Rescale(dim_in, dim_out); break; case Component::kUnknown : default : diff --git a/src/nnet/nnet-component.h b/src/nnet/nnet-component.h index 4839ee4e5..6722089da 100644 --- a/src/nnet/nnet-component.h +++ b/src/nnet/nnet-component.h @@ -1,6 +1,6 @@ // nnet/nnet-component.h -// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) +// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -34,27 +34,17 @@ namespace kaldi { namespace nnet1 { -// declare the nnet class so we can declare pointer -struct NnetTrainOptions; -class Nnet; - - /** - * Abstract class, basic element of the network, - * it is a box with defined inputs, outputs, - * and tranformation functions interface. - * - * It is able to propagate and backpropagate - * exact implementation is to be implemented in descendants. - * - * The data buffers are not included - * and will be managed from outside. + * Abstract class, building block of the network. + * It is able to propagate (PropagateFnc: compute the output based on its input) + * and backpropagate (BackpropagateFnc: i.e. transform loss derivative w.r.t. output to derivative w.r.t. the input) + * the formulas are implemented in descendant classes (AffineTransform,Sigmoid,Softmax,...). */ class Component { - // Polymorphic Component RTTI + /// Component type identification mechanism public: - /// Types of the net components + /// Types of the Components typedef enum { kUnknown = 0x0, @@ -77,23 +67,27 @@ class Component { kRescale, kLog } ComponentType; - /// Pair of type and marker + /// A pair of type and marker struct key_value { const Component::ComponentType key; const char *value; }; - /// Mapping of types and markers + /// Mapping of types and markers (the table is defined in nnet-component.cc) static const struct key_value kMarkerMap[]; /// Convert component type to marker static const char* TypeToMarker(ComponentType t); /// Convert marker to component type static ComponentType MarkerToType(const std::string &s); - - Component(int32 input_dim, int32 output_dim, Nnet *nnet) - : input_dim_(input_dim), output_dim_(output_dim), nnet_(nnet) { } - virtual ~Component() { } - + + /// General interface of a component public: + Component(int32 input_dim, int32 output_dim) + : input_dim_(input_dim), output_dim_(output_dim) { } + virtual ~Component() { } + + /// Copy component (deep copy). + virtual Component* Copy() const = 0; + /// Get Type Identification of the component virtual ComponentType GetType() const = 0; /// Check if contains trainable parameters @@ -110,28 +104,29 @@ class Component { return output_dim_; } - /// Perform forward pass propagateion Input->Output + /// Perform forward pass propagation Input->Output void Propagate(const CuMatrix &in, CuMatrix *out); /// Perform backward pass propagation, out_diff -> in_diff - /// '&in' and '&out' will often be unused... + /// '&in' and '&out' will sometimes be unused... void Backpropagate(const CuMatrix &in, const CuMatrix &out, const CuMatrix &out_diff, CuMatrix *in_diff); /// Read component from stream - static Component* Read(std::istream &is, bool binary, Nnet *nnet); + static Component* Read(std::istream &is, bool binary); /// Write component to stream void Write(std::ostream &os, bool binary) const; /// Optionally print some additional info virtual std::string Info() const { return ""; } - // abstract interface for propagation/backpropagation + + /// Abstract interface for propagation/backpropagation protected: - /// Forward pass transformation (to be implemented by descendents...) + /// Forward pass transformation (to be implemented by descending class...) virtual void PropagateFnc(const CuMatrix &in, CuMatrix *out) = 0; - /// Backward pass transformation (to be implemented by descendents...) + /// Backward pass transformation (to be implemented by descending class...) virtual void BackpropagateFnc(const CuMatrix &in, const CuMatrix &out, const CuMatrix &out_diff, @@ -144,26 +139,24 @@ class Component { virtual void WriteData(std::ostream &os, bool binary) const { } - // data members + /// Data members protected: int32 input_dim_; ///< Size of input vectors int32 output_dim_; ///< Size of output vectors - Nnet *nnet_; ///< Pointer to the whole network - private: - KALDI_DISALLOW_COPY_AND_ASSIGN(Component); + protected: + //KALDI_DISALLOW_COPY_AND_ASSIGN(Component); }; /** - * Class UpdatableComponent is a Component which has - * trainable parameters and contains SGD training - * hyper-parameters (learnrate, momenutm, L2, L1) + * Class UpdatableComponent is a Component which has trainable parameters, + * contains SGD training hyper-parameters in NnetTrainOptions. */ class UpdatableComponent : public Component { public: - UpdatableComponent(int32 input_dim, int32 output_dim, Nnet *nnet) - : Component(input_dim, output_dim, nnet) { } + UpdatableComponent(int32 input_dim, int32 output_dim) + : Component(input_dim, output_dim) { } virtual ~UpdatableComponent() { } /// Check if contains trainable parameters @@ -176,7 +169,7 @@ class UpdatableComponent : public Component { const CuMatrix &diff) = 0; /// Sets the training options to the component - void SetTrainOptions(const NnetTrainOptions &opts) { + virtual void SetTrainOptions(const NnetTrainOptions &opts) { opts_ = opts; } /// Gets the training options from the component @@ -190,18 +183,17 @@ class UpdatableComponent : public Component { }; - - inline void Component::Propagate(const CuMatrix &in, CuMatrix *out) { + // Check the dims if (input_dim_ != in.NumCols()) { - KALDI_ERR << "Nonmatching dims, component:" << input_dim_ << " data:" << in.NumCols(); + KALDI_ERR << "Non-matching dims, component:" << input_dim_ << " data:" << in.NumCols(); } - + // Allocate target buffer if (output_dim_ != out->NumCols() || in.NumRows() != out->NumRows()) { out->Resize(in.NumRows(), output_dim_); } - + // Call the propagation implementation of the component PropagateFnc(in, out); } @@ -210,27 +202,26 @@ inline void Component::Backpropagate(const CuMatrix &in, const CuMatrix &out, const CuMatrix &out_diff, CuMatrix *in_diff) { - //check the dims + // Check the dims if (output_dim_ != out_diff.NumCols()) { - KALDI_ERR << "Nonmatching output dims, component:" << output_dim_ + KALDI_ERR << "Non-matching output dims, component:" << output_dim_ << " data:" << out_diff.NumCols(); } - //allocate buffer + // Allocate target buffer if (input_dim_ != in_diff->NumCols() || out_diff.NumRows() != in_diff->NumRows()) { in_diff->Resize(out_diff.NumRows(), input_dim_); } - //asserts on the dims + // Asserts on the dims KALDI_ASSERT((in.NumRows() == out.NumRows()) && (in.NumRows() == out_diff.NumRows()) && (in.NumRows() == in_diff->NumRows())); KALDI_ASSERT(in.NumCols() == in_diff->NumCols()); KALDI_ASSERT(out.NumCols() == out_diff.NumCols()); - //call the backprop implementation of the component + // Call the backprop implementation of the component BackpropagateFnc(in, out, out_diff, in_diff); } - } // namespace nnet1 } // namespace kaldi diff --git a/src/nnet/nnet-loss-prior.cc b/src/nnet/nnet-loss-prior.cc index 78eb5ff2c..5938caa00 100644 --- a/src/nnet/nnet-loss-prior.cc +++ b/src/nnet/nnet-loss-prior.cc @@ -1,6 +1,6 @@ // nnet/nnet-loss-prior.cc -// Copyright 2012 Karel Vesely +// Copyright 2012 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet/nnet-loss-prior.h b/src/nnet/nnet-loss-prior.h index 4f23b6ceb..6fadcce33 100644 --- a/src/nnet/nnet-loss-prior.h +++ b/src/nnet/nnet-loss-prior.h @@ -1,6 +1,6 @@ // nnet/nnet-loss-prior.h -// Copyright 2012 Karel Vesely +// Copyright 2012 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet/nnet-loss.cc b/src/nnet/nnet-loss.cc index 7c7231f02..9ce8659f8 100644 --- a/src/nnet/nnet-loss.cc +++ b/src/nnet/nnet-loss.cc @@ -1,6 +1,6 @@ // nnet/nnet-loss.cc -// Copyright 2011 Karel Vesely +// Copyright 2011 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet/nnet-loss.h b/src/nnet/nnet-loss.h index 02d401381..ac19b8d17 100644 --- a/src/nnet/nnet-loss.h +++ b/src/nnet/nnet-loss.h @@ -1,6 +1,6 @@ // nnet/nnet-loss.h -// Copyright 2011 Karel Vesely +// Copyright 2011 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // diff --git a/src/nnet/nnet-nnet.cc b/src/nnet/nnet-nnet.cc index d4eaed44e..05fb39dc3 100644 --- a/src/nnet/nnet-nnet.cc +++ b/src/nnet/nnet-nnet.cc @@ -1,6 +1,6 @@ // nnet/nnet-nnet.cc -// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) +// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -27,59 +27,66 @@ namespace kaldi { namespace nnet1 { +Nnet::~Nnet() { + for(int32 i=0; i &in, CuMatrix *out) { KALDI_ASSERT(NULL != out); - if (LayerCount() == 0) { + if (NumComponents() == 0) { out->Resize(in.NumRows(), in.NumCols()); out->CopyFromMat(in); return; } // we need at least L+1 input buffers - KALDI_ASSERT((int32)propagate_buf_.size() >= LayerCount()+1); + KALDI_ASSERT((int32)propagate_buf_.size() >= NumComponents()+1); propagate_buf_[0].Resize(in.NumRows(), in.NumCols()); propagate_buf_[0].CopyFromMat(in); - for(int32 i=0; i<(int32)nnet_.size(); i++) { - nnet_[i]->Propagate(propagate_buf_[i], &propagate_buf_[i+1]); + for(int32 i=0; i<(int32)components_.size(); i++) { + components_[i]->Propagate(propagate_buf_[i], &propagate_buf_[i+1]); } - CuMatrix &mat = propagate_buf_[nnet_.size()]; + CuMatrix &mat = propagate_buf_[components_.size()]; out->Resize(mat.NumRows(), mat.NumCols()); out->CopyFromMat(mat); } void Nnet::Backpropagate(const CuMatrix &out_diff, CuMatrix *in_diff) { - if(LayerCount() == 0) { KALDI_ERR << "Cannot backpropagate on empty network"; } + if(NumComponents() == 0) { KALDI_ERR << "Cannot backpropagate on empty network"; } // we need at least L+1 input bufers - KALDI_ASSERT((int32)propagate_buf_.size() >= LayerCount()+1); + KALDI_ASSERT((int32)propagate_buf_.size() >= NumComponents()+1); // we need at least L-1 error derivative bufers - KALDI_ASSERT((int32)backpropagate_buf_.size() >= LayerCount()-1); + KALDI_ASSERT((int32)backpropagate_buf_.size() >= NumComponents()-1); ////////////////////////////////////// // Backpropagation // - // don't copy the out_diff to buffers, use it as is... - int32 i = nnet_.size()-1; - nnet_.back()->Backpropagate(propagate_buf_[i], propagate_buf_[i+1], + // we don't copy the out_diff to buffers, we use it as it is... + int32 i = components_.size()-1; + components_.back()->Backpropagate(propagate_buf_[i], propagate_buf_[i+1], out_diff, &backpropagate_buf_[i-1]); - if (nnet_[i]->IsUpdatable()) { - UpdatableComponent *uc = dynamic_cast(nnet_[i]); + if (components_[i]->IsUpdatable()) { + UpdatableComponent *uc = dynamic_cast(components_[i]); uc->Update(propagate_buf_[i], out_diff); } // backpropagate by using buffers for(i--; i >= 1; i--) { - nnet_[i]->Backpropagate(propagate_buf_[i], propagate_buf_[i+1], + components_[i]->Backpropagate(propagate_buf_[i], propagate_buf_[i+1], backpropagate_buf_[i], &backpropagate_buf_[i-1]); - if (nnet_[i]->IsUpdatable()) { - UpdatableComponent *uc = dynamic_cast(nnet_[i]); + if (components_[i]->IsUpdatable()) { + UpdatableComponent *uc = dynamic_cast(components_[i]); uc->Update(propagate_buf_[i], backpropagate_buf_[i]); } } @@ -87,13 +94,13 @@ void Nnet::Backpropagate(const CuMatrix &out_diff, CuMatrixBackpropagate(propagate_buf_[0], propagate_buf_[1], + components_[0]->Backpropagate(propagate_buf_[0], propagate_buf_[1], backpropagate_buf_[0], in_diff); } // update the first layer - if (nnet_[0]->IsUpdatable()) { - UpdatableComponent *uc = dynamic_cast(nnet_[0]); + if (components_[0]->IsUpdatable()) { + UpdatableComponent *uc = dynamic_cast(components_[0]); uc->Update(propagate_buf_[0], backpropagate_buf_[0]); } @@ -106,14 +113,14 @@ void Nnet::Backpropagate(const CuMatrix &out_diff, CuMatrix &in, CuMatrix *out) { KALDI_ASSERT(NULL != out); - if (LayerCount() == 0) { + if (NumComponents() == 0) { out->Resize(in.NumRows(), in.NumCols()); out->CopyFromMat(in); return; } - if (LayerCount() == 1) { - nnet_[0]->Propagate(in, out); + if (NumComponents() == 1) { + components_[0]->Propagate(in, out); return; } @@ -122,27 +129,81 @@ void Nnet::Feedforward(const CuMatrix &in, CuMatrix *out) // propagate by using exactly 2 auxiliary buffers int32 L = 0; - nnet_[L]->Propagate(in, &propagate_buf_[L%2]); - for(L++; L<=LayerCount()-2; L++) { - nnet_[L]->Propagate(propagate_buf_[(L-1)%2], &propagate_buf_[L%2]); + components_[L]->Propagate(in, &propagate_buf_[L%2]); + for(L++; L<=NumComponents()-2; L++) { + components_[L]->Propagate(propagate_buf_[(L-1)%2], &propagate_buf_[L%2]); } - nnet_[L]->Propagate(propagate_buf_[(L-1)%2], out); + components_[L]->Propagate(propagate_buf_[(L-1)%2], out); // release the buffers we don't need anymore propagate_buf_[0].Resize(0,0); propagate_buf_[1].Resize(0,0); } + +int32 Nnet::OutputDim() const { + KALDI_ASSERT(!components_.empty()); + return components_.back()->OutputDim(); +} + +int32 Nnet::InputDim() const { + KALDI_ASSERT(!components_.empty()); + return components_.front()->InputDim(); +} + +const Component& Nnet::GetComponent(int32 component) const { + KALDI_ASSERT(static_cast(component) < components_.size()); + return *(components_[component]); +} + +Component& Nnet::GetComponent(int32 component) { + KALDI_ASSERT(static_cast(component) < components_.size()); + return *(components_[component]); +} + +void Nnet::SetComponent(int32 c, Component *component) { + KALDI_ASSERT(static_cast(c) < components_.size()); + delete components_[c]; + components_[c] = component; + Check(); // Check that all the dimensions still match up. +} + +void Nnet::AppendComponent(Component* dynamically_allocated_comp) { + components_.push_back(dynamically_allocated_comp); + Check(); +} + +void Nnet::AppendNnet(const Nnet& nnet_to_append) { + for(int32 i=0; i* wei_copy) { wei_copy->Resize(NumParams()); int32 pos = 0; //copy the params - for(int32 n=0; nIsUpdatable()) { - switch(nnet_[n]->GetType()) { + for(int32 n=0; nIsUpdatable()) { + switch(components_[n]->GetType()) { case Component::kAffineTransform : { //get the weights from CuMatrix to Matrix const CuMatrix& cu_mat = - dynamic_cast(nnet_[n])->GetLinearity(); + dynamic_cast(components_[n])->GetLinearity(); Matrix mat(cu_mat.NumRows(),cu_mat.NumCols()); cu_mat.CopyToMat(&mat); //copy the the matrix row-by-row to the vector @@ -151,7 +212,7 @@ void Nnet::GetWeights(Vector* wei_copy) { pos += mat_size; //get the biases from CuVector to Vector const CuVector& cu_vec = - dynamic_cast(nnet_[n])->GetBias(); + dynamic_cast(components_[n])->GetBias(); Vector vec(cu_vec.Dim()); cu_vec.CopyToVec(&vec); //append biases to the supervector @@ -161,7 +222,7 @@ void Nnet::GetWeights(Vector* wei_copy) { default : KALDI_ERR << "Unimplemented access to parameters " << "of updatable component " - << Component::TypeToMarker(nnet_[n]->GetType()); + << Component::TypeToMarker(components_[n]->GetType()); } } } @@ -172,12 +233,12 @@ void Nnet::GetWeights(Vector* wei_copy) { void Nnet::SetWeights(const Vector& wei_src) { KALDI_ASSERT(wei_src.Dim() == NumParams()); int32 pos = 0; - for(int32 n=0; nIsUpdatable()) { - switch(nnet_[n]->GetType()) { + for(int32 n=0; nIsUpdatable()) { + switch(components_[n]->GetType()) { case Component::kAffineTransform : { //get the component - AffineTransform* aff_t = dynamic_cast(nnet_[n]); + AffineTransform* aff_t = dynamic_cast(components_[n]); //we need weight matrix with original dimensions const CuMatrix& cu_mat = aff_t->GetLinearity(); Matrix mat(cu_mat.NumRows(),cu_mat.NumCols()); @@ -205,7 +266,7 @@ void Nnet::SetWeights(const Vector& wei_src) { default : KALDI_ERR << "Unimplemented access to parameters " << "of updatable component " - << Component::TypeToMarker(nnet_[n]->GetType()); + << Component::TypeToMarker(components_[n]->GetType()); } } } @@ -217,13 +278,13 @@ void Nnet::GetGradient(Vector* grad_copy) { grad_copy->Resize(NumParams()); int32 pos = 0; //copy the params - for(int32 n=0; nIsUpdatable()) { - switch(nnet_[n]->GetType()) { + for(int32 n=0; nIsUpdatable()) { + switch(components_[n]->GetType()) { case Component::kAffineTransform : { //get the weights from CuMatrix to Matrix const CuMatrix& cu_mat = - dynamic_cast(nnet_[n])->GetLinearityCorr(); + dynamic_cast(components_[n])->GetLinearityCorr(); Matrix mat(cu_mat.NumRows(),cu_mat.NumCols()); cu_mat.CopyToMat(&mat); //copy the the matrix row-by-row to the vector @@ -232,7 +293,7 @@ void Nnet::GetGradient(Vector* grad_copy) { pos += mat_size; //get the biases from CuVector to Vector const CuVector& cu_vec = - dynamic_cast(nnet_[n])->GetBiasCorr(); + dynamic_cast(components_[n])->GetBiasCorr(); Vector vec(cu_vec.Dim()); cu_vec.CopyToVec(&vec); //append biases to the supervector @@ -242,7 +303,7 @@ void Nnet::GetGradient(Vector* grad_copy) { default : KALDI_ERR << "Unimplemented access to parameters " << "of updatable component " - << Component::TypeToMarker(nnet_[n]->GetType()); + << Component::TypeToMarker(components_[n]->GetType()); } } } @@ -252,14 +313,14 @@ void Nnet::GetGradient(Vector* grad_copy) { int32 Nnet::NumParams() const { int32 n_params = 0; - for(int32 n=0; nIsUpdatable()) { - switch(nnet_[n]->GetType()) { + for(int32 n=0; nIsUpdatable()) { + switch(components_[n]->GetType()) { case Component::kAffineTransform : - n_params += (1 + nnet_[n]->InputDim()) * nnet_[n]->OutputDim(); + n_params += (1 + components_[n]->InputDim()) * components_[n]->OutputDim(); break; default : - KALDI_WARN << Component::TypeToMarker(nnet_[n]->GetType()) + KALDI_WARN << Component::TypeToMarker(components_[n]->GetType()) << "is updatable, but its parameter count not implemented"; } } @@ -268,40 +329,95 @@ int32 Nnet::NumParams() const { } -void Nnet::Read(std::istream &in, bool binary) { +void Nnet::Read(const std::string &file) { + bool binary; + Input in(file, &binary); + Read(in.Stream(), binary); + in.Close(); + // Warn if the NN is empty + if(NumComponents() == 0) { + KALDI_WARN << "The network '" << file << "' is empty."; + } +} + + +void Nnet::Read(std::istream &is, bool binary) { // get the network layers from a factory Component *comp; - while (NULL != (comp = Component::Read(in, binary, this))) { - if (LayerCount() > 0 && nnet_.back()->OutputDim() != comp->InputDim()) { + while (NULL != (comp = Component::Read(is, binary))) { + if (NumComponents() > 0 && components_.back()->OutputDim() != comp->InputDim()) { KALDI_ERR << "Dimensionality mismatch!" - << " Previous layer output:" << nnet_.back()->OutputDim() + << " Previous layer output:" << components_.back()->OutputDim() << " Current layer input:" << comp->InputDim(); } - nnet_.push_back(comp); + components_.push_back(comp); } // create empty buffers - propagate_buf_.resize(LayerCount()+1); - backpropagate_buf_.resize(LayerCount()-1); + propagate_buf_.resize(NumComponents()+1); + backpropagate_buf_.resize(NumComponents()-1); // reset learn rate opts_.learn_rate = 0.0; + + Check(); //check consistency (dims...) +} + + +void Nnet::Write(const std::string &file, bool binary) { + Output out(file, binary, true); + Write(out.Stream(), binary); + out.Close(); +} + + +void Nnet::Write(std::ostream &os, bool binary) { + Check(); + WriteToken(os, binary, ""); + if(binary == false) os << std::endl; + for(int32 i=0; iWrite(os, binary); + } + WriteToken(os, binary, ""); + if(binary == false) os << std::endl; } std::string Nnet::Info() const { std::ostringstream ostr; - ostr << "num-components " << LayerCount() << std::endl; + ostr << "num-components " << NumComponents() << std::endl; ostr << "input-dim " << InputDim() << std::endl; ostr << "output-dim " << OutputDim() << std::endl; ostr << "number-of-parameters " << static_cast(NumParams())/1e6 << " millions" << std::endl; - for (int32 i = 0; i < LayerCount(); i++) + for (int32 i = 0; i < NumComponents(); i++) ostr << "component " << i+1 << " : " - << Component::TypeToMarker(nnet_[i]->GetType()) - << ", input-dim " << nnet_[i]->InputDim() - << ", output-dim " << nnet_[i]->OutputDim() - << ", " << nnet_[i]->Info() << std::endl; + << Component::TypeToMarker(components_[i]->GetType()) + << ", input-dim " << components_[i]->InputDim() + << ", output-dim " << components_[i]->OutputDim() + << ", " << components_[i]->Info() << std::endl; return ostr.str(); } - + + +void Nnet::Check() const { + for (size_t i = 0; i + 1 < components_.size(); i++) { + KALDI_ASSERT(components_[i] != NULL); + int32 output_dim = components_[i]->OutputDim(), + next_input_dim = components_[i+1]->InputDim(); + KALDI_ASSERT(output_dim == next_input_dim); + } +} + + +void Nnet::SetTrainOptions(const NnetTrainOptions& opts) { + opts_ = opts; + //set values to individual components + for (int32 l=0; l(GetComponent(l)).SetTrainOptions(opts_); + } + } +} + + } // namespace nnet1 } // namespace kaldi diff --git a/src/nnet/nnet-nnet.h b/src/nnet/nnet-nnet.h index e4361c456..a277acbf6 100644 --- a/src/nnet/nnet-nnet.h +++ b/src/nnet/nnet-nnet.h @@ -1,6 +1,6 @@ // nnet/nnet-nnet.h -// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) +// Copyright 2011-2013 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -51,31 +51,28 @@ class Nnet { int32 InputDim() const; /// Dimensionality of network outputs (posteriors | bn-features | etc.) int32 OutputDim() const; - - /// Returns number of layers in the network - int32 LayerCount() const { - return nnet_.size(); - } - /// Access to an individual layer (unprotected) - Component* Layer(int32 index) { - return nnet_[index]; - } - /// Get the position of a layer in the network - int32 IndexOfLayer(const Component& comp) const; + + /// Returns number of components-- think of this as similar to # of layers, but + /// e.g. the nonlinearity and the linear part count as separate components, + /// so the number of components will be more than the number of layers. + int32 NumComponents() const { return components_.size(); } + + const Component& GetComponent(int32 c) const; + Component& GetComponent(int32 c); + + /// Sets the c'th component to "component", taking ownership of the pointer + /// and deleting the corresponding one that we own. + void SetComponent(int32 c, Component *component); - /// Add another layer - /// Warning : the Nnet over-takes responsibility for freeing the memory - /// so use dynamically allocated Component only! - void AppendLayer(Component* dynamically_allocated_comp); - /// Concatenate the network - /// Warning : this is destructive, the arg src_nnet_will_be_empty - /// will be empty network after calling this method - void Concatenate(Nnet* src_nnet_will_be_empty); - /// Remove layer (checks for meaningful dimensions after removal) - void RemoveLayer(int32 index); - void RemoveLastLayer() { - RemoveLayer(LayerCount()-1); - } + /// Appends this component to the components already in the neural net. + /// Takes ownership of the pointer + void AppendComponent(Component *dynamically_allocated_comp); + /// Append another network to the current one (copy components). + void AppendNnet(const Nnet& nnet_to_append); + + /// Remove component + void RemoveComponent(int32 c); + void RemoveLastComponent() { RemoveComponent(NumComponents()-1); } /// Access to forward pass buffers const std::vector >& PropagateBuffer() const { @@ -86,7 +83,7 @@ class Nnet { return backpropagate_buf_; } - /// get the number of parameters in the network + /// Get the number of parameters in the network int32 NumParams() const; /// Get the network weights in a supervector void GetWeights(Vector* wei_copy); @@ -103,8 +100,11 @@ class Nnet { void Write(const std::string &file, bool binary); /// Write MLP to stream void Write(std::ostream &out, bool binary); + /// Create string with human readable description of the nnet instance std::string Info() const; + /// Consistency check. + void Check() const; /// Set training hyper-parameters to the network and its UpdatableComponent(s) void SetTrainOptions(const NnetTrainOptions& opts); @@ -114,11 +114,9 @@ class Nnet { } private: - /// NnetType is alias to vector of components - typedef std::vector NnetType; - /// Vector which contains all the layers composing the network network, - /// also non-linearities (sigmoid|softmax|tanh|...) are considered as layers. - NnetType nnet_; + /// Vector which contains all the components composing the neural network, + /// the components are for example: AffineTransform, Sigmoid, Softmax + std::vector components_; std::vector > propagate_buf_; ///< buffers for forward pass std::vector > backpropagate_buf_; ///< buffers for backward pass @@ -130,109 +128,6 @@ class Nnet { }; -inline Nnet::~Nnet() { - // delete all the components - NnetType::iterator it; - for(it=nnet_.begin(); it!=nnet_.end(); ++it) { - delete *it; - } -} - - -inline int32 Nnet::InputDim() const { - if (LayerCount() == 0) { - KALDI_ERR << "No layers in MLP"; - } - return nnet_.front()->InputDim(); -} - - -inline int32 Nnet::OutputDim() const { - if (LayerCount() <= 0) { - KALDI_ERR << "No layers in MLP"; - } - return nnet_.back()->OutputDim(); -} - - -inline int32 Nnet::IndexOfLayer(const Component &comp) const { - for(int32 i=0; i 0) { - KALDI_ASSERT(OutputDim() == dynamically_allocated_comp->InputDim()); - } - nnet_.push_back(dynamically_allocated_comp); -} - - -inline void Nnet::Concatenate(Nnet* src_nnet_will_be_empty) { - if(LayerCount() > 0) { - KALDI_ASSERT(OutputDim() == src_nnet_will_be_empty->InputDim()); - } - nnet_.insert(nnet_.end(), - src_nnet_will_be_empty->nnet_.begin(), - src_nnet_will_be_empty->nnet_.end()); - src_nnet_will_be_empty->nnet_.clear(); -} - - -inline void Nnet::RemoveLayer(int32 index) { - //make sure we don't break the dimensionalities in the nnet - KALDI_ASSERT(index < LayerCount()); - KALDI_ASSERT(index == LayerCount()-1 || Layer(index)->InputDim() == Layer(index)->OutputDim()); - //remove element from the vector - Component* ptr = nnet_[index]; - nnet_.erase(nnet_.begin()+index); - delete ptr; -} - - -inline void Nnet::Read(const std::string &file) { - bool binary; - Input in(file, &binary); - Read(in.Stream(), binary); - in.Close(); - // Warn if the NN is empty - if(LayerCount() == 0) { - KALDI_WARN << "The network '" << file << "' is empty."; - } -} - - -inline void Nnet::Write(const std::string &file, bool binary) { - Output out(file, binary, true); - Write(out.Stream(), binary); - out.Close(); -} - - -inline void Nnet::Write(std::ostream &out, bool binary) { - for(int32 i=0; iWrite(out, binary); - } -} - - -inline void Nnet::SetTrainOptions(const NnetTrainOptions& opts) { - opts_ = opts; - //set values to individual components - for (int32 l=0; lIsUpdatable()) { - dynamic_cast(Layer(l))->SetTrainOptions(opts_); - } - } -} - - } // namespace nnet1 } // namespace kaldi diff --git a/src/nnet/nnet-rbm.h b/src/nnet/nnet-rbm.h index af97b3a1c..32073789b 100644 --- a/src/nnet/nnet-rbm.h +++ b/src/nnet/nnet-rbm.h @@ -1,6 +1,6 @@ // nnet/nnet-rbm.h -// Copyright 2012-2013 Brno University of Technology (Author: Karel Vesely) +// Copyright 2012-2013 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -35,8 +35,8 @@ class RbmBase : public UpdatableComponent { GAUSSIAN } RbmNodeType; - RbmBase(int32 dim_in, int32 dim_out, Nnet *nnet) - : UpdatableComponent(dim_in, dim_out, nnet) + RbmBase(int32 dim_in, int32 dim_out) + : UpdatableComponent(dim_in, dim_out) { } /*Is included in Component:: itf @@ -85,8 +85,8 @@ class RbmBase : public UpdatableComponent { // RBMs use RbmUpdate(.) void Update(const CuMatrix &input, const CuMatrix &diff) { } // RBMs use option class RbmTrainOptions - void SetTrainOptions(const NnetTrainOptions&); - const NnetTrainOptions& GetTrainOptions() const; + void SetTrainOptions(const NnetTrainOptions&) { } + const NnetTrainOptions& GetTrainOptions() const { } NnetTrainOptions opts_; // //// @@ -97,15 +97,14 @@ class RbmBase : public UpdatableComponent { class Rbm : public RbmBase { public: - Rbm(int32 dim_in, int32 dim_out, Nnet *nnet) - : RbmBase(dim_in, dim_out, nnet) + Rbm(int32 dim_in, int32 dim_out) + : RbmBase(dim_in, dim_out) { } ~Rbm() { } - ComponentType GetType() const { - return kRbm; - } + Component* Copy() const { return new Rbm(*this); } + ComponentType GetType() const { return kRbm; } void ReadData(std::istream &is, bool binary) { std::string vis_node_type, hid_node_type; @@ -164,7 +163,7 @@ class Rbm : public RbmBase { void BackpropagateFnc(const CuMatrix &in, const CuMatrix &out, const CuMatrix &out_diff, CuMatrix *in_diff) { - KALDI_ERR << "Cannot backpropagate through RBM!" + KALDI_ERR << "Cannot back-propagate through RBM!" << "Better convert it to and "; } virtual void Update(const CuMatrix &input, @@ -226,9 +225,9 @@ class Rbm : public RbmBase { // should be about the same. The model is particularly sensitive at the very // beginning of the CD-1 training. // - // We compute varinace of a)input minibatch b)reconstruction. + // We compute variance of a)input mini-batch b)reconstruction. // When the ratio b)/a) is larger than 2, we: - // 1. scale down the weights and biases by b)/a) (for next minibatch b)/a) gets 1.0) + // 1. scale down the weights and biases by b)/a) (for next mini-batch b)/a) gets 1.0) // 2. shrink learning rate by 0.9x // 3. reset the momentum buffer // @@ -255,7 +254,7 @@ class Rbm : public RbmBase { pos_vis_stddev.MulElements(pos_vis_mean_h); pos_vis_stddev.Scale(-1.0); pos_vis_stddev.AddVec(1.0/pos_vis.NumRows(),pos_vis_second_h); - /* set negtive values to zero before the square root */ + /* set negative values to zero before the square root */ for (int32 i=0; i &vec) { vec_aux.MulElements(vec); // (vec-mean)^3 Real skewness = vec_aux.Sum() / pow(variance, 3.0/2.0) / vec.Dim(); // kurtosis (peakedness) - // - makes sence for symmetric distributions (skewness is zero) + // - makes sense for symmetric distributions (skewness is zero) // - positive : 'sharper peak' than Normal distribution - // - negtive : 'heavier tails' than Normal distribution + // - negative : 'heavier tails' than Normal distribution // - zero : same peakedness as the Normal distribution vec_aux.MulElements(vec); // (vec-mean)^4 Real kurtosis = vec_aux.Sum() / (variance * variance) / vec.Dim() - 3.0; @@ -66,6 +66,9 @@ std::string MomentStatistics(const Vector &vec) { return ostr.str(); } +/** + * Overload MomentStatistics to Matrix + */ template std::string MomentStatistics(const Matrix &mat) { Vector vec(mat.NumRows()*mat.NumCols()); @@ -73,6 +76,9 @@ std::string MomentStatistics(const Matrix &mat) { return MomentStatistics(vec); } +/** + * Overload MomentStatistics to CuVector + */ template std::string MomentStatistics(const CuVector &vec) { Vector vec_host(vec.Dim()); @@ -80,6 +86,9 @@ std::string MomentStatistics(const CuVector &vec) { return MomentStatistics(vec_host); } +/** + * Overload MomentStatistics to CuMatrix + */ template std::string MomentStatistics(const CuMatrix &mat) { Matrix mat_host(mat.NumRows(),mat.NumCols()); @@ -96,26 +105,25 @@ std::string MomentStatistics(const CuMatrix &mat) { */ class Splice : public Component { public: - Splice(int32 dim_in, int32 dim_out, Nnet *nnet) - : Component(dim_in, dim_out, nnet) + Splice(int32 dim_in, int32 dim_out) + : Component(dim_in, dim_out) { } ~Splice() { } - ComponentType GetType() const { - return kSplice; - } + Component* Copy() const { return new Splice(*this); } + ComponentType GetType() const { return kSplice; } void ReadData(std::istream &is, bool binary) { - //read double vector + // read double vector Vector vec_d; vec_d.Read(is, binary); - //convert to int vector + // convert to int vector std::vector vec_i(vec_d.Dim()); for(int32 i=0; i vec_d; vec_d.Read(is, binary); - //subtract 1 + // subtract 1 vec_d.Add(-1.0); - //convert to int vector + // convert to int vector std::vector vec_i(vec_d.Dim()); for(int32 i=0; i &in, CuMatrix *out) { out->CopyFromMat(in); - //rescale the data + // rescale the data out->MulColsVec(scale_data_); } void BackpropagateFnc(const CuMatrix &in, const CuMatrix &out, const CuMatrix &out_diff, CuMatrix *in_diff) { in_diff->CopyFromMat(out_diff); - //derivative gets also scaled by the scale_data_ + // derivative gets also scaled by the scale_data_ in_diff->MulColsVec(scale_data_); } - //Data accessors + // Data accessors const CuVector& GetScaleVec() { return scale_data_; } @@ -331,9 +336,6 @@ class Rescale : public Component { - - - } // namespace nnet1 } // namespace kaldi diff --git a/src/nnet2/mixup-nnet.cc b/src/nnet2/mixup-nnet.cc index 3bc323889..7f447faf9 100644 --- a/src/nnet2/mixup-nnet.cc +++ b/src/nnet2/mixup-nnet.cc @@ -24,20 +24,9 @@ namespace kaldi { namespace nnet2 { -static BaseFloat GetFirstLearningRate(const Nnet &nnet) { - for (int32 c = 0; c < nnet.NumComponents(); c++) { - const UpdatableComponent *uc = - dynamic_cast(&(nnet.GetComponent(c))); - if (uc != NULL) - return uc->LearningRate(); - } - KALDI_ERR << "Neural net has no updatable components"; - return 0.0; -} - /** This function makes sure the neural net ends with a - MixtureProbComponent. If it doesn't, it adds one + SumGroupComponent. If it doesn't, it adds one (with a single mixture/matrix corresponding to each output element.) [Before doing so, it makes sure that the last layer is a SoftmaxLayer, which is what @@ -48,24 +37,20 @@ static BaseFloat GetFirstLearningRate(const Nnet &nnet) { static void GiveNnetCorrectTopology(Nnet *nnet, AffineComponent **affine_component, SoftmaxComponent **softmax_component, - MixtureProbComponent **mixture_prob_component) { + SumGroupComponent **sum_group_component) { int32 nc = nnet->NumComponents(); KALDI_ASSERT(nc > 0); Component* component = &(nnet->GetComponent(nc - 1)); - if ((*mixture_prob_component = - dynamic_cast(component)) == NULL) { - KALDI_LOG << "Adding MixtureProbComponent to neural net."; + if ((*sum_group_component = + dynamic_cast(component)) == NULL) { + KALDI_LOG << "Adding SumGroupComponent to neural net."; int32 dim = component->OutputDim(); // Give it the same learning rate as the first updatable layer we have. - BaseFloat learning_rate = GetFirstLearningRate(*nnet), - diag_element = 0.999; // actually it's a don't care. std::vector sizes(dim, 1); // a vector of all ones, of dimension "dim". - *mixture_prob_component = new MixtureProbComponent(); - (*mixture_prob_component)->Init(learning_rate, - diag_element, - sizes); - nnet->Append(*mixture_prob_component); + *sum_group_component = new SumGroupComponent(); + (*sum_group_component)->Init(sizes); + nnet->Append(*sum_group_component); nc++; } component = &(nnet->GetComponent(nc - 2)); @@ -84,14 +69,16 @@ static void GiveNnetCorrectTopology(Nnet *nnet, /** This function works as follows. We first make sure the neural net has the correct topology, so its - last component is a MixtureProbComponent. + last component is a SumGroupComponent. - We then get the counts for each matrix in the MixtureProbComponent (these + We then get the counts for each matrix in the SumGroupComponent (these will either correspond to leaves in the decision tree, or level-1 leaves, if we have a 2-level-tree system). We work out the total count for each of these matrices, by getting the count from the SoftmaxComponent. - - Then, for each matrix in the Mixturemixture-prob component, we + + We then increase, if necessary, the dimensions that the SumGroupComponent sums + over increase the dimension of the SoftmaxComponent if necessary, and duplicate + and then perturb the relevant rows of the AffineComponent. */ @@ -100,18 +87,18 @@ void MixupNnet(const NnetMixupConfig &mixup_config, Nnet *nnet) { AffineComponent *affine_component = NULL; SoftmaxComponent *softmax_component = NULL; - MixtureProbComponent *mixture_prob_component = NULL; + SumGroupComponent *sum_group_component = NULL; GiveNnetCorrectTopology(nnet, &affine_component, &softmax_component, - &mixture_prob_component); // Adds a MixtureProbComponent if needed. + &sum_group_component); // Adds a SumGroupComponent if needed. softmax_component->MixUp(mixup_config.num_mixtures, mixup_config.power, mixup_config.min_count, mixup_config.perturb_stddev, affine_component, - mixture_prob_component); + sum_group_component); nnet->Check(); // Checks that dimensions all match up. } @@ -120,15 +107,16 @@ void MixupNnet(const NnetMixupConfig &mixup_config, void SoftmaxComponent::MixUp(int32 num_mixtures, BaseFloat power, BaseFloat min_count, - BaseFloat perturb_stddev, + BaseFloat perturb_stddev, AffineComponent *ac, - MixtureProbComponent *mc) { - + SumGroupComponent *sc) { // "counts" is derived from this->counts_ by summing. - Vector counts(mc->params_.size()); + std::vector old_sizes; + sc->GetSizes(&old_sizes); + Vector counts(old_sizes.size()); int32 old_dim = 0; - for (size_t i = 0; i < mc->params_.size(); i++) { - int32 this_input_dim = mc->params_[i].NumCols(); + for (size_t i = 0; i < old_sizes.size(); i++) { + int32 this_input_dim = old_sizes[i]; BaseFloat this_tot_count = 0.0; /// Total the count out of /// all the output dims of the softmax layer that correspond /// to this mixture. We'll use this total to allocate new quasi-Gaussians. @@ -141,16 +129,18 @@ void SoftmaxComponent::MixUp(int32 num_mixtures, std::vector targets; // #mixtures for each state. + // Get the target number of mixtures for each state. GetSplitTargets(counts, num_mixtures, power, min_count, &targets); - KALDI_ASSERT(targets.size() == mc->params_.size()); - // floor each target to the current #mixture components. + KALDI_ASSERT(targets.size() == old_sizes.size()); + std::vector new_sizes(old_sizes.size()); for (size_t i = 0; i < targets.size(); i++) - targets[i] = std::max(targets[i], mc->params_[i].NumCols()); - int32 new_dim = std::accumulate(targets.begin(), targets.end(), + new_sizes[i] = std::max(targets[i], old_sizes[i]); + int32 new_dim = std::accumulate(new_sizes.begin(), new_sizes.end(), static_cast(0)), affine_input_dim = ac->InputDim(); KALDI_ASSERT(new_dim >= old_dim); + sc->Init(new_sizes); // bias and linear terms from affine component: Vector old_bias_term(ac->bias_params_); @@ -165,11 +155,10 @@ void SoftmaxComponent::MixUp(int32 num_mixtures, // respectively. They get incremented in the following loop. int32 old_offset = 0, new_offset = 0; Vector old_counts(this->value_sum_); - for (size_t i = 0; i < mc->params_.size(); i++) { - const CuMatrix &this_old_params(mc->params_[i]); - int32 this_old_dim = this_old_params.NumCols(), - this_new_dim = targets[i], - this_cur_dim = this_old_dim; // this_cur_dim is loop variable. + for (size_t i = 0; i < old_sizes.size(); i++) { + int32 this_old_dim = old_sizes[i], + this_new_dim = new_sizes[i], + this_cur_dim = this_old_dim; // this_cur_dim is loop variable. SubMatrix this_old_linear_term(old_linear_term, old_offset, this_old_dim, @@ -184,8 +173,6 @@ void SoftmaxComponent::MixUp(int32 num_mixtures, old_offset, this_old_dim), this_new_counts(new_counts, new_offset, this_new_dim); - Matrix this_new_params(this_old_params.NumRows(), - this_new_dim); // Copy the same-dimensional part of the parameters and counts. this_new_linear_term.Range(0, this_old_dim, 0, affine_input_dim). @@ -195,8 +182,6 @@ void SoftmaxComponent::MixUp(int32 num_mixtures, this_new_counts.Range(0, this_old_dim). CopyFromVec(this_old_counts); // this_new_params is the mixture weights. - this_new_params.Range(0, this_old_params.NumRows(), 0, this_old_dim). - CopyFromMat(this_old_params); // Add the new components... for (; this_cur_dim < this_new_dim; this_cur_dim++) { BaseFloat *count_begin = this_new_counts.Data(), @@ -216,13 +201,9 @@ void SoftmaxComponent::MixUp(int32 num_mixtures, new_vec.AddVec(-perturb_stddev, rand); this_new_bias_term(max_index) += log(0.5); this_new_bias_term(new_index) = this_new_bias_term(max_index); - // now copy the column of the MixtureProbComponent parameters. - for (int32 j = 0; j < this_new_params.NumRows(); j++) - this_new_params(j, new_index) = this_new_params(j, max_index); } old_offset += this_old_dim; new_offset += this_new_dim; - mc->params_[i] = this_new_params; } KALDI_ASSERT(old_offset == old_dim && new_offset == new_dim); ac->SetParams(new_bias_term, new_linear_term); @@ -230,8 +211,6 @@ void SoftmaxComponent::MixUp(int32 num_mixtures, this->value_sum_.CopyFromVec(new_counts); this->count_ = this->value_sum_.Sum(); this->dim_ = new_dim; - mc->input_dim_ = new_dim; // keep this up to date. - // We already updated mc->params_. KALDI_LOG << "Mixed up from dimension of " << old_dim << " to " << new_dim << " in the softmax layer."; } diff --git a/src/nnet2/mixup-nnet.h b/src/nnet2/mixup-nnet.h index 5d6cae5ab..a3b2c6d2f 100644 --- a/src/nnet2/mixup-nnet.h +++ b/src/nnet2/mixup-nnet.h @@ -27,15 +27,12 @@ namespace kaldi { namespace nnet2 { -/** Configuration class that controls neural net "mixupage" which is actually a - scaling on the parameters of each of the updatable layers. - */ struct NnetMixupConfig { BaseFloat power; BaseFloat min_count; int32 num_mixtures; BaseFloat perturb_stddev; - + NnetMixupConfig(): power(0.25), min_count(1000.0), num_mixtures(-1), perturb_stddev(0.01) { } @@ -58,9 +55,8 @@ struct NnetMixupConfig { This function does something similar to Gaussian mixture splitting for GMMs, except applied to the output layer of the neural network. We create additional outputs, which will be summed over using a - MixtureProbComponent (if one does not already exist, it will be - added.) - */ + SumGroupComponent. +*/ void MixupNnet(const NnetMixupConfig &mixup_config, Nnet *nnet); diff --git a/src/nnet2/nnet-component-test.cc b/src/nnet2/nnet-component-test.cc index 45f5d0501..375090b67 100644 --- a/src/nnet2/nnet-component-test.cc +++ b/src/nnet2/nnet-component-test.cc @@ -532,6 +532,27 @@ void UnitTestMixtureProbComponent() { } } + +void UnitTestSumGroupComponent() { + std::vector sizes; + int32 num_sizes = 1 + rand() % 5; + for (int32 i = 0; i < num_sizes; i++) + sizes.push_back(1 + rand() % 5); + + { + SumGroupComponent component; + component.Init(sizes); + UnitTestGenericComponentInternal(component); + } + { + const char *str = "sizes=3:4:5"; + SumGroupComponent component; + component.InitFromString(str); + UnitTestGenericComponentInternal(component); + } +} + + void UnitTestDctComponent() { int32 m = 1 + rand() % 4, n = 1 + rand() % 4, dct_dim = m, dim = m * n; @@ -760,6 +781,7 @@ int main() { UnitTestBlockAffineComponent(); UnitTestBlockAffineComponentPreconditioned(); UnitTestMixtureProbComponent(); + UnitTestSumGroupComponent(); UnitTestDctComponent(); UnitTestFixedLinearComponent(); UnitTestFixedAffineComponent(); diff --git a/src/nnet2/nnet-component.cc b/src/nnet2/nnet-component.cc index ccd7d43f6..17a198b0e 100644 --- a/src/nnet2/nnet-component.cc +++ b/src/nnet2/nnet-component.cc @@ -75,6 +75,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) { ans = new AffinePreconInputComponent(); } else if (component_type == "MixtureProbComponent") { ans = new MixtureProbComponent(); + } else if (component_type == "SumGroupComponent") { + ans = new SumGroupComponent(); } else if (component_type == "BlockAffineComponent") { ans = new BlockAffineComponent(); } else if (component_type == "BlockAffineComponentPreconditioned") { @@ -2922,6 +2924,92 @@ void MixtureProbComponent::UnVectorize(const VectorBase ¶ms) { KALDI_ASSERT(offset == params.Dim()); } +void SumGroupComponent::Init(const std::vector &sizes) { + KALDI_ASSERT(!sizes.empty()); + std::vector cpu_vec(sizes.size()); + std::vector reverse_cpu_vec; + int32 cur_index = 0; + for (size_t i = 0; i < sizes.size(); i++) { + KALDI_ASSERT(sizes[i] > 0); + cpu_vec[i].first = cur_index; + cpu_vec[i].second = cur_index + sizes[i]; + cur_index += sizes[i]; + for (int32 j = cpu_vec[i].first; j < cpu_vec[i].second; j++) + reverse_cpu_vec.push_back(i); + } + this->indexes_ = cpu_vec; + this->reverse_indexes_ = reverse_cpu_vec; + this->input_dim_ = cur_index; + this->output_dim_ = sizes.size(); +} + +void SumGroupComponent::InitFromString(std::string args) { + std::string orig_args(args); + std::vector sizes; + bool ok = ParseFromString("sizes", &args, &sizes); + + if (!ok || !args.empty() || sizes.empty()) + KALDI_ERR << "Invalid initializer for layer of type " + << Type() << ": \"" << orig_args << "\""; + this->Init(sizes); +} + +Component* SumGroupComponent::Copy() const { + SumGroupComponent *ans = new SumGroupComponent(); + ans->indexes_ = indexes_; + ans->reverse_indexes_ = reverse_indexes_; + ans->input_dim_ = input_dim_; + ans->output_dim_ = output_dim_; + return ans; +} + +void SumGroupComponent::Read(std::istream &is, bool binary) { + ExpectOneOrTwoTokens(is, binary, "", ""); + std::vector sizes; + ReadIntegerVector(is, binary, &sizes); + ExpectToken(is, binary, ""); + this->Init(sizes); +} + +void SumGroupComponent::GetSizes(std::vector *sizes) const { + std::vector indexes; + indexes_.CopyToVec(&indexes); + sizes->resize(indexes.size()); + for (size_t i = 0; i < indexes.size(); i++) { + (*sizes)[i] = indexes[i].second - indexes[i].first; + if (i == 0) { KALDI_ASSERT(indexes[i].first == 0); } + else { KALDI_ASSERT(indexes[i].first == indexes[i-1].second); } + KALDI_ASSERT(indexes[i].second > indexes[i].first); + (*sizes)[i] = indexes[i].second - indexes[i].first; + } +} + +void SumGroupComponent::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); + WriteToken(os, binary, ""); + std::vector sizes; + this->GetSizes(&sizes); + WriteIntegerVector(os, binary, sizes); + WriteToken(os, binary, ""); +} + +void SumGroupComponent::Propagate(const CuMatrixBase &in, + int32 num_chunks, + CuMatrix *out) const { + out->Resize(in.NumRows(), this->OutputDim(), kUndefined); + out->SumColumnRanges(in, indexes_); +} + +void SumGroupComponent::Backprop(const CuMatrixBase &, // in_value, + const CuMatrixBase &, // out_value, + const CuMatrixBase &out_deriv, + int32 num_chunks, + Component *to_update, + CuMatrix *in_deriv) const { + in_deriv->Resize(out_deriv.NumRows(), InputDim()); + in_deriv->CopyCols(out_deriv, reverse_indexes_); +} + std::string SpliceComponent::Info() const { std::stringstream stream; diff --git a/src/nnet2/nnet-component.h b/src/nnet2/nnet-component.h index 961912d3b..7b7fa93c6 100644 --- a/src/nnet2/nnet-component.h +++ b/src/nnet2/nnet-component.h @@ -450,7 +450,7 @@ class ScaleComponent: public Component { -class MixtureProbComponent; // Forward declaration. +class SumGroupComponent; // Forward declaration. class AffineComponent; // Forward declaration. class SoftmaxComponent: public NonlinearComponent { @@ -472,12 +472,13 @@ class SoftmaxComponent: public NonlinearComponent { Component *to_update, // may be identical to "this". CuMatrix *in_deriv) const; - void MixUp(int32 num_mixtures, // implemented in mixup-nnet.cc + void MixUp(int32 num_mixtures, BaseFloat power, BaseFloat min_count, BaseFloat perturb_stddev, AffineComponent *ac, - MixtureProbComponent *mc); + SumGroupComponent *sc); + virtual Component* Copy() const { return new SoftmaxComponent(*this); } private: SoftmaxComponent &operator = (const SoftmaxComponent &other); // Disallow. @@ -1227,8 +1228,6 @@ class BlockAffineComponentPreconditioned: public BlockAffineComponent { // one for each row). class MixtureProbComponent: public UpdatableComponent { - friend class SoftmaxComponent; // Mixing-up done by a function - // in that class. public: virtual int32 InputDim() const { return input_dim_; } virtual int32 OutputDim() const { return output_dim_; } @@ -1275,6 +1274,53 @@ class MixtureProbComponent: public UpdatableComponent { int32 output_dim_; }; + +// SumGroupComponent is used to sum up groups of posteriors. +// It's used to introduce a kind of Gaussian-mixture-model-like +// idea into neural nets. This is basically a degenerate case of +// MixtureProbComponent; we had to implement it separately to +// be efficient for CUDA (we can use this one regardless whether +// we have CUDA or not; it's the normal case we want anyway). +class SumGroupComponent: public Component { +public: + virtual int32 InputDim() const { return input_dim_; } + virtual int32 OutputDim() const { return output_dim_; } + void Init(const std::vector &sizes); // the vector is of the input dim + // (>= 1) for each output dim. + void GetSizes(std::vector *sizes) const; // Get a vector saying, for + // each output-dim, how many + // inputs were summed over. + virtual void InitFromString(std::string args); + SumGroupComponent() { } + virtual std::string Type() const { return "SumGroupComponent"; } + virtual bool BackpropNeedsInput() const { return false; } + virtual bool BackpropNeedsOutput() const { return false; } + virtual void Propagate(const CuMatrixBase &in, + int32 num_chunks, + CuMatrix *out) const; + // Note: in_value and out_value are both dummy variables. + virtual void Backprop(const CuMatrixBase &in_value, + const CuMatrixBase &out_value, + const CuMatrixBase &out_deriv, + int32 num_chunks, + Component *to_update, // may be identical to "this". + CuMatrix *in_deriv) const; + virtual Component* Copy() const; + virtual void Read(std::istream &is, bool binary); + virtual void Write(std::ostream &os, bool binary) const; + +private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SumGroupComponent); + // Note: Int32Pair is just struct{ int32 first; int32 second }; it's defined + // in cu-matrixdim.h as extern "C" which is needed for the CUDA interface. + CuArray indexes_; // for each output index, the (start, end) input + // index. + CuArray reverse_indexes_; // for each input index, the output index. + int32 input_dim_; + int32 output_dim_; +}; + + /// PermuteComponent does a random permutation of the dimensions. Useful in /// conjunction with block-diagonal transforms. class PermuteComponent: public Component { diff --git a/src/nnetbin/Makefile b/src/nnetbin/Makefile index 3c1861e00..006ef98ce 100644 --- a/src/nnetbin/Makefile +++ b/src/nnetbin/Makefile @@ -12,7 +12,7 @@ BINFILES = nnet-train-xent-hardlab-perutt \ nnet-train-mmi-sequential \ nnet-train-mpe-sequential \ rbm-train-cd1-frmshuff rbm-convert-to-nnet \ - nnet-forward nnet-copy nnet1-info nnet-concat \ + nnet-forward nnet-copy nnet-info nnet-concat \ transf-to-nnet cmvn-to-nnet OBJFILES = diff --git a/src/nnetbin/cmvn-to-nnet.cc b/src/nnetbin/cmvn-to-nnet.cc index f34987027..fe59ee019 100644 --- a/src/nnetbin/cmvn-to-nnet.cc +++ b/src/nnetbin/cmvn-to-nnet.cc @@ -93,7 +93,7 @@ int main(int argc, char *argv[]) { //create the shift component { - AddShift* shift_component = new AddShift(shift.Dim(), shift.Dim(), &nnet); + AddShift* shift_component = new AddShift(shift.Dim(), shift.Dim()); //the pointer will be given to the nnet, so we don't need to call delete //convert Vector to CuVector @@ -103,12 +103,12 @@ int main(int argc, char *argv[]) { shift_component->SetShiftVec(cu_shift); //append layer to the nnet - nnet.AppendLayer(shift_component); + nnet.AppendComponent(shift_component); } //create the scale component { - Rescale* scale_component = new Rescale(scale.Dim(), scale.Dim(), &nnet); + Rescale* scale_component = new Rescale(scale.Dim(), scale.Dim()); //the pointer will be given to the nnet, so we don't need to call delete //convert Vector to CuVector @@ -118,9 +118,8 @@ int main(int argc, char *argv[]) { scale_component->SetScaleVec(cu_scale); //append layer to the nnet - nnet.AppendLayer(scale_component); + nnet.AppendComponent(scale_component); } - //write the nnet { diff --git a/src/nnetbin/nnet-concat.cc b/src/nnetbin/nnet-concat.cc index eea0b49cd..8afec1276 100644 --- a/src/nnetbin/nnet-concat.cc +++ b/src/nnetbin/nnet-concat.cc @@ -1,6 +1,6 @@ // nnetbin/nnet-concat.cc -// Copyright 2012 Karel Vesely +// Copyright 2012-2013 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -70,7 +70,7 @@ int main(int argc, char *argv[]) { nnet_next.Read(ki.Stream(), binary_read); } //append nnet_next to the network nnet - nnet.Concatenate(&nnet_next); + nnet.AppendNnet(nnet_next); } //finally write the nnet to disk diff --git a/src/nnetbin/nnet-copy.cc b/src/nnetbin/nnet-copy.cc index 75c327b43..ea03955e8 100644 --- a/src/nnetbin/nnet-copy.cc +++ b/src/nnetbin/nnet-copy.cc @@ -1,6 +1,6 @@ // nnetbin/nnet-copy.cc -// Copyright 2012 Karel Vesely +// Copyright 2012 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -64,14 +64,14 @@ int main(int argc, char *argv[]) { // optionally remove N first layers if(remove_first_layers > 0) { for(int32 i=0; i 0) { for(int32 i=0; iGetType() == Component::kSoftmax) { + if(no_softmax && nnet.GetComponent(nnet.NumComponents()-1).GetType() == Component::kSoftmax) { KALDI_LOG << "Removing softmax from the nnet " << model_filename; - nnet.RemoveLayer(nnet.LayerCount()-1); + nnet.RemoveComponent(nnet.NumComponents()-1); } //check for some non-sense option combinations if(apply_log && no_softmax) { KALDI_ERR << "Nonsense option combination : --apply-log=true and --no-softmax=true"; } - if(apply_log && nnet.Layer(nnet.LayerCount()-1)->GetType() != Component::kSoftmax) { + if(apply_log && nnet.GetComponent(nnet.NumComponents()-1).GetType() != Component::kSoftmax) { KALDI_ERR << "Used --apply-log=true, but nnet " << model_filename << " does not have as last component!"; } diff --git a/src/nnetbin/nnet1-info.cc b/src/nnetbin/nnet-info.cc similarity index 100% rename from src/nnetbin/nnet1-info.cc rename to src/nnetbin/nnet-info.cc diff --git a/src/nnetbin/nnet-train-mmi-sequential.cc b/src/nnetbin/nnet-train-mmi-sequential.cc index 0d46a2c71..82327819a 100644 --- a/src/nnetbin/nnet-train-mmi-sequential.cc +++ b/src/nnetbin/nnet-train-mmi-sequential.cc @@ -1,6 +1,6 @@ // nnetbin/nnet-train-mmi-sequential.cc -// Copyright 2012-2013 Karel Vesely +// Copyright 2012-2013 Brno University of Technology (author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // @@ -173,9 +173,9 @@ int main(int argc, char *argv[]) { Nnet nnet; nnet.Read(model_filename); // using activations directly: remove softmax, if present - if (nnet.Layer(nnet.LayerCount()-1)->GetType() == Component::kSoftmax) { + if (nnet.GetComponent(nnet.NumComponents()-1).GetType() == Component::kSoftmax) { KALDI_LOG << "Removing softmax from the nnet " << model_filename; - nnet.RemoveLayer(nnet.LayerCount()-1); + nnet.RemoveComponent(nnet.NumComponents()-1); } else { KALDI_LOG << "The nnet was without softmax " << model_filename; } @@ -424,7 +424,7 @@ int main(int argc, char *argv[]) { //add back the softmax KALDI_LOG << "Appending the softmax " << target_model_filename; - nnet.AppendLayer(new Softmax(nnet.OutputDim(),nnet.OutputDim(),&nnet)); + nnet.AppendComponent(new Softmax(nnet.OutputDim(),nnet.OutputDim())); //store the nnet nnet.Write(target_model_filename, binary); diff --git a/src/nnetbin/nnet-train-mpe-sequential.cc b/src/nnetbin/nnet-train-mpe-sequential.cc index 5e0d3082f..527f19ba7 100644 --- a/src/nnetbin/nnet-train-mpe-sequential.cc +++ b/src/nnetbin/nnet-train-mpe-sequential.cc @@ -1,6 +1,6 @@ // nnetbin/nnet-train-mpe-sequential.cc -// Copyright 2011-2013 Karel Vesely; Arnab Ghoshal +// Copyright 2011-2013 Brno University of Technology (author: Karel Vesely); Arnab Ghoshal // See ../../COPYING for clarification regarding multiple authors // @@ -175,9 +175,9 @@ int main(int argc, char *argv[]) { Nnet nnet; nnet.Read(model_filename); // using activations directly: remove softmax, if present - if (nnet.Layer(nnet.LayerCount()-1)->GetType() == Component::kSoftmax) { + if (nnet.GetComponent(nnet.NumComponents()-1).GetType() == Component::kSoftmax) { KALDI_LOG << "Removing softmax from the nnet " << model_filename; - nnet.RemoveLayer(nnet.LayerCount()-1); + nnet.RemoveComponent(nnet.NumComponents()-1); } else { KALDI_LOG << "The nnet was without softmax " << model_filename; } @@ -357,7 +357,7 @@ int main(int argc, char *argv[]) { // add the softmax layer back before writing KALDI_LOG << "Appending the softmax " << target_model_filename; - nnet.AppendLayer(new Softmax(nnet.OutputDim(),nnet.OutputDim(),&nnet)); + nnet.AppendComponent(new Softmax(nnet.OutputDim(),nnet.OutputDim())); //store the nnet nnet.Write(target_model_filename, binary); diff --git a/src/nnetbin/rbm-convert-to-nnet.cc b/src/nnetbin/rbm-convert-to-nnet.cc index 80802c2b5..b5893fd29 100644 --- a/src/nnetbin/rbm-convert-to-nnet.cc +++ b/src/nnetbin/rbm-convert-to-nnet.cc @@ -57,9 +57,9 @@ int main(int argc, char *argv[]) { nnet.Read(ki.Stream(), binary_read); } - KALDI_ASSERT(nnet.LayerCount() == 1); - KALDI_ASSERT(nnet.Layer(0)->GetType() == Component::kRbm); - RbmBase& rbm = dynamic_cast(*nnet.Layer(0)); + KALDI_ASSERT(nnet.NumComponents() == 1); + KALDI_ASSERT(nnet.GetComponent(0).GetType() == Component::kRbm); + RbmBase& rbm = dynamic_cast(nnet.GetComponent(0)); { Output ko(model_out_filename, binary_write); diff --git a/src/nnetbin/rbm-train-cd1-frmshuff.cc b/src/nnetbin/rbm-train-cd1-frmshuff.cc index ee01ddc64..38e689ddc 100644 --- a/src/nnetbin/rbm-train-cd1-frmshuff.cc +++ b/src/nnetbin/rbm-train-cd1-frmshuff.cc @@ -103,9 +103,9 @@ int main(int argc, char *argv[]) { Nnet nnet; nnet.Read(model_filename); - KALDI_ASSERT(nnet.LayerCount()==1); - KALDI_ASSERT(nnet.Layer(0)->GetType() == Component::kRbm); - RbmBase &rbm = dynamic_cast(*nnet.Layer(0)); + KALDI_ASSERT(nnet.NumComponents()==1); + KALDI_ASSERT(nnet.GetComponent(0).GetType() == Component::kRbm); + RbmBase &rbm = dynamic_cast(nnet.GetComponent(0)); // Configure the RBM // first get make some options easy to access: diff --git a/src/nnetbin/transf-to-nnet.cc b/src/nnetbin/transf-to-nnet.cc index 945762e38..88b2f608e 100644 --- a/src/nnetbin/transf-to-nnet.cc +++ b/src/nnetbin/transf-to-nnet.cc @@ -61,7 +61,7 @@ int main(int argc, char *argv[]) { //we will put the transform to the nnet Nnet nnet; //create affine transform layer - AffineTransform* layer = new AffineTransform(transform.NumCols(),transform.NumRows(),&nnet); + AffineTransform* layer = new AffineTransform(transform.NumCols(),transform.NumRows()); //the pointer will be given to the nnet, so we don't need to call delete //convert Matrix to CuMatrix @@ -71,7 +71,7 @@ int main(int argc, char *argv[]) { layer->SetLinearity(cu_transform); //append layer to the nnet - nnet.AppendLayer(layer); + nnet.AppendComponent(layer); //write the nnet {