зеркало из https://github.com/microsoft/EdgeML.git
trying to generlize ONNX exporter
This commit is contained in:
Родитель
92bf361480
Коммит
1329bfb282
|
@ -8,15 +8,18 @@ import numpy as np
|
|||
|
||||
import edgeml.pytorch.utils as utils
|
||||
|
||||
def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank,
|
||||
gate_nonlinearity, update_nonlinearity):
|
||||
def onnx_exportable_rnn(input, fargs, cell_name,
|
||||
output, hidden_size, wRank, uRank,
|
||||
gate_nonlinearity, update_nonlinearity):
|
||||
class RNNSymbolic(Function):
|
||||
@staticmethod
|
||||
def symbolic(g, *fargs):
|
||||
# NOTE: args/kwargs contain RNN parameters
|
||||
return g.op("FastGRNN", *fargs, outputs=1,
|
||||
hidden_size_i=hidden_size, wRank_i=wRank, uRank_i=uRank,
|
||||
gate_nonlinearity_s=gate_nonlinearity, update_nonlinearity_s=update_nonlinearity)
|
||||
return g.op(cell_name, *fargs,
|
||||
outputs=1, hidden_size_i=hidden_size,
|
||||
wRank_i=wRank, uRank_i=uRank,
|
||||
gate_nonlinearity_s=gate_nonlinearity,
|
||||
update_nonlinearity_s=update_nonlinearity)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *fargs):
|
||||
|
@ -107,10 +110,6 @@ class RNNCell(nn.Module):
|
|||
def uRank(self):
|
||||
return self._uRank
|
||||
|
||||
#@property
|
||||
#def num_weight_matrices(self):
|
||||
# return self._num_weight_matrices
|
||||
|
||||
@property
|
||||
def num_W_matrices(self):
|
||||
return self._num_W_matrices
|
||||
|
@ -119,6 +118,10 @@ class RNNCell(nn.Module):
|
|||
def num_U_matrices(self):
|
||||
return self._num_U_matrices
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, input, state):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -46,25 +46,25 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
|
||||
# The FastGRNN takes audio sequences as input, and outputs hidden states
|
||||
# with dimensionality hidden_units.
|
||||
self.fastgrnn1 = FastGRNN(self.input_dim, self.hidden_units_list[0],
|
||||
self.rnn1 = FastGRNN(self.input_dim, self.hidden_units_list[0],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
|
||||
wSparsity=self.wSparsity_list[0], uSparsity=self.uSparsity_list[0],
|
||||
batch_first = self.batch_first)
|
||||
self.fastgrnn2 = None
|
||||
self.rnn2 = None
|
||||
last_output_size = self.hidden_units_list[0]
|
||||
if self.num_layers > 1:
|
||||
self.fastgrnn2 = FastGRNN(self.hidden_units_list[0], self.hidden_units_list[1],
|
||||
self.rnn2 = FastGRNN(self.hidden_units_list[0], self.hidden_units_list[1],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
|
||||
wSparsity=self.wSparsity_list[1], uSparsity=self.uSparsity_list[1],
|
||||
batch_first = self.batch_first)
|
||||
last_output_size = self.hidden_units_list[1]
|
||||
self.fastgrnn3 = None
|
||||
self.rnn3 = None
|
||||
if self.num_layers > 2:
|
||||
self.fastgrnn3 = FastGRNN(self.hidden_units_list[1], self.hidden_units_list[2],
|
||||
self.rnn3 = FastGRNN(self.hidden_units_list[1], self.hidden_units_list[2],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
|
||||
|
@ -79,25 +79,25 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
self.init_hidden()
|
||||
|
||||
def sparsify(self):
|
||||
self.fastgrnn1.cell.sparsify()
|
||||
self.rnn1.cell.sparsify()
|
||||
if self.num_layers > 1:
|
||||
self.fastgrnn2.cell.sparsify()
|
||||
self.rnn2.cell.sparsify()
|
||||
if self.num_layers > 2:
|
||||
self.fastgrnn3.cell.sparsify()
|
||||
self.rnn3.cell.sparsify()
|
||||
|
||||
def sparsifyWithSupport(self):
|
||||
self.fastgrnn1.cell.sparsifyWithSupport()
|
||||
self.rnn1.cell.sparsifyWithSupport()
|
||||
if self.num_layers > 1:
|
||||
self.fastgrnn2.cell.sparsifyWithSupport()
|
||||
self.rnn2.cell.sparsifyWithSupport()
|
||||
if self.num_layers > 2:
|
||||
self.fastgrnn3.cell.sparsifyWithSupport()
|
||||
self.rnn3.cell.sparsifyWithSupport()
|
||||
|
||||
def getModelSize(self):
|
||||
total_size = self.fastgrnn1.cell.getModelSize()
|
||||
total_size = self.rnn1.cell.getModelSize()
|
||||
if self.num_layers > 1:
|
||||
total_size += self.fastgrnn2.cell.getModelSize()
|
||||
total_size += self.rnn2.cell.getModelSize()
|
||||
if self.num_layers > 2:
|
||||
total_size += self.fastgrnn3.cell.getModelSize()
|
||||
total_size += self.rnn3.cell.getModelSize()
|
||||
total_size += 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
|
||||
return total_size
|
||||
|
||||
|
@ -146,40 +146,40 @@ def get_model_class(inheritance_class=nn.Module):
|
|||
# input is shape: [seq,batch,feature]
|
||||
if self.mean is not None:
|
||||
input = (input - self.mean) / self.std
|
||||
fastgrnn_out1 = self.fastgrnn1(input, hiddenState=self.hidden1)
|
||||
fastgrnn_output = fastgrnn_out1
|
||||
rnn_out1 = self.rnn1(input, hiddenState=self.hidden1)
|
||||
fastgrnn_output = rnn_out1
|
||||
# we have to detach the hidden states because we may keep them longer than 1 iteration.
|
||||
self.hidden1 = fastgrnn_out1.detach()[-1, :, :]
|
||||
self.hidden1 = rnn_out1.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.fastgrnn1.getVars()
|
||||
fastgrnn_out1 = onnx_exportable_fastgrnn(input, weights,
|
||||
output=fastgrnn_out1, hidden_size=self.hidden_units_list[0],
|
||||
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = fastgrnn_out1
|
||||
if self.fastgrnn2 is not None:
|
||||
fastgrnn_out2 = self.fastgrnn2(fastgrnn_out1, hiddenState=self.hidden2)
|
||||
self.hidden2 = fastgrnn_out2.detach()[-1, :, :]
|
||||
weights = self.rnn1.getVars()
|
||||
rnn_out1 = onnx_exportable_rnn(input, weights, self.rnn1.cell.name,
|
||||
output=rnn_out1, hidden_size=self.hidden_units_list[0],
|
||||
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = rnn_out1
|
||||
if self.rnn2 is not None:
|
||||
rnn_out2 = self.rnn2(rnn_out1, hiddenState=self.hidden2)
|
||||
self.hidden2 = rnn_out2.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.fastgrnn2.getVars()
|
||||
fastgrnn_out2 = onnx_exportable_fastgrnn(fastgrnn_out1, weights,
|
||||
output=fastgrnn_out2, hidden_size=self.hidden_units_list[1],
|
||||
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = fastgrnn_out2
|
||||
if self.fastgrnn3 is not None:
|
||||
fastgrnn_out3 = self.fastgrnn3(fastgrnn_out2, hiddenState=self.hidden3)
|
||||
self.hidden3 = fastgrnn_out3.detach()[-1, :, :]
|
||||
weights = self.rnn2.getVars()
|
||||
rnn_out2 = onnx_exportable_rnn(rnn_out1, weights, self.rnn2.cell.name,
|
||||
output=rnn_out2, hidden_size=self.hidden_units_list[1],
|
||||
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = rnn_out2
|
||||
if self.rnn3 is not None:
|
||||
rnn_out3 = self.rnn3(rnn_out2, hiddenState=self.hidden3)
|
||||
self.hidden3 = rnn_out3.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.fastgrnn3.getVars()
|
||||
fastgrnn_out3 = onnx_exportable_fastgrnn(fastgrnn_out2, weights,
|
||||
output=fastgrnn_out3, hidden_size=self.hidden_units_list[2],
|
||||
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = fastgrnn_out3
|
||||
weights = self.rnn3.getVars()
|
||||
rnn_out3 = onnx_exportable_rnn(rnn_out2, weights, self.rnn3.cell.name,
|
||||
output=rnn_out3, hidden_size=self.hidden_units_list[2],
|
||||
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = rnn_out3
|
||||
if self.linear:
|
||||
fastgrnn_output = self.hidden2keyword(fastgrnn_output[-1, :, :])
|
||||
if self.apply_softmax:
|
||||
|
|
Загрузка…
Ссылка в новой задаче