зеркало из https://github.com/microsoft/EdgeML.git
renaming num_weight_matrices to separate vars for num of U and W mats
This commit is contained in:
Родитель
eadae5b38f
Коммит
92bf361480
|
@ -60,13 +60,18 @@ def gen_nonlinearity(A, nonlinearity):
|
|||
|
||||
class RNNCell(nn.Module):
|
||||
def __init__(self, input_size, hidden_size,
|
||||
update_nonlinearity, num_weight_matrices,
|
||||
wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0):
|
||||
gate_nonlinearity, update_nonlinearity,
|
||||
num_W_matrices, num_U_matrices,
|
||||
wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0):
|
||||
super(RNNCell, self).__init__()
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = num_weight_matrices
|
||||
#self._num_weight_matrices = num_weight_matrices
|
||||
self._num_W_matrices = num_W_matrices
|
||||
self._num_U_matrices = num_U_matrices
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
self._wSparsity = wSparsity
|
||||
|
@ -86,6 +91,10 @@ class RNNCell(nn.Module):
|
|||
def output_size(self):
|
||||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
@ -98,9 +107,17 @@ class RNNCell(nn.Module):
|
|||
def uRank(self):
|
||||
return self._uRank
|
||||
|
||||
#@property
|
||||
#def num_weight_matrices(self):
|
||||
# return self._num_weight_matrices
|
||||
|
||||
@property
|
||||
def num_weight_matrices(self):
|
||||
return self._num_weight_matrices
|
||||
def num_W_matrices(self):
|
||||
return self._num_W_matrices
|
||||
|
||||
@property
|
||||
def num_U_matrices(self):
|
||||
return self._num_U_matrices
|
||||
|
||||
def forward(self, input, state):
|
||||
raise NotImplementedError()
|
||||
|
@ -113,8 +130,8 @@ class RNNCell(nn.Module):
|
|||
Function to get aimed model size
|
||||
'''
|
||||
mats = self.getVars()
|
||||
endW = self._num_weight_matrices[0]
|
||||
endU = endW + self._num_weight_matrices[1]
|
||||
endW = self._num_W_matrices
|
||||
endU = endW + self._num_U_matrices
|
||||
|
||||
totalnnz = 2 # For Zeta and Nu
|
||||
for i in range(0, endW):
|
||||
|
@ -127,7 +144,7 @@ class RNNCell(nn.Module):
|
|||
|
||||
def copy_previous_UW(self):
|
||||
mats = self.getVars()
|
||||
num_mats = self._num_weight_matrices[0] + self._num_weight_matrices[1]
|
||||
num_mats = self._num_W_matrices + self._num_U_matrices
|
||||
if len(self.oldmats) != num_mats:
|
||||
for i in range(num_mats):
|
||||
self.oldmats.append(torch.FloatTensor())
|
||||
|
@ -137,8 +154,8 @@ class RNNCell(nn.Module):
|
|||
|
||||
def sparsify(self):
|
||||
mats = self.getVars()
|
||||
endW = self._num_weight_matrices[0]
|
||||
endU = endW + self._num_weight_matrices[1]
|
||||
endW = self._num_W_matrices
|
||||
endU = endW + self._num_U_matrices
|
||||
for i in range(0, endW):
|
||||
mats[i] = utils.hardThreshold(mats[i], self._wSparsity)
|
||||
for i in range(endW, endU):
|
||||
|
@ -147,7 +164,7 @@ class RNNCell(nn.Module):
|
|||
|
||||
def sparsifyWithSupport(self):
|
||||
mats = self.getVars()
|
||||
endU = self._num_weight_matrices[0] + self._num_weight_matrices[1]
|
||||
endU = self._num_W_matrices + self._num_U_matrices
|
||||
for i in range(0, endU):
|
||||
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])
|
||||
|
||||
|
@ -191,15 +208,15 @@ class FastGRNNCell(RNNCell):
|
|||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
|
||||
name="FastGRNN"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[1, 1], wRank, uRank, wSparsity, uSparsity)
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size,
|
||||
gate_nonlinearity, update_nonlinearity,
|
||||
1, 1, wRank, uRank, wSparsity, uSparsity)
|
||||
self._zetaInit = zetaInit
|
||||
self._nuInit = nuInit
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 1
|
||||
self._num_W_matrices += 1
|
||||
if uRank is not None:
|
||||
self._num_weight_matrices[1] += 1
|
||||
self._num_U_matrices += 1
|
||||
self._name = name
|
||||
|
||||
if wRank is None:
|
||||
|
@ -221,10 +238,6 @@ class FastGRNNCell(RNNCell):
|
|||
|
||||
self.copy_previous_UW()
|
||||
|
||||
@property
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -258,12 +271,12 @@ class FastGRNNCell(RNNCell):
|
|||
|
||||
def getVars(self):
|
||||
Vars = []
|
||||
if self._num_weight_matrices[0] == 1:
|
||||
if self._num_W_matrices == 1:
|
||||
Vars.append(self.W)
|
||||
else:
|
||||
Vars.extend([self.W1, self.W2])
|
||||
|
||||
if self._num_weight_matrices[1] == 1:
|
||||
if self._num_U_matrices == 1:
|
||||
Vars.append(self.U)
|
||||
else:
|
||||
Vars.extend([self.U1, self.U2])
|
||||
|
@ -310,15 +323,16 @@ class FastRNNCell(RNNCell):
|
|||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0,
|
||||
name="FastRNN"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[1, 1], wRank, uRank, wSparsity, uSparsity)
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size,
|
||||
None, update_nonlinearity,
|
||||
1, 1, wRank, uRank, wSparsity, uSparsity)
|
||||
|
||||
self._alphaInit = alphaInit
|
||||
self._betaInit = betaInit
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 1
|
||||
self._num_W_matrices += 1
|
||||
if uRank is not None:
|
||||
self._num_weight_matrices[1] += 1
|
||||
self._num_U_matrices += 1
|
||||
self._name = name
|
||||
|
||||
if wRank is None:
|
||||
|
@ -370,12 +384,12 @@ class FastRNNCell(RNNCell):
|
|||
|
||||
def getVars(self):
|
||||
Vars = []
|
||||
if self._num_weight_matrices[0] == 1:
|
||||
if self._num_W_matrices == 1:
|
||||
Vars.append(self.W)
|
||||
else:
|
||||
Vars.extend([self.W1, self.W2])
|
||||
|
||||
if self._num_weight_matrices[1] == 1:
|
||||
if self._num_U_matrices == 1:
|
||||
Vars.append(self.U)
|
||||
else:
|
||||
Vars.extend([self.U1, self.U2])
|
||||
|
@ -422,14 +436,14 @@ class LSTMLRCell(RNNCell):
|
|||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, name="LSTMLR"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[4, 4], wRank, uRank, wSparsity, uSparsity)
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size,
|
||||
gate_nonlinearity, update_nonlinearity,
|
||||
4, 4, wRank, uRank, wSparsity, uSparsity)
|
||||
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 1
|
||||
self._num_W_matrices += 1
|
||||
if uRank is not None:
|
||||
self._num_weight_matrices[1] += 1
|
||||
self._num_U_matrices += 1
|
||||
self._name = name
|
||||
|
||||
if wRank is None:
|
||||
|
@ -534,12 +548,12 @@ class LSTMLRCell(RNNCell):
|
|||
|
||||
def getVars(self):
|
||||
Vars = []
|
||||
if self._num_weight_matrices[0] == 4:
|
||||
if self._num_W_matrices == 4:
|
||||
Vars.extend([self.W1, self.W2, self.W3, self.W4])
|
||||
else:
|
||||
Vars.extend([self.W, self.W1, self.W2, self.W3, self.W4])
|
||||
|
||||
if self._num_weight_matrices[1] == 4:
|
||||
if self._num_U_matrices == 4:
|
||||
Vars.extend([self.U1, self.U2, self.U3, self.U4])
|
||||
else:
|
||||
Vars.extend([self.U, self.U1, self.U2, self.U3, self.U4])
|
||||
|
@ -582,14 +596,14 @@ class GRULRCell(RNNCell):
|
|||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, name="GRULR"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[3, 3], wRank, uRank, wSparsity, uSparsity)
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size,
|
||||
gate_nonlinearity, update_nonlinearity,
|
||||
3, 3, wRank, uRank, wSparsity, uSparsity)
|
||||
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 1
|
||||
self._num_W_matrices += 1
|
||||
if uRank is not None:
|
||||
self._num_weight_matrices[1] += 1
|
||||
self._num_U_matrices += 1
|
||||
self._name = name
|
||||
|
||||
if wRank is None:
|
||||
|
@ -623,10 +637,6 @@ class GRULRCell(RNNCell):
|
|||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
self._device = self.bias_update.device
|
||||
|
||||
@property
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -679,12 +689,12 @@ class GRULRCell(RNNCell):
|
|||
|
||||
def getVars(self):
|
||||
Vars = []
|
||||
if self._num_weight_matrices[0] == 3:
|
||||
if self._num_W_matrices == 3:
|
||||
Vars.extend([self.W1, self.W2, self.W3])
|
||||
else:
|
||||
Vars.extend([self.W, self.W1, self.W2, self.W3])
|
||||
|
||||
if self._num_weight_matrices[1] == 3:
|
||||
if self._num_U_matrices == 3:
|
||||
Vars.extend([self.U1, self.U2, self.U3])
|
||||
else:
|
||||
Vars.extend([self.U, self.U1, self.U2, self.U3])
|
||||
|
@ -726,14 +736,14 @@ class UGRNNLRCell(RNNCell):
|
|||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
wSparsity=1.0, uSparsity=1.0, name="UGRNNLR"):
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
|
||||
[2, 2], wRank, uRank, wSparsity, uSparsity)
|
||||
super(FastGRNNCell, self).__init__(input_size, hidden_size,
|
||||
gate_nonlinearity, update_nonlinearity,
|
||||
2, 2, wRank, uRank, wSparsity, uSparsity)
|
||||
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
if wRank is not None:
|
||||
self._num_weight_matrices[0] += 1
|
||||
self._num_W_matrices += 1
|
||||
if uRank is not None:
|
||||
self._num_weight_matrices[1] += 1
|
||||
self._num_U_matrices += 1
|
||||
self._name = name
|
||||
|
||||
if wRank is None:
|
||||
|
@ -760,10 +770,6 @@ class UGRNNLRCell(RNNCell):
|
|||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
self._device = self.bias_update.device
|
||||
|
||||
@property
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
@ -804,12 +810,12 @@ class UGRNNLRCell(RNNCell):
|
|||
|
||||
def getVars(self):
|
||||
Vars = []
|
||||
if self._num_weight_matrices[0] == 2:
|
||||
if self._num_W_matrices == 2:
|
||||
Vars.extend([self.W1, self.W2])
|
||||
else:
|
||||
Vars.extend([self.W, self.W1, self.W2])
|
||||
|
||||
if self._num_weight_matrices[1] == 2:
|
||||
if self._num_U_matrices == 2:
|
||||
Vars.extend([self.U1, self.U2])
|
||||
else:
|
||||
Vars.extend([self.U, self.U1, self.U2])
|
||||
|
|
Загрузка…
Ссылка в новой задаче