petridishnn/petridish/data/inat.py

327 строки
12 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import os
import cv2
import json
import multiprocessing
from tensorpack.dataflow import DataFlow, PrefetchDataZMQ, \
LMDBData, LMDBDataPoint, PrefetchData, \
MapDataComponent, AugmentImageComponent, BatchData
from tensorpack.dataflow import imgaug, LocallyShuffleData
from tensorpack.utils import logger
from tensorpack.dataflow.serialize import LMDBSerializer
dump_dataflow_to_lmdb = LMDBSerializer.save
class INatRawData(DataFlow):
def __init__(
self,
inat_dir, anno_json_fn,
part=None, part_include=True,
shuffle=True,
allowed_labels=None):
"""
Some notes about the Inat data (applies to 2017 and 2018 sets):
1. Train and val have the same data directory structure and
are co-located. Since there is no public test set, the val set here
is considered the blind set and no parameters should be tuned on it
2. Labels are in the original data path for both.
3. Image list is already shuffled. Labels are 0-based.
Args:
inat_dir (str) : dirname for the root of inat raw data.
It should have the following structure:
<inat_dir>
<train.json>
<val.json>
<train_val_imagedir>/
Actinopterygii/
...
anno_json_fn (str) : Filename for the train/val json file
part (None or int) : Partition of the total dataset
part_include (bool) : Whether the partition is to be included (True) or not (False)
shuffle (bool) : whether to shuffle data every epoch.
allowed_labels (dict of int to int) :
map from the original label index to new label index.
None means identity mapping for all labels.
"""
assert os.path.exists(inat_dir), inat_dir
json_fn = os.path.join(inat_dir, anno_json_fn)
with open(json_fn) as fin:
jd = json.load(fin)
self.inat_dir = inat_dir
self.shuffle = shuffle
self.meta_info = []
self.allowed_labels = allowed_labels
tot_images = len(jd['images'])
assert len(jd['annotations']) == tot_images
cv_K = 10
part_unit = int(tot_images/cv_K)
#Here part is used for cros-validation
#A positive number indicates extracting that part
#A negative number indicates extracting the rest other than that part
cv = 0
if part is not None:
cv = abs(part)
assert cv <= cv_K
#part_unit = 150000
cv_start = cv * part_unit
cv_end = min(tot_images, (cv+1) * part_unit)
print(cv_start, cv_end)
#need to shuffle before creating the parts
#This needs to be a deterministic shuffle for the cv splits to be meaningful
merged_data = list(zip(jd['images'], jd['annotations']))
if self.shuffle:
np.random.seed(0)
np.random.shuffle(merged_data)
# create meta data info
for idx, (info, anno) in enumerate(merged_data):
if part is not None:
#Include only cv-split part (used for cv-val set)
if part_include:
if idx < cv_start:
continue
if idx >= cv_end:
break
#Exclude cv-split part (used for cv-train set)
else:
if (idx >= cv_start) and (idx < cv_end):
continue
fn = info['file_name']
label = anno['category_id']
if self.allowed_labels is None or label in self.allowed_labels:
self.meta_info.append((fn, self.allowed_labels[label]))
print(len(self.meta_info))
def get_data(self):
if self.shuffle:
np.random.shuffle(self.meta_info)
for info in self.meta_info:
fn, label = info
full_fn = os.path.join(self.inat_dir, fn)
with open(full_fn, 'rb') as fin:
jpeg = fin.read()
jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
yield [jpeg, label]
def size(self):
return len(self.meta_info)
image_mean = [0.485, 0.456, 0.406][::-1]
image_std = [0.229, 0.224, 0.225][::-1]
def get_inat_augmented_data(
subset, options,
lmdb_dir=None,
year='2018',
do_multiprocess=True,
do_validation=False,
is_train=None,
shuffle=None,
n_allow=None):
input_size = options.input_size if options.input_size else 224
isTrain = is_train if is_train is not None else (subset == 'train' and do_multiprocess)
shuffle = shuffle if shuffle is not None else isTrain
postfix = "" if n_allow is None else "_allow_{}".format(n_allow)
#TODO: Parameterize the cv split to be consider
#Currently hardcoding to 1
cv = 1
# When do_validation is True it will expect *cv_train and *cv_val lmdbs
# Currently the cv_train split is always used
if isTrain:
postfix += '_cv_train_{}'.format(cv)
elif do_validation:
subset = 'train'
postfix += '_cv_val_{}'.format(cv)
if lmdb_dir == None:
lmdb_path = os.path.join(
options.data_dir, 'inat_lmdb', 'inat2018_{}{}.lmdb'.format(subset, postfix))
else:
lmdb_path = os.path.join(
options.data_dir, lmdb_dir, 'inat{}_{}{}.lmdb'.format(year,subset, postfix))
ds = LMDBData(lmdb_path, shuffle=False)
if shuffle:
ds = LocallyShuffleData(ds, 1024*80) # This is 64G~80G in memory images
ds = PrefetchData(ds, 1024*8, 1) # prefetch around 8 G
ds = LMDBDataPoint(ds)
ds = MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0) # BGR uint8 data
if isTrain:
class Resize(imgaug.ImageAugmentor):
"""
crop 8%~100% of the original image
See `Going Deeper with Convolutions` by Google.
"""
def _augment(self, img, _):
h, w = img.shape[:2]
area = h * w
for _ in range(10):
targetArea = self.rng.uniform(0.08, 1.0) * area
aspectR = self.rng.uniform(0.75, 1.333)
ww = int(np.sqrt(targetArea * aspectR))
hh = int(np.sqrt(targetArea / aspectR))
if self.rng.uniform() < 0.5:
ww, hh = hh, ww
if hh <= h and ww <= w:
x1 = 0 if w == ww else self.rng.randint(0, w - ww)
y1 = 0 if h == hh else self.rng.randint(0, h - hh)
out = img[y1:y1 + hh, x1:x1 + ww]
out = cv2.resize(
out, (input_size, input_size), interpolation=cv2.INTER_CUBIC)
return out
out = cv2.resize(
img, (input_size, input_size), interpolation=cv2.INTER_CUBIC)
return out
augmentors = [
Resize(),
imgaug.RandomOrderAug(
[imgaug.Brightness(30, clip=False),
imgaug.Contrast((0.8, 1.2), clip=False),
imgaug.Saturation(0.4),
# rgb-bgr conversion
imgaug.Lighting(0.1,
eigval=[0.2175, 0.0188, 0.0045][::-1],
eigvec=np.array(
[[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]],
dtype='float32')[::-1, ::-1]
)]),
imgaug.Clip(),
imgaug.Flip(horiz=True),
imgaug.ToUint8()
]
else:
augmentors = [
imgaug.ResizeShortestEdge(256),
imgaug.CenterCrop((input_size, input_size)),
imgaug.ToUint8()
]
ds = AugmentImageComponent(ds, augmentors, copy=False)
if do_multiprocess:
ds = PrefetchDataZMQ(ds, min(24, multiprocessing.cpu_count()))
ds = BatchData(ds, options.batch_size // options.nr_gpu, remainder=not isTrain)
return ds
def inat_raw_to_lmdb(
inat_dir=None,
inat_lmdb_dir=None,
allowed_labels=None,
year='2018',
do_crossval=False,
splits=['train', 'val']):
if inat_dir is None or inat_lmdb_dir is None:
try:
data_root = os.environ['GLOBAL_DATA_DIR']
if inat_dir is None:
inat_dir = os.path.join(data_root, 'inat')
if inat_lmdb_dir is None:
inat_lmdb_dir = os.path.join(data_root, 'inat_lmdb')
except:
logger.info('GLOBAL_DATA_DIR is not set as an env variable')
raise
if not os.path.exists(inat_dir):
raise Exception("inat dir does not exist")
if not os.path.exists(inat_lmdb_dir):
os.makedirs(inat_lmdb_dir)
if allowed_labels is not None:
label_map_fn = os.path.join(
inat_lmdb_dir, 'label_map_{}.npz'.format(len(allowed_labels)))
np.savez(label_map_fn, allowed_labels=allowed_labels)
for split in splits:
is_train = split == 'train'
postfix = ''
if allowed_labels:
postfix += "_allow_{}".format(len(allowed_labels))
anno_fn = '{}{}.json'.format(split, year)
if is_train and do_crossval:
cv_K = 10
for cv in range(cv_K):
postfix_cv_train = postfix + '_cv_train_{}'.format(cv)
postfix_cv_val = postfix + '_cv_val_{}'.format(cv)
#Train part of cv
ds0 = INatRawData(
inat_dir,
anno_fn,
part=cv,
part_include=False,
allowed_labels=allowed_labels,
shuffle=True)
ds1 = PrefetchDataZMQ(ds0, nr_proc=1)
dump_dataflow_to_lmdb(
ds1,
os.path.join(
inat_lmdb_dir,
'inat{}_{}{}.lmdb'.format(year, split, postfix_cv_train)
)
)
#Val part of cv
ds2 = INatRawData(
inat_dir,
anno_fn,
part=cv,
part_include=True,
allowed_labels=allowed_labels,
shuffle=False)
ds3 = PrefetchDataZMQ(ds2, nr_proc=1)
dump_dataflow_to_lmdb(
ds3,
os.path.join(
inat_lmdb_dir,
'inat{}_{}{}.lmdb'.format(year, split, postfix_cv_val)
)
)
else:
ds0 = INatRawData(
inat_dir,
anno_fn,
part=None,
allowed_labels=allowed_labels,
shuffle=is_train)
ds1 = PrefetchDataZMQ(ds0, nr_proc=1)
dump_dataflow_to_lmdb(
ds1,
os.path.join(
inat_lmdb_dir,
'inat{}_{}{}.lmdb'.format(year, split, postfix)
)
)
if __name__ == '__main__':
n_categories = 5089
n_allowed = 1000 # 100
#inat_raw_to_lmdb()
np.random.seed(19921102)
allowed_labels = list(np.random.choice(range(n_categories), n_allowed, replace=False))
orig_to_new = dict()
for idx, label in enumerate(allowed_labels):
orig_to_new[label] = idx
inat_raw_to_lmdb(inat_dir='/media/data/saurajim/inat2017_data', inat_lmdb_dir='/media/data/saurajim/inat2017_data/lmdb', allowed_labels=orig_to_new, year='2017', do_crossval=True)