Better key figure handling and additional score

This commit is contained in:
Tilman Kamp 2019-08-02 15:03:01 +02:00
Родитель 31da832843
Коммит c6efe188a2
2 изменённых файлов: 89 добавлений и 83 удалений

Просмотреть файл

@ -9,7 +9,7 @@ import numpy as np
import wavTranscriber
from collections import Counter
from search import FuzzySearch
from text import Alphabet, TextCleaner, levenshtein
from text import Alphabet, TextCleaner, levenshtein, similarity
def main(args):
@ -86,24 +86,27 @@ def main(args):
help='Writes clean aligned original transcripts to result file')
output_group.add_argument('--output-aligned-raw', action="store_true",
help='Writes raw aligned original transcripts to result file')
output_group.add_argument('--output-wer', action="store_true",
help='Writes word error rate (WER) to output')
output_group.add_argument('--output-cer', action="store_true",
help='Writes character error rate (CER) to output')
output_group.add_argument('--output-min-length', type=int, required=False, default=1,
help='Minimum phrase length (default: 1)')
output_group.add_argument('--output-max-length', type=int, required=False,
help='Maximum phrase length (default: no limit)')
output_group.add_argument('--output-min-score', type=float, required=False, default=2.0,
help='Minimum matching score (default: 2.0)')
output_group.add_argument('--output-max-score', type=float, required=False,
help='Maximum matching score (default: no limit)')
for b in ['Min', 'Max']:
for r in ['CER', 'WER']:
output_group.add_argument('--output-' + b.lower() + '-' + r.lower(), type=float, required=False,
help=b + 'imum ' + ('character' if r == 'CER' else 'word') +
' error rate (' + r + ') the STT transcript of the audio ' +
'has to have when compared with the original text')
named_numbers = {
'tlen': ('transcript length', int, None),
'mlen': ('match length', int, None),
'SWS': ('Smith-Waterman score', float, 'From 0.0 (not equal at all) to 100.0+ (pretty equal)'),
'WNG': ('weighted N-gram intersection', float, 'From 0.0 (not equal at all) to 100.0 (totally equal)'),
'CER': ('character error rate', float, 'From 0.0 (no different words) to 100.0+ (total miss)'),
'WER': ('word error rate', float, 'From 0.0 (no wrong characters) to 100.0+ (total miss)')
}
for short in named_numbers.keys():
long, atype, desc = named_numbers[short]
desc = (' - value range: ' + desc) if desc else ''
output_group.add_argument('--output-' + short.lower(), action="store_true",
help='Writes {} ({}) to output'.format(long, short))
for extreme in ['Min', 'Max']:
output_group.add_argument('--output-' + extreme.lower() + '-' + short.lower(), type=atype, required=False,
help='{}imum {} ({}) the STT transcript of the audio '
'has to have when compared with the original text{}'
.format(extreme, long, short, desc))
args = parser.parse_args()
@ -229,81 +232,84 @@ def main(args):
result_fragments = []
substitutions = Counter()
statistics = Counter()
def skip(message, index):
logging.info('Fragment %d: %s' % (index, message))
statistics[message] += 1
end_fragments = (args.start + args.num_samples) if args.num_samples else len(fragments)
fragments = fragments[args.start:end_fragments]
def skip(index, reason):
logging.info('Fragment {}: {}'.format(index, reason))
statistics[reason] += 1
def should_skip(number_key, index, show, get_value):
kl = number_key.lower()
should_output = getattr(args, 'output_' + kl)
min_val, max_val = getattr(args, 'output_min_' + kl), getattr(args, 'output_max_' + kl)
if should_output or min_val or max_val:
val = get_value()
if len(number_key) == 3:
show.insert(0, '{}: {:.2f}'.format(number_key, val))
if should_output:
fragments[index][kl] = val
reason_base = '{} ({})'.format(named_numbers[number_key][0], number_key)
reason = None
if min_val and val < min_val:
reason = reason_base + ' too low'
elif max_val and val > max_val:
reason = reason_base + ' too high'
if reason:
skip(index, reason)
return True
return False
for index, fragment in enumerate(fragments):
time_start = fragment['time-start']
time_length = fragment['time-length']
fragment_transcript = fragment['transcript']
if args.align_min_length and len(fragment_transcript) < args.align_min_length:
skip('Transcript too short for alignment', index)
continue
if args.align_max_length and len(fragment_transcript) > args.align_max_length:
skip('Transcript too long for alignment', index)
continue
match, match_distance, match_substitutions = search.find_best(fragment_transcript)
if match is None:
skip('No match for transcript', index)
continue
substitutions += match_substitutions
fragment_matched = tc.clean_text[match.start:match.end]
if args.output_min_length and len(fragment_matched) < args.output_min_length:
skip('Match too short', index)
continue
if args.output_max_length and len(fragment_matched) > args.output_max_length:
skip('Match too long', index)
continue
score = match_distance/max(len(fragment_matched), len(fragment_transcript))
sample_numbers = ['Score %.2f' % score]
if args.output_min_score and score < args.output_min_score:
skip('Matching score too low', index)
continue
if args.output_max_score and score > args.output_max_score:
skip('Matching score too high', index)
continue
original_start = tc.get_original_offset(match.start)
original_end = tc.get_original_offset(match.end)
result_fragment = {
'time-start': time_start,
'time-length': time_length,
'text-start': original_start,
'text-length': original_end-original_start,
'score': score
'time-length': time_length
}
if args.output_cer or args.output_min_cer or args.output_max_cer:
cer = levenshtein(fragment_transcript, fragment_matched)/len(fragment_matched)
sample_numbers.insert(0, 'CER: %.2f' % cer * 100)
if args.output_cer:
result_fragment['cer'] = cer
if args.output_min_cer and score < args.output_min_cer:
skip('Character error rate (CER) too low', index)
continue
if args.output_max_cer and score > args.output_max_cer:
skip('Character error rate (CER) too high', index)
continue
if args.output_wer or args.output_min_wer or args.output_max_wer:
wer = levenshtein(fragment_transcript.split(), fragment_matched.split())/len(fragment_matched.split())
sample_numbers.insert(0, 'WER: %.2f' % wer * 100)
if args.output_wer:
result_fragment['wer'] = wer
if args.output_min_wer and score < args.output_min_wer:
skip('Word error rate (WER) too low', index)
continue
if args.output_max_wer and score > args.output_max_wer:
skip('Word error rate (WER) too high', index)
sample_numbers = []
if should_skip('tlen', index, sample_numbers, lambda: len(fragment_transcript)):
continue
if args.output_stt:
result_fragment['stt'] = fragment_transcript
if args.output_aligned:
result_fragment['aligned'] = fragment_matched
match, match_distance, match_substitutions = search.find_best(fragment_transcript)
if match is None:
skip(index, 'No match for transcript')
continue
original_start = tc.get_original_offset(match.start)
original_end = tc.get_original_offset(match.end)
result_fragment['text-start'] = original_start
result_fragment['text-length'] = original_end - original_start
if args.output_aligned_raw:
result_fragment['aligned-raw'] = original_transcript[original_start:original_end]
substitutions += match_substitutions
fragment_matched = tc.clean_text[match.start:match.end]
if should_skip('mlen', index, sample_numbers, lambda: len(fragment_matched)):
continue
if args.output_aligned:
result_fragment['aligned'] = fragment_matched
if should_skip('SWS', index, sample_numbers,
lambda: match_distance / max(len(fragment_matched), len(fragment_transcript))):
continue
if should_skip('WNG', index, sample_numbers,
lambda: 100 * similarity(fragment_matched, fragment_transcript)):
continue
if should_skip('CER', index, sample_numbers,
lambda: 100 * levenshtein(fragment_transcript, fragment_matched) / len(fragment_matched)):
continue
if should_skip('WER', index, sample_numbers,
lambda: 100 * levenshtein(fragment_transcript.split(),
fragment_matched.split())/len(fragment_matched.split())):
continue
result_fragments.append(result_fragment)
logging.debug('Fragment %d aligned with %s' % (index, ' '.join(sample_numbers)))
logging.debug('- T: ' + args.text_context * ' ' + '"%s"' % fragment_transcript)

Просмотреть файл

@ -12,7 +12,7 @@ class FuzzySearch(object):
match_score=100,
mismatch_score=-100,
gap_score=-100,
letter_similarities=None):
char_similarities=None):
self.text = text
self.max_candidates = max_candidates
self.candidate_threshold = candidate_threshold
@ -21,7 +21,7 @@ class FuzzySearch(object):
self.match_score = match_score
self.mismatch_score = mismatch_score
self.gap_score = gap_score
self.letter_similarities = letter_similarities
self.char_similarities = char_similarities
self.ngrams = {}
for i, ngram in enumerate(ngrams(' ' + text + ' ', 3)):
if ngram in self.ngrams:
@ -38,8 +38,8 @@ class FuzzySearch(object):
def char_similarity(self, a, b):
key = FuzzySearch.char_pair(a, b)
if self.letter_similarities and key in self.letter_similarities:
return self.letter_similarities[key]
if self.char_similarities and key in self.char_similarities:
return self.char_similarities[key]
return self.match_score if a == b else self.mismatch_score
def sw_align(self, a, start, end):