28 строки
796 B
Python
28 строки
796 B
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import unicode_literals
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
|
|
def compute_metrics(x):
|
|
sx = np.sort(-x, axis=1)
|
|
d = np.diag(-x)
|
|
d = d[:, np.newaxis]
|
|
ind = sx - d
|
|
ind = np.where(ind == 0)
|
|
ind = ind[1]
|
|
metrics = {}
|
|
metrics['R1'] = float(np.sum(ind == 0)) / len(ind)
|
|
metrics['R5'] = float(np.sum(ind < 5)) / len(ind)
|
|
metrics['R10'] = float(np.sum(ind < 10)) / len(ind)
|
|
metrics['MR'] = np.median(ind) + 1
|
|
return metrics
|
|
|
|
def print_computed_metrics(metrics):
|
|
r1 = metrics['R1']
|
|
r5 = metrics['R5']
|
|
r10 = metrics['R10']
|
|
mr = metrics['MR']
|
|
print('R@1: {:.4f} - R@5: {:.4f} - R@10: {:.4f} - Median R: {}'.format(r1, r5, r10, mr))
|