зеркало из https://github.com/microsoft/EdgeML.git
Cosmetic changes
This commit is contained in:
Родитель
395ee84765
Коммит
d9ecbd6cdf
|
@ -8,6 +8,7 @@ from pytorch_edgeml.trainer.bonsaiTrainer import BonsaiTrainer
|
|||
from pytorch_edgeml.graph.bonsai import Bonsai
|
||||
import torch
|
||||
|
||||
|
||||
def main():
|
||||
# change cuda:0 to cuda:gpuid for specific allocation
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
|
|
@ -46,7 +46,6 @@ def main():
|
|||
assert dataDimension % inputDims == 0, "Infeasible per step input, " + \
|
||||
"Timesteps have to be integer"
|
||||
|
||||
|
||||
currDir = helpermethods.createTimeStampDir(dataDir, cell)
|
||||
|
||||
helpermethods.dumpCommand(sys.argv, currDir)
|
||||
|
@ -77,7 +76,8 @@ def main():
|
|||
sys.exit('Exiting: No Such Cell as ' + cell)
|
||||
|
||||
FastCellTrainer = FastTrainer(FastCell, numClasses, sW=sW, sU=sU,
|
||||
learningRate=learningRate, outFile=outFile, device=device)
|
||||
learningRate=learningRate, outFile=outFile,
|
||||
device=device)
|
||||
|
||||
FastCellTrainer.train(batchSize, totalEpochs,
|
||||
torch.from_numpy(Xtrain.astype(np.float32)),
|
||||
|
|
|
@ -53,13 +53,17 @@ class BaseRNN(nn.Module):
|
|||
|
||||
def forward(self, input):
|
||||
hiddenStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1], self.RNNCell.output_size]).to(self.device)
|
||||
[input.shape[0], input.shape[1],
|
||||
self.RNNCell.output_size]).to(self.device)
|
||||
|
||||
hiddenState = torch.zeros([input.shape[0], self.RNNCell.output_size]).to(self.device)
|
||||
hiddenState = torch.zeros([input.shape[0],
|
||||
self.RNNCell.output_size]).to(self.device)
|
||||
if self.RNNCell.cellType == "LSTMLR":
|
||||
cellStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1], self.RNNCell.output_size]).to(self.device)
|
||||
cellState = torch.zeros([input.shape[0], self.RNNCell.output_size]).to(self.device)
|
||||
[input.shape[0], input.shape[1],
|
||||
self.RNNCell.output_size]).to(self.device)
|
||||
cellState = torch.zeros(
|
||||
[input.shape[0], self.RNNCell.output_size]).to(self.device)
|
||||
for i in range(0, input.shape[1]):
|
||||
hiddenState, cellState = self.RNNCell(
|
||||
input[:, i, :], (hiddenState, cellState))
|
||||
|
|
|
@ -69,7 +69,8 @@ class BonsaiTrainer:
|
|||
|
||||
if (self.bonsaiObj.numClasses > 2):
|
||||
if self.useMCHLoss is True:
|
||||
marginLoss = utils.multiClassHingeLoss(logits, labels, self.device)
|
||||
marginLoss = utils.multiClassHingeLoss(
|
||||
logits, labels, self.device)
|
||||
else:
|
||||
marginLoss = utils.crossEntropyLoss(logits, labels)
|
||||
loss = marginLoss + regLoss
|
||||
|
@ -116,10 +117,14 @@ class BonsaiTrainer:
|
|||
self.__thrsdZ = utils.hardThreshold(currZ.cpu(), self.sZ)
|
||||
self.__thrsdT = utils.hardThreshold(currT.cpu(), self.sT)
|
||||
|
||||
self.bonsaiObj.W.data = torch.FloatTensor(self.__thrsdW).to(self.device)
|
||||
self.bonsaiObj.V.data = torch.FloatTensor(self.__thrsdV).to(self.device)
|
||||
self.bonsaiObj.Z.data = torch.FloatTensor(self.__thrsdZ).to(self.device)
|
||||
self.bonsaiObj.T.data = torch.FloatTensor(self.__thrsdT).to(self.device)
|
||||
self.bonsaiObj.W.data = torch.FloatTensor(
|
||||
self.__thrsdW).to(self.device)
|
||||
self.bonsaiObj.V.data = torch.FloatTensor(
|
||||
self.__thrsdV).to(self.device)
|
||||
self.bonsaiObj.Z.data = torch.FloatTensor(
|
||||
self.__thrsdZ).to(self.device)
|
||||
self.bonsaiObj.T.data = torch.FloatTensor(
|
||||
self.__thrsdT).to(self.device)
|
||||
|
||||
def runSparseTraining(self):
|
||||
'''
|
||||
|
@ -283,7 +288,8 @@ class BonsaiTrainer:
|
|||
|
||||
sum_tr = 0.0
|
||||
for k in range(0, self.bonsaiObj.internalNodes):
|
||||
sum_tr += (np.sum(np.abs(np.dot(Teval[k].cpu(), Xcapeval.cpu()))))
|
||||
sum_tr += (
|
||||
np.sum(np.abs(np.dot(Teval[k].cpu(), Xcapeval.cpu()))))
|
||||
|
||||
if(self.bonsaiObj.internalNodes > 0):
|
||||
sum_tr /= (100 * self.bonsaiObj.internalNodes)
|
||||
|
@ -347,7 +353,8 @@ class BonsaiTrainer:
|
|||
oldSigmaI = self.sigmaI
|
||||
self.sigmaI = 1e9
|
||||
logits, _ = self.bonsaiObj(Xtest.to(self.device), self.sigmaI)
|
||||
testLoss, marginLoss, regLoss = self.loss(logits, Ytest.to(self.device))
|
||||
testLoss, marginLoss, regLoss = self.loss(
|
||||
logits, Ytest.to(self.device))
|
||||
testAcc = self.accuracy(logits, Ytest.to(self.device)).item()
|
||||
|
||||
if ihtDone == 0:
|
||||
|
|
|
@ -9,9 +9,11 @@ import pytorch_edgeml.utils as utils
|
|||
from pytorch_edgeml.graph.rnn import *
|
||||
import numpy as np
|
||||
|
||||
|
||||
class FastTrainer:
|
||||
|
||||
def __init__(self, FastObj, numClasses, sW=1.0, sU=1.0, learningRate=0.01, outFile=None, device=None):
|
||||
def __init__(self, FastObj, numClasses, sW=1.0, sU=1.0,
|
||||
learningRate=0.01, outFile=None, device=None):
|
||||
'''
|
||||
FastObj - Can be either FastRNN or FastGRNN or any of the RNN cells
|
||||
in graph.rnn with proper initialisations
|
||||
|
@ -52,7 +54,8 @@ class FastTrainer:
|
|||
|
||||
self.RNN = BaseRNN(self.FastObj, self.device).to(device)
|
||||
|
||||
self.FC = nn.Parameter(torch.randn([self.FastObj.output_size, self.numClasses])).to(device)
|
||||
self.FC = nn.Parameter(torch.randn(
|
||||
[self.FastObj.output_size, self.numClasses])).to(device)
|
||||
self.FCbias = nn.Parameter(torch.randn([self.numClasses])).to(device)
|
||||
|
||||
self.FastParams = self.FastObj.getVars()
|
||||
|
@ -120,7 +123,8 @@ class FastTrainer:
|
|||
self.thrsdParams.append(
|
||||
utils.hardThreshold(self.FastParams[i].data.cpu(), self.sU))
|
||||
for i in range(0, self.totalMatrices):
|
||||
self.FastParams[i].data = torch.FloatTensor(self.thrsdParams[i]).to(self.device)
|
||||
self.FastParams[i].data = torch.FloatTensor(
|
||||
self.thrsdParams[i]).to(self.device)
|
||||
|
||||
def runSparseTraining(self):
|
||||
'''
|
||||
|
@ -129,10 +133,11 @@ class FastTrainer:
|
|||
self.reTrainParams = []
|
||||
for i in range(0, self.totalMatrices):
|
||||
self.reTrainParams.append(
|
||||
utils.copySupport(self.thrsdParams[i].data, self.FastParams[i].data.cpu()))
|
||||
utils.copySupport(self.thrsdParams[i].data,
|
||||
self.FastParams[i].data.cpu()))
|
||||
for i in range(0, self.totalMatrices):
|
||||
self.FastParams[i].data = torch.FloatTensor(self.reTrainParams[i]).to(self.device)
|
||||
|
||||
self.FastParams[i].data = torch.FloatTensor(
|
||||
self.reTrainParams[i]).to(self.device)
|
||||
|
||||
def getModelSize(self):
|
||||
'''
|
||||
|
@ -176,7 +181,8 @@ class FastTrainer:
|
|||
Function to save Parameter matrices
|
||||
'''
|
||||
if self.numMatrices[0] == 1:
|
||||
np.save(os.path.join(currDir, "W.npy"), self.FastParams[0].data.cpu())
|
||||
np.save(os.path.join(currDir, "W.npy"),
|
||||
self.FastParams[0].data.cpu())
|
||||
elif self.FastObj.wRank is None:
|
||||
if self.numMatrices[0] == 2:
|
||||
np.save(os.path.join(currDir, "W1.npy"),
|
||||
|
@ -235,7 +241,8 @@ class FastTrainer:
|
|||
|
||||
idx = self.numMatrices[0]
|
||||
if self.numMatrices[1] == 1:
|
||||
np.save(os.path.join(currDir, "U.npy"), self.FastParams[idx + 0].data.cpu())
|
||||
np.save(os.path.join(currDir, "U.npy"),
|
||||
self.FastParams[idx + 0].data.cpu())
|
||||
elif self.FastObj.uRank is None:
|
||||
if self.numMatrices[1] == 2:
|
||||
np.save(os.path.join(currDir, "U1.npy"),
|
||||
|
@ -333,7 +340,6 @@ class FastTrainer:
|
|||
np.save(os.path.join(currDir, "FC.npy"), self.FC.data.cpu())
|
||||
np.save(os.path.join(currDir, "FCbias.npy"), self.FCbias.data.cpu())
|
||||
|
||||
|
||||
def train(self, batchSize, totalEpochs, Xtrain, Xtest, Ytrain, Ytest,
|
||||
decayStep, decayRate, dataDir, currDir):
|
||||
'''
|
||||
|
@ -423,7 +429,6 @@ class FastTrainer:
|
|||
testLoss = self.loss(logits, Ytest.to(self.device)).item()
|
||||
testAcc = self.accuracy(logits, Ytest.to(self.device)).item()
|
||||
|
||||
|
||||
if ihtDone == 0:
|
||||
maxTestAcc = -10000
|
||||
maxTestAccEpoch = i
|
||||
|
@ -437,7 +442,6 @@ class FastTrainer:
|
|||
" Test Accuracy: " + str(testAcc), file=self.outFile)
|
||||
self.outFile.flush()
|
||||
|
||||
|
||||
print("\nMaximum Test accuracy at compressed" +
|
||||
" model size(including early stopping): " +
|
||||
str(maxTestAcc) + " at Epoch: " +
|
||||
|
@ -467,10 +471,3 @@ class FastTrainer:
|
|||
self.outFile.flush()
|
||||
if self.outFile is not sys.stdout:
|
||||
self.outFile.close()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче