Add FastGRNNCUDA to FastCells and Fix torch.randn() Argument Errors

This commit is contained in:
ShikharJ 2020-12-18 13:45:08 +05:30
Родитель 9020ed3115
Коммит 635c899b8b
4 изменённых файлов: 20 добавлений и 14 удалений

Просмотреть файл

@ -57,6 +57,11 @@ def main():
gate_nonlinearity=gate_non_linearity,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "FastGRNNCUDA":
FastCell = FastGRNNCUDACell(inputDims, hiddenDims,
gate_nonlinearity=gate_non_linearity,
update_nonlinearity=update_non_linearity,
wRank=wRank, uRank=uRank)
elif cell == "FastRNN":
FastCell = FastRNNCell(inputDims, hiddenDims,
update_nonlinearity=update_non_linearity,

Просмотреть файл

@ -88,8 +88,8 @@ def getArgs():
'train.npy and test.npy')
parser.add_argument('-c', '--cell', type=str, default="FastGRNN",
help='Choose between [FastGRNN, FastRNN, UGRNN' +
', GRU, LSTM], default: FastGRNN')
help='Choose between [FastGRNN, FastGRNNCUDA, FastRNN,' +
' UGRNN, GRU, LSTM], default: FastGRNN')
parser.add_argument('-id', '--input-dim', type=checkIntNneg, required=True,
help='Input Dimension of RNN, each timestep will ' +

Просмотреть файл

@ -8,8 +8,8 @@ if findCUDA() is not None:
name='fastgrnn_cuda',
ext_modules=[
CUDAExtension('fastgrnn_cuda', [
'edgeml_pytorch/cuda/fastgrnn_cuda.cpp',
'edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu',
'fastgrnn_cuda.cpp',
'fastgrnn_cuda_kernel.cu',
]),
],
cmdclass={

Просмотреть файл

@ -13,6 +13,7 @@ try:
if utils.findCUDA() is not None:
import fastgrnn_cuda
except:
print("Running without FastGRNN CUDA")
pass
@ -354,29 +355,29 @@ class FastGRNNCUDACell(RNNCell):
self._name = name
if wRank is None:
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], self.device))
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], device=self.device))
self.W1 = torch.empty(0)
self.W2 = torch.empty(0)
else:
self.W = torch.empty(0)
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], self.device))
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], self.device))
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], device=self.device))
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], device=self.device))
if uRank is None:
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], self.device))
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], device=self.device))
self.U1 = torch.empty(0)
self.U2 = torch.empty(0)
else:
self.U = torch.empty(0)
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], self.device))
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], self.device))
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], device=self.device))
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], device=self.device))
self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], self.device))
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], self.device))
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], self.device))
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], self.device))
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], device=self.device))
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], device=self.device))
@property
def name(self):