Fixed overlapping classes issues in iNat and animals extension merge

This commit is contained in:
Ubuntu 2019-04-26 08:46:16 +00:00
Родитель e95317fa9c
Коммит 539aabf159
1 изменённых файлов: 61 добавлений и 26 удалений

Просмотреть файл

@ -1,30 +1,65 @@
import json
# Minival
# - Extension
js = json.load(open('./minival_animalsExtended2017.json'))
splitid = [idx for idx,im in enumerate(js['images']) if im['license'] == 9][0]
js['images'] = js['images'][splitid:]
js['annotations'] = js['annotations'][splitid:]
# - Original iNat
inat = json.load(open('./minival2017.json'))
import ipdb; ipdb.set_trace()
js['images'] = inat['images'] + js['images']
js['annotations'] = inat['annotations'] + js['annotations']
# Write out
json.dump(js, open('./minival_iNatAllExtended2017.json', 'wt'))
def merge(input1, input2, output, old_new_mapping = [{}, {}]):
print('mapping', len(old_new_mapping[0]), len(old_new_mapping[1]))
# input 1 is the inat json
# input 2 is the extension-only json
# {659: 1221, 216: 375}
input1_class_blacklist = [1221, 375]
# First json
with open(input1, 'rt') as fi:
js1 = json.load(fi)
# Second json
with open(input2, 'rt') as fi:
js2 = json.load(fi)
# Trainval
# - Extension
js = json.load(open('./trainval_animalsExtended2017.json'))
splitid = [idx for idx,im in enumerate(js['images']) if im['license'] == 9][0]
js['images'] = js['images'][splitid:]
js['annotations'] = js['annotations'][splitid:]
# - Original iNat
inat = json.load(open('./trainval2017.json'))
import ipdb; ipdb.set_trace()
js['images'] = inat['images'] + js['images']
js['annotations'] = inat['annotations'] + js['annotations']
# Write out
json.dump(js, open('./trainval_iNatAllExtended2017.json', 'wt'))
# Delete duplicate classes from input1
images_to_delete = set()
for ann_idx in range(len(js1['annotations'])):
if js1['annotations'][ann_idx]['category_id'] in input1_class_blacklist:
images_to_delete.add(js1['annotations'][ann_idx]['image_id'])
for k,v in js1.items():
print(k,len(v))
js1['images'] = [im for im in js1['images'] if im['id'] not in images_to_delete]
js1['categories'] = [cat for cat in js1['categories'] if cat['id'] not in input1_class_blacklist]
js1['annotations'] = [ann for ann in js1['annotations'] if ann['category_id'] not in input1_class_blacklist]
for k,v in js1.items():
print(k,len(v))
# Renumber classes in input1
max_class_id = -1
for new_id, old_id in enumerate(list(set([ann['category_id'] for ann in js1['annotations']]))):
if old_id not in old_new_mapping[0].keys():
old_new_mapping[0][old_id] = max_class_id + 1
max_class_id += 1
for cat_idx in range(len(js1['categories'])):
js1['categories'][cat_idx]['id'] = old_new_mapping[0][js1['categories'][cat_idx]['id']]
for ann_idx in range(len(js2['annotations'])):
js1['annotations'][ann_idx]['category_id'] = old_new_mapping[0][js1['annotations'][ann_idx]['category_id']]
# Renumber classes in input2
max_class_id = max([cat['id'] for cat in js1['categories']])
for new_id, old_id in enumerate(list(set([ann['category_id'] for ann in js2['annotations']]))):
if old_id not in old_new_mapping[1].keys():
old_new_mapping[1][old_id] = max_class_id + 1
max_class_id += 1
#assert len(set([cat['id'] for cat in js1['categories']]) & set(old_new_mapping[1].values())) == 0
import ipdb; ipdb.set_trace()
js2['categories'] = [cat for cat in js2['categories'] if cat['id'] in old_new_mapping[1].keys()]
for cat_idx in range(len(js2['categories'])):
js2['categories'][cat_idx]['id'] = old_new_mapping[1][js2['categories'][cat_idx]['id']]
for ann_idx in range(len(js2['annotations'])):
js2['annotations'][ann_idx]['category_id'] = old_new_mapping[1][js2['annotations'][ann_idx]['category_id']]
import ipdb; ipdb.set_trace()
js1['images'] += js2['images']
js1['annotations'] += js2['annotations']
js1['categories'] += js2['categories']
# Write out
with open(output, 'wt') as fi:
json.dump(js1, open(output, 'wt'))
print('mapping', len(old_new_mapping[0]), len(old_new_mapping[1]))
return old_new_mapping
old_new_mapping = merge('trainval2017.json', 'trainval_animalsExtended2017_extensionOnly.json', 'trainval_iNatAllExtended2017.json')
merge('minival2017.json', 'minival_animalsExtended2017_extensionOnly.json', 'minival_iNatAllExtended2017.json', old_new_mapping)