From c43458e671f0734e98cc878bb0f10a9dbcac7ca6 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin Date: Thu, 1 Jun 2023 04:09:57 -0400 Subject: [PATCH] Add reset functions that don't clear the model --- include/lpcnet.h | 3 +++ src/lpcnet.c | 15 +++++++++++---- src/lpcnet_plc.c | 16 ++++++++++++---- src/lpcnet_private.h | 18 +++++++++++------- 4 files changed, 37 insertions(+), 15 deletions(-) diff --git a/include/lpcnet.h b/include/lpcnet.h index 067a432..c094aec 100644 --- a/include/lpcnet.h +++ b/include/lpcnet.h @@ -75,6 +75,8 @@ LPCNET_EXPORT int lpcnet_decoder_get_size(void); */ LPCNET_EXPORT int lpcnet_decoder_init(LPCNetDecState *st); +LPCNET_EXPORT void lpcnet_reset(LPCNetState *lpcnet); + /** Allocates and initializes a decoder state. * @returns The newly created state */ @@ -186,6 +188,7 @@ LPCNET_EXPORT void lpcnet_synthesize(LPCNetState *st, const float *features, sho LPCNET_EXPORT int lpcnet_plc_get_size(void); LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options); +LPCNET_EXPORT void lpcnet_plc_reset(LPCNetPLCState *st); LPCNET_EXPORT LPCNetPLCState *lpcnet_plc_create(int options); diff --git a/src/lpcnet.c b/src/lpcnet.c index 455a5ed..31fe06d 100644 --- a/src/lpcnet.c +++ b/src/lpcnet.c @@ -171,23 +171,30 @@ LPCNET_EXPORT int lpcnet_get_size() return sizeof(LPCNetState); } +LPCNET_EXPORT void lpcnet_reset(LPCNetState *lpcnet) +{ + const char* rng_string="LPCNet"; + RNN_CLEAR((char*)&lpcnet->LPCNET_RESET_START, + sizeof(LPCNetState)- + ((char*)&lpcnet->LPCNET_RESET_START - (char*)lpcnet)); + lpcnet->last_exc = lin2ulaw(0.f); + kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string)); +} + LPCNET_EXPORT int lpcnet_init(LPCNetState *lpcnet) { int i; int ret; - const char* rng_string="LPCNet"; - memset(lpcnet, 0, lpcnet_get_size()); - lpcnet->last_exc = lin2ulaw(0.f); for (i=0;i<256;i++) { float prob = .025f+.95f*i/255.f; lpcnet->sampling_logit_table[i] = -log((1-prob)/prob); } - kiss99_srand(&lpcnet->rng, (const unsigned char *)rng_string, strlen(rng_string)); #ifndef USE_WEIGHTS_FILE ret = init_lpcnet_model(&lpcnet->model, lpcnet_arrays); #else ret = 0; #endif + lpcnet_reset(lpcnet); celt_assert(ret == 0); return ret; } diff --git a/src/lpcnet_plc.c b/src/lpcnet_plc.c index 1e27fb0..a104c1d 100644 --- a/src/lpcnet_plc.c +++ b/src/lpcnet_plc.c @@ -43,10 +43,11 @@ LPCNET_EXPORT int lpcnet_plc_get_size() { return sizeof(LPCNetPLCState); } -LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) { - int ret; - RNN_CLEAR(st, 1); - lpcnet_init(&st->lpcnet); +LPCNET_EXPORT void lpcnet_plc_reset(LPCNetPLCState *st) { + RNN_CLEAR((char*)&st->LPCNET_PLC_RESET_START, + sizeof(LPCNetPLCState)- + ((char*)&st->LPCNET_PLC_RESET_START - (char*)st)); + lpcnet_reset(&st->lpcnet); lpcnet_encoder_init(&st->enc); RNN_CLEAR(st->pcm, PLC_BUF_SIZE); st->pcm_fill = PLC_BUF_SIZE; @@ -55,6 +56,12 @@ LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) { st->loss_count = 0; st->dc_mem = 0; st->queued_update = 0; +} + +LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) { + int ret; + lpcnet_init(&st->lpcnet); + lpcnet_encoder_init(&st->enc); if ((options&0x3) == LPCNET_PLC_CAUSAL) { st->enable_blending = 1; st->non_causal = 0; @@ -74,6 +81,7 @@ LPCNET_EXPORT int lpcnet_plc_init(LPCNetPLCState *st, int options) { ret = 0; #endif celt_assert(ret == 0); + lpcnet_plc_reset(st); return ret; } diff --git a/src/lpcnet_private.h b/src/lpcnet_private.h index 5db3c63..39cfbf7 100644 --- a/src/lpcnet_private.h +++ b/src/lpcnet_private.h @@ -26,6 +26,11 @@ #define MAX_FEATURE_BUFFER_SIZE 4 struct LPCNetState { + LPCNetModel model; + float sampling_logit_table[256]; + kiss99_ctx rng; + +#define LPCNET_RESET_START nnet NNetState nnet; int last_exc; float last_sig[LPC_ORDER]; @@ -35,14 +40,11 @@ struct LPCNetState { #if FEATURES_DELAY>0 float old_lpc[FEATURES_DELAY][LPC_ORDER]; #endif - float sampling_logit_table[256]; float gru_a_condition[3*GRU_A_STATE_SIZE]; float gru_b_condition[3*GRU_B_STATE_SIZE]; int frame_count; float deemph_mem; float lpc[LPC_ORDER]; - kiss99_ctx rng; - LPCNetModel model; }; struct LPCNetDecState { @@ -74,8 +76,14 @@ struct LPCNetEncState{ #define PLC_BUF_SIZE (FEATURES_DELAY*FRAME_SIZE + TRAINING_OFFSET) struct LPCNetPLCState { + PLCModel model; LPCNetState lpcnet; LPCNetEncState enc; + int enable_blending; + int non_causal; + int remove_dc; + +#define LPCNET_PLC_RESET_START fec float fec[PLC_MAX_FEC][NB_FEATURES]; int fec_keep_pos; int fec_read_pos; @@ -89,16 +97,12 @@ struct LPCNetPLCState { int loss_count; PLCNetState plc_net; PLCNetState plc_copy[FEATURES_DELAY+1]; - int enable_blending; - int non_causal; double dc_mem; double syn_dc; - int remove_dc; short dc_buf[TRAINING_OFFSET]; int queued_update; short queued_samples[FRAME_SIZE]; - PLCModel model; }; extern float ceps_codebook1[];