ensure model is in device after sparsifying

This commit is contained in:
Harsha Vardhan Simhadri 2019-07-30 19:38:36 +05:30
Родитель cd2a2fdb8e
Коммит c1a3eaeb19
2 изменённых файлов: 1 добавлений и 1 удалений

Просмотреть файл

@ -232,7 +232,6 @@ class FastGRNNCell(nn.Module):
self.U1_old.to(self.U1.device) self.U1_old.to(self.U1.device)
self.U2_old.to(self.U2.device) self.U2_old.to(self.U2.device)
def sparsify(self): def sparsify(self):
if self._wRank is None: if self._wRank is None:
self.W.data = utils.hardThreshold(self.W, self._wSparsity) self.W.data = utils.hardThreshold(self.W, self._wSparsity)

Просмотреть файл

@ -239,6 +239,7 @@ class KeywordSpotter(nn.Module):
self.sparsifyWithSupport() self.sparsifyWithSupport()
else: else:
self.sparsifyWithSupport() self.sparsifyWithSupport()
self.to(device) # sparsify routines might move param matrices to cpu
learning_rate = optimizer.param_groups[0]['lr'] learning_rate = optimizer.param_groups[0]['lr']
if detail: if detail: