зеркало из https://github.com/microsoft/EdgeML.git
Cleaned up and reorganised Bonsai TF
This commit is contained in:
Родитель
0ccc451f31
Коммит
71135ec11a
|
@ -1,215 +1,148 @@
|
|||
import utils
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
## Bonsai Class
|
||||
|
||||
class Bonsai:
|
||||
## Constructor
|
||||
def __init__(self, C, F, P, D, S, lW, lT, lV, lZ,
|
||||
sW, sT, sV, sZ, lr = None, W = None, T = None,
|
||||
V = None, Z = None, feats = None):
|
||||
def __init__(self, numClasses, dataDimension, projectionDimension,
|
||||
treeDepth, sigma, X,
|
||||
W=None, T=None, V=None, Z=None):
|
||||
'''
|
||||
Expected Dimensions:
|
||||
|
||||
self.dataDimension = F + 1
|
||||
self.projectionDimension = P
|
||||
|
||||
if (C > 2):
|
||||
self.numClasses = C
|
||||
elif (C == 2):
|
||||
self.numClasses = 1
|
||||
Bonsai Params // Optional
|
||||
W [numClasses*totalNodes, projectionDimension]
|
||||
V [numClasses*totalNodes, projectionDimension]
|
||||
Z [projectionDimension, dataDimension + 1]
|
||||
T [internalNodes, projectionDimension]
|
||||
|
||||
self.treeDepth = D
|
||||
self.sigma = S
|
||||
|
||||
## Regularizer coefficients
|
||||
self.lW = lW
|
||||
self.lV = lV
|
||||
self.lT = lT
|
||||
self.lZ = lZ
|
||||
|
||||
## Sparsity hyperparams
|
||||
self.sW = sW
|
||||
self.sV = sV
|
||||
self.sT = sT
|
||||
self.sZ = sZ
|
||||
internalNodes = 2**treeDepth - 1
|
||||
totalNodes = 2*internalNodes + 1
|
||||
|
||||
self.internalNodes = 2**self.treeDepth - 1
|
||||
self.totalNodes = 2*self.internalNodes + 1
|
||||
sigma - tanh nonlinearity
|
||||
sigmaI - Indicator function for node probs
|
||||
|
||||
## The Parameters of Bonsai
|
||||
self.W = self.initW(W)
|
||||
self.V = self.initV(V)
|
||||
self.T = self.initT(T)
|
||||
self.Z = self.initZ(Z)
|
||||
|
||||
## Placeholders for Features and labels
|
||||
## feats are to be fed when joint training is being done with Bonsai as end classifier
|
||||
if feats is None:
|
||||
self.x = tf.placeholder("float", [None, self.dataDimension])
|
||||
else:
|
||||
self.x = feats
|
||||
X is the Data Placeholder - Dims [_, dataDimension]
|
||||
'''
|
||||
|
||||
self.y = tf.placeholder("float", [None, self.numClasses])
|
||||
self.dataDimension = dataDimension
|
||||
self.projectionDimension = projectionDimension
|
||||
|
||||
## Placeholder for batch size, needed for Multiclass hinge loss
|
||||
self.batch_th = tf.placeholder(tf.int64, name='batch_th')
|
||||
self.numClasses = numClasses
|
||||
|
||||
self.sigmaI = 1.0
|
||||
if lr is not None:
|
||||
self.learningRate = lr
|
||||
else:
|
||||
self.learningRate = 0.01
|
||||
self.treeDepth = treeDepth
|
||||
self.sigma = sigma
|
||||
|
||||
## Functions to setup required graphs
|
||||
self.hardThrsd()
|
||||
self.sparseTraining()
|
||||
self.lossGraph()
|
||||
self.trainGraph()
|
||||
self.accuracyGraph()
|
||||
self.internalNodes = 2**self.treeDepth - 1
|
||||
self.totalNodes = 2 * self.internalNodes + 1
|
||||
|
||||
## Functions to initilaise Params (Warm start possible with given numpy matrices)
|
||||
def initZ(self, Z):
|
||||
if Z is None:
|
||||
Z = tf.random_normal([self.projectionDimension, self.dataDimension])
|
||||
Z = tf.Variable(Z, name='Z', dtype=tf.float32)
|
||||
return Z
|
||||
self.W = self.initW(W)
|
||||
self.V = self.initV(V)
|
||||
self.T = self.initT(T)
|
||||
self.Z = self.initZ(Z)
|
||||
|
||||
def initW(self, W):
|
||||
if W is None:
|
||||
W = tf.random_normal([self.numClasses*self.totalNodes, self.projectionDimension])
|
||||
W = tf.Variable(W, name='W', dtype=tf.float32)
|
||||
return W
|
||||
self.sigmaI = 1.0
|
||||
|
||||
def initV(self, V):
|
||||
if V is None:
|
||||
V = tf.random_normal([self.numClasses*self.totalNodes, self.projectionDimension])
|
||||
V = tf.Variable(V, name='V', dtype=tf.float32)
|
||||
return V
|
||||
self.X = X
|
||||
|
||||
def initT(self, T):
|
||||
if T is None:
|
||||
T = tf.random_normal([self.internalNodes, self.projectionDimension])
|
||||
T = tf.Variable(T, name='T', dtype=tf.float32)
|
||||
return T
|
||||
self.assertInit()
|
||||
|
||||
## Function to get aimed model size
|
||||
def getModelSize(self):
|
||||
nnzZ = np.ceil(int(self.Z.shape[0]*self.Z.shape[1])*self.sZ)
|
||||
nnzW = np.ceil(int(self.W.shape[0]*self.W.shape[1])*self.sW)
|
||||
nnzV = np.ceil(int(self.V.shape[0]*self.V.shape[1])*self.sV)
|
||||
nnzT = np.ceil(int(self.T.shape[0]*self.T.shape[1])*self.sT)
|
||||
return int((nnzZ+nnzT+nnzV+nnzW)*8)
|
||||
self.score, self.X_ = self.bonsaiGraph(self.X)
|
||||
self.prediction = self.getPrediction(self.score)
|
||||
|
||||
## Function to build the Bonsai Tree graph
|
||||
def bonsaiGraph(self, X):
|
||||
X = tf.reshape(X, [-1,self.dataDimension])
|
||||
X_ = tf.divide(tf.matmul(self.Z, X, transpose_b=True), self.projectionDimension)
|
||||
|
||||
W_ = self.W[0:(self.numClasses)]
|
||||
V_ = self.V[0:(self.numClasses)]
|
||||
def initZ(self, Z):
|
||||
if Z is None:
|
||||
Z = tf.random_normal(
|
||||
[self.projectionDimension, self.dataDimension])
|
||||
Z = tf.Variable(Z, name='Z', dtype=tf.float32)
|
||||
return Z
|
||||
|
||||
self.nodeProb = []
|
||||
self.nodeProb.append(1)
|
||||
|
||||
score_ = self.nodeProb[0]*tf.multiply(tf.matmul(W_, X_), tf.tanh(self.sigma*tf.matmul(V_, X_)))
|
||||
for i in range(1, self.totalNodes):
|
||||
W_ = self.W[i*self.numClasses:((i+1)*self.numClasses)]
|
||||
V_ = self.V[i*self.numClasses:((i+1)*self.numClasses)]
|
||||
def initW(self, W):
|
||||
if W is None:
|
||||
W = tf.random_normal(
|
||||
[self.numClasses * self.totalNodes, self.projectionDimension])
|
||||
W = tf.Variable(W, name='W', dtype=tf.float32)
|
||||
return W
|
||||
|
||||
prob = (1+((-1)**(i+1))*tf.tanh(tf.multiply(self.sigmaI,
|
||||
tf.matmul(tf.reshape(self.T[int(np.ceil(i/2)-1)], [-1, self.projectionDimension]), X_))))
|
||||
def initV(self, V):
|
||||
if V is None:
|
||||
V = tf.random_normal(
|
||||
[self.numClasses * self.totalNodes, self.projectionDimension])
|
||||
V = tf.Variable(V, name='V', dtype=tf.float32)
|
||||
return V
|
||||
|
||||
prob = tf.divide(prob, 2)
|
||||
prob = self.nodeProb[int(np.ceil(i/2)-1)]*prob
|
||||
self.nodeProb.append(prob)
|
||||
score_ += self.nodeProb[i]*tf.multiply(tf.matmul(W_, X_), tf.tanh(self.sigma*tf.matmul(V_, X_)))
|
||||
|
||||
return score_, X_, self.T, self.W, self.V, self.Z
|
||||
def initT(self, T):
|
||||
if T is None:
|
||||
T = tf.random_normal(
|
||||
[self.internalNodes, self.projectionDimension])
|
||||
T = tf.Variable(T, name='T', dtype=tf.float32)
|
||||
return T
|
||||
|
||||
## Functions setting up graphs for IHT and Sparse Retraining
|
||||
def hardThrsd(self):
|
||||
## Placeholders for Hard Thresholding and Sparse Training
|
||||
self.Wth = tf.placeholder(tf.float32, name='Wth')
|
||||
self.Vth = tf.placeholder(tf.float32, name='Vth')
|
||||
self.Zth = tf.placeholder(tf.float32, name='Zth')
|
||||
self.Tth = tf.placeholder(tf.float32, name='Tth')
|
||||
def bonsaiGraph(self, X):
|
||||
'''
|
||||
Function to build the Bonsai Tree graph
|
||||
Expected Dimensions
|
||||
|
||||
self.Woph = self.W.assign(self.Wth)
|
||||
self.Voph = self.V.assign(self.Vth)
|
||||
self.Toph = self.T.assign(self.Tth)
|
||||
self.Zoph = self.Z.assign(self.Zth)
|
||||
|
||||
self.hardThresholdGroup = tf.group(self.Woph, self.Voph, self.Toph, self.Zoph)
|
||||
X is [_, self.dataDimension]
|
||||
'''
|
||||
|
||||
def runHardThrsd(self, sess):
|
||||
currW = self.Weval.eval()
|
||||
currV = self.Veval.eval()
|
||||
currZ = self.Zeval.eval()
|
||||
currT = self.Teval.eval()
|
||||
X_ = tf.divide(tf.matmul(self.Z, X, transpose_b=True),
|
||||
self.projectionDimension)
|
||||
|
||||
self.thrsdW = utils.hardThreshold(currW, self.sW)
|
||||
self.thrsdV = utils.hardThreshold(currV, self.sV)
|
||||
self.thrsdZ = utils.hardThreshold(currZ, self.sZ)
|
||||
self.thrsdT = utils.hardThreshold(currT, self.sT)
|
||||
W_ = self.W[0:(self.numClasses)]
|
||||
V_ = self.V[0:(self.numClasses)]
|
||||
|
||||
fd_thrsd = {self.Wth:self.thrsdW, self.Vth:self.thrsdV, self.Zth:self.thrsdZ, self.Tth:self.thrsdT}
|
||||
sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd)
|
||||
self.__nodeProb = []
|
||||
self.__nodeProb.append(1)
|
||||
|
||||
def sparseTraining(self):
|
||||
self.Wops = self.W.assign(self.Wth)
|
||||
self.Vops = self.V.assign(self.Vth)
|
||||
self.Zops = self.Z.assign(self.Zth)
|
||||
self.Tops = self.T.assign(self.Tth)
|
||||
self.sparseRetrainGroup = tf.group(self.Wops, self.Vops, self.Tops, self.Zops)
|
||||
score_ = self.__nodeProb[0] * tf.multiply(
|
||||
tf.matmul(W_, X_), tf.tanh(self.sigma * tf.matmul(V_, X_)))
|
||||
for i in range(1, self.totalNodes):
|
||||
W_ = self.W[i * self.numClasses:((i + 1) * self.numClasses)]
|
||||
V_ = self.V[i * self.numClasses:((i + 1) * self.numClasses)]
|
||||
|
||||
def runSparseTraining(self, sess):
|
||||
currW = self.Weval.eval()
|
||||
currV = self.Veval.eval()
|
||||
currZ = self.Zeval.eval()
|
||||
currT = self.Teval.eval()
|
||||
prob = (1 + ((-1)**(i + 1)) * tf.tanh(tf.multiply(self.sigmaI,
|
||||
tf.matmul(tf.reshape(self.T[int(np.ceil(i / 2) - 1)], [-1, self.projectionDimension]), X_))))
|
||||
|
||||
newW = utils.copySupport(self.thrsdW, currW)
|
||||
newV = utils.copySupport(self.thrsdV, currV)
|
||||
newZ = utils.copySupport(self.thrsdZ, currZ)
|
||||
newT = utils.copySupport(self.thrsdT, currT)
|
||||
prob = tf.divide(prob, 2)
|
||||
prob = self.__nodeProb[int(np.ceil(i / 2) - 1)] * prob
|
||||
self.__nodeProb.append(prob)
|
||||
score_ += self.__nodeProb[i] * tf.multiply(
|
||||
tf.matmul(W_, X_), tf.tanh(self.sigma * tf.matmul(V_, X_)))
|
||||
|
||||
fd_st = {self.Wth:newW, self.Vth:newV, self.Zth:newZ, self.Tth:newT}
|
||||
sess.run(self.sparseRetrainGroup, feed_dict=fd_st)
|
||||
return score_, X_
|
||||
|
||||
## Function to build a Loss graph for Bonsai
|
||||
def lossGraph(self):
|
||||
self.score, self.Xeval, self.Teval, self.Weval, self.Veval, self.Zeval = self.bonsaiGraph(self.x)
|
||||
|
||||
self.regLoss = 0.5*(self.lZ*tf.square(tf.norm(self.Z)) + self.lW*tf.square(tf.norm(self.W)) +
|
||||
self.lV*tf.square(tf.norm(self.V)) + self.lT*tf.square(tf.norm(self.T)))
|
||||
|
||||
if (self.numClasses > 2):
|
||||
## need to give users an option to choose the loss
|
||||
if True:
|
||||
self.marginLoss = utils.multiClassHingeLoss(tf.transpose(self.score), tf.argmax(self.y,1), self.batch_th)
|
||||
else:
|
||||
self.marginLoss = utils.crossEntropyLoss(tf.transpose(self.score), self.y)
|
||||
self.loss = self.marginLoss + self.regLoss
|
||||
else:
|
||||
self.marginLoss = tf.reduce_mean(tf.nn.relu(1.0 - (2*self.y-1)*tf.transpose(self.score)))
|
||||
self.loss = self.marginLoss + self.regLoss
|
||||
|
||||
## Function to set up optimisation for Bonsai
|
||||
def trainGraph(self):
|
||||
self.trainStep = tf.train.AdamOptimizer(self.learningRate).minimize(self.loss)
|
||||
|
||||
## Function to run training step on Bonsai
|
||||
def runTraining(self, sess, _feed_dict):
|
||||
sess.run([self.trainStep], feed_dict=_feed_dict)
|
||||
|
||||
## Function to build a graph to compute accuracy of the current model
|
||||
def accuracyGraph(self):
|
||||
if (self.numClasses > 2):
|
||||
correctPrediction = tf.equal(tf.argmax(tf.transpose(self.score),1), tf.argmax(self.y,1))
|
||||
self.accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32))
|
||||
else:
|
||||
y_ = self.y*2-1
|
||||
correctPrediction = tf.multiply(tf.transpose(self.score), y_)
|
||||
correctPrediction = tf.nn.relu(correctPrediction)
|
||||
correctPrediction = tf.ceil(tf.tanh(correctPrediction))
|
||||
self.accuracy = tf.reduce_mean(tf.cast(correctPrediction, tf.float32))
|
||||
def getPrediction(self, score):
|
||||
'''
|
||||
Takes in a score tensor and outputs a integer class for each datapoint
|
||||
'''
|
||||
if self.numClasses > 2:
|
||||
prediction = tf.argmax(tf.transpose(score), 1)
|
||||
return prediction
|
||||
else:
|
||||
prediction = tf.argmax(
|
||||
tf.concat([tf.transpose(score),
|
||||
0 * tf.transpose(score)], 1), 1)
|
||||
return prediction
|
||||
|
||||
def assertInit(self):
|
||||
errmsg = "Dimension Mismatch, X is [_, self.dataDimension]"
|
||||
assert (len(self.X.shape) == 2 and int(
|
||||
self.X.shape[1]) == self.dataDimension), errmsg
|
||||
errRank = "All Parameters must has only two dimensions shape = [a, b]"
|
||||
assert len(self.W.shape) == len(self.Z.shape), errRank
|
||||
assert len(self.W.shape) == len(self.T.shape), errRank
|
||||
assert len(self.W.shape) == 2, errRank
|
||||
assert self.W.shape == self.V.shape, "W and V should be of same Dimensions"
|
||||
errW = "W and V are [numClasses*totalNodes, projectionDimension]"
|
||||
assert self.W.shape[0] == self.numClasses * self.totalNodes, errW
|
||||
assert self.W.shape[1] == self.projectionDimension, errW
|
||||
errZ = "Z is [projectionDimension, dataDimension]"
|
||||
assert self.Z.shape[0] == self.projectionDimension, errZ
|
||||
assert self.Z.shape[1] == self.dataDimension, errZ
|
||||
errT = "T is [internalNodes, projectionDimension]"
|
||||
assert self.T.shape[0] == self.internalNodes, errT
|
||||
assert self.T.shape[1] == self.projectionDimension, errT
|
||||
assert int(self.numClasses) > 0, "numClasses should be > 1"
|
||||
assert int(self.dataDimension) > 0, "# of features in data should be > 0"
|
||||
assert int(self.projectionDimension) > 0, "Projection should be > 0 dims"
|
||||
assert int(self.treeDepth) >= 0, "treeDepth should be >= 0"
|
||||
|
|
|
@ -0,0 +1,331 @@
|
|||
import tensorflow as tf
|
||||
import utils
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
|
||||
class BonsaiTrainer:
|
||||
def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ,
|
||||
learningRate, Y, lossFlag=True):
|
||||
'''
|
||||
bonsaiObj - Initialised Bonsai Object and Graph
|
||||
lW, lT, lV and lZ are regularisers to Bonsai Params
|
||||
sW, sT, sV and sZ are sparsity factors to Bonsai Params
|
||||
learningRate - learningRate fro optimizer
|
||||
Y - Label placeholder for loss computation
|
||||
lossFlag - For choice between HingeLoss vs CrossEntropy
|
||||
lossFlag - True - MultiClass - multiClassHingeLoss
|
||||
lossFlag - False - MultiClass - crossEntropyLoss
|
||||
'''
|
||||
|
||||
self.bonsaiObj = bonsaiObj
|
||||
|
||||
self.lW = lW
|
||||
self.lV = lV
|
||||
self.lT = lT
|
||||
self.lZ = lZ
|
||||
|
||||
self.sW = sW
|
||||
self.sV = sV
|
||||
self.sT = sT
|
||||
self.sZ = sZ
|
||||
|
||||
self.Y = Y
|
||||
self.X = self.bonsaiObj.X
|
||||
|
||||
self.lossFlag = lossFlag
|
||||
|
||||
self.learningRate = learningRate
|
||||
|
||||
self.assertInit()
|
||||
|
||||
self.score = self.bonsaiObj.score
|
||||
self.X_ = self.bonsaiObj.X_
|
||||
|
||||
self.loss, self.marginLoss, self.regLoss = self.lossGraph()
|
||||
|
||||
self.trainStep = self.trainGraph()
|
||||
self.accuracy = self.accuracyGraph()
|
||||
self.prediction = self.bonsaiObj.prediction
|
||||
|
||||
self.hardThrsd()
|
||||
self.sparseTraining()
|
||||
|
||||
def lossGraph(self):
|
||||
'''
|
||||
Loss Graph for given Bonsai Obj
|
||||
'''
|
||||
self.regLoss = 0.5 * (self.lZ * tf.square(tf.norm(self.bonsaiObj.Z)) +
|
||||
self.lW * tf.square(tf.norm(self.bonsaiObj.W)) +
|
||||
self.lV * tf.square(tf.norm(self.bonsaiObj.V)) +
|
||||
self.lT * tf.square(tf.norm(self.bonsaiObj.T)))
|
||||
|
||||
if (self.bonsaiObj.numClasses > 2):
|
||||
if self.lossFlag is True:
|
||||
self.batch_th = tf.placeholder(tf.int64, name='batch_th')
|
||||
self.marginLoss = utils.multiClassHingeLoss(
|
||||
tf.transpose(self.score), tf.argmax(self.Y, 1), self.batch_th)
|
||||
else:
|
||||
self.marginLoss = utils.crossEntropyLoss(
|
||||
tf.transpose(self.score), self.Y)
|
||||
self.loss = self.marginLoss + self.regLoss
|
||||
else:
|
||||
self.marginLoss = tf.reduce_mean(tf.nn.relu(
|
||||
1.0 - (2 * self.Y - 1) * tf.transpose(self.score)))
|
||||
self.loss = self.marginLoss + self.regLoss
|
||||
|
||||
return self.loss, self.marginLoss, self.regLoss
|
||||
|
||||
def trainGraph(self):
|
||||
'''
|
||||
Train Graph for the loss generated by Bonsai
|
||||
'''
|
||||
self.bonsaiObj.TrainStep = tf.train.AdamOptimizer(
|
||||
self.learningRate).minimize(self.loss)
|
||||
|
||||
return self.bonsaiObj.TrainStep
|
||||
|
||||
def accuracyGraph(self):
|
||||
'''
|
||||
Accuracy Graph to evaluate accuracy when needed
|
||||
'''
|
||||
if (self.bonsaiObj.numClasses > 2):
|
||||
correctPrediction = tf.equal(
|
||||
tf.argmax(tf.transpose(self.score), 1), tf.argmax(self.Y, 1))
|
||||
self.accuracy = tf.reduce_mean(
|
||||
tf.cast(correctPrediction, tf.float32))
|
||||
else:
|
||||
y_ = self.Y * 2 - 1
|
||||
correctPrediction = tf.multiply(tf.transpose(self.score), y_)
|
||||
correctPrediction = tf.nn.relu(correctPrediction)
|
||||
correctPrediction = tf.ceil(tf.tanh(correctPrediction))
|
||||
self.accuracy = tf.reduce_mean(
|
||||
tf.cast(correctPrediction, tf.float32))
|
||||
|
||||
return self.accuracy
|
||||
|
||||
def hardThrsd(self):
|
||||
'''
|
||||
Set up for hard Thresholding Functionality
|
||||
'''
|
||||
self.__Wth = tf.placeholder(tf.float32, name='Wth')
|
||||
self.__Vth = tf.placeholder(tf.float32, name='Vth')
|
||||
self.__Zth = tf.placeholder(tf.float32, name='Zth')
|
||||
self.__Tth = tf.placeholder(tf.float32, name='Tth')
|
||||
|
||||
self.__Woph = self.bonsaiObj.W.assign(self.__Wth)
|
||||
self.__Voph = self.bonsaiObj.V.assign(self.__Vth)
|
||||
self.__Toph = self.bonsaiObj.T.assign(self.__Tth)
|
||||
self.__Zoph = self.bonsaiObj.Z.assign(self.__Zth)
|
||||
|
||||
self.hardThresholdGroup = tf.group(
|
||||
self.__Woph, self.__Voph, self.__Toph, self.__Zoph)
|
||||
|
||||
def sparseTraining(self):
|
||||
'''
|
||||
Set up for Sparse Retraining Functionality
|
||||
'''
|
||||
self.__Wops = self.bonsaiObj.W.assign(self.__Wth)
|
||||
self.__Vops = self.bonsaiObj.V.assign(self.__Vth)
|
||||
self.__Zops = self.bonsaiObj.Z.assign(self.__Zth)
|
||||
self.__Tops = self.bonsaiObj.T.assign(self.__Tth)
|
||||
|
||||
self.sparseRetrainGroup = tf.group(
|
||||
self.__Wops, self.__Vops, self.__Tops, self.__Zops)
|
||||
|
||||
def runHardThrsd(self, sess):
|
||||
'''
|
||||
Function to run the IHT routine on Bonsai Obj
|
||||
'''
|
||||
currW = self.bonsaiObj.W.eval()
|
||||
currV = self.bonsaiObj.V.eval()
|
||||
currZ = self.bonsaiObj.Z.eval()
|
||||
currT = self.bonsaiObj.T.eval()
|
||||
|
||||
self.__thrsdW = utils.hardThreshold(currW, self.sW)
|
||||
self.__thrsdV = utils.hardThreshold(currV, self.sV)
|
||||
self.__thrsdZ = utils.hardThreshold(currZ, self.sZ)
|
||||
self.__thrsdT = utils.hardThreshold(currT, self.sT)
|
||||
|
||||
fd_thrsd = {self.__Wth: self.__thrsdW, self.__Vth: self.__thrsdV,
|
||||
self.__Zth: self.__thrsdZ, self.__Tth: self.__thrsdT}
|
||||
sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd)
|
||||
|
||||
def runSparseTraining(self, sess):
|
||||
'''
|
||||
Function to run the Sparse Retraining routine on Bonsai Obj
|
||||
'''
|
||||
currW = self.bonsaiObj.W.eval()
|
||||
currV = self.bonsaiObj.V.eval()
|
||||
currZ = self.bonsaiObj.Z.eval()
|
||||
currT = self.bonsaiObj.T.eval()
|
||||
|
||||
newW = utils.copySupport(self.__thrsdW, currW)
|
||||
newV = utils.copySupport(self.__thrsdV, currV)
|
||||
newZ = utils.copySupport(self.__thrsdZ, currZ)
|
||||
newT = utils.copySupport(self.__thrsdT, currT)
|
||||
|
||||
fd_st = {self.__Wth: newW, self.__Vth: newV,
|
||||
self.__Zth: newZ, self.__Tth: newT}
|
||||
sess.run(self.sparseRetrainGroup, feed_dict=fd_st)
|
||||
|
||||
def assertInit(self):
|
||||
err = "sparsity must be between 0 and 1"
|
||||
assert self.sW >= 0 and self.sW <= 1, "W " + err
|
||||
assert self.sV >= 0 and self.sV <= 1, "V " + err
|
||||
assert self.sZ >= 0 and self.sZ <= 1, "Z " + err
|
||||
assert self.sT >= 0 and self.sT <= 1, "T " + err
|
||||
errMsg = "Dimension Mismatch, X is [_, self.dataDimension]"
|
||||
assert (len(self.Y.shape) == 2 and
|
||||
self.Y.shape[1] == self.bonsaiObj.numClasses), errMsg
|
||||
|
||||
# Function to get aimed model size
|
||||
def getModelSize(self):
|
||||
nnzZ = np.ceil(
|
||||
int(self.bonsaiObj.Z.shape[0] * self.bonsaiObj.Z.shape[1]) * self.sZ)
|
||||
nnzW = np.ceil(
|
||||
int(self.bonsaiObj.W.shape[0] * self.bonsaiObj.W.shape[1]) * self.sW)
|
||||
nnzV = np.ceil(
|
||||
int(self.bonsaiObj.V.shape[0] * self.bonsaiObj.V.shape[1]) * self.sV)
|
||||
nnzT = np.ceil(
|
||||
int(self.bonsaiObj.T.shape[0] * self.bonsaiObj.T.shape[1]) * self.sT)
|
||||
return int((nnzZ + nnzT + nnzV + nnzW) * 8)
|
||||
|
||||
def train(self, batchSize, totalEpochs, sess,
|
||||
Xtrain, Xtest, Ytrain, Ytest):
|
||||
'''
|
||||
The Dense - IHT - Sparse Retrain Routine for Bonsai Training
|
||||
'''
|
||||
numIters = Xtrain.shape[0] / batchSize
|
||||
|
||||
totalBatches = numIters * totalEpochs
|
||||
|
||||
counter = 0
|
||||
if self.bonsaiObj.numClasses > 2:
|
||||
trimlevel = 15
|
||||
else:
|
||||
trimlevel = 5
|
||||
ihtDone = 0
|
||||
|
||||
for i in range(totalEpochs):
|
||||
print("\nEpoch Number: " + str(i))
|
||||
|
||||
trainAcc = 0.0
|
||||
for j in range(numIters):
|
||||
|
||||
if counter == 0:
|
||||
print(
|
||||
"\n******************** Dense Training Phase Started ********************\n")
|
||||
|
||||
# Updating the indicator sigma
|
||||
if ((counter == 0) or (counter == int(totalBatches / 3)) or
|
||||
(counter == int(2 * totalBatches / 3))):
|
||||
self.bonsaiObj.sigmaI = 1
|
||||
itersInPhase = 0
|
||||
|
||||
elif (itersInPhase % 100 == 0):
|
||||
indices = np.random.choice(Xtrain.shape[0], 100)
|
||||
batchX = Xtrain[indices, :]
|
||||
batchY = Ytrain[indices, :]
|
||||
batchY = np.reshape(
|
||||
batchY, [-1, self.bonsaiObj.numClasses])
|
||||
|
||||
_feed_dict = {self.X: batchX}
|
||||
Xcapeval = self.X_.eval(feed_dict=_feed_dict)
|
||||
Teval = self.bonsaiObj.T.eval()
|
||||
|
||||
sum_tr = 0.0
|
||||
for k in range(0, self.bonsaiObj.internalNodes):
|
||||
sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval))))
|
||||
|
||||
if(self.bonsaiObj.internalNodes > 0):
|
||||
sum_tr /= (100 * self.bonsaiObj.internalNodes)
|
||||
sum_tr = 0.1 / sum_tr
|
||||
else:
|
||||
sum_tr = 0.1
|
||||
sum_tr = min(
|
||||
1000, sum_tr * (2**(float(itersInPhase) /
|
||||
(float(totalBatches) / 30.0))))
|
||||
|
||||
self.bonsaiObj.sigmaI = sum_tr
|
||||
|
||||
itersInPhase += 1
|
||||
batchX = Xtrain[j * batchSize:(j + 1) * batchSize]
|
||||
batchY = Ytrain[j * batchSize:(j + 1) * batchSize]
|
||||
batchY = np.reshape(
|
||||
batchY, [-1, self.bonsaiObj.numClasses])
|
||||
|
||||
if self.bonsaiObj.numClasses > 2:
|
||||
if self.lossFlag is True:
|
||||
_feed_dict = {self.X: batchX, self.Y: batchY,
|
||||
self.batch_th: batchY.shape[0]}
|
||||
else:
|
||||
_feed_dict = {self.X: batchX, self.Y: batchY}
|
||||
else:
|
||||
_feed_dict = {self.X: batchX, self.Y: batchY}
|
||||
|
||||
# Mini-batch training
|
||||
_, batchLoss = sess.run(
|
||||
[self.trainStep, self.loss], feed_dict=_feed_dict)
|
||||
|
||||
batchAcc = self.accuracy.eval(feed_dict=_feed_dict)
|
||||
trainAcc += batchAcc
|
||||
|
||||
# Training routine involving IHT and sparse retraining
|
||||
if (counter >= int(totalBatches / 3) and
|
||||
(counter < int(2 * totalBatches / 3)) and
|
||||
counter % trimlevel == 0):
|
||||
self.runHardThrsd(sess)
|
||||
if ihtDone == 0:
|
||||
print(
|
||||
"\n******************** IHT Phase Started ********************\n")
|
||||
ihtDone = 1
|
||||
elif ((ihtDone == 1 and counter >= int(totalBatches / 3) and
|
||||
(counter < int(2 * totalBatches / 3)) and
|
||||
counter % trimlevel != 0) or
|
||||
(counter >= int(2 * totalBatches / 3))):
|
||||
self.runSparseTraining(sess)
|
||||
if counter == int(2 * totalBatches / 3):
|
||||
print(
|
||||
"\n******************** Sprase Retraining Phase Started ********************\n")
|
||||
counter += 1
|
||||
|
||||
print("Train accuracy " + str(trainAcc / numIters))
|
||||
|
||||
if self.bonsaiObj.numClasses > 2:
|
||||
if self.lossFlag is True:
|
||||
_feed_dict = {self.X: Xtest, self.Y: Ytest,
|
||||
self.batch_th: Ytest.shape[0]}
|
||||
else:
|
||||
_feed_dict = {self.X: Xtest, self.Y: Ytest}
|
||||
else:
|
||||
_feed_dict = {self.X: Xtest, self.Y: Ytest}
|
||||
|
||||
# This helps in direct testing instead of extracting the model out
|
||||
oldSigmaI = self.bonsaiObj.sigmaI
|
||||
self.bonsaiObj.sigmaI = 1e9
|
||||
|
||||
testAcc = self.accuracy.eval(feed_dict=_feed_dict)
|
||||
if ihtDone == 0:
|
||||
maxTestAcc = -10000
|
||||
maxTestAccEpoch = i
|
||||
else:
|
||||
if maxTestAcc <= testAcc:
|
||||
maxTestAccEpoch = i
|
||||
maxTestAcc = testAcc
|
||||
|
||||
print("Test accuracy %g" % testAcc)
|
||||
|
||||
testLoss = self.loss.eval(feed_dict=_feed_dict)
|
||||
regTestLoss = self.regLoss.eval(feed_dict=_feed_dict)
|
||||
print("MarginLoss + RegLoss: " + str(testLoss - regTestLoss) +
|
||||
" + " + str(regTestLoss) + " = " + str(testLoss) + "\n")
|
||||
|
||||
self.bonsaiObj.sigmaI = oldSigmaI
|
||||
sys.stdout.flush()
|
||||
|
||||
print("Maximum Test accuracy at compressed model size(including early stopping): " +
|
||||
str(maxTestAcc) + " at Epoch: " +
|
||||
str(maxTestAccEpoch) + "\nFinal Test Accuracy: " + str(testAcc))
|
||||
print("\nNon-Zeros: " + str(self.getModelSize()) + " Model Size: " +
|
||||
str(float(self.getModelSize()) / 1024.0) + " KB \n")
|
|
@ -1,14 +1,14 @@
|
|||
import utils
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import sys
|
||||
from bonsaiTrainer import BonsaiTrainer
|
||||
from bonsai import Bonsai
|
||||
|
||||
## Fixing seeds for reproducibility
|
||||
# Fixing seeds for reproducibility
|
||||
tf.set_random_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
## Hyper Param pre-processing
|
||||
# Hyper Param pre-processing
|
||||
args = utils.getArgs()
|
||||
|
||||
sigma = args.sigma
|
||||
|
@ -26,154 +26,56 @@ learningRate = args.learningRate
|
|||
|
||||
data_dir = args.data_dir
|
||||
|
||||
(dataDimension, numClasses,
|
||||
Xtrain, Ytrain, Xtest, Ytest) = utils.preProcessData(data_dir)
|
||||
(dataDimension, numClasses,
|
||||
Xtrain, Ytrain, Xtest, Ytest) = utils.preProcessData(data_dir)
|
||||
|
||||
sparZ = args.sZ
|
||||
|
||||
if numClasses == 2:
|
||||
sparW = 1
|
||||
sparV = 1
|
||||
sparT = 1
|
||||
if numClasses > 2:
|
||||
sparW = 0.2
|
||||
sparV = 0.2
|
||||
sparT = 0.2
|
||||
else:
|
||||
sparW = 0.2
|
||||
sparV = 0.2
|
||||
sparT = 0.2
|
||||
sparW = 1
|
||||
sparV = 1
|
||||
sparT = 1
|
||||
|
||||
if args.sW is not None:
|
||||
sparW = args.sW
|
||||
sparW = args.sW
|
||||
if args.sV is not None:
|
||||
sparV = args.sV
|
||||
sparV = args.sV
|
||||
if args.sT is not None:
|
||||
sparT = args.sT
|
||||
sparT = args.sT
|
||||
|
||||
if args.batchSize is None:
|
||||
batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0]))))
|
||||
batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0]))))
|
||||
else:
|
||||
batchSize = args.batchSize
|
||||
batchSize = args.batchSize
|
||||
|
||||
lossFlag = True
|
||||
|
||||
## Creation of Bonsai Object
|
||||
bonsaiObj = Bonsai(numClasses, dataDimension, projectionDimension, depth, sigma,
|
||||
regW, regT, regV, regZ, sparW, sparT, sparV, sparZ, lr = learningRate)
|
||||
X = tf.placeholder("float32", [None, dataDimension])
|
||||
Y = tf.placeholder("float32", [None, numClasses])
|
||||
|
||||
bonsaiObj = Bonsai(numClasses, dataDimension,
|
||||
projectionDimension, depth, sigma, X)
|
||||
|
||||
bonsaiTrainer = BonsaiTrainer(bonsaiObj,
|
||||
regW, regT, regV, regZ,
|
||||
sparW, sparT, sparV, sparZ,
|
||||
learningRate, Y, lossFlag)
|
||||
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.group(tf.initialize_all_variables(), tf.initialize_variables(tf.local_variables())))
|
||||
saver = tf.train.Saver() ## Use it incase of saving the model
|
||||
sess.run(tf.group(tf.initialize_all_variables(),
|
||||
tf.initialize_variables(tf.local_variables())))
|
||||
saver = tf.train.Saver()
|
||||
|
||||
numIters = Xtrain.shape[0]/batchSize
|
||||
bonsaiTrainer.train(batchSize, totalEpochs, sess, Xtrain, Xtest, Ytrain, Ytest)
|
||||
|
||||
totalBatches = numIters*totalEpochs
|
||||
print(bonsaiTrainer.bonsaiObj.score.eval(feed_dict={X: Xtest}))
|
||||
print(bonsaiObj.score.eval(feed_dict={X: Xtest}))
|
||||
|
||||
counter = 0
|
||||
if bonsaiObj.numClasses > 2:
|
||||
trimlevel = 15
|
||||
else:
|
||||
trimlevel = 5
|
||||
ihtDone = 0
|
||||
|
||||
|
||||
for i in range(totalEpochs):
|
||||
print("\nEpoch Number: "+str(i))
|
||||
|
||||
trainAcc = 0.0
|
||||
for j in range(numIters):
|
||||
|
||||
if counter == 0:
|
||||
print("\n******************** Dense Training Phase Started ********************\n")
|
||||
|
||||
## Updating the indicator sigma
|
||||
if ((counter == 0) or (counter == int(totalBatches/3)) or (counter == int(2*totalBatches/3))):
|
||||
bonsaiObj.sigmaI = 1
|
||||
itersInPhase = 0
|
||||
|
||||
elif (itersInPhase%100 == 0):
|
||||
indices = np.random.choice(Xtrain.shape[0],100)
|
||||
batchX = Xtrain[indices,:]
|
||||
batchY = Ytrain[indices,:]
|
||||
batchY = np.reshape(batchY, [-1, bonsaiObj.numClasses])
|
||||
|
||||
_feed_dict = {bonsaiObj.x: batchX, bonsaiObj.y: batchY}
|
||||
Xcapeval = bonsaiObj.Xeval.eval(feed_dict=_feed_dict)
|
||||
Teval = bonsaiObj.Teval.eval()
|
||||
|
||||
sum_tr = 0.0
|
||||
for k in range(0, bonsaiObj.internalNodes):
|
||||
sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval))))
|
||||
|
||||
if(bonsaiObj.internalNodes > 0):
|
||||
sum_tr /= (100*bonsaiObj.internalNodes)
|
||||
sum_tr = 0.1/sum_tr
|
||||
else:
|
||||
sum_tr = 0.1
|
||||
sum_tr = min(1000,sum_tr*(2**(float(itersInPhase)/(float(totalBatches)/30.0))))
|
||||
|
||||
bonsaiObj.sigmaI = sum_tr
|
||||
|
||||
itersInPhase += 1
|
||||
batchX = Xtrain[j*batchSize:(j+1)*batchSize]
|
||||
batchY = Ytrain[j*batchSize:(j+1)*batchSize]
|
||||
batchY = np.reshape(batchY, [-1, bonsaiObj.numClasses])
|
||||
|
||||
if bonsaiObj.numClasses > 2:
|
||||
_feed_dict = {bonsaiObj.x: batchX, bonsaiObj.y: batchY, bonsaiObj.batch_th: batchY.shape[0]}
|
||||
else:
|
||||
_feed_dict = {bonsaiObj.x: batchX, bonsaiObj.y: batchY}
|
||||
|
||||
## Mini-batch training
|
||||
batchLoss = bonsaiObj.runTraining(sess, _feed_dict)
|
||||
|
||||
batchAcc = bonsaiObj.accuracy.eval(feed_dict=_feed_dict)
|
||||
trainAcc += batchAcc
|
||||
|
||||
## Training routine involving IHT and sparse retraining
|
||||
if (counter >= int(totalBatches/3) and (counter < int(2*totalBatches/3)) and counter%trimlevel == 0):
|
||||
bonsaiObj.runHardThrsd(sess)
|
||||
if ihtDone == 0:
|
||||
print("\n******************** IHT Phase Started ********************\n")
|
||||
ihtDone = 1
|
||||
elif ((ihtDone == 1 and counter >= int(totalBatches/3) and (counter < int(2*totalBatches/3))
|
||||
and counter%trimlevel != 0) or (counter >= int(2*totalBatches/3))):
|
||||
bonsaiObj.runSparseTraining(sess)
|
||||
if counter == int(2*totalBatches/3):
|
||||
print("\n******************** Sprase Retraining Phase Started ********************\n")
|
||||
counter += 1
|
||||
|
||||
print("Train accuracy "+str(trainAcc/numIters))
|
||||
|
||||
if bonsaiObj.numClasses > 2:
|
||||
_feed_dict = {bonsaiObj.x: Xtest, bonsaiObj.y: Ytest, bonsaiObj.batch_th: Ytest.shape[0]}
|
||||
else:
|
||||
_feed_dict = {bonsaiObj.x: Xtest, bonsaiObj.y: Ytest}
|
||||
|
||||
## this helps in direct testing instead of extracting the model out
|
||||
oldSigmaI = bonsaiObj.sigmaI
|
||||
bonsaiObj.sigmaI = 1e9
|
||||
|
||||
testAcc = bonsaiObj.accuracy.eval(feed_dict=_feed_dict)
|
||||
if ihtDone == 0:
|
||||
maxTestAcc = -10000
|
||||
maxTestAccEpoch = i
|
||||
else:
|
||||
if maxTestAcc <= testAcc:
|
||||
maxTestAccEpoch = i
|
||||
maxTestAcc = testAcc
|
||||
|
||||
print("Test accuracy %g"%testAcc)
|
||||
|
||||
testLoss = bonsaiObj.loss.eval(feed_dict=_feed_dict)
|
||||
regTestLoss = bonsaiObj.regLoss.eval(feed_dict=_feed_dict)
|
||||
print("MarginLoss + RegLoss: " + str(testLoss - regTestLoss) + " + " + str(regTestLoss) + " = " + str(testLoss) + "\n")
|
||||
|
||||
bonsaiObj.sigmaI = oldSigmaI
|
||||
sys.stdout.flush()
|
||||
|
||||
print("Maximum Test accuracy at compressed model size(including early stopping): "
|
||||
+ str(maxTestAcc) + " at Epoch: " + str(maxTestAccEpoch) + "\nFinal Test Accuracy: " + str(testAcc))
|
||||
print("\nNon-Zeros: " + str(bonsaiObj.getModelSize()) + " Model Size: " +
|
||||
str(float(bonsaiObj.getModelSize())/1024.0) + " KB \n")
|
||||
|
||||
np.save("W.npy", bonsaiObj.Weval.eval())
|
||||
np.save("V.npy", bonsaiObj.Veval.eval())
|
||||
np.save("Z.npy", bonsaiObj.Zeval.eval())
|
||||
np.save("T.npy", bonsaiObj.Teval.eval())
|
||||
np.save("W.npy", bonsaiTrainer.bonsaiObj.W.eval())
|
||||
np.save("V.npy", bonsaiTrainer.bonsaiObj.V.eval())
|
||||
np.save("Z.npy", bonsaiTrainer.bonsaiObj.Z.eval())
|
||||
np.save("T.npy", bonsaiTrainer.bonsaiObj.T.eval())
|
||||
|
|
|
@ -2,138 +2,179 @@ import tensorflow as tf
|
|||
import numpy as np
|
||||
import argparse
|
||||
|
||||
## Functions to check sanity of input arguments
|
||||
# Functions to check sanity of input arguments
|
||||
|
||||
|
||||
def checkIntPos(value):
|
||||
ivalue = int(value)
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
|
||||
return ivalue
|
||||
ivalue = int(value)
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid positive int value" % value)
|
||||
return ivalue
|
||||
|
||||
|
||||
def checkIntNneg(value):
|
||||
ivalue = int(value)
|
||||
if ivalue < 0:
|
||||
raise argparse.ArgumentTypeError("%s is an invalid non-neg int value" % value)
|
||||
return ivalue
|
||||
ivalue = int(value)
|
||||
if ivalue < 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid non-neg int value" % value)
|
||||
return ivalue
|
||||
|
||||
|
||||
def checkFloatNneg(value):
|
||||
fvalue = float(value)
|
||||
if fvalue < 0:
|
||||
raise argparse.ArgumentTypeError("%s is an invalid non-neg float value" % value)
|
||||
return fvalue
|
||||
fvalue = float(value)
|
||||
if fvalue < 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid non-neg float value" % value)
|
||||
return fvalue
|
||||
|
||||
|
||||
def checkFloatPos(value):
|
||||
fvalue = float(value)
|
||||
if fvalue <= 0:
|
||||
raise argparse.ArgumentTypeError("%s is an invalid positive float value" % value)
|
||||
return fvalue
|
||||
fvalue = float(value)
|
||||
if fvalue <= 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid positive float value" % value)
|
||||
return fvalue
|
||||
|
||||
# Function to parse the input arguments
|
||||
|
||||
|
||||
## Function to parse the input arguments
|
||||
def getArgs():
|
||||
parser = argparse.ArgumentParser(description='HyperParams for Bonsai Algorithm')
|
||||
parser.add_argument('-dir', '--data_dir', required=True, help='Data directory containing train.npy and test.npy')
|
||||
|
||||
parser.add_argument('-d', '--depth', type=checkIntNneg, default=2, help='Depth of Bonsai Tree (default: 2 try: [0, 1, 3])')
|
||||
parser.add_argument('-p', '--projDim', type=checkIntPos, default=10, help='Projection Dimension (default: 20 try: [5, 10, 30])')
|
||||
parser.add_argument('-s', '--sigma', type=float, default=1.0, help='Parameter for sigmoid sharpness (default: 1.0 try: [3.0, 0.05, 0.1]')
|
||||
parser.add_argument('-e', '--epochs', type=checkIntPos, default=42, help='Total Epochs (default: 42 try:[100, 150, 60])')
|
||||
parser.add_argument('-b', '--batchSize', type=checkIntPos, help='Batch Size to be used (default: max(100, sqrt(train_samples)))')
|
||||
parser.add_argument('-lr', '--learningRate', type=checkFloatPos, default=0.01, help='Initial Learning rate for Adam Oprimizer (default: 0.01)')
|
||||
parser = argparse.ArgumentParser(
|
||||
description='HyperParams for Bonsai Algorithm')
|
||||
parser.add_argument('-dir', '--data_dir', required=True,
|
||||
help='Data directory containing train.npy and test.npy')
|
||||
|
||||
parser.add_argument('-rW', type=float, default=0.0001, help='Regularizer for predictor parameter W (default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rV', type=float, default=0.0001, help='Regularizer for predictor parameter V (default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rT', type=float, default=0.0001, help='Regularizer for branching parameter Theta (default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rZ', type=float, default=0.00001, help='Regularizer for projection parameter Z (default: 0.00001 try: [0.001, 0.0001, 0.000001])')
|
||||
parser.add_argument('-d', '--depth', type=checkIntNneg, default=2,
|
||||
help='Depth of Bonsai Tree (default: 2 try: [0, 1, 3])')
|
||||
parser.add_argument('-p', '--projDim', type=checkIntPos, default=10,
|
||||
help='Projection Dimension (default: 20 try: [5, 10, 30])')
|
||||
parser.add_argument('-s', '--sigma', type=float, default=1.0,
|
||||
help='Parameter for sigmoid sharpness (default: 1.0 try: [3.0, 0.05, 0.1]')
|
||||
parser.add_argument('-e', '--epochs', type=checkIntPos, default=42,
|
||||
help='Total Epochs (default: 42 try:[100, 150, 60])')
|
||||
parser.add_argument('-b', '--batchSize', type=checkIntPos,
|
||||
help='Batch Size to be used (default: max(100, sqrt(train_samples)))')
|
||||
parser.add_argument('-lr', '--learningRate', type=checkFloatPos, default=0.01,
|
||||
help='Initial Learning rate for Adam Oprimizer (default: 0.01)')
|
||||
|
||||
parser.add_argument('-sW', type=checkFloatPos, help='Sparsity for predictor parameter W (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sV', type=checkFloatPos, help='Sparsity for predictor parameter V (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sT', type=checkFloatPos, help='Sparsity for branching parameter Theta (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sZ', type=checkFloatPos, default=0.2, help='Sparsity for projection parameter Z (default: 0.2 try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-rW', type=float, default=0.0001,
|
||||
help='Regularizer for predictor parameter W (default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rV', type=float, default=0.0001,
|
||||
help='Regularizer for predictor parameter V (default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rT', type=float, default=0.0001,
|
||||
help='Regularizer for branching parameter Theta (default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rZ', type=float, default=0.00001,
|
||||
help='Regularizer for projection parameter Z (default: 0.00001 try: [0.001, 0.0001, 0.000001])')
|
||||
|
||||
parser.add_argument('-sW', type=checkFloatPos,
|
||||
help='Sparsity for predictor parameter W (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sV', type=checkFloatPos,
|
||||
help='Sparsity for predictor parameter V (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sT', type=checkFloatPos,
|
||||
help='Sparsity for branching parameter Theta (default: For Binary classification 1.0 else 0.2 try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sZ', type=checkFloatPos, default=0.2,
|
||||
help='Sparsity for projection parameter Z (default: 0.2 try: [0.1, 0.3, 0.5])')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
# Function for Multi Class Hinge Loss : TF has no internal implementation
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
## Function for Multi Class Hinge Loss : TF has no internal implementation
|
||||
def multiClassHingeLoss(logits, label, batch_th):
|
||||
flatLogits = tf.reshape(logits, [-1,])
|
||||
correctId = tf.range(0, batch_th) * logits.shape[1] + label
|
||||
correctLogit = tf.gather(flatLogits, correctId)
|
||||
flatLogits = tf.reshape(logits, [-1, ])
|
||||
correctId = tf.range(0, batch_th) * logits.shape[1] + label
|
||||
correctLogit = tf.gather(flatLogits, correctId)
|
||||
|
||||
maxLabel = tf.argmax(logits, 1)
|
||||
top2, _ = tf.nn.top_k(logits, k=2, sorted=True)
|
||||
maxLabel = tf.argmax(logits, 1)
|
||||
top2, _ = tf.nn.top_k(logits, k=2, sorted=True)
|
||||
|
||||
wrongMaxLogit = tf.where(tf.equal(maxLabel, label), top2[:,1], top2[:,0])
|
||||
wrongMaxLogit = tf.where(tf.equal(maxLabel, label), top2[:, 1], top2[:, 0])
|
||||
|
||||
return tf.reduce_mean(tf.nn.relu(1. + wrongMaxLogit - correctLogit))
|
||||
|
||||
# Function for cross entropy loss in multiclass case (for faster convergence in joint training)
|
||||
|
||||
return tf.reduce_mean(tf.nn.relu(1. + wrongMaxLogit - correctLogit))
|
||||
|
||||
## Function for cross entropy loss in multiclass case (for faster convergence in joint training)
|
||||
def crossEntropyLoss(logits, label):
|
||||
return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=label))
|
||||
return tf.reduce_mean(
|
||||
tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=label))
|
||||
|
||||
# Hard Thresholding Function
|
||||
|
||||
|
||||
## Hard Thresholding Function
|
||||
def hardThreshold(A, s):
|
||||
A_ = np.copy(A)
|
||||
A_ = A_.ravel()
|
||||
if len(A_) > 0:
|
||||
th = np.percentile(np.abs(A_), (1 - s)*100.0, interpolation='higher')
|
||||
A_[np.abs(A_)<th] = 0.0
|
||||
A_ = A_.reshape(A.shape)
|
||||
return A_
|
||||
A_ = np.copy(A)
|
||||
A_ = A_.ravel()
|
||||
if len(A_) > 0:
|
||||
th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher')
|
||||
A_[np.abs(A_) < th] = 0.0
|
||||
A_ = A_.reshape(A.shape)
|
||||
return A_
|
||||
|
||||
# Copy the support of src onto dest tensor
|
||||
|
||||
|
||||
## Copy the support of src onto dest tensor
|
||||
def copySupport(src, dest):
|
||||
support = np.nonzero(src)
|
||||
dest_ = dest
|
||||
dest = np.zeros(dest_.shape)
|
||||
dest[support] = dest_[support]
|
||||
return dest
|
||||
support = np.nonzero(src)
|
||||
dest_ = dest
|
||||
dest = np.zeros(dest_.shape)
|
||||
dest[support] = dest_[support]
|
||||
return dest
|
||||
|
||||
# Function to read and pre-process data
|
||||
|
||||
|
||||
## Function to read and pre-process data
|
||||
def preProcessData(data_dir):
|
||||
train = np.load(data_dir + '/train.npy')
|
||||
test = np.load(data_dir + '/test.npy')
|
||||
train = np.load(data_dir + '/train.npy')
|
||||
test = np.load(data_dir + '/test.npy')
|
||||
|
||||
dataDimension = int(train.shape[1]) - 1
|
||||
dataDimension = int(train.shape[1]) - 1
|
||||
|
||||
Xtrain = train[:, 1:dataDimension+1]
|
||||
Ytrain_ = train[:, 0]
|
||||
numClasses = max(Ytrain_) - min(Ytrain_) + 1
|
||||
Xtrain = train[:, 1:dataDimension + 1]
|
||||
Ytrain_ = train[:, 0]
|
||||
numClasses = max(Ytrain_) - min(Ytrain_) + 1
|
||||
|
||||
Xtest = test[:,1:dataDimension+1]
|
||||
Ytest_ = test[:,0]
|
||||
Xtest = test[:, 1:dataDimension + 1]
|
||||
Ytest_ = test[:, 0]
|
||||
|
||||
numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1))
|
||||
numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1))
|
||||
|
||||
# Mean Var Normalisation
|
||||
mean = np.mean(Xtrain,0)
|
||||
std = np.std(Xtrain,0)
|
||||
std[std[:]<0.000001] = 1
|
||||
Xtrain = (Xtrain-mean)/std
|
||||
# Mean Var Normalisation
|
||||
mean = np.mean(Xtrain, 0)
|
||||
std = np.std(Xtrain, 0)
|
||||
std[std[:] < 0.000001] = 1
|
||||
Xtrain = (Xtrain - mean) / std
|
||||
|
||||
Xtest = (Xtest-mean)/std
|
||||
# End Mean Var normalisation
|
||||
Xtest = (Xtest - mean) / std
|
||||
# End Mean Var normalisation
|
||||
|
||||
lab = Ytrain_.astype('uint8')
|
||||
lab = np.array(lab) - min(lab)
|
||||
lab = Ytrain_.astype('uint8')
|
||||
lab = np.array(lab) - min(lab)
|
||||
|
||||
lab_ = np.zeros((Xtrain.shape[0], numClasses))
|
||||
lab_[np.arange(Xtrain.shape[0]), lab] = 1
|
||||
if (numClasses == 2):
|
||||
Ytrain = np.reshape(lab, [-1, 1])
|
||||
else:
|
||||
Ytrain = lab_
|
||||
lab_ = np.zeros((Xtrain.shape[0], numClasses))
|
||||
lab_[np.arange(Xtrain.shape[0]), lab] = 1
|
||||
if (numClasses == 2):
|
||||
Ytrain = np.reshape(lab, [-1, 1])
|
||||
else:
|
||||
Ytrain = lab_
|
||||
|
||||
lab = Ytest_.astype('uint8')
|
||||
lab = np.array(lab)-min(lab)
|
||||
lab = Ytest_.astype('uint8')
|
||||
lab = np.array(lab) - min(lab)
|
||||
|
||||
lab_ = np.zeros((Xtest.shape[0], numClasses))
|
||||
lab_[np.arange(Xtest.shape[0]), lab] = 1
|
||||
if (numClasses == 2):
|
||||
Ytest = np.reshape(lab, [-1, 1])
|
||||
else:
|
||||
Ytest = lab_
|
||||
lab_ = np.zeros((Xtest.shape[0], numClasses))
|
||||
lab_[np.arange(Xtest.shape[0]), lab] = 1
|
||||
if (numClasses == 2):
|
||||
Ytest = np.reshape(lab, [-1, 1])
|
||||
else:
|
||||
Ytest = lab_
|
||||
|
||||
trainBias = np.ones([Xtrain.shape[0], 1]);
|
||||
Xtrain = np.append(Xtrain, trainBias, axis=1)
|
||||
testBias = np.ones([Xtest.shape[0], 1]);
|
||||
Xtest = np.append(Xtest, testBias, axis=1)
|
||||
trainBias = np.ones([Xtrain.shape[0], 1])
|
||||
Xtrain = np.append(Xtrain, trainBias, axis=1)
|
||||
testBias = np.ones([Xtest.shape[0], 1])
|
||||
Xtest = np.append(Xtest, testBias, axis=1)
|
||||
|
||||
return dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest
|
||||
if numClasses == 2:
|
||||
numClasses = 1
|
||||
|
||||
return dataDimension + 1, numClasses, Xtrain, Ytrain, Xtest, Ytest
|
||||
|
|
Загрузка…
Ссылка в новой задаче