* SRNN: fix default cell parameters

* SRNN: add model saving to example script
This commit is contained in:
Moksh Jain 2019-11-19 01:01:40 +05:30 коммит произвёл Harsha Vardhan Simhadri
Родитель 0d27f3ebb6
Коммит 3a56912caa
2 изменённых файлов: 10 добавлений и 1 удалений

Просмотреть файл

@ -30,6 +30,8 @@ std[std[:] < 0.000001] = 1
x_train_ = (x_train_ - mean) / std
x_val_ = (x_val_ - mean) / std
x_test_ = (x_test_ - mean) / std
np.save('mean.npy', mean)
np.save('std.npy', std)
x_train = np.swapaxes(x_train_, 0, 1)
x_val = np.swapaxes(x_val_, 0, 1)
@ -73,3 +75,6 @@ trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)
trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val,
printStep=printStep, valStep=valStep)
print('Saving trained model:')
torch.save(srnn2.state_dict(), 'model_srnn.pt')

Просмотреть файл

@ -1273,7 +1273,11 @@ class SRNN2(nn.Module):
assert 0 < dropoutProbability0 <= 1.0
if dropoutProbability1 != None:
assert 0 < dropoutProbability1 <= 1.0
self.cellArgs = {}
# Setting batch_first = False to ensure compatibility of parameters across nn.LSTM and the
# other low-rank implementations
self.cellArgs = {
'batch_first': False
}
self.cellArgs.update(cellArgs)
supportedCells = ['LSTM', 'FastRNNCell', 'FastGRNNCell', 'GRULRCell']
assert cellType in supportedCells, 'Currently supported cells: %r' % supportedCells