зеркало из https://github.com/mozilla/DSAlign.git
Better set splitting
This commit is contained in:
Родитель
72b8cc45f4
Коммит
36d3672a19
|
@ -33,7 +33,7 @@ def engroup(lst, get_key):
|
|||
return groups
|
||||
|
||||
|
||||
def get_set_sizes(population_size):
|
||||
def get_sample_size(population_size):
|
||||
margin_of_error = 0.01
|
||||
fraction_picking = 0.50
|
||||
z_score = 2.58 # Corresponds to confidence level 99%
|
||||
|
@ -48,7 +48,7 @@ def get_set_sizes(population_size):
|
|||
sample_size = int(numerator / denominator)
|
||||
if 2 * sample_size + train_size <= population_size:
|
||||
break
|
||||
return population_size - 2 * sample_size, sample_size
|
||||
return sample_size
|
||||
|
||||
|
||||
def load_segment(audio_path):
|
||||
|
@ -253,40 +253,54 @@ def main(args):
|
|||
if path.exists(path.join(target_dir, p)):
|
||||
fail('"{}" already existing - use --force to ignore'.format(p))
|
||||
|
||||
def assign_fragments(frags, name):
|
||||
ensure_list(name)
|
||||
for f in frags:
|
||||
f['list-name'] = name
|
||||
logging.info('Built set "{}" - samples: {}'.format(name, len(frags)))
|
||||
|
||||
if args.split_seed is not None:
|
||||
random.seed(args.split_seed)
|
||||
|
||||
partitions = engroup(fragments, get_partition)
|
||||
for partition, partition_fragments in partitions.items():
|
||||
logging.info('Partition "{}":'.format(partition))
|
||||
if not args.split or partition == 'other':
|
||||
ensure_list(partition)
|
||||
for fragment in partition_fragments:
|
||||
fragment['list-name'] = partition
|
||||
logging.info(' - samples: {}'.format(len(partition_fragments)))
|
||||
else:
|
||||
train_size, sample_size = get_set_sizes(len(partition_fragments))
|
||||
if args.split_field:
|
||||
portions = list(engroup(partition_fragments, lambda f: get_meta(f, args.split_field)).values())
|
||||
portions.sort(key=lambda p: len(p))
|
||||
train_set, dev_set, test_set = [], [], []
|
||||
for offset, sample_set in [(0, dev_set), (1, test_set)]:
|
||||
for portion in portions[offset::2]:
|
||||
if len(sample_set) < sample_size:
|
||||
sample_set.extend(portion)
|
||||
else:
|
||||
train_set.extend(portion)
|
||||
else:
|
||||
|
||||
if args.split and args.split_field:
|
||||
metas = engroup(fragments, lambda f: get_meta(f, args.split_field)).items()
|
||||
metas = sorted(metas, key=lambda meta_frags: len(meta_frags[1]))
|
||||
metas = list(map(lambda meta_frags: meta_frags[0], metas))
|
||||
partitions = list(map(lambda part_frags: (part_frags[0],
|
||||
get_sample_size(len(part_frags[1])),
|
||||
engroup(part_frags[1], lambda pf: get_meta(pf, args.split_field)),
|
||||
[[], []]),
|
||||
partitions.items()))
|
||||
for partition, sample_size, _, sample_sets in partitions:
|
||||
while len(metas) > 0 and (len(sample_sets[0]) < sample_size or len(sample_sets[1]) < sample_size):
|
||||
for sample_set_index, sample_set in enumerate(sample_sets):
|
||||
if len(metas) > 0 and sample_size > len(sample_set):
|
||||
meta = metas.pop(0)
|
||||
for _, _, partition_portions, other_sample_sets in partitions:
|
||||
if meta in partition_portions:
|
||||
other_sample_sets[sample_set_index].extend(partition_portions[meta])
|
||||
del partition_portions[meta]
|
||||
for partition, sample_size, partition_portions, sample_sets in partitions:
|
||||
train_set = []
|
||||
for portion in partition_portions.values():
|
||||
train_set.extend(portion)
|
||||
for set_name, set_fragments in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
|
||||
assign_fragments(set_fragments, partition + '-' + set_name)
|
||||
else:
|
||||
for partition, partition_fragments in partitions.items():
|
||||
if args.split:
|
||||
sample_size = get_sample_size(len(partition_fragments))
|
||||
random.shuffle(partition_fragments)
|
||||
test_set = partition_fragments[:sample_size]
|
||||
partition_fragments = partition_fragments[sample_size:]
|
||||
dev_set = partition_fragments[:sample_size]
|
||||
train_set = partition_fragments[sample_size:]
|
||||
for set_name, set_fragments in [('train', train_set), ('dev', dev_set), ('test', test_set)]:
|
||||
list_name = partition + '-' + set_name
|
||||
ensure_list(list_name)
|
||||
for fragment in set_fragments:
|
||||
fragment['list-name'] = list_name
|
||||
logging.info(' - sub-set "{}" - samples: {}'.format(set_name, len(set_fragments)))
|
||||
for set_name, set_fragments in [('train', train_set), ('dev', dev_set), ('test', test_set)]:
|
||||
assign_fragments(set_fragments, partition + '-' + set_name)
|
||||
else:
|
||||
assign_fragments(partition_fragments, partition)
|
||||
|
||||
for list_name in lists.keys():
|
||||
dir_name = path.join(target_dir, list_name)
|
||||
|
|
Загрузка…
Ссылка в новой задаче