зеркало из https://github.com/mozilla/DSAlign.git
Ability to drop samples with unknown and/or multi-instance meta data
This commit is contained in:
Родитель
3342067a0b
Коммит
aacbda1676
|
@ -116,6 +116,10 @@ def main(args):
|
|||
help='Split each partition except "other" into train/dev/test sub-sets.')
|
||||
parser.add_argument('--split-field', type=str,
|
||||
help='Sample meta field that should be used for splitting (e.g. "speaker")')
|
||||
parser.add_argument('--split-drop-multiple', action="store_true",
|
||||
help='Drop all samples with multiple --split-field assignments.')
|
||||
parser.add_argument('--split-drop-unknown', action="store_true",
|
||||
help='Drop all samples with no --split-field assignment.')
|
||||
for sub_set in SET_NAMES:
|
||||
parser.add_argument('--assign-' + sub_set,
|
||||
help='Comma separated list of --split-field values that are to be assigned to sub-set '
|
||||
|
@ -284,17 +288,20 @@ def main(args):
|
|||
for fragment in progress(fragments, desc='Computing qualities'):
|
||||
fragment['quality'] = eval(args.criteria, {'math': math}, fragment)
|
||||
|
||||
def get_meta(fragment, meta_field):
|
||||
if 'meta' in fragment:
|
||||
meta = fragment['meta']
|
||||
if meta_field in meta:
|
||||
for value in meta[meta_field]:
|
||||
return value
|
||||
return UNKNOWN
|
||||
def get_meta_field(f, meta_field):
|
||||
if 'meta' in f:
|
||||
meta_fields = f['meta']
|
||||
if isinstance(meta_fields, dict) and meta_field in meta_fields:
|
||||
return meta_fields[meta_field]
|
||||
return []
|
||||
|
||||
def get_first_meta(f, meta_field):
|
||||
metas = get_meta_field(f, meta_field)
|
||||
return metas[0] if metas else UNKNOWN
|
||||
|
||||
if args.debias is not None:
|
||||
for debias in args.debias:
|
||||
grouped = engroup(fragments, lambda f: get_meta(f, debias))
|
||||
grouped = engroup(fragments, lambda f: get_first_meta(f, debias))
|
||||
if UNKNOWN in grouped:
|
||||
fragments = grouped[UNKNOWN]
|
||||
del grouped[UNKNOWN]
|
||||
|
@ -342,15 +349,19 @@ def main(args):
|
|||
if args.split_seed is not None:
|
||||
random.seed(args.split_seed)
|
||||
|
||||
partitions = engroup(fragments, get_partition)
|
||||
|
||||
if args.split and args.split_field:
|
||||
metas = engroup(fragments, lambda f: get_meta(f, args.split_field)).items()
|
||||
if args.split_drop_multiple:
|
||||
fragments = filter(lambda f: len(get_meta_field(f, args.split_field)) < 2, fragments)
|
||||
if args.split_drop_unknown:
|
||||
fragments = filter(lambda f: len(get_meta_field(f, args.split_field)) > 0, fragments)
|
||||
fragments = list(fragments)
|
||||
metas = engroup(fragments, lambda f: get_first_meta(f, args.split_field)).items()
|
||||
metas = sorted(metas, key=lambda meta_frags: len(meta_frags[1]))
|
||||
metas = list(map(lambda meta_frags: meta_frags[0], metas))
|
||||
partitions = engroup(fragments, get_partition)
|
||||
partitions = list(map(lambda part_frags: (part_frags[0],
|
||||
get_sample_size(len(part_frags[1])),
|
||||
engroup(part_frags[1], lambda pf: get_meta(pf, args.split_field)),
|
||||
engroup(part_frags[1], lambda pf: get_first_meta(pf, args.split_field)),
|
||||
[[], [], []]),
|
||||
partitions.items()))
|
||||
remaining_metas = []
|
||||
|
@ -379,6 +390,7 @@ def main(args):
|
|||
for set_index, set_name in enumerate(SET_NAMES):
|
||||
assign_fragments(sample_sets[set_index], partition + '-' + set_name)
|
||||
else:
|
||||
partitions = engroup(fragments, get_partition)
|
||||
for partition, partition_fragments in partitions.items():
|
||||
if args.split:
|
||||
sample_size = get_sample_size(len(partition_fragments))
|
||||
|
|
Загрузка…
Ссылка в новой задаче