зеркало из https://github.com/microsoft/EdgeML.git
cleaning up fastmodel.py
This commit is contained in:
Родитель
1329bfb282
Коммит
d373338570
|
@ -8,18 +8,16 @@ import numpy as np
|
|||
|
||||
import edgeml.pytorch.utils as utils
|
||||
|
||||
def onnx_exportable_rnn(input, fargs, cell_name,
|
||||
output, hidden_size, wRank, uRank,
|
||||
gate_nonlinearity, update_nonlinearity):
|
||||
def onnx_exportable_rnn(input, fargs, cell, output):
|
||||
class RNNSymbolic(Function):
|
||||
@staticmethod
|
||||
def symbolic(g, *fargs):
|
||||
# NOTE: args/kwargs contain RNN parameters
|
||||
return g.op(cell_name, *fargs,
|
||||
outputs=1, hidden_size_i=hidden_size,
|
||||
wRank_i=wRank, uRank_i=uRank,
|
||||
gate_nonlinearity_s=gate_nonlinearity,
|
||||
update_nonlinearity_s=update_nonlinearity)
|
||||
return g.op(cell.name, *fargs,
|
||||
outputs=1, hidden_size_i=cell.state_size,
|
||||
wRank_i=cell.wRank, uRank_i=cell.uRank,
|
||||
gate_nonlinearity_s=cell.gate_nonlinearity,
|
||||
update_nonlinearity_s=cell.update_nonlinearity)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *fargs):
|
||||
|
|
|
@ -38,6 +38,8 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
self.linear = linear
|
||||
self.batch_first = batch_first
|
||||
self.apply_softmax = apply_softmax
|
||||
self.rnn_list = []
|
||||
|
||||
if self.linear:
|
||||
if not self.num_classes:
|
||||
raise Exception("num_classes need to be specified if linear is True")
|
||||
|
@ -53,6 +55,7 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
wSparsity=self.wSparsity_list[0], uSparsity=self.uSparsity_list[0],
|
||||
batch_first = self.batch_first)
|
||||
self.rnn2 = None
|
||||
self.rnn_list.append(self.rnn1)
|
||||
last_output_size = self.hidden_units_list[0]
|
||||
if self.num_layers > 1:
|
||||
self.rnn2 = FastGRNN(self.hidden_units_list[0], self.hidden_units_list[1],
|
||||
|
@ -62,6 +65,7 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
wSparsity=self.wSparsity_list[1], uSparsity=self.uSparsity_list[1],
|
||||
batch_first = self.batch_first)
|
||||
last_output_size = self.hidden_units_list[1]
|
||||
self.rnn_list.append(self.rnn2)
|
||||
self.rnn3 = None
|
||||
if self.num_layers > 2:
|
||||
self.rnn3 = FastGRNN(self.hidden_units_list[1], self.hidden_units_list[2],
|
||||
|
@ -71,6 +75,7 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
wSparsity=self.wSparsity_list[2], uSparsity=self.uSparsity_list[2],
|
||||
batch_first = self.batch_first)
|
||||
last_output_size = self.hidden_units_list[2]
|
||||
self.rnn_list.append(self.rnn3)
|
||||
|
||||
# The linear layer is a fully connected layer that maps from hidden state space
|
||||
# to number of expected keywords
|
||||
|
@ -79,26 +84,17 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
self.init_hidden()
|
||||
|
||||
def sparsify(self):
|
||||
self.rnn1.cell.sparsify()
|
||||
if self.num_layers > 1:
|
||||
self.rnn2.cell.sparsify()
|
||||
if self.num_layers > 2:
|
||||
self.rnn3.cell.sparsify()
|
||||
for rnn in self.rnn_list:
|
||||
rnn.cell.sparsify()
|
||||
|
||||
def sparsifyWithSupport(self):
|
||||
self.rnn1.cell.sparsifyWithSupport()
|
||||
if self.num_layers > 1:
|
||||
self.rnn2.cell.sparsifyWithSupport()
|
||||
if self.num_layers > 2:
|
||||
self.rnn3.cell.sparsifyWithSupport()
|
||||
for rnn in self.rnn_list:
|
||||
rnn.cell.sparsifyWithSupport()
|
||||
|
||||
def getModelSize(self):
|
||||
total_size = self.rnn1.cell.getModelSize()
|
||||
if self.num_layers > 1:
|
||||
total_size += self.rnn2.cell.getModelSize()
|
||||
if self.num_layers > 2:
|
||||
total_size += self.rnn3.cell.getModelSize()
|
||||
total_size += 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
|
||||
total_size = 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
|
||||
for rnn in self.rnn_list:
|
||||
total_size += rnn.cell.getModelSize()
|
||||
return total_size
|
||||
|
||||
def normalize(self, mean, std):
|
||||
|
@ -147,42 +143,30 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
if self.mean is not None:
|
||||
input = (input - self.mean) / self.std
|
||||
rnn_out1 = self.rnn1(input, hiddenState=self.hidden1)
|
||||
fastgrnn_output = rnn_out1
|
||||
model_output = rnn_out1
|
||||
# we have to detach the hidden states because we may keep them longer than 1 iteration.
|
||||
self.hidden1 = rnn_out1.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.rnn1.getVars()
|
||||
rnn_out1 = onnx_exportable_rnn(input, weights, self.rnn1.cell.name,
|
||||
output=rnn_out1, hidden_size=self.hidden_units_list[0],
|
||||
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = rnn_out1
|
||||
rnn_out1 = onnx_exportable_rnn(input, weights, self.rnn1.cell, output=rnn_out1)
|
||||
model_output = rnn_out1
|
||||
if self.rnn2 is not None:
|
||||
rnn_out2 = self.rnn2(rnn_out1, hiddenState=self.hidden2)
|
||||
self.hidden2 = rnn_out2.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.rnn2.getVars()
|
||||
rnn_out2 = onnx_exportable_rnn(rnn_out1, weights, self.rnn2.cell.name,
|
||||
output=rnn_out2, hidden_size=self.hidden_units_list[1],
|
||||
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = rnn_out2
|
||||
rnn_out2 = onnx_exportable_rnn(rnn_out1, weights, self.rnn2.cell, output=rnn_out2)
|
||||
model_output = rnn_out2
|
||||
if self.rnn3 is not None:
|
||||
rnn_out3 = self.rnn3(rnn_out2, hiddenState=self.hidden3)
|
||||
self.hidden3 = rnn_out3.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.rnn3.getVars()
|
||||
rnn_out3 = onnx_exportable_rnn(rnn_out2, weights, self.rnn3.cell.name,
|
||||
output=rnn_out3, hidden_size=self.hidden_units_list[2],
|
||||
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = rnn_out3
|
||||
rnn_out3 = onnx_exportable_rnn(rnn_out2, weights, self.rnn3.cell, output=rnn_out3)
|
||||
model_output = rnn_out3
|
||||
if self.linear:
|
||||
fastgrnn_output = self.hidden2keyword(fastgrnn_output[-1, :, :])
|
||||
model_output = self.hidden2keyword(model_output[-1, :, :])
|
||||
if self.apply_softmax:
|
||||
fastgrnn_output = F.log_softmax(fastgrnn_output, dim=1)
|
||||
return fastgrnn_output
|
||||
model_output = F.log_softmax(model_output, dim=1)
|
||||
return model_output
|
||||
return RNNClassifierModel
|
Загрузка…
Ссылка в новой задаче