python 3 compatibility fix
This commit is contained in:
Родитель
278ae80c53
Коммит
adc8410e5f
|
@ -38,33 +38,33 @@ def readBatch(src, outFmt):
|
||||||
return np.hstack((np.reshape(d['labels'], (len(d['labels']), 1)), feat))
|
return np.hstack((np.reshape(d['labels'], (len(d['labels']), 1)), feat))
|
||||||
|
|
||||||
def loadData(src, outFmt):
|
def loadData(src, outFmt):
|
||||||
print 'Downloading ' + src
|
print ('Downloading ' + src)
|
||||||
fname, h = urllib.urlretrieve(src, './delete.me')
|
fname, h = urllib.urlretrieve(src, './delete.me')
|
||||||
print 'Done.'
|
print ('Done.')
|
||||||
try:
|
try:
|
||||||
print 'Extracting files...'
|
print ('Extracting files...')
|
||||||
with tarfile.open(fname) as tar:
|
with tarfile.open(fname) as tar:
|
||||||
tar.extractall()
|
tar.extractall()
|
||||||
print 'Done.'
|
print ('Done.')
|
||||||
print 'Preparing train set...'
|
print ('Preparing train set...')
|
||||||
trn = np.empty((0, NumFeat + 1))
|
trn = np.empty((0, NumFeat + 1))
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
batchName = './cifar-10-batches-py/data_batch_{0}'.format(i + 1)
|
batchName = './cifar-10-batches-py/data_batch_{0}'.format(i + 1)
|
||||||
trn = np.vstack((trn, readBatch(batchName, outFmt)))
|
trn = np.vstack((trn, readBatch(batchName, outFmt)))
|
||||||
print 'Done.'
|
print ('Done.')
|
||||||
print 'Preparing test set...'
|
print ('Preparing test set...')
|
||||||
tst = readBatch('./cifar-10-batches-py/test_batch', outFmt)
|
tst = readBatch('./cifar-10-batches-py/test_batch', outFmt)
|
||||||
print 'Done.'
|
print ('Done.')
|
||||||
finally:
|
finally:
|
||||||
os.remove(fname)
|
os.remove(fname)
|
||||||
return (trn, tst)
|
return (trn, tst)
|
||||||
|
|
||||||
def usage():
|
def usage():
|
||||||
print 'Usage: CIFAR_convert.py [-f <format>] \n where format can be either cudnn or legacy. Default is cudnn.'
|
print ('Usage: CIFAR_convert.py [-f <format>] \n where format can be either cudnn or legacy. Default is cudnn.')
|
||||||
|
|
||||||
def parseCmdOpt(argv):
|
def parseCmdOpt(argv):
|
||||||
if len(argv) == 0:
|
if len(argv) == 0:
|
||||||
print "Using cudnn output format."
|
print ("Using cudnn output format.")
|
||||||
return "cudnn"
|
return "cudnn"
|
||||||
try:
|
try:
|
||||||
opts, args = getopt.getopt(argv, 'hf:', ['help', 'outFormat='])
|
opts, args = getopt.getopt(argv, 'hf:', ['help', 'outFormat='])
|
||||||
|
@ -78,7 +78,7 @@ def parseCmdOpt(argv):
|
||||||
elif opt in ('-f', '--outFormat'):
|
elif opt in ('-f', '--outFormat'):
|
||||||
fmt = arg
|
fmt = arg
|
||||||
if fmt != 'cudnn' and fmt != 'legacy':
|
if fmt != 'cudnn' and fmt != 'legacy':
|
||||||
print 'Invalid output format option.'
|
print ('Invalid output format option.')
|
||||||
usage()
|
usage()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
return fmt
|
return fmt
|
||||||
|
@ -86,9 +86,9 @@ def parseCmdOpt(argv):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fmt = parseCmdOpt(sys.argv[1:])
|
fmt = parseCmdOpt(sys.argv[1:])
|
||||||
trn, tst = loadData('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fmt)
|
trn, tst = loadData('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', fmt)
|
||||||
print 'Writing train text file...'
|
print ('Writing train text file...')
|
||||||
np.savetxt(r'./Train.txt', trn, fmt = '%u', delimiter='\t')
|
np.savetxt(r'./Train.txt', trn, fmt = '%u', delimiter='\t')
|
||||||
print 'Done.'
|
print ('Done.')
|
||||||
print 'Writing test text file...'
|
print ('Writing test text file...')
|
||||||
np.savetxt(r'./Test.txt', tst, fmt = '%u', delimiter='\t')
|
np.savetxt(r'./Test.txt', tst, fmt = '%u', delimiter='\t')
|
||||||
print 'Done.'
|
print ('Done.')
|
||||||
|
|
Загрузка…
Ссылка в новой задаче