file structure changed and multi layer fastgrnn model class added with rolling support.
До Ширина: | Высота: | Размер: 25 KiB После Ширина: | Высота: | Размер: 25 KiB |
До Ширина: | Высота: | Размер: 13 KiB После Ширина: | Высота: | Размер: 13 KiB |
До Ширина: | Высота: | Размер: 10 KiB После Ширина: | Высота: | Размер: 10 KiB |
До Ширина: | Высота: | Размер: 11 KiB После Ширина: | Высота: | Размер: 11 KiB |
До Ширина: | Высота: | Размер: 4.6 KiB После Ширина: | Высота: | Размер: 4.6 KiB |
До Ширина: | Высота: | Размер: 23 KiB После Ширина: | Высота: | Размер: 23 KiB |
|
@ -3,36 +3,57 @@
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
import numpy as np
|
||||
|
||||
def gen_non_linearity(A, non_linearity):
|
||||
def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank, gate_nonlinearity, update_nonlinearity):
|
||||
class RNNSymbolic(Function):
|
||||
@staticmethod
|
||||
def symbolic(g, *fargs):
|
||||
# NOTE: args/kwargs contain RNN parameters
|
||||
return g.op("FastGRNN", *fargs, outputs=1,
|
||||
hidden_size_i=hidden_size, wRank_i=wRank, uRank_i=uRank,
|
||||
gate_nonlinearity_s=gate_nonlinearity, update_nonlinearity_s=update_nonlinearity)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *fargs):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *gargs, **gkwargs):
|
||||
raise RuntimeError("FIXME: Traced RNNs don't support backward")
|
||||
|
||||
output_temp = RNNSymbolic.apply(input, *fargs)
|
||||
return output_temp
|
||||
|
||||
def gen_nonlinearity(A, nonlinearity):
|
||||
'''
|
||||
Returns required activation for a tensor based on the inputs
|
||||
|
||||
non_linearity is either a callable or a value in
|
||||
nonlinearity is either a callable or a value in
|
||||
['tanh', 'sigmoid', 'relu', 'quantTanh', 'quantSigm', 'quantSigm4']
|
||||
'''
|
||||
if non_linearity == "tanh":
|
||||
if nonlinearity == "tanh":
|
||||
return torch.tanh(A)
|
||||
elif non_linearity == "sigmoid":
|
||||
elif nonlinearity == "sigmoid":
|
||||
return torch.sigmoid(A)
|
||||
elif non_linearity == "relu":
|
||||
elif nonlinearity == "relu":
|
||||
return torch.relu(A, 0.0)
|
||||
elif non_linearity == "quantTanh":
|
||||
elif nonlinearity == "quantTanh":
|
||||
return torch.max(torch.min(A, torch.ones_like(A)), -1.0 * torch.ones_like(A))
|
||||
elif non_linearity == "quantSigm":
|
||||
elif nonlinearity == "quantSigm":
|
||||
A = (A + 1.0) / 2.0
|
||||
return torch.max(torch.min(A, torch.ones_like(A)), torch.zeros_like(A))
|
||||
elif non_linearity == "quantSigm4":
|
||||
elif nonlinearity == "quantSigm4":
|
||||
A = (A + 2.0) / 4.0
|
||||
return torch.max(torch.min(A, torch.ones_like(A)), torch.zeros_like(A))
|
||||
else:
|
||||
# non_linearity is a user specified function
|
||||
if not callable(non_linearity):
|
||||
raise ValueError("non_linearity is either a callable or a value " +
|
||||
# nonlinearity is a user specified function
|
||||
if not callable(nonlinearity):
|
||||
raise ValueError("nonlinearity is either a callable or a value " +
|
||||
+ "['tanh', 'sigmoid', 'relu', 'quantTanh', " +
|
||||
"'quantSigm'")
|
||||
return non_linearity(A)
|
||||
return nonlinearity(A)
|
||||
|
||||
|
||||
class BaseRNN(nn.Module):
|
||||
|
@ -49,6 +70,9 @@ class BaseRNN(nn.Module):
|
|||
self.RNNCell = RNNCell
|
||||
self.batch_first = batch_first
|
||||
|
||||
def getVars(self):
|
||||
return self.RNNCell.getVars()
|
||||
|
||||
def forward(self, input, hiddenState=None,
|
||||
cellState=None):
|
||||
if self.batch_first is True:
|
||||
|
@ -111,9 +135,9 @@ class FastGRNNCell(nn.Module):
|
|||
Has multiple activation functions for the gates
|
||||
hidden_size = # hidden units
|
||||
|
||||
gate_non_linearity = nonlinearity for the gate can be chosen from
|
||||
gate_nonlinearity = nonlinearity for the gate can be chosen from
|
||||
[tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
update_non_linearity = nonlinearity for final rnn update
|
||||
update_nonlinearity = nonlinearity for final rnn update
|
||||
can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
|
||||
wRank = rank of W matrix (creates two matrices if not None)
|
||||
|
@ -134,15 +158,15 @@ class FastGRNNCell(nn.Module):
|
|||
W = matmul(W_1, W_2) and U = matmul(U_1, U_2)
|
||||
'''
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None,
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
zetaInit=1.0, nuInit=-4.0, name="FastGRNN"):
|
||||
super(FastGRNNCell, self).__init__()
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [1, 1]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
|
@ -155,17 +179,17 @@ class FastGRNNCell(nn.Module):
|
|||
self._name = name
|
||||
|
||||
if wRank is None:
|
||||
self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size]))
|
||||
self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size]))
|
||||
else:
|
||||
self.W1 = nn.Parameter(0.1 * torch.randn([input_size, wRank]))
|
||||
self.W2 = nn.Parameter(0.1 * torch.randn([wRank, hidden_size]))
|
||||
self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size]))
|
||||
self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank]))
|
||||
|
||||
if uRank is None:
|
||||
self.U = nn.Parameter(
|
||||
0.1 * torch.randn([hidden_size, hidden_size]))
|
||||
else:
|
||||
self.U1 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
|
||||
self.U2 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size]))
|
||||
self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size]))
|
||||
self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank]))
|
||||
|
||||
self.bias_gate = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
self.bias_update = nn.Parameter(torch.ones([1, hidden_size]))
|
||||
|
@ -185,12 +209,12 @@ class FastGRNNCell(nn.Module):
|
|||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def gate_non_linearity(self):
|
||||
return self._gate_non_linearity
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def update_non_linearity(self):
|
||||
return self._update_non_linearity
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
|
@ -214,23 +238,23 @@ class FastGRNNCell(nn.Module):
|
|||
|
||||
def forward(self, input, state):
|
||||
if self._wRank is None:
|
||||
wComp = torch.matmul(input, self.W)
|
||||
wComp = torch.matmul(input, torch.transpose(self.W, 0, 1))
|
||||
else:
|
||||
wComp = torch.matmul(
|
||||
torch.matmul(input, self.W1), self.W2)
|
||||
torch.matmul(input, torch.transpose(self.W1, 0, 1)), torch.transpose(self.W2, 0, 1))
|
||||
|
||||
if self._uRank is None:
|
||||
uComp = torch.matmul(state, self.U)
|
||||
uComp = torch.matmul(state, torch.transpose(self.U, 0, 1))
|
||||
else:
|
||||
uComp = torch.matmul(
|
||||
torch.matmul(state, self.U1), self.U2)
|
||||
torch.matmul(state, torch.transpose(self.U1, 0, 1)), torch.transpose(self.U2, 0, 1))
|
||||
|
||||
pre_comp = wComp + uComp
|
||||
|
||||
z = gen_non_linearity(pre_comp + self.bias_gate,
|
||||
self._gate_non_linearity)
|
||||
c = gen_non_linearity(pre_comp + self.bias_update,
|
||||
self._update_non_linearity)
|
||||
z = gen_nonlinearity(pre_comp + self.bias_gate,
|
||||
self._gate_nonlinearity)
|
||||
c = gen_nonlinearity(pre_comp + self.bias_update,
|
||||
self._update_nonlinearity)
|
||||
new_h = z * state + (torch.sigmoid(self.zeta) *
|
||||
(1.0 - z) + torch.sigmoid(self.nu)) * c
|
||||
|
||||
|
@ -260,7 +284,7 @@ class FastRNNCell(nn.Module):
|
|||
Has multiple activation functions for the gates
|
||||
hidden_size = # hidden units
|
||||
|
||||
update_non_linearity = nonlinearity for final rnn update
|
||||
update_nonlinearity = nonlinearity for final rnn update
|
||||
can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
|
||||
wRank = rank of W matrix (creates two matrices if not None)
|
||||
|
@ -281,13 +305,13 @@ class FastRNNCell(nn.Module):
|
|||
'''
|
||||
|
||||
def __init__(self, input_size, hidden_size,
|
||||
update_non_linearity="tanh", wRank=None, uRank=None,
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
alphaInit=-3.0, betaInit=3.0, name="FastRNN"):
|
||||
super(FastRNNCell, self).__init__()
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [1, 1]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
|
@ -329,8 +353,8 @@ class FastRNNCell(nn.Module):
|
|||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def update_non_linearity(self):
|
||||
return self._update_non_linearity
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
|
@ -367,8 +391,8 @@ class FastRNNCell(nn.Module):
|
|||
|
||||
pre_comp = wComp + uComp
|
||||
|
||||
c = gen_non_linearity(pre_comp + self.bias_update,
|
||||
self._update_non_linearity)
|
||||
c = gen_nonlinearity(pre_comp + self.bias_update,
|
||||
self._update_nonlinearity)
|
||||
new_h = torch.sigmoid(self.beta) * state + \
|
||||
torch.sigmoid(self.alpha) * c
|
||||
|
||||
|
@ -399,9 +423,9 @@ class LSTMLRCell(nn.Module):
|
|||
Has multiple activation functions for the gates
|
||||
hidden_size = # hidden units
|
||||
|
||||
gate_non_linearity = nonlinearity for the gate can be chosen from
|
||||
gate_nonlinearity = nonlinearity for the gate can be chosen from
|
||||
[tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
update_non_linearity = nonlinearity for final rnn update
|
||||
update_nonlinearity = nonlinearity for final rnn update
|
||||
can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
|
||||
wRank = rank of all W matrices
|
||||
|
@ -425,15 +449,15 @@ class LSTMLRCell(nn.Module):
|
|||
Wi = matmul(W, W_i) and Ui = matmul(U, U_i)
|
||||
'''
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None,
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
name="LSTMLR"):
|
||||
super(LSTMLRCell, self).__init__()
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [4, 4]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
|
@ -493,12 +517,12 @@ class LSTMLRCell(nn.Module):
|
|||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def gate_non_linearity(self):
|
||||
return self._gate_non_linearity
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def update_non_linearity(self):
|
||||
return self._update_non_linearity
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
|
@ -557,18 +581,18 @@ class LSTMLRCell(nn.Module):
|
|||
pre_comp3 = wComp3 + uComp3
|
||||
pre_comp4 = wComp4 + uComp4
|
||||
|
||||
i = gen_non_linearity(pre_comp1 + self.bias_i,
|
||||
self._gate_non_linearity)
|
||||
f = gen_non_linearity(pre_comp2 + self.bias_f,
|
||||
self._gate_non_linearity)
|
||||
o = gen_non_linearity(pre_comp4 + self.bias_o,
|
||||
self._gate_non_linearity)
|
||||
i = gen_nonlinearity(pre_comp1 + self.bias_i,
|
||||
self._gate_nonlinearity)
|
||||
f = gen_nonlinearity(pre_comp2 + self.bias_f,
|
||||
self._gate_nonlinearity)
|
||||
o = gen_nonlinearity(pre_comp4 + self.bias_o,
|
||||
self._gate_nonlinearity)
|
||||
|
||||
c_ = gen_non_linearity(pre_comp3 + self.bias_c,
|
||||
self._update_non_linearity)
|
||||
c_ = gen_nonlinearity(pre_comp3 + self.bias_c,
|
||||
self._update_nonlinearity)
|
||||
|
||||
new_c = f * c + i * c_
|
||||
new_h = o * gen_non_linearity(new_c, self._update_non_linearity)
|
||||
new_h = o * gen_nonlinearity(new_c, self._update_nonlinearity)
|
||||
return new_h, new_c
|
||||
|
||||
def getVars(self):
|
||||
|
@ -594,9 +618,9 @@ class GRULRCell(nn.Module):
|
|||
Has multiple activation functions for the gates
|
||||
hidden_size = # hidden units
|
||||
|
||||
gate_non_linearity = nonlinearity for the gate can be chosen from
|
||||
gate_nonlinearity = nonlinearity for the gate can be chosen from
|
||||
[tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
update_non_linearity = nonlinearity for final rnn update
|
||||
update_nonlinearity = nonlinearity for final rnn update
|
||||
can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
|
||||
wRank = rank of W matrix
|
||||
|
@ -618,15 +642,15 @@ class GRULRCell(nn.Module):
|
|||
Wi = matmul(W, W_i) and Ui = matmul(U, U_i)
|
||||
'''
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None,
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
name="GRULR"):
|
||||
super(GRULRCell, self).__init__()
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [3, 3]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
|
@ -680,12 +704,12 @@ class GRULRCell(nn.Module):
|
|||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def gate_non_linearity(self):
|
||||
return self._gate_non_linearity
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def update_non_linearity(self):
|
||||
return self._update_non_linearity
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
|
@ -732,10 +756,10 @@ class GRULRCell(nn.Module):
|
|||
pre_comp1 = wComp1 + uComp1
|
||||
pre_comp2 = wComp2 + uComp2
|
||||
|
||||
r = gen_non_linearity(pre_comp1 + self.bias_r,
|
||||
self._gate_non_linearity)
|
||||
z = gen_non_linearity(pre_comp2 + self.bias_gate,
|
||||
self._gate_non_linearity)
|
||||
r = gen_nonlinearity(pre_comp1 + self.bias_r,
|
||||
self._gate_nonlinearity)
|
||||
z = gen_nonlinearity(pre_comp2 + self.bias_gate,
|
||||
self._gate_nonlinearity)
|
||||
|
||||
if self._uRank is None:
|
||||
pre_comp3 = wComp3 + torch.matmul(r * state, self.U3)
|
||||
|
@ -743,8 +767,8 @@ class GRULRCell(nn.Module):
|
|||
pre_comp3 = wComp3 + \
|
||||
torch.matmul(torch.matmul(r * state, self.U), self.U3)
|
||||
|
||||
c = gen_non_linearity(pre_comp3 + self.bias_update,
|
||||
self._update_non_linearity)
|
||||
c = gen_nonlinearity(pre_comp3 + self.bias_update,
|
||||
self._update_nonlinearity)
|
||||
|
||||
new_h = z * state + (1.0 - z) * c
|
||||
return new_h
|
||||
|
@ -772,9 +796,9 @@ class UGRNNLRCell(nn.Module):
|
|||
Has multiple activation functions for the gates
|
||||
hidden_size = # hidden units
|
||||
|
||||
gate_non_linearity = nonlinearity for the gate can be chosen from
|
||||
gate_nonlinearity = nonlinearity for the gate can be chosen from
|
||||
[tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
update_non_linearity = nonlinearity for final rnn update
|
||||
update_nonlinearity = nonlinearity for final rnn update
|
||||
can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm]
|
||||
|
||||
wRank = rank of W matrix
|
||||
|
@ -795,15 +819,15 @@ class UGRNNLRCell(nn.Module):
|
|||
Wi = matmul(W, W_i) and Ui = matmul(U, U_i)
|
||||
'''
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None,
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
name="UGRNNLR"):
|
||||
super(UGRNNLRCell, self).__init__()
|
||||
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._num_weight_matrices = [2, 2]
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
|
@ -850,12 +874,12 @@ class UGRNNLRCell(nn.Module):
|
|||
return self._hidden_size
|
||||
|
||||
@property
|
||||
def gate_non_linearity(self):
|
||||
return self._gate_non_linearity
|
||||
def gate_nonlinearity(self):
|
||||
return self._gate_nonlinearity
|
||||
|
||||
@property
|
||||
def update_non_linearity(self):
|
||||
return self._update_non_linearity
|
||||
def update_nonlinearity(self):
|
||||
return self._update_nonlinearity
|
||||
|
||||
@property
|
||||
def wRank(self):
|
||||
|
@ -899,10 +923,10 @@ class UGRNNLRCell(nn.Module):
|
|||
pre_comp1 = wComp1 + uComp1
|
||||
pre_comp2 = wComp2 + uComp2
|
||||
|
||||
z = gen_non_linearity(pre_comp1 + self.bias_gate,
|
||||
self._gate_non_linearity)
|
||||
c = gen_non_linearity(pre_comp2 + self.bias_update,
|
||||
self._update_non_linearity)
|
||||
z = gen_nonlinearity(pre_comp1 + self.bias_gate,
|
||||
self._gate_nonlinearity)
|
||||
c = gen_nonlinearity(pre_comp2 + self.bias_update,
|
||||
self._update_nonlinearity)
|
||||
|
||||
new_h = z * state + (1.0 - z) * c
|
||||
return new_h
|
||||
|
@ -927,20 +951,20 @@ class UGRNNLRCell(nn.Module):
|
|||
class LSTM(nn.Module):
|
||||
"""Equivalent to nn.LSTM using LSTMLRCell"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None, batch_first=True):
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None, batch_first=True):
|
||||
super(LSTM, self).__init__()
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
self.batch_first = batch_first
|
||||
|
||||
self.cell = LSTMLRCell(input_size, hidden_size,
|
||||
gate_non_linearity=gate_non_linearity,
|
||||
update_non_linearity=update_non_linearity,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first)
|
||||
|
||||
|
@ -951,20 +975,20 @@ class LSTM(nn.Module):
|
|||
class GRU(nn.Module):
|
||||
"""Equivalent to nn.GRU using GRULRCell"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None, batch_first=True):
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None, batch_first=True):
|
||||
super(GRU, self).__init__()
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
self.batch_first = batch_first
|
||||
|
||||
self.cell = GRULRCell(input_size, hidden_size,
|
||||
gate_non_linearity=gate_non_linearity,
|
||||
update_non_linearity=update_non_linearity,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first)
|
||||
|
||||
|
@ -975,20 +999,20 @@ class GRU(nn.Module):
|
|||
class UGRNN(nn.Module):
|
||||
"""Equivalent to nn.UGRNN using UGRNNLRCell"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None, batch_first=True):
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None, batch_first=True):
|
||||
super(UGRNN, self).__init__()
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
self.batch_first = batch_first
|
||||
|
||||
self.cell = UGRNNLRCell(input_size, hidden_size,
|
||||
gate_non_linearity=gate_non_linearity,
|
||||
update_non_linearity=update_non_linearity,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first)
|
||||
|
||||
|
@ -999,21 +1023,21 @@ class UGRNN(nn.Module):
|
|||
class FastRNN(nn.Module):
|
||||
"""Equivalent to nn.FastRNN using FastRNNCell"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None,
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
alphaInit=-3.0, betaInit=3.0, batch_first=True):
|
||||
super(FastRNN, self).__init__()
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
self.batch_first = batch_first
|
||||
|
||||
self.cell = FastRNNCell(input_size, hidden_size,
|
||||
gate_non_linearity=gate_non_linearity,
|
||||
update_non_linearity=update_non_linearity,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
alphaInit=alphaInit, betaInit=betaInit)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first)
|
||||
|
@ -1025,25 +1049,28 @@ class FastRNN(nn.Module):
|
|||
class FastGRNN(nn.Module):
|
||||
"""Equivalent to nn.FastGRNN using FastGRNNCell"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid",
|
||||
update_non_linearity="tanh", wRank=None, uRank=None,
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
zetaInit=1.0, nuInit=-4.0, batch_first=True):
|
||||
super(FastGRNN, self).__init__()
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_non_linearity = gate_non_linearity
|
||||
self._update_non_linearity = update_non_linearity
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
self.batch_first = batch_first
|
||||
|
||||
self.cell = FastGRNNCell(input_size, hidden_size,
|
||||
gate_non_linearity=gate_non_linearity,
|
||||
update_non_linearity=update_non_linearity,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
zetaInit=zetaInit, nuInit=nuInit)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first)
|
||||
|
||||
def getVars(self):
|
||||
return self.unrollRNN.getVars()
|
||||
|
||||
def forward(self, input, hiddenState=None, cellState=None):
|
||||
return self.unrollRNN(input, hiddenState, cellState)
|
||||
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import pytorch_edgeml.utils as utils
|
||||
import edgeml.pytorch.utils as utils
|
||||
|
||||
|
||||
class BonsaiTrainer:
|
|
@ -5,8 +5,8 @@ import os
|
|||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pytorch_edgeml.utils as utils
|
||||
from pytorch_edgeml.graph.rnn import *
|
||||
import edgeml.pytorch.utils as utils
|
||||
from edgeml.pytorch.graph.rnn import *
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import sys
|
||||
from edgeml.pytorch.graph.rnn import *
|
||||
|
||||
def fastgrnnmodel(inheritance_class=nn.Module):
|
||||
class FastGRNNModel(inheritance_class):
|
||||
"""This class is a PyTorch Module that implements a 1, 2 or 3 layer GRU based audio classifier"""
|
||||
|
||||
def __init__(self, input_dim, num_layers, hidden_units_list, wRank_list, uRank_list, gate_nonlinearity, update_nonlinearity, num_classes=None, linear=True, batch_first=False, apply_softmax=True):
|
||||
"""
|
||||
Initialize the KeywordSpotter with the following parameters:
|
||||
input_dim - the size of the input audio frame in # samples.
|
||||
hidden_units - the size of the hidden state of the FastGrnn nodes
|
||||
num_keywords - the number of predictions to come out of the model.
|
||||
num_layers - the number of FastGrnn layers to use (1, 2 or 3)
|
||||
"""
|
||||
self.input_dim = input_dim
|
||||
self.hidden_units_list = hidden_units_list
|
||||
self.num_layers = num_layers
|
||||
self.num_classes = num_classes
|
||||
self.wRank_list = wRank_list
|
||||
self.uRank_list = uRank_list
|
||||
self.gate_nonlinearity = gate_nonlinearity
|
||||
self.update_nonlinearity = update_nonlinearity
|
||||
self.linear = linear
|
||||
self.batch_first = batch_first
|
||||
self.apply_softmax = apply_softmax
|
||||
if self.linear:
|
||||
if not self.num_classes:
|
||||
raise Exception("num_classes need to be specified if linear is True")
|
||||
|
||||
super(FastGRNNModel, self).__init__()
|
||||
|
||||
# The FastGRNN takes audio sequences as input, and outputs hidden states
|
||||
# with dimensionality hidden_units.
|
||||
self.fastgrnn1 = FastGRNN(self.input_dim, self.hidden_units_list[0],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
|
||||
batch_first = self.batch_first)
|
||||
self.fastgrnn2 = None
|
||||
last_output_size = self.hidden_units_list[0]
|
||||
if self.num_layers > 1:
|
||||
self.fastgrnn2 = FastGRNN(self.hidden_units_list[0], self.hidden_units_list[1],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
|
||||
batch_first = self.batch_first)
|
||||
last_output_size = self.hidden_units_list[1]
|
||||
self.fastgrnn3 = None
|
||||
if self.num_layers > 2:
|
||||
self.fastgrnn3 = FastGRNN(self.hidden_units_list[1], self.hidden_units_list[2],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
|
||||
batch_first = self.batch_first)
|
||||
last_output_size = self.hidden_units_list[2]
|
||||
|
||||
# The linear layer is a fully connected layer that maps from hidden state space
|
||||
# to number of expected keywords
|
||||
if self.linear:
|
||||
self.hidden2keyword = nn.Linear(last_output_size, num_classes)
|
||||
self.init_hidden()
|
||||
|
||||
def normalize(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def name(self):
|
||||
return "{} layer FastGRNN".format(self.num_layers)
|
||||
|
||||
def init_hidden_bag(self, hidden_bag_size, device):
|
||||
self.hidden_bag_size = hidden_bag_size
|
||||
self.device = device
|
||||
self.hidden1_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[0]],
|
||||
dtype=np.float32)).to(self.device)
|
||||
if self.num_layers >= 2:
|
||||
self.hidden2_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[1]],
|
||||
dtype=np.float32)).to(self.device)
|
||||
if self.num_layers == 3:
|
||||
self.hidden3_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[2]],
|
||||
dtype=np.float32)).to(self.device)
|
||||
|
||||
def rolling_step(self):
|
||||
shuffled_indices = list(range(self.hidden_bag_size))
|
||||
np.random.shuffle(shuffled_indices)
|
||||
if self.hidden1 is not None:
|
||||
batch_size = self.hidden1.shape[0]
|
||||
temp_indices = shuffled_indices[:batch_size]
|
||||
self.hidden1_bag[temp_indices, :] = self.hidden1
|
||||
self.hidden1 = self.hidden1_bag[0:batch_size, :]
|
||||
if self.num_layers >= 2:
|
||||
self.hidden2_bag[temp_indices, :] = self.hidden2
|
||||
self.hidden2 = self.hidden2_bag[0:batch_size, :]
|
||||
if self.num_layers == 3:
|
||||
self.hidden3_bag[temp_indices, :] = self.hidden3
|
||||
self.hidden3 = self.hidden3_bag[0:batch_size, :]
|
||||
|
||||
def init_hidden(self):
|
||||
""" Clear the hidden state for the GRU nodes """
|
||||
self.hidden1 = None
|
||||
self.hidden2 = None
|
||||
self.hidden3 = None
|
||||
|
||||
def forward(self, input):
|
||||
""" Perform the forward processing of the given input and return the prediction """
|
||||
# input is shape: [seq,batch,feature]
|
||||
if self.mean is not None:
|
||||
input = (input - self.mean) / self.std
|
||||
fastgrnn_out1 = self.fastgrnn1(input, hiddenState=self.hidden1)
|
||||
fastgrnn_output = fastgrnn_out1
|
||||
# we have to detach the hidden states because we may keep them longer than 1 iteration.
|
||||
self.hidden1 = fastgrnn_out1.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.fastgrnn1.getVars()
|
||||
fastgrnn_out1 = onnx_exportable_fastgrnn(input, weights,
|
||||
output=fastgrnn_out1, hidden_size=self.hidden_units_list[0],
|
||||
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = fastgrnn_out1
|
||||
if self.fastgrnn2 is not None:
|
||||
fastgrnn_out2 = self.fastgrnn2(fastgrnn_out1, hiddenState=self.hidden2)
|
||||
self.hidden2 = fastgrnn_out2.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.fastgrnn2.getVars()
|
||||
fastgrnn_out2 = onnx_exportable_fastgrnn(fastgrnn_out1, weights,
|
||||
output=fastgrnn_out2, hidden_size=self.hidden_units_list[1],
|
||||
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = fastgrnn_out2
|
||||
if self.fastgrnn3 is not None:
|
||||
fastgrnn_out3 = self.fastgrnn3(fastgrnn_out2, hiddenState=self.hidden3)
|
||||
self.hidden3 = fastgrnn_out3.detach()[-1, :, :]
|
||||
if self.tracking:
|
||||
weights = self.fastgrnn3.getVars()
|
||||
fastgrnn_out3 = onnx_exportable_fastgrnn(fastgrnn_out2, weights,
|
||||
output=fastgrnn_out3, hidden_size=self.hidden_units_list[2],
|
||||
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
|
||||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity)
|
||||
fastgrnn_output = fastgrnn_out3
|
||||
if self.linear:
|
||||
fastgrnn_output = self.hidden2keyword(fastgrnn_output[-1, :, :])
|
||||
if self.apply_softmax:
|
||||
fastgrnn_output = F.log_softmax(fastgrnn_output, dim=1)
|
||||
return fastgrnn_output
|
||||
return FastGRNNModel
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import pytorch_edgeml.utils as utils
|
||||
import edgeml.pytorch.utils as utils
|
||||
|
||||
|
||||
class ProtoNNTrainer:
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import pytorch_edgeml.utils as utils
|
||||
import edgeml.pytorch.utils as utils
|
||||
|
||||
|
||||
class SRNNTrainer:
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
from __future__ import print_function
|
||||
import tensorflow as tf
|
||||
import edgeml.utils as utils
|
||||
import edgeml.tf.utils as utils
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import sys
|
||||
import edgeml.utils as utils
|
||||
import edgeml.tf.utils as utils
|
||||
import pandas as pd
|
||||
|
||||
class EMI_Trainer:
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import os
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
import edgeml.utils as utils
|
||||
import edgeml.tf.utils as utils
|
||||
import numpy as np
|
||||
|
||||
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import sys
|
||||
import edgeml.utils as utils
|
||||
import edgeml.tf.utils as utils
|
||||
|
||||
|
||||
class ProtoNNTrainer:
|
|
@ -4,7 +4,7 @@ This directory includes, example notebook and general execution script of
|
|||
Bonsai developed as part of EdgeML. Also, we include a sample cleanup and
|
||||
use-case on the USPS10 public dataset.
|
||||
|
||||
`pytorch_edgeml.graph.bonsai` implements the Bonsai prediction graph in pytorch.
|
||||
`edgeml.pytorch.graph.bonsai` implements the Bonsai prediction graph in pytorch.
|
||||
The three-phase training routine for Bonsai is decoupled from the forward graph
|
||||
to facilitate a plug and play behaviour wherein Bonsai can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs).
|
|
@ -4,8 +4,8 @@
|
|||
import helpermethods
|
||||
import numpy as np
|
||||
import sys
|
||||
from pytorch_edgeml.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from pytorch_edgeml.graph.bonsai import Bonsai
|
||||
from edgeml.pytorch.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from edgeml.pytorch.graph.bonsai import Bonsai
|
||||
import torch
|
||||
|
||||
|
|
@ -5,21 +5,21 @@ FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified
|
|||
UGRNN, GRU and LSTM to support the LSQ training routine.
|
||||
Also, we include a sample cleanup and use-case on the USPS10 public dataset.
|
||||
|
||||
`pytorch_edgeml.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../pytorch_edgeml/graph/rnn.py#L226)) and **FastGRNN** ([`FastGRNNCell`](../../pytorch_edgeml/graph/rnn.py#L80)) with
|
||||
`edgeml.pytorch.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../pytorch_edgeml/graph/rnn.py#L226)) and **FastGRNN** ([`FastGRNNCell`](../../pytorch_edgeml/graph/rnn.py#L80)) with
|
||||
multiple additional features like Low-Rank parameterisation, custom
|
||||
non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training
|
||||
routine for FastRNN and FastGRNN is decoupled from the custom cells to
|
||||
facilitate a plug and play behaviour of the custom RNN cells in other
|
||||
architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell` etc.,
|
||||
`pytorch_edgeml.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../pytorch_edgeml/graph/rnn.py#L742)),
|
||||
`edgeml.pytorch.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../pytorch_edgeml/graph/rnn.py#L742)),
|
||||
**GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L565)) and **LSTM** ([`LSTMLRCell`](../../pytorch_edgeml/graph/rnn.py#L369)). These cells also can be substituted for FastCells where ever feasible.
|
||||
|
||||
`pytorch_edgeml.graph.rnn` also contains fully wrapped RNNs which are equivalent to `nn.LSTM` and `nn.GRU`. Implemented cells:
|
||||
**FastRNN** ([`FastRNN`](../../pytorch_edgeml/graph/rnn.py#L968)), **FastGRNN** ([`FastGRNN`](../../pytorch_edgeml/graph/rnn.py#L993)), **UGRNN** ([`UGRNN`](../../pytorch_edgeml/graph/rnn.py#L945)), **GRU** ([`GRU`](../../edgeml/graph/rnn.py#L922)) and **LSTM** ([`LSTM`](../../pytorch_edgeml/graph/rnn.py#L899)).
|
||||
`edgeml.pytorch.graph.rnn` also contains fully wrapped RNNs which are equivalent to `nn.LSTM` and `nn.GRU`. Implemented cells:
|
||||
**FastRNN** ([`FastRNN`](../../pytorch_edgeml/graph/rnn.py#L968)), **FastGRNN** ([`FastGRNN`](../../pytorch_edgeml/graph/rnn.py#L993)), **UGRNN** ([`UGRNN`](../../edgeml.pytorch/graph/rnn.py#L945)), **GRU** ([`GRU`](../../edgeml/graph/rnn.py#L922)) and **LSTM** ([`LSTM`](../../pytorch_edgeml/graph/rnn.py#L899)).
|
||||
|
||||
Note that all the cells and wrappers (when used independently from `fastcell_example.py` or `pytorch_edgeml.trainer.fastTrainer`) take in data in a batch first format ie., [batchSize, timeSteps, inputDims] by default but it can also support [timeSteps, batchSize, inputDims] format by setting `batch_first` argument to False when used. `fast_example.py` automatically takes care it while assuming the standard format between tf, c++ and pytorch.
|
||||
Note that all the cells and wrappers (when used independently from `fastcell_example.py` or `edgeml.pytorch.trainer.fastTrainer`) take in data in a batch first format ie., [batchSize, timeSteps, inputDims] by default but it can also support [timeSteps, batchSize, inputDims] format by setting `batch_first` argument to False when used. `fast_example.py` automatically takes care it while assuming the standard format between tf, c++ and pytorch.
|
||||
|
||||
For training FastCells, `pytorch_edgeml.trainer.fastTrainer` implements the three-phase
|
||||
For training FastCells, `edgeml.pytorch.trainer.fastTrainer` implements the three-phase
|
||||
FastCell training routine in PyTorch. A simple example,
|
||||
`examples/fastcell_example.py` is provided to illustrate its usage.
|
||||
|
|
@ -5,8 +5,8 @@ import helpermethods
|
|||
import torch
|
||||
import numpy as np
|
||||
import sys
|
||||
from pytorch_edgeml.graph.rnn import *
|
||||
from pytorch_edgeml.trainer.fastTrainer import FastTrainer
|
||||
from edgeml.pytorch.graph.rnn import *
|
||||
from edgeml.pytorch.trainer.fastTrainer import FastTrainer
|
||||
|
||||
|
||||
def main():
|
|
@ -4,11 +4,11 @@ This directory includes an example [notebook](protoNN_example.ipynb) and a
|
|||
command line execution script of ProtoNN developed as part of EdgeML. The
|
||||
example is based on the USPS dataset.
|
||||
|
||||
`pytorch_edgeml.graph.protoNN` implements the ProtoNN prediction functions.
|
||||
`edgeml.pytorch.graph.protoNN` implements the ProtoNN prediction functions.
|
||||
The training routine for ProtoNN is decoupled from the forward graph to
|
||||
facilitate a plug and play behaviour wherein ProtoNN can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs). The
|
||||
training routine is implemented in `pytorch_edgeml.trainer.protoNNTrainer`.
|
||||
training routine is implemented in `edgeml.pytorch.trainer.protoNNTrainer`.
|
||||
(This is also an artifact of consistency requirements with Tensorflow
|
||||
implementation).
|
||||
|
|
@ -5,7 +5,7 @@ from __future__ import print_function
|
|||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import pytorch_edgeml.utils as utils
|
||||
import edgeml.pytorch.utils as utils
|
||||
import argparse
|
||||
|
||||
|
|
@ -17,9 +17,9 @@
|
|||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from pytorch_edgeml.graph.protoNN import ProtoNN\n",
|
||||
"from pytorch_edgeml.trainer.protoNNTrainer import ProtoNNTrainer\n",
|
||||
"import pytorch_edgeml.utils as utils\n",
|
||||
"from edgeml.pytorch.graph.protoNN import ProtoNN\n",
|
||||
"from edgeml.pytorch.trainer.protoNNTrainer import ProtoNNTrainer\n",
|
||||
"import edgeml.pytorch.utils as utils\n",
|
||||
"import helpermethods as helper"
|
||||
]
|
||||
},
|
|
@ -5,9 +5,9 @@ from __future__ import print_function
|
|||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from pytorch_edgeml.trainer.protoNNTrainer import ProtoNNTrainer
|
||||
from pytorch_edgeml.graph.protoNN import ProtoNN
|
||||
import pytorch_edgeml.utils as utils
|
||||
from edgeml.pytorch.trainer.protoNNTrainer import ProtoNNTrainer
|
||||
from edgeml.pytorch.graph.protoNN import ProtoNN
|
||||
import edgeml.pytorch.utils as utils
|
||||
import helpermethods as helper
|
||||
import torch
|
||||
|
|
@ -1,12 +1,13 @@
|
|||
# Pytorch Shallow RNN Examples
|
||||
|
||||
This directory includes an example [notebook](SRNN_Example.ipynb) of how to use
|
||||
SRNN on the [Google Speech Commands
|
||||
This directory includes an example [notebook](SRNN_Example.ipynb) and a
|
||||
[python script](SRNN_Example.py) that explains the basic setup of SRNN by
|
||||
training a simple model on the [Google Speech Commands
|
||||
Dataset](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html).
|
||||
|
||||
`pytorch_edgeml.graph.rnn.SRNN2` implements a 2 layer SRNN network. We will use
|
||||
`edgeml.pytorch.graph.rnn.SRNN2` implements a 2 layer SRNN network. We will use
|
||||
this with an LSTM cell on this dataset. The training routine for SRNN is
|
||||
implemented in `pytorch_edgeml.trainer.srnnTrainer` and will be used as part of
|
||||
implemented in `edgeml.pytorch.trainer.srnnTrainer` and will be used as part of
|
||||
this example.
|
||||
|
||||
**Tested With:** pytorch > 1.1.0 with Python 2 and Python 3
|
||||
|
@ -23,8 +24,17 @@ dataset and write numpy files that confirm to the required format.
|
|||
./fetch_google.py
|
||||
python process_google.py
|
||||
|
||||
With the provided configuration, you can expect a validation accuracy of about
|
||||
92%.
|
||||
## Training the Model
|
||||
|
||||
A sample [notebook](SRNN_Example.ipynb) and a corresponding command line script
|
||||
is provided for training. To run the command line script, please use:
|
||||
|
||||
```
|
||||
python SRNN_Example.py --data-dir ./GoogleSpeech/Extracted --brick-size 11
|
||||
```
|
||||
|
||||
With the provided default configuration, you can expect a validation accuracy
|
||||
of about 92%.
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Licensed under the MIT license.
|
|
@ -4,9 +4,10 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SRNN on Speech Commands Dataset\n",
|
||||
"## SRNN on Speech Commands Dataset\n",
|
||||
"\n",
|
||||
"Please use `fetch_google.sh` to download the Google Speech Commands Dataset and `python process_google.py` to create feature extracted data."
|
||||
"\n",
|
||||
"Please use `fetch_google.sh` to download the Google Speech Commands Dataset and python `process_google.py` to create feature extracted data."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -14,8 +15,8 @@
|
|||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:51.914361Z",
|
||||
"start_time": "2019-07-14T12:52:51.667856Z"
|
||||
"end_time": "2019-07-17T17:19:40.265370Z",
|
||||
"start_time": "2019-07-17T17:19:39.980681Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -24,12 +25,7 @@
|
|||
"import sys\n",
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from pytorch_edgeml.graph.rnn import SRNN2\n",
|
||||
"from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer\n",
|
||||
"import pytorch_edgeml.utils as utils"
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
||||
"import torch"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -37,23 +33,15 @@
|
|||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.040100Z",
|
||||
"start_time": "2019-07-14T12:52:51.916533Z"
|
||||
"end_time": "2019-07-17T17:19:40.273918Z",
|
||||
"start_time": "2019-07-17T17:19:40.267871Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_DIR = './GoogleSpeech/Extracted/'\n",
|
||||
"x_train_, y_train = np.squeeze(np.load(DATA_DIR + 'x_train.npy')), np.squeeze(np.load(DATA_DIR + 'y_train.npy'))\n",
|
||||
"x_val_, y_val = np.squeeze(np.load(DATA_DIR + 'x_val.npy')), np.squeeze(np.load(DATA_DIR + 'y_val.npy'))\n",
|
||||
"x_test_, y_test = np.squeeze(np.load(DATA_DIR + 'x_test.npy')), np.squeeze(np.load(DATA_DIR + 'y_test.npy'))\n",
|
||||
"# Mean-var normalize\n",
|
||||
"mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
|
||||
"std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
|
||||
"std[std[:] < 0.000001] = 1\n",
|
||||
"x_train_ = (x_train_ - mean) / std\n",
|
||||
"x_val_ = (x_val_ - mean) / std\n",
|
||||
"x_test_ = (x_test_ - mean) / std"
|
||||
"from edgeml.pytorch.graph.rnn import SRNN2\n",
|
||||
"from edgeml.pytorch.trainer.srnnTrainer import SRNNTrainer\n",
|
||||
"import edgeml.pytorch.utils as utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -61,8 +49,23 @@
|
|||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.047992Z",
|
||||
"start_time": "2019-07-14T12:52:56.042445Z"
|
||||
"end_time": "2019-07-17T17:19:40.340949Z",
|
||||
"start_time": "2019-07-17T17:19:40.275892Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"DATA_DIR = './GoogleSpeech/Extracted/'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-17T17:19:45.603764Z",
|
||||
"start_time": "2019-07-17T17:19:40.343295Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
|
@ -77,6 +80,17 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"x_train_, y_train = np.load(DATA_DIR + 'x_train.npy'), np.load(DATA_DIR + 'y_train.npy')\n",
|
||||
"x_val_, y_val = np.load(DATA_DIR + 'x_val.npy'), np.load(DATA_DIR + 'y_val.npy')\n",
|
||||
"x_test_, y_test = np.load(DATA_DIR + 'x_test.npy'), np.load(DATA_DIR + 'y_test.npy')\n",
|
||||
"# Mean-var normalize\n",
|
||||
"mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
|
||||
"std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
|
||||
"std[std[:] < 0.000001] = 1\n",
|
||||
"x_train_ = (x_train_ - mean) / std\n",
|
||||
"x_val_ = (x_val_ - mean) / std\n",
|
||||
"x_test_ = (x_test_ - mean) / std\n",
|
||||
"\n",
|
||||
"x_train = np.swapaxes(x_train_, 0, 1)\n",
|
||||
"x_val = np.swapaxes(x_val_, 0, 1)\n",
|
||||
"x_test = np.swapaxes(x_test_, 0, 1)\n",
|
||||
|
@ -87,20 +101,21 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.068329Z",
|
||||
"start_time": "2019-07-14T12:52:56.049725Z"
|
||||
"end_time": "2019-07-17T17:19:45.610070Z",
|
||||
"start_time": "2019-07-17T17:19:45.606282Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"numTimeSteps = x_train.shape[0]\n",
|
||||
"numInput = x_train.shape[-1]\n",
|
||||
"brickSize = 11\n",
|
||||
"numClasses = y_train.shape[1]\n",
|
||||
"\n",
|
||||
"# Network Parameters\n",
|
||||
"brickSize = 11\n",
|
||||
"hiddenDim0 = 64\n",
|
||||
"hiddenDim1 = 32\n",
|
||||
"cellType = 'LSTM'\n",
|
||||
|
@ -111,11 +126,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.088534Z",
|
||||
"start_time": "2019-07-14T12:52:56.070114Z"
|
||||
"end_time": "2019-07-17T17:19:49.576076Z",
|
||||
"start_time": "2019-07-17T17:19:45.612629Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
|
@ -128,17 +143,17 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device)\n",
|
||||
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device) \n",
|
||||
"trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:59:52.893161Z",
|
||||
"start_time": "2019-07-14T12:52:56.090327Z"
|
||||
"end_time": "2019-07-17T17:20:28.680161Z",
|
||||
"start_time": "2019-07-17T17:19:49.578246Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
|
@ -146,28 +161,28 @@
|
|||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 0 batch 0 loss 2.799139 acc 0.031250\n",
|
||||
"Epoch 0 batch 200 loss 0.784248 acc 0.750000\n",
|
||||
"Epoch 1 batch 0 loss 0.379059 acc 0.875000\n",
|
||||
"Epoch 1 batch 200 loss 0.544366 acc 0.820312\n",
|
||||
"Epoch 2 batch 0 loss 0.272113 acc 0.914062\n",
|
||||
"Epoch 2 batch 200 loss 0.400919 acc 0.867188\n",
|
||||
"Epoch 3 batch 0 loss 0.200825 acc 0.953125\n",
|
||||
"Epoch 3 batch 200 loss 0.248952 acc 0.906250\n",
|
||||
"Epoch 4 batch 0 loss 0.161245 acc 0.960938\n",
|
||||
"Epoch 4 batch 200 loss 0.294340 acc 0.875000\n",
|
||||
"Validation accuracy: 0.913063\n",
|
||||
"Epoch 5 batch 0 loss 0.159573 acc 0.953125\n",
|
||||
"Epoch 5 batch 200 loss 0.233308 acc 0.937500\n",
|
||||
"Epoch 6 batch 0 loss 0.068345 acc 0.984375\n",
|
||||
"Epoch 6 batch 200 loss 0.225371 acc 0.937500\n",
|
||||
"Epoch 7 batch 0 loss 0.112335 acc 0.968750\n",
|
||||
"Epoch 7 batch 200 loss 0.170626 acc 0.945312\n",
|
||||
"Epoch 8 batch 0 loss 0.168985 acc 0.945312\n",
|
||||
"Epoch 8 batch 200 loss 0.160869 acc 0.937500\n",
|
||||
"Epoch 9 batch 0 loss 0.123516 acc 0.953125\n",
|
||||
"Epoch 9 batch 200 loss 0.172936 acc 0.937500\n",
|
||||
"Validation accuracy: 0.908208\n"
|
||||
"Epoch 0 batch 0 loss 4.295151 acc 0.031250\n",
|
||||
"Epoch 0 batch 200 loss 1.002617 acc 0.718750\n",
|
||||
"Epoch 1 batch 0 loss 0.647069 acc 0.796875\n",
|
||||
"Epoch 1 batch 200 loss 0.469229 acc 0.835938\n",
|
||||
"Epoch 2 batch 0 loss 0.388671 acc 0.882812\n",
|
||||
"Epoch 2 batch 200 loss 0.396696 acc 0.859375\n",
|
||||
"Epoch 3 batch 0 loss 0.266433 acc 0.921875\n",
|
||||
"Epoch 3 batch 200 loss 0.281694 acc 0.867188\n",
|
||||
"Epoch 4 batch 0 loss 0.302240 acc 0.906250\n",
|
||||
"Epoch 4 batch 200 loss 0.245797 acc 0.929688\n",
|
||||
"Validation accuracy: 0.911003\n",
|
||||
"Epoch 5 batch 0 loss 0.202542 acc 0.945312\n",
|
||||
"Epoch 5 batch 200 loss 0.192004 acc 0.929688\n",
|
||||
"Epoch 6 batch 0 loss 0.256735 acc 0.921875\n",
|
||||
"Epoch 6 batch 200 loss 0.279066 acc 0.921875\n",
|
||||
"Epoch 7 batch 0 loss 0.228837 acc 0.945312\n",
|
||||
"Epoch 7 batch 200 loss 0.222357 acc 0.937500\n",
|
||||
"Epoch 8 batch 0 loss 0.164639 acc 0.960938\n",
|
||||
"Epoch 8 batch 200 loss 0.160117 acc 0.945312\n",
|
||||
"Epoch 9 batch 0 loss 0.173849 acc 0.953125\n",
|
||||
"Epoch 9 batch 200 loss 0.201694 acc 0.929688\n",
|
||||
"Validation accuracy: 0.912474\n"
|
||||
]
|
||||
}
|
||||
],
|
|
@ -1,19 +1,28 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from __future__ import print_function
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_edgeml.graph.rnn import SRNN2
|
||||
from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer
|
||||
import pytorch_edgeml.utils as utils
|
||||
from edgeml.pytorch.graph.rnn import SRNN2
|
||||
from edgeml.pytorch.trainer.srnnTrainer import SRNNTrainer
|
||||
import edgeml.pytorch.utils as utils
|
||||
import helpermethods as helper
|
||||
|
||||
config = helper.getSRNN2Args()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
DATA_DIR = '/datadrive/data/SRNN/GoogleSpeech/Extracted/'
|
||||
x_train_, y_train = np.squeeze(np.load(DATA_DIR + 'x_train.npy')), np.squeeze(np.load(DATA_DIR + 'y_train.npy'))
|
||||
x_val_, y_val = np.squeeze(np.load(DATA_DIR + 'x_val.npy')), np.squeeze(np.load(DATA_DIR + 'y_val.npy'))
|
||||
x_test_, y_test = np.squeeze(np.load(DATA_DIR + 'x_test.npy')), np.squeeze(np.load(DATA_DIR + 'y_test.npy'))
|
||||
DATA_DIR = config.data_dir
|
||||
x_train_ = np.load(DATA_DIR + 'x_train.npy')
|
||||
y_train = np.load(DATA_DIR + 'y_train.npy')
|
||||
x_val_ = np.load(DATA_DIR + 'x_val.npy')
|
||||
y_val = np.load(DATA_DIR + 'y_val.npy')
|
||||
x_test_ = np.load(DATA_DIR + 'x_test.npy')
|
||||
y_test = np.load(DATA_DIR + 'y_test.npy')
|
||||
|
||||
# Mean-var normalize
|
||||
mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)
|
||||
std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)
|
||||
|
@ -31,17 +40,20 @@ print("Test shape", x_test.shape, y_test.shape)
|
|||
|
||||
numTimeSteps = x_train.shape[0]
|
||||
numInput = x_train.shape[-1]
|
||||
brickSize = 11
|
||||
numClasses = y_train.shape[1]
|
||||
|
||||
hiddenDim0 = 64
|
||||
hiddenDim1 = 32
|
||||
cellType = 'LSTM'
|
||||
learningRate = 0.01
|
||||
batchSize = 128
|
||||
epochs = 10
|
||||
hiddenDim0 = config.hidden_dim0
|
||||
brickSize = config.brick_size
|
||||
hiddenDim1 = config.hidden_dim1
|
||||
cellType = config.cell_type
|
||||
learningRate = config.learning_rate
|
||||
batchSize = config.batch_size
|
||||
epochs = config.epochs
|
||||
printStep = config.print_step
|
||||
valStep = config.val_step
|
||||
|
||||
srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device)
|
||||
srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device)
|
||||
trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)
|
||||
|
||||
trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val, printStep=200, valStep=5)
|
||||
trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val,
|
||||
printStep=printStep, valStep=valStep)
|
0
pytorch/examples/SRNN/fetch_google.sh → examples/pytorch/SRNN/fetch_google.sh
Executable file → Normal file
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
|
||||
def getSRNN2Args():
|
||||
def checkIntPos(value):
|
||||
ivalue = int(value)
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid positive int value" % value)
|
||||
return ivalue
|
||||
|
||||
def checkIntNneg(value):
|
||||
ivalue = int(value)
|
||||
if ivalue < 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid non-neg int value" % value)
|
||||
return ivalue
|
||||
|
||||
def checkFloatNneg(value):
|
||||
fvalue = float(value)
|
||||
if fvalue < 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid non-neg float value" % value)
|
||||
return fvalue
|
||||
|
||||
def checkFloatPos(value):
|
||||
fvalue = float(value)
|
||||
if fvalue <= 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is an invalid positive float value" % value)
|
||||
return fvalue
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Hyperparameters for 2 layer SRNN Algorithm')
|
||||
|
||||
parser.add_argument('-d', '--data-dir', required=True,
|
||||
help='Directory containing processed data.')
|
||||
parser.add_argument('-h0', '--hidden-dim0', type=checkIntPos, default=64,
|
||||
help='Hidden dimension of lower layer RNN cell.')
|
||||
parser.add_argument('-h1', '--hidden-dim1', type=checkIntPos, default=32,
|
||||
help='Hidden dimension of upper layer RNN cell.')
|
||||
parser.add_argument('-bz', '--brick-size', type=checkIntPos, required=True,
|
||||
help='Brick size to be used at the lower layer.')
|
||||
parser.add_argument('-c', '--cell-type', default='LSTM',
|
||||
help='Type of RNN cell to use among [LSTM, FastRNN, ' +
|
||||
'FastGRNN')
|
||||
|
||||
parser.add_argument('-p', '--num-prototypes', type=checkIntPos, default=20,
|
||||
help='Number of prototypes.')
|
||||
parser.add_argument('-g', '--gamma', type=checkFloatPos, default=None,
|
||||
help='Gamma for Gaussian kernel. If not provided, ' +
|
||||
'median heuristic will be used to estimate gamma.')
|
||||
|
||||
parser.add_argument('-e', '--epochs', type=checkIntPos, default=10,
|
||||
help='Total training epochs.')
|
||||
parser.add_argument('-b', '--batch-size', type=checkIntPos, default=128,
|
||||
help='Batch size for each pass.')
|
||||
parser.add_argument('-r', '--learning-rate', type=checkFloatPos,
|
||||
default=0.01,
|
||||
help='Learning rate for ADAM Optimizer.')
|
||||
|
||||
parser.add_argument('-pS', '--print-step', type=int, default=200,
|
||||
help='The number of update steps between print ' +
|
||||
'calls to console.')
|
||||
parser.add_argument('-vS', '--val-step', type=int, default=5,
|
||||
help='The number of epochs between validation' +
|
||||
'performance evaluation')
|
||||
return parser.parse_args()
|
|
@ -4,7 +4,7 @@ This directory includes, example notebook and general execution script of
|
|||
Bonsai developed as part of EdgeML. Also, we include a sample cleanup and
|
||||
use-case on the USPS10 public dataset.
|
||||
|
||||
`edgeml.graph.bonsai` implements the Bonsai prediction graph in tensorflow.
|
||||
`edgeml.tf.graph.bonsai` implements the Bonsai prediction graph in tensorflow.
|
||||
The three-phase training routine for Bonsai is decoupled from the forward graph
|
||||
to facilitate a plug and play behaviour wherein Bonsai can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs).
|
|
@ -33,8 +33,8 @@
|
|||
"os.environ['CUDA_VISIBLE_DEVICES'] =''\n",
|
||||
"\n",
|
||||
"#Bonsai imports\n",
|
||||
"from edgeml.trainer.bonsaiTrainer import BonsaiTrainer\n",
|
||||
"from edgeml.graph.bonsai import Bonsai\n",
|
||||
"from edgeml.tf.trainer.bonsaiTrainer import BonsaiTrainer\n",
|
||||
"from edgeml.tf.graph.bonsai import Bonsai\n",
|
||||
"\n",
|
||||
"# Fixing seeds for reproducibility\n",
|
||||
"tf.set_random_seed(42)\n",
|
|
@ -5,8 +5,8 @@ import helpermethods
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import sys
|
||||
from edgeml.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from edgeml.graph.bonsai import Bonsai
|
||||
from edgeml.tf.trainer.bonsaiTrainer import BonsaiTrainer
|
||||
from edgeml.tf.graph.bonsai import Bonsai
|
||||
|
||||
|
||||
def main():
|
|
@ -40,10 +40,10 @@
|
|||
"tf.set_random_seed(42)\n",
|
||||
"\n",
|
||||
"# MI-RNN and EMI-RNN imports\n",
|
||||
"from edgeml.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.utils"
|
||||
"from edgeml.tf.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.tf.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml.tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.tf.utils"
|
||||
]
|
||||
},
|
||||
{
|
|
@ -34,11 +34,11 @@
|
|||
"os.environ['CUDA_VISIBLE_DEVICES'] ='1'\n",
|
||||
"\n",
|
||||
"# FastGRNN and FastRNN imports\n",
|
||||
"from edgeml.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.graph.rnn import EMI_FastGRNN\n",
|
||||
"from edgeml.graph.rnn import EMI_FastRNN\n",
|
||||
"from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.utils"
|
||||
"from edgeml.tf.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.tf.graph.rnn import EMI_FastGRNN\n",
|
||||
"from edgeml.tf.graph.rnn import EMI_FastRNN\n",
|
||||
"from edgeml.tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.tf.utils"
|
||||
]
|
||||
},
|
||||
{
|
|
@ -35,10 +35,10 @@
|
|||
"os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n",
|
||||
"\n",
|
||||
"# MI-RNN and EMI-RNN imports\n",
|
||||
"from edgeml.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.utils"
|
||||
"from edgeml.tf.graph.rnn import EMI_DataPipeline\n",
|
||||
"from edgeml.tf.graph.rnn import EMI_BasicLSTM\n",
|
||||
"from edgeml.tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n",
|
||||
"import edgeml.tf.utils"
|
||||
]
|
||||
},
|
||||
{
|
|
@ -3,7 +3,7 @@
|
|||
This directory includes example notebooks EMI-RNN developed as part of EdgeML.
|
||||
The example is based on the UCI Human Activity Recognition dataset.
|
||||
|
||||
Please refer to `tf/docs/EMI-RNN.md` for detailed documentation of EMI-RNN.
|
||||
Please refer to `docs/EMI-RNN.md` for detailed documentation of EMI-RNN.
|
||||
|
||||
Please refer to `00_emi_lstm_example.ipynb` for a quick and dirty getting
|
||||
started guide.
|
0
tf/examples/EMI-RNN/img/3PartsGraph.png → examples/tf/EMI-RNN/img/3PartsGraph.png
Executable file → Normal file
До Ширина: | Высота: | Размер: 25 KiB После Ширина: | Высота: | Размер: 25 KiB |
0
tf/examples/EMI-RNN/img/MIML_illustration.png → examples/tf/EMI-RNN/img/MIML_illustration.png
Executable file → Normal file
До Ширина: | Высота: | Размер: 23 KiB После Ширина: | Высота: | Размер: 23 KiB |
|
@ -5,7 +5,7 @@ FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified
|
|||
UGRNN, GRU and LSTM to support the LSQ training routine.
|
||||
Also, we include a sample cleanup and use-case on the USPS10 public dataset.
|
||||
|
||||
`edgeml.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with
|
||||
`edgeml.tf.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with
|
||||
multiple additional features like Low-Rank parameterisation, custom
|
||||
non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training
|
||||
routine for FastRNN and FastGRNN is decoupled from the custom cells to
|
||||
|
@ -14,9 +14,9 @@ architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `G
|
|||
`edgeml.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../edgeml/graph/rnn.py#L862)),
|
||||
**GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L635)) and **LSTM** ([`LSTMLRCell`](../../edgeml/graph/rnn.py#L376)). These cells also can be substituted for FastCells where ever feasible.
|
||||
|
||||
For training FastCells, `edgeml.trainer.fastTrainer` implements the three-phase
|
||||
For training FastCells, `edgeml.tf.trainer.fastTrainer` implements the three-phase
|
||||
FastCell training routine in Tensorflow. A simple example,
|
||||
`examples/fastcell_example.py` is provided to illustrate its usage.
|
||||
`examples/tf/fastcell_example.py` is provided to illustrate its usage.
|
||||
|
||||
Note that `fastcell_example.py` assumes that data is in a specific format. It
|
||||
is assumed that train and test data is contained in two files, `train.npy` and
|
|
@ -28,12 +28,12 @@
|
|||
"os.environ['CUDA_VISIBLE_DEVICES'] =''\n",
|
||||
"\n",
|
||||
"#FastRNN and FastGRNN imports\n",
|
||||
"from edgeml.trainer.fastTrainer import FastTrainer\n",
|
||||
"from edgeml.graph.rnn import FastGRNNCell\n",
|
||||
"from edgeml.graph.rnn import FastRNNCell\n",
|
||||
"from edgeml.graph.rnn import UGRNNLRCell\n",
|
||||
"from edgeml.graph.rnn import GRULRCell\n",
|
||||
"from edgeml.graph.rnn import LSTMLRCell\n",
|
||||
"from edgeml.tf.trainer.fastTrainer import FastTrainer\n",
|
||||
"from edgeml.tf.graph.rnn import FastGRNNCell\n",
|
||||
"from edgeml.tf.graph.rnn import FastRNNCell\n",
|
||||
"from edgeml.tf.graph.rnn import UGRNNLRCell\n",
|
||||
"from edgeml.tf.graph.rnn import GRULRCell\n",
|
||||
"from edgeml.tf.graph.rnn import LSTMLRCell\n",
|
||||
"\n",
|
||||
"# Fixing seeds for reproducibility\n",
|
||||
"tf.set_random_seed(42)\n",
|
|
@ -6,12 +6,12 @@ import tensorflow as tf
|
|||
import numpy as np
|
||||
import sys
|
||||
|
||||
from edgeml.trainer.fastTrainer import FastTrainer
|
||||
from edgeml.graph.rnn import FastGRNNCell
|
||||
from edgeml.graph.rnn import FastRNNCell
|
||||
from edgeml.graph.rnn import UGRNNLRCell
|
||||
from edgeml.graph.rnn import GRULRCell
|
||||
from edgeml.graph.rnn import LSTMLRCell
|
||||
from edgeml.tf.trainer.fastTrainer import FastTrainer
|
||||
from edgeml.tf.graph.rnn import FastGRNNCell
|
||||
from edgeml.tf.graph.rnn import FastRNNCell
|
||||
from edgeml.tf.graph.rnn import UGRNNLRCell
|
||||
from edgeml.tf.graph.rnn import GRULRCell
|
||||
from edgeml.tf.graph.rnn import LSTMLRCell
|
||||
|
||||
|
||||
def main():
|
|
@ -4,11 +4,11 @@ This directory includes an example [notebook](protoNN_example.ipynb) and a
|
|||
command line execution script of ProtoNN developed as part of EdgeML. The
|
||||
example is based on the USPS dataset.
|
||||
|
||||
`edgeml.graph.protoNN` implements the ProtoNN prediction graph in Tensorflow.
|
||||
`edgeml.tf.graph.protoNN` implements the ProtoNN prediction graph in Tensorflow.
|
||||
The training routine for ProtoNN is decoupled from the forward graph to
|
||||
facilitate a plug and play behaviour wherein ProtoNN can be combined with or
|
||||
used as a final layer classifier for other architectures (RNNs, CNNs). The
|
||||
training routine is implemented in `edgeml.trainer.protoNNTrainer`.
|
||||
training routine is implemented in `edgeml.tf.trainer.protoNNTrainer`.
|
||||
|
||||
Note that, `protoNN_example.py` assumes the data to be in a specific format. It
|
||||
is assumed that train and test data is contained in two files, `train.npy` and
|
|
@ -6,7 +6,7 @@ import sys
|
|||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import edgeml.utils as utils
|
||||
import edgeml.tf.utils as utils
|
||||
import argparse
|
||||
|
||||
|
|
@ -29,9 +29,9 @@
|
|||
"import numpy as np\n",
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"from edgeml.trainer.protoNNTrainer import ProtoNNTrainer\n",
|
||||
"from edgeml.graph.protoNN import ProtoNN\n",
|
||||
"import edgeml.utils as utils\n",
|
||||
"from edgeml.tf.trainer.protoNNTrainer import ProtoNNTrainer\n",
|
||||
"from edgeml.tf.graph.protoNN import ProtoNN\n",
|
||||
"import edgeml.tf.utils as utils\n",
|
||||
"import helpermethods as helper"
|
||||
]
|
||||
},
|
|
@ -6,9 +6,9 @@ import sys
|
|||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from edgeml.trainer.protoNNTrainer import ProtoNNTrainer
|
||||
from edgeml.graph.protoNN import ProtoNN
|
||||
import edgeml.utils as utils
|
||||
from edgeml.tf.trainer.protoNNTrainer import ProtoNNTrainer
|
||||
from edgeml.tf.graph.protoNN import ProtoNN
|
||||
import edgeml.tf.utils as utils
|
||||
import helpermethods as helper
|
||||
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
## Edge Machine Learning: PyTorch Library
|
||||
|
||||
This directory includes, PyTorch implementations of various techniques and
|
||||
algorithms developed as part of EdgeML. Currently, the following algorithms are
|
||||
available in PyTorch:
|
||||
|
||||
1. [Bonsai](../docs/publications/Bonsai.pdf)
|
||||
2. [FastRNN & FastGRNN](../docs/publications/FastGRNN.pdf)
|
||||
|
||||
The PyTorch compute graphs for these algoriths are packaged as
|
||||
`pytorch_edgeml.graph`. Trainers for these algorithms are in `pytorch_edgeml.trainer`. Usage
|
||||
directions and examples for these algorithms are provided in `examples`
|
||||
directory. To get started with any of the provided algorithms, please follow
|
||||
the notebooks in the the `examples` directory.
|
||||
|
||||
## Installation
|
||||
|
||||
Use pip and the provided requirements file to first install required
|
||||
dependencies before installing the `pytorch_edgeml` library. Details for installation provided below.
|
||||
|
||||
It is highly recommended that EdgeML be installed in a virtual environment. Please create
|
||||
a new virtual environment using your environment manager ([virtualenv](https://virtualenv.pypa.io/en/stable/userguide/#usage) or [Anaconda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands)).
|
||||
Make sure the new environment is active before running the below mentioned commands.
|
||||
|
||||
### CPU
|
||||
|
||||
```
|
||||
pip install -r requirements-cpu.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Tested on Python 3.6 with PyTorch 1.1.
|
||||
|
||||
### GPU
|
||||
|
||||
Install appropriate CUDA and cuDNN [Tested with >= CUDA 9.0 and cuDNN >= 7.0]
|
||||
|
||||
```
|
||||
pip install -r requirements-gpu.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Note: If the above commands don't go through for PyTorch installation on CPU and GPU, please follow this [link](https://pytorch.org/get-started/locally/).
|
||||
|
||||
Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
Licensed under the MIT license.
|
|
@ -1,9 +0,0 @@
|
|||
from distutils.core import setup
|
||||
|
||||
setup(
|
||||
name='pytorch_edgeml',
|
||||
version='0.2',
|
||||
packages=['pytorch_edgeml', ],
|
||||
license='MIT License',
|
||||
long_description=open('../License.txt').read(),
|
||||
)
|
|
@ -0,0 +1,11 @@
|
|||
import setuptools #enables develop
|
||||
|
||||
setuptools.setup(
|
||||
name='edgeml',
|
||||
version='0.2',
|
||||
description='machine learning algorithms for edge devices developed at Microsoft Research India.',
|
||||
packages=setuptools.find_packages(),
|
||||
license='MIT License',
|
||||
long_description=open('License.txt').read(),
|
||||
url='https://github.com/Microsoft/EdgeML',
|
||||
)
|
|
@ -1,9 +0,0 @@
|
|||
from distutils.core import setup
|
||||
|
||||
setup(
|
||||
name='edgeml',
|
||||
version='0.2',
|
||||
packages=['edgeml', ],
|
||||
license='MIT License',
|
||||
long_description=open('../License.txt').read(),
|
||||
)
|