зеркало из https://github.com/mozilla/TTS.git
num_speakers larger than 1
This commit is contained in:
Родитель
d23e29ea1f
Коммит
6390c3b2e6
|
@ -29,7 +29,7 @@ class Tacotron(nn.Module):
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(num_chars, 256)
|
self.embedding = nn.Embedding(num_chars, 256)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
if num_speakers > 0:
|
if num_speakers > 1:
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(256)
|
self.encoder = Encoder(256)
|
||||||
|
|
|
@ -29,7 +29,7 @@ class Tacotron2(nn.Module):
|
||||||
std = sqrt(2.0 / (num_chars + 512))
|
std = sqrt(2.0 / (num_chars + 512))
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding.weight.data.uniform_(-val, val)
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
if num_speakers > 0:
|
if num_speakers > 1:
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(512)
|
self.encoder = Encoder(512)
|
||||||
|
|
|
@ -30,7 +30,7 @@ class TacotronGST(nn.Module):
|
||||||
self.linear_dim = linear_dim
|
self.linear_dim = linear_dim
|
||||||
self.embedding = nn.Embedding(num_chars, 256)
|
self.embedding = nn.Embedding(num_chars, 256)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
if num_speakers > 0:
|
if num_speakers > 1:
|
||||||
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(256)
|
self.encoder = Encoder(256)
|
||||||
|
|
14
train.py
14
train.py
|
@ -78,7 +78,7 @@ def setup_loader(is_val=False, verbose=False):
|
||||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
ap, epoch):
|
ap, epoch):
|
||||||
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
|
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
|
||||||
if c.num_speakers > 0:
|
if c.num_speakers > 1:
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
@ -102,7 +102,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
avg_text_length = torch.mean(text_lengths.float())
|
avg_text_length = torch.mean(text_lengths.float())
|
||||||
avg_spec_length = torch.mean(mel_lengths.float())
|
avg_spec_length = torch.mean(mel_lengths.float())
|
||||||
|
|
||||||
if c.num_speakers > 0:
|
if c.num_speakers > 1:
|
||||||
speaker_ids = []
|
speaker_ids = []
|
||||||
for speaker_name in speaker_names:
|
for speaker_name in speaker_names:
|
||||||
if speaker_name not in speaker_mapping:
|
if speaker_name not in speaker_mapping:
|
||||||
|
@ -272,14 +272,14 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
tb_logger.tb_model_weights(model, current_step)
|
tb_logger.tb_model_weights(model, current_step)
|
||||||
|
|
||||||
# save speaker mapping
|
# save speaker mapping
|
||||||
if c.num_speakers > 0:
|
if c.num_speakers > 1:
|
||||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||||
return avg_postnet_loss, current_step
|
return avg_postnet_loss, current_step
|
||||||
|
|
||||||
|
|
||||||
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
data_loader = setup_loader(is_val=True)
|
data_loader = setup_loader(is_val=True)
|
||||||
if c.num_speakers > 0:
|
if c.num_speakers > 1:
|
||||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||||
model.eval()
|
model.eval()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
@ -311,7 +311,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
mel_lengths = data[5]
|
mel_lengths = data[5]
|
||||||
stop_targets = data[6]
|
stop_targets = data[6]
|
||||||
|
|
||||||
if c.num_speakers > 0:
|
if c.num_speakers > 1:
|
||||||
speaker_ids = [speaker_mapping[speaker_name]
|
speaker_ids = [speaker_mapping[speaker_name]
|
||||||
for speaker_name in speaker_names]
|
for speaker_name in speaker_names]
|
||||||
speaker_ids = torch.LongTensor(speaker_ids)
|
speaker_ids = torch.LongTensor(speaker_ids)
|
||||||
|
@ -415,7 +415,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
print(" | > Synthesizing test sentences")
|
print(" | > Synthesizing test sentences")
|
||||||
speaker_id = 0 if c.num_speakers > 0 else None
|
speaker_id = 0 if c.num_speakers > 1 else None
|
||||||
for idx, test_sentence in enumerate(test_sentences):
|
for idx, test_sentence in enumerate(test_sentences):
|
||||||
try:
|
try:
|
||||||
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
|
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
|
||||||
|
@ -484,7 +484,7 @@ def main(args):
|
||||||
args.restore_step = checkpoint['step']
|
args.restore_step = checkpoint['step']
|
||||||
# copying speakers.json
|
# copying speakers.json
|
||||||
prev_out_path = os.path.dirname(args.restore_path)
|
prev_out_path = os.path.dirname(args.restore_path)
|
||||||
if c.num_speakers > 0:
|
if c.num_speakers > 1:
|
||||||
copy_speaker_mapping(prev_out_path, OUT_PATH)
|
copy_speaker_mapping(prev_out_path, OUT_PATH)
|
||||||
else:
|
else:
|
||||||
args.restore_step = 0
|
args.restore_step = 0
|
||||||
|
|
Загрузка…
Ссылка в новой задаче