зеркало из https://github.com/microsoft/EdgeML.git
Well tested gpu mode of Fastcells and Bonsai
This commit is contained in:
Родитель
7ec84702a6
Коммит
395ee84765
|
@ -46,7 +46,6 @@ def main():
|
|||
assert dataDimension % inputDims == 0, "Infeasible per step input, " + \
|
||||
"Timesteps have to be integer"
|
||||
|
||||
timeSteps = int(dataDimension/inputDims)
|
||||
|
||||
currDir = helpermethods.createTimeStampDir(dataDir, cell)
|
||||
|
||||
|
@ -77,7 +76,7 @@ def main():
|
|||
else:
|
||||
sys.exit('Exiting: No Such Cell as ' + cell)
|
||||
|
||||
FastCellTrainer = FastTrainer(FastCell, timeSteps, numClasses, sW=sW, sU=sU,
|
||||
FastCellTrainer = FastTrainer(FastCell, numClasses, sW=sW, sU=sU,
|
||||
learningRate=learningRate, outFile=outFile, device=device)
|
||||
|
||||
FastCellTrainer.train(batchSize, totalEpochs,
|
||||
|
|
|
@ -43,23 +43,22 @@ class BaseRNN(nn.Module):
|
|||
[batchSize, timeSteps, inputDims]
|
||||
'''
|
||||
|
||||
def __init__(self, RNNCell, timeSteps, device=None):
|
||||
def __init__(self, RNNCell, device=None):
|
||||
super(BaseRNN, self).__init__()
|
||||
self.RNNCell = RNNCell
|
||||
if device is None:
|
||||
self.device = "cpu"
|
||||
else:
|
||||
self.device = device
|
||||
self.timeSteps = timeSteps
|
||||
|
||||
def forward(self, input):
|
||||
hiddenStates = torch.zeros(
|
||||
[input.shape[0], self.timeSteps, self.RNNCell.output_size]).to(self.device)
|
||||
[input.shape[0], input.shape[1], self.RNNCell.output_size]).to(self.device)
|
||||
|
||||
hiddenState = torch.zeros([input.shape[0], self.RNNCell.output_size]).to(self.device)
|
||||
if self.RNNCell.cellType == "LSTMLR":
|
||||
cellStates = torch.zeros(
|
||||
[input.shape[0], self.timeSteps, self.RNNCell.output_size]).to(self.device)
|
||||
[input.shape[0], input.shape[1], self.RNNCell.output_size]).to(self.device)
|
||||
cellState = torch.zeros([input.shape[0], self.RNNCell.output_size]).to(self.device)
|
||||
for i in range(0, input.shape[1]):
|
||||
hiddenState, cellState = self.RNNCell(
|
||||
|
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
|
||||
class FastTrainer:
|
||||
|
||||
def __init__(self, FastObj, timeSteps, numClasses, sW=1.0, sU=1.0, learningRate=0.01, outFile=None, device=None):
|
||||
def __init__(self, FastObj, numClasses, sW=1.0, sU=1.0, learningRate=0.01, outFile=None, device=None):
|
||||
'''
|
||||
FastObj - Can be either FastRNN or FastGRNN or any of the RNN cells
|
||||
in graph.rnn with proper initialisations
|
||||
|
@ -44,14 +44,13 @@ class FastTrainer:
|
|||
else:
|
||||
self.isDenseTraining = False
|
||||
|
||||
self.timeSteps = timeSteps
|
||||
self.assertInit()
|
||||
self.numMatrices = self.FastObj.num_weight_matrices
|
||||
self.totalMatrices = self.numMatrices[0] + self.numMatrices[1]
|
||||
|
||||
self.optimizer = self.optimizer()
|
||||
|
||||
self.RNN = BaseRNN(self.FastObj, self.timeSteps, self.device).to(device)
|
||||
self.RNN = BaseRNN(self.FastObj, self.device).to(device)
|
||||
|
||||
self.FC = nn.Parameter(torch.randn([self.FastObj.output_size, self.numClasses])).to(device)
|
||||
self.FCbias = nn.Parameter(torch.randn([self.numClasses])).to(device)
|
||||
|
@ -353,7 +352,7 @@ class FastTrainer:
|
|||
ihtDone = 1
|
||||
maxTestAcc = -10000
|
||||
header = '*' * 20
|
||||
|
||||
self.timeSteps = int(Xtest.shape[1]/self.inputDims)
|
||||
Xtest = Xtest.reshape((-1, self.timeSteps, self.inputDims))
|
||||
Xtrain = Xtrain.reshape((-1, self.timeSteps, self.inputDims))
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче