diff --git a/config.json b/config.json index 3c43152..9ec31b4 100644 --- a/config.json +++ b/config.json @@ -12,9 +12,9 @@ "text_cleaner": "english_cleaners", "epochs": 2000, - "lr": 0.0003, + "lr": 0.0006 / 32, "warmup_steps": 4000, - "batch_size": 8, + "batch_size": 1, "r": 5, "griffin_lim_iters": 60, @@ -25,5 +25,6 @@ "checkpoint": false, "save_step": 69, "data_path": "/run/shm/erogol/LJSpeech-1.0", + "min_seq_len": 90, "output_path": "result" } diff --git a/datasets/LJSpeech.py b/datasets/LJSpeech.py index a42a626..3edd644 100644 --- a/datasets/LJSpeech.py +++ b/datasets/LJSpeech.py @@ -14,7 +14,8 @@ class LJSpeechDataset(Dataset): def __init__(self, csv_file, root_dir, outputs_per_step, sample_rate, text_cleaner, num_mels, min_level_db, frame_shift_ms, - frame_length_ms, preemphasis, ref_level_db, num_freq, power): + frame_length_ms, preemphasis, ref_level_db, num_freq, power, + min_seq_len=0): with open(csv_file, "r") as f: self.frames = [line.split('|') for line in f] @@ -22,6 +23,7 @@ class LJSpeechDataset(Dataset): self.outputs_per_step = outputs_per_step self.sample_rate = sample_rate self.cleaners = text_cleaner + self.min_seq_length = min_seq_length self.ap = AudioProcessor(sample_rate, num_mels, min_level_db, frame_shift_ms, frame_length_ms, preemphasis, ref_level_db, num_freq, power) print(" > Reading LJSpeech from - {}".format(root_dir)) @@ -45,8 +47,14 @@ class LJSpeechDataset(Dataset): idxs = np.argsort(lengths) new_frames = [None] * len(lengths) + ignored = [] for i, idx in enumerate(idxs): - new_frames[i] = self.frames[idx] + length = lengths[idx] + if length < self.min_seq_length: + ignored.append(idx) + else + new_frames[i] = self.frames[idx] + print(" | > {} instances are ignored by min_seq_len ({})".format(len(ignored), self.min_seq_len)) self.frames = new_frames def __len__(self): diff --git a/train.py b/train.py index 1b2c944..288dad2 100644 --- a/train.py +++ b/train.py @@ -302,7 +302,8 @@ def main(args): c.preemphasis, c.ref_level_db, c.num_freq, - c.power + c.power, + min_seq_len=c.min_seq_len ) train_loader = DataLoader(train_dataset, batch_size=c.batch_size,