зеркало из https://github.com/microsoft/EdgeML.git
Fix SRNN Training (#144)
* SRNN: fix default cell parameters * SRNN: add model saving to example script
This commit is contained in:
Родитель
0d27f3ebb6
Коммит
3a56912caa
|
@ -30,6 +30,8 @@ std[std[:] < 0.000001] = 1
|
|||
x_train_ = (x_train_ - mean) / std
|
||||
x_val_ = (x_val_ - mean) / std
|
||||
x_test_ = (x_test_ - mean) / std
|
||||
np.save('mean.npy', mean)
|
||||
np.save('std.npy', std)
|
||||
|
||||
x_train = np.swapaxes(x_train_, 0, 1)
|
||||
x_val = np.swapaxes(x_val_, 0, 1)
|
||||
|
@ -73,3 +75,6 @@ trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)
|
|||
|
||||
trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val,
|
||||
printStep=printStep, valStep=valStep)
|
||||
|
||||
print('Saving trained model:')
|
||||
torch.save(srnn2.state_dict(), 'model_srnn.pt')
|
||||
|
|
|
@ -1273,7 +1273,11 @@ class SRNN2(nn.Module):
|
|||
assert 0 < dropoutProbability0 <= 1.0
|
||||
if dropoutProbability1 != None:
|
||||
assert 0 < dropoutProbability1 <= 1.0
|
||||
self.cellArgs = {}
|
||||
# Setting batch_first = False to ensure compatibility of parameters across nn.LSTM and the
|
||||
# other low-rank implementations
|
||||
self.cellArgs = {
|
||||
'batch_first': False
|
||||
}
|
||||
self.cellArgs.update(cellArgs)
|
||||
supportedCells = ['LSTM', 'FastRNNCell', 'FastGRNNCell', 'GRULRCell']
|
||||
assert cellType in supportedCells, 'Currently supported cells: %r' % supportedCells
|
||||
|
|
Загрузка…
Ссылка в новой задаче