subword-nmt/subword_nmt/chrF.py

140 строки
4.4 KiB
Python
Executable File

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Rico Sennrich
"""Compute chrF3 for machine translation evaluation
Reference:
Maja Popović (2015). chrF: character n-gram F-score for automatic MT evaluation. In Proceedings of the Tenth Workshop on Statistical Machine Translationn, pages 392–395, Lisbon, Portugal.
"""
from __future__ import print_function, unicode_literals, division
import sys
import codecs
import io
import argparse
from collections import defaultdict
# hack for python2/3 compatibility
from io import open
argparse.open = open
def create_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="learn BPE-based word segmentation")
parser.add_argument(
'--ref', '-r', type=argparse.FileType('r'), required=True,
metavar='PATH',
help="Reference file")
parser.add_argument(
'--hyp', type=argparse.FileType('r'), metavar='PATH',
default=sys.stdin,
help="Hypothesis file (default: stdin).")
parser.add_argument(
'--beta', '-b', type=float, default=3,
metavar='FLOAT',
help="beta parameter (default: '%(default)s')")
parser.add_argument(
'--ngram', '-n', type=int, default=6,
metavar='INT',
help="ngram order (default: '%(default)s')")
parser.add_argument(
'--space', '-s', action='store_true',
help="take spaces into account (default: '%(default)s')")
parser.add_argument(
'--precision', action='store_true',
help="report precision (default: '%(default)s')")
parser.add_argument(
'--recall', action='store_true',
help="report recall (default: '%(default)s')")
return parser
def extract_ngrams(words, max_length=4, spaces=False):
if not spaces:
words = ''.join(words.split())
else:
words = words.strip()
results = defaultdict(lambda: defaultdict(int))
for length in range(max_length):
for start_pos in range(len(words)):
end_pos = start_pos + length + 1
if end_pos <= len(words):
results[length][tuple(words[start_pos: end_pos])] += 1
return results
def get_correct(ngrams_ref, ngrams_test, correct, total):
for rank in ngrams_test:
for chain in ngrams_test[rank]:
total[rank] += ngrams_test[rank][chain]
if chain in ngrams_ref[rank]:
correct[rank] += min(ngrams_test[rank][chain], ngrams_ref[rank][chain])
return correct, total
def f1(correct, total_hyp, total_ref, max_length, beta=3, smooth=0):
precision = 0
recall = 0
for i in range(max_length):
if total_hyp[i] + smooth and total_ref[i] + smooth:
precision += (correct[i] + smooth) / (total_hyp[i] + smooth)
recall += (correct[i] + smooth) / (total_ref[i] + smooth)
precision /= max_length
recall /= max_length
return (1 + beta**2) * (precision*recall) / ((beta**2 * precision) + recall), precision, recall
def main(args):
correct = [0]*args.ngram
total = [0]*args.ngram
total_ref = [0]*args.ngram
for line in args.ref:
line2 = args.hyp.readline()
ngrams_ref = extract_ngrams(line, max_length=args.ngram, spaces=args.space)
ngrams_test = extract_ngrams(line2, max_length=args.ngram, spaces=args.space)
get_correct(ngrams_ref, ngrams_test, correct, total)
for rank in ngrams_ref:
for chain in ngrams_ref[rank]:
total_ref[rank] += ngrams_ref[rank][chain]
chrf, precision, recall = f1(correct, total, total_ref, args.ngram, args.beta)
print('chrF3: {0:.4f}'.format(chrf))
if args.precision:
print('chrPrec: {0:.4f}'.format(precision))
if args.recall:
print('chrRec: {0:.4f}'.format(recall))
if __name__ == '__main__':
# python 2/3 compatibility
if sys.version_info < (3, 0):
sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
else:
sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8')
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True)
parser = create_parser()
args = parser.parse_args()
main(args)