trying to generlize ONNX exporter

This commit is contained in:
Harsha Vardhan Simhadri 2019-08-01 17:47:59 +05:30
Родитель 92bf361480
Коммит 1329bfb282
2 изменённых файлов: 56 добавлений и 53 удалений

Просмотреть файл

@ -8,15 +8,18 @@ import numpy as np
import edgeml.pytorch.utils as utils
def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank,
gate_nonlinearity, update_nonlinearity):
def onnx_exportable_rnn(input, fargs, cell_name,
output, hidden_size, wRank, uRank,
gate_nonlinearity, update_nonlinearity):
class RNNSymbolic(Function):
@staticmethod
def symbolic(g, *fargs):
# NOTE: args/kwargs contain RNN parameters
return g.op("FastGRNN", *fargs, outputs=1,
hidden_size_i=hidden_size, wRank_i=wRank, uRank_i=uRank,
gate_nonlinearity_s=gate_nonlinearity, update_nonlinearity_s=update_nonlinearity)
return g.op(cell_name, *fargs,
outputs=1, hidden_size_i=hidden_size,
wRank_i=wRank, uRank_i=uRank,
gate_nonlinearity_s=gate_nonlinearity,
update_nonlinearity_s=update_nonlinearity)
@staticmethod
def forward(ctx, *fargs):
@ -107,10 +110,6 @@ class RNNCell(nn.Module):
def uRank(self):
return self._uRank
#@property
#def num_weight_matrices(self):
# return self._num_weight_matrices
@property
def num_W_matrices(self):
return self._num_W_matrices
@ -119,6 +118,10 @@ class RNNCell(nn.Module):
def num_U_matrices(self):
return self._num_U_matrices
@property
def name(self):
raise NotImplementedError()
def forward(self, input, state):
raise NotImplementedError()

Просмотреть файл

@ -46,25 +46,25 @@ def get_model_class(inheritance_class=nn.Module):
# The FastGRNN takes audio sequences as input, and outputs hidden states
# with dimensionality hidden_units.
self.fastgrnn1 = FastGRNN(self.input_dim, self.hidden_units_list[0],
self.rnn1 = FastGRNN(self.input_dim, self.hidden_units_list[0],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity,
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
wSparsity=self.wSparsity_list[0], uSparsity=self.uSparsity_list[0],
batch_first = self.batch_first)
self.fastgrnn2 = None
self.rnn2 = None
last_output_size = self.hidden_units_list[0]
if self.num_layers > 1:
self.fastgrnn2 = FastGRNN(self.hidden_units_list[0], self.hidden_units_list[1],
self.rnn2 = FastGRNN(self.hidden_units_list[0], self.hidden_units_list[1],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity,
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
wSparsity=self.wSparsity_list[1], uSparsity=self.uSparsity_list[1],
batch_first = self.batch_first)
last_output_size = self.hidden_units_list[1]
self.fastgrnn3 = None
self.rnn3 = None
if self.num_layers > 2:
self.fastgrnn3 = FastGRNN(self.hidden_units_list[1], self.hidden_units_list[2],
self.rnn3 = FastGRNN(self.hidden_units_list[1], self.hidden_units_list[2],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity,
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
@ -79,25 +79,25 @@ def get_model_class(inheritance_class=nn.Module):
self.init_hidden()
def sparsify(self):
self.fastgrnn1.cell.sparsify()
self.rnn1.cell.sparsify()
if self.num_layers > 1:
self.fastgrnn2.cell.sparsify()
self.rnn2.cell.sparsify()
if self.num_layers > 2:
self.fastgrnn3.cell.sparsify()
self.rnn3.cell.sparsify()
def sparsifyWithSupport(self):
self.fastgrnn1.cell.sparsifyWithSupport()
self.rnn1.cell.sparsifyWithSupport()
if self.num_layers > 1:
self.fastgrnn2.cell.sparsifyWithSupport()
self.rnn2.cell.sparsifyWithSupport()
if self.num_layers > 2:
self.fastgrnn3.cell.sparsifyWithSupport()
self.rnn3.cell.sparsifyWithSupport()
def getModelSize(self):
total_size = self.fastgrnn1.cell.getModelSize()
total_size = self.rnn1.cell.getModelSize()
if self.num_layers > 1:
total_size += self.fastgrnn2.cell.getModelSize()
total_size += self.rnn2.cell.getModelSize()
if self.num_layers > 2:
total_size += self.fastgrnn3.cell.getModelSize()
total_size += self.rnn3.cell.getModelSize()
total_size += 4 * self.hidden_units_list[self.num_layers-1] * self.num_classes
return total_size
@ -146,40 +146,40 @@ def get_model_class(inheritance_class=nn.Module):
# input is shape: [seq,batch,feature]
if self.mean is not None:
input = (input - self.mean) / self.std
fastgrnn_out1 = self.fastgrnn1(input, hiddenState=self.hidden1)
fastgrnn_output = fastgrnn_out1
rnn_out1 = self.rnn1(input, hiddenState=self.hidden1)
fastgrnn_output = rnn_out1
# we have to detach the hidden states because we may keep them longer than 1 iteration.
self.hidden1 = fastgrnn_out1.detach()[-1, :, :]
self.hidden1 = rnn_out1.detach()[-1, :, :]
if self.tracking:
weights = self.fastgrnn1.getVars()
fastgrnn_out1 = onnx_exportable_fastgrnn(input, weights,
output=fastgrnn_out1, hidden_size=self.hidden_units_list[0],
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = fastgrnn_out1
if self.fastgrnn2 is not None:
fastgrnn_out2 = self.fastgrnn2(fastgrnn_out1, hiddenState=self.hidden2)
self.hidden2 = fastgrnn_out2.detach()[-1, :, :]
weights = self.rnn1.getVars()
rnn_out1 = onnx_exportable_rnn(input, weights, self.rnn1.cell.name,
output=rnn_out1, hidden_size=self.hidden_units_list[0],
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = rnn_out1
if self.rnn2 is not None:
rnn_out2 = self.rnn2(rnn_out1, hiddenState=self.hidden2)
self.hidden2 = rnn_out2.detach()[-1, :, :]
if self.tracking:
weights = self.fastgrnn2.getVars()
fastgrnn_out2 = onnx_exportable_fastgrnn(fastgrnn_out1, weights,
output=fastgrnn_out2, hidden_size=self.hidden_units_list[1],
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = fastgrnn_out2
if self.fastgrnn3 is not None:
fastgrnn_out3 = self.fastgrnn3(fastgrnn_out2, hiddenState=self.hidden3)
self.hidden3 = fastgrnn_out3.detach()[-1, :, :]
weights = self.rnn2.getVars()
rnn_out2 = onnx_exportable_rnn(rnn_out1, weights, self.rnn2.cell.name,
output=rnn_out2, hidden_size=self.hidden_units_list[1],
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = rnn_out2
if self.rnn3 is not None:
rnn_out3 = self.rnn3(rnn_out2, hiddenState=self.hidden3)
self.hidden3 = rnn_out3.detach()[-1, :, :]
if self.tracking:
weights = self.fastgrnn3.getVars()
fastgrnn_out3 = onnx_exportable_fastgrnn(fastgrnn_out2, weights,
output=fastgrnn_out3, hidden_size=self.hidden_units_list[2],
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = fastgrnn_out3
weights = self.rnn3.getVars()
rnn_out3 = onnx_exportable_rnn(rnn_out2, weights, self.rnn3.cell.name,
output=rnn_out3, hidden_size=self.hidden_units_list[2],
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity)
fastgrnn_output = rnn_out3
if self.linear:
fastgrnn_output = self.hidden2keyword(fastgrnn_output[-1, :, :])
if self.apply_softmax: