зеркало из https://github.com/mozilla/kaldi.git
Use OnlineFeatureInterface instead of hardcoded OnlineNnet2FeaturePipeline
in nnet2 and nnet3 decoders to allow more flexible feature pipelines.
This commit is contained in:
Родитель
8305c4c7cb
Коммит
47b7a6b5c7
|
@ -57,6 +57,9 @@ class OnlineGenericBaseFeature: public OnlineBaseFeature {
|
|||
virtual bool IsLastFrame(int32 frame) const {
|
||||
return input_finished_ && frame == NumFramesReady() - 1;
|
||||
}
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return computer_.GetFrameOptions().frame_shift_ms * 1.0e-03;
|
||||
}
|
||||
|
||||
virtual int32 NumFramesReady() const { return features_.size(); }
|
||||
|
||||
|
@ -140,6 +143,10 @@ class OnlineMatrixFeature: public OnlineFeatureInterface {
|
|||
|
||||
virtual int32 Dim() const { return mat_.NumCols(); }
|
||||
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return 0.01f;
|
||||
}
|
||||
|
||||
virtual int32 NumFramesReady() const { return mat_.NumRows(); }
|
||||
|
||||
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat) {
|
||||
|
@ -150,6 +157,7 @@ class OnlineMatrixFeature: public OnlineFeatureInterface {
|
|||
return (frame + 1 == mat_.NumRows());
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
const MatrixBase<BaseFloat> &mat_;
|
||||
};
|
||||
|
@ -291,13 +299,15 @@ class OnlineCmvn: public OnlineFeatureInterface {
|
|||
virtual bool IsLastFrame(int32 frame) const {
|
||||
return src_->IsLastFrame(frame);
|
||||
}
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return src_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
// The online cmvn does not introduce any additional latency.
|
||||
virtual int32 NumFramesReady() const { return src_->NumFramesReady(); }
|
||||
|
||||
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat);
|
||||
|
||||
|
||||
//
|
||||
// Next, functions that are not in the interface.
|
||||
//
|
||||
|
@ -421,6 +431,9 @@ class OnlineSpliceFrames: public OnlineFeatureInterface {
|
|||
virtual bool IsLastFrame(int32 frame) const {
|
||||
return src_->IsLastFrame(frame);
|
||||
}
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return src_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
virtual int32 NumFramesReady() const;
|
||||
|
||||
|
@ -451,6 +464,9 @@ class OnlineTransform: public OnlineFeatureInterface {
|
|||
virtual bool IsLastFrame(int32 frame) const {
|
||||
return src_->IsLastFrame(frame);
|
||||
}
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return src_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
virtual int32 NumFramesReady() const { return src_->NumFramesReady(); }
|
||||
|
||||
|
@ -482,6 +498,9 @@ class OnlineDeltaFeature: public OnlineFeatureInterface {
|
|||
virtual bool IsLastFrame(int32 frame) const {
|
||||
return src_->IsLastFrame(frame);
|
||||
}
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return src_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
virtual int32 NumFramesReady() const;
|
||||
|
||||
|
@ -510,6 +529,9 @@ class OnlineCacheFeature: public OnlineFeatureInterface {
|
|||
virtual bool IsLastFrame(int32 frame) const {
|
||||
return src_->IsLastFrame(frame);
|
||||
}
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return src_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
virtual int32 NumFramesReady() const { return src_->NumFramesReady(); }
|
||||
|
||||
|
@ -541,6 +563,10 @@ class OnlineAppendFeature: public OnlineFeatureInterface {
|
|||
virtual bool IsLastFrame(int32 frame) const {
|
||||
return (src1_->IsLastFrame(frame) || src2_->IsLastFrame(frame));
|
||||
}
|
||||
// Hopefully sources have the same rate
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return src1_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
virtual int32 NumFramesReady() const {
|
||||
return std::min(src1_->NumFramesReady(), src2_->NumFramesReady());
|
||||
|
|
|
@ -576,6 +576,8 @@ class OnlinePitchFeatureImpl {
|
|||
explicit OnlinePitchFeatureImpl(const PitchExtractionOptions &opts);
|
||||
|
||||
int32 Dim() const { return 2; }
|
||||
|
||||
BaseFloat FrameShiftInSeconds() const;
|
||||
|
||||
int32 NumFramesReady() const;
|
||||
|
||||
|
@ -879,6 +881,10 @@ bool OnlinePitchFeatureImpl::IsLastFrame(int32 frame) const {
|
|||
return (input_finished_ && frame + 1 == T);
|
||||
}
|
||||
|
||||
BaseFloat OnlinePitchFeatureImpl::FrameShiftInSeconds() const {
|
||||
return opts_.frame_shift_ms * 1.0e-03;
|
||||
}
|
||||
|
||||
int32 OnlinePitchFeatureImpl::NumFramesReady() const {
|
||||
int32 num_frames = lag_nccf_.size(),
|
||||
latency = frames_latency_;
|
||||
|
@ -1171,6 +1177,10 @@ bool OnlinePitchFeature::IsLastFrame(int32 frame) const {
|
|||
return impl_->IsLastFrame(frame);
|
||||
}
|
||||
|
||||
BaseFloat OnlinePitchFeature::FrameShiftInSeconds() const {
|
||||
return impl_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
void OnlinePitchFeature::GetFrame(int32 frame, VectorBase<BaseFloat> *feat) {
|
||||
impl_->GetFrame(frame, feat);
|
||||
}
|
||||
|
|
|
@ -301,6 +301,8 @@ class OnlinePitchFeature: public OnlineBaseFeature {
|
|||
virtual int32 Dim() const { return 2; /* (NCCF, pitch) */ }
|
||||
|
||||
virtual int32 NumFramesReady() const;
|
||||
|
||||
virtual BaseFloat FrameShiftInSeconds() const;
|
||||
|
||||
virtual bool IsLastFrame(int32 frame) const;
|
||||
|
||||
|
@ -336,6 +338,9 @@ class OnlineProcessPitch: public OnlineFeatureInterface {
|
|||
else
|
||||
return src_->IsLastFrame(frame - opts_.delay);
|
||||
}
|
||||
virtual BaseFloat FrameShiftInSeconds() const {
|
||||
return src_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
virtual int32 NumFramesReady() const;
|
||||
|
||||
|
|
|
@ -74,10 +74,15 @@ class OnlineFeatureInterface {
|
|||
/// the class.
|
||||
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat) = 0;
|
||||
|
||||
// Returns frame shift in seconds. Helps to estimate duration from frame
|
||||
// counts.
|
||||
virtual BaseFloat FrameShiftInSeconds() const = 0;
|
||||
|
||||
/// Virtual destructor. Note: constructors that take another member of
|
||||
/// type OnlineFeatureInterface are not expected to take ownership of
|
||||
/// that pointer; the caller needs to keep track of that manually.
|
||||
virtual ~OnlineFeatureInterface() { }
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -152,6 +152,10 @@ int32 OnlineIvectorFeature::NumFramesReady() const {
|
|||
return lda_->NumFramesReady();
|
||||
}
|
||||
|
||||
BaseFloat OnlineIvectorFeature::FrameShiftInSeconds() const {
|
||||
return lda_->FrameShiftInSeconds();
|
||||
}
|
||||
|
||||
void OnlineIvectorFeature::UpdateFrameWeights(
|
||||
const std::vector<std::pair<int32, BaseFloat> > &delta_weights) {
|
||||
// add the elements to delta_weights_, which is a priority queue. The top
|
||||
|
|
|
@ -266,6 +266,7 @@ class OnlineIvectorFeature: public OnlineFeatureInterface {
|
|||
virtual int32 Dim() const;
|
||||
virtual bool IsLastFrame(int32 frame) const;
|
||||
virtual int32 NumFramesReady() const;
|
||||
virtual BaseFloat FrameShiftInSeconds() const;
|
||||
virtual void GetFrame(int32 frame, VectorBase<BaseFloat> *feat);
|
||||
|
||||
/// Set the adaptation state to a particular value, e.g. reflecting previous
|
||||
|
|
|
@ -28,7 +28,7 @@ SingleUtteranceNnet2Decoder::SingleUtteranceNnet2Decoder(
|
|||
const TransitionModel &tmodel,
|
||||
const nnet2::AmNnet &model,
|
||||
const fst::Fst<fst::StdArc> &fst,
|
||||
OnlineNnet2FeaturePipeline *feature_pipeline):
|
||||
OnlineFeatureInterface *feature_pipeline):
|
||||
config_(config),
|
||||
feature_pipeline_(feature_pipeline),
|
||||
tmodel_(tmodel),
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
#include "util/common-utils.h"
|
||||
#include "base/kaldi-error.h"
|
||||
#include "nnet2/online-nnet2-decodable.h"
|
||||
#include "online2/online-nnet2-feature-pipeline.h"
|
||||
#include "itf/online-feature-itf.h"
|
||||
#include "online2/online-endpoint.h"
|
||||
#include "decoder/lattice-faster-online-decoder.h"
|
||||
#include "hmm/transition-model.h"
|
||||
|
@ -72,7 +72,7 @@ class SingleUtteranceNnet2Decoder {
|
|||
const TransitionModel &tmodel,
|
||||
const nnet2::AmNnet &model,
|
||||
const fst::Fst<fst::StdArc> &fst,
|
||||
OnlineNnet2FeaturePipeline *feature_pipeline);
|
||||
OnlineFeatureInterface *feature_pipeline);
|
||||
|
||||
/// advance the decoding as far as we can.
|
||||
void AdvanceDecoding();
|
||||
|
@ -111,7 +111,7 @@ class SingleUtteranceNnet2Decoder {
|
|||
|
||||
OnlineNnet2DecodingConfig config_;
|
||||
|
||||
OnlineNnet2FeaturePipeline *feature_pipeline_;
|
||||
OnlineFeatureInterface *feature_pipeline_;
|
||||
|
||||
const TransitionModel &tmodel_;
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ SingleUtteranceNnet3Decoder::SingleUtteranceNnet3Decoder(
|
|||
const TransitionModel &tmodel,
|
||||
const nnet3::AmNnetSimple &am_model,
|
||||
const fst::Fst<fst::StdArc> &fst,
|
||||
OnlineNnet2FeaturePipeline *feature_pipeline):
|
||||
OnlineFeatureInterface *feature_pipeline):
|
||||
config_(config),
|
||||
feature_pipeline_(feature_pipeline),
|
||||
tmodel_(tmodel),
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "matrix/matrix-lib.h"
|
||||
#include "util/common-utils.h"
|
||||
#include "base/kaldi-error.h"
|
||||
#include "online2/online-nnet2-feature-pipeline.h"
|
||||
#include "itf/online-feature-itf.h"
|
||||
#include "online2/online-endpoint.h"
|
||||
#include "decoder/lattice-faster-online-decoder.h"
|
||||
#include "hmm/transition-model.h"
|
||||
|
@ -73,7 +73,7 @@ class SingleUtteranceNnet3Decoder {
|
|||
const TransitionModel &tmodel,
|
||||
const nnet3::AmNnetSimple &am_model,
|
||||
const fst::Fst<fst::StdArc> &fst,
|
||||
OnlineNnet2FeaturePipeline *feature_pipeline);
|
||||
OnlineFeatureInterface *feature_pipeline);
|
||||
|
||||
/// advance the decoding as far as we can.
|
||||
void AdvanceDecoding();
|
||||
|
@ -112,7 +112,7 @@ class SingleUtteranceNnet3Decoder {
|
|||
|
||||
OnlineNnet3DecodingConfig config_;
|
||||
|
||||
OnlineNnet2FeaturePipeline *feature_pipeline_;
|
||||
OnlineFeatureInterface *feature_pipeline_;
|
||||
|
||||
const TransitionModel &tmodel_;
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "feat/wave-reader.h"
|
||||
#include "online2/online-nnet2-decoding.h"
|
||||
#include "online2/online-nnet2-feature-pipeline.h"
|
||||
#include "online2/onlinebin-util.h"
|
||||
#include "online2/online-timing.h"
|
||||
#include "online2/online-endpoint.h"
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "feat/wave-reader.h"
|
||||
#include "online2/online-nnet2-decoding.h"
|
||||
#include "online2/online-nnet2-feature-pipeline.h"
|
||||
#include "online2/onlinebin-util.h"
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "feat/wave-reader.h"
|
||||
#include "online2/online-nnet2-decoding.h"
|
||||
#include "online2/online-nnet2-feature-pipeline.h"
|
||||
#include "online2/onlinebin-util.h"
|
||||
#include "online2/online-timing.h"
|
||||
#include "online2/online-endpoint.h"
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include "feat/wave-reader.h"
|
||||
#include "online2/online-nnet2-decoding-threaded.h"
|
||||
#include "online2/online-nnet2-feature-pipeline.h"
|
||||
#include "online2/onlinebin-util.h"
|
||||
#include "online2/online-timing.h"
|
||||
#include "online2/online-endpoint.h"
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "feat/wave-reader.h"
|
||||
#include "online2/online-nnet3-decoding.h"
|
||||
#include "online2/online-nnet2-feature-pipeline.h"
|
||||
#include "online2/onlinebin-util.h"
|
||||
#include "online2/online-timing.h"
|
||||
#include "online2/online-endpoint.h"
|
||||
|
|
Загрузка…
Ссылка в новой задаче