ensure model is in device after sparsifying

This commit is contained in:
Harsha Vardhan Simhadri 2019-07-30 19:38:36 +05:30
Родитель cd2a2fdb8e
Коммит c1a3eaeb19
2 изменённых файлов: 1 добавлений и 1 удалений

Просмотреть файл

@ -231,7 +231,6 @@ class FastGRNNCell(nn.Module):
self.U2_old = torch.FloatTensor(np.copy(self.U2.data.cpu().detach().numpy()))
self.U1_old.to(self.U1.device)
self.U2_old.to(self.U2.device)
def sparsify(self):
if self._wRank is None:

Просмотреть файл

@ -239,6 +239,7 @@ class KeywordSpotter(nn.Module):
self.sparsifyWithSupport()
else:
self.sparsifyWithSupport()
self.to(device) # sparsify routines might move param matrices to cpu
learning_rate = optimizer.param_groups[0]['lr']
if detail: