This commit is contained in:
stonebig 2016-01-26 18:33:10 +01:00
Родитель 278ae80c53
Коммит adc8410e5f
1 изменённых файлов: 15 добавлений и 15 удалений

Просмотреть файл

@ -38,33 +38,33 @@ def readBatch(src, outFmt):
return np.hstack((np.reshape(d['labels'], (len(d['labels']), 1)), feat)) return np.hstack((np.reshape(d['labels'], (len(d['labels']), 1)), feat))
def loadData(src, outFmt): def loadData(src, outFmt):
print 'Downloading ' + src print ('Downloading ' + src)
fname, h = urllib.urlretrieve(src, './delete.me') fname, h = urllib.urlretrieve(src, './delete.me')
print 'Done.' print ('Done.')
try: try:
print 'Extracting files...' print ('Extracting files...')
with tarfile.open(fname) as tar: with tarfile.open(fname) as tar:
tar.extractall() tar.extractall()
print 'Done.' print ('Done.')
print 'Preparing train set...' print ('Preparing train set...')
trn = np.empty((0, NumFeat + 1)) trn = np.empty((0, NumFeat + 1))
for i in range(5): for i in range(5):
batchName = './cifar-10-batches-py/data_batch_{0}'.format(i + 1) batchName = './cifar-10-batches-py/data_batch_{0}'.format(i + 1)
trn = np.vstack((trn, readBatch(batchName, outFmt))) trn = np.vstack((trn, readBatch(batchName, outFmt)))
print 'Done.' print ('Done.')
print 'Preparing test set...' print ('Preparing test set...')
tst = readBatch('./cifar-10-batches-py/test_batch', outFmt) tst = readBatch('./cifar-10-batches-py/test_batch', outFmt)
print 'Done.' print ('Done.')
finally: finally:
os.remove(fname) os.remove(fname)
return (trn, tst) return (trn, tst)
def usage(): def usage():
print 'Usage: CIFAR_convert.py [-f <format>] \n where format can be either cudnn or legacy. Default is cudnn.' print ('Usage: CIFAR_convert.py [-f <format>] \n where format can be either cudnn or legacy. Default is cudnn.')
def parseCmdOpt(argv): def parseCmdOpt(argv):
if len(argv) == 0: if len(argv) == 0:
print "Using cudnn output format." print ("Using cudnn output format.")
return "cudnn" return "cudnn"
try: try:
opts, args = getopt.getopt(argv, 'hf:', ['help', 'outFormat=']) opts, args = getopt.getopt(argv, 'hf:', ['help', 'outFormat='])
@ -78,7 +78,7 @@ def parseCmdOpt(argv):
elif opt in ('-f', '--outFormat'): elif opt in ('-f', '--outFormat'):
fmt = arg fmt = arg
if fmt != 'cudnn' and fmt != 'legacy': if fmt != 'cudnn' and fmt != 'legacy':
print 'Invalid output format option.' print ('Invalid output format option.')
usage() usage()
sys.exit(1) sys.exit(1)
return fmt return fmt
@ -86,9 +86,9 @@ def parseCmdOpt(argv):
if __name__ == "__main__": if __name__ == "__main__":
fmt = parseCmdOpt(sys.argv[1:]) fmt = parseCmdOpt(sys.argv[1:])
trn, tst = loadData('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fmt) trn, tst = loadData('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fmt)
print 'Writing train text file...' print ('Writing train text file...')
np.savetxt(r'./Train.txt', trn, fmt = '%u', delimiter='\t') np.savetxt(r'./Train.txt', trn, fmt = '%u', delimiter='\t')
print 'Done.' print ('Done.')
print 'Writing test text file...' print ('Writing test text file...')
np.savetxt(r'./Test.txt', tst, fmt = '%u', delimiter='\t') np.savetxt(r'./Test.txt', tst, fmt = '%u', delimiter='\t')
print 'Done.' print ('Done.')