Ability to drop samples with unknown and/or multi-instance meta data

This commit is contained in:
Tilman Kamp 2020-02-17 11:41:55 +01:00
Родитель 3342067a0b
Коммит aacbda1676
1 изменённых файлов: 24 добавлений и 12 удалений

Просмотреть файл

@ -116,6 +116,10 @@ def main(args):
help='Split each partition except "other" into train/dev/test sub-sets.')
parser.add_argument('--split-field', type=str,
help='Sample meta field that should be used for splitting (e.g. "speaker")')
parser.add_argument('--split-drop-multiple', action="store_true",
help='Drop all samples with multiple --split-field assignments.')
parser.add_argument('--split-drop-unknown', action="store_true",
help='Drop all samples with no --split-field assignment.')
for sub_set in SET_NAMES:
parser.add_argument('--assign-' + sub_set,
help='Comma separated list of --split-field values that are to be assigned to sub-set '
@ -284,17 +288,20 @@ def main(args):
for fragment in progress(fragments, desc='Computing qualities'):
fragment['quality'] = eval(args.criteria, {'math': math}, fragment)
def get_meta(fragment, meta_field):
if 'meta' in fragment:
meta = fragment['meta']
if meta_field in meta:
for value in meta[meta_field]:
return value
return UNKNOWN
def get_meta_field(f, meta_field):
if 'meta' in f:
meta_fields = f['meta']
if isinstance(meta_fields, dict) and meta_field in meta_fields:
return meta_fields[meta_field]
return []
def get_first_meta(f, meta_field):
metas = get_meta_field(f, meta_field)
return metas[0] if metas else UNKNOWN
if args.debias is not None:
for debias in args.debias:
grouped = engroup(fragments, lambda f: get_meta(f, debias))
grouped = engroup(fragments, lambda f: get_first_meta(f, debias))
if UNKNOWN in grouped:
fragments = grouped[UNKNOWN]
del grouped[UNKNOWN]
@ -342,15 +349,19 @@ def main(args):
if args.split_seed is not None:
random.seed(args.split_seed)
partitions = engroup(fragments, get_partition)
if args.split and args.split_field:
metas = engroup(fragments, lambda f: get_meta(f, args.split_field)).items()
if args.split_drop_multiple:
fragments = filter(lambda f: len(get_meta_field(f, args.split_field)) < 2, fragments)
if args.split_drop_unknown:
fragments = filter(lambda f: len(get_meta_field(f, args.split_field)) > 0, fragments)
fragments = list(fragments)
metas = engroup(fragments, lambda f: get_first_meta(f, args.split_field)).items()
metas = sorted(metas, key=lambda meta_frags: len(meta_frags[1]))
metas = list(map(lambda meta_frags: meta_frags[0], metas))
partitions = engroup(fragments, get_partition)
partitions = list(map(lambda part_frags: (part_frags[0],
get_sample_size(len(part_frags[1])),
engroup(part_frags[1], lambda pf: get_meta(pf, args.split_field)),
engroup(part_frags[1], lambda pf: get_first_meta(pf, args.split_field)),
[[], [], []]),
partitions.items()))
remaining_metas = []
@ -379,6 +390,7 @@ def main(args):
for set_index, set_name in enumerate(SET_NAMES):
assign_fragments(sample_sets[set_index], partition + '-' + set_name)
else:
partitions = engroup(fragments, get_partition)
for partition, partition_fragments in partitions.items():
if args.split:
sample_size = get_sample_size(len(partition_fragments))