update
This commit is contained in:
Родитель
1f3f589ba3
Коммит
962eb07554
|
@ -1,3 +1,2 @@
|
||||||
.git
|
.git
|
||||||
.idea
|
.idea
|
||||||
READMEINS.md
|
|
|
@ -439,10 +439,10 @@ def eval_epoch(args, model, test_dataloader, device, n_gpu):
|
||||||
sim_matrix = []
|
sim_matrix = []
|
||||||
for idx in range(len(parallel_outputs)):
|
for idx in range(len(parallel_outputs)):
|
||||||
sim_matrix += parallel_outputs[idx]
|
sim_matrix += parallel_outputs[idx]
|
||||||
|
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
|
||||||
else:
|
else:
|
||||||
sim_matrix = _run_on_single_gpu(model, batch_list, batch_list, batch_sequence_output_list, batch_visual_output_list)
|
sim_matrix = _run_on_single_gpu(model, batch_list, batch_list, batch_sequence_output_list, batch_visual_output_list)
|
||||||
|
|
||||||
sim_matrix = np.concatenate(tuple(sim_matrix), axis=0)
|
|
||||||
metrics = compute_metrics(sim_matrix)
|
metrics = compute_metrics(sim_matrix)
|
||||||
logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0])))
|
logger.info('\t Length-T: {}, Length-V:{}'.format(len(sim_matrix), len(sim_matrix[0])))
|
||||||
logger.info('\t>>> R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.
|
logger.info('\t>>> R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.
|
||||||
|
|
Загрузка…
Ссылка в новой задаче