This commit is contained in:
Harsha Vardhan Simhadri 2019-08-01 19:18:43 +05:30
Родитель 1329bfb282
Коммит d373338570
2 изменённых файлов: 28 добавлений и 46 удалений

Просмотреть файл

@ -8,18 +8,16 @@ import numpy as np
import edgeml.pytorch.utils as utils
def onnx_exportable_rnn(input, fargs, cell_name,
output, hidden_size, wRank, uRank,
gate_nonlinearity, update_nonlinearity):
def onnx_exportable_rnn(input, fargs, cell, output):
class RNNSymbolic(Function):
@staticmethod
def symbolic(g, *fargs):
# NOTE: args/kwargs contain RNN parameters
return g.op(cell_name, *fargs,
outputs=1, hidden_size_i=hidden_size,
wRank_i=wRank, uRank_i=uRank,
gate_nonlinearity_s=gate_nonlinearity,
update_nonlinearity_s=update_nonlinearity)
return g.op(cell.name, *fargs,
outputs=1, hidden_size_i=cell.state_size,
wRank_i=cell.wRank, uRank_i=cell.uRank,
gate_nonlinearity_s=cell.gate_nonlinearity,
update_nonlinearity_s=cell.update_nonlinearity)
@staticmethod
def forward(ctx, *fargs):

Просмотреть файл

@ -38,6 +38,8 @@ def get_model_class(inheritance_class=nn.Module):
self.linear = linear
self.batch_first = batch_first
self.apply_softmax = apply_softmax
self.rnn_list = []
if self.linear:
if not self.num_classes:
raise Exception("num_classes need to be specified if linear is True")
@ -53,6 +55,7 @@ def get_model_class(inheritance_class=nn.Module):
wSparsity=self.wSparsity_list[0], uSparsity=self.uSparsity_list[0],
batch_first = self.batch_first)
self.rnn2 = None
self.rnn_list.append(self.rnn1)
last_output_size = self.hidden_units_list[0]
if self.num_layers > 1:
self.rnn2 = FastGRNN(self.hidden_units_list[0], self.hidden_units_list[1],
@ -62,6 +65,7 @@ def get_model_class(inheritance_class=nn.Module):
wSparsity=self.wSparsity_list[1], uSparsity=self.uSparsity_list[1],
batch_first = self.batch_first)
last_output_size = self.hidden_units_list[1]
self.rnn_list.append(self.rnn2)
self.rnn3 = None
if self.num_layers > 2:
self.rnn3 = FastGRNN(self.hidden_units_list[1], self.hidden_units_list[2],
@ -71,6 +75,7 @@ def get_model_class(inheritance_class=nn.Module):
wSparsity=self.wSparsity_list[2], uSparsity=self.uSparsity_list[2],
batch_first = self.batch_first)
last_output_size = self.hidden_units_list[2]
self.rnn_list.append(self.rnn3)
# The linear layer is a fully connected layer that maps from hidden state space
# to number of expected keywords
@ -79,26 +84,17 @@ def get_model_class(inheritance_class=nn.Module):
self.init_hidden()
def sparsify(self):
self.rnn1.cell.sparsify()
if self.num_layers > 1:
self.rnn2.cell.sparsify()
if self.num_layers > 2:
self.rnn3.cell.sparsify()
for rnn in self.rnn_list:
rnn.cell.sparsify()
def sparsifyWithSupport(self):
self.rnn1.cell.sparsifyWithSupport()
if self.num_layers > 1:
self.rnn2.cell.sparsifyWithSupport()
if self.num_layers > 2:
self.rnn3.cell.sparsifyWithSupport()
for rnn in self.rnn_list:
rnn.cell.sparsifyWithSupport()
def getModelSize(self):
total_size = self.rnn1.cell.getModelSize()
if self.num_layers > 1:
total_size += self.rnn2.cell.getModelSize()
if self.num_layers > 2:
total_size += self.rnn3.cell.getModelSize()
total_size += 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
total_size = 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
for rnn in self.rnn_list:
total_size += rnn.cell.getModelSize()
return total_size
def normalize(self, mean, std):
@ -147,42 +143,30 @@ def get_model_class(inheritance_class=nn.Module):
if self.mean is not None:
input = (input - self.mean) / self.std
rnn_out1 = self.rnn1(input, hiddenState=self.hidden1)
fastgrnn_output = rnn_out1
model_output = rnn_out1
# we have to detach the hidden states because we may keep them longer than 1 iteration.
self.hidden1 = rnn_out1.detach()[-1, :, :]
if self.tracking:
weights = self.rnn1.getVars()
rnn_out1 = onnx_exportable_rnn(input, weights, self.rnn1.cell.name,
output=rnn_out1, hidden_size=self.hidden_units_list[0],
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = rnn_out1
rnn_out1 = onnx_exportable_rnn(input, weights, self.rnn1.cell, output=rnn_out1)
model_output = rnn_out1
if self.rnn2 is not None:
rnn_out2 = self.rnn2(rnn_out1, hiddenState=self.hidden2)
self.hidden2 = rnn_out2.detach()[-1, :, :]
if self.tracking:
weights = self.rnn2.getVars()
rnn_out2 = onnx_exportable_rnn(rnn_out1, weights, self.rnn2.cell.name,
output=rnn_out2, hidden_size=self.hidden_units_list[1],
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = rnn_out2
rnn_out2 = onnx_exportable_rnn(rnn_out1, weights, self.rnn2.cell, output=rnn_out2)
model_output = rnn_out2
if self.rnn3 is not None:
rnn_out3 = self.rnn3(rnn_out2, hiddenState=self.hidden3)
self.hidden3 = rnn_out3.detach()[-1, :, :]
if self.tracking:
weights = self.rnn3.getVars()
rnn_out3 = onnx_exportable_rnn(rnn_out2, weights, self.rnn3.cell.name,
output=rnn_out3, hidden_size=self.hidden_units_list[2],
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = rnn_out3
rnn_out3 = onnx_exportable_rnn(rnn_out2, weights, self.rnn3.cell, output=rnn_out3)
model_output = rnn_out3
if self.linear:
fastgrnn_output = self.hidden2keyword(fastgrnn_output[-1, :, :])
model_output = self.hidden2keyword(model_output[-1, :, :])
if self.apply_softmax:
fastgrnn_output = F.log_softmax(fastgrnn_output, dim=1)
return fastgrnn_output
model_output = F.log_softmax(model_output, dim=1)
return model_output
return RNNClassifierModel