Added extra argument to a couple of matrix-library functions.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@1363 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
This commit is contained in:
Dan Povey 2012-09-18 05:09:15 +00:00
Родитель 45e4d2f0a0
Коммит 8163148b37
20 изменённых файлов: 232 добавлений и 171 удалений

Просмотреть файл

@ -310,7 +310,7 @@ void CuVector<Real>::AddRowSumMat(Real alpha, const CuMatrix<Real> &mat, Real be
#endif #endif
{ {
Vector<Real> tmp(mat.NumCols()); Vector<Real> tmp(mat.NumCols());
tmp.AddRowSumMat(mat.Mat()); tmp.AddRowSumMat(1.0, mat.Mat());
if(beta != 1.0) vec_.Scale(beta); if(beta != 1.0) vec_.Scale(beta);
vec_.AddVec(alpha, tmp); vec_.AddVec(alpha, tmp);
} }
@ -360,7 +360,7 @@ void CuVector<Real>::AddColSumMat(Real alpha, const CuMatrix<Real> &mat, Real be
#endif #endif
{ {
Vector<Real> tmp(mat.NumRows()); Vector<Real> tmp(mat.NumRows());
tmp.AddColSumMat(mat.Mat()); tmp.AddColSumMat(1.0, mat.Mat());
if(beta != 1.0) vec_.Scale(beta); if(beta != 1.0) vec_.Scale(beta);
vec_.AddVec(alpha,tmp); vec_.AddVec(alpha,tmp);
} }

Просмотреть файл

@ -153,7 +153,7 @@ int main(int argc, char *argv[]) {
} }
if (subtract_mean) { if (subtract_mean) {
Vector<BaseFloat> mean(features.NumCols()); Vector<BaseFloat> mean(features.NumCols());
mean.AddRowSumMat(features); mean.AddRowSumMat(1.0, features);
mean.Scale(1.0 / features.NumRows()); mean.Scale(1.0 / features.NumRows());
for (int32 i = 0; i < features.NumRows(); i++) for (int32 i = 0; i < features.NumRows(); i++)
features.Row(i).AddVec(-1.0, mean); features.Row(i).AddVec(-1.0, mean);

Просмотреть файл

@ -153,7 +153,7 @@ int main(int argc, char *argv[]) {
} }
if (subtract_mean) { if (subtract_mean) {
Vector<BaseFloat> mean(features.NumCols()); Vector<BaseFloat> mean(features.NumCols());
mean.AddRowSumMat(features); mean.AddRowSumMat(1.0, features);
mean.Scale(1.0 / features.NumRows()); mean.Scale(1.0 / features.NumRows());
for (int32 i = 0; i < features.NumRows(); i++) for (int32 i = 0; i < features.NumRows(); i++)
features.Row(i).AddVec(-1.0, mean); features.Row(i).AddVec(-1.0, mean);

Просмотреть файл

@ -153,7 +153,7 @@ int main(int argc, char *argv[]) {
} }
if (subtract_mean) { if (subtract_mean) {
Vector<BaseFloat> mean(features.NumCols()); Vector<BaseFloat> mean(features.NumCols());
mean.AddRowSumMat(features); mean.AddRowSumMat(1.0, features);
mean.Scale(1.0 / features.NumRows()); mean.Scale(1.0 / features.NumRows());
for (size_t i = 0; i < features.NumRows(); i++) for (size_t i = 0; i < features.NumRows(); i++)
features.Row(i).AddVec(-1.0, mean); features.Row(i).AddVec(-1.0, mean);

Просмотреть файл

@ -116,7 +116,7 @@ int main(int argc, char *argv[]) {
} }
if (subtract_mean) { if (subtract_mean) {
Vector<BaseFloat> mean(features.NumCols()); Vector<BaseFloat> mean(features.NumCols());
mean.AddRowSumMat(features); mean.AddRowSumMat(1.0, features);
mean.Scale(1.0 / features.NumRows()); mean.Scale(1.0 / features.NumRows());
for (int32 i = 0; i < features.NumRows(); i++) for (int32 i = 0; i < features.NumRows(); i++)
features.Row(i).AddVec(-1.0, mean); features.Row(i).AddVec(-1.0, mean);

Просмотреть файл

@ -52,7 +52,7 @@ int main(int argc, char *argv[]) {
continue; continue;
} }
Vector<BaseFloat> mean(feats.NumCols()); Vector<BaseFloat> mean(feats.NumCols());
mean.AddRowSumMat(feats); mean.AddRowSumMat(1.0, feats);
mean.Scale(1.0 / feats.NumRows()); mean.Scale(1.0 / feats.NumRows());
for (int32 i = 0; i < feats.NumRows(); i++) for (int32 i = 0; i < feats.NumRows(); i++)
feats.Row(i).AddVec(-1.0, mean); feats.Row(i).AddVec(-1.0, mean);

Просмотреть файл

@ -309,7 +309,7 @@ UnitTestEstimateFullGmm() {
Vector<BaseFloat> mean(dim); Vector<BaseFloat> mean(dim);
cov.AddMatMat(1.0, feats, kTrans, feats, kNoTrans, 0.0); cov.AddMatMat(1.0, feats, kTrans, feats, kNoTrans, 0.0);
cov.Scale(1.0 / feats.NumRows()); cov.Scale(1.0 / feats.NumRows());
mean.AddRowSumMat(feats); mean.AddRowSumMat(1.0, feats);
mean.Scale(1.0 / feats.NumRows()); mean.Scale(1.0 / feats.NumRows());
cov.AddVecVec(-1.0, mean, mean); cov.AddVecVec(-1.0, mean, mean);
BaseFloat logdet = cov.LogDet(); BaseFloat logdet = cov.LogDet();

Просмотреть файл

@ -763,8 +763,7 @@ void MatrixBase<double>::CopyFromTp(const TpMatrix<double> & M,
template<typename Real> template<typename Real>
void MatrixBase<Real>::CopyRowsFromVec(const VectorBase<Real> &rv) { void MatrixBase<Real>::CopyRowsFromVec(const VectorBase<Real> &rv) {
KALDI_ASSERT(rv.Dim() == num_rows_*num_cols_); if (rv.Dim() == num_rows_*num_cols_) {
if (stride_ == num_cols_) { if (stride_ == num_cols_) {
// one big copy operation. // one big copy operation.
const Real *rv_data = rv.Data(); const Real *rv_data = rv.Data();
@ -779,12 +778,19 @@ void MatrixBase<Real>::CopyRowsFromVec(const VectorBase<Real> &rv) {
rv_data += num_cols_; rv_data += num_cols_;
} }
} }
} else if (rv.Dim() == num_cols_) {
const Real *rv_data = rv.Data();
for (MatrixIndexT r = 0; r < num_rows_; r++)
std::memcpy(RowData(r), rv_data, sizeof(Real)*num_cols_);
} else {
KALDI_ERR << "Wrong sized arguments";
}
} }
template<typename Real> template<typename Real>
template<typename OtherReal> template<typename OtherReal>
void MatrixBase<Real>::CopyRowsFromVec(const VectorBase<OtherReal> &rv) { void MatrixBase<Real>::CopyRowsFromVec(const VectorBase<OtherReal> &rv) {
KALDI_ASSERT(rv.Dim() == num_rows_*num_cols_); if (rv.Dim() == num_rows_*num_cols_) {
const OtherReal *rv_data = rv.Data(); const OtherReal *rv_data = rv.Data();
for (MatrixIndexT r = 0; r < num_rows_; r++) { for (MatrixIndexT r = 0; r < num_rows_; r++) {
Real *row_data = RowData(r); Real *row_data = RowData(r);
@ -793,7 +799,18 @@ void MatrixBase<Real>::CopyRowsFromVec(const VectorBase<OtherReal> &rv) {
} }
rv_data += num_cols_; rv_data += num_cols_;
} }
} else if (rv.Dim() == num_cols_) {
const OtherReal *rv_data = rv.Data();
Real *first_row_data = RowData(0);
for (MatrixIndexT c = 0; c < num_cols_; c++)
first_row_data[c] = rv_data[c];
for (MatrixIndexT r = 1; r < num_rows_; r++)
std::memcpy(RowData(r), first_row_data, sizeof(Real)*num_cols_);
} else {
KALDI_ERR << "Wrong sized arguments.";
} }
}
template template
void MatrixBase<float>::CopyRowsFromVec(const VectorBase<double> &rv); void MatrixBase<float>::CopyRowsFromVec(const VectorBase<double> &rv);
@ -802,8 +819,7 @@ void MatrixBase<double>::CopyRowsFromVec(const VectorBase<float> &rv);
template<typename Real> template<typename Real>
void MatrixBase<Real>::CopyColsFromVec(const VectorBase<Real> &rv) { void MatrixBase<Real>::CopyColsFromVec(const VectorBase<Real> &rv) {
KALDI_ASSERT(rv.Dim() == num_rows_*num_cols_); if (rv.Dim() == num_rows_*num_cols_) {
const Real *v_inc_data = rv.Data(); const Real *v_inc_data = rv.Data();
Real *m_inc_data = data_; Real *m_inc_data = data_;
@ -814,6 +830,18 @@ void MatrixBase<Real>::CopyColsFromVec(const VectorBase<Real> &rv) {
v_inc_data += num_rows_; v_inc_data += num_rows_;
m_inc_data ++; m_inc_data ++;
} }
} else if (rv.Dim() == num_rows_) {
const Real *v_inc_data = rv.Data();
Real *m_inc_data = data_;
for (MatrixIndexT r = 0; r < num_rows_; r++) {
BaseFloat value = *(v_inc_data++);
for (MatrixIndexT c = 0; c < num_cols_; c++)
m_inc_data[c] = value;
m_inc_data += stride_;
}
} else {
KALDI_ERR << "Wrong size of arguments.";
}
} }

Просмотреть файл

@ -130,14 +130,17 @@ class MatrixBase {
MatrixTransposeType Trans = kNoTrans); MatrixTransposeType Trans = kNoTrans);
/// Inverse of vec() operator. Copies vector into matrix, row-by-row. /// Inverse of vec() operator. Copies vector into matrix, row-by-row.
/// Note that rv.Dim() must equal NumRows()*NumCols(). /// Note that rv.Dim() must either equal NumRows()*NumCols() or
/// NumCols()-- this has two modes of operation.
void CopyRowsFromVec(const VectorBase<Real> &v); void CopyRowsFromVec(const VectorBase<Real> &v);
template<typename OtherReal> template<typename OtherReal>
void CopyRowsFromVec(const VectorBase<OtherReal> &v); void CopyRowsFromVec(const VectorBase<OtherReal> &v);
/// Copies vector into matrix, column-by-column. /// Copies vector into matrix, column-by-column.
/// Note that rv.Dim() must equal NumRows()*NumCols(). /// Note that rv.Dim() must either equal NumRows()*NumCols() or NumRows();
/// this has two modes of operation.
void CopyColsFromVec(const VectorBase<Real> &v); void CopyColsFromVec(const VectorBase<Real> &v);
/// Copy vector into specific column of matrix. /// Copy vector into specific column of matrix.
void CopyColFromVec(const VectorBase<Real> &v, const MatrixIndexT col); void CopyColFromVec(const VectorBase<Real> &v, const MatrixIndexT col);
/// Copy vector into specific row of matrix. /// Copy vector into specific row of matrix.
@ -745,6 +748,12 @@ std::istream & operator >> (std::istream & In, MatrixBase<Real> & M);
template<typename Real> template<typename Real>
std::istream & operator >> (std::istream & In, Matrix<Real> & M); std::istream & operator >> (std::istream & In, Matrix<Real> & M);
template<class Real>
bool SameDim(const MatrixBase<Real> &M, const MatrixBase<Real> &N) {
return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols());
}
/// @} end of \addtogroup matrix_funcs_io /// @} end of \addtogroup matrix_funcs_io

Просмотреть файл

@ -551,7 +551,7 @@ Real VectorBase<Real>::SumLog() const {
} }
template<typename Real> template<typename Real>
void VectorBase<Real>::AddRowSumMat(const MatrixBase<Real> &rM) { void VectorBase<Real>::AddRowSumMat(Real alpha, const MatrixBase<Real> &rM) {
// note the double accumulator // note the double accumulator
double sum; double sum;
KALDI_ASSERT(dim_ == rM.NumCols()); KALDI_ASSERT(dim_ == rM.NumCols());
@ -560,12 +560,12 @@ void VectorBase<Real>::AddRowSumMat(const MatrixBase<Real> &rM) {
for (MatrixIndexT j = 0; j < rM.NumRows(); j++) { for (MatrixIndexT j = 0; j < rM.NumRows(); j++) {
sum += rM(j, i); sum += rM(j, i);
} }
data_[i] += sum; data_[i] += alpha * sum;
} }
} }
template<typename Real> template<typename Real>
void VectorBase<Real>::AddColSumMat(const MatrixBase<Real> &rM) { void VectorBase<Real>::AddColSumMat(Real alpha, const MatrixBase<Real> &rM) {
// note the double accumulator // note the double accumulator
double sum; double sum;
KALDI_ASSERT(dim_ == rM.NumRows()); KALDI_ASSERT(dim_ == rM.NumRows());
@ -574,7 +574,7 @@ void VectorBase<Real>::AddColSumMat(const MatrixBase<Real> &rM) {
for (MatrixIndexT j = 0; j < rM.NumCols(); j++) { for (MatrixIndexT j = 0; j < rM.NumCols(); j++) {
sum += rM(i, j); sum += rM(i, j);
} }
data_[i] += sum; data_[i] += alpha * sum;
} }
} }

Просмотреть файл

@ -242,11 +242,11 @@ class VectorBase {
/// negative. /// negative.
Real SumLog() const; Real SumLog() const;
/// Adds sum of the rows of M to existing contents. /// Adds sum of the rows of M to existing contents, times alpha.
void AddRowSumMat(const MatrixBase<Real>& M); void AddRowSumMat(Real alpha, const MatrixBase<Real>& M);
/// Adds sum of the columns of M to existing contents. /// Adds sum of the columns of M to existing contents.
void AddColSumMat(const MatrixBase<Real>& M); void AddColSumMat(Real alpha, const MatrixBase<Real>& M);
/// Returns log(sum(exp())) without exp overflow /// Returns log(sum(exp())) without exp overflow
/// If prune > 0.0, ignores terms less than the max - prune. /// If prune > 0.0, ignores terms less than the max - prune.

Просмотреть файл

@ -209,6 +209,27 @@ static void UnitTestSpAddVec() {
} }
} }
template<class Real> static void UnitTestCopyRowsAndCols() {
// Test other mode of CopyRowsFromVec, and CopyColsFromVec,
// where vector is duplicated.
for (int32 i = 0; i < 30; i++) {
int32 dimM = 1 + rand() % 5, dimN = 1 + rand() % 5;
Vector<float> w(dimN); // test cross-type version of
// CopyRowsFromVec.
Vector<Real> v(dimM);
Matrix<Real> M(dimM, dimN), N(dimM, dimN);
InitRand(&v);
InitRand(&w);
M.CopyColsFromVec(v);
N.CopyRowsFromVec(w);
for (int32 r = 0; r < dimM; r++) {
for (int32 c = 0; c < dimN; c++) {
KALDI_ASSERT(M(r, c) == v(r));
KALDI_ASSERT(N(r, c) == w(c));
}
}
}
}
template<class Real> static void UnitTestSpliceRows() { template<class Real> static void UnitTestSpliceRows() {
@ -354,8 +375,10 @@ static void UnitTestSimpleForVec() { // testing some simple operaters on vector
Matrix<Real> M(dimM, dimN); Matrix<Real> M(dimM, dimN);
InitRand(&M); InitRand(&M);
Vector<Real> Vr(dimN), Vc(dimM); Vector<Real> Vr(dimN), Vc(dimM);
Vr.AddRowSumMat(M); Vr.AddRowSumMat(0.5, M);
Vc.AddColSumMat(M); Vc.AddColSumMat(0.5, M);
Vr.Scale(2.0);
Vc.Scale(2.0);
Vector<Real> V2r(dimN), V2c(dimM); Vector<Real> V2r(dimN), V2c(dimM);
for (MatrixIndexT k = 0; k < dimM; k++) { for (MatrixIndexT k = 0; k < dimM; k++) {
@ -2972,6 +2995,7 @@ template<class Real> static void MatrixUnitTest() {
UnitTestMat2Vec<Real>(); UnitTestMat2Vec<Real>();
UnitTestSpLogExp<Real>(); UnitTestSpLogExp<Real>();
KALDI_LOG << " Point H"; KALDI_LOG << " Point H";
UnitTestCopyRowsAndCols<Real>();
UnitTestSpliceRows<Real>(); UnitTestSpliceRows<Real>();
UnitTestAddSp<Real>(); UnitTestAddSp<Real>();
UnitTestRemoveRow<Real>(); UnitTestRemoveRow<Real>();

Просмотреть файл

@ -72,7 +72,7 @@ class BiasedLinearity : public UpdatableComponent {
// compute gradient // compute gradient
linearity_corr_.AddMatMat(1.0, err, kTrans, input, kNoTrans, momentum_); linearity_corr_.AddMatMat(1.0, err, kTrans, input, kNoTrans, momentum_);
bias_corr_.Scale(momentum_); bias_corr_.Scale(momentum_);
bias_corr_.AddRowSumMat(err); bias_corr_.AddRowSumMat(1.0, err);
// l2 regularization // l2 regularization
if (l2_penalty_ != 0.0) { if (l2_penalty_ != 0.0) {
linearity_.AddMat(-learn_rate_*l2_penalty_*input.NumRows(), linearity_); linearity_.AddMat(-learn_rate_*l2_penalty_*input.NumRows(), linearity_);

Просмотреть файл

@ -128,7 +128,7 @@ void Rnnlm::Backpropagate(const MatrixBase<BaseFloat> &in_err) {
// update layer2 // update layer2
W2_.AddMatMat(-learn_rate_, h2_, kTrans, in_err, kNoTrans, 1.0); W2_.AddMatMat(-learn_rate_, h2_, kTrans, in_err, kNoTrans, 1.0);
b2_corr_.SetZero(); b2_corr_.SetZero();
b2_corr_.AddRowSumMat(in_err); b2_corr_.AddRowSumMat(1.0, in_err);
b2_.AddVec(-learn_rate_, b2_corr_); b2_.AddVec(-learn_rate_, b2_corr_);
// LAYER1 // LAYER1
@ -136,7 +136,7 @@ void Rnnlm::Backpropagate(const MatrixBase<BaseFloat> &in_err) {
U1_corr_.SetZero(); U1_corr_.SetZero();
b1_corr_.SetZero(); b1_corr_.SetZero();
// accumulate gradient for layer1 // accumulate gradient for layer1
b1_corr_.AddRowSumMat(e2_); b1_corr_.AddRowSumMat(1.0, e2_);
for(int32 r=0; r<e2_.NumRows();r++) { for(int32 r=0; r<e2_.NumRows();r++) {
V1_corr_.Row(in_seq_[r]-1).AddVec(1.0, e2_.Row(r)); V1_corr_.Row(in_seq_[r]-1).AddVec(1.0, e2_.Row(r));
} }
@ -174,7 +174,7 @@ void Rnnlm::Backpropagate(const MatrixBase<BaseFloat> &in_err) {
} }
// accumulate graidient // accumulate graidient
b1_corr_.AddRowSumMat(E); b1_corr_.AddRowSumMat(1.0, E);
for(int32 r=0; r<E.NumRows();r++) { for(int32 r=0; r<E.NumRows();r++) {
// :TODO: IS IT CORRECT? // :TODO: IS IT CORRECT?
// V1_corr_.Row(in_seq_[r+step]).AddVec(1.0,E.Row(r)); // V1_corr_.Row(in_seq_[r+step]).AddVec(1.0,E.Row(r));

Просмотреть файл

@ -273,8 +273,8 @@ double EbwAmSgmmUpdater::UpdateM(const MleAmSgmmAccs &num_accs,
Vector<double> num_count_vec(I), den_count_vec(I), impr_vec(I); Vector<double> num_count_vec(I), den_count_vec(I), impr_vec(I);
for (int32 j = 0; j < num_accs.num_states_; j++) { for (int32 j = 0; j < num_accs.num_states_; j++) {
num_count_vec.AddRowSumMat(num_accs.gamma_[j]); num_count_vec.AddRowSumMat(1.0, num_accs.gamma_[j]);
den_count_vec.AddRowSumMat(den_accs.gamma_[j]); den_count_vec.AddRowSumMat(1.0, den_accs.gamma_[j]);
} }
for (int32 i = 0; i < I; i++) { for (int32 i = 0; i < I; i++) {
@ -363,8 +363,8 @@ double EbwAmSgmmUpdater::UpdateWParallel(const MleAmSgmmAccs &num_accs,
Vector<double> num_count_vec(I), den_count_vec(I), impr_vec(I); Vector<double> num_count_vec(I), den_count_vec(I), impr_vec(I);
for (int32 j = 0; j < num_accs.num_states_; j++) { for (int32 j = 0; j < num_accs.num_states_; j++) {
num_count_vec.AddRowSumMat(num_accs.gamma_[j]); num_count_vec.AddRowSumMat(1.0, num_accs.gamma_[j]);
den_count_vec.AddRowSumMat(den_accs.gamma_[j]); den_count_vec.AddRowSumMat(1.0, den_accs.gamma_[j]);
} }
// Get the F_i and g_i quantities-- this is done in parallel (multi-core), // Get the F_i and g_i quantities-- this is done in parallel (multi-core),
@ -448,8 +448,8 @@ double EbwAmSgmmUpdater::UpdateN(const MleAmSgmmAccs &num_accs,
Vector<double> num_count_vec(I), den_count_vec(I), impr_vec(I); Vector<double> num_count_vec(I), den_count_vec(I), impr_vec(I);
for (int32 j = 0; j < num_accs.num_states_; j++) { for (int32 j = 0; j < num_accs.num_states_; j++) {
num_count_vec.AddRowSumMat(num_accs.gamma_[j]); num_count_vec.AddRowSumMat(1.0, num_accs.gamma_[j]);
den_count_vec.AddRowSumMat(den_accs.gamma_[j]); den_count_vec.AddRowSumMat(1.0, den_accs.gamma_[j]);
} }
for (int32 i = 0; i < I; i++) { for (int32 i = 0; i < I; i++) {
@ -522,8 +522,8 @@ double EbwAmSgmmUpdater::UpdateVars(const MleAmSgmmAccs &num_accs,
Vector<double> num_count_vec(I), den_count_vec(I), impr_vec(I); Vector<double> num_count_vec(I), den_count_vec(I), impr_vec(I);
for (int32 j = 0; j < num_accs.num_states_; j++) { for (int32 j = 0; j < num_accs.num_states_; j++) {
num_count_vec.AddRowSumMat(num_accs.gamma_[j]); num_count_vec.AddRowSumMat(1.0, num_accs.gamma_[j]);
den_count_vec.AddRowSumMat(den_accs.gamma_[j]); den_count_vec.AddRowSumMat(1.0, den_accs.gamma_[j]);
} }
for (int32 i = 0; i < I; i++) { for (int32 i = 0; i < I; i++) {

Просмотреть файл

@ -538,7 +538,7 @@ BaseFloat AmSgmm2::LogLikelihood(const Sgmm2PerFrameDerivedVars &per_frame_vars,
substate_cache.remaining_log_like = max; substate_cache.remaining_log_like = max;
int32 num_substates = loglikes.NumCols(); int32 num_substates = loglikes.NumCols();
substate_cache.likes.Resize(num_substates); // zeroes it. substate_cache.likes.Resize(num_substates); // zeroes it.
substate_cache.likes.AddRowSumMat(loglikes); // add likelihoods [not in log!] for substate_cache.likes.AddRowSumMat(1.0, loglikes); // add likelihoods [not in log!] for
// each column [i.e. summing over the rows], so we get the sum for // each column [i.e. summing over the rows], so we get the sum for
// each substate index. You have to multiply by exp(remaining_log_like) // each substate index. You have to multiply by exp(remaining_log_like)
// to get a real likelihood. // to get a real likelihood.
@ -613,7 +613,7 @@ void AmSgmm2::SplitSubstatesInGroup(const Vector<BaseFloat> &pdf_occupancies,
int32 split_m; // substate to split. int32 split_m; // substate to split.
{ {
Vector<BaseFloat> substate_count(tgt_M); Vector<BaseFloat> substate_count(tgt_M);
substate_count.AddRowSumMat(c_j); substate_count.AddRowSumMat(1.0, c_j);
BaseFloat *data = substate_count.Data(); BaseFloat *data = substate_count.Data();
split_m = std::max_element(data, data+cur_M) - data; split_m = std::max_element(data, data+cur_M) - data;
} }

Просмотреть файл

@ -50,10 +50,10 @@ void EbwAmSgmm2Updater::Update(const MleAmSgmm2Accs &num_accs,
Vector<double> gamma_num(num_accs.num_gaussians_); Vector<double> gamma_num(num_accs.num_gaussians_);
for (int32 j1 = 0; j1 < num_accs.num_groups_; j1++) for (int32 j1 = 0; j1 < num_accs.num_groups_; j1++)
gamma_num.AddRowSumMat(num_accs.gamma_[j1]); gamma_num.AddRowSumMat(1.0, num_accs.gamma_[j1]);
Vector<double> gamma_den(den_accs.num_gaussians_); Vector<double> gamma_den(den_accs.num_gaussians_);
for (int32 j1 = 0; j1 < den_accs.num_groups_; j1++) for (int32 j1 = 0; j1 < den_accs.num_groups_; j1++)
gamma_den.AddRowSumMat(den_accs.gamma_[j1]); gamma_den.AddRowSumMat(1.0, den_accs.gamma_[j1]);
BaseFloat tot_impr = 0.0; BaseFloat tot_impr = 0.0;

Просмотреть файл

@ -617,7 +617,7 @@ void MleAmSgmm2Updater::Update(const MleAmSgmm2Accs &accs,
Vector<double> gamma_i(accs.num_gaussians_); Vector<double> gamma_i(accs.num_gaussians_);
for (int32 j1 = 0; j1 < accs.num_groups_; j1++) for (int32 j1 = 0; j1 < accs.num_groups_; j1++)
gamma_i.AddRowSumMat(accs.gamma_[j1]); // add sum of rows of gamma_i.AddRowSumMat(1.0, accs.gamma_[j1]); // add sum of rows of
// accs.gamma_[j1], to gamma_i. // accs.gamma_[j1], to gamma_i.
if (flags & kSgmmPhoneProjections) if (flags & kSgmmPhoneProjections)

Просмотреть файл

@ -63,7 +63,7 @@ void LdaEstimate::Estimate(int32 target_dim,
// total covariance // total covariance
double sum = zero_acc_.Sum(); double sum = zero_acc_.Sum();
Vector<double> total_mean(dim); Vector<double> total_mean(dim);
total_mean.AddRowSumMat(first_acc_); total_mean.AddRowSumMat(1.0, first_acc_);
total_mean.Scale(1/sum); total_mean.Scale(1/sum);
SpMatrix<double> total_covar(total_second_acc_); SpMatrix<double> total_covar(total_second_acc_);
total_covar.Scale(1/sum); total_covar.Scale(1/sum);