зеркало из https://github.com/mozilla/subword-nmt.git
new option --total-symbols in learn-bpe
redefines "--symbols" to be the number of merge operations, minus the character vocabulary size, so that "--symbols" becomes an estimate of the final symbol vocabulary size. thx @phikoehn
This commit is contained in:
Родитель
71b22d1a99
Коммит
61ad855cf0
|
@ -55,6 +55,9 @@ def create_parser(subparsers=None):
|
|||
help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))')
|
||||
parser.add_argument('--dict-input', action="store_true",
|
||||
help="If set, input file is interpreted as a dictionary where each line contains a word-count pair")
|
||||
parser.add_argument(
|
||||
'--total-symbols', '-t', action="store_true",
|
||||
help="subtract number of characters from the symbols to be generated (so that '--symbols' becomes an estimate for the total number of symbols needed to encode text).")
|
||||
parser.add_argument(
|
||||
'--verbose', '-v', action="store_true",
|
||||
help="verbose mode.")
|
||||
|
@ -197,7 +200,7 @@ def prune_stats(stats, big_stats, threshold):
|
|||
big_stats[item] = freq
|
||||
|
||||
|
||||
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False):
|
||||
def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_dict=False, total_symbols=False):
|
||||
"""Learn num_symbols BPE operations from vocabulary, and write to outfile.
|
||||
"""
|
||||
|
||||
|
@ -211,6 +214,19 @@ def learn_bpe(infile, outfile, num_symbols, min_frequency=2, verbose=False, is_d
|
|||
|
||||
stats, indices = get_pair_statistics(sorted_vocab)
|
||||
big_stats = copy.deepcopy(stats)
|
||||
|
||||
if total_symbols:
|
||||
uniq_char_internal = set()
|
||||
uniq_char_final = set()
|
||||
for word in vocab:
|
||||
for char in word[:-1]:
|
||||
uniq_char_internal.add(char)
|
||||
uniq_char_final.add(char[-1])
|
||||
sys.stderr.write('Number of word-internal characters: {0}\n'.format(len(uniq_char_internal)))
|
||||
sys.stderr.write('Number of word-final characters: {0}\n'.format(len(uniq_char_final)))
|
||||
sys.stderr.write('Reducing number of merge operations by {0}\n'.format(len(uniq_char_internal) + len(uniq_char_final)))
|
||||
num_symbols -= len(uniq_char_internal) + len(uniq_char_final)
|
||||
|
||||
# threshold is inspired by Zipfian assumption, but should only affect speed
|
||||
threshold = max(stats.values()) / 10
|
||||
for i in range(num_symbols):
|
||||
|
@ -270,4 +286,4 @@ if __name__ == '__main__':
|
|||
if args.output.name != '<stdout>':
|
||||
args.output = codecs.open(args.output.name, 'w', encoding='utf-8')
|
||||
|
||||
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input)
|
||||
learn_bpe(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input, total_symbols=args.total_symbols)
|
||||
|
|
Загрузка…
Ссылка в новой задаче