moving params appropriately for gpu

This commit is contained in:
Harsha Vardhan Simhadri 2019-08-16 19:31:28 +05:30
Родитель bf71343ba8
Коммит 8f3a558813
3 изменённых файлов: 22 добавлений и 8 удалений

Просмотреть файл

@ -126,7 +126,7 @@ class RNNCell(nn.Module):
def getVars(self):
raise NotImplementedError()
def getModelSize(self):
def get_model_size(self):
'''
Function to get aimed model size
'''
@ -136,11 +136,17 @@ class RNNCell(nn.Module):
totalnnz = 2 # For Zeta and Nu
for i in range(0, endW):
device = mats[i].device
totalnnz += utils.countNNZ(mats[i].cpu(), self._wSparsity)
mats[i].to(device)
for i in range(endW, endU):
device = mats[i].device
totalnnz += utils.countNNZ(mats[i].cpu(), self._uSparsity)
mats[i].to(device)
for i in range(endU, len(mats)):
device = mats[i].device
totalnnz += utils.countNNZ(mats[i].cpu(), False)
mats[i].to(device)
return totalnnz * 4
def copy_previous_UW(self):

Просмотреть файл

@ -91,10 +91,10 @@ def get_model_class(inheritance_class=nn.Module):
for rnn in self.rnn_list:
rnn.cell.sparsifyWithSupport()
def getModelSize(self):
def get_model_size(self):
total_size = 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
for rnn in self.rnn_list:
total_size += rnn.cell.getModelSize()
total_size += rnn.cell.get_model_size()
return total_size
def normalize(self, mean, std):
@ -104,17 +104,22 @@ def get_model_class(inheritance_class=nn.Module):
def name(self):
return "{} layer FastGRNN".format(self.num_layers)
def move_to(self, device):
for rnn in self.rnn_list:
rnn.to(device)
if hasattr(self, 'hidden2keyword'):
self.hidden2keyword.to(device)
def init_hidden_bag(self, hidden_bag_size, device):
self.hidden_bag_size = hidden_bag_size
self.device = device
self.hidden1_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[0]],
dtype=np.float32)).to(self.device)
dtype=np.float32)).to(device)
if self.num_layers >= 2:
self.hidden2_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[1]],
dtype=np.float32)).to(self.device)
dtype=np.float32)).to(device)
if self.num_layers == 3:
self.hidden3_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[2]],
dtype=np.float32)).to(self.device)
dtype=np.float32)).to(device)
def rolling_step(self):
shuffled_indices = list(range(self.hidden_bag_size))

Просмотреть файл

@ -189,6 +189,7 @@ class KeywordSpotter(nn.Module):
audio = audio.transpose(1, 0) # GRU wants seq,batch,feature
if device:
self.move_to(device)
audio = audio.to(device)
labels = labels.to(device)
@ -204,6 +205,8 @@ class KeywordSpotter(nn.Module):
else:
self.init_hidden()
self.to(device) # sparsify routines might move param matrices to cpu
# Before the backward pass, use the optimizer object to zero all of the
# gradients for the variables it will update (which are the learnable
# weights of the model). This is because by default, gradients are
@ -262,7 +265,7 @@ class KeywordSpotter(nn.Module):
end = time.time()
self.training = False
print("Trained in {:.2f} seconds".format(end - start))
print("Model size {}".format(self.getModelSize()))
print("Model size {}".format(self.get_model_size()))
return log
def evaluate(self, test_data, batch_size, device=None, outfile=None):