This commit is contained in:
Harsha Vardhan Simhadri 2019-07-31 14:01:11 +05:30
Родитель a4deef91c2
Коммит c2ac41c868
1 изменённых файлов: 97 добавлений и 263 удалений

Просмотреть файл

@ -58,87 +58,15 @@ def gen_nonlinearity(A, nonlinearity):
return nonlinearity(A)
class BaseRNN(nn.Module):
'''
Generic equivalent of static_rnn in tf
Used to unroll all the cell written in this file
We assume input to be batch_first by default ie.,
[batchSize, timeSteps, inputDims] else
[timeSteps, batchSize, inputDims]
'''
def __init__(self, RNNCell, batch_first=True):
super(BaseRNN, self).__init__()
self._RNNCell = RNNCell
self._batch_first = batch_first
def getVars(self):
return self._RNNCell.getVars()
def forward(self, input, hiddenState=None,
cellState=None):
if self._batch_first is True:
self.device = input.device
hiddenStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if hiddenState is None:
hiddenState = torch.zeros([input.shape[0],
self._RNNCell.output_size]).to(self.device)
if self._RNNCell.cellType == "LSTMLR":
cellStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if cellState is None:
cellState = torch.zeros(
[input.shape[0], self._RNNCell.output_size]).to(self.device)
for i in range(0, input.shape[1]):
hiddenState, cellState = self._RNNCell(
input[:, i, :], (hiddenState, cellState))
hiddenStates[:, i, :] = hiddenState
cellStates[:, i, :] = cellState
return hiddenStates, cellStates
else:
for i in range(0, input.shape[1]):
hiddenState = self._RNNCell(input[:, i, :], hiddenState)
hiddenStates[:, i, :] = hiddenState
return hiddenStates
else:
self.device = input.device
hiddenStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if hiddenState is None:
hiddenState = torch.zeros([input.shape[1],
self._RNNCell.output_size]).to(self.device)
if self._RNNCell.cellType == "LSTMLR":
cellStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if cellState is None:
cellState = torch.zeros(
[input.shape[1], self._RNNCell.output_size]).to(self.device)
for i in range(0, input.shape[0]):
hiddenState, cellState = self._RNNCell(
input[i, :, :], (hiddenState, cellState))
hiddenStates[i, :, :] = hiddenState
cellStates[i, :, :] = cellState
return hiddenStates, cellStates
else:
for i in range(0, input.shape[0]):
hiddenState = self._RNNCell(input[i, :, :], hiddenState)
hiddenStates[i, :, :] = hiddenState
return hiddenStates
class RNNCell(nn.Module):
def __init__(self, input_size, hidden_size, update_nonlinearity="tanh",
wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0,
name="None"):
def __init__(self, input_size, hidden_size,
update_nonlinearity, num_weight_matrices,
wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0):
super(RNNCell, self).__init__()
self._input_size = input_size
self._hidden_size = hidden_size
self._update_nonlinearity = update_nonlinearity
self._num_weight_matrices = [1, 1]
self._num_weight_matrices = num_weight_matrices
self._wRank = wRank
self._uRank = uRank
self._wSparsity = wSparsity
@ -174,10 +102,6 @@ class RNNCell(nn.Module):
def num_weight_matrices(self):
return self._num_weight_matrices
@property
def name(self):
return self._name
def forward(self, input, state):
raise NotImplementedError()
@ -227,7 +151,6 @@ class RNNCell(nn.Module):
for i in range(0, endU):
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])
class FastGRNNCell(RNNCell):
'''
FastGRNN Cell with Both Full Rank and Low Rank Formulations
@ -269,7 +192,7 @@ class FastGRNNCell(RNNCell):
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
name="FastGRNN"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
wRank, uRank, wSparsity, uSparsity)
[1, 1], wRank, uRank, wSparsity, uSparsity)
self._gate_nonlinearity = gate_nonlinearity
self._zetaInit = zetaInit
self._nuInit = nuInit
@ -302,6 +225,10 @@ class FastGRNNCell(RNNCell):
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def name(self):
return self._name
@property
def cellType(self):
return "FastGRNN"
@ -346,7 +273,7 @@ class FastGRNNCell(RNNCell):
return Vars
class FastRNNCell(nn.Module):
class FastRNNCell(RNNCell):
'''
FastRNN Cell with Both Full Rank and Low Rank Formulations
Has multiple activation functions for the gates
@ -383,16 +310,9 @@ class FastRNNCell(nn.Module):
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0,
name="FastRNN"):
super(FastRNNCell, self).__init__()
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[1, 1], wRank, uRank, wSparsity, uSparsity)
self._input_size = input_size
self._hidden_size = hidden_size
self._update_nonlinearity = update_nonlinearity
self._num_weight_matrices = [1, 1]
self._wRank = wRank
self._uRank = uRank
self._wSparsity = wSparsity
self._uSparsity = uSparsity
self._alphaInit = alphaInit
self._betaInit = betaInit
if wRank is not None:
@ -418,48 +338,6 @@ class FastRNNCell(nn.Module):
self.alpha = nn.Parameter(self._alphaInit * torch.ones([1, 1]))
self.beta = nn.Parameter(self._betaInit * torch.ones([1, 1]))
def sparsify(self):
if self._wRank is None:
self.W.data = utils.hardThreshold(self.W, self._wSparsity)
else:
self.W1.data = utils.hardThreshold(self.W1, self._wSparsity)
self.W2.data = utils.hardThreshold(self.W2, self._wSparsity)
if self._uRank is None:
self.U.data = utils.hardThreshold(self.U, self._uSparsity)
else:
self.U1.data = utils.hardThreshold(self.U1, self._uSparsity)
self.U2.data = utils.hardThreshold(self.U2, self._uSparsity)
@property
def state_size(self):
return self._hidden_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._hidden_size
@property
def update_nonlinearity(self):
return self._update_nonlinearity
@property
def wRank(self):
return self._wRank
@property
def uRank(self):
return self._uRank
@property
def num_weight_matrices(self):
return self._num_weight_matrices
@property
def name(self):
return self._name
@ -507,26 +385,8 @@ class FastRNNCell(nn.Module):
return Vars
def getModelSize(self):
'''
Function to get aimed model size
'''
totalnnz = 2 # For \alpha and \beta
totalnnz += utils.countNNZ(self.bias_update, False)
if self._wRank is None:
totalnnz += utils.countNNZ(self.W, self._wSparsity)
else:
totalnnz += utils.countNNZ(self.W1, self._wSparsity)
totalnnz += utils.countNNZ(self.W2, self._wSparsity)
if self._uRank is None:
totalnnz += utils.countNNZ(self.U, self._uSparsity)
else:
totalnnz += utils.countNNZ(self.U1, self._uSparsity)
totalnnz += utils.countNNZ(self.U2, self._uSparsity)
return totalnnz
class LSTMLRCell(nn.Module):
class LSTMLRCell(RNNCell):
'''
LR - Low Rank
LSTM LR Cell with Both Full Rank and Low Rank Formulations
@ -561,16 +421,11 @@ class LSTMLRCell(nn.Module):
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
name="LSTMLR"):
super(LSTMLRCell, self).__init__()
wSparsity=1.0, uSparsity=1.0, name="LSTMLR"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[4, 4], wRank, uRank, wSparsity, uSparsity)
self._input_size = input_size
self._hidden_size = hidden_size
self._gate_nonlinearity = gate_nonlinearity
self._update_nonlinearity = update_nonlinearity
self._num_weight_matrices = [4, 4]
self._wRank = wRank
self._uRank = uRank
if wRank is not None:
self._num_weight_matrices[0] += 1
if uRank is not None:
@ -614,38 +469,10 @@ class LSTMLRCell(nn.Module):
self.bias_c = nn.Parameter(torch.ones([1, hidden_size]))
self.bias_o = nn.Parameter(torch.ones([1, hidden_size]))
@property
def state_size(self):
return 2 * self._hidden_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._hidden_size
@property
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def update_nonlinearity(self):
return self._update_nonlinearity
@property
def wRank(self):
return self._wRank
@property
def uRank(self):
return self._uRank
@property
def num_weight_matrices(self):
return self._num_weight_matrices
@property
def name(self):
return self._name
@ -722,7 +549,7 @@ class LSTMLRCell(nn.Module):
return Vars
class GRULRCell(nn.Module):
class GRULRCell(RNNCell):
'''
GRU LR Cell with Both Full Rank and Low Rank Formulations
Has multiple activation functions for the gates
@ -754,16 +581,11 @@ class GRULRCell(nn.Module):
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
name="GRULR"):
super(GRULRCell, self).__init__()
wSparsity=1.0, uSparsity=1.0, name="GRULR"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[3, 3], wRank, uRank, wSparsity, uSparsity)
self._input_size = input_size
self._hidden_size = hidden_size
self._gate_nonlinearity = gate_nonlinearity
self._update_nonlinearity = update_nonlinearity
self._num_weight_matrices = [3, 3]
self._wRank = wRank
self._uRank = uRank
if wRank is not None:
self._num_weight_matrices[0] += 1
if uRank is not None:
@ -801,38 +623,10 @@ class GRULRCell(nn.Module):
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
self._device = self.bias_update.device
@property
def state_size(self):
return self._hidden_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._hidden_size
@property
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def update_nonlinearity(self):
return self._update_nonlinearity
@property
def wRank(self):
return self._wRank
@property
def uRank(self):
return self._uRank
@property
def num_weight_matrices(self):
return self._num_weight_matrices
@property
def name(self):
return self._name
@ -900,7 +694,7 @@ class GRULRCell(nn.Module):
return Vars
class UGRNNLRCell(nn.Module):
class UGRNNLRCell(RNNCell):
'''
UGRNN LR Cell with Both Full Rank and Low Rank Formulations
Has multiple activation functions for the gates
@ -931,16 +725,11 @@ class UGRNNLRCell(nn.Module):
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
name="UGRNNLR"):
super(UGRNNLRCell, self).__init__()
wSparsity=1.0, uSparsity=1.0, name="UGRNNLR"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[2, 2], wRank, uRank, wSparsity, uSparsity)
self._input_size = input_size
self._hidden_size = hidden_size
self._gate_nonlinearity = gate_nonlinearity
self._update_nonlinearity = update_nonlinearity
self._num_weight_matrices = [2, 2]
self._wRank = wRank
self._uRank = uRank
if wRank is not None:
self._num_weight_matrices[0] += 1
if uRank is not None:
@ -971,38 +760,10 @@ class UGRNNLRCell(nn.Module):
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
self._device = self.bias_update.device
@property
def state_size(self):
return self._hidden_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._hidden_size
@property
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def update_nonlinearity(self):
return self._update_nonlinearity
@property
def wRank(self):
return self._wRank
@property
def uRank(self):
return self._uRank
@property
def num_weight_matrices(self):
return self._num_weight_matrices
@property
def name(self):
return self._name
@ -1058,6 +819,79 @@ class UGRNNLRCell(nn.Module):
return Vars
class BaseRNN(nn.Module):
'''
Generic equivalent of static_rnn in tf
Used to unroll all the cell written in this file
We assume input to be batch_first by default ie.,
[batchSize, timeSteps, inputDims] else
[timeSteps, batchSize, inputDims]
'''
def __init__(self, cell: RNNCell, batch_first=True):
super(BaseRNN, self).__init__()
self._RNNCell = cell
self._batch_first = batch_first
def getVars(self):
return self._RNNCell.getVars()
def forward(self, input, hiddenState=None,
cellState=None):
if self._batch_first is True:
self.device = input.device
hiddenStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if hiddenState is None:
hiddenState = torch.zeros([input.shape[0],
self._RNNCell.output_size]).to(self.device)
if self._RNNCell.cellType == "LSTMLR":
cellStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if cellState is None:
cellState = torch.zeros(
[input.shape[0], self._RNNCell.output_size]).to(self.device)
for i in range(0, input.shape[1]):
hiddenState, cellState = self._RNNCell(
input[:, i, :], (hiddenState, cellState))
hiddenStates[:, i, :] = hiddenState
cellStates[:, i, :] = cellState
return hiddenStates, cellStates
else:
for i in range(0, input.shape[1]):
hiddenState = self._RNNCell(input[:, i, :], hiddenState)
hiddenStates[:, i, :] = hiddenState
return hiddenStates
else:
self.device = input.device
hiddenStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if hiddenState is None:
hiddenState = torch.zeros([input.shape[1],
self._RNNCell.output_size]).to(self.device)
if self._RNNCell.cellType == "LSTMLR":
cellStates = torch.zeros(
[input.shape[0], input.shape[1],
self._RNNCell.output_size]).to(self.device)
if cellState is None:
cellState = torch.zeros(
[input.shape[1], self._RNNCell.output_size]).to(self.device)
for i in range(0, input.shape[0]):
hiddenState, cellState = self._RNNCell(
input[i, :, :], (hiddenState, cellState))
hiddenStates[i, :, :] = hiddenState
cellStates[i, :, :] = cellState
return hiddenStates, cellStates
else:
for i in range(0, input.shape[0]):
hiddenState = self._RNNCell(input[i, :, :], hiddenState)
hiddenStates[i, :, :] = hiddenState
return hiddenStates
class LSTM(nn.Module):
"""Equivalent to nn.LSTM using LSTMLRCell"""