This commit is contained in:
Harsha Vardhan Simhadri 2019-07-31 12:23:05 +05:30
Родитель bd7732aedd
Коммит 27ef35ea64
1 изменённых файлов: 97 добавлений и 151 удалений

Просмотреть файл

@ -130,8 +130,102 @@ class BaseRNN(nn.Module):
hiddenStates[i, :, :] = hiddenState
return hiddenStates
class RNNCell(nn.Module):
def __init__(self, input_size, hidden_size, update_nonlinearity="tanh",
wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0,
name="None"):
super(RNNCell, self).__init__()
self._input_size = input_size
self._hidden_size = hidden_size
self._update_nonlinearity = update_nonlinearity
self._num_weight_matrices = [1, 1]
self._wRank = wRank
self._uRank = uRank
self._wSparsity = wSparsity
self._uSparsity = uSparsity
self.oldmats = []
class FastGRNNCell(nn.Module):
@property
def state_size(self):
return self._hidden_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._hidden_size
@property
def update_nonlinearity(self):
return self._update_nonlinearity
@property
def wRank(self):
return self._wRank
@property
def uRank(self):
return self._uRank
@property
def num_weight_matrices(self):
return self._num_weight_matrices
@property
def name(self):
return self._name
def getVars(self):
raise NotImplementedError()
def getModelSize(self):
'''
Function to get aimed model size
'''
mats = self.getVars()
endW = self._num_weight_matrices[0]
endU = endW + self._num_weight_matrices[1]
totalnnz = 2 # For Zeta and Nu
for i in range(0, endW):
totalnnz += utils.countNNZ(mats[i], self._wSparsity)
for i in range(endW, endU):
totalnnz += utils.countNNZ(mats[i], self._uSparsity)
for i in range(endU, len(mats)):
totalnnz += utils.countNNZ(mats[i], False)
return totalnnz * 4
def copy_previous_UW(self):
mats = self.getVars()
num_mats = self._num_weight_matrices[0] + self._num_weight_matrices[1]
if len(self.oldmats) != num_mats:
for i in range(num_mats):
self.oldmats.append(torch.FloatTensor())
for i in range(num_mats):
self.oldmats[i] = torch.FloatTensor(np.copy(mats[i].data.cpu().detach().numpy()))
self.oldmats[i].to(mats[i].device)
def sparsify(self):
mats = self.getVars()
endW = self._num_weight_matrices[0]
endU = endW + self._num_weight_matrices[1]
for i in range(0, endW):
mats[i] = utils.hardThreshold(mats[i], self._wSparsity)
for i in range(endW, endU):
mats[i] = utils.hardThreshold(mats[i], self._uSparsity)
self.copy_previous_UW()
def sparsifyWithSupport(self):
mats = self.getVars()
endU = self._num_weight_matrices[0] + self._num_weight_matrices[1]
for i in range(0, endU):
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])
class FastGRNNCell(RNNCell):
'''
FastGRNN Cell with Both Full Rank and Low Rank Formulations
Has multiple activation functions for the gates
@ -171,17 +265,9 @@ class FastGRNNCell(nn.Module):
update_nonlinearity="tanh", wRank=None, uRank=None,
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
name="FastGRNN"):
super(FastGRNNCell, self).__init__()
self._input_size = input_size
self._hidden_size = hidden_size
super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity,
wRank, uRank, wSparsity, uSparsity)
self._gate_nonlinearity = gate_nonlinearity
self._update_nonlinearity = update_nonlinearity
self._num_weight_matrices = [1, 1]
self._wRank = wRank
self._uRank = uRank
self._wSparsity = wSparsity
self._uSparsity = uSparsity
self._zetaInit = zetaInit
self._nuInit = nuInit
if wRank is not None:
@ -207,46 +293,12 @@ class FastGRNNCell(nn.Module):
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
self.oldmats = []
self.copy_previous_UW()
@property
def state_size(self):
return self._hidden_size
@property
def input_size(self):
return self._input_size
@property
def output_size(self):
return self._hidden_size
@property
def gate_nonlinearity(self):
return self._gate_nonlinearity
@property
def update_nonlinearity(self):
return self._update_nonlinearity
@property
def wRank(self):
return self._wRank
@property
def uRank(self):
return self._uRank
@property
def num_weight_matrices(self):
return self._num_weight_matrices
@property
def name(self):
return self._name
@property
def cellType(self):
return "FastGRNN"
@ -290,112 +342,6 @@ class FastGRNNCell(nn.Module):
Vars.extend([self.zeta, self.nu])
return Vars
def getModelSize(self):
'''
Function to get aimed model size
'''
mats = self.getVars()
endW = self._num_weight_matrices[0]
endU = endW + self._num_weight_matrices[1]
totalnnz = 2 # For Zeta and Nu
for i in range(0, endW):
totalnnz += utils.countNNZ(mats[i], self._wSparsity)
for i in range(endW, endU):
totalnnz += utils.countNNZ(mats[i], self._uSparsity)
for i in range(endU, len(mats)):
totalnnz += utils.countNNZ(mats[i], False)
return totalnnz * 4
#totalnnz += utils.countNNZ(self.bias_gate, False)
#totalnnz += utils.countNNZ(self.bias_update, False)
#if self._wRank is None:
# totalnnz += utils.countNNZ(self.W, self._wSparsity)
#else:
# totalnnz += utils.countNNZ(self.W1, self._wSparsity)
# totalnnz += utils.countNNZ(self.W2, self._wSparsity)
#if self._uRank is None:
# totalnnz += utils.countNNZ(self.U, self._uSparsity)
#else:
# totalnnz += utils.countNNZ(self.U1, self._uSparsity)
# totalnnz += utils.countNNZ(self.U2, self._uSparsity)
def copy_previous_UW(self):
mats = self.getVars()
num_mats = self._num_weight_matrices[0] + self._num_weight_matrices[1]
if len(self.oldmats) != num_mats:
for i in range(num_mats):
self.oldmats.append(torch.FloatTensor())
for i in range(num_mats):
self.oldmats[i] = torch.FloatTensor(np.copy(mats[i].data.cpu().detach().numpy()))
self.oldmats[i].to(mats[i].device)
#if self._wRank is None:
# if self._wSparsity < 1.0:
# self.W_old = torch.FloatTensor(np.copy(self.W.data.cpu().detach().numpy()))
# self.W_old.to(self.W.device)
#else:
# if self._wSparsity < 1.0:
# self.W1_old = torch.FloatTensor(np.copy(self.W1.data.cpu().detach().numpy()))
# self.W2_old = torch.FloatTensor(np.copy(self.W2.data.cpu().detach().numpy()))
# self.W1_old.to(self.W1.device)
# self.W2_old.to(self.W2.device)
#if self._uRank is None:
# if self._uSparsity < 1.0:
# self.U_old = torch.FloatTensor(np.copy(self.U.data.cpu().detach().numpy()))
# self.U_old.to(self.U.device)
#else:
# if self._uSparsity < 1.0:
# self.U1_old = torch.FloatTensor(np.copy(self.U1.data.cpu().detach().numpy()))
# self.U2_old = torch.FloatTensor(np.copy(self.U2.data.cpu().detach().numpy()))
# self.U1_old.to(self.U1.device)
# self.U2_old.to(self.U2.device)
def sparsify(self):
mats = self.getVars()
endW = self._num_weight_matrices[0]
endU = endW + self._num_weight_matrices[1]
for i in range(0, endW):
mats[i] = utils.hardThreshold(mats[i], self._wSparsity)
for i in range(endW, endU):
mats[i] = utils.hardThreshold(mats[i], self._uSparsity)
self.copy_previous_UW()
#if self._wRank is None:
# self.W.data = utils.hardThreshold(self.W, self._wSparsity)
#else:
# self.W1.data = utils.hardThreshold(self.W1, self._wSparsity)
# self.W2.data = utils.hardThreshold(self.W2, self._wSparsity)
#if self._uRank is None:
# self.U.data = utils.hardThreshold(self.U, self._uSparsity)
#else:
# self.U1.data = utils.hardThreshold(self.U1, self._uSparsity)
# self.U2.data = utils.hardThreshold(self.U2, self._uSparsity)
#self.copy_previous_UW()
def sparsifyWithSupport(self):
mats = self.getVars()
endU = self._num_weight_matrices[0] + self._num_weight_matrices[1]
for i in range(0, endU):
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])
#if self._wRank is None:
# self.W.data = utils.supportBasedThreshold(self.W, self.W_old)
#else:
# self.W1.data = utils.supportBasedThreshold(self.W1, self.W1_old)
# self.W2.data = utils.supportBasedThreshold(self.W2, self.W2_old)
#if self._uRank is None:
# self.U.data = utils.supportBasedThreshold(self.U, self.U_old)
#else:
# self.U1.data = utils.supportBasedThreshold(self.U1, self.U1_old)
# self.U2.data = utils.supportBasedThreshold(self.U2, self.U2_old)
#self.copy_previous_UW()
class FastRNNCell(nn.Module):
'''