diff --git a/edgeml/pytorch/graph/rnn.py b/edgeml/pytorch/graph/rnn.py index a71da761..7654b128 100644 --- a/edgeml/pytorch/graph/rnn.py +++ b/edgeml/pytorch/graph/rnn.py @@ -201,14 +201,16 @@ class FastGRNNCell(nn.Module): else: self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size])) self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank])) - - self.copy_previous_state() - + self.bias_gate = nn.Parameter(torch.ones([1, hidden_size])) self.bias_update = nn.Parameter(torch.ones([1, hidden_size])) self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1])) self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1])) + self.oldmats = [] + self.copy_previous_UW() + + @property def state_size(self): return self._hidden_size @@ -263,7 +265,6 @@ class FastGRNNCell(nn.Module): torch.matmul(state, torch.transpose(self.U1, 0, 1)), torch.transpose(self.U2, 0, 1)) pre_comp = wComp + uComp - z = gen_nonlinearity(pre_comp + self.bias_gate, self._gate_nonlinearity) c = gen_nonlinearity(pre_comp + self.bias_update, @@ -293,22 +294,18 @@ class FastGRNNCell(nn.Module): ''' Function to get aimed model size ''' - totalnnz = 2 # For Zeta and Nu - mats = self.getVars() - endW = self._num_weight_matrices[0] + endU = endW + self._num_weight_matrices[1] + + totalnnz = 2 # For Zeta and Nu for i in range(0, endW): totalnnz += utils.countNNZ(mats[i], self._wSparsity) - - endU = endW + self._num_weight_matrices[1] for i in range(endW, endU): totalnnz += utils.countNNZ(mats[i], self._uSparsity) - - for i in range(endU, mats.len()): + for i in range(endU, len(mats)): totalnnz += utils.countNNZ(mats[i], False) - - return totalnnz + return totalnnz * 4 #totalnnz += utils.countNNZ(self.bias_gate, False) #totalnnz += utils.countNNZ(self.bias_update, False) @@ -325,56 +322,79 @@ class FastGRNNCell(nn.Module): # totalnnz += utils.countNNZ(self.U2, self._uSparsity) - def copy_previous_state(self): - if self._wRank is None: - if self._wSparsity < 1.0: - self.W_old = torch.FloatTensor(np.copy(self.W.data.cpu().detach().numpy())) - self.W_old.to(self.W.device) - else: - if self._wSparsity < 1.0: - self.W1_old = torch.FloatTensor(np.copy(self.W1.data.cpu().detach().numpy())) - self.W2_old = torch.FloatTensor(np.copy(self.W2.data.cpu().detach().numpy())) - self.W1_old.to(self.W1.device) - self.W2_old.to(self.W2.device) + def copy_previous_UW(self): + mats = self.getVars() + num_mats = self._num_weight_matrices[0] + self._num_weight_matrices[1] + if len(self.oldmats) != num_mats: + for i in range(num_mats): + self.oldmats.append(torch.FloatTensor()) + for i in range(num_mats): + self.oldmats[i] = torch.FloatTensor(np.copy(mats[i].data.cpu().detach().numpy())) + self.oldmats[i].to(mats[i].device) - if self._uRank is None: - if self._uSparsity < 1.0: - self.U_old = torch.FloatTensor(np.copy(self.U.data.cpu().detach().numpy())) - self.U_old.to(self.U.device) - else: - if self._uSparsity < 1.0: - self.U1_old = torch.FloatTensor(np.copy(self.U1.data.cpu().detach().numpy())) - self.U2_old = torch.FloatTensor(np.copy(self.U2.data.cpu().detach().numpy())) - self.U1_old.to(self.U1.device) - self.U2_old.to(self.U2.device) + #if self._wRank is None: + # if self._wSparsity < 1.0: + # self.W_old = torch.FloatTensor(np.copy(self.W.data.cpu().detach().numpy())) + # self.W_old.to(self.W.device) + #else: + # if self._wSparsity < 1.0: + # self.W1_old = torch.FloatTensor(np.copy(self.W1.data.cpu().detach().numpy())) + # self.W2_old = torch.FloatTensor(np.copy(self.W2.data.cpu().detach().numpy())) + # self.W1_old.to(self.W1.device) + # self.W2_old.to(self.W2.device) + + #if self._uRank is None: + # if self._uSparsity < 1.0: + # self.U_old = torch.FloatTensor(np.copy(self.U.data.cpu().detach().numpy())) + # self.U_old.to(self.U.device) + #else: + # if self._uSparsity < 1.0: + # self.U1_old = torch.FloatTensor(np.copy(self.U1.data.cpu().detach().numpy())) + # self.U2_old = torch.FloatTensor(np.copy(self.U2.data.cpu().detach().numpy())) + # self.U1_old.to(self.U1.device) + # self.U2_old.to(self.U2.device) def sparsify(self): - if self._wRank is None: - self.W.data = utils.hardThreshold(self.W, self._wSparsity) - else: - self.W1.data = utils.hardThreshold(self.W1, self._wSparsity) - self.W2.data = utils.hardThreshold(self.W2, self._wSparsity) + mats = self.getVars() + endW = self._num_weight_matrices[0] + endU = endW + self._num_weight_matrices[1] + for i in range(0, endW): + mats[i] = utils.hardThreshold(mats[i], self._wSparsity) + for i in range(endW, endU): + mats[i] = utils.hardThreshold(mats[i], self._uSparsity) + self.copy_previous_UW() - if self._uRank is None: - self.U.data = utils.hardThreshold(self.U, self._uSparsity) - else: - self.U1.data = utils.hardThreshold(self.U1, self._uSparsity) - self.U2.data = utils.hardThreshold(self.U2, self._uSparsity) - self.copy_previous_state() + #if self._wRank is None: + # self.W.data = utils.hardThreshold(self.W, self._wSparsity) + #else: + # self.W1.data = utils.hardThreshold(self.W1, self._wSparsity) + # self.W2.data = utils.hardThreshold(self.W2, self._wSparsity) + + #if self._uRank is None: + # self.U.data = utils.hardThreshold(self.U, self._uSparsity) + #else: + # self.U1.data = utils.hardThreshold(self.U1, self._uSparsity) + # self.U2.data = utils.hardThreshold(self.U2, self._uSparsity) + #self.copy_previous_UW() def sparsifyWithSupport(self): - if self._wRank is None: - self.W.data = utils.supportBasedThreshold(self.W, self.W_old) - else: - self.W1.data = utils.supportBasedThreshold(self.W1, self.W1_old) - self.W2.data = utils.supportBasedThreshold(self.W2, self.W2_old) + mats = self.getVars() + endU = self._num_weight_matrices[0] + self._num_weight_matrices[1] + for i in range(0, endU): + mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i]) - if self._uRank is None: - self.U.data = utils.supportBasedThreshold(self.U, self.U_old) - else: - self.U1.data = utils.supportBasedThreshold(self.U1, self.U1_old) - self.U2.data = utils.supportBasedThreshold(self.U2, self.U2_old) - #self.copy_previous_state() + #if self._wRank is None: + # self.W.data = utils.supportBasedThreshold(self.W, self.W_old) + #else: + # self.W1.data = utils.supportBasedThreshold(self.W1, self.W1_old) + # self.W2.data = utils.supportBasedThreshold(self.W2, self.W2_old) + + #if self._uRank is None: + # self.U.data = utils.supportBasedThreshold(self.U, self.U_old) + #else: + # self.U1.data = utils.supportBasedThreshold(self.U1, self.U1_old) + # self.U2.data = utils.supportBasedThreshold(self.U2, self.U2_old) + #self.copy_previous_UW() class FastRNNCell(nn.Module): diff --git a/edgeml/pytorch/trainer/fastmodel.py b/edgeml/pytorch/trainer/fastmodel.py index 44fa87a7..58e190b6 100644 --- a/edgeml/pytorch/trainer/fastmodel.py +++ b/edgeml/pytorch/trainer/fastmodel.py @@ -98,7 +98,7 @@ def fastgrnnmodel(inheritance_class=nn.Module): total_size += self.fastgrnn2.cell.getModelSize() if self.num_layers > 2: total_size += self.fastgrnn3.cell.getModelSize() - total_size += self.hidden_units_list[2] * self.num_classes + total_size += 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes return total_size def normalize(self, mean, std): diff --git a/examples/pytorch/FastCells/train_classifier.py b/examples/pytorch/FastCells/train_classifier.py index eecf6490..c3d84923 100644 --- a/examples/pytorch/FastCells/train_classifier.py +++ b/examples/pytorch/FastCells/train_classifier.py @@ -231,7 +231,7 @@ class KeywordSpotter(nn.Module): optimizer.step() if sparsify: - if epoch > num_epochs/3: + if epoch >= num_epochs/3: if epoch < (2*num_epochs)/3: if i_batch % trim_level == 0: self.sparsify()