зеркало из https://github.com/microsoft/EdgeML.git
cleaned up FastGRNN code
This commit is contained in:
Родитель
a4deef91c2
Коммит
c2ac41c868
|
@ -58,87 +58,15 @@ def gen_nonlinearity(A, nonlinearity):
|
|||
return nonlinearity(A)
|
||||
|
||||
|
||||
class BaseRNN(nn.Module):
|
||||
'''
|
||||
Generic equivalent of static_rnn in tf
|
||||
Used to unroll all the cell written in this file
|
||||
We assume input to be batch_first by default ie.,
|
||||
[batchSize, timeSteps, inputDims] else
|
||||
[timeSteps, batchSize, inputDims]
|
||||
'''
|
||||
|
||||
def __init__(self, RNNCell, batch_first=True):
|
||||
super(BaseRNN, self).__init__()
|
||||
self._RNNCell = RNNCell
|
||||
self._batch_first = batch_first
|
||||
|
||||
def getVars(self):
|
||||
return self._RNNCell.getVars()
|
||||
|
||||
def forward(self, input, hiddenState=None,
|
||||
cellState=None):
|
||||
if self._batch_first is True:
|
||||
self.device = input.device
|
||||
hiddenStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if hiddenState is None:
|
||||
hiddenState = torch.zeros([input.shape[0],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if self._RNNCell.cellType == "LSTMLR":
|
||||
cellStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if cellState is None:
|
||||
cellState = torch.zeros(
|
||||
[input.shape[0], self._RNNCell.output_size]).to(self.device)
|
||||
for i in range(0, input.shape[1]):
|
||||
hiddenState, cellState = self._RNNCell(
|
||||
input[:, i, :], (hiddenState, cellState))
|
||||
hiddenStates[:, i, :] = hiddenState
|
||||
cellStates[:, i, :] = cellState
|
||||
return hiddenStates, cellStates
|
||||
else:
|
||||
for i in range(0, input.shape[1]):
|
||||
hiddenState = self._RNNCell(input[:, i, :], hiddenState)
|
||||
hiddenStates[:, i, :] = hiddenState
|
||||
return hiddenStates
|
||||
else:
|
||||
self.device = input.device
|
||||
hiddenStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if hiddenState is None:
|
||||
hiddenState = torch.zeros([input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if self._RNNCell.cellType == "LSTMLR":
|
||||
cellStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if cellState is None:
|
||||
cellState = torch.zeros(
|
||||
[input.shape[1], self._RNNCell.output_size]).to(self.device)
|
||||
for i in range(0, input.shape[0]):
|
||||
hiddenState, cellState = self._RNNCell(
|
||||
input[i, :, :], (hiddenState, cellState))
|
||||
hiddenStates[i, :, :] = hiddenState
|
||||
cellStates[i, :, :] = cellState
|
||||
return hiddenStates, cellStates
|
||||
else:
|
||||
for i in range(0, input.shape[0]):
|
||||
hiddenState = self._RNNCell(input[i, :, :], hiddenState)
|
||||
hiddenStates[i, :, :] = hiddenState
|
||||
return hiddenStates
|
||||
|
||||
class RNNCell(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, update_nonlinearity="tanh",
|
||||
wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0,
|
||||
name="None"):
|
||||
def __init__(self, input_size, hidden_size,
|
||||
update_nonlinearity, num_weight_matrices,
|
||||
wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0):
|
||||
super(RNNCell, self).__init__()
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [1, 1]
|
||||
self._num_weight_matrices = num_weight_matrices
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
self._wSparsity = wSparsity
|
||||
|
@ -174,10 +102,6 @@ class RNNCell(nn.Module):
|
|||
def num_weight_matrices(self):
|
||||
return self._num_weight_matrices
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def forward(self, input, state):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -227,7 +151,6 @@ class RNNCell(nn.Module):
|
|||
for i in range(0, endU):
|
||||
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])
|
||||
|
||||
|
||||
class FastGRNNCell(RNNCell):
|
||||
'''
|
||||
FastGRNN Cell with Both Full Rank and Low Rank Formulations
|
||||
|
@ -269,7 +192,7 @@ class FastGRNNCell(RNNCell):
|
|||
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
|
||||
name="FastGRNN"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
wRank, uRank, wSparsity, uSparsity)
|
||||
[1, 1], wRank, uRank, wSparsity, uSparsity)
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._zetaInit = zetaInit
|
||||
self._nuInit = nuInit
|
||||
|
@ -302,6 +225,10 @@ class FastGRNNCell(RNNCell):
|
|||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def cellType(self):
|
||||
return "FastGRNN"
|
||||
|
@ -346,7 +273,7 @@ class FastGRNNCell(RNNCell):
|
|||
return Vars
|
||||
|
||||
|
||||
class FastRNNCell(nn.Module):
|
||||
class FastRNNCell(RNNCell):
|
||||
'''
|
||||
FastRNN Cell with Both Full Rank and Low Rank Formulations
|
||||
Has multiple activation functions for the gates
|
||||
|
@ -383,16 +310,9 @@ class FastRNNCell(nn.Module):
|
|||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0,
|
||||
name="FastRNN"):
|
||||
super(FastRNNCell, self).__init__()
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[1, 1], wRank, uRank, wSparsity, uSparsity)
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [1, 1]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
self._wSparsity = wSparsity
|
||||
self._uSparsity = uSparsity
|
||||
self._alphaInit = alphaInit
|
||||
self._betaInit = betaInit
|
||||
if wRank is not None:
|
||||
|
@ -418,48 +338,6 @@ class FastRNNCell(nn.Module):
|
|||
self.alpha = nn.Parameter(self._alphaInit * torch.ones([1, 1]))
|
||||
self.beta = nn.Parameter(self._betaInit * torch.ones([1, 1]))
|
||||
|
||||
def sparsify(self):
|
||||
if self._wRank is None:
|
||||
self.W.data = utils.hardThreshold(self.W, self._wSparsity)
|
||||
else:
|
||||
self.W1.data = utils.hardThreshold(self.W1, self._wSparsity)
|
||||
self.W2.data = utils.hardThreshold(self.W2, self._wSparsity)
|
||||
|
||||
if self._uRank is None:
|
||||
self.U.data = utils.hardThreshold(self.U, self._uSparsity)
|
||||
else:
|
||||
self.U1.data = utils.hardThreshold(self.U1, self._uSparsity)
|
||||
self.U2.data = utils.hardThreshold(self.U2, self._uSparsity)
|
||||
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def input_size(self):
|
||||
return self._input_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
return self._wRank
|
||||
|
||||
@property
|
||||
def uRank(self):
|
||||
return self._uRank
|
||||
|
||||
@property
|
||||
def num_weight_matrices(self):
|
||||
return self._num_weight_matrices
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -507,26 +385,8 @@ class FastRNNCell(nn.Module):
|
|||
|
||||
return Vars
|
||||
|
||||
def getModelSize(self):
|
||||
'''
|
||||
Function to get aimed model size
|
||||
'''
|
||||
totalnnz = 2 # For \alpha and \beta
|
||||
totalnnz += utils.countNNZ(self.bias_update, False)
|
||||
if self._wRank is None:
|
||||
totalnnz += utils.countNNZ(self.W, self._wSparsity)
|
||||
else:
|
||||
totalnnz += utils.countNNZ(self.W1, self._wSparsity)
|
||||
totalnnz += utils.countNNZ(self.W2, self._wSparsity)
|
||||
|
||||
if self._uRank is None:
|
||||
totalnnz += utils.countNNZ(self.U, self._uSparsity)
|
||||
else:
|
||||
totalnnz += utils.countNNZ(self.U1, self._uSparsity)
|
||||
totalnnz += utils.countNNZ(self.U2, self._uSparsity)
|
||||
return totalnnz
|
||||
|
||||
class LSTMLRCell(nn.Module):
|
||||
class LSTMLRCell(RNNCell):
|
||||
'''
|
||||
LR - Low Rank
|
||||
LSTM LR Cell with Both Full Rank and Low Rank Formulations
|
||||
|
@ -561,16 +421,11 @@ class LSTMLRCell(nn.Module):
|
|||
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
name="LSTMLR"):
|
||||
super(LSTMLRCell, self).__init__()
|
||||
wSparsity=1.0, uSparsity=1.0, name="LSTMLR"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[4, 4], wRank, uRank, wSparsity, uSparsity)
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [4, 4]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 1
|
||||
if uRank is not None:
|
||||
|
@ -614,38 +469,10 @@ class LSTMLRCell(nn.Module):
|
|||
self.bias_c = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
self.bias_o = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return 2 * self._hidden_size
|
||||
|
||||
@property
|
||||
def input_size(self):
|
||||
return self._input_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
return self._wRank
|
||||
|
||||
@property
|
||||
def uRank(self):
|
||||
return self._uRank
|
||||
|
||||
@property
|
||||
def num_weight_matrices(self):
|
||||
return self._num_weight_matrices
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -722,7 +549,7 @@ class LSTMLRCell(nn.Module):
|
|||
return Vars
|
||||
|
||||
|
||||
class GRULRCell(nn.Module):
|
||||
class GRULRCell(RNNCell):
|
||||
'''
|
||||
GRU LR Cell with Both Full Rank and Low Rank Formulations
|
||||
Has multiple activation functions for the gates
|
||||
|
@ -754,16 +581,11 @@ class GRULRCell(nn.Module):
|
|||
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
name="GRULR"):
|
||||
super(GRULRCell, self).__init__()
|
||||
wSparsity=1.0, uSparsity=1.0, name="GRULR"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[3, 3], wRank, uRank, wSparsity, uSparsity)
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [3, 3]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 1
|
||||
if uRank is not None:
|
||||
|
@ -801,38 +623,10 @@ class GRULRCell(nn.Module):
|
|||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
self._device = self.bias_update.device
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def input_size(self):
|
||||
return self._input_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
return self._wRank
|
||||
|
||||
@property
|
||||
def uRank(self):
|
||||
return self._uRank
|
||||
|
||||
@property
|
||||
def num_weight_matrices(self):
|
||||
return self._num_weight_matrices
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -900,7 +694,7 @@ class GRULRCell(nn.Module):
|
|||
return Vars
|
||||
|
||||
|
||||
class UGRNNLRCell(nn.Module):
|
||||
class UGRNNLRCell(RNNCell):
|
||||
'''
|
||||
UGRNN LR Cell with Both Full Rank and Low Rank Formulations
|
||||
Has multiple activation functions for the gates
|
||||
|
@ -931,16 +725,11 @@ class UGRNNLRCell(nn.Module):
|
|||
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
name="UGRNNLR"):
|
||||
super(UGRNNLRCell, self).__init__()
|
||||
wSparsity=1.0, uSparsity=1.0, name="UGRNNLR"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[2, 2], wRank, uRank, wSparsity, uSparsity)
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [2, 2]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 1
|
||||
if uRank is not None:
|
||||
|
@ -971,38 +760,10 @@ class UGRNNLRCell(nn.Module):
|
|||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
self._device = self.bias_update.device
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def input_size(self):
|
||||
return self._input_size
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
return self._wRank
|
||||
|
||||
@property
|
||||
def uRank(self):
|
||||
return self._uRank
|
||||
|
||||
@property
|
||||
def num_weight_matrices(self):
|
||||
return self._num_weight_matrices
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -1058,6 +819,79 @@ class UGRNNLRCell(nn.Module):
|
|||
return Vars
|
||||
|
||||
|
||||
class BaseRNN(nn.Module):
|
||||
'''
|
||||
Generic equivalent of static_rnn in tf
|
||||
Used to unroll all the cell written in this file
|
||||
We assume input to be batch_first by default ie.,
|
||||
[batchSize, timeSteps, inputDims] else
|
||||
[timeSteps, batchSize, inputDims]
|
||||
'''
|
||||
|
||||
def __init__(self, cell: RNNCell, batch_first=True):
|
||||
super(BaseRNN, self).__init__()
|
||||
self._RNNCell = cell
|
||||
self._batch_first = batch_first
|
||||
|
||||
def getVars(self):
|
||||
return self._RNNCell.getVars()
|
||||
|
||||
def forward(self, input, hiddenState=None,
|
||||
cellState=None):
|
||||
if self._batch_first is True:
|
||||
self.device = input.device
|
||||
hiddenStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if hiddenState is None:
|
||||
hiddenState = torch.zeros([input.shape[0],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if self._RNNCell.cellType == "LSTMLR":
|
||||
cellStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if cellState is None:
|
||||
cellState = torch.zeros(
|
||||
[input.shape[0], self._RNNCell.output_size]).to(self.device)
|
||||
for i in range(0, input.shape[1]):
|
||||
hiddenState, cellState = self._RNNCell(
|
||||
input[:, i, :], (hiddenState, cellState))
|
||||
hiddenStates[:, i, :] = hiddenState
|
||||
cellStates[:, i, :] = cellState
|
||||
return hiddenStates, cellStates
|
||||
else:
|
||||
for i in range(0, input.shape[1]):
|
||||
hiddenState = self._RNNCell(input[:, i, :], hiddenState)
|
||||
hiddenStates[:, i, :] = hiddenState
|
||||
return hiddenStates
|
||||
else:
|
||||
self.device = input.device
|
||||
hiddenStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if hiddenState is None:
|
||||
hiddenState = torch.zeros([input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if self._RNNCell.cellType == "LSTMLR":
|
||||
cellStates = torch.zeros(
|
||||
[input.shape[0], input.shape[1],
|
||||
self._RNNCell.output_size]).to(self.device)
|
||||
if cellState is None:
|
||||
cellState = torch.zeros(
|
||||
[input.shape[1], self._RNNCell.output_size]).to(self.device)
|
||||
for i in range(0, input.shape[0]):
|
||||
hiddenState, cellState = self._RNNCell(
|
||||
input[i, :, :], (hiddenState, cellState))
|
||||
hiddenStates[i, :, :] = hiddenState
|
||||
cellStates[i, :, :] = cellState
|
||||
return hiddenStates, cellStates
|
||||
else:
|
||||
for i in range(0, input.shape[0]):
|
||||
hiddenState = self._RNNCell(input[i, :, :], hiddenState)
|
||||
hiddenStates[i, :, :] = hiddenState
|
||||
return hiddenStates
|
||||
|
||||
|
||||
class LSTM(nn.Module):
|
||||
"""Equivalent to nn.LSTM using LSTMLRCell"""
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче