Well tested gpu mode of Fastcells and Bonsai

This commit is contained in:
Aditya Kusupati 2019-06-22 00:40:24 +00:00
Родитель 7ec84702a6
Коммит 395ee84765
3 изменённых файлов: 7 добавлений и 10 удалений

Просмотреть файл

@ -46,7 +46,6 @@ def main():
assert dataDimension % inputDims == 0, "Infeasible per step input, " + \
"Timesteps have to be integer"
timeSteps = int(dataDimension/inputDims)
currDir = helpermethods.createTimeStampDir(dataDir, cell)
@ -77,7 +76,7 @@ def main():
else:
sys.exit('Exiting: No Such Cell as ' + cell)
FastCellTrainer = FastTrainer(FastCell, timeSteps, numClasses, sW=sW, sU=sU,
FastCellTrainer = FastTrainer(FastCell, numClasses, sW=sW, sU=sU,
learningRate=learningRate, outFile=outFile, device=device)
FastCellTrainer.train(batchSize, totalEpochs,

Просмотреть файл

@ -43,23 +43,22 @@ class BaseRNN(nn.Module):
[batchSize, timeSteps, inputDims]
'''
def __init__(self, RNNCell, timeSteps, device=None):
def __init__(self, RNNCell, device=None):
super(BaseRNN, self).__init__()
self.RNNCell = RNNCell
if device is None:
self.device = "cpu"
else:
self.device = device
self.timeSteps = timeSteps
def forward(self, input):
hiddenStates = torch.zeros(
[input.shape[0], self.timeSteps, self.RNNCell.output_size]).to(self.device)
[input.shape[0], input.shape[1], self.RNNCell.output_size]).to(self.device)
hiddenState = torch.zeros([input.shape[0], self.RNNCell.output_size]).to(self.device)
if self.RNNCell.cellType == "LSTMLR":
cellStates = torch.zeros(
[input.shape[0], self.timeSteps, self.RNNCell.output_size]).to(self.device)
[input.shape[0], input.shape[1], self.RNNCell.output_size]).to(self.device)
cellState = torch.zeros([input.shape[0], self.RNNCell.output_size]).to(self.device)
for i in range(0, input.shape[1]):
hiddenState, cellState = self.RNNCell(

Просмотреть файл

@ -11,7 +11,7 @@ import numpy as np
class FastTrainer:
def __init__(self, FastObj, timeSteps, numClasses, sW=1.0, sU=1.0, learningRate=0.01, outFile=None, device=None):
def __init__(self, FastObj, numClasses, sW=1.0, sU=1.0, learningRate=0.01, outFile=None, device=None):
'''
FastObj - Can be either FastRNN or FastGRNN or any of the RNN cells
in graph.rnn with proper initialisations
@ -44,14 +44,13 @@ class FastTrainer:
else:
self.isDenseTraining = False
self.timeSteps = timeSteps
self.assertInit()
self.numMatrices = self.FastObj.num_weight_matrices
self.totalMatrices = self.numMatrices[0] + self.numMatrices[1]
self.optimizer = self.optimizer()
self.RNN = BaseRNN(self.FastObj, self.timeSteps, self.device).to(device)
self.RNN = BaseRNN(self.FastObj, self.device).to(device)
self.FC = nn.Parameter(torch.randn([self.FastObj.output_size, self.numClasses])).to(device)
self.FCbias = nn.Parameter(torch.randn([self.numClasses])).to(device)
@ -353,7 +352,7 @@ class FastTrainer:
ihtDone = 1
maxTestAcc = -10000
header = '*' * 20
self.timeSteps = int(Xtest.shape[1]/self.inputDims)
Xtest = Xtest.reshape((-1, self.timeSteps, self.inputDims))
Xtrain = Xtrain.reshape((-1, self.timeSteps, self.inputDims))