Bonsai Tf Code, cleaned and commented

This commit is contained in:
Aditya Kusupati 2018-02-09 20:53:03 +05:30
Родитель c5b5482b1f
Коммит 7f52c63944
4 изменённых файлов: 525 добавлений и 0 удалений

210
tf/bonsai/bonsai.py Normal file
Просмотреть файл

@ -0,0 +1,210 @@
import utils
import tensorflow as tf
import numpy as np
## Bonsai Class
class Bonsai:
## Constructor
def __init__(self, C, F, P, D, S, lW, lT, lV, lZ,
sW, sT, sV, sZ, lr = None, W = None, T = None,
V = None, Z = None, feats = None):
self.dataDimension = F + 1
self.projectionDimension = P
if (C > 2):
self.numClasses = C
elif (C == 2):
self.numClasses = 1
self.treeDepth = D
self.sigma = S
## Regularizer coefficients
self.lW = lW
self.lV = lV
self.lT = lT
self.lZ = lZ
## Sparsity hyperparams
self.sW = sW
self.sV = sV
self.sT = sT
self.sZ = sZ
self.internalNodes = 2**self.treeDepth - 1
self.totalNodes = 2*self.internalNodes + 1
## The Parameters of Bonsai
self.W = self.initW(W)
self.V = self.initV(V)
self.T = self.initT(T)
self.Z = self.initZ(Z)
## Placeholders for Hard Thresholding and Sparse Training
self.Wth = tf.placeholder(tf.float32, name='Wth')
self.Vth = tf.placeholder(tf.float32, name='Vth')
self.Zth = tf.placeholder(tf.float32, name='Zth')
self.Tth = tf.placeholder(tf.float32, name='Tth')
## Placeholders for Features and labels
## feats are to be fed when joint training is being done with Bonsai as end classifier
if feats is None:
self.x = tf.placeholder("float", [None, self.dataDimension])
else:
self.x = feats
self.y = tf.placeholder("float", [None, self.numClasses])
## Placeholder for batch size, needed for Multiclass hinge loss
self.batch_th = tf.placeholder(tf.int64, name='batch_th')
self.sigmaI = 1.0
if lr is not None:
self.learningRate = lr
else:
self.learningRate = 0.01
## Functions to setup required graphs
self.hardThrsd()
self.sparseTraining()
self.lossGraph()
self.trainGraph()
self.accuracyGraph()
## Functions to initilaise Params (Warm start possible with given numpy matrices)
def initZ(self, Z):
if Z is None:
Z = tf.random_normal([self.projectionDimension, self.dataDimension])
Z = tf.Variable(Z, name='Z', dtype=tf.float32)
return Z
def initW(self, W):
if W is None:
W = tf.random_normal([self.numClasses*self.totalNodes, self.projectionDimension])
W = tf.Variable(W, name='W', dtype=tf.float32)
return W
def initV(self, V):
if V is None:
V = tf.random_normal([self.numClasses*self.totalNodes, self.projectionDimension])
V = tf.Variable(V, name='V', dtype=tf.float32)
return V
def initT(self, T):
if T is None:
T = tf.random_normal([self.internalNodes, self.projectionDimension])
T = tf.Variable(T, name='T', dtype=tf.float32)
return T
## Function to get aimed model size
def getModelSize(self):
nnzZ = np.ceil(int(self.Z.shape[0]*self.Z.shape[1])*self.sZ)
nnzW = np.ceil(int(self.W.shape[0]*self.W.shape[1])*self.sW)
nnzV = np.ceil(int(self.V.shape[0]*self.V.shape[1])*self.sV)
nnzT = np.ceil(int(self.T.shape[0]*self.T.shape[1])*self.sT)
return int((nnzZ+nnzT+nnzV+nnzW)*8)
## Function to build the Bonsai Tree graph
def bonsaiGraph(self, X):
X = tf.reshape(X, [-1,self.dataDimension])
X_ = tf.divide(tf.matmul(self.Z, X, transpose_b=True), self.projectionDimension)
W_ = self.W[0:(self.numClasses)]
V_ = self.V[0:(self.numClasses)]
self.nodeProb = []
self.nodeProb.append(1)
score_ = self.nodeProb[0]*tf.multiply(tf.matmul(W_, X_), tf.tanh(self.sigma*tf.matmul(V_, X_)))
for i in range(1, self.totalNodes):
W_ = self.W[i*self.numClasses:((i+1)*self.numClasses)]
V_ = self.V[i*self.numClasses:((i+1)*self.numClasses)]
prob = (1+((-1)**(i+1))*tf.tanh(tf.multiply(self.sigmaI,
tf.matmul(tf.reshape(self.T[int(np.ceil(i/2)-1)], [-1, self.projectionDimension]), X_))))
prob = tf.divide(prob, 2)
prob = self.nodeProb[int(np.ceil(i/2)-1)]*prob
self.nodeProb.append(prob)
score_ += self.nodeProb[i]*tf.multiply(tf.matmul(W_, X_), tf.tanh(self.sigma*tf.matmul(V_, X_)))
return score_, X_, self.T, self.W, self.V, self.Z
## Functions setting up graphs for IHT and Sparse Retraining
def hardThrsd(self):
self.Woph = self.W.assign(self.Wth)
self.Voph = self.V.assign(self.Vth)
self.Toph = self.T.assign(self.Tth)
self.Zoph = self.Z.assign(self.Zth)
self.hardThresholdGroup = tf.group(self.Woph, self.Voph, self.Toph, self.Zoph)
def runHardThrsd(self, sess):
currW = self.Weval.eval()
currV = self.Veval.eval()
currZ = self.Zeval.eval()
currT = self.Teval.eval()
self.thrsdW = utils.hardThreshold(currW, self.sW)
self.thrsdV = utils.hardThreshold(currV, self.sV)
self.thrsdZ = utils.hardThreshold(currZ, self.sZ)
self.thrsdT = utils.hardThreshold(currT, self.sT)
fd_thrsd = {self.Wth:self.thrsdW, self.Vth:self.thrsdV, self.Zth:self.thrsdZ, self.Tth:self.thrsdT}
sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd)
def sparseTraining(self):
self.Wops = self.W.assign(self.Wth)
self.Vops = self.V.assign(self.Vth)
self.Zops = self.Z.assign(self.Zth)
self.Tops = self.T.assign(self.Tth)
self.sparseRetrainGroup = tf.group(self.Wops, self.Vops, self.Tops, self.Zops)
def runSparseTraining(self, sess):
currW = self.Weval.eval()
currV = self.Veval.eval()
currZ = self.Zeval.eval()
currT = self.Teval.eval()
newW = utils.copySupport(self.thrsdW, currW)
newV = utils.copySupport(self.thrsdV, currV)
newZ = utils.copySupport(self.thrsdZ, currZ)
newT = utils.copySupport(self.thrsdT, currT)
fd_st = {self.Wth:newW, self.Vth:newV, self.Zth:newZ, self.Tth:newT}
sess.run(self.sparseRetrainGroup, feed_dict=fd_st)
## Function to build a Loss graph for Bonsai
def lossGraph(self):
self.score, self.Xeval, self.Teval, self.Weval, self.Veval, self.Zeval = self.bonsaiGraph(self.x)
self.regLoss = 0.5*(self.lZ*tf.square(tf.norm(self.Z)) + self.lW*tf.square(tf.norm(self.W)) +
self.lV*tf.square(tf.norm(self.V)) + self.lT*tf.square(tf.norm(self.T)))
if (self.numClasses > 2):
self.marginLoss = utils.multiClassHingeLoss(tf.transpose(self.score), tf.argmax(self.y,1), self.batch_th)
self.loss = self.marginLoss + self.regLoss
else:
self.marginLoss = tf.reduce_mean(tf.nn.relu(1.0 - (2*self.y-1)*tf.transpose(self.score)))
self.loss = self.marginLoss + self.regLoss
## Function to set up optimisation for Bonsai
def trainGraph(self):
self.trainStep = tf.train.AdamOptimizer(self.learningRate).minimize(self.loss)
## Function to run training step on Bonsai
def runTraining(self, sess, _feed_dict):
sess.run([self.trainStep], feed_dict=_feed_dict)
## Function to build a graph to compute accuracy of the current model
def accuracyGraph(self):
if (self.numClasses > 2):
correctPrediction = tf.equal(tf.argmax(tf.transpose(self.score),1), tf.argmax(self.y,1))
self.accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32))
else:
y_ = self.y*2-1
correctPrediction = tf.multiply(tf.transpose(self.score), y_)
correctPrediction = tf.nn.relu(correctPrediction)
correctPrediction = tf.ceil(tf.tanh(correctPrediction))
self.accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32))

3
tf/bonsai/readme.txt Normal file
Просмотреть файл

@ -0,0 +1,3 @@
python train.py --help
This gives you the usage instructiosn.

177
tf/bonsai/train.py Normal file
Просмотреть файл

@ -0,0 +1,177 @@
import utils
import tensorflow as tf
import numpy as np
import sys
from bonsai import Bonsai
## Fixing seeds for reproducibility
tf.set_random_seed(42)
np.random.seed(42)
## Hyper Param pre-processing
args = utils.getArgs()
sigma = args.sigma
depth = args.depth
projectionDimension = args.projDim
regZ = args.rZ
regT = args.rT
regW = args.rW
regV = args.rV
totalEpochs = args.epochs
learningRate = args.learningRate
data_dir = args.data_dir
(dataDimension, numClasses,
Xtrain, Ytrain, Xtest, Ytest) = utils.preProcessData(data_dir)
sparZ = args.sZ
if numClasses == 2:
sparW = 1
sparV = 1
sparT = 1
else:
sparW = 0.2
sparV = 0.2
sparT = 0.2
if args.sW is not None:
sparW = args.sW
if args.sV is not None:
sparV = args.sV
if args.sT is not None:
sparT = args.sT
if args.batchSize is None:
batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0]))))
else:
batchSize = args.batchSize
## Creation of Bonsai Object
bonsaiObj = Bonsai(numClasses, dataDimension, projectionDimension, depth, sigma,
regW, regT, regV, regZ, sparW, sparT, sparV, sparZ, lr = learningRate)
sess = tf.InteractiveSession()
sess.run(tf.group(tf.initialize_all_variables(), tf.initialize_variables(tf.local_variables())))
saver = tf.train.Saver() ## Use it incase of saving the model
numIters = Xtrain.shape[0]/batchSize
totalBatches = numIters*totalEpochs
counter = 0
if bonsaiObj.numClasses > 2:
trimlevel = 15
else:
trimlevel = 5
ihtDone = 0
for i in range(totalEpochs):
print("\nEpoch Number: "+str(i))
trainAcc = 0.0
for j in range(numIters):
if counter == 0:
print("\n******************** Dense Training Phase Started ********************\n")
## Updating the indicator sigma
if ((counter == 0) or (counter == int(totalBatches/3)) or (counter == int(2*totalBatches/3))):
bonsaiObj.sigmaI = 1
itersInPhase = 0
elif (itersInPhase%100 == 0):
indices = np.random.choice(Xtrain.shape[0],100)
batchX = Xtrain[indices,:]
batchY = Ytrain[indices,:]
batchY = np.reshape(batchY, [-1, bonsaiObj.numClasses])
_feed_dict = {bonsaiObj.x: batchX, bonsaiObj.y: batchY}
Xcapeval = bonsaiObj.Xeval.eval(feed_dict=_feed_dict)
Teval = bonsaiObj.Teval.eval()
sum_tr = 0.0
for k in range(0, bonsaiObj.internalNodes):
sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval))))
if(bonsaiObj.internalNodes > 0):
sum_tr /= (100*bonsaiObj.internalNodes)
sum_tr = 0.1/sum_tr
else:
sum_tr = 0.1
sum_tr = min(1000,sum_tr*(2**(float(itersInPhase)/(float(totalBatches)/30.0))))
bonsaiObj.sigmaI = sum_tr
itersInPhase += 1
batchX = Xtrain[j*batchSize:(j+1)*batchSize]
batchY = Ytrain[j*batchSize:(j+1)*batchSize]
batchY = np.reshape(batchY, [-1, bonsaiObj.numClasses])
if bonsaiObj.numClasses > 2:
_feed_dict = {bonsaiObj.x: batchX, bonsaiObj.y: batchY, bonsaiObj.batch_th: batchY.shape[0]}
else:
_feed_dict = {bonsaiObj.x: batchX, bonsaiObj.y: batchY}
## Mini-batch training
batchLoss = bonsaiObj.runTraining(sess, _feed_dict)
batchAcc = bonsaiObj.accuracy.eval(feed_dict=_feed_dict)
trainAcc += batchAcc
## Training routine involving IHT and sparse retraining
if (counter >= int(totalBatches/3) and (counter < int(2*totalBatches/3)) and counter%trimlevel == 0):
bonsaiObj.runHardThrsd(sess)
if ihtDone == 0:
print("\n******************** IHT Phase Started ********************\n")
ihtDone = 1
elif ((ihtDone == 1 and counter >= int(totalBatches/3) and (counter < int(2*totalBatches/3))
and counter%trimlevel != 0) or (counter >= int(2*totalBatches/3))):
bonsaiObj.runSparseTraining(sess)
if counter == int(2*totalBatches/3):
print("\n******************** Sprase Retraining Phase Started ********************\n")
counter += 1
print("Train accuracy "+str(trainAcc/numIters))
if bonsaiObj.numClasses > 2:
_feed_dict = {bonsaiObj.x: Xtest, bonsaiObj.y: Ytest, bonsaiObj.batch_th: Ytest.shape[0]}
else:
_feed_dict = {bonsaiObj.x: Xtest, bonsaiObj.y: Ytest}
## this helps in direct testing instead of extracting the model out
oldSigmaI = bonsaiObj.sigmaI
bonsaiObj.sigmaI = 1e9
testAcc = bonsaiObj.accuracy.eval(feed_dict=_feed_dict)
if ihtDone == 0:
maxTestAcc = -10000
maxTestAccEpoch = i
else:
if maxTestAcc <= testAcc:
maxTestAccEpoch = i
maxTestAcc = testAcc
print("Test accuracy %g"%testAcc)
testLoss = bonsaiObj.loss.eval(feed_dict=_feed_dict)
regTestLoss = bonsaiObj.regLoss.eval(feed_dict=_feed_dict)
print("MarginLoss + RegLoss: " + str(testLoss - regTestLoss) + " + " + str(regTestLoss) + " = " + str(testLoss) + "\n")
bonsaiObj.sigmaI = oldSigmaI
sys.stdout.flush()
print("Maximum Test accuracy at compressed model size(including early stopping): "
+ str(maxTestAcc) + " at Epoch: " + str(maxTestAccEpoch) + "\nFinal Test Accuracy: " + str(testAcc))
print("\nNon-Zeros: " + str(bonsaiObj.getModelSize()) + " Model Size: " +
str(float(bonsaiObj.getModelSize())/1024.0) + " KB \n")
np.save("W.npy", bonsaiObj.Weval.eval())
np.save("V.npy", bonsaiObj.Veval.eval())
np.save("Z.npy", bonsaiObj.Zeval.eval())
np.save("T.npy", bonsaiObj.Teval.eval())

