new option --total-symbols in learn-bpe

redefines "--symbols" to be the number of merge operations,
minus the character vocabulary size, so that "--symbols" becomes
an estimate of the final symbol vocabulary size.

thx @phikoehn
This commit is contained in:
Rico Sennrich 2018-06-28 11:41:08 +01:00
Родитель 71b22d1a99
Коммит 61ad855cf0
1 изменённых файлов: 18 добавлений и 2 удалений

Просмотреть файл

@ -55,6 +55,9 @@ def create_parser(subparsers=None):
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
parser.add_argument('--dict-input', action="store_true",
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
parser.add_argument(
'--total-symbols', '-t', action="store_true",
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
parser.add_argument(
'--verbose', '-v', action="store_true",
help="verbose mode.")
@ -197,7 +200,7 @@ def prune_stats(stats, big_stats, threshold):
big_stats[item] = freq
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False):
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False):
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
"""
@ -211,6 +214,19 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d
stats, indices = get_pair_statistics(sorted_vocab)
big_stats = copy.deepcopy(stats)
if total_symbols:
uniq_char_internal = set()
uniq_char_final = set()
for word in vocab:
for char in word[:-1]:
uniq_char_internal.add(char)
uniq_char_final.add(char[-1])
sys.stderr.write('Number of word-internal characters: {0}\n'.format(len(uniq_char_internal)))
sys.stderr.write('Number of word-final characters: {0}\n'.format(len(uniq_char_final)))
sys.stderr.write('Reducing number of merge operations by {0}\n'.format(len(uniq_char_internal) + len(uniq_char_final)))
num_symbols -= len(uniq_char_internal) + len(uniq_char_final)
# threshold is inspired by Zipfian assumption, but should only affect speed
threshold = max(stats.values()) / 10
for i in range(num_symbols):
@ -270,4 +286,4 @@ if __name__ == '__main__':
if args.output.name != '<stdout>':
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input)
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols)