This commit is contained in:
Tilman Kamp 2019-09-17 15:38:36 +02:00
Родитель 72b8cc45f4
Коммит 36d3672a19
1 изменённых файлов: 42 добавлений и 28 удалений

Просмотреть файл

@ -33,7 +33,7 @@ def engroup(lst, get_key):
return groups
def get_set_sizes(population_size):
def get_sample_size(population_size):
margin_of_error = 0.01
fraction_picking = 0.50
z_score = 2.58 # Corresponds to confidence level 99%
@ -48,7 +48,7 @@ def get_set_sizes(population_size):
sample_size = int(numerator / denominator)
if 2 * sample_size + train_size <= population_size:
break
return population_size - 2 * sample_size, sample_size
return sample_size
def load_segment(audio_path):
@ -253,40 +253,54 @@ def main(args):
if path.exists(path.join(target_dir, p)):
fail('"{}" already existing - use --force to ignore'.format(p))
def assign_fragments(frags, name):
ensure_list(name)
for f in frags:
f['list-name'] = name
logging.info('Built set "{}" - samples: {}'.format(name, len(frags)))
if args.split_seed is not None:
random.seed(args.split_seed)
partitions = engroup(fragments, get_partition)
for partition, partition_fragments in partitions.items():
logging.info('Partition "{}":'.format(partition))
if not args.split or partition == 'other':
ensure_list(partition)
for fragment in partition_fragments:
fragment['list-name'] = partition
logging.info(' - samples: {}'.format(len(partition_fragments)))
else:
train_size, sample_size = get_set_sizes(len(partition_fragments))
if args.split_field:
portions = list(engroup(partition_fragments, lambda f: get_meta(f, args.split_field)).values())
portions.sort(key=lambda p: len(p))
train_set, dev_set, test_set = [], [], []
for offset, sample_set in [(0, dev_set), (1, test_set)]:
for portion in portions[offset::2]:
if len(sample_set) < sample_size:
sample_set.extend(portion)
else:
train_set.extend(portion)
else:
if args.split and args.split_field:
metas = engroup(fragments, lambda f: get_meta(f, args.split_field)).items()
metas = sorted(metas, key=lambda meta_frags: len(meta_frags[1]))
metas = list(map(lambda meta_frags: meta_frags[0], metas))
partitions = list(map(lambda part_frags: (part_frags[0],
get_sample_size(len(part_frags[1])),
engroup(part_frags[1], lambda pf: get_meta(pf, args.split_field)),
[[], []]),
partitions.items()))
for partition, sample_size, _, sample_sets in partitions:
while len(metas) > 0 and (len(sample_sets[0]) < sample_size or len(sample_sets[1]) < sample_size):
for sample_set_index, sample_set in enumerate(sample_sets):
if len(metas) > 0 and sample_size > len(sample_set):
meta = metas.pop(0)
for _, _, partition_portions, other_sample_sets in partitions:
if meta in partition_portions:
other_sample_sets[sample_set_index].extend(partition_portions[meta])
del partition_portions[meta]
for partition, sample_size, partition_portions, sample_sets in partitions:
train_set = []
for portion in partition_portions.values():
train_set.extend(portion)
for set_name, set_fragments in [('train', train_set), ('dev', sample_sets[0]), ('test', sample_sets[1])]:
assign_fragments(set_fragments, partition + '-' + set_name)
else:
for partition, partition_fragments in partitions.items():
if args.split:
sample_size = get_sample_size(len(partition_fragments))
random.shuffle(partition_fragments)
test_set = partition_fragments[:sample_size]
partition_fragments = partition_fragments[sample_size:]
dev_set = partition_fragments[:sample_size]
train_set = partition_fragments[sample_size:]
for set_name, set_fragments in [('train', train_set), ('dev', dev_set), ('test', test_set)]:
list_name = partition + '-' + set_name
ensure_list(list_name)
for fragment in set_fragments:
fragment['list-name'] = list_name
logging.info(' - sub-set "{}" - samples: {}'.format(set_name, len(set_fragments)))
for set_name, set_fragments in [('train', train_set), ('dev', dev_set), ('test', test_set)]:
assign_fragments(set_fragments, partition + '-' + set_name)
else:
assign_fragments(partition_fragments, partition)
for list_name in lists.keys():
dir_name = path.join(target_dir, list_name)