зеркало из https://github.com/microsoft/EdgeML.git
Add FastGRNNCUDA to FastCells and Fix torch.randn() Argument Errors
This commit is contained in:
Родитель
9020ed3115
Коммит
635c899b8b
|
@ -57,6 +57,11 @@ def main():
|
|||
gate_nonlinearity=gate_non_linearity,
|
||||
update_nonlinearity=update_non_linearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
elif cell == "FastGRNNCUDA":
|
||||
FastCell = FastGRNNCUDACell(inputDims, hiddenDims,
|
||||
gate_nonlinearity=gate_non_linearity,
|
||||
update_nonlinearity=update_non_linearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
elif cell == "FastRNN":
|
||||
FastCell = FastRNNCell(inputDims, hiddenDims,
|
||||
update_nonlinearity=update_non_linearity,
|
||||
|
|
|
@ -88,8 +88,8 @@ def getArgs():
|
|||
'train.npy and test.npy')
|
||||
|
||||
parser.add_argument('-c', '--cell', type=str, default="FastGRNN",
|
||||
help='Choose between [FastGRNN, FastRNN, UGRNN' +
|
||||
', GRU, LSTM], default: FastGRNN')
|
||||
help='Choose between [FastGRNN, FastGRNNCUDA, FastRNN,' +
|
||||
' UGRNN, GRU, LSTM], default: FastGRNN')
|
||||
|
||||
parser.add_argument('-id', '--input-dim', type=checkIntNneg, required=True,
|
||||
help='Input Dimension of RNN, each timestep will ' +
|
||||
|
|
|
@ -8,8 +8,8 @@ if findCUDA() is not None:
|
|||
name='fastgrnn_cuda',
|
||||
ext_modules=[
|
||||
CUDAExtension('fastgrnn_cuda', [
|
||||
'edgeml_pytorch/cuda/fastgrnn_cuda.cpp',
|
||||
'edgeml_pytorch/cuda/fastgrnn_cuda_kernel.cu',
|
||||
'fastgrnn_cuda.cpp',
|
||||
'fastgrnn_cuda_kernel.cu',
|
||||
]),
|
||||
],
|
||||
cmdclass={
|
||||
|
|
|
@ -13,6 +13,7 @@ try:
|
|||
if utils.findCUDA() is not None:
|
||||
import fastgrnn_cuda
|
||||
except:
|
||||
print("Running without FastGRNN CUDA")
|
||||
pass
|
||||
|
||||
|
||||
|
@ -354,29 +355,29 @@ class FastGRNNCUDACell(RNNCell):
|
|||
self._name = name
|
||||
|
||||
if wRank is None:
|
||||
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], self.device))
|
||||
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size], device=self.device))
|
||||
self.W1 = torch.empty(0)
|
||||
self.W2 = torch.empty(0)
|
||||
else:
|
||||
self.W = torch.empty(0)
|
||||
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], self.device))
|
||||
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], self.device))
|
||||
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size], device=self.device))
|
||||
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank], device=self.device))
|
||||
|
||||
if uRank is None:
|
||||
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], self.device))
|
||||
self.U = nn.Parameter(0.1 * torch.randn([hidden_size, hidden_size], device=self.device))
|
||||
self.U1 = torch.empty(0)
|
||||
self.U2 = torch.empty(0)
|
||||
else:
|
||||
self.U = torch.empty(0)
|
||||
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], self.device))
|
||||
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], self.device))
|
||||
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size], device=self.device))
|
||||
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank], device=self.device))
|
||||
|
||||
self._gate_non_linearity = NON_LINEARITY[gate_nonlinearity]
|
||||
|
||||
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], self.device))
|
||||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], self.device))
|
||||
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], self.device))
|
||||
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], self.device))
|
||||
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
|
||||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size], device=self.device))
|
||||
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1], device=self.device))
|
||||
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1], device=self.device))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче