141 строка
4.4 KiB
Python
141 строка
4.4 KiB
Python
from __future__ import print_function
|
|
try:
|
|
from urllib.request import urlretrieve
|
|
except ImportError:
|
|
from urllib import urlretrieve
|
|
import sys
|
|
import tarfile
|
|
import os
|
|
import numpy as np
|
|
import pickle as cp
|
|
from PIL import Image
|
|
import xml.etree.cElementTree as et
|
|
import xml.dom.minidom
|
|
from itertools import product, count
|
|
import fire
|
|
|
|
|
|
IMGSIZE = 32
|
|
NUMBER_OF_TRAINING_BATCHES = 5
|
|
CIFAR_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
|
|
DATA_DIR = 'data'
|
|
|
|
|
|
def download_data(src):
|
|
print('Downloading ' + src)
|
|
fname, h = urlretrieve(src, './delete.me')
|
|
print('Done.')
|
|
return fname
|
|
|
|
|
|
def extract(fname):
|
|
try:
|
|
print('Extracting files...')
|
|
with tarfile.open(fname) as tar:
|
|
tar.extractall()
|
|
print('Done.')
|
|
finally:
|
|
os.remove(fname)
|
|
|
|
|
|
def _pad_image(pixData, pad):
|
|
return np.pad(pixData, ((0, 0), (pad, pad), (pad, pad)), mode='constant',
|
|
constant_values=128) # can also use mode='edge'
|
|
|
|
|
|
def saveMean(fname, data):
|
|
root = et.Element('opencv_storage')
|
|
et.SubElement(root, 'Channel').text = '3'
|
|
et.SubElement(root, 'Row').text = str(IMGSIZE)
|
|
et.SubElement(root, 'Col').text = str(IMGSIZE)
|
|
meanImg = et.SubElement(root, 'MeanImg', type_id='opencv-matrix')
|
|
et.SubElement(meanImg, 'rows').text = '1'
|
|
et.SubElement(meanImg, 'cols').text = str(IMGSIZE * IMGSIZE * 3)
|
|
et.SubElement(meanImg, 'dt').text = 'f'
|
|
et.SubElement(meanImg, 'data').text = ' '.join(['%e' % n for n in np.reshape(data, (IMGSIZE * IMGSIZE * 3))])
|
|
|
|
tree = et.ElementTree(root)
|
|
tree.write(fname)
|
|
x = xml.dom.minidom.parse(fname)
|
|
with open(fname, 'w') as f:
|
|
f.write(x.toprettyxml(indent=' '))
|
|
|
|
|
|
def saveImage(fname, pixData, pad):
|
|
if pad > 0:
|
|
pixData = _pad_image(pixData, pad)
|
|
|
|
img = Image.new('RGB', (IMGSIZE + 2 * pad, IMGSIZE + 2 * pad))
|
|
pixels = img.load()
|
|
for x, y in product(range(img.size[0]), range(img.size[1])):
|
|
pixels[x, y] = (pixData[0][y][x], pixData[1][y][x], pixData[2][y][x])
|
|
img.save(fname)
|
|
|
|
|
|
def load_data_file(f):
|
|
if sys.version_info[0] < 3: # python 3
|
|
data = cp.load(f)
|
|
else:
|
|
data = cp.load(f, encoding='latin1')
|
|
return data['labels'], data['data']
|
|
|
|
|
|
def read_train_batch(frompath, batch_index):
|
|
return read_batch(os.path.join(frompath, "data_batch_{}".format(batch_index)))
|
|
|
|
|
|
def read_test_batch(frompath):
|
|
return read_batch(os.path.join(frompath, "test_batch"))
|
|
|
|
|
|
def read_batch(filename):
|
|
with open(filename, 'rb') as f:
|
|
labels, data = load_data_file(f)
|
|
for i in range(len(labels)):
|
|
yield labels[i], data[i, :].reshape((3, IMGSIZE, IMGSIZE))
|
|
|
|
|
|
def saveTrainImages(topath, map_filename='train_map.txt', mean_filename='CIFAR-10_mean.xml',
|
|
frompath='cifar-10-batches-py'):
|
|
if not os.path.exists(topath):
|
|
os.makedirs(topath)
|
|
|
|
file_num_generator = count(start=0, step=1)
|
|
dataSum = np.zeros((3, IMGSIZE, IMGSIZE)) # mean is in CHW format.
|
|
with open(map_filename, 'w') as mapFile:
|
|
for ifile in range(1, NUMBER_OF_TRAINING_BATCHES + 1): # Loop through batches
|
|
for label, data in read_train_batch(frompath, ifile):
|
|
fname = '%05d.png' % next(file_num_generator)
|
|
saveImage(os.path.join(topath, fname), data, 4)
|
|
mapFile.write("%s\t%d\n" % (fname, label))
|
|
dataSum += data
|
|
saveMean(mean_filename, dataSum / next(file_num_generator))
|
|
|
|
|
|
def saveTestImages(topath, filename='test_map.txt', frompath='cifar-10-batches-py'):
|
|
if not os.path.exists(topath):
|
|
os.makedirs(topath)
|
|
|
|
file_num_generator = count(start=0, step=1)
|
|
with open(filename, 'w') as mapFile:
|
|
for label, data in read_test_batch(frompath):
|
|
fname = '%05d.png' % next(file_num_generator)
|
|
saveImage(os.path.join(topath, fname), data, 4)
|
|
mapFile.write("%s\t%d\n" % (fname, label))
|
|
|
|
|
|
def main(data_dir=DATA_DIR):
|
|
fname = download_data(CIFAR_URL)
|
|
extract(fname)
|
|
train_path = os.path.join(data_dir, 'train')
|
|
test_path = os.path.join(data_dir, 'test')
|
|
saveTrainImages(train_path,
|
|
map_filename=os.path.join(data_dir, 'train_map.txt'),
|
|
mean_filename=os.path.join(data_dir, 'CIFAR-10_mean.xml'))
|
|
saveTestImages(test_path, os.path.join(data_dir, 'test_map.txt'))
|
|
|
|
|
|
if __name__=='__main__':
|
|
fire.Fire(main)
|
|
|