UniVL/metrics.py

28 строки
796 B
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function
import numpy as np
def compute_metrics(x):
sx = np.sort(-x, axis=1)
d = np.diag(-x)
d = d[:, np.newaxis]
ind = sx - d
ind = np.where(ind == 0)
ind = ind[1]
metrics = {}
metrics['R1'] = float(np.sum(ind == 0)) / len(ind)
metrics['R5'] = float(np.sum(ind < 5)) / len(ind)
metrics['R10'] = float(np.sum(ind < 10)) / len(ind)
metrics['MR'] = np.median(ind) + 1
return metrics
def print_computed_metrics(metrics):
r1 = metrics['R1']
r5 = metrics['R5']
r10 = metrics['R10']
mr = metrics['MR']
print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr))