This commit is contained in:
Sergii Dymchenko 2019-11-20 21:54:17 -08:00
Родитель c4e92dfe1f
Коммит b92d5dedb6
1 изменённых файлов: 6 добавлений и 0 удалений

Просмотреть файл

@ -1,4 +1,6 @@
import argparse
import datetime
import sys
import torch
@ -22,6 +24,8 @@ model = Net()
model.to(device)
model.load_state_dict(torch.load(args.model_state_dict_path))
model.eval()
start_time = datetime.datetime.now()
with torch.no_grad():
for i_batch, sample_batched in enumerate(test_loader):
features = sample_batched["x"].to(device=device)
@ -29,3 +33,5 @@ with torch.no_grad():
indices = outputs.argmax(1)
letters = [chr(ord('a') + i) for i in indices]
print(''.join(letters))
end_time = datetime.datetime.now()
sys.stderr.write("Solving time (sec): %f\n" % (end_time - start_time).total_seconds())