зеркало из https://github.com/mozilla/subword-nmt.git
140 строки
4.4 KiB
Python
Executable File
140 строки
4.4 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
# Author: Rico Sennrich
|
|
|
|
"""Compute chrF3 for machine translation evaluation
|
|
|
|
Reference:
|
|
Maja Popović (2015). chrF: character n-gram F-score for automatic MT evaluation. In Proceedings of the Tenth Workshop on Statistical Machine Translationn, pages 392–395, Lisbon, Portugal.
|
|
"""
|
|
|
|
from __future__ import print_function, unicode_literals, division
|
|
|
|
import sys
|
|
import codecs
|
|
import io
|
|
import argparse
|
|
|
|
from collections import defaultdict
|
|
|
|
# hack for python2/3 compatibility
|
|
from io import open
|
|
argparse.open = open
|
|
|
|
def create_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
description="learn BPE-based word segmentation")
|
|
|
|
parser.add_argument(
|
|
'--ref', '-r', type=argparse.FileType('r'), required=True,
|
|
metavar='PATH',
|
|
help="Reference file")
|
|
parser.add_argument(
|
|
'--hyp', type=argparse.FileType('r'), metavar='PATH',
|
|
default=sys.stdin,
|
|
help="Hypothesis file (default: stdin).")
|
|
parser.add_argument(
|
|
'--beta', '-b', type=float, default=3,
|
|
metavar='FLOAT',
|
|
help="beta parameter (default: '%(default)s')")
|
|
parser.add_argument(
|
|
'--ngram', '-n', type=int, default=6,
|
|
metavar='INT',
|
|
help="ngram order (default: '%(default)s')")
|
|
parser.add_argument(
|
|
'--space', '-s', action='store_true',
|
|
help="take spaces into account (default: '%(default)s')")
|
|
parser.add_argument(
|
|
'--precision', action='store_true',
|
|
help="report precision (default: '%(default)s')")
|
|
parser.add_argument(
|
|
'--recall', action='store_true',
|
|
help="report recall (default: '%(default)s')")
|
|
|
|
return parser
|
|
|
|
def extract_ngrams(words, max_length=4, spaces=False):
|
|
|
|
if not spaces:
|
|
words = ''.join(words.split())
|
|
else:
|
|
words = words.strip()
|
|
|
|
results = defaultdict(lambda: defaultdict(int))
|
|
for length in range(max_length):
|
|
for start_pos in range(len(words)):
|
|
end_pos = start_pos + length + 1
|
|
if end_pos <= len(words):
|
|
results[length][tuple(words[start_pos: end_pos])] += 1
|
|
return results
|
|
|
|
|
|
def get_correct(ngrams_ref, ngrams_test, correct, total):
|
|
|
|
for rank in ngrams_test:
|
|
for chain in ngrams_test[rank]:
|
|
total[rank] += ngrams_test[rank][chain]
|
|
if chain in ngrams_ref[rank]:
|
|
correct[rank] += min(ngrams_test[rank][chain], ngrams_ref[rank][chain])
|
|
|
|
return correct, total
|
|
|
|
|
|
def f1(correct, total_hyp, total_ref, max_length, beta=3, smooth=0):
|
|
|
|
precision = 0
|
|
recall = 0
|
|
|
|
for i in range(max_length):
|
|
if total_hyp[i] + smooth and total_ref[i] + smooth:
|
|
precision += (correct[i] + smooth) / (total_hyp[i] + smooth)
|
|
recall += (correct[i] + smooth) / (total_ref[i] + smooth)
|
|
|
|
precision /= max_length
|
|
recall /= max_length
|
|
|
|
return (1 + beta**2) * (precision*recall) / ((beta**2 * precision) + recall), precision, recall
|
|
|
|
def main(args):
|
|
|
|
correct = [0]*args.ngram
|
|
total = [0]*args.ngram
|
|
total_ref = [0]*args.ngram
|
|
for line in args.ref:
|
|
line2 = args.hyp.readline()
|
|
|
|
ngrams_ref = extract_ngrams(line, max_length=args.ngram, spaces=args.space)
|
|
ngrams_test = extract_ngrams(line2, max_length=args.ngram, spaces=args.space)
|
|
|
|
get_correct(ngrams_ref, ngrams_test, correct, total)
|
|
|
|
for rank in ngrams_ref:
|
|
for chain in ngrams_ref[rank]:
|
|
total_ref[rank] += ngrams_ref[rank][chain]
|
|
|
|
chrf, precision, recall = f1(correct, total, total_ref, args.ngram, args.beta)
|
|
|
|
print('chrF3: {0:.4f}'.format(chrf))
|
|
if args.precision:
|
|
print('chrPrec: {0:.4f}'.format(precision))
|
|
if args.recall:
|
|
print('chrRec: {0:.4f}'.format(recall))
|
|
|
|
if __name__ == '__main__':
|
|
|
|
# python 2/3 compatibility
|
|
if sys.version_info < (3, 0):
|
|
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
|
|
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
|
|
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
|
|
else:
|
|
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
|
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True)
|
|
|
|
parser = create_parser()
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|