diff --git a/pytorch/examples/Bonsai/README.md b/pytorch/examples/Bonsai/README.md new file mode 100644 index 00000000..91cb0021 --- /dev/null +++ b/pytorch/examples/Bonsai/README.md @@ -0,0 +1,67 @@ +# EdgeML Bonsai on a sample public dataset + +This directory includes, example notebook and general execution script of +Bonsai developed as part of EdgeML. Also, we include a sample cleanup and +use-case on the USPS10 public dataset. + +`edgeml.graph.bonsai` implements the Bonsai prediction graph in tensorflow. +The three-phase training routine for Bonsai is decoupled from the forward graph +to facilitate a plug and play behaviour wherein Bonsai can be combined with or +used as a final layer classifier for other architectures (RNNs, CNNs). + +Note that `bonsai_example.py` assumes that data is in a specific format. It is +assumed that train and test data is contained in two files, `train.npy` and +`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples, +numberOfFeatures + 1]`. The first column of each matrix is assumed to contain +label information. For an N-Class problem, we assume the labels are integers +from 0 through N-1. `bonsai_example.py` also supports univariate regression +and can be accessed using the help options of the script. Multivariate regression +requires restructuring of the input data format and can further help in extending +bonsai to multi-label classification and multi-variate regression. Lastly, +the training data, `train.npy`, is assumed to well shuffled +as the training routine doesn't shuffle internally. + +**Tested With:** Tensorflow >1.6 with Python 2 and Python 3 + +## Download and clean up sample dataset + +We will be testing out the validation of the code by using the USPS dataset. +The download and cleanup of the dataset to match the above-mentioned format is +done by the script [fetch_usps.py](fetch_usps.py) and +[process_usps.py](process_usps.py) + +``` +python fetch_usps.py +python process_usps.py +``` + +## Sample command for Bonsai on USPS10 +The following sample run on usps10 should validate your library: + +```bash +python bonsai_example.py -dir usps10/ -d 3 -p 28 -rW 0.001 -rZ 0.0001 -rV 0.001 -rT 0.001 -sZ 0.2 -sW 0.3 -sV 0.3 -sT 0.62 -e 100 -s 1 +``` +This command should give you a final output screen which reads roughly similar to (might not be exact numbers due to various version mismatches): +``` +Maximum Test accuracy at compressed model size(including early stopping): 0.94369704 at Epoch: 66 +Final Test Accuracy: 0.93024415 + +Non-Zeros: 4156.0 Model Size: 31.703125 KB hasSparse: True +``` + +usps10 directory will now have a consolidated results file called `TFBonsaiResults.txt` and a directory `TFBonsaiResults` with the corresponding models with each run of the code on the usps10 dataset + +## Byte Quantization (Q) for model compression +If you wish to quantize the generated model to use byte quantized integers use `quantizeBonsaiModels.py`. Usage Instructions: + +``` +python quantizeBonsaiModels.py -h +``` + +This will generate quantized models with a suffix of `q` before every param stored in a new directory `QuantizedTFBonsaiModel` inside the model directory. +One can use this model further on edge devices. + + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT license. diff --git a/pytorch/examples/Bonsai/bonsai_example.py b/pytorch/examples/Bonsai/bonsai_example.py new file mode 100644 index 00000000..0d808906 --- /dev/null +++ b/pytorch/examples/Bonsai/bonsai_example.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import helpermethods +import numpy as np +import sys +from pytorch_edgeml.trainer.bonsaiTrainer import BonsaiTrainer +from pytorch_edgeml.graph.bonsai import Bonsai +import torch + + +def main(): + # Fixing seeds for reproducibility + torch.manual_seed(42) + np.random.seed(42) + + # Hyper Param pre-processing + args = helpermethods.getArgs() + + sigma = args.sigma + depth = args.depth + + projectionDimension = args.proj_dim + regZ = args.rZ + regT = args.rT + regW = args.rW + regV = args.rV + + totalEpochs = args.epochs + + learningRate = args.learning_rate + + dataDir = args.data_dir + + outFile = args.output_file + + (dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest, + mean, std) = helpermethods.preProcessData(dataDir) + + sparZ = args.sZ + + if numClasses > 2: + sparW = 0.2 + sparV = 0.2 + sparT = 0.2 + else: + sparW = 1 + sparV = 1 + sparT = 1 + + if args.sW is not None: + sparW = args.sW + if args.sV is not None: + sparV = args.sV + if args.sT is not None: + sparT = args.sT + + if args.batch_size is None: + batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0])))) + else: + batchSize = args.batch_size + + useMCHLoss = True + + if numClasses == 2: + numClasses = 1 + + currDir = helpermethods.createTimeStampDir(dataDir) + + helpermethods.dumpCommand(sys.argv, currDir) + helpermethods.saveMeanStd(mean, std, currDir) + + # numClasses = 1 for binary case + bonsaiObj = Bonsai(numClasses, dataDimension, + projectionDimension, depth, sigma) + + bonsaiTrainer = BonsaiTrainer(bonsaiObj, + regW, regT, regV, regZ, + sparW, sparT, sparV, sparZ, + learningRate, useMCHLoss, outFile) + + bonsaiTrainer.train(batchSize, totalEpochs, + torch.from_numpy(Xtrain.astype(np.float32)), + torch.from_numpy(Xtest.astype(np.float32)), + torch.from_numpy(Ytrain.astype(np.float32)), + torch.from_numpy(Ytest.astype(np.float32)), + dataDir, currDir) + + sys.stdout.close() + + +if __name__ == '__main__': + main() diff --git a/pytorch/examples/Bonsai/fetch_usps.py b/pytorch/examples/Bonsai/fetch_usps.py new file mode 100644 index 00000000..c1b2e072 --- /dev/null +++ b/pytorch/examples/Bonsai/fetch_usps.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +# +# Setting up the USPS Data. + +import subprocess +import os +import numpy as np +from sklearn.datasets import load_svmlight_file +import sys + +def downloadData(workingDir, downloadDir, linkTrain, linkTest): + def runcommand(command): + p = subprocess.Popen(command.split(), stdout=subprocess.PIPE) + output, error = p.communicate() + assert(p.returncode == 0), 'Command failed: %s' % command + + path = workingDir + '/' + downloadDir + path = os.path.abspath(path) + try: + os.mkdir(path) + except OSError: + print("Could not create %s. Make sure the path does" % path) + print("not already exist and you have permisions to create it.") + return False + cwd = os.getcwd() + os.chdir(path) + print("Downloading data") + command = 'wget %s' % linkTrain + runcommand(command) + command = 'wget %s' % linkTest + runcommand(command) + print("Extracting data") + command = 'bzip2 -d usps.bz2' + runcommand(command) + command = 'bzip2 -d usps.t.bz2' + runcommand(command) + command = 'mv usps train.txt' + runcommand(command) + command = 'mv usps.t test.txt' + runcommand(command) + os.chdir(cwd) + return True + +if __name__ == '__main__': + workingDir = './' + downloadDir = 'usps10' + linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2' + linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2' + failureMsg = ''' +Download Failed! +To manually perform the download +\t1. Create a new empty directory named `usps10`. +\t2. Download the data from the following links into the usps10 directory. +\t\tTest: %s +\t\tTrain: %s +\t3. Extract the downloaded files. +\t4. Rename `usps` to `train.txt` and, +\t5. Rename `usps.t` to `test.txt +''' % (linkTrain, linkTest) + + if not downloadData(workingDir, downloadDir, linkTrain, linkTest): + exit(failureMsg) + print("Done") diff --git a/pytorch/examples/Bonsai/helpermethods.py b/pytorch/examples/Bonsai/helpermethods.py new file mode 100644 index 00000000..3c1f14d0 --- /dev/null +++ b/pytorch/examples/Bonsai/helpermethods.py @@ -0,0 +1,258 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import argparse +import datetime +import os +import numpy as np + +''' + Functions to check sanity of input arguments + for the example script. +''' + + +def checkIntPos(value): + ivalue = int(value) + if ivalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive int value" % value) + return ivalue + + +def checkIntNneg(value): + ivalue = int(value) + if ivalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg int value" % value) + return ivalue + + +def checkFloatNneg(value): + fvalue = float(value) + if fvalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg float value" % value) + return fvalue + + +def checkFloatPos(value): + fvalue = float(value) + if fvalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive float value" % value) + return fvalue + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def getArgs(): + ''' + Function to parse arguments for Bonsai Algorithm + ''' + parser = argparse.ArgumentParser( + description='HyperParams for Bonsai Algorithm') + parser.add_argument('-dir', '--data-dir', required=True, + help='Data directory containing' + + 'train.npy and test.npy') + + parser.add_argument('-d', '--depth', type=checkIntNneg, default=2, + help='Depth of Bonsai Tree ' + + '(default: 2 try: [0, 1, 3])') + parser.add_argument('-p', '--proj-dim', type=checkIntPos, default=10, + help='Projection Dimension ' + + '(default: 20 try: [5, 10, 30])') + parser.add_argument('-s', '--sigma', type=float, default=1.0, + help='Parameter for sigmoid sharpness ' + + '(default: 1.0 try: [3.0, 0.05, 0.1]') + parser.add_argument('-e', '--epochs', type=checkIntPos, default=42, + help='Total Epochs (default: 42 try:[100, 150, 60])') + parser.add_argument('-b', '--batch-size', type=checkIntPos, + help='Batch Size to be used ' + + '(default: max(100, sqrt(train_samples)))') + parser.add_argument('-lr', '--learning-rate', type=checkFloatPos, + default=0.01, help='Initial Learning rate for ' + + 'Adam Optimizer (default: 0.01)') + + parser.add_argument('-rW', type=float, default=0.0001, + help='Regularizer for predictor parameter W ' + + '(default: 0.0001 try: [0.01, 0.001, 0.00001])') + parser.add_argument('-rV', type=float, default=0.0001, + help='Regularizer for predictor parameter V ' + + '(default: 0.0001 try: [0.01, 0.001, 0.00001])') + parser.add_argument('-rT', type=float, default=0.0001, + help='Regularizer for branching parameter Theta ' + + '(default: 0.0001 try: [0.01, 0.001, 0.00001])') + parser.add_argument('-rZ', type=float, default=0.00001, + help='Regularizer for projection parameter Z ' + + '(default: 0.00001 try: [0.001, 0.0001, 0.000001])') + + parser.add_argument('-sW', type=checkFloatPos, + help='Sparsity for predictor parameter W ' + + '(default: For Binary classification 1.0 else 0.2 ' + + 'try: [0.1, 0.3, 0.5])') + parser.add_argument('-sV', type=checkFloatPos, + help='Sparsity for predictor parameter V ' + + '(default: For Binary classification 1.0 else 0.2 ' + + 'try: [0.1, 0.3, 0.5])') + parser.add_argument('-sT', type=checkFloatPos, + help='Sparsity for branching parameter Theta ' + + '(default: For Binary classification 1.0 else 0.2 ' + + 'try: [0.1, 0.3, 0.5])') + parser.add_argument('-sZ', type=checkFloatPos, default=0.2, + help='Sparsity for projection parameter Z ' + + '(default: 0.2 try: [0.1, 0.3, 0.5])') + parser.add_argument('-oF', '--output-file', default=None, + help='Output file for dumping the program output, ' + + '(default: stdout)') + + return parser.parse_args() + + +def getQuantArgs(): + ''' + Function to parse arguments for Model Quantisation + ''' + parser = argparse.ArgumentParser( + description='Arguments for quantizing Fast models. ' + + 'Works only for piece-wise linear non-linearities, ' + + 'like relu, quantTanh, quantSigm (check rnn.py for the definitions)') + parser.add_argument('-dir', '--model-dir', required=True, + help='model directory containing' + + '*.npy weight files dumped from the trained model') + parser.add_argument('-m', '--max-val', type=checkIntNneg, default=127, + help='this represents the maximum possible value ' + + 'in model, essentially the byte complexity, ' + + '127=> 1 byte is default') + + return parser.parse_args() + + +def createTimeStampDir(dataDir): + ''' + Creates a Directory with timestamp as it's name + ''' + if os.path.isdir(dataDir + '/pytorchBonsaiResults') is False: + try: + os.mkdir(dataDir + '/pytorchBonsaiResults') + except OSError: + print("Creation of the directory %s failed" % + dataDir + '/pytorchBonsaiResults') + + currDir = 'pytorchBonsaiResults/' + \ + datetime.datetime.now().strftime("%H_%M_%S_%d_%m_%y") + if os.path.isdir(dataDir + '/' + currDir) is False: + try: + os.mkdir(dataDir + '/' + currDir) + except OSError: + print("Creation of the directory %s failed" % + dataDir + '/' + currDir) + else: + return (dataDir + '/' + currDir) + return None + + +def preProcessData(dataDir): + ''' + Function to pre-process input data + Expects a .npy file of form [lbl feats] for each datapoint + Outputs a train and test set datapoints appended with 1 for Bias induction + dataDimension, numClasses are inferred directly + ''' + train = np.load(dataDir + '/train.npy') + test = np.load(dataDir + '/test.npy') + + dataDimension = int(train.shape[1]) - 1 + + Xtrain = train[:, 1:dataDimension + 1] + Ytrain_ = train[:, 0] + + Xtest = test[:, 1:dataDimension + 1] + Ytest_ = test[:, 0] + + # Mean Var Normalisation + mean = np.mean(Xtrain, 0) + std = np.std(Xtrain, 0) + std[std[:] < 0.000001] = 1 + Xtrain = (Xtrain - mean) / std + Xtest = (Xtest - mean) / std + # End Mean Var normalisation + + # Classification. + + numClasses = max(Ytrain_) - min(Ytrain_) + 1 + numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1)) + + lab = Ytrain_.astype('uint8') + lab = np.array(lab) - min(lab) + + lab_ = np.zeros((Xtrain.shape[0], numClasses)) + lab_[np.arange(Xtrain.shape[0]), lab] = 1 + if (numClasses == 2): + Ytrain = np.reshape(lab, [-1, 1]) + else: + Ytrain = lab_ + + lab = Ytest_.astype('uint8') + lab = np.array(lab) - min(lab) + + lab_ = np.zeros((Xtest.shape[0], numClasses)) + lab_[np.arange(Xtest.shape[0]), lab] = 1 + if (numClasses == 2): + Ytest = np.reshape(lab, [-1, 1]) + else: + Ytest = lab_ + + trainBias = np.ones([Xtrain.shape[0], 1]) + Xtrain = np.append(Xtrain, trainBias, axis=1) + testBias = np.ones([Xtest.shape[0], 1]) + Xtest = np.append(Xtest, testBias, axis=1) + + mean = np.append(mean, np.array([0])) + std = np.append(std, np.array([1])) + + return dataDimension + 1, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std + + +def dumpCommand(list, currDir): + ''' + Dumps the current command to a file for further use + ''' + commandFile = open(currDir + '/command.txt', 'w') + command = "python" + + command = command + " " + ' '.join(list) + commandFile.write(command) + + commandFile.flush() + commandFile.close() + + +def saveMeanStd(mean, std, currDir): + ''' + Function to save Mean and Std vectors + ''' + np.save(currDir + '/mean.npy', mean) + np.save(currDir + '/std.npy', std) + saveMeanStdSeeDot(mean, std, currDir + "/SeeDot") + + +def saveMeanStdSeeDot(mean, std, seeDotDir): + ''' + Function to save Mean and Std vectors + ''' + if os.path.isdir(seeDotDir) is False: + try: + os.mkdir(seeDotDir) + except OSError: + print("Creation of the directory %s failed" % + seeDotDir) + np.savetxt(seeDotDir + '/Mean', mean, delimiter="\t") + np.savetxt(seeDotDir + '/Std', std, delimiter="\t") diff --git a/pytorch/examples/Bonsai/process_usps.py b/pytorch/examples/Bonsai/process_usps.py new file mode 100644 index 00000000..252ba11e --- /dev/null +++ b/pytorch/examples/Bonsai/process_usps.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +# +# Processing the USPS Data. It is assumed that the data is already +# downloaded. + +import subprocess +import os +import numpy as np +from sklearn.datasets import load_svmlight_file +import sys + +def processData(workingDir, downloadDir): + def loadLibSVMFile(file): + data = load_svmlight_file(file) + features = data[0] + labels = data[1] + retMat = np.zeros([features.shape[0], features.shape[1] + 1]) + retMat[:, 0] = labels + retMat[:, 1:] = features.todense() + return retMat + + path = workingDir + '/' + downloadDir + path = os.path.abspath(path) + trf = path + '/train.txt' + tsf = path + '/test.txt' + assert os.path.isfile(trf), 'File not found: %s' % trf + assert os.path.isfile(tsf), 'File not found: %s' % tsf + train = loadLibSVMFile(trf) + test = loadLibSVMFile(tsf) + + # Convert the labels from 0 to numClasses-1 + y_train = train[:, 0] + y_test = test[:, 0] + + lab = y_train.astype('uint8') + lab = np.array(lab) - min(lab) + train[:, 0] = lab + + lab = y_test.astype('uint8') + lab = np.array(lab) - min(lab) + test[:, 0] = lab + + np.save(path + '/train.npy', train) + np.save(path + '/test.npy', test) + +if __name__ == '__main__': + # Configuration + workingDir = './' + downloadDir = 'usps10' + # End config + print("Processing data") + processData(workingDir, downloadDir) + print("Done") diff --git a/pytorch/examples/Bonsai/quantizeBonsaiModels.py b/pytorch/examples/Bonsai/quantizeBonsaiModels.py new file mode 100644 index 00000000..5d42aeda --- /dev/null +++ b/pytorch/examples/Bonsai/quantizeBonsaiModels.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import helpermethods +import os +import numpy as np + + +def min_max(A, name): + print(name + " has max: " + str(np.max(A)) + " min: " + str(np.min(A))) + return np.max([np.abs(np.max(A)), np.abs(np.min(A))]) + + +def quantizeBonsaiModels(modelDir, maxValue=127, scalarScaleFactor=1000): + ls = os.listdir(modelDir) + paramNameList = [] + paramWeightList = [] + paramLimitList = [] + + for file in ls: + if file.endswith("npy"): + if file.startswith("mean") or file.startswith("std") or file.startswith("hyperParam"): + continue + else: + paramNameList.append(file) + temp = np.load(modelDir + "/" + file) + paramWeightList.append(temp) + paramLimitList.append(min_max(temp, file)) + + paramLimit = np.max(paramLimitList) + + paramScaleFactor = np.round((2.0 * maxValue + 1.0) / (2.0 * paramLimit)) + + quantParamWeights = [] + for param in paramWeightList: + temp = np.round(paramScaleFactor * param) + temp[temp[:] > maxValue] = maxValue + temp[temp[:] < -maxValue] = -1 * (maxValue + 1) + + if maxValue <= 127: + temp = temp.astype('int8') + elif maxValue <= 32767: + temp = temp.astype('int16') + else: + temp = temp.astype('int32') + + quantParamWeights.append(temp) + + if os.path.isdir(modelDir + '/QuantizedTFBonsaiModel') is False: + try: + os.mkdir(modelDir + '/QuantizedTFBonsaiModel') + quantModelDir = modelDir + '/QuantizedTFBonsaiModel' + except OSError: + print("Creation of the directory %s failed" % + modelDir + '/QuantizedTFBonsaiModel') + + np.save(quantModelDir + "/paramScaleFactor.npy", + paramScaleFactor.astype('int32')) + + for i in range(len(paramNameList)): + np.save(quantModelDir + "/q" + paramNameList[i], quantParamWeights[i]) + + print("\n\nQuantized Model Dir: " + quantModelDir) + + +def main(): + args = helpermethods.getQuantArgs() + quantizeBonsaiModels(args.model_dir, int(args.max_val)) + + +if __name__ == '__main__': + main() diff --git a/pytorch/pytorch_edgeml/graph/bonsai.py b/pytorch/pytorch_edgeml/graph/bonsai.py new file mode 100644 index 00000000..29d4ddeb --- /dev/null +++ b/pytorch/pytorch_edgeml/graph/bonsai.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import numpy as np + + +class Bonsai(nn.Module): + + def __init__(self, numClasses, dataDimension, projectionDimension, + treeDepth, sigma, W=None, T=None, V=None, Z=None): + super(Bonsai, self).__init__() + ''' + Expected Dimensions: + + Bonsai Params // Optional + W [numClasses*totalNodes, projectionDimension] + V [numClasses*totalNodes, projectionDimension] + Z [projectionDimension, dataDimension + 1] + T [internalNodes, projectionDimension] + + internalNodes = 2**treeDepth - 1 + totalNodes = 2*internalNodes + 1 + + sigma - tanh non-linearity + sigmaI - Indicator function for node probabilities + sigmaI - has to be set to infinity(1e9 for practice) + while doing testing/inference + numClasses will be reset to 1 in binary case + ''' + + self.dataDimension = dataDimension + self.projectionDimension = projectionDimension + + if numClasses == 2: + self.numClasses = 1 + else: + self.numClasses = numClasses + + self.treeDepth = treeDepth + self.sigma = sigma + + self.internalNodes = 2**self.treeDepth - 1 + self.totalNodes = 2 * self.internalNodes + 1 + + self.W = self.initW(W) + self.V = self.initV(V) + self.T = self.initT(T) + self.Z = self.initZ(Z) + + self.assertInit() + + def initZ(self, Z): + if Z is None: + Z = torch.randn([self.projectionDimension, self.dataDimension]) + Z = nn.Parameter(Z) + else: + Z.data = torch.from_numpy(Z.astype(np.float32)) + return Z + + def initW(self, W): + if W is None: + W = torch.randn( + [self.numClasses * self.totalNodes, self.projectionDimension]) + W = nn.Parameter(W) + else: + W.data = torch.from_numpy(W.astype(np.float32)) + return W + + def initV(self, V): + if V is None: + V = torch.randn( + [self.numClasses * self.totalNodes, self.projectionDimension]) + V = nn.Parameter(V) + else: + V.data = torch.from_numpy(V.astype(np.float32)) + return V + + def initT(self, T): + if T is None: + T = torch.randn([self.internalNodes, self.projectionDimension]) + T = nn.Parameter(T) + else: + T.data = torch.from_numpy(T.astype(np.float32)) + return T + + def forward(self, X, sigmaI): + ''' + Function to build/exxecute the Bonsai Tree graph + Expected Dimensions + + X is [batchSize, self.dataDimension] + sigmaI is constant + ''' + X_ = torch.matmul(self.Z, torch.t(X)) / self.projectionDimension + W_ = self.W[0:(self.numClasses)] + V_ = self.V[0:(self.numClasses)] + self.__nodeProb = [] + self.__nodeProb.append(1) + score_ = self.__nodeProb[0] * (torch.matmul(W_, X_) * + torch.tanh(self.sigma * + torch.matmul(V_, X_))) + for i in range(1, self.totalNodes): + W_ = self.W[i * self.numClasses:((i + 1) * self.numClasses)] + V_ = self.V[i * self.numClasses:((i + 1) * self.numClasses)] + + T_ = torch.reshape(self.T[int(np.ceil(i / 2.0) - 1.0)], + [-1, self.projectionDimension]) + prob = (1 + ((-1)**(i + 1)) * + torch.tanh(sigmaI * torch.matmul(T_, X_))) + + prob = prob / 2.0 + prob = self.__nodeProb[int(np.ceil(i / 2.0) - 1.0)] * prob + self.__nodeProb.append(prob) + score_ += self.__nodeProb[i] * (torch.matmul(W_, X_) * + torch.tanh(self.sigma * + torch.matmul(V_, X_))) + + self.score = score_ + self.X_ = X_ + return torch.t(self.score), self.X_ + + def assertInit(self): + errRank = "All Parameters must has only two dimensions shape = [a, b]" + assert len(self.W.shape) == len(self.Z.shape), errRank + assert len(self.W.shape) == len(self.T.shape), errRank + assert len(self.W.shape) == 2, errRank + msg = "W and V should be of same Dimensions" + assert self.W.shape == self.V.shape, msg + errW = "W and V are [numClasses*totalNodes, projectionDimension]" + assert self.W.shape[0] == self.numClasses * self.totalNodes, errW + assert self.W.shape[1] == self.projectionDimension, errW + errZ = "Z is [projectionDimension, dataDimension]" + assert self.Z.shape[0] == self.projectionDimension, errZ + assert self.Z.shape[1] == self.dataDimension, errZ + errT = "T is [internalNodes, projectionDimension]" + assert self.T.shape[0] == self.internalNodes, errT + assert self.T.shape[1] == self.projectionDimension, errT + assert int(self.numClasses) > 0, "numClasses should be > 1" + msg = "# of features in data should be > 0" + assert int(self.dataDimension) > 0, msg + msg = "Projection should be > 0 dims" + assert int(self.projectionDimension) > 0, msg + msg = "treeDepth should be >= 0" + assert int(self.treeDepth) >= 0, msg diff --git a/pytorch/pytorch_edgeml/trainer/bonsaiTrainer.py b/pytorch/pytorch_edgeml/trainer/bonsaiTrainer.py new file mode 100644 index 00000000..0a16f573 --- /dev/null +++ b/pytorch/pytorch_edgeml/trainer/bonsaiTrainer.py @@ -0,0 +1,397 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import numpy as np +import os +import sys +import pytorch_edgeml.utils as utils + + +class BonsaiTrainer: + + def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ, + learningRate, useMCHLoss=False, outFile=None): + ''' + bonsaiObj - Initialised Bonsai Object and Graph + lW, lT, lV and lZ are regularisers to Bonsai Params + sW, sT, sV and sZ are sparsity factors to Bonsai Params + learningRate - learningRate for optimizer + useMCHLoss - For choice between HingeLoss vs CrossEntropy + useMCHLoss - True - MultiClass - multiClassHingeLoss + useMCHLoss - False - MultiClass - crossEntropyLoss + ''' + + self.bonsaiObj = bonsaiObj + + self.lW = lW + self.lV = lV + self.lT = lT + self.lZ = lZ + + self.sW = sW + self.sV = sV + self.sT = sT + self.sZ = sZ + + self.useMCHLoss = useMCHLoss + + if outFile is not None: + print("Outfile : ", outFile) + self.outFile = open(outFile, 'w') + else: + self.outFile = sys.stdout + + self.learningRate = learningRate + + self.assertInit() + + self.optimizer = self.optimizer() + + if self.sW > 0.99 and self.sV > 0.99 and self.sZ > 0.99 and self.sT > 0.99: + self.isDenseTraining = True + else: + self.isDenseTraining = False + + def loss(self, logits, labels): + ''' + Loss function for given Bonsai Obj + ''' + regLoss = 0.5 * (self.lZ * (torch.norm(self.bonsaiObj.Z)**2) + + self.lW * (torch.norm(self.bonsaiObj.W)**2) + + self.lV * (torch.norm(self.bonsaiObj.V)**2) + + self.lT * (torch.norm(self.bonsaiObj.T))**2) + + if (self.bonsaiObj.numClasses > 2): + if self.useMCHLoss is True: + marginLoss = utils.multiClassHingeLoss(logits, labels) + else: + marginLoss = utils.crossEntropyLoss(logits, labels) + loss = marginLoss + regLoss + else: + marginLoss = utils.binaryHingeLoss(logits, labels) + loss = marginLoss + regLoss + + return loss, marginLoss, regLoss + + def optimizer(self): + ''' + Optimizer for Bonsai Params + ''' + optimizer = torch.optim.Adam( + self.bonsaiObj.parameters(), lr=self.learningRate) + + return optimizer + + def accuracy(self, logits, labels): + ''' + Accuracy fucntion to evaluate accuracy when needed + ''' + if (self.bonsaiObj.numClasses > 2): + correctPredictions = (logits.argmax(dim=1) == labels.argmax(dim=1)) + accuracy = torch.mean(correctPredictions.float()) + else: + pred = (torch.cat((torch.zeros(logits.shape), + logits), 1)).argmax(dim=1) + accuracy = torch.mean((labels.view(-1).long() == pred).float()) + + return accuracy + + def runHardThrsd(self): + ''' + Function to run the IHT routine on Bonsai Obj + ''' + currW = self.bonsaiObj.W.data + currV = self.bonsaiObj.V.data + currZ = self.bonsaiObj.Z.data + currT = self.bonsaiObj.T.data + + self.__thrsdW = utils.hardThreshold(currW, self.sW) + self.__thrsdV = utils.hardThreshold(currV, self.sV) + self.__thrsdZ = utils.hardThreshold(currZ, self.sZ) + self.__thrsdT = utils.hardThreshold(currT, self.sT) + + self.bonsaiObj.W.data = torch.FloatTensor(self.__thrsdW) + self.bonsaiObj.V.data = torch.FloatTensor(self.__thrsdV) + self.bonsaiObj.Z.data = torch.FloatTensor(self.__thrsdZ) + self.bonsaiObj.T.data = torch.FloatTensor(self.__thrsdT) + + def runSparseTraining(self): + ''' + Function to run the Sparse Retraining routine on Bonsai Obj + ''' + currW = self.bonsaiObj.W.data + currV = self.bonsaiObj.V.data + currZ = self.bonsaiObj.Z.data + currT = self.bonsaiObj.T.data + + newW = utils.copySupport(self.__thrsdW, currW) + newV = utils.copySupport(self.__thrsdV, currV) + newZ = utils.copySupport(self.__thrsdZ, currZ) + newT = utils.copySupport(self.__thrsdT, currT) + + self.bonsaiObj.W.data = torch.FloatTensor(newW) + self.bonsaiObj.V.data = torch.FloatTensor(newV) + self.bonsaiObj.Z.data = torch.FloatTensor(newZ) + self.bonsaiObj.T.data = torch.FloatTensor(newT) + + def assertInit(self): + err = "sparsity must be between 0 and 1" + assert self.sW >= 0 and self.sW <= 1, "W " + err + assert self.sV >= 0 and self.sV <= 1, "V " + err + assert self.sZ >= 0 and self.sZ <= 1, "Z " + err + assert self.sT >= 0 and self.sT <= 1, "T " + err + + def saveParams(self, currDir): + ''' + Function to save Parameter matrices into a given folder + ''' + paramDir = currDir + '/' + np.save(paramDir + "W.npy", self.bonsaiObj.W.data) + np.save(paramDir + "V.npy", self.bonsaiObj.V.data) + np.save(paramDir + "T.npy", self.bonsaiObj.T.data) + np.save(paramDir + "Z.npy", self.bonsaiObj.Z.data) + hyperParamDict = {'dataDim': self.bonsaiObj.dataDimension, + 'projDim': self.bonsaiObj.projectionDimension, + 'numClasses': self.bonsaiObj.numClasses, + 'depth': self.bonsaiObj.treeDepth, + 'sigma': self.bonsaiObj.sigma} + hyperParamFile = paramDir + 'hyperParam.npy' + np.save(hyperParamFile, hyperParamDict) + + def saveParamsForSeeDot(self, currDir): + ''' + Function to save Parameter matrices into a given folder for SeeDot compiler + ''' + seeDotDir = currDir + '/SeeDot/' + + if os.path.isdir(seeDotDir) is False: + try: + os.mkdir(seeDotDir) + except OSError: + print("Creation of the directory %s failed" % + seeDotDir) + + np.savetxt(seeDotDir + "W", + utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.W.data, + self.bonsaiObj.numClasses, + self.bonsaiObj.totalNodes), + delimiter="\t") + np.savetxt(seeDotDir + "V", + utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.V.data, + self.bonsaiObj.numClasses, + self.bonsaiObj.totalNodes), + delimiter="\t") + np.savetxt(seeDotDir + "T", self.bonsaiObj.T.data, delimiter="\t") + np.savetxt(seeDotDir + "Z", self.bonsaiObj.Z.data, delimiter="\t") + np.savetxt(seeDotDir + "Sigma", + np.array([self.bonsaiObj.sigma]), delimiter="\t") + + def loadModel(self, currDir): + ''' + Load the Saved model and load it to the model using constructor + Returns two dict one for params and other for hyperParams + ''' + paramDir = currDir + '/' + paramDict = {} + paramDict['W'] = np.load(paramDir + "W.npy") + paramDict['V'] = np.load(paramDir + "V.npy") + paramDict['T'] = np.load(paramDir + "T.npy") + paramDict['Z'] = np.load(paramDir + "Z.npy") + hyperParamDict = np.load(paramDir + "hyperParam.npy").item() + return paramDict, hyperParamDict + + # Function to get aimed model size + def getModelSize(self): + ''' + Function to get aimed model size + ''' + nnzZ, sizeZ, sparseZ = utils.countnnZ(self.bonsaiObj.Z, self.sZ) + nnzW, sizeW, sparseW = utils.countnnZ(self.bonsaiObj.W, self.sW) + nnzV, sizeV, sparseV = utils.countnnZ(self.bonsaiObj.V, self.sV) + nnzT, sizeT, sparseT = utils.countnnZ(self.bonsaiObj.T, self.sT) + + totalnnZ = (nnzZ + nnzT + nnzV + nnzW) + totalSize = (sizeZ + sizeW + sizeV + sizeT) + hasSparse = (sparseW or sparseV or sparseT or sparseZ) + return totalnnZ, totalSize, hasSparse + + def train(self, batchSize, totalEpochs, + Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir): + ''' + The Dense - IHT - Sparse Retrain Routine for Bonsai Training + ''' + resultFile = open(dataDir + '/pytorchBonsaiResults.txt', 'a+') + numIters = Xtrain.shape[0] / batchSize + + totalBatches = numIters * totalEpochs + + self.sigmaI = 1 + + counter = 0 + if self.bonsaiObj.numClasses > 2: + trimlevel = 15 + else: + trimlevel = 5 + ihtDone = 0 + + maxTestAcc = -10000 + if self.isDenseTraining is True: + ihtDone = 1 + self.sigmaI = 1 + itersInPhase = 0 + + header = '*' * 20 + for i in range(totalEpochs): + print("\nEpoch Number: " + str(i), file=self.outFile) + + ''' + trainAcc -> For Classification, it is 'Accuracy'. + ''' + trainAcc = 0.0 + trainLoss = 0.0 + + numIters = int(numIters) + for j in range(numIters): + + if counter == 0: + msg = " Dense Training Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + + # Updating the indicator sigma + if ((counter == 0) or (counter == int(totalBatches / 3.0)) or + (counter == int(2 * totalBatches / 3.0))) and (self.isDenseTraining is False): + self.sigmaI = 1 + itersInPhase = 0 + + elif (itersInPhase % 100 == 0): + indices = np.random.choice(Xtrain.shape[0], 100) + batchX = Xtrain[indices, :] + batchY = Ytrain[indices, :] + batchY = np.reshape( + batchY, [-1, self.bonsaiObj.numClasses]) + + Teval = self.bonsaiObj.T.data + Xcapeval = (torch.matmul(self.bonsaiObj.Z, torch.t( + batchX)) / self.bonsaiObj.projectionDimension).data + + sum_tr = 0.0 + for k in range(0, self.bonsaiObj.internalNodes): + sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval)))) + + if(self.bonsaiObj.internalNodes > 0): + sum_tr /= (100 * self.bonsaiObj.internalNodes) + sum_tr = 0.1 / sum_tr + else: + sum_tr = 0.1 + sum_tr = min( + 1000, sum_tr * (2**(float(itersInPhase) / + (float(totalBatches) / 30.0)))) + + self.sigmaI = sum_tr + + itersInPhase += 1 + batchX = Xtrain[j * batchSize:(j + 1) * batchSize] + batchY = Ytrain[j * batchSize:(j + 1) * batchSize] + batchY = np.reshape( + batchY, [-1, self.bonsaiObj.numClasses]) + + self.optimizer.zero_grad() + logits, _ = self.bonsaiObj(batchX, self.sigmaI) + batchLoss, _, _ = self.loss(logits, batchY) + batchAcc = self.accuracy(logits, batchY) + + batchLoss.backward() + self.optimizer.step() + + # Classification. + + trainAcc += batchAcc.item() + trainLoss += batchLoss.item() + + # Training routine involving IHT and sparse retraining + if (counter >= int(totalBatches / 3.0) and + (counter < int(2 * totalBatches / 3.0)) and + counter % trimlevel == 0 and + self.isDenseTraining is False): + self.runHardThrsd() + if ihtDone == 0: + msg = " IHT Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + ihtDone = 1 + elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and + (counter < int(2 * totalBatches / 3.0)) and + counter % trimlevel != 0 and + self.isDenseTraining is False) or + (counter >= int(2 * totalBatches / 3.0) and + self.isDenseTraining is False)): + self.runSparseTraining() + if counter == int(2 * totalBatches / 3.0): + msg = " Sparse Retraining Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + counter += 1 + + print("\nClassification Train Loss: " + str(trainLoss / numIters) + + "\nTraining accuracy (Classification): " + + str(trainAcc / numIters), + file=self.outFile) + + oldSigmaI = self.sigmaI + self.sigmaI = 1e9 + logits, _ = self.bonsaiObj(Xtest, self.sigmaI) + testLoss, marginLoss, regLoss = self.loss(logits, Ytest) + testAcc = self.accuracy(logits, Ytest).item() + + if ihtDone == 0: + maxTestAcc = -10000 + maxTestAccEpoch = i + else: + if maxTestAcc <= testAcc: + maxTestAccEpoch = i + maxTestAcc = testAcc + self.saveParams(currDir) + self.saveParamsForSeeDot(currDir) + + print("Test accuracy %g" % testAcc, file=self.outFile) + + testAcc = testAcc + maxTestAcc = maxTestAcc + + print("MarginLoss + RegLoss: " + str(marginLoss.item()) + " + " + + str(regLoss.item()) + " = " + str(testLoss.item()) + "\n", + file=self.outFile) + self.outFile.flush() + + self.sigmaI = oldSigmaI + + # sigmaI has to be set to infinity to ensure + # only a single path is used in inference + self.sigmaI = 1e9 + print("\nNon-Zero : " + str(self.getModelSize()[0]) + " Model Size: " + + str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " + + str(self.getModelSize()[2]) + "\n", file=self.outFile) + + print("For Classification, Maximum Test accuracy at compressed" + + " model size(including early stopping): " + + str(maxTestAcc) + " at Epoch: " + + str(maxTestAccEpoch + 1) + "\nFinal Test" + + " Accuracy: " + str(testAcc), file=self.outFile) + + resultFile.write("MaxTestAcc: " + str(maxTestAcc) + + " at Epoch(totalEpochs): " + + str(maxTestAccEpoch + 1) + + "(" + str(totalEpochs) + ")" + " ModelSize: " + + str(float(self.getModelSize()[1]) / 1024.0) + + " KB hasSparse: " + str(self.getModelSize()[2]) + + " Param Directory: " + + str(os.path.abspath(currDir)) + "\n") + print("The Model Directory: " + currDir + "\n") + + resultFile.close() + self.outFile.flush() + + if self.outFile is not sys.stdout: + self.outFile.close() diff --git a/pytorch/pytorch_edgeml/utils.py b/pytorch/pytorch_edgeml/utils.py new file mode 100644 index 00000000..3fe7cc3f --- /dev/null +++ b/pytorch/pytorch_edgeml/utils.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from __future__ import print_function +import numpy as np +import torch +import torch.nn.functional as F + + +def multiClassHingeLoss(logits, labels): + ''' + MultiClassHingeLoss to match C++ Version - No pytorch internal version + ''' + flatLogits = torch.reshape(logits, [-1, ]) + labels_ = labels.argmax(dim=1) + + correctId = torch.arange(labels.shape[0]) * labels.shape[1] + labels_ + correctLogit = torch.gather(flatLogits, 0, correctId) + + maxLabel = logits.argmax(dim=1) + top2, _ = torch.topk(logits, k=2, sorted=True) + + wrongMaxLogit = torch.where((maxLabel == labels_), top2[:, 1], top2[:, 0]) + + return torch.mean(F.relu(1. + wrongMaxLogit - correctLogit)) + + +def crossEntropyLoss(logits, labels): + ''' + Cross Entropy loss for MultiClass case in joint training for + faster convergence + ''' + return F.cross_entropy(logits, labels.argmax(dim=1)) + + +def binaryHingeLoss(logits, labels): + ''' + BinaryHingeLoss to match C++ Version - No pytorch internal version + ''' + return torch.mean(F.relu(1.0 - (2 * labels - 1) * logits)) + + +def hardThreshold(A, s): + ''' + Hard thresholding function on Tensor A with sparsity s + ''' + A_ = np.copy(A) + A_ = A_.ravel() + if len(A_) > 0: + th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher') + A_[np.abs(A_) < th] = 0.0 + A_ = A_.reshape(A.shape) + return A_ + + +def copySupport(src, dest): + ''' + copy support of src tensor to dest tensor + ''' + support = np.nonzero(src) + dest_ = dest + dest = np.zeros(dest_.shape) + dest[support] = dest_[support] + return dest + + +def countnnZ(A, s, bytesPerVar=4): + ''' + Returns # of non-zeros and representative size of the tensor + Uses dense for s >= 0.5 - 4 byte + Else uses sparse - 8 byte + ''' + params = 1 + hasSparse = False + for i in range(0, len(A.shape)): + params *= int(A.shape[i]) + if s < 0.5: + nnZ = np.ceil(params * s) + hasSparse = True + return nnZ, nnZ * 2 * bytesPerVar, hasSparse + else: + nnZ = params + return nnZ, nnZ * bytesPerVar, hasSparse + + +def restructreMatrixBonsaiSeeDot(A, nClasses, nNodes): + ''' + Restructures a matrix from [nNodes*nClasses, Proj] to + [nClasses*nNodes, Proj] for SeeDot + ''' + tempMatrix = np.zeros(A.shape) + rowIndex = 0 + + for i in range(0, nClasses): + for j in range(0, nNodes): + tempMatrix[rowIndex] = A[j * nClasses + i] + rowIndex += 1 + + return tempMatrix diff --git a/pytorch/setup.py b/pytorch/setup.py new file mode 100644 index 00000000..ccfb4b93 --- /dev/null +++ b/pytorch/setup.py @@ -0,0 +1,9 @@ +from distutils.core import setup + +setup( + name='pytorch_edgeml', + version='0.2', + packages=['pytorch_edgeml', ], + license='MIT License', + long_description=open('../License.txt').read(), +)