From 27ef35ea64566ac0793d9e9343de62f680c67fcd Mon Sep 17 00:00:00 2001 From: Harsha Vardhan Simhadri Date: Wed, 31 Jul 2019 12:23:05 +0530 Subject: [PATCH] cleaned up FastGRNN code --- edgeml/pytorch/graph/rnn.py | 248 ++++++++++++++---------------------- 1 file changed, 97 insertions(+), 151 deletions(-) diff --git a/edgeml/pytorch/graph/rnn.py b/edgeml/pytorch/graph/rnn.py index 7654b128..75eaf9b2 100644 --- a/edgeml/pytorch/graph/rnn.py +++ b/edgeml/pytorch/graph/rnn.py @@ -130,8 +130,102 @@ class BaseRNN(nn.Module): hiddenStates[i, :, :] = hiddenState return hiddenStates +class RNNCell(nn.Module): + def __init__(self, input_size, hidden_size, update_nonlinearity="tanh", + wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0, + name="None"): + super(RNNCell, self).__init__() + self._input_size = input_size + self._hidden_size = hidden_size + self._update_nonlinearity = update_nonlinearity + self._num_weight_matrices = [1, 1] + self._wRank = wRank + self._uRank = uRank + self._wSparsity = wSparsity + self._uSparsity = uSparsity + self.oldmats = [] -class FastGRNNCell(nn.Module): + + @property + def state_size(self): + return self._hidden_size + + @property + def input_size(self): + return self._input_size + + @property + def output_size(self): + return self._hidden_size + + @property + def update_nonlinearity(self): + return self._update_nonlinearity + + @property + def wRank(self): + return self._wRank + + @property + def uRank(self): + return self._uRank + + @property + def num_weight_matrices(self): + return self._num_weight_matrices + + @property + def name(self): + return self._name + + def getVars(self): + raise NotImplementedError() + + def getModelSize(self): + ''' + Function to get aimed model size + ''' + mats = self.getVars() + endW = self._num_weight_matrices[0] + endU = endW + self._num_weight_matrices[1] + + totalnnz = 2 # For Zeta and Nu + for i in range(0, endW): + totalnnz += utils.countNNZ(mats[i], self._wSparsity) + for i in range(endW, endU): + totalnnz += utils.countNNZ(mats[i], self._uSparsity) + for i in range(endU, len(mats)): + totalnnz += utils.countNNZ(mats[i], False) + return totalnnz * 4 + + def copy_previous_UW(self): + mats = self.getVars() + num_mats = self._num_weight_matrices[0] + self._num_weight_matrices[1] + if len(self.oldmats) != num_mats: + for i in range(num_mats): + self.oldmats.append(torch.FloatTensor()) + for i in range(num_mats): + self.oldmats[i] = torch.FloatTensor(np.copy(mats[i].data.cpu().detach().numpy())) + self.oldmats[i].to(mats[i].device) + + def sparsify(self): + mats = self.getVars() + endW = self._num_weight_matrices[0] + endU = endW + self._num_weight_matrices[1] + for i in range(0, endW): + mats[i] = utils.hardThreshold(mats[i], self._wSparsity) + for i in range(endW, endU): + mats[i] = utils.hardThreshold(mats[i], self._uSparsity) + self.copy_previous_UW() + + def sparsifyWithSupport(self): + mats = self.getVars() + endU = self._num_weight_matrices[0] + self._num_weight_matrices[1] + for i in range(0, endU): + mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i]) + + +class FastGRNNCell(RNNCell): ''' FastGRNN Cell with Both Full Rank and Low Rank Formulations Has multiple activation functions for the gates @@ -171,17 +265,9 @@ class FastGRNNCell(nn.Module): update_nonlinearity="tanh", wRank=None, uRank=None, wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0, name="FastGRNN"): - super(FastGRNNCell, self).__init__() - - self._input_size = input_size - self._hidden_size = hidden_size + super(FastGRNNCell, self).__init__(input_size, hidden_size, update_nonlinearity, + wRank, uRank, wSparsity, uSparsity) self._gate_nonlinearity = gate_nonlinearity - self._update_nonlinearity = update_nonlinearity - self._num_weight_matrices = [1, 1] - self._wRank = wRank - self._uRank = uRank - self._wSparsity = wSparsity - self._uSparsity = uSparsity self._zetaInit = zetaInit self._nuInit = nuInit if wRank is not None: @@ -207,46 +293,12 @@ class FastGRNNCell(nn.Module): self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1])) self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1])) - self.oldmats = [] self.copy_previous_UW() - - @property - def state_size(self): - return self._hidden_size - - @property - def input_size(self): - return self._input_size - - @property - def output_size(self): - return self._hidden_size - @property def gate_nonlinearity(self): return self._gate_nonlinearity - @property - def update_nonlinearity(self): - return self._update_nonlinearity - - @property - def wRank(self): - return self._wRank - - @property - def uRank(self): - return self._uRank - - @property - def num_weight_matrices(self): - return self._num_weight_matrices - - @property - def name(self): - return self._name - @property def cellType(self): return "FastGRNN" @@ -290,112 +342,6 @@ class FastGRNNCell(nn.Module): Vars.extend([self.zeta, self.nu]) return Vars - def getModelSize(self): - ''' - Function to get aimed model size - ''' - mats = self.getVars() - endW = self._num_weight_matrices[0] - endU = endW + self._num_weight_matrices[1] - - totalnnz = 2 # For Zeta and Nu - for i in range(0, endW): - totalnnz += utils.countNNZ(mats[i], self._wSparsity) - for i in range(endW, endU): - totalnnz += utils.countNNZ(mats[i], self._uSparsity) - for i in range(endU, len(mats)): - totalnnz += utils.countNNZ(mats[i], False) - return totalnnz * 4 - - #totalnnz += utils.countNNZ(self.bias_gate, False) - #totalnnz += utils.countNNZ(self.bias_update, False) - #if self._wRank is None: - # totalnnz += utils.countNNZ(self.W, self._wSparsity) - #else: - # totalnnz += utils.countNNZ(self.W1, self._wSparsity) - # totalnnz += utils.countNNZ(self.W2, self._wSparsity) - - #if self._uRank is None: - # totalnnz += utils.countNNZ(self.U, self._uSparsity) - #else: - # totalnnz += utils.countNNZ(self.U1, self._uSparsity) - # totalnnz += utils.countNNZ(self.U2, self._uSparsity) - - - def copy_previous_UW(self): - mats = self.getVars() - num_mats = self._num_weight_matrices[0] + self._num_weight_matrices[1] - if len(self.oldmats) != num_mats: - for i in range(num_mats): - self.oldmats.append(torch.FloatTensor()) - for i in range(num_mats): - self.oldmats[i] = torch.FloatTensor(np.copy(mats[i].data.cpu().detach().numpy())) - self.oldmats[i].to(mats[i].device) - - #if self._wRank is None: - # if self._wSparsity < 1.0: - # self.W_old = torch.FloatTensor(np.copy(self.W.data.cpu().detach().numpy())) - # self.W_old.to(self.W.device) - #else: - # if self._wSparsity < 1.0: - # self.W1_old = torch.FloatTensor(np.copy(self.W1.data.cpu().detach().numpy())) - # self.W2_old = torch.FloatTensor(np.copy(self.W2.data.cpu().detach().numpy())) - # self.W1_old.to(self.W1.device) - # self.W2_old.to(self.W2.device) - - #if self._uRank is None: - # if self._uSparsity < 1.0: - # self.U_old = torch.FloatTensor(np.copy(self.U.data.cpu().detach().numpy())) - # self.U_old.to(self.U.device) - #else: - # if self._uSparsity < 1.0: - # self.U1_old = torch.FloatTensor(np.copy(self.U1.data.cpu().detach().numpy())) - # self.U2_old = torch.FloatTensor(np.copy(self.U2.data.cpu().detach().numpy())) - # self.U1_old.to(self.U1.device) - # self.U2_old.to(self.U2.device) - - def sparsify(self): - mats = self.getVars() - endW = self._num_weight_matrices[0] - endU = endW + self._num_weight_matrices[1] - for i in range(0, endW): - mats[i] = utils.hardThreshold(mats[i], self._wSparsity) - for i in range(endW, endU): - mats[i] = utils.hardThreshold(mats[i], self._uSparsity) - self.copy_previous_UW() - - #if self._wRank is None: - # self.W.data = utils.hardThreshold(self.W, self._wSparsity) - #else: - # self.W1.data = utils.hardThreshold(self.W1, self._wSparsity) - # self.W2.data = utils.hardThreshold(self.W2, self._wSparsity) - - #if self._uRank is None: - # self.U.data = utils.hardThreshold(self.U, self._uSparsity) - #else: - # self.U1.data = utils.hardThreshold(self.U1, self._uSparsity) - # self.U2.data = utils.hardThreshold(self.U2, self._uSparsity) - #self.copy_previous_UW() - - def sparsifyWithSupport(self): - mats = self.getVars() - endU = self._num_weight_matrices[0] + self._num_weight_matrices[1] - for i in range(0, endU): - mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i]) - - #if self._wRank is None: - # self.W.data = utils.supportBasedThreshold(self.W, self.W_old) - #else: - # self.W1.data = utils.supportBasedThreshold(self.W1, self.W1_old) - # self.W2.data = utils.supportBasedThreshold(self.W2, self.W2_old) - - #if self._uRank is None: - # self.U.data = utils.supportBasedThreshold(self.U, self.U_old) - #else: - # self.U1.data = utils.supportBasedThreshold(self.U1, self.U1_old) - # self.U2.data = utils.supportBasedThreshold(self.U2, self.U2_old) - #self.copy_previous_UW() - class FastRNNCell(nn.Module): '''