This commit is contained in:
Daisy Deng 2020-02-24 21:25:02 +00:00
Родитель 3325bdb21a
Коммит 6597c93148
1 изменённых файлов: 7 добавлений и 4 удалений

Просмотреть файл

@ -224,6 +224,7 @@ class Transformer:
# init training
tr_loss = 0.0
accum_loss = 0
train_size = 0
self.model.train()
self.model.zero_grad()
@ -254,7 +255,7 @@ class Transformer:
tr_loss += loss.item()
accum_loss += loss.item()
train_size += list(inputs.values())[0].size()[0]
if (step + 1) % gradient_accumulation_steps == 0:
global_step += 1
@ -274,13 +275,14 @@ class Transformer:
endtime_string = datetime.datetime.fromtimestamp(end).strftime(
"%d/%m/%Y %H:%M:%S"
)
log_line = """timestamp: {0:s}, loss: {1:.6f}, time duration: {2:f},
number of examples in current step: {3:.0f}, step {4:.0f}
log_line = """timestamp: {0:s}, average loss: {1:.6f}, time duration: {2:f},
number of examples in current reporting: {3:.0f}, step {4:.0f}
out of total {5:.0f}""".format(
endtime_string,
accum_loss / report_every,
end - start,
list(inputs.values())[0].size()[0],
#list(inputs.values())[0].size()[0],
train_size,
global_step,
max_steps,
@ -288,6 +290,7 @@ class Transformer:
logger.info(log_line)
print(log_line)
accum_loss = 0
train_size = 0
start = end
if type(optimizer) == list:
for o in optimizer: