зеркало из https://github.com/microsoft/EdgeML.git
cleaned up FastGRNN code
This commit is contained in:
Родитель
a35778fe30
Коммит
bd7732aedd
|
@ -201,14 +201,16 @@ class FastGRNNCell(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size]))
|
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size]))
|
||||||
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
|
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
|
||||||
|
|
||||||
self.copy_previous_state()
|
|
||||||
|
|
||||||
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
|
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
|
||||||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
|
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
|
||||||
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
|
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
|
||||||
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
|
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
|
||||||
|
|
||||||
|
self.oldmats = []
|
||||||
|
self.copy_previous_UW()
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_size(self):
|
def state_size(self):
|
||||||
return self._hidden_size
|
return self._hidden_size
|
||||||
|
@ -263,7 +265,6 @@ class FastGRNNCell(nn.Module):
|
||||||
torch.matmul(state, torch.transpose(self.U1, 0, 1)), torch.transpose(self.U2, 0, 1))
|
torch.matmul(state, torch.transpose(self.U1, 0, 1)), torch.transpose(self.U2, 0, 1))
|
||||||
|
|
||||||
pre_comp = wComp + uComp
|
pre_comp = wComp + uComp
|
||||||
|
|
||||||
z = gen_nonlinearity(pre_comp + self.bias_gate,
|
z = gen_nonlinearity(pre_comp + self.bias_gate,
|
||||||
self._gate_nonlinearity)
|
self._gate_nonlinearity)
|
||||||
c = gen_nonlinearity(pre_comp + self.bias_update,
|
c = gen_nonlinearity(pre_comp + self.bias_update,
|
||||||
|
@ -293,22 +294,18 @@ class FastGRNNCell(nn.Module):
|
||||||
'''
|
'''
|
||||||
Function to get aimed model size
|
Function to get aimed model size
|
||||||
'''
|
'''
|
||||||
totalnnz = 2 # For Zeta and Nu
|
|
||||||
|
|
||||||
mats = self.getVars()
|
mats = self.getVars()
|
||||||
|
|
||||||
endW = self._num_weight_matrices[0]
|
endW = self._num_weight_matrices[0]
|
||||||
|
endU = endW + self._num_weight_matrices[1]
|
||||||
|
|
||||||
|
totalnnz = 2 # For Zeta and Nu
|
||||||
for i in range(0, endW):
|
for i in range(0, endW):
|
||||||
totalnnz += utils.countNNZ(mats[i], self._wSparsity)
|
totalnnz += utils.countNNZ(mats[i], self._wSparsity)
|
||||||
|
|
||||||
endU = endW + self._num_weight_matrices[1]
|
|
||||||
for i in range(endW, endU):
|
for i in range(endW, endU):
|
||||||
totalnnz += utils.countNNZ(mats[i], self._uSparsity)
|
totalnnz += utils.countNNZ(mats[i], self._uSparsity)
|
||||||
|
for i in range(endU, len(mats)):
|
||||||
for i in range(endU, mats.len()):
|
|
||||||
totalnnz += utils.countNNZ(mats[i], False)
|
totalnnz += utils.countNNZ(mats[i], False)
|
||||||
|
return totalnnz * 4
|
||||||
return totalnnz
|
|
||||||
|
|
||||||
#totalnnz += utils.countNNZ(self.bias_gate, False)
|
#totalnnz += utils.countNNZ(self.bias_gate, False)
|
||||||
#totalnnz += utils.countNNZ(self.bias_update, False)
|
#totalnnz += utils.countNNZ(self.bias_update, False)
|
||||||
|
@ -325,56 +322,79 @@ class FastGRNNCell(nn.Module):
|
||||||
# totalnnz += utils.countNNZ(self.U2, self._uSparsity)
|
# totalnnz += utils.countNNZ(self.U2, self._uSparsity)
|
||||||
|
|
||||||
|
|
||||||
def copy_previous_state(self):
|
def copy_previous_UW(self):
|
||||||
if self._wRank is None:
|
mats = self.getVars()
|
||||||
if self._wSparsity < 1.0:
|
num_mats = self._num_weight_matrices[0] + self._num_weight_matrices[1]
|
||||||
self.W_old = torch.FloatTensor(np.copy(self.W.data.cpu().detach().numpy()))
|
if len(self.oldmats) != num_mats:
|
||||||
self.W_old.to(self.W.device)
|
for i in range(num_mats):
|
||||||
else:
|
self.oldmats.append(torch.FloatTensor())
|
||||||
if self._wSparsity < 1.0:
|
for i in range(num_mats):
|
||||||
self.W1_old = torch.FloatTensor(np.copy(self.W1.data.cpu().detach().numpy()))
|
self.oldmats[i] = torch.FloatTensor(np.copy(mats[i].data.cpu().detach().numpy()))
|
||||||
self.W2_old = torch.FloatTensor(np.copy(self.W2.data.cpu().detach().numpy()))
|
self.oldmats[i].to(mats[i].device)
|
||||||
self.W1_old.to(self.W1.device)
|
|
||||||
self.W2_old.to(self.W2.device)
|
|
||||||
|
|
||||||
if self._uRank is None:
|
#if self._wRank is None:
|
||||||
if self._uSparsity < 1.0:
|
# if self._wSparsity < 1.0:
|
||||||
self.U_old = torch.FloatTensor(np.copy(self.U.data.cpu().detach().numpy()))
|
# self.W_old = torch.FloatTensor(np.copy(self.W.data.cpu().detach().numpy()))
|
||||||
self.U_old.to(self.U.device)
|
# self.W_old.to(self.W.device)
|
||||||
else:
|
#else:
|
||||||
if self._uSparsity < 1.0:
|
# if self._wSparsity < 1.0:
|
||||||
self.U1_old = torch.FloatTensor(np.copy(self.U1.data.cpu().detach().numpy()))
|
# self.W1_old = torch.FloatTensor(np.copy(self.W1.data.cpu().detach().numpy()))
|
||||||
self.U2_old = torch.FloatTensor(np.copy(self.U2.data.cpu().detach().numpy()))
|
# self.W2_old = torch.FloatTensor(np.copy(self.W2.data.cpu().detach().numpy()))
|
||||||
self.U1_old.to(self.U1.device)
|
# self.W1_old.to(self.W1.device)
|
||||||
self.U2_old.to(self.U2.device)
|
# self.W2_old.to(self.W2.device)
|
||||||
|
|
||||||
|
#if self._uRank is None:
|
||||||
|
# if self._uSparsity < 1.0:
|
||||||
|
# self.U_old = torch.FloatTensor(np.copy(self.U.data.cpu().detach().numpy()))
|
||||||
|
# self.U_old.to(self.U.device)
|
||||||
|
#else:
|
||||||
|
# if self._uSparsity < 1.0:
|
||||||
|
# self.U1_old = torch.FloatTensor(np.copy(self.U1.data.cpu().detach().numpy()))
|
||||||
|
# self.U2_old = torch.FloatTensor(np.copy(self.U2.data.cpu().detach().numpy()))
|
||||||
|
# self.U1_old.to(self.U1.device)
|
||||||
|
# self.U2_old.to(self.U2.device)
|
||||||
|
|
||||||
def sparsify(self):
|
def sparsify(self):
|
||||||
if self._wRank is None:
|
mats = self.getVars()
|
||||||
self.W.data = utils.hardThreshold(self.W, self._wSparsity)
|
endW = self._num_weight_matrices[0]
|
||||||
else:
|
endU = endW + self._num_weight_matrices[1]
|
||||||
self.W1.data = utils.hardThreshold(self.W1, self._wSparsity)
|
for i in range(0, endW):
|
||||||
self.W2.data = utils.hardThreshold(self.W2, self._wSparsity)
|
mats[i] = utils.hardThreshold(mats[i], self._wSparsity)
|
||||||
|
for i in range(endW, endU):
|
||||||
|
mats[i] = utils.hardThreshold(mats[i], self._uSparsity)
|
||||||
|
self.copy_previous_UW()
|
||||||
|
|
||||||
if self._uRank is None:
|
#if self._wRank is None:
|
||||||
self.U.data = utils.hardThreshold(self.U, self._uSparsity)
|
# self.W.data = utils.hardThreshold(self.W, self._wSparsity)
|
||||||
else:
|
#else:
|
||||||
self.U1.data = utils.hardThreshold(self.U1, self._uSparsity)
|
# self.W1.data = utils.hardThreshold(self.W1, self._wSparsity)
|
||||||
self.U2.data = utils.hardThreshold(self.U2, self._uSparsity)
|
# self.W2.data = utils.hardThreshold(self.W2, self._wSparsity)
|
||||||
self.copy_previous_state()
|
|
||||||
|
#if self._uRank is None:
|
||||||
|
# self.U.data = utils.hardThreshold(self.U, self._uSparsity)
|
||||||
|
#else:
|
||||||
|
# self.U1.data = utils.hardThreshold(self.U1, self._uSparsity)
|
||||||
|
# self.U2.data = utils.hardThreshold(self.U2, self._uSparsity)
|
||||||
|
#self.copy_previous_UW()
|
||||||
|
|
||||||
def sparsifyWithSupport(self):
|
def sparsifyWithSupport(self):
|
||||||
if self._wRank is None:
|
mats = self.getVars()
|
||||||
self.W.data = utils.supportBasedThreshold(self.W, self.W_old)
|
endU = self._num_weight_matrices[0] + self._num_weight_matrices[1]
|
||||||
else:
|
for i in range(0, endU):
|
||||||
self.W1.data = utils.supportBasedThreshold(self.W1, self.W1_old)
|
mats[i] = utils.supportBasedThreshold(mats[i], self.oldmats[i])
|
||||||
self.W2.data = utils.supportBasedThreshold(self.W2, self.W2_old)
|
|
||||||
|
|
||||||
if self._uRank is None:
|
#if self._wRank is None:
|
||||||
self.U.data = utils.supportBasedThreshold(self.U, self.U_old)
|
# self.W.data = utils.supportBasedThreshold(self.W, self.W_old)
|
||||||
else:
|
#else:
|
||||||
self.U1.data = utils.supportBasedThreshold(self.U1, self.U1_old)
|
# self.W1.data = utils.supportBasedThreshold(self.W1, self.W1_old)
|
||||||
self.U2.data = utils.supportBasedThreshold(self.U2, self.U2_old)
|
# self.W2.data = utils.supportBasedThreshold(self.W2, self.W2_old)
|
||||||
#self.copy_previous_state()
|
|
||||||
|
#if self._uRank is None:
|
||||||
|
# self.U.data = utils.supportBasedThreshold(self.U, self.U_old)
|
||||||
|
#else:
|
||||||
|
# self.U1.data = utils.supportBasedThreshold(self.U1, self.U1_old)
|
||||||
|
# self.U2.data = utils.supportBasedThreshold(self.U2, self.U2_old)
|
||||||
|
#self.copy_previous_UW()
|
||||||
|
|
||||||
|
|
||||||
class FastRNNCell(nn.Module):
|
class FastRNNCell(nn.Module):
|
||||||
|
|
|
@ -98,7 +98,7 @@ def fastgrnnmodel(inheritance_class=nn.Module):
|
||||||
total_size += self.fastgrnn2.cell.getModelSize()
|
total_size += self.fastgrnn2.cell.getModelSize()
|
||||||
if self.num_layers > 2:
|
if self.num_layers > 2:
|
||||||
total_size += self.fastgrnn3.cell.getModelSize()
|
total_size += self.fastgrnn3.cell.getModelSize()
|
||||||
total_size += self.hidden_units_list[2] * self.num_classes
|
total_size += 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
|
||||||
return total_size
|
return total_size
|
||||||
|
|
||||||
def normalize(self, mean, std):
|
def normalize(self, mean, std):
|
||||||
|
|
|
@ -231,7 +231,7 @@ class KeywordSpotter(nn.Module):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if sparsify:
|
if sparsify:
|
||||||
if epoch > num_epochs/3:
|
if epoch >= num_epochs/3:
|
||||||
if epoch < (2*num_epochs)/3:
|
if epoch < (2*num_epochs)/3:
|
||||||
if i_batch % trim_level == 0:
|
if i_batch % trim_level == 0:
|
||||||
self.sparsify()
|
self.sparsify()
|
||||||
|
|
Загрузка…
Ссылка в новой задаче