renaming num_weight_matrices to separate vars for num of U and W mats

This commit is contained in:
Harsha Vardhan Simhadri 2019-08-01 10:54:36 +05:30
Родитель eadae5b38f
Коммит 92bf361480
1 изменённых файлов: 63 добавлений и 57 удалений

Просмотреть файл

@ -60,13 +60,18 @@ def gen_nonlinearity(A, nonlinearity):
class RNNCell(nn.Module):
def __init__(self, input_size, hidden_size,
update_nonlinearity, num_weight_matrices,
wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0):
gate_nonlinearity, update_nonlinearity,
num_W_matrices, num_U_matrices,
wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0):
super(RNNCell, self).__init__()
self._input_size = input_size
self._hidden_size = hidden_size
self._gate_nonlinearity = gate_nonlinearity
self._update_nonlinearity = update_nonlinearity
self._num_weight_matrices = num_weight_matrices
#self._num_weight_matrices = num_weight_matrices
self._num_W_matrices = num_W_matrices
self._num_U_matrices = num_U_matrices
self._wRank = wRank
self._uRank = uRank
self._wSparsity = wSparsity
@ -86,6 +91,10 @@ class RNNCell(nn.Module):
def output_size(self):
return self._hidden_size
@property
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def update_nonlinearity(self):
return self._update_nonlinearity
@ -98,9 +107,17 @@ class RNNCell(nn.Module):
def uRank(self):
return self._uRank
#@property
#def num_weight_matrices(self):
# return self._num_weight_matrices
@property
def num_weight_matrices(self):
return self._num_weight_matrices
def num_W_matrices(self):
return self._num_W_matrices
@property
def num_U_matrices(self):
return self._num_U_matrices
def forward(self, input, state):
raise NotImplementedError()
@ -113,8 +130,8 @@ class RNNCell(nn.Module):
Function to get aimed model size
'''
mats = self.getVars()
endW = self._num_weight_matrices[0]
endU = endW + self._num_weight_matrices[1]
endW = self._num_W_matrices
endU = endW + self._num_U_matrices
totalnnz = 2 # For Zeta and Nu
for i in range(0, endW):
@ -127,7 +144,7 @@ class RNNCell(nn.Module):
def copy_previous_UW(self):
mats = self.getVars()
num_mats = self._num_weight_matrices[0] + self._num_weight_matrices[1]
num_mats = self._num_W_matrices + self._num_U_matrices
if len(self.oldmats) != num_mats:
for i in range(num_mats):
self.oldmats.append(torch.FloatTensor())
@ -137,8 +154,8 @@ class RNNCell(nn.Module):
def sparsify(self):
mats = self.getVars()
endW = self._num_weight_matrices[0]
endU = endW + self._num_weight_matrices[1]
endW = self._num_W_matrices
endU = endW + self._num_U_matrices
for i in range(0, endW):
mats[i] = utils.hardThreshold(mats[i], self._wSparsity)
for i in range(endW, endU):
@ -147,7 +164,7 @@ class RNNCell(nn.Module):
def sparsifyWithSupport(self):
mats = self.getVars()
endU = self._num_weight_matrices[0] + self._num_weight_matrices[1]
endU = self._num_W_matrices + self._num_U_matrices
for i in range(0, endU):
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])
@ -191,15 +208,15 @@ class FastGRNNCell(RNNCell):
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
name="FastGRNN"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[1, 1], wRank, uRank, wSparsity, uSparsity)
self._gate_nonlinearity = gate_nonlinearity
super(FastGRNNCell, self).__init__(input_size, hidden_size,
gate_nonlinearity, update_nonlinearity,
1, 1, wRank, uRank, wSparsity, uSparsity)
self._zetaInit = zetaInit
self._nuInit = nuInit
if wRank is not None:
self._num_weight_matrices[0] += 1
self._num_W_matrices += 1
if uRank is not None:
self._num_weight_matrices[1] += 1
self._num_U_matrices += 1
self._name = name
if wRank is None:
@ -221,10 +238,6 @@ class FastGRNNCell(RNNCell):
self.copy_previous_UW()
@property
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def name(self):
return self._name
@ -258,12 +271,12 @@ class FastGRNNCell(RNNCell):
def getVars(self):
Vars = []
if self._num_weight_matrices[0] == 1:
if self._num_W_matrices == 1:
Vars.append(self.W)
else:
Vars.extend([self.W1, self.W2])
if self._num_weight_matrices[1] == 1:
if self._num_U_matrices == 1:
Vars.append(self.U)
else:
Vars.extend([self.U1, self.U2])
@ -310,15 +323,16 @@ class FastRNNCell(RNNCell):
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0,
name="FastRNN"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[1, 1], wRank, uRank, wSparsity, uSparsity)
super(FastGRNNCell, self).__init__(input_size, hidden_size,
None, update_nonlinearity,
1, 1, wRank, uRank, wSparsity, uSparsity)
self._alphaInit = alphaInit
self._betaInit = betaInit
if wRank is not None:
self._num_weight_matrices[0] += 1
self._num_W_matrices += 1
if uRank is not None:
self._num_weight_matrices[1] += 1
self._num_U_matrices += 1
self._name = name
if wRank is None:
@ -370,12 +384,12 @@ class FastRNNCell(RNNCell):
def getVars(self):
Vars = []
if self._num_weight_matrices[0] == 1:
if self._num_W_matrices == 1:
Vars.append(self.W)
else:
Vars.extend([self.W1, self.W2])
if self._num_weight_matrices[1] == 1:
if self._num_U_matrices == 1:
Vars.append(self.U)
else:
Vars.extend([self.U1, self.U2])
@ -422,14 +436,14 @@ class LSTMLRCell(RNNCell):
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, name="LSTMLR"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[4, 4], wRank, uRank, wSparsity, uSparsity)
super(FastGRNNCell, self).__init__(input_size, hidden_size,
gate_nonlinearity, update_nonlinearity,
4, 4, wRank, uRank, wSparsity, uSparsity)
self._gate_nonlinearity = gate_nonlinearity
if wRank is not None:
self._num_weight_matrices[0] += 1
self._num_W_matrices += 1
if uRank is not None:
self._num_weight_matrices[1] += 1
self._num_U_matrices += 1
self._name = name
if wRank is None:
@ -534,12 +548,12 @@ class LSTMLRCell(RNNCell):
def getVars(self):
Vars = []
if self._num_weight_matrices[0] == 4:
if self._num_W_matrices == 4:
Vars.extend([self.W1, self.W2, self.W3, self.W4])
else:
Vars.extend([self.W, self.W1, self.W2, self.W3, self.W4])
if self._num_weight_matrices[1] == 4:
if self._num_U_matrices == 4:
Vars.extend([self.U1, self.U2, self.U3, self.U4])
else:
Vars.extend([self.U, self.U1, self.U2, self.U3, self.U4])
@ -582,14 +596,14 @@ class GRULRCell(RNNCell):
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, name="GRULR"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[3, 3], wRank, uRank, wSparsity, uSparsity)
super(FastGRNNCell, self).__init__(input_size, hidden_size,
gate_nonlinearity, update_nonlinearity,
3, 3, wRank, uRank, wSparsity, uSparsity)
self._gate_nonlinearity = gate_nonlinearity
if wRank is not None:
self._num_weight_matrices[0] += 1
self._num_W_matrices += 1
if uRank is not None:
self._num_weight_matrices[1] += 1
self._num_U_matrices += 1
self._name = name
if wRank is None:
@ -623,10 +637,6 @@ class GRULRCell(RNNCell):
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
self._device = self.bias_update.device
@property
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def name(self):
return self._name
@ -679,12 +689,12 @@ class GRULRCell(RNNCell):
def getVars(self):
Vars = []
if self._num_weight_matrices[0] == 3:
if self._num_W_matrices == 3:
Vars.extend([self.W1, self.W2, self.W3])
else:
Vars.extend([self.W, self.W1, self.W2, self.W3])
if self._num_weight_matrices[1] == 3:
if self._num_U_matrices == 3:
Vars.extend([self.U1, self.U2, self.U3])
else:
Vars.extend([self.U, self.U1, self.U2, self.U3])
@ -726,14 +736,14 @@ class UGRNNLRCell(RNNCell):
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, name="UGRNNLR"):
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
[2, 2], wRank, uRank, wSparsity, uSparsity)
super(FastGRNNCell, self).__init__(input_size, hidden_size,
gate_nonlinearity, update_nonlinearity,
2, 2, wRank, uRank, wSparsity, uSparsity)
self._gate_nonlinearity = gate_nonlinearity
if wRank is not None:
self._num_weight_matrices[0] += 1
self._num_W_matrices += 1
if uRank is not None:
self._num_weight_matrices[1] += 1
self._num_U_matrices += 1
self._name = name
if wRank is None:
@ -760,10 +770,6 @@ class UGRNNLRCell(RNNCell):
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
self._device = self.bias_update.device
@property
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def name(self):
return self._name
@ -804,12 +810,12 @@ class UGRNNLRCell(RNNCell):
def getVars(self):
Vars = []
if self._num_weight_matrices[0] == 2:
if self._num_W_matrices == 2:
Vars.extend([self.W1, self.W2])
else:
Vars.extend([self.W, self.W1, self.W2])
if self._num_weight_matrices[1] == 2:
if self._num_U_matrices == 2:
Vars.extend([self.U1, self.U2])
else:
Vars.extend([self.U, self.U1, self.U2])