This commit is contained in:
Harsha Vardhan Simhadri 2019-07-31 23:40:05 +05:30
Родитель cd35fa9796
Коммит eadae5b38f
1 изменённых файлов: 3 добавлений и 3 удалений

Просмотреть файл

@ -8,7 +8,8 @@ import numpy as np
import edgeml.pytorch.utils as utils
def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank, gate_nonlinearity, update_nonlinearity):
def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank,
gate_nonlinearity, update_nonlinearity):
class RNNSymbolic(Function):
@staticmethod
def symbolic(g, *fargs):
@ -25,8 +26,7 @@ def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank, ga
def backward(ctx, *gargs, **gkwargs):
raise RuntimeError("FIXME: Traced RNNs don't support backward")
output_temp = RNNSymbolic.apply(input, *fargs)
return output_temp
return RNNSymbolic.apply(input, *fargs)
def gen_nonlinearity(A, nonlinearity):
'''