135
tf/bonsai/utils.py Normal file
Просмотреть файл

@ -0,0 +1,135 @@
import tensorflow as tf
import numpy as np
import argparse
## Functions to check sanity of input arguments
def checkIntPos(value):
ivalue = int(value)
if ivalue <= 0:
raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
return ivalue
def checkIntNneg(value):
ivalue = int(value)
if ivalue < 0:
raise argparse.ArgumentTypeError("%s is an invalid non-neg int value" % value)
return ivalue
def checkFloatNneg(value):
fvalue = float(value)
if fvalue < 0:
raise argparse.ArgumentTypeError("%s is an invalid non-neg float value" % value)
return fvalue
def checkFloatPos(value):
fvalue = float(value)
if fvalue <= 0:
raise argparse.ArgumentTypeError("%s is an invalid positive float value" % value)
return fvalue
## Function to parse the input arguments
def getArgs():
parser = argparse.ArgumentParser(description='HyperParams for Bonsai Algorithm')
parser.add_argument('-dir', '--data_dir', required=True, help='Data directory containing train.npy and test.npy')
parser.add_argument('-d', '--depth', type=checkIntNneg, default=2, help='Depth of Bonsai Tree (default: 2 try: [0, 1, 3])')
parser.add_argument('-p', '--projDim', type=checkIntPos, default=10, help='Projection Dimension (default: 20 try: [5, 10, 30])')
parser.add_argument('-s', '--sigma', type=float, default=1.0, help='Parameter for sigmoid sharpness (default: 1.0 try: [3.0, 0.05, 0.1]')
parser.add_argument('-e', '--epochs', type=checkIntPos, default=42, help='Total Epochs (default: 42 try:[100, 150, 60])')
parser.add_argument('-b', '--batchSize', type=checkIntPos, help='Batch Size to be used (default: max(100, sqrt(train_samples)))')
parser.add_argument('-lr', '--learningRate', type=checkFloatPos, default=0.01, help='Initial Learning rate for Adam Oprimizer (default: 0.01)')
parser.add_argument('-rW', type=float, default=0.0001, help='Regularizer for predictor parameter W (default: 0.0001 try: [0.01, 0.001, 0.00001])')
parser.add_argument('-rV', type=float, default=0.0001, help='Regularizer for predictor parameter V (default: 0.0001 try: [0.01, 0.001, 0.00001])')
parser.add_argument('-rT', type=float, default=0.0001, help='Regularizer for branching parameter Theta (default: 0.0001 try: [0.01, 0.001, 0.00001])')
parser.add_argument('-rZ', type=float, default=0.00001, help='Regularizer for projection parameter Z (default: 0.00001 try: [0.001, 0.0001, 0.000001])')
parser.add_argument('-sW', type=checkFloatPos, help='Sparsity for predictor parameter W (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
parser.add_argument('-sV', type=checkFloatPos, help='Sparsity for predictor parameter V (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
parser.add_argument('-sT', type=checkFloatPos, help='Sparsity for branching parameter Theta (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
parser.add_argument('-sZ', type=checkFloatPos, default=0.2, help='Sparsity for projection parameter Z (default: 0.2 try: [0.1, 0.3, 0.5])')
return parser.parse_args()
## Function for Multi Class Hinge Loss : TF has no internal implementation
def multiClassHingeLoss(logits, label, batch_th):
flat_logits = tf.reshape(logits, [-1,])
correct_id = tf.range(0, batch_th) * logits.shape[1] + label
correct_logit = tf.gather(flat_logits, correct_id)
max_label = tf.argmax(logits, 1)
top2, _ = tf.nn.top_k(logits, k=2, sorted=True)
wrong_max_logit = tf.where(tf.equal(max_label, label), top2[:,1], top2[:,0])
return tf.reduce_mean(tf.nn.relu(1. + wrong_max_logit - correct_logit))
## Hard Thresholding Function
def hardThreshold(A, s):
A_ = np.copy(A)
A_ = A_.ravel()
if len(A_) > 0:
th = np.percentile(np.abs(A_), (1 - s)*100.0, interpolation='higher')
A_[np.abs(A_)<th] = 0.0
A_ = A_.reshape(A.shape)
return A_
## Copy the support of src onto dest tensor
def copySupport(src, dest):
support = np.nonzero(src)
dest_ = dest
dest = np.zeros(dest_.shape)
dest[support] = dest_[support]
return dest
## Function to read and pre-process data
def preProcessData(data_dir):
train = np.load(data_dir + '/train.npy')
test = np.load(data_dir + '/test.npy')
dataDimension = int(train.shape[1]) - 1
Xtrain = train[:, 1:dataDimension+1]
Ytrain_ = train[:, 0]
numClasses = max(Ytrain_) - min(Ytrain_) + 1
Xtest = test[:,1:dataDimension+1]
Ytest_ = test[:,0]
numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1))
# Mean Var Normalisation
mean = np.mean(Xtrain,0)
std = np.std(Xtrain,0)
std[std[:]<0.000001] = 1
Xtrain = (Xtrain-mean)/std
Xtest = (Xtest-mean)/std
# End Mean Var normalisation
lab = Ytrain_.astype('uint8')
lab = np.array(lab) - min(lab)
lab_ = np.zeros((Xtrain.shape[0], numClasses))
lab_[np.arange(Xtrain.shape[0]), lab] = 1
if (numClasses == 2):
Ytrain = np.reshape(lab, [-1, 1])
else:
Ytrain = lab_
lab = Ytest_.astype('uint8')
lab = np.array(lab)-min(lab)
lab_ = np.zeros((Xtest.shape[0], numClasses))
lab_[np.arange(Xtest.shape[0]), lab] = 1
if (numClasses == 2):
Ytest = np.reshape(lab, [-1, 1])
else:
Ytest = lab_
trainBias = np.ones([Xtrain.shape[0], 1]);
Xtrain = np.append(Xtrain, trainBias, axis=1)
testBias = np.ones([Xtest.shape[0], 1]);
Xtest = np.append(Xtest, testBias, axis=1)
return dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest