Use OnlineFeatureInterface instead of hardcoded OnlineNnet2FeaturePipeline

in nnet2 and nnet3 decoders to allow more flexible feature pipelines.
This commit is contained in:
Nickolay Shmyrev 2016-05-03 19:44:52 +02:00
Родитель 8305c4c7cb
Коммит 47b7a6b5c7
15 изменённых файлов: 65 добавлений и 9 удалений

Просмотреть файл

@ -57,6 +57,9 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature {
virtual bool IsLastFrame(int32 frame) const {
return input_finished_ && frame == NumFramesReady() - 1;
}
virtual BaseFloat FrameShiftInSeconds() const {
return computer_.GetFrameOptions().frame_shift_ms * 1.0e-03;
}
virtual int32 NumFramesReady() const { return features_.size(); }
@ -140,6 +143,10 @@ class OnlineMatrixFeature: public OnlineFeatureInterface {
virtual int32 Dim() const { return mat_.NumCols(); }
virtual BaseFloat FrameShiftInSeconds() const {
return 0.01f;
}
virtual int32 NumFramesReady() const { return mat_.NumRows(); }
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat) {
@ -150,6 +157,7 @@ class OnlineMatrixFeature: public OnlineFeatureInterface {
return (frame + 1 == mat_.NumRows());
}
private:
const MatrixBase<BaseFloat> &mat_;
};
@ -291,13 +299,15 @@ class OnlineCmvn: public OnlineFeatureInterface {
virtual bool IsLastFrame(int32 frame) const {
return src_->IsLastFrame(frame);
}
virtual BaseFloat FrameShiftInSeconds() const {
return src_->FrameShiftInSeconds();
}
// The online cmvn does not introduce any additional latency.
virtual int32 NumFramesReady() const { return src_->NumFramesReady(); }
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat);
//
// Next, functions that are not in the interface.
//
@ -421,6 +431,9 @@ class OnlineSpliceFrames: public OnlineFeatureInterface {
virtual bool IsLastFrame(int32 frame) const {
return src_->IsLastFrame(frame);
}
virtual BaseFloat FrameShiftInSeconds() const {
return src_->FrameShiftInSeconds();
}
virtual int32 NumFramesReady() const;
@ -451,6 +464,9 @@ class OnlineTransform: public OnlineFeatureInterface {
virtual bool IsLastFrame(int32 frame) const {
return src_->IsLastFrame(frame);
}
virtual BaseFloat FrameShiftInSeconds() const {
return src_->FrameShiftInSeconds();
}
virtual int32 NumFramesReady() const { return src_->NumFramesReady(); }
@ -482,6 +498,9 @@ class OnlineDeltaFeature: public OnlineFeatureInterface {
virtual bool IsLastFrame(int32 frame) const {
return src_->IsLastFrame(frame);
}
virtual BaseFloat FrameShiftInSeconds() const {
return src_->FrameShiftInSeconds();
}
virtual int32 NumFramesReady() const;
@ -510,6 +529,9 @@ class OnlineCacheFeature: public OnlineFeatureInterface {
virtual bool IsLastFrame(int32 frame) const {
return src_->IsLastFrame(frame);
}
virtual BaseFloat FrameShiftInSeconds() const {
return src_->FrameShiftInSeconds();
}
virtual int32 NumFramesReady() const { return src_->NumFramesReady(); }
@ -541,6 +563,10 @@ class OnlineAppendFeature: public OnlineFeatureInterface {
virtual bool IsLastFrame(int32 frame) const {
return (src1_->IsLastFrame(frame) || src2_->IsLastFrame(frame));
}
// Hopefully sources have the same rate
virtual BaseFloat FrameShiftInSeconds() const {
return src1_->FrameShiftInSeconds();
}
virtual int32 NumFramesReady() const {
return std::min(src1_->NumFramesReady(), src2_->NumFramesReady());

Просмотреть файл

@ -576,6 +576,8 @@ class OnlinePitchFeatureImpl {
explicit OnlinePitchFeatureImpl(const PitchExtractionOptions &opts);
int32 Dim() const { return 2; }
BaseFloat FrameShiftInSeconds() const;
int32 NumFramesReady() const;
@ -879,6 +881,10 @@ bool OnlinePitchFeatureImpl::IsLastFrame(int32 frame) const {
return (input_finished_ && frame + 1 == T);
}
BaseFloat OnlinePitchFeatureImpl::FrameShiftInSeconds() const {
return opts_.frame_shift_ms * 1.0e-03;
}
int32 OnlinePitchFeatureImpl::NumFramesReady() const {
int32 num_frames = lag_nccf_.size(),
latency = frames_latency_;
@ -1171,6 +1177,10 @@ bool OnlinePitchFeature::IsLastFrame(int32 frame) const {
return impl_->IsLastFrame(frame);
}
BaseFloat OnlinePitchFeature::FrameShiftInSeconds() const {
return impl_->FrameShiftInSeconds();
}
void OnlinePitchFeature::GetFrame(int32 frame, VectorBase<BaseFloat> *feat) {
impl_->GetFrame(frame, feat);
}

Просмотреть файл

@ -301,6 +301,8 @@ class OnlinePitchFeature: public OnlineBaseFeature {
virtual int32 Dim() const { return 2; /* (NCCF, pitch) */ }
virtual int32 NumFramesReady() const;
virtual BaseFloat FrameShiftInSeconds() const;
virtual bool IsLastFrame(int32 frame) const;
@ -336,6 +338,9 @@ class OnlineProcessPitch: public OnlineFeatureInterface {
else
return src_->IsLastFrame(frame - opts_.delay);
}
virtual BaseFloat FrameShiftInSeconds() const {
return src_->FrameShiftInSeconds();
}
virtual int32 NumFramesReady() const;

Просмотреть файл

@ -74,10 +74,15 @@ class OnlineFeatureInterface {
/// the class.
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat) = 0;
// Returns frame shift in seconds. Helps to estimate duration from frame
// counts.
virtual BaseFloat FrameShiftInSeconds() const = 0;
/// Virtual destructor. Note: constructors that take another member of
/// type OnlineFeatureInterface are not expected to take ownership of
/// that pointer; the caller needs to keep track of that manually.
virtual ~OnlineFeatureInterface() { }
};

Просмотреть файл

@ -152,6 +152,10 @@ int32 OnlineIvectorFeature::NumFramesReady() const {
return lda_->NumFramesReady();
}
BaseFloat OnlineIvectorFeature::FrameShiftInSeconds() const {
return lda_->FrameShiftInSeconds();
}
void OnlineIvectorFeature::UpdateFrameWeights(
const std::vector<std::pair<int32, BaseFloat> > &delta_weights) {
// add the elements to delta_weights_, which is a priority queue. The top

Просмотреть файл

@ -266,6 +266,7 @@ class OnlineIvectorFeature: public OnlineFeatureInterface {
virtual int32 Dim() const;
virtual bool IsLastFrame(int32 frame) const;
virtual int32 NumFramesReady() const;
virtual BaseFloat FrameShiftInSeconds() const;
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat);
/// Set the adaptation state to a particular value, e.g. reflecting previous

Просмотреть файл

@ -28,7 +28,7 @@ SingleUtteranceNnet2Decoder::SingleUtteranceNnet2Decoder(
const TransitionModel &tmodel,
const nnet2::AmNnet &model,
const fst::Fst<fst::StdArc> &fst,
OnlineNnet2FeaturePipeline *feature_pipeline):
OnlineFeatureInterface *feature_pipeline):
config_(config),
feature_pipeline_(feature_pipeline),
tmodel_(tmodel),

Просмотреть файл

@ -29,7 +29,7 @@
#include "util/common-utils.h"
#include "base/kaldi-error.h"
#include "nnet2/online-nnet2-decodable.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "itf/online-feature-itf.h"
#include "online2/online-endpoint.h"
#include "decoder/lattice-faster-online-decoder.h"
#include "hmm/transition-model.h"
@ -72,7 +72,7 @@ class SingleUtteranceNnet2Decoder {
const TransitionModel &tmodel,
const nnet2::AmNnet &model,
const fst::Fst<fst::StdArc> &fst,
OnlineNnet2FeaturePipeline *feature_pipeline);
OnlineFeatureInterface *feature_pipeline);
/// advance the decoding as far as we can.
void AdvanceDecoding();
@ -111,7 +111,7 @@ class SingleUtteranceNnet2Decoder {
OnlineNnet2DecodingConfig config_;
OnlineNnet2FeaturePipeline *feature_pipeline_;
OnlineFeatureInterface *feature_pipeline_;
const TransitionModel &tmodel_;

Просмотреть файл

@ -29,7 +29,7 @@ SingleUtteranceNnet3Decoder::SingleUtteranceNnet3Decoder(
const TransitionModel &tmodel,
const nnet3::AmNnetSimple &am_model,
const fst::Fst<fst::StdArc> &fst,
OnlineNnet2FeaturePipeline *feature_pipeline):
OnlineFeatureInterface *feature_pipeline):
config_(config),
feature_pipeline_(feature_pipeline),
tmodel_(tmodel),

Просмотреть файл

@ -30,7 +30,7 @@
#include "matrix/matrix-lib.h"
#include "util/common-utils.h"
#include "base/kaldi-error.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "itf/online-feature-itf.h"
#include "online2/online-endpoint.h"
#include "decoder/lattice-faster-online-decoder.h"
#include "hmm/transition-model.h"
@ -73,7 +73,7 @@ class SingleUtteranceNnet3Decoder {
const TransitionModel &tmodel,
const nnet3::AmNnetSimple &am_model,
const fst::Fst<fst::StdArc> &fst,
OnlineNnet2FeaturePipeline *feature_pipeline);
OnlineFeatureInterface *feature_pipeline);
/// advance the decoding as far as we can.
void AdvanceDecoding();
@ -112,7 +112,7 @@ class SingleUtteranceNnet3Decoder {
OnlineNnet3DecodingConfig config_;
OnlineNnet2FeaturePipeline *feature_pipeline_;
OnlineFeatureInterface *feature_pipeline_;
const TransitionModel &tmodel_;

Просмотреть файл

@ -19,6 +19,7 @@
#include "feat/wave-reader.h"
#include "online2/online-nnet2-decoding.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"

Просмотреть файл

@ -20,6 +20,7 @@
#include "feat/wave-reader.h"
#include "online2/online-nnet2-decoding.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/onlinebin-util.h"
int main(int argc, char *argv[]) {

Просмотреть файл

@ -19,6 +19,7 @@
#include "feat/wave-reader.h"
#include "online2/online-nnet2-decoding.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"

Просмотреть файл

@ -19,6 +19,7 @@
#include "feat/wave-reader.h"
#include "online2/online-nnet2-decoding-threaded.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"

Просмотреть файл

@ -20,6 +20,7 @@
#include "feat/wave-reader.h"
#include "online2/online-nnet3-decoding.h"
#include "online2/online-nnet2-feature-pipeline.h"
#include "online2/onlinebin-util.h"
#include "online2/online-timing.h"
#include "online2/online-endpoint.h"