зеркало из https://github.com/microsoft/EdgeML.git
Working and tested pytorch bonsai
This commit is contained in:
Родитель
c683fb1132
Коммит
3246f9bec7
|
@ -0,0 +1,67 @@
|
|||
# EdgeML Bonsai on a sample public dataset
|
||||
|
||||
This directory includes, example notebook and general execution script of
|
||||
Bonsai developed as part of EdgeML. Also, we include a sample cleanup and
|
||||
use-case on the USPS10 public dataset.
|
||||
|
||||
`edgeml.graph.bonsai` implements the Bonsai prediction graph in tensorflow.
|
||||
The three-phase training routine for Bonsai is decoupled from the forward graph
|
||||
to facilitate a plug and play behaviour wherein Bonsai can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs).
|
||||
|
||||
Note that `bonsai_example.py` assumes that data is in a specific format. It is
|
||||
assumed that train and test data is contained in two files, `train.npy` and
|
||||
`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples,
|
||||
numberOfFeatures + 1]`. The first column of each matrix is assumed to contain
|
||||
label information. For an N-Class problem, we assume the labels are integers
|
||||
from 0 through N-1. `bonsai_example.py` also supports univariate regression
|
||||
and can be accessed using the help options of the script. Multivariate regression
|
||||
requires restructuring of the input data format and can further help in extending
|
||||
bonsai to multi-label classification and multi-variate regression. Lastly,
|
||||
the training data, `train.npy`, is assumed to well shuffled
|
||||
as the training routine doesn't shuffle internally.
|
||||
|
||||
**Tested With:** Tensorflow >1.6 with Python 2 and Python 3
|
||||
|
||||
## Download and clean up sample dataset
|
||||
|
||||
We will be testing out the validation of the code by using the USPS dataset.
|
||||
The download and cleanup of the dataset to match the above-mentioned format is
|
||||
done by the script [fetch_usps.py](fetch_usps.py) and
|
||||
[process_usps.py](process_usps.py)
|
||||
|
||||
```
|
||||
python fetch_usps.py
|
||||
python process_usps.py
|
||||
```
|
||||
|
||||
## Sample command for Bonsai on USPS10
|
||||
The following sample run on usps10 should validate your library:
|
||||
|
||||
```bash
|
||||
python bonsai_example.py -dir usps10/ -d 3 -p 28 -rW 0.001 -rZ 0.0001 -rV 0.001 -rT 0.001 -sZ 0.2 -sW 0.3 -sV 0.3 -sT 0.62 -e 100 -s 1
|
||||
```
|
||||
This command should give you a final output screen which reads roughly similar to (might not be exact numbers due to various version mismatches):
|
||||
```
|
||||
Maximum Test accuracy at compressed model size(including early stopping): 0.94369704 at Epoch: 66
|
||||
Final Test Accuracy: 0.93024415
|
||||
|
||||
Non-Zeros: 4156.0 Model Size: 31.703125 KB hasSparse: True
|
||||
```
|
||||
|
||||
usps10 directory will now have a consolidated results file called `TFBonsaiResults.txt` and a directory `TFBonsaiResults` with the corresponding models with each run of the code on the usps10 dataset
|
||||
|
||||
## Byte Quantization (Q) for model compression
|
||||
If you wish to quantize the generated model to use byte quantized integers use `quantizeBonsaiModels.py`. Usage Instructions:
|
||||
|
||||
```
|
||||
python quantizeBonsaiModels.py -h
|
||||
```
|
||||
|
||||
This will generate quantized models with a suffix of `q` before every param stored in a new directory `QuantizedTFBonsaiModel` inside the model directory.
|
||||
One can use this model further on edge devices.
|
||||
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
|
||||
Licensed under the MIT license.
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import helpermethods
|
||||
import numpy as np
|
||||
import sys
|
||||
from pytorch_edgeml.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from pytorch_edgeml.graph.bonsai import Bonsai
|
||||
import torch
|
||||
|
||||
|
||||
def main():
|
||||
# Fixing seeds for reproducibility
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
# Hyper Param pre-processing
|
||||
args = helpermethods.getArgs()
|
||||
|
||||
sigma = args.sigma
|
||||
depth = args.depth
|
||||
|
||||
projectionDimension = args.proj_dim
|
||||
regZ = args.rZ
|
||||
regT = args.rT
|
||||
regW = args.rW
|
||||
regV = args.rV
|
||||
|
||||
totalEpochs = args.epochs
|
||||
|
||||
learningRate = args.learning_rate
|
||||
|
||||
dataDir = args.data_dir
|
||||
|
||||
outFile = args.output_file
|
||||
|
||||
(dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest,
|
||||
mean, std) = helpermethods.preProcessData(dataDir)
|
||||
|
||||
sparZ = args.sZ
|
||||
|
||||
if numClasses > 2:
|
||||
sparW = 0.2
|
||||
sparV = 0.2
|
||||
sparT = 0.2
|
||||
else:
|
||||
sparW = 1
|
||||
sparV = 1
|
||||
sparT = 1
|
||||
|
||||
if args.sW is not None:
|
||||
sparW = args.sW
|
||||
if args.sV is not None:
|
||||
sparV = args.sV
|
||||
if args.sT is not None:
|
||||
sparT = args.sT
|
||||
|
||||
if args.batch_size is None:
|
||||
batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0]))))
|
||||
else:
|
||||
batchSize = args.batch_size
|
||||
|
||||
useMCHLoss = True
|
||||
|
||||
if numClasses == 2:
|
||||
numClasses = 1
|
||||
|
||||
currDir = helpermethods.createTimeStampDir(dataDir)
|
||||
|
||||
helpermethods.dumpCommand(sys.argv, currDir)
|
||||
helpermethods.saveMeanStd(mean, std, currDir)
|
||||
|
||||
# numClasses = 1 for binary case
|
||||
bonsaiObj = Bonsai(numClasses, dataDimension,
|
||||
projectionDimension, depth, sigma)
|
||||
|
||||
bonsaiTrainer = BonsaiTrainer(bonsaiObj,
|
||||
regW, regT, regV, regZ,
|
||||
sparW, sparT, sparV, sparZ,
|
||||
learningRate, useMCHLoss, outFile)
|
||||
|
||||
bonsaiTrainer.train(batchSize, totalEpochs,
|
||||
torch.from_numpy(Xtrain.astype(np.float32)),
|
||||
torch.from_numpy(Xtest.astype(np.float32)),
|
||||
torch.from_numpy(Ytrain.astype(np.float32)),
|
||||
torch.from_numpy(Ytest.astype(np.float32)),
|
||||
dataDir, currDir)
|
||||
|
||||
sys.stdout.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# Setting up the USPS Data.
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import numpy as np
|
||||
from sklearn.datasets import load_svmlight_file
|
||||
import sys
|
||||
|
||||
def downloadData(workingDir, downloadDir, linkTrain, linkTest):
|
||||
def runcommand(command):
|
||||
p = subprocess.Popen(command.split(), stdout=subprocess.PIPE)
|
||||
output, error = p.communicate()
|
||||
assert(p.returncode == 0), 'Command failed: %s' % command
|
||||
|
||||
path = workingDir + '/' + downloadDir
|
||||
path = os.path.abspath(path)
|
||||
try:
|
||||
os.mkdir(path)
|
||||
except OSError:
|
||||
print("Could not create %s. Make sure the path does" % path)
|
||||
print("not already exist and you have permisions to create it.")
|
||||
return False
|
||||
cwd = os.getcwd()
|
||||
os.chdir(path)
|
||||
print("Downloading data")
|
||||
command = 'wget %s' % linkTrain
|
||||
runcommand(command)
|
||||
command = 'wget %s' % linkTest
|
||||
runcommand(command)
|
||||
print("Extracting data")
|
||||
command = 'bzip2 -d usps.bz2'
|
||||
runcommand(command)
|
||||
command = 'bzip2 -d usps.t.bz2'
|
||||
runcommand(command)
|
||||
command = 'mv usps train.txt'
|
||||
runcommand(command)
|
||||
command = 'mv usps.t test.txt'
|
||||
runcommand(command)
|
||||
os.chdir(cwd)
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
workingDir = './'
|
||||
downloadDir = 'usps10'
|
||||
linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2'
|
||||
linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2'
|
||||
failureMsg = '''
|
||||
Download Failed!
|
||||
To manually perform the download
|
||||
\t1. Create a new empty directory named `usps10`.
|
||||
\t2. Download the data from the following links into the usps10 directory.
|
||||
\t\tTest: %s
|
||||
\t\tTrain: %s
|
||||
\t3. Extract the downloaded files.
|
||||
\t4. Rename `usps` to `train.txt` and,
|
||||
\t5. Rename `usps.t` to `test.txt
|
||||
''' % (linkTrain, linkTest)
|
||||
|
||||
if not downloadData(workingDir, downloadDir, linkTrain, linkTest):
|
||||
exit(failureMsg)
|
||||
print("Done")
|
|
@ -0,0 +1,258 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
'''
|
||||
Functions to check sanity of input arguments
|
||||
for the example script.
|
||||
'''
|
||||
|
||||
|
||||
def checkIntPos(value):
|
||||
ivalue = int(value)
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid positive int value" % value)
|
||||
return ivalue
|
||||
|
||||
|
||||
def checkIntNneg(value):
|
||||
ivalue = int(value)
|
||||
if ivalue < 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid non-neg int value" % value)
|
||||
return ivalue
|
||||
|
||||
|
||||
def checkFloatNneg(value):
|
||||
fvalue = float(value)
|
||||
if fvalue < 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid non-neg float value" % value)
|
||||
return fvalue
|
||||
|
||||
|
||||
def checkFloatPos(value):
|
||||
fvalue = float(value)
|
||||
if fvalue <= 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid positive float value" % value)
|
||||
return fvalue
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
|
||||
def getArgs():
|
||||
'''
|
||||
Function to parse arguments for Bonsai Algorithm
|
||||
'''
|
||||
parser = argparse.ArgumentParser(
|
||||
description='HyperParams for Bonsai Algorithm')
|
||||
parser.add_argument('-dir', '--data-dir', required=True,
|
||||
help='Data directory containing' +
|
||||
'train.npy and test.npy')
|
||||
|
||||
parser.add_argument('-d', '--depth', type=checkIntNneg, default=2,
|
||||
help='Depth of Bonsai Tree ' +
|
||||
'(default: 2 try: [0, 1, 3])')
|
||||
parser.add_argument('-p', '--proj-dim', type=checkIntPos, default=10,
|
||||
help='Projection Dimension ' +
|
||||
'(default: 20 try: [5, 10, 30])')
|
||||
parser.add_argument('-s', '--sigma', type=float, default=1.0,
|
||||
help='Parameter for sigmoid sharpness ' +
|
||||
'(default: 1.0 try: [3.0, 0.05, 0.1]')
|
||||
parser.add_argument('-e', '--epochs', type=checkIntPos, default=42,
|
||||
help='Total Epochs (default: 42 try:[100, 150, 60])')
|
||||
parser.add_argument('-b', '--batch-size', type=checkIntPos,
|
||||
help='Batch Size to be used ' +
|
||||
'(default: max(100, sqrt(train_samples)))')
|
||||
parser.add_argument('-lr', '--learning-rate', type=checkFloatPos,
|
||||
default=0.01, help='Initial Learning rate for ' +
|
||||
'Adam Optimizer (default: 0.01)')
|
||||
|
||||
parser.add_argument('-rW', type=float, default=0.0001,
|
||||
help='Regularizer for predictor parameter W ' +
|
||||
'(default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rV', type=float, default=0.0001,
|
||||
help='Regularizer for predictor parameter V ' +
|
||||
'(default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rT', type=float, default=0.0001,
|
||||
help='Regularizer for branching parameter Theta ' +
|
||||
'(default: 0.0001 try: [0.01, 0.001, 0.00001])')
|
||||
parser.add_argument('-rZ', type=float, default=0.00001,
|
||||
help='Regularizer for projection parameter Z ' +
|
||||
'(default: 0.00001 try: [0.001, 0.0001, 0.000001])')
|
||||
|
||||
parser.add_argument('-sW', type=checkFloatPos,
|
||||
help='Sparsity for predictor parameter W ' +
|
||||
'(default: For Binary classification 1.0 else 0.2 ' +
|
||||
'try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sV', type=checkFloatPos,
|
||||
help='Sparsity for predictor parameter V ' +
|
||||
'(default: For Binary classification 1.0 else 0.2 ' +
|
||||
'try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sT', type=checkFloatPos,
|
||||
help='Sparsity for branching parameter Theta ' +
|
||||
'(default: For Binary classification 1.0 else 0.2 ' +
|
||||
'try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-sZ', type=checkFloatPos, default=0.2,
|
||||
help='Sparsity for projection parameter Z ' +
|
||||
'(default: 0.2 try: [0.1, 0.3, 0.5])')
|
||||
parser.add_argument('-oF', '--output-file', default=None,
|
||||
help='Output file for dumping the program output, ' +
|
||||
'(default: stdout)')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def getQuantArgs():
|
||||
'''
|
||||
Function to parse arguments for Model Quantisation
|
||||
'''
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Arguments for quantizing Fast models. ' +
|
||||
'Works only for piece-wise linear non-linearities, ' +
|
||||
'like relu, quantTanh, quantSigm (check rnn.py for the definitions)')
|
||||
parser.add_argument('-dir', '--model-dir', required=True,
|
||||
help='model directory containing' +
|
||||
'*.npy weight files dumped from the trained model')
|
||||
parser.add_argument('-m', '--max-val', type=checkIntNneg, default=127,
|
||||
help='this represents the maximum possible value ' +
|
||||
'in model, essentially the byte complexity, ' +
|
||||
'127=> 1 byte is default')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def createTimeStampDir(dataDir):
|
||||
'''
|
||||
Creates a Directory with timestamp as it's name
|
||||
'''
|
||||
if os.path.isdir(dataDir + '/pytorchBonsaiResults') is False:
|
||||
try:
|
||||
os.mkdir(dataDir + '/pytorchBonsaiResults')
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" %
|
||||
dataDir + '/pytorchBonsaiResults')
|
||||
|
||||
currDir = 'pytorchBonsaiResults/' + \
|
||||
datetime.datetime.now().strftime("%H_%M_%S_%d_%m_%y")
|
||||
if os.path.isdir(dataDir + '/' + currDir) is False:
|
||||
try:
|
||||
os.mkdir(dataDir + '/' + currDir)
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" %
|
||||
dataDir + '/' + currDir)
|
||||
else:
|
||||
return (dataDir + '/' + currDir)
|
||||
return None
|
||||
|
||||
|
||||
def preProcessData(dataDir):
|
||||
'''
|
||||
Function to pre-process input data
|
||||
Expects a .npy file of form [lbl feats] for each datapoint
|
||||
Outputs a train and test set datapoints appended with 1 for Bias induction
|
||||
dataDimension, numClasses are inferred directly
|
||||
'''
|
||||
train = np.load(dataDir + '/train.npy')
|
||||
test = np.load(dataDir + '/test.npy')
|
||||
|
||||
dataDimension = int(train.shape[1]) - 1
|
||||
|
||||
Xtrain = train[:, 1:dataDimension + 1]
|
||||
Ytrain_ = train[:, 0]
|
||||
|
||||
Xtest = test[:, 1:dataDimension + 1]
|
||||
Ytest_ = test[:, 0]
|
||||
|
||||
# Mean Var Normalisation
|
||||
mean = np.mean(Xtrain, 0)
|
||||
std = np.std(Xtrain, 0)
|
||||
std[std[:] < 0.000001] = 1
|
||||
Xtrain = (Xtrain - mean) / std
|
||||
Xtest = (Xtest - mean) / std
|
||||
# End Mean Var normalisation
|
||||
|
||||
# Classification.
|
||||
|
||||
numClasses = max(Ytrain_) - min(Ytrain_) + 1
|
||||
numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1))
|
||||
|
||||
lab = Ytrain_.astype('uint8')
|
||||
lab = np.array(lab) - min(lab)
|
||||
|
||||
lab_ = np.zeros((Xtrain.shape[0], numClasses))
|
||||
lab_[np.arange(Xtrain.shape[0]), lab] = 1
|
||||
if (numClasses == 2):
|
||||
Ytrain = np.reshape(lab, [-1, 1])
|
||||
else:
|
||||
Ytrain = lab_
|
||||
|
||||
lab = Ytest_.astype('uint8')
|
||||
lab = np.array(lab) - min(lab)
|
||||
|
||||
lab_ = np.zeros((Xtest.shape[0], numClasses))
|
||||
lab_[np.arange(Xtest.shape[0]), lab] = 1
|
||||
if (numClasses == 2):
|
||||
Ytest = np.reshape(lab, [-1, 1])
|
||||
else:
|
||||
Ytest = lab_
|
||||
|
||||
trainBias = np.ones([Xtrain.shape[0], 1])
|
||||
Xtrain = np.append(Xtrain, trainBias, axis=1)
|
||||
testBias = np.ones([Xtest.shape[0], 1])
|
||||
Xtest = np.append(Xtest, testBias, axis=1)
|
||||
|
||||
mean = np.append(mean, np.array([0]))
|
||||
std = np.append(std, np.array([1]))
|
||||
|
||||
return dataDimension + 1, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std
|
||||
|
||||
|
||||
def dumpCommand(list, currDir):
|
||||
'''
|
||||
Dumps the current command to a file for further use
|
||||
'''
|
||||
commandFile = open(currDir + '/command.txt', 'w')
|
||||
command = "python"
|
||||
|
||||
command = command + " " + ' '.join(list)
|
||||
commandFile.write(command)
|
||||
|
||||
commandFile.flush()
|
||||
commandFile.close()
|
||||
|
||||
|
||||
def saveMeanStd(mean, std, currDir):
|
||||
'''
|
||||
Function to save Mean and Std vectors
|
||||
'''
|
||||
np.save(currDir + '/mean.npy', mean)
|
||||
np.save(currDir + '/std.npy', std)
|
||||
saveMeanStdSeeDot(mean, std, currDir + "/SeeDot")
|
||||
|
||||
|
||||
def saveMeanStdSeeDot(mean, std, seeDotDir):
|
||||
'''
|
||||
Function to save Mean and Std vectors
|
||||
'''
|
||||
if os.path.isdir(seeDotDir) is False:
|
||||
try:
|
||||
os.mkdir(seeDotDir)
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" %
|
||||
seeDotDir)
|
||||
np.savetxt(seeDotDir + '/Mean', mean, delimiter="\t")
|
||||
np.savetxt(seeDotDir + '/Std', std, delimiter="\t")
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
#
|
||||
# Processing the USPS Data. It is assumed that the data is already
|
||||
# downloaded.
|
||||
|
||||
import subprocess
|
||||
import os
|
||||
import numpy as np
|
||||
from sklearn.datasets import load_svmlight_file
|
||||
import sys
|
||||
|
||||
def processData(workingDir, downloadDir):
|
||||
def loadLibSVMFile(file):
|
||||
data = load_svmlight_file(file)
|
||||
features = data[0]
|
||||
labels = data[1]
|
||||
retMat = np.zeros([features.shape[0], features.shape[1] + 1])
|
||||
retMat[:, 0] = labels
|
||||
retMat[:, 1:] = features.todense()
|
||||
return retMat
|
||||
|
||||
path = workingDir + '/' + downloadDir
|
||||
path = os.path.abspath(path)
|
||||
trf = path + '/train.txt'
|
||||
tsf = path + '/test.txt'
|
||||
assert os.path.isfile(trf), 'File not found: %s' % trf
|
||||
assert os.path.isfile(tsf), 'File not found: %s' % tsf
|
||||
train = loadLibSVMFile(trf)
|
||||
test = loadLibSVMFile(tsf)
|
||||
|
||||
# Convert the labels from 0 to numClasses-1
|
||||
y_train = train[:, 0]
|
||||
y_test = test[:, 0]
|
||||
|
||||
lab = y_train.astype('uint8')
|
||||
lab = np.array(lab) - min(lab)
|
||||
train[:, 0] = lab
|
||||
|
||||
lab = y_test.astype('uint8')
|
||||
lab = np.array(lab) - min(lab)
|
||||
test[:, 0] = lab
|
||||
|
||||
np.save(path + '/train.npy', train)
|
||||
np.save(path + '/test.npy', test)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Configuration
|
||||
workingDir = './'
|
||||
downloadDir = 'usps10'
|
||||
# End config
|
||||
print("Processing data")
|
||||
processData(workingDir, downloadDir)
|
||||
print("Done")
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import helpermethods
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def min_max(A, name):
|
||||
print(name + " has max: " + str(np.max(A)) + " min: " + str(np.min(A)))
|
||||
return np.max([np.abs(np.max(A)), np.abs(np.min(A))])
|
||||
|
||||
|
||||
def quantizeBonsaiModels(modelDir, maxValue=127, scalarScaleFactor=1000):
|
||||
ls = os.listdir(modelDir)
|
||||
paramNameList = []
|
||||
paramWeightList = []
|
||||
paramLimitList = []
|
||||
|
||||
for file in ls:
|
||||
if file.endswith("npy"):
|
||||
if file.startswith("mean") or file.startswith("std") or file.startswith("hyperParam"):
|
||||
continue
|
||||
else:
|
||||
paramNameList.append(file)
|
||||
temp = np.load(modelDir + "/" + file)
|
||||
paramWeightList.append(temp)
|
||||
paramLimitList.append(min_max(temp, file))
|
||||
|
||||
paramLimit = np.max(paramLimitList)
|
||||
|
||||
paramScaleFactor = np.round((2.0 * maxValue + 1.0) / (2.0 * paramLimit))
|
||||
|
||||
quantParamWeights = []
|
||||
for param in paramWeightList:
|
||||
temp = np.round(paramScaleFactor * param)
|
||||
temp[temp[:] > maxValue] = maxValue
|
||||
temp[temp[:] < -maxValue] = -1 * (maxValue + 1)
|
||||
|
||||
if maxValue <= 127:
|
||||
temp = temp.astype('int8')
|
||||
elif maxValue <= 32767:
|
||||
temp = temp.astype('int16')
|
||||
else:
|
||||
temp = temp.astype('int32')
|
||||
|
||||
quantParamWeights.append(temp)
|
||||
|
||||
if os.path.isdir(modelDir + '/QuantizedTFBonsaiModel') is False:
|
||||
try:
|
||||
os.mkdir(modelDir + '/QuantizedTFBonsaiModel')
|
||||
quantModelDir = modelDir + '/QuantizedTFBonsaiModel'
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" %
|
||||
modelDir + '/QuantizedTFBonsaiModel')
|
||||
|
||||
np.save(quantModelDir + "/paramScaleFactor.npy",
|
||||
paramScaleFactor.astype('int32'))
|
||||
|
||||
for i in range(len(paramNameList)):
|
||||
np.save(quantModelDir + "/q" + paramNameList[i], quantParamWeights[i])
|
||||
|
||||
print("\n\nQuantized Model Dir: " + quantModelDir)
|
||||
|
||||
|
||||
def main():
|
||||
args = helpermethods.getQuantArgs()
|
||||
quantizeBonsaiModels(args.model_dir, int(args.max_val))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Bonsai(nn.Module):
|
||||
|
||||
def __init__(self, numClasses, dataDimension, projectionDimension,
|
||||
treeDepth, sigma, W=None, T=None, V=None, Z=None):
|
||||
super(Bonsai, self).__init__()
|
||||
'''
|
||||
Expected Dimensions:
|
||||
|
||||
Bonsai Params // Optional
|
||||
W [numClasses*totalNodes, projectionDimension]
|
||||
V [numClasses*totalNodes, projectionDimension]
|
||||
Z [projectionDimension, dataDimension + 1]
|
||||
T [internalNodes, projectionDimension]
|
||||
|
||||
internalNodes = 2**treeDepth - 1
|
||||
totalNodes = 2*internalNodes + 1
|
||||
|
||||
sigma - tanh non-linearity
|
||||
sigmaI - Indicator function for node probabilities
|
||||
sigmaI - has to be set to infinity(1e9 for practice)
|
||||
while doing testing/inference
|
||||
numClasses will be reset to 1 in binary case
|
||||
'''
|
||||
|
||||
self.dataDimension = dataDimension
|
||||
self.projectionDimension = projectionDimension
|
||||
|
||||
if numClasses == 2:
|
||||
self.numClasses = 1
|
||||
else:
|
||||
self.numClasses = numClasses
|
||||
|
||||
self.treeDepth = treeDepth
|
||||
self.sigma = sigma
|
||||
|
||||
self.internalNodes = 2**self.treeDepth - 1
|
||||
self.totalNodes = 2 * self.internalNodes + 1
|
||||
|
||||
self.W = self.initW(W)
|
||||
self.V = self.initV(V)
|
||||
self.T = self.initT(T)
|
||||
self.Z = self.initZ(Z)
|
||||
|
||||
self.assertInit()
|
||||
|
||||
def initZ(self, Z):
|
||||
if Z is None:
|
||||
Z = torch.randn([self.projectionDimension, self.dataDimension])
|
||||
Z = nn.Parameter(Z)
|
||||
else:
|
||||
Z.data = torch.from_numpy(Z.astype(np.float32))
|
||||
return Z
|
||||
|
||||
def initW(self, W):
|
||||
if W is None:
|
||||
W = torch.randn(
|
||||
[self.numClasses * self.totalNodes, self.projectionDimension])
|
||||
W = nn.Parameter(W)
|
||||
else:
|
||||
W.data = torch.from_numpy(W.astype(np.float32))
|
||||
return W
|
||||
|
||||
def initV(self, V):
|
||||
if V is None:
|
||||
V = torch.randn(
|
||||
[self.numClasses * self.totalNodes, self.projectionDimension])
|
||||
V = nn.Parameter(V)
|
||||
else:
|
||||
V.data = torch.from_numpy(V.astype(np.float32))
|
||||
return V
|
||||
|
||||
def initT(self, T):
|
||||
if T is None:
|
||||
T = torch.randn([self.internalNodes, self.projectionDimension])
|
||||
T = nn.Parameter(T)
|
||||
else:
|
||||
T.data = torch.from_numpy(T.astype(np.float32))
|
||||
return T
|
||||
|
||||
def forward(self, X, sigmaI):
|
||||
'''
|
||||
Function to build/exxecute the Bonsai Tree graph
|
||||
Expected Dimensions
|
||||
|
||||
X is [batchSize, self.dataDimension]
|
||||
sigmaI is constant
|
||||
'''
|
||||
X_ = torch.matmul(self.Z, torch.t(X)) / self.projectionDimension
|
||||
W_ = self.W[0:(self.numClasses)]
|
||||
V_ = self.V[0:(self.numClasses)]
|
||||
self.__nodeProb = []
|
||||
self.__nodeProb.append(1)
|
||||
score_ = self.__nodeProb[0] * (torch.matmul(W_, X_) *
|
||||
torch.tanh(self.sigma *
|
||||
torch.matmul(V_, X_)))
|
||||
for i in range(1, self.totalNodes):
|
||||
W_ = self.W[i * self.numClasses:((i + 1) * self.numClasses)]
|
||||
V_ = self.V[i * self.numClasses:((i + 1) * self.numClasses)]
|
||||
|
||||
T_ = torch.reshape(self.T[int(np.ceil(i / 2.0) - 1.0)],
|
||||
[-1, self.projectionDimension])
|
||||
prob = (1 + ((-1)**(i + 1)) *
|
||||
torch.tanh(sigmaI * torch.matmul(T_, X_)))
|
||||
|
||||
prob = prob / 2.0
|
||||
prob = self.__nodeProb[int(np.ceil(i / 2.0) - 1.0)] * prob
|
||||
self.__nodeProb.append(prob)
|
||||
score_ += self.__nodeProb[i] * (torch.matmul(W_, X_) *
|
||||
torch.tanh(self.sigma *
|
||||
torch.matmul(V_, X_)))
|
||||
|
||||
self.score = score_
|
||||
self.X_ = X_
|
||||
return torch.t(self.score), self.X_
|
||||
|
||||
def assertInit(self):
|
||||
errRank = "All Parameters must has only two dimensions shape = [a, b]"
|
||||
assert len(self.W.shape) == len(self.Z.shape), errRank
|
||||
assert len(self.W.shape) == len(self.T.shape), errRank
|
||||
assert len(self.W.shape) == 2, errRank
|
||||
msg = "W and V should be of same Dimensions"
|
||||
assert self.W.shape == self.V.shape, msg
|
||||
errW = "W and V are [numClasses*totalNodes, projectionDimension]"
|
||||
assert self.W.shape[0] == self.numClasses * self.totalNodes, errW
|
||||
assert self.W.shape[1] == self.projectionDimension, errW
|
||||
errZ = "Z is [projectionDimension, dataDimension]"
|
||||
assert self.Z.shape[0] == self.projectionDimension, errZ
|
||||
assert self.Z.shape[1] == self.dataDimension, errZ
|
||||
errT = "T is [internalNodes, projectionDimension]"
|
||||
assert self.T.shape[0] == self.internalNodes, errT
|
||||
assert self.T.shape[1] == self.projectionDimension, errT
|
||||
assert int(self.numClasses) > 0, "numClasses should be > 1"
|
||||
msg = "# of features in data should be > 0"
|
||||
assert int(self.dataDimension) > 0, msg
|
||||
msg = "Projection should be > 0 dims"
|
||||
assert int(self.projectionDimension) > 0, msg
|
||||
msg = "treeDepth should be >= 0"
|
||||
assert int(self.treeDepth) >= 0, msg
|
|
@ -0,0 +1,397 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import pytorch_edgeml.utils as utils
|
||||
|
||||
|
||||
class BonsaiTrainer:
|
||||
|
||||
def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ,
|
||||
learningRate, useMCHLoss=False, outFile=None):
|
||||
'''
|
||||
bonsaiObj - Initialised Bonsai Object and Graph
|
||||
lW, lT, lV and lZ are regularisers to Bonsai Params
|
||||
sW, sT, sV and sZ are sparsity factors to Bonsai Params
|
||||
learningRate - learningRate for optimizer
|
||||
useMCHLoss - For choice between HingeLoss vs CrossEntropy
|
||||
useMCHLoss - True - MultiClass - multiClassHingeLoss
|
||||
useMCHLoss - False - MultiClass - crossEntropyLoss
|
||||
'''
|
||||
|
||||
self.bonsaiObj = bonsaiObj
|
||||
|
||||
self.lW = lW
|
||||
self.lV = lV
|
||||
self.lT = lT
|
||||
self.lZ = lZ
|
||||
|
||||
self.sW = sW
|
||||
self.sV = sV
|
||||
self.sT = sT
|
||||
self.sZ = sZ
|
||||
|
||||
self.useMCHLoss = useMCHLoss
|
||||
|
||||
if outFile is not None:
|
||||
print("Outfile : ", outFile)
|
||||
self.outFile = open(outFile, 'w')
|
||||
else:
|
||||
self.outFile = sys.stdout
|
||||
|
||||
self.learningRate = learningRate
|
||||
|
||||
self.assertInit()
|
||||
|
||||
self.optimizer = self.optimizer()
|
||||
|
||||
if self.sW > 0.99 and self.sV > 0.99 and self.sZ > 0.99 and self.sT > 0.99:
|
||||
self.isDenseTraining = True
|
||||
else:
|
||||
self.isDenseTraining = False
|
||||
|
||||
def loss(self, logits, labels):
|
||||
'''
|
||||
Loss function for given Bonsai Obj
|
||||
'''
|
||||
regLoss = 0.5 * (self.lZ * (torch.norm(self.bonsaiObj.Z)**2) +
|
||||
self.lW * (torch.norm(self.bonsaiObj.W)**2) +
|
||||
self.lV * (torch.norm(self.bonsaiObj.V)**2) +
|
||||
self.lT * (torch.norm(self.bonsaiObj.T))**2)
|
||||
|
||||
if (self.bonsaiObj.numClasses > 2):
|
||||
if self.useMCHLoss is True:
|
||||
marginLoss = utils.multiClassHingeLoss(logits, labels)
|
||||
else:
|
||||
marginLoss = utils.crossEntropyLoss(logits, labels)
|
||||
loss = marginLoss + regLoss
|
||||
else:
|
||||
marginLoss = utils.binaryHingeLoss(logits, labels)
|
||||
loss = marginLoss + regLoss
|
||||
|
||||
return loss, marginLoss, regLoss
|
||||
|
||||
def optimizer(self):
|
||||
'''
|
||||
Optimizer for Bonsai Params
|
||||
'''
|
||||
optimizer = torch.optim.Adam(
|
||||
self.bonsaiObj.parameters(), lr=self.learningRate)
|
||||
|
||||
return optimizer
|
||||
|
||||
def accuracy(self, logits, labels):
|
||||
'''
|
||||
Accuracy fucntion to evaluate accuracy when needed
|
||||
'''
|
||||
if (self.bonsaiObj.numClasses > 2):
|
||||
correctPredictions = (logits.argmax(dim=1) == labels.argmax(dim=1))
|
||||
accuracy = torch.mean(correctPredictions.float())
|
||||
else:
|
||||
pred = (torch.cat((torch.zeros(logits.shape),
|
||||
logits), 1)).argmax(dim=1)
|
||||
accuracy = torch.mean((labels.view(-1).long() == pred).float())
|
||||
|
||||
return accuracy
|
||||
|
||||
def runHardThrsd(self):
|
||||
'''
|
||||
Function to run the IHT routine on Bonsai Obj
|
||||
'''
|
||||
currW = self.bonsaiObj.W.data
|
||||
currV = self.bonsaiObj.V.data
|
||||
currZ = self.bonsaiObj.Z.data
|
||||
currT = self.bonsaiObj.T.data
|
||||
|
||||
self.__thrsdW = utils.hardThreshold(currW, self.sW)
|
||||
self.__thrsdV = utils.hardThreshold(currV, self.sV)
|
||||
self.__thrsdZ = utils.hardThreshold(currZ, self.sZ)
|
||||
self.__thrsdT = utils.hardThreshold(currT, self.sT)
|
||||
|
||||
self.bonsaiObj.W.data = torch.FloatTensor(self.__thrsdW)
|
||||
self.bonsaiObj.V.data = torch.FloatTensor(self.__thrsdV)
|
||||
self.bonsaiObj.Z.data = torch.FloatTensor(self.__thrsdZ)
|
||||
self.bonsaiObj.T.data = torch.FloatTensor(self.__thrsdT)
|
||||
|
||||
def runSparseTraining(self):
|
||||
'''
|
||||
Function to run the Sparse Retraining routine on Bonsai Obj
|
||||
'''
|
||||
currW = self.bonsaiObj.W.data
|
||||
currV = self.bonsaiObj.V.data
|
||||
currZ = self.bonsaiObj.Z.data
|
||||
currT = self.bonsaiObj.T.data
|
||||
|
||||
newW = utils.copySupport(self.__thrsdW, currW)
|
||||
newV = utils.copySupport(self.__thrsdV, currV)
|
||||
newZ = utils.copySupport(self.__thrsdZ, currZ)
|
||||
newT = utils.copySupport(self.__thrsdT, currT)
|
||||
|
||||
self.bonsaiObj.W.data = torch.FloatTensor(newW)
|
||||
self.bonsaiObj.V.data = torch.FloatTensor(newV)
|
||||
self.bonsaiObj.Z.data = torch.FloatTensor(newZ)
|
||||
self.bonsaiObj.T.data = torch.FloatTensor(newT)
|
||||
|
||||
def assertInit(self):
|
||||
err = "sparsity must be between 0 and 1"
|
||||
assert self.sW >= 0 and self.sW <= 1, "W " + err
|
||||
assert self.sV >= 0 and self.sV <= 1, "V " + err
|
||||
assert self.sZ >= 0 and self.sZ <= 1, "Z " + err
|
||||
assert self.sT >= 0 and self.sT <= 1, "T " + err
|
||||
|
||||
def saveParams(self, currDir):
|
||||
'''
|
||||
Function to save Parameter matrices into a given folder
|
||||
'''
|
||||
paramDir = currDir + '/'
|
||||
np.save(paramDir + "W.npy", self.bonsaiObj.W.data)
|
||||
np.save(paramDir + "V.npy", self.bonsaiObj.V.data)
|
||||
np.save(paramDir + "T.npy", self.bonsaiObj.T.data)
|
||||
np.save(paramDir + "Z.npy", self.bonsaiObj.Z.data)
|
||||
hyperParamDict = {'dataDim': self.bonsaiObj.dataDimension,
|
||||
'projDim': self.bonsaiObj.projectionDimension,
|
||||
'numClasses': self.bonsaiObj.numClasses,
|
||||
'depth': self.bonsaiObj.treeDepth,
|
||||
'sigma': self.bonsaiObj.sigma}
|
||||
hyperParamFile = paramDir + 'hyperParam.npy'
|
||||
np.save(hyperParamFile, hyperParamDict)
|
||||
|
||||
def saveParamsForSeeDot(self, currDir):
|
||||
'''
|
||||
Function to save Parameter matrices into a given folder for SeeDot compiler
|
||||
'''
|
||||
seeDotDir = currDir + '/SeeDot/'
|
||||
|
||||
if os.path.isdir(seeDotDir) is False:
|
||||
try:
|
||||
os.mkdir(seeDotDir)
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" %
|
||||
seeDotDir)
|
||||
|
||||
np.savetxt(seeDotDir + "W",
|
||||
utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.W.data,
|
||||
self.bonsaiObj.numClasses,
|
||||
self.bonsaiObj.totalNodes),
|
||||
delimiter="\t")
|
||||
np.savetxt(seeDotDir + "V",
|
||||
utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.V.data,
|
||||
self.bonsaiObj.numClasses,
|
||||
self.bonsaiObj.totalNodes),
|
||||
delimiter="\t")
|
||||
np.savetxt(seeDotDir + "T", self.bonsaiObj.T.data, delimiter="\t")
|
||||
np.savetxt(seeDotDir + "Z", self.bonsaiObj.Z.data, delimiter="\t")
|
||||
np.savetxt(seeDotDir + "Sigma",
|
||||
np.array([self.bonsaiObj.sigma]), delimiter="\t")
|
||||
|
||||
def loadModel(self, currDir):
|
||||
'''
|
||||
Load the Saved model and load it to the model using constructor
|
||||
Returns two dict one for params and other for hyperParams
|
||||
'''
|
||||
paramDir = currDir + '/'
|
||||
paramDict = {}
|
||||
paramDict['W'] = np.load(paramDir + "W.npy")
|
||||
paramDict['V'] = np.load(paramDir + "V.npy")
|
||||
paramDict['T'] = np.load(paramDir + "T.npy")
|
||||
paramDict['Z'] = np.load(paramDir + "Z.npy")
|
||||
hyperParamDict = np.load(paramDir + "hyperParam.npy").item()
|
||||
return paramDict, hyperParamDict
|
||||
|
||||
# Function to get aimed model size
|
||||
def getModelSize(self):
|
||||
'''
|
||||
Function to get aimed model size
|
||||
'''
|
||||
nnzZ, sizeZ, sparseZ = utils.countnnZ(self.bonsaiObj.Z, self.sZ)
|
||||
nnzW, sizeW, sparseW = utils.countnnZ(self.bonsaiObj.W, self.sW)
|
||||
nnzV, sizeV, sparseV = utils.countnnZ(self.bonsaiObj.V, self.sV)
|
||||
nnzT, sizeT, sparseT = utils.countnnZ(self.bonsaiObj.T, self.sT)
|
||||
|
||||
totalnnZ = (nnzZ + nnzT + nnzV + nnzW)
|
||||
totalSize = (sizeZ + sizeW + sizeV + sizeT)
|
||||
hasSparse = (sparseW or sparseV or sparseT or sparseZ)
|
||||
return totalnnZ, totalSize, hasSparse
|
||||
|
||||
def train(self, batchSize, totalEpochs,
|
||||
Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir):
|
||||
'''
|
||||
The Dense - IHT - Sparse Retrain Routine for Bonsai Training
|
||||
'''
|
||||
resultFile = open(dataDir + '/pytorchBonsaiResults.txt', 'a+')
|
||||
numIters = Xtrain.shape[0] / batchSize
|
||||
|
||||
totalBatches = numIters * totalEpochs
|
||||
|
||||
self.sigmaI = 1
|
||||
|
||||
counter = 0
|
||||
if self.bonsaiObj.numClasses > 2:
|
||||
trimlevel = 15
|
||||
else:
|
||||
trimlevel = 5
|
||||
ihtDone = 0
|
||||
|
||||
maxTestAcc = -10000
|
||||
if self.isDenseTraining is True:
|
||||
ihtDone = 1
|
||||
self.sigmaI = 1
|
||||
itersInPhase = 0
|
||||
|
||||
header = '*' * 20
|
||||
for i in range(totalEpochs):
|
||||
print("\nEpoch Number: " + str(i), file=self.outFile)
|
||||
|
||||
'''
|
||||
trainAcc -> For Classification, it is 'Accuracy'.
|
||||
'''
|
||||
trainAcc = 0.0
|
||||
trainLoss = 0.0
|
||||
|
||||
numIters = int(numIters)
|
||||
for j in range(numIters):
|
||||
|
||||
if counter == 0:
|
||||
msg = " Dense Training Phase Started "
|
||||
print("\n%s%s%s\n" %
|
||||
(header, msg, header), file=self.outFile)
|
||||
|
||||
# Updating the indicator sigma
|
||||
if ((counter == 0) or (counter == int(totalBatches / 3.0)) or
|
||||
(counter == int(2 * totalBatches / 3.0))) and (self.isDenseTraining is False):
|
||||
self.sigmaI = 1
|
||||
itersInPhase = 0
|
||||
|
||||
elif (itersInPhase % 100 == 0):
|
||||
indices = np.random.choice(Xtrain.shape[0], 100)
|
||||
batchX = Xtrain[indices, :]
|
||||
batchY = Ytrain[indices, :]
|
||||
batchY = np.reshape(
|
||||
batchY, [-1, self.bonsaiObj.numClasses])
|
||||
|
||||
Teval = self.bonsaiObj.T.data
|
||||
Xcapeval = (torch.matmul(self.bonsaiObj.Z, torch.t(
|
||||
batchX)) / self.bonsaiObj.projectionDimension).data
|
||||
|
||||
sum_tr = 0.0
|
||||
for k in range(0, self.bonsaiObj.internalNodes):
|
||||
sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval))))
|
||||
|
||||
if(self.bonsaiObj.internalNodes > 0):
|
||||
sum_tr /= (100 * self.bonsaiObj.internalNodes)
|
||||
sum_tr = 0.1 / sum_tr
|
||||
else:
|
||||
sum_tr = 0.1
|
||||
sum_tr = min(
|
||||
1000, sum_tr * (2**(float(itersInPhase) /
|
||||
(float(totalBatches) / 30.0))))
|
||||
|
||||
self.sigmaI = sum_tr
|
||||
|
||||
itersInPhase += 1
|
||||
batchX = Xtrain[j * batchSize:(j + 1) * batchSize]
|
||||
batchY = Ytrain[j * batchSize:(j + 1) * batchSize]
|
||||
batchY = np.reshape(
|
||||
batchY, [-1, self.bonsaiObj.numClasses])
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
logits, _ = self.bonsaiObj(batchX, self.sigmaI)
|
||||
batchLoss, _, _ = self.loss(logits, batchY)
|
||||
batchAcc = self.accuracy(logits, batchY)
|
||||
|
||||
batchLoss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# Classification.
|
||||
|
||||
trainAcc += batchAcc.item()
|
||||
trainLoss += batchLoss.item()
|
||||
|
||||
# Training routine involving IHT and sparse retraining
|
||||
if (counter >= int(totalBatches / 3.0) and
|
||||
(counter < int(2 * totalBatches / 3.0)) and
|
||||
counter % trimlevel == 0 and
|
||||
self.isDenseTraining is False):
|
||||
self.runHardThrsd()
|
||||
if ihtDone == 0:
|
||||
msg = " IHT Phase Started "
|
||||
print("\n%s%s%s\n" %
|
||||
(header, msg, header), file=self.outFile)
|
||||
ihtDone = 1
|
||||
elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and
|
||||
(counter < int(2 * totalBatches / 3.0)) and
|
||||
counter % trimlevel != 0 and
|
||||
self.isDenseTraining is False) or
|
||||
(counter >= int(2 * totalBatches / 3.0) and
|
||||
self.isDenseTraining is False)):
|
||||
self.runSparseTraining()
|
||||
if counter == int(2 * totalBatches / 3.0):
|
||||
msg = " Sparse Retraining Phase Started "
|
||||
print("\n%s%s%s\n" %
|
||||
(header, msg, header), file=self.outFile)
|
||||
counter += 1
|
||||
|
||||
print("\nClassification Train Loss: " + str(trainLoss / numIters) +
|
||||
"\nTraining accuracy (Classification): " +
|
||||
str(trainAcc / numIters),
|
||||
file=self.outFile)
|
||||
|
||||
oldSigmaI = self.sigmaI
|
||||
self.sigmaI = 1e9
|
||||
logits, _ = self.bonsaiObj(Xtest, self.sigmaI)
|
||||
testLoss, marginLoss, regLoss = self.loss(logits, Ytest)
|
||||
testAcc = self.accuracy(logits, Ytest).item()
|
||||
|
||||
if ihtDone == 0:
|
||||
maxTestAcc = -10000
|
||||
maxTestAccEpoch = i
|
||||
else:
|
||||
if maxTestAcc <= testAcc:
|
||||
maxTestAccEpoch = i
|
||||
maxTestAcc = testAcc
|
||||
self.saveParams(currDir)
|
||||
self.saveParamsForSeeDot(currDir)
|
||||
|
||||
print("Test accuracy %g" % testAcc, file=self.outFile)
|
||||
|
||||
testAcc = testAcc
|
||||
maxTestAcc = maxTestAcc
|
||||
|
||||
print("MarginLoss + RegLoss: " + str(marginLoss.item()) + " + " +
|
||||
str(regLoss.item()) + " = " + str(testLoss.item()) + "\n",
|
||||
file=self.outFile)
|
||||
self.outFile.flush()
|
||||
|
||||
self.sigmaI = oldSigmaI
|
||||
|
||||
# sigmaI has to be set to infinity to ensure
|
||||
# only a single path is used in inference
|
||||
self.sigmaI = 1e9
|
||||
print("\nNon-Zero : " + str(self.getModelSize()[0]) + " Model Size: " +
|
||||
str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " +
|
||||
str(self.getModelSize()[2]) + "\n", file=self.outFile)
|
||||
|
||||
print("For Classification, Maximum Test accuracy at compressed" +
|
||||
" model size(including early stopping): " +
|
||||
str(maxTestAcc) + " at Epoch: " +
|
||||
str(maxTestAccEpoch + 1) + "\nFinal Test" +
|
||||
" Accuracy: " + str(testAcc), file=self.outFile)
|
||||
|
||||
resultFile.write("MaxTestAcc: " + str(maxTestAcc) +
|
||||
" at Epoch(totalEpochs): " +
|
||||
str(maxTestAccEpoch + 1) +
|
||||
"(" + str(totalEpochs) + ")" + " ModelSize: " +
|
||||
str(float(self.getModelSize()[1]) / 1024.0) +
|
||||
" KB hasSparse: " + str(self.getModelSize()[2]) +
|
||||
" Param Directory: " +
|
||||
str(os.path.abspath(currDir)) + "\n")
|
||||
print("The Model Directory: " + currDir + "\n")
|
||||
|
||||
resultFile.close()
|
||||
self.outFile.flush()
|
||||
|
||||
if self.outFile is not sys.stdout:
|
||||
self.outFile.close()
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def multiClassHingeLoss(logits, labels):
|
||||
'''
|
||||
MultiClassHingeLoss to match C++ Version - No pytorch internal version
|
||||
'''
|
||||
flatLogits = torch.reshape(logits, [-1, ])
|
||||
labels_ = labels.argmax(dim=1)
|
||||
|
||||
correctId = torch.arange(labels.shape[0]) * labels.shape[1] + labels_
|
||||
correctLogit = torch.gather(flatLogits, 0, correctId)
|
||||
|
||||
maxLabel = logits.argmax(dim=1)
|
||||
top2, _ = torch.topk(logits, k=2, sorted=True)
|
||||
|
||||
wrongMaxLogit = torch.where((maxLabel == labels_), top2[:, 1], top2[:, 0])
|
||||
|
||||
return torch.mean(F.relu(1. + wrongMaxLogit - correctLogit))
|
||||
|
||||
|
||||
def crossEntropyLoss(logits, labels):
|
||||
'''
|
||||
Cross Entropy loss for MultiClass case in joint training for
|
||||
faster convergence
|
||||
'''
|
||||
return F.cross_entropy(logits, labels.argmax(dim=1))
|
||||
|
||||
|
||||
def binaryHingeLoss(logits, labels):
|
||||
'''
|
||||
BinaryHingeLoss to match C++ Version - No pytorch internal version
|
||||
'''
|
||||
return torch.mean(F.relu(1.0 - (2 * labels - 1) * logits))
|
||||
|
||||
|
||||
def hardThreshold(A, s):
|
||||
'''
|
||||
Hard thresholding function on Tensor A with sparsity s
|
||||
'''
|
||||
A_ = np.copy(A)
|
||||
A_ = A_.ravel()
|
||||
if len(A_) > 0:
|
||||
th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher')
|
||||
A_[np.abs(A_) < th] = 0.0
|
||||
A_ = A_.reshape(A.shape)
|
||||
return A_
|
||||
|
||||
|
||||
def copySupport(src, dest):
|
||||
'''
|
||||
copy support of src tensor to dest tensor
|
||||
'''
|
||||
support = np.nonzero(src)
|
||||
dest_ = dest
|
||||
dest = np.zeros(dest_.shape)
|
||||
dest[support] = dest_[support]
|
||||
return dest
|
||||
|
||||
|
||||
def countnnZ(A, s, bytesPerVar=4):
|
||||
'''
|
||||
Returns # of non-zeros and representative size of the tensor
|
||||
Uses dense for s >= 0.5 - 4 byte
|
||||
Else uses sparse - 8 byte
|
||||
'''
|
||||
params = 1
|
||||
hasSparse = False
|
||||
for i in range(0, len(A.shape)):
|
||||
params *= int(A.shape[i])
|
||||
if s < 0.5:
|
||||
nnZ = np.ceil(params * s)
|
||||
hasSparse = True
|
||||
return nnZ, nnZ * 2 * bytesPerVar, hasSparse
|
||||
else:
|
||||
nnZ = params
|
||||
return nnZ, nnZ * bytesPerVar, hasSparse
|
||||
|
||||
|
||||
def restructreMatrixBonsaiSeeDot(A, nClasses, nNodes):
|
||||
'''
|
||||
Restructures a matrix from [nNodes*nClasses, Proj] to
|
||||
[nClasses*nNodes, Proj] for SeeDot
|
||||
'''
|
||||
tempMatrix = np.zeros(A.shape)
|
||||
rowIndex = 0
|
||||
|
||||
for i in range(0, nClasses):
|
||||
for j in range(0, nNodes):
|
||||
tempMatrix[rowIndex] = A[j * nClasses + i]
|
||||
rowIndex += 1
|
||||
|
||||
return tempMatrix
|
|
@ -0,0 +1,9 @@
|
|||
from distutils.core import setup
|
||||
|
||||
setup(
|
||||
name='pytorch_edgeml',
|
||||
version='0.2',
|
||||
packages=['pytorch_edgeml', ],
|
||||
license='MIT License',
|
||||
long_description=open('../License.txt').read(),
|
||||
)
|
Загрузка…
Ссылка в новой задаче