nnet1: integrating comments from Dan.

This commit is contained in:
vesis84 2016-05-12 02:12:01 +02:00
Родитель 8ab8fb954f
Коммит bffbe48130
34 изменённых файлов: 196 добавлений и 299 удалений

Просмотреть файл

@ -82,7 +82,7 @@ if [ $stage -le 4 ]; then
cnn_dbn=$dir/cnn_dbn.nnet
{ # Concatenate CNN layers and DBN,
num_components=$(nnet-info $feature_transform | grep -m1 num-components | awk '{print $2;}')
nnet-concat "nnet-copy --remove-first-layers=$num_components $feature_transform_dbn - |" $dbn $cnn_dbn \
nnet-concat "nnet-copy --remove-first-components=$num_components $feature_transform_dbn - |" $dbn $cnn_dbn \
2>$dir/log/concat_cnn_dbn.log || exit 1
}
# Train

Просмотреть файл

@ -66,7 +66,7 @@ if [ $stage -le 2 ]; then
# Concat 'feature_transform' with convolutional layers,
dir=exp/cnn4c
nnet-concat $dir/final.feature_transform \
"nnet-copy --remove-last-layers=$(((hid_layers+1)*2)) $dir/final.nnet - |" \
"nnet-copy --remove-last-components=$(((hid_layers+1)*2)) $dir/final.nnet - |" \
$dir/final.feature_transform_cnn
fi
@ -99,7 +99,7 @@ if [ $stage -le 5 ]; then
cnn_dbn=$dir/cnn_dbn.nnet
{ # Concatenate CNN layers and DBN,
num_components=$(nnet-info $feature_transform | grep -m1 num-components | awk '{print $2;}')
cnn="nnet-copy --remove-first-layers=$num_components $feature_transform_dbn - |"
cnn="nnet-copy --remove-first-components=$num_components $feature_transform_dbn - |"
nnet-concat "$cnn" $dbn $cnn_dbn 2>$dir/log/concat_cnn_dbn.log || exit 1
}
# Train

Просмотреть файл

@ -93,7 +93,7 @@ if [ $stage -le 4 ]; then
cnn_dbn=$dir/cnn_dbn.nnet
{ # Concatenate CNN layers and DBN,
num_components=$(nnet-info $feature_transform | grep -m1 num-components | awk '{print $2;}')
cnn="nnet-copy --remove-first-layers=$num_components $feature_transform_dbn - |"
cnn="nnet-copy --remove-first-components=$num_components $feature_transform_dbn - |"
nnet-concat "$cnn" $dbn $cnn_dbn 2>$dir/log/concat_cnn_dbn.log || exit 1
}
# Train

Просмотреть файл

@ -148,7 +148,7 @@ if [ $stage -le 2 ]; then
# Compose feature_transform for 2nd part,
nnet-initialize <(echo "<Splice> <InputDim> $bn1_dim <OutputDim> $((13*bn1_dim)) <BuildVector> -10 -5:5 10 </BuildVector>") \
$dir_part1/splice_for_bottleneck.nnet
nnet-concat $dir_part1/final.feature_transform "nnet-copy --remove-last-layers=4 $dir_part1/final.nnet - |" \
nnet-concat $dir_part1/final.feature_transform "nnet-copy --remove-last-components=4 $dir_part1/final.nnet - |" \
$dir_part1/splice_for_bottleneck.nnet $dir_part1/final.feature_transform.part1
# Train 2nd part,
$cuda_cmd $dir/log/train_nnet.log \

Просмотреть файл

@ -89,7 +89,7 @@ if [ $stage -le 4 ]; then
dir=exp/nnet5b_uc-part1
feature_transform=$dir/final.feature_transform.part1
nnet-concat $dir/final.feature_transform \
"nnet-copy --remove-last-layers=4 --binary=false $dir/final.nnet - |" \
"nnet-copy --remove-last-components=4 --binary=false $dir/final.nnet - |" \
"utils/nnet/gen_splice.py --fea-dim=80 --splice=2 --splice-step=5 |" \
$feature_transform || exit 1

Просмотреть файл

@ -78,7 +78,7 @@ if [ $stage -le 4 ]; then
nnet-initialize <(echo "<Splice> <InputDim> 80 <OutputDim> 1040 <BuildVector> -10 -5:5 10 </BuildVector>") \
$dir/splice_for_bottleneck.nnet
# Concatanate the input-transform, 1stage network, splicing,
nnet-concat $dir/final.feature_transform "nnet-copy --remove-last-layers=4 $dir/final.nnet - |" \
nnet-concat $dir/final.feature_transform "nnet-copy --remove-last-components=4 $dir/final.nnet - |" \
$dir/splice_for_bottleneck.nnet $feature_transform
# Train 2nd network, overall context +/-15 frames,

Просмотреть файл

@ -455,7 +455,7 @@ steps/nnet/train_scheduler.sh \
${config:+ --config $config} \
$nnet_init "$feats_tr" "$feats_cv" "$labels_tr" "$labels_cv" $dir
echo "$0: Successfuly finished! '$dir'"
echo "$0: Successfuly finished. '$dir'"
sleep 3
exit 0

Просмотреть файл

@ -233,6 +233,22 @@ inline T CuArray<T>::Max() const {
}
template<typename T>
void CuArray<T>::Read(std::istream& in, bool binary) {
std::vector<T> tmp;
ReadIntegerVector(in, binary, &tmp);
(*this) = tmp;
}
template<typename T>
void CuArray<T>::Write(std::ostream& out, bool binary) const {
std::vector<T> tmp(this->Dim());
this->CopyToVec(&tmp);
WriteIntegerVector(out, binary, tmp);
}
/**
* Print the vector to stream
*/
@ -248,23 +264,6 @@ std::ostream &operator << (std::ostream &out, const CuArray<T> &vec) {
return out;
}
template<typename T>
void ReadIntegerVector(std::istream& in, bool binary, CuArray<T>* vec) {
std::vector<T> tmp;
ReadIntegerVector(in, binary, &tmp);
(*vec) = tmp;
}
template<typename T>
void WriteIntegerVector(std::ostream& out, bool binary, const CuArray<T>& vec) {
std::vector<T> tmp(vec.Dim());
vec.CopyToVec(&tmp);
WriteIntegerVector(out, binary, tmp);
}
} // namespace kaldi
#endif

Просмотреть файл

@ -44,7 +44,7 @@ void CuArray<int32>::Set(const int32 &value) {
dim3 dimGrid(n_blocks(Dim(), CU2DBLOCK));
::MatrixDim d = { 1, Dim(), Dim() };
cudaI32_set_const(dimGrid, dimBlock, data_, value, d);
cuda_int32_set_const(dimGrid, dimBlock, data_, value, d);
CU_SAFE_CALL(cudaGetLastError());
CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
@ -69,7 +69,7 @@ void CuArray<int32>::Add(const int32 &value) {
dim3 dimGrid(n_blocks(Dim(), CU2DBLOCK));
::MatrixDim d = { 1, Dim(), Dim() };
cudaI32_add(dimGrid, dimBlock, data_, value, d);
cuda_int32_add(dimGrid, dimBlock, data_, value, d);
CU_SAFE_CALL(cudaGetLastError());
CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());

Просмотреть файл

@ -121,6 +121,10 @@ class CuArray {
CuArray<T> &operator= (const std::vector<T> &in) {
this->CopyFromVec(in); return *this;
}
/// I/O
void Read(std::istream &is, bool binary);
void Write(std::ostream &is, bool binary) const;
private:
MatrixIndexT dim_; ///< dimension of the vector
@ -133,17 +137,8 @@ class CuArray {
template<typename T>
std::ostream &operator << (std::ostream &out, const CuArray<T> &vec);
/// Wrapper for reading,
template<typename T>
void ReadIntegerVector(std::istream& in, bool binary, CuArray<T>* vec);
/// Wrapper for writing,
template<typename T>
void WriteIntegerVector(std::ostream& out, bool binary, const CuArray<T>& vec);
} // namespace
#include "cudamatrix/cu-array-inl.h"
#endif

Просмотреть файл

@ -35,8 +35,8 @@ extern "C" {
/*********************************************************
* int32 CUDA kernel calls (no template wrapper)
*/
void cudaI32_set_const(dim3 Gr, dim3 Bl, int32_cuda *mat, int32_cuda value, MatrixDim d);
void cudaI32_add(dim3 Gr, dim3 Bl, int32_cuda *mat, int32_cuda value, MatrixDim d);
void cuda_int32_set_const(dim3 Gr, dim3 Bl, int32_cuda *mat, int32_cuda value, MatrixDim d);
void cuda_int32_add(dim3 Gr, dim3 Bl, int32_cuda *mat, int32_cuda value, MatrixDim d);

Просмотреть файл

@ -2103,10 +2103,10 @@ static void _diff_xent(const int32_cuda* vec_tgt, Real* mat_net_out, Real* vec_l
/*
* "int32"
*/
void cudaI32_set_const(dim3 Gr, dim3 Bl, int32_cuda* mat, int32_cuda value, MatrixDim d) {
void cuda_int32_set_const(dim3 Gr, dim3 Bl, int32_cuda* mat, int32_cuda value, MatrixDim d) {
_set_const<<<Gr,Bl>>>(mat,value,d);
}
void cudaI32_add(dim3 Gr, dim3 Bl, int32_cuda* mat, int32_cuda value, MatrixDim d) {
void cuda_int32_add(dim3 Gr, dim3 Bl, int32_cuda* mat, int32_cuda value, MatrixDim d) {
_add<<<Gr,Bl>>>(mat,value,d);
}

Просмотреть файл

@ -35,20 +35,15 @@ namespace nnet1 {
class Softmax : public Component {
public:
Softmax(int32 dim_in, int32 dim_out) :
Softmax(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out)
{ }
~Softmax()
{ }
Component* Copy() const {
return new Softmax(*this);
}
ComponentType GetType() const {
return kSoftmax;
}
Component* Copy() const { return new Softmax(*this); }
ComponentType GetType() const { return kSoftmax; }
void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
CuMatrixBase<BaseFloat> *out) {
@ -72,20 +67,15 @@ class Softmax : public Component {
class BlockSoftmax : public Component {
public:
BlockSoftmax(int32 dim_in, int32 dim_out) :
BlockSoftmax(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out)
{ }
~BlockSoftmax()
{ }
Component* Copy() const {
return new BlockSoftmax(*this);
}
ComponentType GetType() const {
return kBlockSoftmax;
}
Component* Copy() const { return new BlockSoftmax(*this); }
ComponentType GetType() const { return kBlockSoftmax; }
void InitData(std::istream &is) {
// parse config
@ -175,20 +165,15 @@ class BlockSoftmax : public Component {
class Sigmoid : public Component {
public:
Sigmoid(int32 dim_in, int32 dim_out) :
Sigmoid(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out)
{ }
~Sigmoid()
{ }
Component* Copy() const {
return new Sigmoid(*this);
}
ComponentType GetType() const {
return kSigmoid;
}
Component* Copy() const { return new Sigmoid(*this); }
ComponentType GetType() const { return kSigmoid; }
void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
CuMatrixBase<BaseFloat> *out) {
@ -209,20 +194,15 @@ class Sigmoid : public Component {
class Tanh : public Component {
public:
Tanh(int32 dim_in, int32 dim_out) :
Tanh(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out)
{ }
~Tanh()
{ }
Component* Copy() const {
return new Tanh(*this);
}
ComponentType GetType() const {
return kTanh;
}
Component* Copy() const { return new Tanh(*this); }
ComponentType GetType() const { return kTanh; }
void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
CuMatrixBase<BaseFloat> *out) {
@ -243,7 +223,7 @@ class Tanh : public Component {
class Dropout : public Component {
public:
Dropout(int32 dim_in, int32 dim_out) :
Dropout(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out),
dropout_retention_(0.5)
{ }
@ -251,13 +231,8 @@ class Dropout : public Component {
~Dropout()
{ }
Component* Copy() const {
return new Dropout(*this);
}
ComponentType GetType() const {
return kDropout;
}
Component* Copy() const { return new Dropout(*this); }
ComponentType GetType() const { return kDropout; }
void InitData(std::istream &is) {
is >> std::ws; // eat-up whitespace
@ -308,9 +283,7 @@ class Dropout : public Component {
in_diff->Scale(1.0/dropout_retention_);
}
BaseFloat GetDropoutRetention() {
return dropout_retention_;
}
BaseFloat GetDropoutRetention() { return dropout_retention_; }
void SetDropoutRetention(BaseFloat dr) {
dropout_retention_ = dr;

Просмотреть файл

@ -32,11 +32,11 @@ namespace nnet1 {
class AffineTransform : public UpdatableComponent {
public:
AffineTransform(int32 dim_in, int32 dim_out)
: UpdatableComponent(dim_in, dim_out),
linearity_(dim_out, dim_in), bias_(dim_out),
linearity_corr_(dim_out, dim_in), bias_corr_(dim_out),
max_norm_(0.0)
AffineTransform(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out),
linearity_(dim_out, dim_in), bias_(dim_out),
linearity_corr_(dim_out, dim_in), bias_corr_(dim_out),
max_norm_(0.0)
{ }
~AffineTransform()
{ }
@ -216,18 +216,14 @@ class AffineTransform : public UpdatableComponent {
}
/// Accessors to the component parameters,
const CuVectorBase<BaseFloat>& GetBias() const {
return bias_;
}
const CuVectorBase<BaseFloat>& GetBias() const { return bias_; }
void SetBias(const CuVectorBase<BaseFloat>& bias) {
KALDI_ASSERT(bias.Dim() == bias_.Dim());
bias_.CopyFromVec(bias);
}
const CuMatrixBase<BaseFloat>& GetLinearity() const {
return linearity_;
}
const CuMatrixBase<BaseFloat>& GetLinearity() const { return linearity_; }
void SetLinearity(const CuMatrixBase<BaseFloat>& linearity) {
KALDI_ASSERT(linearity.NumRows() == linearity_.NumRows());

Просмотреть файл

@ -40,11 +40,11 @@ namespace nnet1 {
*/
class AveragePooling2DComponent : public Component {
public:
AveragePooling2DComponent(int32 dim_in, int32 dim_out)
: Component(dim_in, dim_out),
fmap_x_len_(0), fmap_y_len_(0),
pool_x_len_(0), pool_y_len_(0),
pool_x_step_(0), pool_y_step_(0)
AveragePooling2DComponent(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out),
fmap_x_len_(0), fmap_y_len_(0),
pool_x_len_(0), pool_y_len_(0),
pool_x_step_(0), pool_y_step_(0)
{ }
~AveragePooling2DComponent()
{ }
@ -199,8 +199,8 @@ class AveragePooling2DComponent : public Component {
private:
int32 fmap_x_len_, fmap_y_len_,
pool_x_len_, pool_y_len_,
pool_x_step_, pool_y_step_;
pool_x_len_, pool_y_len_,
pool_x_step_, pool_y_step_;
};
} // namespace nnet1

Просмотреть файл

@ -39,9 +39,13 @@ namespace nnet1 {
*/
class AveragePoolingComponent : public Component {
public:
AveragePoolingComponent(int32 dim_in, int32 dim_out)
: Component(dim_in, dim_out), pool_size_(0), pool_step_(0), pool_stride_(0)
AveragePoolingComponent(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out),
pool_size_(0),
pool_step_(0),
pool_stride_(0)
{ }
~AveragePoolingComponent()
{ }

Просмотреть файл

@ -49,7 +49,7 @@ namespace nnet1 {
class BLstmProjectedStreams : public UpdatableComponent {
public:
BLstmProjectedStreams(int32 input_dim, int32 output_dim) :
BLstmProjectedStreams(int32 input_dim, int32 output_dim):
UpdatableComponent(input_dim, output_dim),
ncell_(0),
nrecur_(static_cast<int32>(output_dim/2)),
@ -61,17 +61,12 @@ class BLstmProjectedStreams : public UpdatableComponent {
~BLstmProjectedStreams()
{ }
Component* Copy() const {
return new BLstmProjectedStreams(*this);
}
ComponentType GetType() const {
return kBLstmProjectedStreams;
}
Component* Copy() const { return new BLstmProjectedStreams(*this); }
ComponentType GetType() const { return kBLstmProjectedStreams; }
/// set the utterance length used for parallel training
void SetSeqLengths(const std::vector<int32> &sequence_lengths) {
sequence_lengths_ = sequence_lengths;
sequence_lengths_ = sequence_lengths;
}
void InitData(std::istream &is) {

Просмотреть файл

@ -101,7 +101,7 @@ class Component {
/// Generic interface of a component,
public:
Component(int32 input_dim, int32 output_dim) :
Component(int32 input_dim, int32 output_dim):
input_dim_(input_dim),
output_dim_(output_dim)
{ }
@ -201,7 +201,7 @@ class Component {
*/
class UpdatableComponent : public Component {
public:
UpdatableComponent(int32 input_dim, int32 output_dim) :
UpdatableComponent(int32 input_dim, int32 output_dim):
Component(input_dim, output_dim),
learn_rate_coef_(1.0),
bias_learn_rate_coef_(1.0)

Просмотреть файл

@ -71,7 +71,7 @@ namespace nnet1 {
*/
class Convolutional2DComponent : public UpdatableComponent {
public:
Convolutional2DComponent(int32 dim_in, int32 dim_out) :
Convolutional2DComponent(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out),
fmap_x_len_(0), fmap_y_len_(0),
filt_x_len_(0), filt_y_len_(0),
@ -82,13 +82,8 @@ class Convolutional2DComponent : public UpdatableComponent {
~Convolutional2DComponent()
{ }
Component* Copy() const {
return new Convolutional2DComponent(*this);
}
ComponentType GetType() const {
return kConvolutional2DComponent;
}
Component* Copy() const { return new Convolutional2DComponent(*this); }
ComponentType GetType() const { return kConvolutional2DComponent; }
void InitData(std::istream &is) {
// define options

Просмотреть файл

@ -65,7 +65,7 @@ namespace nnet1 {
*/
class ConvolutionalComponent : public UpdatableComponent {
public:
ConvolutionalComponent(int32 dim_in, int32 dim_out) :
ConvolutionalComponent(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out),
patch_dim_(0),
patch_step_(0),
@ -76,13 +76,8 @@ class ConvolutionalComponent : public UpdatableComponent {
~ConvolutionalComponent()
{ }
Component* Copy() const {
return new ConvolutionalComponent(*this);
}
ComponentType GetType() const {
return kConvolutionalComponent;
}
Component* Copy() const { return new ConvolutionalComponent(*this); }
ComponentType GetType() const { return kConvolutionalComponent; }
void InitData(std::istream &is) {
// define options

Просмотреть файл

@ -42,7 +42,7 @@ namespace nnet1 {
*/
class FramePoolingComponent : public UpdatableComponent {
public:
FramePoolingComponent(int32 dim_in, int32 dim_out) :
FramePoolingComponent(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out),
feature_dim_(0),
normalize_(false)
@ -51,13 +51,8 @@ class FramePoolingComponent : public UpdatableComponent {
~FramePoolingComponent()
{ }
Component* Copy() const {
return new FramePoolingComponent(*this);
}
ComponentType GetType() const {
return kFramePoolingComponent;
}
Component* Copy() const { return new FramePoolingComponent(*this); }
ComponentType GetType() const { return kFramePoolingComponent; }
/**
* Here the offsets are w.r.t. central frames, which has offset 0.

Просмотреть файл

@ -36,21 +36,16 @@ namespace nnet1 {
class KlHmm : public Component {
public:
KlHmm(int32 dim_in, int32 dim_out)
: Component(dim_in, dim_out),
kl_stats_(dim_out, dim_in, kSetZero)
KlHmm(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out),
kl_stats_(dim_out, dim_in, kSetZero)
{ }
~KlHmm()
{ }
Component* Copy() const {
return new KlHmm(*this);
}
ComponentType GetType() const {
return kKlHmm;
}
Component* Copy() const { return new KlHmm(*this); }
ComponentType GetType() const { return kKlHmm; }
void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
CuMatrixBase<BaseFloat> *out) {
@ -153,9 +148,6 @@ class KlHmm : public Component {
CuMatrix<BaseFloat> kl_inv_q_;
};
} // namespace nnet1
} // namespace kaldi

Просмотреть файл

@ -32,7 +32,7 @@ namespace nnet1 {
class LinearTransform : public UpdatableComponent {
public:
LinearTransform(int32 dim_in, int32 dim_out) :
LinearTransform(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out),
linearity_(dim_out, dim_in),
linearity_corr_(dim_out, dim_in)
@ -41,13 +41,8 @@ class LinearTransform : public UpdatableComponent {
~LinearTransform()
{ }
Component* Copy() const {
return new LinearTransform(*this);
}
ComponentType GetType() const {
return kLinearTransform;
}
Component* Copy() const { return new LinearTransform(*this); }
ComponentType GetType() const { return kLinearTransform; }
void InitData(std::istream &is) {
// define options
@ -196,9 +191,7 @@ class LinearTransform : public UpdatableComponent {
}
/// Accessors to the component parameters
const CuMatrixBase<BaseFloat>& GetLinearity() {
return linearity_;
}
const CuMatrixBase<BaseFloat>& GetLinearity() { return linearity_; }
void SetLinearity(const CuMatrixBase<BaseFloat>& linearity) {
KALDI_ASSERT(linearity.NumRows() == linearity_.NumRows());
@ -206,9 +199,7 @@ class LinearTransform : public UpdatableComponent {
linearity_.CopyFromMat(linearity);
}
const CuMatrixBase<BaseFloat>& GetLinearityCorr() {
return linearity_corr_;
}
const CuMatrixBase<BaseFloat>& GetLinearityCorr() { return linearity_corr_; }
private:
CuMatrix<BaseFloat> linearity_;

Просмотреть файл

@ -61,7 +61,7 @@ class LossItf {
class Xent : public LossItf {
public:
Xent() :
Xent():
frames_progress_(0.0),
xentropy_progress_(0.0),
entropy_progress_(0.0)
@ -124,9 +124,15 @@ class Xent : public LossItf {
class Mse : public LossItf {
public:
Mse() : frames_(0.0), loss_(0.0),
frames_progress_(0.0), loss_progress_(0.0) { }
~Mse() { }
Mse():
frames_(0.0),
loss_(0.0),
frames_progress_(0.0),
loss_progress_(0.0)
{ }
~Mse()
{ }
/// Evaluate mean square error using target-matrix,
void Eval(const VectorBase<BaseFloat> &frame_weights,
@ -164,7 +170,9 @@ class Mse : public LossItf {
class MultiTaskLoss : public LossItf {
public:
MultiTaskLoss() { }
MultiTaskLoss()
{ }
~MultiTaskLoss() {
while (loss_vec_.size() > 0) {
delete loss_vec_.back();

Просмотреть файл

@ -48,7 +48,7 @@ namespace nnet1 {
class LstmProjectedStreams : public UpdatableComponent {
public:
LstmProjectedStreams(int32 input_dim, int32 output_dim) :
LstmProjectedStreams(int32 input_dim, int32 output_dim):
UpdatableComponent(input_dim, output_dim),
ncell_(0),
nrecur_(output_dim),
@ -60,13 +60,8 @@ class LstmProjectedStreams : public UpdatableComponent {
~LstmProjectedStreams()
{ }
Component* Copy() const {
return new LstmProjectedStreams(*this);
}
ComponentType GetType() const {
return kLstmProjectedStreams;
}
Component* Copy() const { return new LstmProjectedStreams(*this); }
ComponentType GetType() const { return kLstmProjectedStreams; }
void InitData(std::istream &is) {
// define options,

Просмотреть файл

@ -40,22 +40,18 @@ namespace nnet1 {
*/
class MaxPooling2DComponent : public Component {
public:
MaxPooling2DComponent(int32 dim_in, int32 dim_out)
: Component(dim_in, dim_out),
fmap_x_len_(0), fmap_y_len_(0),
pool_x_len_(0), pool_y_len_(0), pool_x_step_(0), pool_y_step_(0)
MaxPooling2DComponent(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out),
fmap_x_len_(0), fmap_y_len_(0),
pool_x_len_(0), pool_y_len_(0),
pool_x_step_(0), pool_y_step_(0)
{ }
~MaxPooling2DComponent()
{ }
Component* Copy() const {
return new MaxPooling2DComponent(*this);
}
ComponentType GetType() const {
return kMaxPooling2DComponent;
}
Component* Copy() const { return new MaxPooling2DComponent(*this); }
ComponentType GetType() const { return kMaxPooling2DComponent; }
void InitData(std::istream &is) {
// parse config
@ -219,8 +215,8 @@ class MaxPooling2DComponent : public Component {
private:
int32 fmap_x_len_, fmap_y_len_,
pool_x_len_, pool_y_len_,
pool_x_step_, pool_y_step_;
pool_x_len_, pool_y_len_,
pool_x_step_, pool_y_step_;
};
} // namespace nnet1

Просмотреть файл

@ -39,20 +39,18 @@ namespace nnet1 {
*/
class MaxPoolingComponent : public Component {
public:
MaxPoolingComponent(int32 dim_in, int32 dim_out)
: Component(dim_in, dim_out), pool_size_(0), pool_step_(0), pool_stride_(0)
MaxPoolingComponent(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out),
pool_size_(0),
pool_step_(0),
pool_stride_(0)
{ }
~MaxPoolingComponent()
{ }
Component* Copy() const {
return new MaxPoolingComponent(*this);
}
ComponentType GetType() const {
return kMaxPoolingComponent;
}
Component* Copy() const { return new MaxPoolingComponent(*this); }
ComponentType GetType() const { return kMaxPoolingComponent; }
void InitData(std::istream &is) {
// parse config

Просмотреть файл

@ -35,28 +35,18 @@ namespace nnet1 {
class ParallelComponent : public UpdatableComponent {
public:
ParallelComponent(int32 dim_in, int32 dim_out)
: UpdatableComponent(dim_in, dim_out)
ParallelComponent(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out)
{ }
~ParallelComponent()
{ }
Component* Copy() const {
return new ParallelComponent(*this);
}
Component* Copy() const { return new ParallelComponent(*this); }
ComponentType GetType() const { return kParallelComponent; }
ComponentType GetType() const {
return kParallelComponent;
}
const Nnet& GetNestedNnet(int32 id) const {
return nnet_.at(id);
}
Nnet& GetNestedNnet(int32 id) {
return nnet_.at(id);
}
const Nnet& GetNestedNnet(int32 id) const { return nnet_.at(id); }
Nnet& GetNestedNnet(int32 id) { return nnet_.at(id); }
void InitData(std::istream &is) {
// define options

Просмотреть файл

@ -37,9 +37,11 @@ struct PdfPriorOptions {
BaseFloat prior_scale;
BaseFloat prior_floor;
PdfPriorOptions() : class_frame_counts(""),
prior_scale(1.0),
prior_floor(1e-10) {}
PdfPriorOptions():
class_frame_counts(""),
prior_scale(1.0),
prior_floor(1e-10)
{ }
void Register(OptionsItf *opts) {
opts->Register("class-frame-counts", &class_frame_counts,

Просмотреть файл

@ -40,7 +40,7 @@ struct NnetDataRandomizerOptions {
int32 randomizer_seed;
int32 minibatch_size;
NnetDataRandomizerOptions() :
NnetDataRandomizerOptions():
randomizer_size(32768),
randomizer_seed(777),
minibatch_size(256)
@ -86,12 +86,12 @@ class RandomizerMask {
*/
class MatrixRandomizer {
public:
MatrixRandomizer() :
MatrixRandomizer():
data_begin_(0),
data_end_(0)
{ }
explicit MatrixRandomizer(const NnetDataRandomizerOptions &conf) :
explicit MatrixRandomizer(const NnetDataRandomizerOptions &conf):
data_begin_(0),
data_end_(0)
{
@ -147,12 +147,12 @@ class MatrixRandomizer {
/// Randomizes elements of a vector according to a mask
class VectorRandomizer {
public:
VectorRandomizer() :
VectorRandomizer():
data_begin_(0),
data_end_(0)
{ }
explicit VectorRandomizer(const NnetDataRandomizerOptions &conf) :
explicit VectorRandomizer(const NnetDataRandomizerOptions &conf):
data_begin_(0),
data_end_(0)
{
@ -208,12 +208,12 @@ class VectorRandomizer {
template<typename T>
class StdVectorRandomizer {
public:
StdVectorRandomizer() :
StdVectorRandomizer():
data_begin_(0),
data_end_(0)
{ }
explicit StdVectorRandomizer(const NnetDataRandomizerOptions &conf) :
explicit StdVectorRandomizer(const NnetDataRandomizerOptions &conf):
data_begin_(0),
data_end_(0)
{

Просмотреть файл

@ -39,7 +39,7 @@ class RbmBase : public Component {
Gaussian
} RbmNodeType;
RbmBase(int32 dim_in, int32 dim_out) :
RbmBase(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out)
{ }
@ -95,7 +95,7 @@ class RbmBase : public Component {
class Rbm : public RbmBase {
public:
Rbm(int32 dim_in, int32 dim_out) :
Rbm(int32 dim_in, int32 dim_out):
RbmBase(dim_in, dim_out)
{ }

Просмотреть файл

@ -41,7 +41,7 @@ namespace nnet1 {
*/
class SimpleSentenceAveragingComponent : public Component {
public:
SimpleSentenceAveragingComponent(int32 dim_in, int32 dim_out) :
SimpleSentenceAveragingComponent(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out),
gradient_boost_(100.0),
shrinkage_(0.0),
@ -179,8 +179,8 @@ class SimpleSentenceAveragingComponent : public Component {
/** Deprecated!!!, keeping it as Katka Zmolikova used it in JSALT 2015 */
class SentenceAveragingComponent : public UpdatableComponent {
public:
SentenceAveragingComponent(int32 dim_in, int32 dim_out)
: UpdatableComponent(dim_in, dim_out), learn_rate_factor_(100.0)
SentenceAveragingComponent(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out), learn_rate_factor_(100.0)
{ }
~SentenceAveragingComponent()
{ }

Просмотреть файл

@ -33,12 +33,15 @@ struct NnetTrainOptions {
BaseFloat momentum;
BaseFloat l2_penalty;
BaseFloat l1_penalty;
// default values
NnetTrainOptions() : learn_rate(0.008),
momentum(0.0),
l2_penalty(0.0),
l1_penalty(0.0)
{ }
NnetTrainOptions():
learn_rate(0.008),
momentum(0.0),
l2_penalty(0.0),
l1_penalty(0.0)
{ }
// register options
void Register(OptionsItf *opts) {
opts->Register("learn-rate", &learn_rate, "Learning rate");
@ -46,6 +49,7 @@ struct NnetTrainOptions {
opts->Register("l2-penalty", &l2_penalty, "L2 penalty (weight decay)");
opts->Register("l1-penalty", &l1_penalty, "L1 penalty (promote sparsity)");
}
// print for debug purposes
friend std::ostream& operator<<(std::ostream& os, const NnetTrainOptions& opts) {
os << "RbmTrainOptions : "
@ -66,15 +70,18 @@ struct RbmTrainOptions {
int32 momentum_steps;
int32 momentum_step_period;
BaseFloat l2_penalty;
// default values
RbmTrainOptions() : learn_rate(0.4),
momentum(0.5),
momentum_max(0.9),
momentum_steps(40),
momentum_step_period(500000),
// 500000 * 40 = 55h of linear increase of momentum
l2_penalty(0.0002)
{ }
RbmTrainOptions():
learn_rate(0.4),
momentum(0.5),
momentum_max(0.9),
momentum_steps(40),
momentum_step_period(500000),
// 500000 * 40 = 55h of linear increase of momentum
l2_penalty(0.0002)
{ }
// register options
void Register(OptionsItf *opts) {
opts->Register("learn-rate", &learn_rate, "Learning rate");
@ -91,6 +98,7 @@ struct RbmTrainOptions {
opts->Register("l2-penalty", &l2_penalty,
"L2 penalty (weight decay, increases mixing-rate)");
}
// print for debug purposes
friend std::ostream& operator<<(std::ostream& os, const RbmTrainOptions& opts) {
os << "RbmTrainOptions : "

Просмотреть файл

@ -41,20 +41,15 @@ namespace nnet1 {
*/
class Splice: public Component {
public:
Splice(int32 dim_in, int32 dim_out) :
Splice(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out)
{ }
~Splice()
{ }
Component* Copy() const {
return new Splice(*this);
}
ComponentType GetType() const {
return kSplice;
}
Component* Copy() const { return new Splice(*this); }
ComponentType GetType() const { return kSplice; }
void InitData(std::istream &is) {
// define options,
@ -64,7 +59,7 @@ class Splice: public Component {
while (is >> std::ws, !is.eof()) {
ReadToken(is, false, &token);
/**/ if (token == "<ReadVector>") {
ReadIntegerVector(is, false, &frame_offsets_);
frame_offsets_.Read(is, false);
} else if (token == "<BuildVector>") {
// Parse the list of 'matlab-like' indices:
// <BuildVector> 1:1:1000 1 2 3 1:10 </BuildVector>
@ -92,12 +87,12 @@ class Splice: public Component {
}
void ReadData(std::istream &is, bool binary) {
ReadIntegerVector(is, binary, &frame_offsets_);
frame_offsets_.Read(is, binary);
KALDI_ASSERT(frame_offsets_.Dim() * InputDim() == OutputDim());
}
void WriteData(std::ostream &os, bool binary) const {
WriteIntegerVector(os, binary, frame_offsets_);
frame_offsets_.Write(os, binary);
}
std::string Info() const {
@ -131,20 +126,15 @@ class Splice: public Component {
*/
class CopyComponent: public Component {
public:
CopyComponent(int32 dim_in, int32 dim_out) :
CopyComponent(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out)
{ }
~CopyComponent()
{ }
Component* Copy() const {
return new CopyComponent(*this);
}
ComponentType GetType() const {
return kCopy;
}
Component* Copy() const { return new CopyComponent(*this); }
ComponentType GetType() const { return kCopy; }
void InitData(std::istream &is) {
// define options,
@ -154,7 +144,7 @@ class CopyComponent: public Component {
while (is >> std::ws, !is.eof()) {
ReadToken(is, false, &token);
/**/ if (token == "<ReadVector>") {
ReadIntegerVector(is, false, &copy_from_indices_);
copy_from_indices_.Read(is, false);
} else if (token == "<BuildVector>") {
// <BuildVector> 1:1:1000 1:1:1000 1 2 3 1:10 </BuildVector>
// 'matlab-line' indexing, read the colon-separated-lists:
@ -188,7 +178,7 @@ class CopyComponent: public Component {
}
void ReadData(std::istream &is, bool binary) {
ReadIntegerVector(is, binary, &copy_from_indices_);
copy_from_indices_.Read(is, binary);
KALDI_ASSERT(copy_from_indices_.Dim() == OutputDim());
copy_from_indices_.Add(-1); // -1 from each element,
}
@ -196,7 +186,7 @@ class CopyComponent: public Component {
void WriteData(std::ostream &os, bool binary) const {
CuArray<int32> tmp(copy_from_indices_);
tmp.Add(1); // +1 to each element,
WriteIntegerVector(os, binary, tmp);
tmp.Write(os, binary);
}
std::string Info() const {
@ -234,20 +224,15 @@ class CopyComponent: public Component {
*/
class LengthNormComponent: public Component {
public:
LengthNormComponent(int32 dim_in, int32 dim_out) :
LengthNormComponent(int32 dim_in, int32 dim_out):
Component(dim_in, dim_out)
{ }
~LengthNormComponent()
{ }
Component* Copy() const {
return new LengthNormComponent(*this);
}
ComponentType GetType() const {
return kLengthNormComponent;
}
Component* Copy() const { return new LengthNormComponent(*this); }
ComponentType GetType() const { return kLengthNormComponent; }
void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
CuMatrixBase<BaseFloat> *out) {
@ -286,20 +271,16 @@ class LengthNormComponent: public Component {
*/
class AddShift : public UpdatableComponent {
public:
AddShift(int32 dim_in, int32 dim_out) :
UpdatableComponent(dim_in, dim_out), shift_data_(dim_in)
AddShift(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out),
shift_data_(dim_in)
{ }
~AddShift()
{ }
Component* Copy() const {
return new AddShift(*this);
}
ComponentType GetType() const {
return kAddShift;
}
Component* Copy() const { return new AddShift(*this); }
ComponentType GetType() const { return kAddShift; }
void InitData(std::istream &is) {
// define options
@ -389,9 +370,7 @@ class AddShift : public UpdatableComponent {
shift_data_.AddVec(-lr * learn_rate_coef_, shift_data_grad_);
}
void SetLearnRateCoef(float c) {
learn_rate_coef_ = c;
}
void SetLearnRateCoef(float c) { learn_rate_coef_ = c; }
protected:
CuVector<BaseFloat> shift_data_;
@ -405,7 +384,7 @@ class AddShift : public UpdatableComponent {
*/
class Rescale : public UpdatableComponent {
public:
Rescale(int32 dim_in, int32 dim_out) :
Rescale(int32 dim_in, int32 dim_out):
UpdatableComponent(dim_in, dim_out),
scale_data_(dim_in)
{ }
@ -413,13 +392,8 @@ class Rescale : public UpdatableComponent {
~Rescale()
{ }
Component* Copy() const {
return new Rescale(*this);
}
ComponentType GetType() const {
return kRescale;
}
Component* Copy() const { return new Rescale(*this); }
ComponentType GetType() const { return kRescale; }
void InitData(std::istream &is) {
// define options
@ -512,17 +486,13 @@ class Rescale : public UpdatableComponent {
scale_data_.AddVec(-lr * learn_rate_coef_, scale_data_grad_);
}
void SetLearnRateCoef(float c) {
learn_rate_coef_ = c;
}
void SetLearnRateCoef(float c) { learn_rate_coef_ = c; }
protected:
CuVector<BaseFloat> scale_data_;
CuVector<BaseFloat> scale_data_grad_;
};
} // namespace nnet1
} // namespace kaldi