diff --git a/tf/docs/EMI-RNN.md b/docs/EMI-RNN.md similarity index 100% rename from tf/docs/EMI-RNN.md rename to docs/EMI-RNN.md diff --git a/tf/docs/FastCells.md b/docs/FastCells.md similarity index 100% rename from tf/docs/FastCells.md rename to docs/FastCells.md diff --git a/tf/docs/img/3PartsGraph.png b/docs/img/3PartsGraph.png old mode 100755 new mode 100644 similarity index 100% rename from tf/docs/img/3PartsGraph.png rename to docs/img/3PartsGraph.png diff --git a/tf/docs/img/FastGRNN.png b/docs/img/FastGRNN.png similarity index 100% rename from tf/docs/img/FastGRNN.png rename to docs/img/FastGRNN.png diff --git a/tf/docs/img/FastGRNN_eq.png b/docs/img/FastGRNN_eq.png similarity index 100% rename from tf/docs/img/FastGRNN_eq.png rename to docs/img/FastGRNN_eq.png diff --git a/tf/docs/img/FastRNN.png b/docs/img/FastRNN.png similarity index 100% rename from tf/docs/img/FastRNN.png rename to docs/img/FastRNN.png diff --git a/tf/docs/img/FastRNN_eq.png b/docs/img/FastRNN_eq.png similarity index 100% rename from tf/docs/img/FastRNN_eq.png rename to docs/img/FastRNN_eq.png diff --git a/tf/docs/img/MIML_illustration.png b/docs/img/MIML_illustration.png old mode 100755 new mode 100644 similarity index 100% rename from tf/docs/img/MIML_illustration.png rename to docs/img/MIML_illustration.png diff --git a/tf/README.md b/edgeml/README.md similarity index 100% rename from tf/README.md rename to edgeml/README.md diff --git a/edgeml/__init__.py b/edgeml/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/edgeml/pytorch/__init__.py b/edgeml/pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/edgeml/pytorch/graph/__init__.py b/edgeml/pytorch/graph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch/pytorch_edgeml/graph/bonsai.py b/edgeml/pytorch/graph/bonsai.py similarity index 100% rename from pytorch/pytorch_edgeml/graph/bonsai.py rename to edgeml/pytorch/graph/bonsai.py diff --git a/pytorch/pytorch_edgeml/graph/protoNN.py b/edgeml/pytorch/graph/protoNN.py similarity index 100% rename from pytorch/pytorch_edgeml/graph/protoNN.py rename to edgeml/pytorch/graph/protoNN.py diff --git a/pytorch/pytorch_edgeml/graph/rnn.py b/edgeml/pytorch/graph/rnn.py similarity index 80% rename from pytorch/pytorch_edgeml/graph/rnn.py rename to edgeml/pytorch/graph/rnn.py index 5dc17cca..d01d589e 100644 --- a/pytorch/pytorch_edgeml/graph/rnn.py +++ b/edgeml/pytorch/graph/rnn.py @@ -3,36 +3,57 @@ import torch import torch.nn as nn +from torch.autograd import Function import numpy as np -def gen_non_linearity(A, non_linearity): +def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank, gate_nonlinearity, update_nonlinearity): + class RNNSymbolic(Function): + @staticmethod + def symbolic(g, *fargs): + # NOTE: args/kwargs contain RNN parameters + return g.op("FastGRNN", *fargs, outputs=1, + hidden_size_i=hidden_size, wRank_i=wRank, uRank_i=uRank, + gate_nonlinearity_s=gate_nonlinearity, update_nonlinearity_s=update_nonlinearity) + + @staticmethod + def forward(ctx, *fargs): + return output + + @staticmethod + def backward(ctx, *gargs, **gkwargs): + raise RuntimeError("FIXME: Traced RNNs don't support backward") + + output_temp = RNNSymbolic.apply(input, *fargs) + return output_temp + +def gen_nonlinearity(A, nonlinearity): ''' Returns required activation for a tensor based on the inputs - non_linearity is either a callable or a value in + nonlinearity is either a callable or a value in ['tanh', 'sigmoid', 'relu', 'quantTanh', 'quantSigm', 'quantSigm4'] ''' - if non_linearity == "tanh": + if nonlinearity == "tanh": return torch.tanh(A) - elif non_linearity == "sigmoid": + elif nonlinearity == "sigmoid": return torch.sigmoid(A) - elif non_linearity == "relu": + elif nonlinearity == "relu": return torch.relu(A, 0.0) - elif non_linearity == "quantTanh": + elif nonlinearity == "quantTanh": return torch.max(torch.min(A, torch.ones_like(A)), -1.0 * torch.ones_like(A)) - elif non_linearity == "quantSigm": + elif nonlinearity == "quantSigm": A = (A + 1.0) / 2.0 return torch.max(torch.min(A, torch.ones_like(A)), torch.zeros_like(A)) - elif non_linearity == "quantSigm4": + elif nonlinearity == "quantSigm4": A = (A + 2.0) / 4.0 return torch.max(torch.min(A, torch.ones_like(A)), torch.zeros_like(A)) else: - # non_linearity is a user specified function - if not callable(non_linearity): - raise ValueError("non_linearity is either a callable or a value " + + # nonlinearity is a user specified function + if not callable(nonlinearity): + raise ValueError("nonlinearity is either a callable or a value " + + "['tanh', 'sigmoid', 'relu', 'quantTanh', " + "'quantSigm'") - return non_linearity(A) + return nonlinearity(A) class BaseRNN(nn.Module): @@ -49,6 +70,9 @@ class BaseRNN(nn.Module): self.RNNCell = RNNCell self.batch_first = batch_first + def getVars(self): + return self.RNNCell.getVars() + def forward(self, input, hiddenState=None, cellState=None): if self.batch_first is True: @@ -111,9 +135,9 @@ class FastGRNNCell(nn.Module): Has multiple activation functions for the gates hidden_size = # hidden units - gate_non_linearity = nonlinearity for the gate can be chosen from + gate_nonlinearity = nonlinearity for the gate can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] - update_non_linearity = nonlinearity for final rnn update + update_nonlinearity = nonlinearity for final rnn update can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] wRank = rank of W matrix (creates two matrices if not None) @@ -134,15 +158,15 @@ class FastGRNNCell(nn.Module): W = matmul(W_1, W_2) and U = matmul(U_1, U_2) ''' - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, name="FastGRNN"): super(FastGRNNCell, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._num_weight_matrices = [1, 1] self._wRank = wRank self._uRank = uRank @@ -155,17 +179,17 @@ class FastGRNNCell(nn.Module): self._name = name if wRank is None: - self.W = nn.Parameter(0.1 * torch.randn([input_size, hidden_size])) + self.W = nn.Parameter(0.1 * torch.randn([hidden_size, input_size])) else: - self.W1 = nn.Parameter(0.1 * torch.randn([input_size, wRank])) - self.W2 = nn.Parameter(0.1 * torch.randn([wRank, hidden_size])) + self.W1 = nn.Parameter(0.1 * torch.randn([wRank, input_size])) + self.W2 = nn.Parameter(0.1 * torch.randn([hidden_size, wRank])) if uRank is None: self.U = nn.Parameter( 0.1 * torch.randn([hidden_size, hidden_size])) else: - self.U1 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank])) - self.U2 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size])) + self.U1 = nn.Parameter(0.1 * torch.randn([uRank, hidden_size])) + self.U2 = nn.Parameter(0.1 * torch.randn([hidden_size, uRank])) self.bias_gate = nn.Parameter(torch.ones([1, hidden_size])) self.bias_update = nn.Parameter(torch.ones([1, hidden_size])) @@ -185,12 +209,12 @@ class FastGRNNCell(nn.Module): return self._hidden_size @property - def gate_non_linearity(self): - return self._gate_non_linearity + def gate_nonlinearity(self): + return self._gate_nonlinearity @property - def update_non_linearity(self): - return self._update_non_linearity + def update_nonlinearity(self): + return self._update_nonlinearity @property def wRank(self): @@ -214,23 +238,23 @@ class FastGRNNCell(nn.Module): def forward(self, input, state): if self._wRank is None: - wComp = torch.matmul(input, self.W) + wComp = torch.matmul(input, torch.transpose(self.W, 0, 1)) else: wComp = torch.matmul( - torch.matmul(input, self.W1), self.W2) + torch.matmul(input, torch.transpose(self.W1, 0, 1)), torch.transpose(self.W2, 0, 1)) if self._uRank is None: - uComp = torch.matmul(state, self.U) + uComp = torch.matmul(state, torch.transpose(self.U, 0, 1)) else: uComp = torch.matmul( - torch.matmul(state, self.U1), self.U2) + torch.matmul(state, torch.transpose(self.U1, 0, 1)), torch.transpose(self.U2, 0, 1)) pre_comp = wComp + uComp - z = gen_non_linearity(pre_comp + self.bias_gate, - self._gate_non_linearity) - c = gen_non_linearity(pre_comp + self.bias_update, - self._update_non_linearity) + z = gen_nonlinearity(pre_comp + self.bias_gate, + self._gate_nonlinearity) + c = gen_nonlinearity(pre_comp + self.bias_update, + self._update_nonlinearity) new_h = z * state + (torch.sigmoid(self.zeta) * (1.0 - z) + torch.sigmoid(self.nu)) * c @@ -260,7 +284,7 @@ class FastRNNCell(nn.Module): Has multiple activation functions for the gates hidden_size = # hidden units - update_non_linearity = nonlinearity for final rnn update + update_nonlinearity = nonlinearity for final rnn update can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] wRank = rank of W matrix (creates two matrices if not None) @@ -281,13 +305,13 @@ class FastRNNCell(nn.Module): ''' def __init__(self, input_size, hidden_size, - update_non_linearity="tanh", wRank=None, uRank=None, + update_nonlinearity="tanh", wRank=None, uRank=None, alphaInit=-3.0, betaInit=3.0, name="FastRNN"): super(FastRNNCell, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._update_non_linearity = update_non_linearity + self._update_nonlinearity = update_nonlinearity self._num_weight_matrices = [1, 1] self._wRank = wRank self._uRank = uRank @@ -329,8 +353,8 @@ class FastRNNCell(nn.Module): return self._hidden_size @property - def update_non_linearity(self): - return self._update_non_linearity + def update_nonlinearity(self): + return self._update_nonlinearity @property def wRank(self): @@ -367,8 +391,8 @@ class FastRNNCell(nn.Module): pre_comp = wComp + uComp - c = gen_non_linearity(pre_comp + self.bias_update, - self._update_non_linearity) + c = gen_nonlinearity(pre_comp + self.bias_update, + self._update_nonlinearity) new_h = torch.sigmoid(self.beta) * state + \ torch.sigmoid(self.alpha) * c @@ -399,9 +423,9 @@ class LSTMLRCell(nn.Module): Has multiple activation functions for the gates hidden_size = # hidden units - gate_non_linearity = nonlinearity for the gate can be chosen from + gate_nonlinearity = nonlinearity for the gate can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] - update_non_linearity = nonlinearity for final rnn update + update_nonlinearity = nonlinearity for final rnn update can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] wRank = rank of all W matrices @@ -425,15 +449,15 @@ class LSTMLRCell(nn.Module): Wi = matmul(W, W_i) and Ui = matmul(U, U_i) ''' - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, name="LSTMLR"): super(LSTMLRCell, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._num_weight_matrices = [4, 4] self._wRank = wRank self._uRank = uRank @@ -493,12 +517,12 @@ class LSTMLRCell(nn.Module): return self._hidden_size @property - def gate_non_linearity(self): - return self._gate_non_linearity + def gate_nonlinearity(self): + return self._gate_nonlinearity @property - def update_non_linearity(self): - return self._update_non_linearity + def update_nonlinearity(self): + return self._update_nonlinearity @property def wRank(self): @@ -557,18 +581,18 @@ class LSTMLRCell(nn.Module): pre_comp3 = wComp3 + uComp3 pre_comp4 = wComp4 + uComp4 - i = gen_non_linearity(pre_comp1 + self.bias_i, - self._gate_non_linearity) - f = gen_non_linearity(pre_comp2 + self.bias_f, - self._gate_non_linearity) - o = gen_non_linearity(pre_comp4 + self.bias_o, - self._gate_non_linearity) + i = gen_nonlinearity(pre_comp1 + self.bias_i, + self._gate_nonlinearity) + f = gen_nonlinearity(pre_comp2 + self.bias_f, + self._gate_nonlinearity) + o = gen_nonlinearity(pre_comp4 + self.bias_o, + self._gate_nonlinearity) - c_ = gen_non_linearity(pre_comp3 + self.bias_c, - self._update_non_linearity) + c_ = gen_nonlinearity(pre_comp3 + self.bias_c, + self._update_nonlinearity) new_c = f * c + i * c_ - new_h = o * gen_non_linearity(new_c, self._update_non_linearity) + new_h = o * gen_nonlinearity(new_c, self._update_nonlinearity) return new_h, new_c def getVars(self): @@ -594,9 +618,9 @@ class GRULRCell(nn.Module): Has multiple activation functions for the gates hidden_size = # hidden units - gate_non_linearity = nonlinearity for the gate can be chosen from + gate_nonlinearity = nonlinearity for the gate can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] - update_non_linearity = nonlinearity for final rnn update + update_nonlinearity = nonlinearity for final rnn update can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] wRank = rank of W matrix @@ -618,15 +642,15 @@ class GRULRCell(nn.Module): Wi = matmul(W, W_i) and Ui = matmul(U, U_i) ''' - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, name="GRULR"): super(GRULRCell, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._num_weight_matrices = [3, 3] self._wRank = wRank self._uRank = uRank @@ -680,12 +704,12 @@ class GRULRCell(nn.Module): return self._hidden_size @property - def gate_non_linearity(self): - return self._gate_non_linearity + def gate_nonlinearity(self): + return self._gate_nonlinearity @property - def update_non_linearity(self): - return self._update_non_linearity + def update_nonlinearity(self): + return self._update_nonlinearity @property def wRank(self): @@ -732,10 +756,10 @@ class GRULRCell(nn.Module): pre_comp1 = wComp1 + uComp1 pre_comp2 = wComp2 + uComp2 - r = gen_non_linearity(pre_comp1 + self.bias_r, - self._gate_non_linearity) - z = gen_non_linearity(pre_comp2 + self.bias_gate, - self._gate_non_linearity) + r = gen_nonlinearity(pre_comp1 + self.bias_r, + self._gate_nonlinearity) + z = gen_nonlinearity(pre_comp2 + self.bias_gate, + self._gate_nonlinearity) if self._uRank is None: pre_comp3 = wComp3 + torch.matmul(r * state, self.U3) @@ -743,8 +767,8 @@ class GRULRCell(nn.Module): pre_comp3 = wComp3 + \ torch.matmul(torch.matmul(r * state, self.U), self.U3) - c = gen_non_linearity(pre_comp3 + self.bias_update, - self._update_non_linearity) + c = gen_nonlinearity(pre_comp3 + self.bias_update, + self._update_nonlinearity) new_h = z * state + (1.0 - z) * c return new_h @@ -772,9 +796,9 @@ class UGRNNLRCell(nn.Module): Has multiple activation functions for the gates hidden_size = # hidden units - gate_non_linearity = nonlinearity for the gate can be chosen from + gate_nonlinearity = nonlinearity for the gate can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] - update_non_linearity = nonlinearity for final rnn update + update_nonlinearity = nonlinearity for final rnn update can be chosen from [tanh, sigmoid, relu, quantTanh, quantSigm] wRank = rank of W matrix @@ -795,15 +819,15 @@ class UGRNNLRCell(nn.Module): Wi = matmul(W, W_i) and Ui = matmul(U, U_i) ''' - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, name="UGRNNLR"): super(UGRNNLRCell, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._num_weight_matrices = [2, 2] self._wRank = wRank self._uRank = uRank @@ -850,12 +874,12 @@ class UGRNNLRCell(nn.Module): return self._hidden_size @property - def gate_non_linearity(self): - return self._gate_non_linearity + def gate_nonlinearity(self): + return self._gate_nonlinearity @property - def update_non_linearity(self): - return self._update_non_linearity + def update_nonlinearity(self): + return self._update_nonlinearity @property def wRank(self): @@ -899,10 +923,10 @@ class UGRNNLRCell(nn.Module): pre_comp1 = wComp1 + uComp1 pre_comp2 = wComp2 + uComp2 - z = gen_non_linearity(pre_comp1 + self.bias_gate, - self._gate_non_linearity) - c = gen_non_linearity(pre_comp2 + self.bias_update, - self._update_non_linearity) + z = gen_nonlinearity(pre_comp1 + self.bias_gate, + self._gate_nonlinearity) + c = gen_nonlinearity(pre_comp2 + self.bias_update, + self._update_nonlinearity) new_h = z * state + (1.0 - z) * c return new_h @@ -927,20 +951,20 @@ class UGRNNLRCell(nn.Module): class LSTM(nn.Module): """Equivalent to nn.LSTM using LSTMLRCell""" - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, batch_first=True): + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, batch_first=True): super(LSTM, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._wRank = wRank self._uRank = uRank self.batch_first = batch_first self.cell = LSTMLRCell(input_size, hidden_size, - gate_non_linearity=gate_non_linearity, - update_non_linearity=update_non_linearity, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank) self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first) @@ -951,20 +975,20 @@ class LSTM(nn.Module): class GRU(nn.Module): """Equivalent to nn.GRU using GRULRCell""" - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, batch_first=True): + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, batch_first=True): super(GRU, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._wRank = wRank self._uRank = uRank self.batch_first = batch_first self.cell = GRULRCell(input_size, hidden_size, - gate_non_linearity=gate_non_linearity, - update_non_linearity=update_non_linearity, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank) self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first) @@ -975,20 +999,20 @@ class GRU(nn.Module): class UGRNN(nn.Module): """Equivalent to nn.UGRNN using UGRNNLRCell""" - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, batch_first=True): + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, batch_first=True): super(UGRNN, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._wRank = wRank self._uRank = uRank self.batch_first = batch_first self.cell = UGRNNLRCell(input_size, hidden_size, - gate_non_linearity=gate_non_linearity, - update_non_linearity=update_non_linearity, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank) self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first) @@ -999,21 +1023,21 @@ class UGRNN(nn.Module): class FastRNN(nn.Module): """Equivalent to nn.FastRNN using FastRNNCell""" - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, alphaInit=-3.0, betaInit=3.0, batch_first=True): super(FastRNN, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._wRank = wRank self._uRank = uRank self.batch_first = batch_first self.cell = FastRNNCell(input_size, hidden_size, - gate_non_linearity=gate_non_linearity, - update_non_linearity=update_non_linearity, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank, alphaInit=alphaInit, betaInit=betaInit) self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first) @@ -1025,25 +1049,28 @@ class FastRNN(nn.Module): class FastGRNN(nn.Module): """Equivalent to nn.FastGRNN using FastGRNNCell""" - def __init__(self, input_size, hidden_size, gate_non_linearity="sigmoid", - update_non_linearity="tanh", wRank=None, uRank=None, + def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", + update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, batch_first=True): super(FastGRNN, self).__init__() self._input_size = input_size self._hidden_size = hidden_size - self._gate_non_linearity = gate_non_linearity - self._update_non_linearity = update_non_linearity + self._gate_nonlinearity = gate_nonlinearity + self._update_nonlinearity = update_nonlinearity self._wRank = wRank self._uRank = uRank self.batch_first = batch_first self.cell = FastGRNNCell(input_size, hidden_size, - gate_non_linearity=gate_non_linearity, - update_non_linearity=update_non_linearity, + gate_nonlinearity=gate_nonlinearity, + update_nonlinearity=update_nonlinearity, wRank=wRank, uRank=uRank, zetaInit=zetaInit, nuInit=nuInit) self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first) + def getVars(self): + return self.unrollRNN.getVars() + def forward(self, input, hiddenState=None, cellState=None): return self.unrollRNN(input, hiddenState, cellState) diff --git a/edgeml/pytorch/trainer/__init__.py b/edgeml/pytorch/trainer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch/pytorch_edgeml/trainer/bonsaiTrainer.py b/edgeml/pytorch/trainer/bonsaiTrainer.py similarity index 99% rename from pytorch/pytorch_edgeml/trainer/bonsaiTrainer.py rename to edgeml/pytorch/trainer/bonsaiTrainer.py index 169b9669..76c6cc85 100644 --- a/pytorch/pytorch_edgeml/trainer/bonsaiTrainer.py +++ b/edgeml/pytorch/trainer/bonsaiTrainer.py @@ -5,7 +5,7 @@ import torch import numpy as np import os import sys -import pytorch_edgeml.utils as utils +import edgeml.pytorch.utils as utils class BonsaiTrainer: diff --git a/pytorch/pytorch_edgeml/trainer/fastTrainer.py b/edgeml/pytorch/trainer/fastTrainer.py similarity index 99% rename from pytorch/pytorch_edgeml/trainer/fastTrainer.py rename to edgeml/pytorch/trainer/fastTrainer.py index b44dd19b..efe6698b 100644 --- a/pytorch/pytorch_edgeml/trainer/fastTrainer.py +++ b/edgeml/pytorch/trainer/fastTrainer.py @@ -5,8 +5,8 @@ import os import sys import torch import torch.nn as nn -import pytorch_edgeml.utils as utils -from pytorch_edgeml.graph.rnn import * +import edgeml.pytorch.utils as utils +from edgeml.pytorch.graph.rnn import * import numpy as np diff --git a/edgeml/pytorch/trainer/fastmodel.py b/edgeml/pytorch/trainer/fastmodel.py new file mode 100644 index 00000000..c6fff580 --- /dev/null +++ b/edgeml/pytorch/trainer/fastmodel.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import sys +from edgeml.pytorch.graph.rnn import * + +def fastgrnnmodel(inheritance_class=nn.Module): + class FastGRNNModel(inheritance_class): + """This class is a PyTorch Module that implements a 1, 2 or 3 layer GRU based audio classifier""" + + def __init__(self, input_dim, num_layers, hidden_units_list, wRank_list, uRank_list, gate_nonlinearity, update_nonlinearity, num_classes=None, linear=True, batch_first=False, apply_softmax=True): + """ + Initialize the KeywordSpotter with the following parameters: + input_dim - the size of the input audio frame in # samples. + hidden_units - the size of the hidden state of the FastGrnn nodes + num_keywords - the number of predictions to come out of the model. + num_layers - the number of FastGrnn layers to use (1, 2 or 3) + """ + self.input_dim = input_dim + self.hidden_units_list = hidden_units_list + self.num_layers = num_layers + self.num_classes = num_classes + self.wRank_list = wRank_list + self.uRank_list = uRank_list + self.gate_nonlinearity = gate_nonlinearity + self.update_nonlinearity = update_nonlinearity + self.linear = linear + self.batch_first = batch_first + self.apply_softmax = apply_softmax + if self.linear: + if not self.num_classes: + raise Exception("num_classes need to be specified if linear is True") + + super(FastGRNNModel, self).__init__() + + # The FastGRNN takes audio sequences as input, and outputs hidden states + # with dimensionality hidden_units. + self.fastgrnn1 = FastGRNN(self.input_dim, self.hidden_units_list[0], + gate_nonlinearity=self.gate_nonlinearity, + update_nonlinearity=self.update_nonlinearity, + wRank=self.wRank_list[0], uRank=self.uRank_list[0], + batch_first = self.batch_first) + self.fastgrnn2 = None + last_output_size = self.hidden_units_list[0] + if self.num_layers > 1: + self.fastgrnn2 = FastGRNN(self.hidden_units_list[0], self.hidden_units_list[1], + gate_nonlinearity=self.gate_nonlinearity, + update_nonlinearity=self.update_nonlinearity, + wRank=self.wRank_list[1], uRank=self.uRank_list[1], + batch_first = self.batch_first) + last_output_size = self.hidden_units_list[1] + self.fastgrnn3 = None + if self.num_layers > 2: + self.fastgrnn3 = FastGRNN(self.hidden_units_list[1], self.hidden_units_list[2], + gate_nonlinearity=self.gate_nonlinearity, + update_nonlinearity=self.update_nonlinearity, + wRank=self.wRank_list[2], uRank=self.uRank_list[2], + batch_first = self.batch_first) + last_output_size = self.hidden_units_list[2] + + # The linear layer is a fully connected layer that maps from hidden state space + # to number of expected keywords + if self.linear: + self.hidden2keyword = nn.Linear(last_output_size, num_classes) + self.init_hidden() + + def normalize(self, mean, std): + self.mean = mean + self.std = std + + def name(self): + return "{} layer FastGRNN".format(self.num_layers) + + def init_hidden_bag(self, hidden_bag_size, device): + self.hidden_bag_size = hidden_bag_size + self.device = device + self.hidden1_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[0]], + dtype=np.float32)).to(self.device) + if self.num_layers >= 2: + self.hidden2_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[1]], + dtype=np.float32)).to(self.device) + if self.num_layers == 3: + self.hidden3_bag = torch.from_numpy(np.zeros([self.hidden_bag_size, self.hidden_units_list[2]], + dtype=np.float32)).to(self.device) + + def rolling_step(self): + shuffled_indices = list(range(self.hidden_bag_size)) + np.random.shuffle(shuffled_indices) + if self.hidden1 is not None: + batch_size = self.hidden1.shape[0] + temp_indices = shuffled_indices[:batch_size] + self.hidden1_bag[temp_indices, :] = self.hidden1 + self.hidden1 = self.hidden1_bag[0:batch_size, :] + if self.num_layers >= 2: + self.hidden2_bag[temp_indices, :] = self.hidden2 + self.hidden2 = self.hidden2_bag[0:batch_size, :] + if self.num_layers == 3: + self.hidden3_bag[temp_indices, :] = self.hidden3 + self.hidden3 = self.hidden3_bag[0:batch_size, :] + + def init_hidden(self): + """ Clear the hidden state for the GRU nodes """ + self.hidden1 = None + self.hidden2 = None + self.hidden3 = None + + def forward(self, input): + """ Perform the forward processing of the given input and return the prediction """ + # input is shape: [seq,batch,feature] + if self.mean is not None: + input = (input - self.mean) / self.std + fastgrnn_out1 = self.fastgrnn1(input, hiddenState=self.hidden1) + fastgrnn_output = fastgrnn_out1 + # we have to detach the hidden states because we may keep them longer than 1 iteration. + self.hidden1 = fastgrnn_out1.detach()[-1, :, :] + if self.tracking: + weights = self.fastgrnn1.getVars() + fastgrnn_out1 = onnx_exportable_fastgrnn(input, weights, + output=fastgrnn_out1, hidden_size=self.hidden_units_list[0], + wRank=self.wRank_list[0], uRank=self.uRank_list[0], + gate_nonlinearity=self.gate_nonlinearity, + update_nonlinearity=self.update_nonlinearity) + fastgrnn_output = fastgrnn_out1 + if self.fastgrnn2 is not None: + fastgrnn_out2 = self.fastgrnn2(fastgrnn_out1, hiddenState=self.hidden2) + self.hidden2 = fastgrnn_out2.detach()[-1, :, :] + if self.tracking: + weights = self.fastgrnn2.getVars() + fastgrnn_out2 = onnx_exportable_fastgrnn(fastgrnn_out1, weights, + output=fastgrnn_out2, hidden_size=self.hidden_units_list[1], + wRank=self.wRank_list[1], uRank=self.uRank_list[1], + gate_nonlinearity=self.gate_nonlinearity, + update_nonlinearity=self.update_nonlinearity) + fastgrnn_output = fastgrnn_out2 + if self.fastgrnn3 is not None: + fastgrnn_out3 = self.fastgrnn3(fastgrnn_out2, hiddenState=self.hidden3) + self.hidden3 = fastgrnn_out3.detach()[-1, :, :] + if self.tracking: + weights = self.fastgrnn3.getVars() + fastgrnn_out3 = onnx_exportable_fastgrnn(fastgrnn_out2, weights, + output=fastgrnn_out3, hidden_size=self.hidden_units_list[2], + wRank=self.wRank_list[2], uRank=self.uRank_list[2], + gate_nonlinearity=self.gate_nonlinearity, + update_nonlinearity=self.update_nonlinearity) + fastgrnn_output = fastgrnn_out3 + if self.linear: + fastgrnn_output = self.hidden2keyword(fastgrnn_output[-1, :, :]) + if self.apply_softmax: + fastgrnn_output = F.log_softmax(fastgrnn_output, dim=1) + return fastgrnn_output + return FastGRNNModel \ No newline at end of file diff --git a/pytorch/pytorch_edgeml/trainer/protoNNTrainer.py b/edgeml/pytorch/trainer/protoNNTrainer.py similarity index 99% rename from pytorch/pytorch_edgeml/trainer/protoNNTrainer.py rename to edgeml/pytorch/trainer/protoNNTrainer.py index 80e52cf6..e49fa633 100644 --- a/pytorch/pytorch_edgeml/trainer/protoNNTrainer.py +++ b/edgeml/pytorch/trainer/protoNNTrainer.py @@ -5,7 +5,7 @@ import torch import numpy as np import os import sys -import pytorch_edgeml.utils as utils +import edgeml.pytorch.utils as utils class ProtoNNTrainer: diff --git a/pytorch/pytorch_edgeml/trainer/srnnTrainer.py b/edgeml/pytorch/trainer/srnnTrainer.py similarity index 99% rename from pytorch/pytorch_edgeml/trainer/srnnTrainer.py rename to edgeml/pytorch/trainer/srnnTrainer.py index 2e96ac0b..06b2c914 100644 --- a/pytorch/pytorch_edgeml/trainer/srnnTrainer.py +++ b/edgeml/pytorch/trainer/srnnTrainer.py @@ -5,7 +5,7 @@ import torch import numpy as np import os import sys -import pytorch_edgeml.utils as utils +import edgeml.pytorch.utils as utils class SRNNTrainer: diff --git a/pytorch/pytorch_edgeml/utils.py b/edgeml/pytorch/utils.py similarity index 100% rename from pytorch/pytorch_edgeml/utils.py rename to edgeml/pytorch/utils.py diff --git a/tf/edgeml/__init__.py b/edgeml/tf/__init__.py similarity index 100% rename from tf/edgeml/__init__.py rename to edgeml/tf/__init__.py diff --git a/tf/edgeml/graph/__init__.py b/edgeml/tf/graph/__init__.py similarity index 100% rename from tf/edgeml/graph/__init__.py rename to edgeml/tf/graph/__init__.py diff --git a/tf/edgeml/graph/bonsai.py b/edgeml/tf/graph/bonsai.py similarity index 100% rename from tf/edgeml/graph/bonsai.py rename to edgeml/tf/graph/bonsai.py diff --git a/tf/edgeml/graph/protoNN.py b/edgeml/tf/graph/protoNN.py similarity index 100% rename from tf/edgeml/graph/protoNN.py rename to edgeml/tf/graph/protoNN.py diff --git a/tf/edgeml/graph/rnn.py b/edgeml/tf/graph/rnn.py similarity index 100% rename from tf/edgeml/graph/rnn.py rename to edgeml/tf/graph/rnn.py diff --git a/tf/edgeml/trainer/__init__.py b/edgeml/tf/trainer/__init__.py similarity index 100% rename from tf/edgeml/trainer/__init__.py rename to edgeml/tf/trainer/__init__.py diff --git a/tf/edgeml/trainer/bonsaiTrainer.py b/edgeml/tf/trainer/bonsaiTrainer.py similarity index 99% rename from tf/edgeml/trainer/bonsaiTrainer.py rename to edgeml/tf/trainer/bonsaiTrainer.py index 93d28760..1d32a38f 100644 --- a/tf/edgeml/trainer/bonsaiTrainer.py +++ b/edgeml/tf/trainer/bonsaiTrainer.py @@ -3,7 +3,7 @@ from __future__ import print_function import tensorflow as tf -import edgeml.utils as utils +import edgeml.tf.utils as utils import numpy as np import os import sys diff --git a/tf/edgeml/trainer/emirnnTrainer.py b/edgeml/tf/trainer/emirnnTrainer.py similarity index 99% rename from tf/edgeml/trainer/emirnnTrainer.py rename to edgeml/tf/trainer/emirnnTrainer.py index 446fe336..6ea59b03 100644 --- a/tf/edgeml/trainer/emirnnTrainer.py +++ b/edgeml/tf/trainer/emirnnTrainer.py @@ -5,7 +5,7 @@ from __future__ import print_function import tensorflow as tf import numpy as np import sys -import edgeml.utils as utils +import edgeml.tf.utils as utils import pandas as pd class EMI_Trainer: diff --git a/tf/edgeml/trainer/fastTrainer.py b/edgeml/tf/trainer/fastTrainer.py similarity index 99% rename from tf/edgeml/trainer/fastTrainer.py rename to edgeml/tf/trainer/fastTrainer.py index 37561b9e..f1615c64 100644 --- a/tf/edgeml/trainer/fastTrainer.py +++ b/edgeml/tf/trainer/fastTrainer.py @@ -5,7 +5,7 @@ from __future__ import print_function import os import sys import tensorflow as tf -import edgeml.utils as utils +import edgeml.tf.utils as utils import numpy as np diff --git a/tf/edgeml/trainer/protoNNTrainer.py b/edgeml/tf/trainer/protoNNTrainer.py similarity index 99% rename from tf/edgeml/trainer/protoNNTrainer.py rename to edgeml/tf/trainer/protoNNTrainer.py index bf23121a..9881f82b 100644 --- a/tf/edgeml/trainer/protoNNTrainer.py +++ b/edgeml/tf/trainer/protoNNTrainer.py @@ -5,7 +5,7 @@ from __future__ import print_function import tensorflow as tf import numpy as np import sys -import edgeml.utils as utils +import edgeml.tf.utils as utils class ProtoNNTrainer: diff --git a/tf/edgeml/utils.py b/edgeml/tf/utils.py similarity index 100% rename from tf/edgeml/utils.py rename to edgeml/tf/utils.py diff --git a/pytorch/examples/Bonsai/README.md b/examples/pytorch/Bonsai/README.md similarity index 97% rename from pytorch/examples/Bonsai/README.md rename to examples/pytorch/Bonsai/README.md index afc46994..c388178a 100644 --- a/pytorch/examples/Bonsai/README.md +++ b/examples/pytorch/Bonsai/README.md @@ -4,7 +4,7 @@ This directory includes, example notebook and general execution script of Bonsai developed as part of EdgeML. Also, we include a sample cleanup and use-case on the USPS10 public dataset. -`pytorch_edgeml.graph.bonsai` implements the Bonsai prediction graph in pytorch. +`edgeml.pytorch.graph.bonsai` implements the Bonsai prediction graph in pytorch. The three-phase training routine for Bonsai is decoupled from the forward graph to facilitate a plug and play behaviour wherein Bonsai can be combined with or used as a final layer classifier for other architectures (RNNs, CNNs). diff --git a/pytorch/examples/Bonsai/bonsai_example.py b/examples/pytorch/Bonsai/bonsai_example.py similarity index 95% rename from pytorch/examples/Bonsai/bonsai_example.py rename to examples/pytorch/Bonsai/bonsai_example.py index e7425265..33d98285 100644 --- a/pytorch/examples/Bonsai/bonsai_example.py +++ b/examples/pytorch/Bonsai/bonsai_example.py @@ -4,8 +4,8 @@ import helpermethods import numpy as np import sys -from pytorch_edgeml.trainer.bonsaiTrainer import BonsaiTrainer -from pytorch_edgeml.graph.bonsai import Bonsai +from edgeml.pytorch.trainer.bonsaiTrainer import BonsaiTrainer +from edgeml.pytorch.graph.bonsai import Bonsai import torch diff --git a/pytorch/examples/Bonsai/fetch_usps.py b/examples/pytorch/Bonsai/fetch_usps.py similarity index 100% rename from pytorch/examples/Bonsai/fetch_usps.py rename to examples/pytorch/Bonsai/fetch_usps.py diff --git a/pytorch/examples/Bonsai/helpermethods.py b/examples/pytorch/Bonsai/helpermethods.py similarity index 100% rename from pytorch/examples/Bonsai/helpermethods.py rename to examples/pytorch/Bonsai/helpermethods.py diff --git a/pytorch/examples/Bonsai/process_usps.py b/examples/pytorch/Bonsai/process_usps.py similarity index 100% rename from pytorch/examples/Bonsai/process_usps.py rename to examples/pytorch/Bonsai/process_usps.py diff --git a/pytorch/examples/Bonsai/quantizeBonsaiModels.py b/examples/pytorch/Bonsai/quantizeBonsaiModels.py similarity index 100% rename from pytorch/examples/Bonsai/quantizeBonsaiModels.py rename to examples/pytorch/Bonsai/quantizeBonsaiModels.py diff --git a/pytorch/examples/FastCells/README.md b/examples/pytorch/FastCells/README.md similarity index 92% rename from pytorch/examples/FastCells/README.md rename to examples/pytorch/FastCells/README.md index 08875f25..99c39d60 100644 --- a/pytorch/examples/FastCells/README.md +++ b/examples/pytorch/FastCells/README.md @@ -5,21 +5,21 @@ FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified UGRNN, GRU and LSTM to support the LSQ training routine. Also, we include a sample cleanup and use-case on the USPS10 public dataset. -`pytorch_edgeml.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../pytorch_edgeml/graph/rnn.py#L226)) and **FastGRNN** ([`FastGRNNCell`](../../pytorch_edgeml/graph/rnn.py#L80)) with +`edgeml.pytorch.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../pytorch_edgeml/graph/rnn.py#L226)) and **FastGRNN** ([`FastGRNNCell`](../../pytorch_edgeml/graph/rnn.py#L80)) with multiple additional features like Low-Rank parameterisation, custom non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training routine for FastRNN and FastGRNN is decoupled from the custom cells to facilitate a plug and play behaviour of the custom RNN cells in other architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell` etc., -`pytorch_edgeml.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../pytorch_edgeml/graph/rnn.py#L742)), +`edgeml.pytorch.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../pytorch_edgeml/graph/rnn.py#L742)), **GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L565)) and **LSTM** ([`LSTMLRCell`](../../pytorch_edgeml/graph/rnn.py#L369)). These cells also can be substituted for FastCells where ever feasible. -`pytorch_edgeml.graph.rnn` also contains fully wrapped RNNs which are equivalent to `nn.LSTM` and `nn.GRU`. Implemented cells: -**FastRNN** ([`FastRNN`](../../pytorch_edgeml/graph/rnn.py#L968)), **FastGRNN** ([`FastGRNN`](../../pytorch_edgeml/graph/rnn.py#L993)), **UGRNN** ([`UGRNN`](../../pytorch_edgeml/graph/rnn.py#L945)), **GRU** ([`GRU`](../../edgeml/graph/rnn.py#L922)) and **LSTM** ([`LSTM`](../../pytorch_edgeml/graph/rnn.py#L899)). +`edgeml.pytorch.graph.rnn` also contains fully wrapped RNNs which are equivalent to `nn.LSTM` and `nn.GRU`. Implemented cells: +**FastRNN** ([`FastRNN`](../../pytorch_edgeml/graph/rnn.py#L968)), **FastGRNN** ([`FastGRNN`](../../pytorch_edgeml/graph/rnn.py#L993)), **UGRNN** ([`UGRNN`](../../edgeml.pytorch/graph/rnn.py#L945)), **GRU** ([`GRU`](../../edgeml/graph/rnn.py#L922)) and **LSTM** ([`LSTM`](../../pytorch_edgeml/graph/rnn.py#L899)). -Note that all the cells and wrappers (when used independently from `fastcell_example.py` or `pytorch_edgeml.trainer.fastTrainer`) take in data in a batch first format ie., [batchSize, timeSteps, inputDims] by default but it can also support [timeSteps, batchSize, inputDims] format by setting `batch_first` argument to False when used. `fast_example.py` automatically takes care it while assuming the standard format between tf, c++ and pytorch. +Note that all the cells and wrappers (when used independently from `fastcell_example.py` or `edgeml.pytorch.trainer.fastTrainer`) take in data in a batch first format ie., [batchSize, timeSteps, inputDims] by default but it can also support [timeSteps, batchSize, inputDims] format by setting `batch_first` argument to False when used. `fast_example.py` automatically takes care it while assuming the standard format between tf, c++ and pytorch. -For training FastCells, `pytorch_edgeml.trainer.fastTrainer` implements the three-phase +For training FastCells, `edgeml.pytorch.trainer.fastTrainer` implements the three-phase FastCell training routine in PyTorch. A simple example, `examples/fastcell_example.py` is provided to illustrate its usage. diff --git a/pytorch/examples/FastCells/fastcell_example.py b/examples/pytorch/FastCells/fastcell_example.py similarity index 96% rename from pytorch/examples/FastCells/fastcell_example.py rename to examples/pytorch/FastCells/fastcell_example.py index 1ccb93f0..a3357346 100644 --- a/pytorch/examples/FastCells/fastcell_example.py +++ b/examples/pytorch/FastCells/fastcell_example.py @@ -5,8 +5,8 @@ import helpermethods import torch import numpy as np import sys -from pytorch_edgeml.graph.rnn import * -from pytorch_edgeml.trainer.fastTrainer import FastTrainer +from edgeml.pytorch.graph.rnn import * +from edgeml.pytorch.trainer.fastTrainer import FastTrainer def main(): diff --git a/pytorch/examples/FastCells/fetch_usps.py b/examples/pytorch/FastCells/fetch_usps.py similarity index 100% rename from pytorch/examples/FastCells/fetch_usps.py rename to examples/pytorch/FastCells/fetch_usps.py diff --git a/pytorch/examples/FastCells/helpermethods.py b/examples/pytorch/FastCells/helpermethods.py similarity index 100% rename from pytorch/examples/FastCells/helpermethods.py rename to examples/pytorch/FastCells/helpermethods.py diff --git a/pytorch/examples/FastCells/process_usps.py b/examples/pytorch/FastCells/process_usps.py similarity index 100% rename from pytorch/examples/FastCells/process_usps.py rename to examples/pytorch/FastCells/process_usps.py diff --git a/pytorch/examples/FastCells/quantizeFastModels.py b/examples/pytorch/FastCells/quantizeFastModels.py similarity index 100% rename from pytorch/examples/FastCells/quantizeFastModels.py rename to examples/pytorch/FastCells/quantizeFastModels.py diff --git a/pytorch/examples/ProtoNN/README.md b/examples/pytorch/ProtoNN/README.md similarity index 93% rename from pytorch/examples/ProtoNN/README.md rename to examples/pytorch/ProtoNN/README.md index 75a3fdad..f1f1b11c 100644 --- a/pytorch/examples/ProtoNN/README.md +++ b/examples/pytorch/ProtoNN/README.md @@ -4,11 +4,11 @@ This directory includes an example [notebook](protoNN_example.ipynb) and a command line execution script of ProtoNN developed as part of EdgeML. The example is based on the USPS dataset. -`pytorch_edgeml.graph.protoNN` implements the ProtoNN prediction functions. +`edgeml.pytorch.graph.protoNN` implements the ProtoNN prediction functions. The training routine for ProtoNN is decoupled from the forward graph to facilitate a plug and play behaviour wherein ProtoNN can be combined with or used as a final layer classifier for other architectures (RNNs, CNNs). The -training routine is implemented in `pytorch_edgeml.trainer.protoNNTrainer`. +training routine is implemented in `edgeml.pytorch.trainer.protoNNTrainer`. (This is also an artifact of consistency requirements with Tensorflow implementation). diff --git a/pytorch/examples/ProtoNN/fetch_usps.py b/examples/pytorch/ProtoNN/fetch_usps.py similarity index 100% rename from pytorch/examples/ProtoNN/fetch_usps.py rename to examples/pytorch/ProtoNN/fetch_usps.py diff --git a/pytorch/examples/ProtoNN/helpermethods.py b/examples/pytorch/ProtoNN/helpermethods.py similarity index 99% rename from pytorch/examples/ProtoNN/helpermethods.py rename to examples/pytorch/ProtoNN/helpermethods.py index 0ba74556..49483e9a 100644 --- a/pytorch/examples/ProtoNN/helpermethods.py +++ b/examples/pytorch/ProtoNN/helpermethods.py @@ -5,7 +5,7 @@ from __future__ import print_function import sys import os import numpy as np -import pytorch_edgeml.utils as utils +import edgeml.pytorch.utils as utils import argparse diff --git a/pytorch/examples/ProtoNN/process_usps.py b/examples/pytorch/ProtoNN/process_usps.py similarity index 100% rename from pytorch/examples/ProtoNN/process_usps.py rename to examples/pytorch/ProtoNN/process_usps.py diff --git a/pytorch/examples/ProtoNN/protoNN_example.ipynb b/examples/pytorch/ProtoNN/protoNN_example.ipynb similarity index 99% rename from pytorch/examples/ProtoNN/protoNN_example.ipynb rename to examples/pytorch/ProtoNN/protoNN_example.ipynb index cf3192f0..b9aa583b 100644 --- a/pytorch/examples/ProtoNN/protoNN_example.ipynb +++ b/examples/pytorch/ProtoNN/protoNN_example.ipynb @@ -17,9 +17,9 @@ "import numpy as np\n", "import torch\n", "\n", - "from pytorch_edgeml.graph.protoNN import ProtoNN\n", - "from pytorch_edgeml.trainer.protoNNTrainer import ProtoNNTrainer\n", - "import pytorch_edgeml.utils as utils\n", + "from edgeml.pytorch.graph.protoNN import ProtoNN\n", + "from edgeml.pytorch.trainer.protoNNTrainer import ProtoNNTrainer\n", + "import edgeml.pytorch.utils as utils\n", "import helpermethods as helper" ] }, diff --git a/pytorch/examples/ProtoNN/protoNN_example.py b/examples/pytorch/ProtoNN/protoNN_example.py similarity index 95% rename from pytorch/examples/ProtoNN/protoNN_example.py rename to examples/pytorch/ProtoNN/protoNN_example.py index bcd65056..87330d22 100644 --- a/pytorch/examples/ProtoNN/protoNN_example.py +++ b/examples/pytorch/ProtoNN/protoNN_example.py @@ -5,9 +5,9 @@ from __future__ import print_function import sys import os import numpy as np -from pytorch_edgeml.trainer.protoNNTrainer import ProtoNNTrainer -from pytorch_edgeml.graph.protoNN import ProtoNN -import pytorch_edgeml.utils as utils +from edgeml.pytorch.trainer.protoNNTrainer import ProtoNNTrainer +from edgeml.pytorch.graph.protoNN import ProtoNN +import edgeml.pytorch.utils as utils import helpermethods as helper import torch diff --git a/pytorch/examples/SRNN/README.md b/examples/pytorch/SRNN/README.md similarity index 58% rename from pytorch/examples/SRNN/README.md rename to examples/pytorch/SRNN/README.md index 28f36c15..eefe57cd 100644 --- a/pytorch/examples/SRNN/README.md +++ b/examples/pytorch/SRNN/README.md @@ -1,12 +1,13 @@ # Pytorch Shallow RNN Examples -This directory includes an example [notebook](SRNN_Example.ipynb) of how to use -SRNN on the [Google Speech Commands +This directory includes an example [notebook](SRNN_Example.ipynb) and a +[python script](SRNN_Example.py) that explains the basic setup of SRNN by +training a simple model on the [Google Speech Commands Dataset](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html). -`pytorch_edgeml.graph.rnn.SRNN2` implements a 2 layer SRNN network. We will use +`edgeml.pytorch.graph.rnn.SRNN2` implements a 2 layer SRNN network. We will use this with an LSTM cell on this dataset. The training routine for SRNN is -implemented in `pytorch_edgeml.trainer.srnnTrainer` and will be used as part of +implemented in `edgeml.pytorch.trainer.srnnTrainer` and will be used as part of this example. **Tested With:** pytorch > 1.1.0 with Python 2 and Python 3 @@ -23,8 +24,17 @@ dataset and write numpy files that confirm to the required format. ./fetch_google.py python process_google.py -With the provided configuration, you can expect a validation accuracy of about -92%. +## Training the Model + +A sample [notebook](SRNN_Example.ipynb) and a corresponding command line script +is provided for training. To run the command line script, please use: + +``` +python SRNN_Example.py --data-dir ./GoogleSpeech/Extracted --brick-size 11 +``` + +With the provided default configuration, you can expect a validation accuracy +of about 92%. Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license. diff --git a/pytorch/examples/SRNN/SRNN_Example.ipynb b/examples/pytorch/SRNN/SRNN_Example.ipynb similarity index 55% rename from pytorch/examples/SRNN/SRNN_Example.ipynb rename to examples/pytorch/SRNN/SRNN_Example.ipynb index 109ce951..a0db3a7d 100644 --- a/pytorch/examples/SRNN/SRNN_Example.ipynb +++ b/examples/pytorch/SRNN/SRNN_Example.ipynb @@ -4,9 +4,10 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# SRNN on Speech Commands Dataset\n", + "## SRNN on Speech Commands Dataset\n", "\n", - "Please use `fetch_google.sh` to download the Google Speech Commands Dataset and `python process_google.py` to create feature extracted data." + "\n", + "Please use `fetch_google.sh` to download the Google Speech Commands Dataset and python `process_google.py` to create feature extracted data." ] }, { @@ -14,8 +15,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2019-07-14T12:52:51.914361Z", - "start_time": "2019-07-14T12:52:51.667856Z" + "end_time": "2019-07-17T17:19:40.265370Z", + "start_time": "2019-07-17T17:19:39.980681Z" } }, "outputs": [], @@ -24,12 +25,7 @@ "import sys\n", "import os\n", "import numpy as np\n", - "import torch\n", - "\n", - "from pytorch_edgeml.graph.rnn import SRNN2\n", - "from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer\n", - "import pytorch_edgeml.utils as utils" - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + "import torch" ] }, { @@ -37,23 +33,15 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2019-07-14T12:52:56.040100Z", - "start_time": "2019-07-14T12:52:51.916533Z" + "end_time": "2019-07-17T17:19:40.273918Z", + "start_time": "2019-07-17T17:19:40.267871Z" } }, "outputs": [], "source": [ - "DATA_DIR = './GoogleSpeech/Extracted/'\n", - "x_train_, y_train = np.squeeze(np.load(DATA_DIR + 'x_train.npy')), np.squeeze(np.load(DATA_DIR + 'y_train.npy'))\n", - "x_val_, y_val = np.squeeze(np.load(DATA_DIR + 'x_val.npy')), np.squeeze(np.load(DATA_DIR + 'y_val.npy'))\n", - "x_test_, y_test = np.squeeze(np.load(DATA_DIR + 'x_test.npy')), np.squeeze(np.load(DATA_DIR + 'y_test.npy'))\n", - "# Mean-var normalize\n", - "mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n", - "std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n", - "std[std[:] < 0.000001] = 1\n", - "x_train_ = (x_train_ - mean) / std\n", - "x_val_ = (x_val_ - mean) / std\n", - "x_test_ = (x_test_ - mean) / std" + "from edgeml.pytorch.graph.rnn import SRNN2\n", + "from edgeml.pytorch.trainer.srnnTrainer import SRNNTrainer\n", + "import edgeml.pytorch.utils as utils" ] }, { @@ -61,8 +49,23 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2019-07-14T12:52:56.047992Z", - "start_time": "2019-07-14T12:52:56.042445Z" + "end_time": "2019-07-17T17:19:40.340949Z", + "start_time": "2019-07-17T17:19:40.275892Z" + } + }, + "outputs": [], + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "DATA_DIR = './GoogleSpeech/Extracted/'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2019-07-17T17:19:45.603764Z", + "start_time": "2019-07-17T17:19:40.343295Z" } }, "outputs": [ @@ -77,6 +80,17 @@ } ], "source": [ + "x_train_, y_train = np.load(DATA_DIR + 'x_train.npy'), np.load(DATA_DIR + 'y_train.npy')\n", + "x_val_, y_val = np.load(DATA_DIR + 'x_val.npy'), np.load(DATA_DIR + 'y_val.npy')\n", + "x_test_, y_test = np.load(DATA_DIR + 'x_test.npy'), np.load(DATA_DIR + 'y_test.npy')\n", + "# Mean-var normalize\n", + "mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n", + "std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n", + "std[std[:] < 0.000001] = 1\n", + "x_train_ = (x_train_ - mean) / std\n", + "x_val_ = (x_val_ - mean) / std\n", + "x_test_ = (x_test_ - mean) / std\n", + "\n", "x_train = np.swapaxes(x_train_, 0, 1)\n", "x_val = np.swapaxes(x_val_, 0, 1)\n", "x_test = np.swapaxes(x_test_, 0, 1)\n", @@ -87,20 +101,21 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2019-07-14T12:52:56.068329Z", - "start_time": "2019-07-14T12:52:56.049725Z" + "end_time": "2019-07-17T17:19:45.610070Z", + "start_time": "2019-07-17T17:19:45.606282Z" } }, "outputs": [], "source": [ "numTimeSteps = x_train.shape[0]\n", "numInput = x_train.shape[-1]\n", - "brickSize = 11\n", "numClasses = y_train.shape[1]\n", "\n", + "# Network Parameters\n", + "brickSize = 11\n", "hiddenDim0 = 64\n", "hiddenDim1 = 32\n", "cellType = 'LSTM'\n", @@ -111,11 +126,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2019-07-14T12:52:56.088534Z", - "start_time": "2019-07-14T12:52:56.070114Z" + "end_time": "2019-07-17T17:19:49.576076Z", + "start_time": "2019-07-17T17:19:45.612629Z" } }, "outputs": [ @@ -128,17 +143,17 @@ } ], "source": [ - "srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device)\n", + "srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device) \n", "trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2019-07-14T12:59:52.893161Z", - "start_time": "2019-07-14T12:52:56.090327Z" + "end_time": "2019-07-17T17:20:28.680161Z", + "start_time": "2019-07-17T17:19:49.578246Z" } }, "outputs": [ @@ -146,28 +161,28 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0 batch 0 loss 2.799139 acc 0.031250\n", - "Epoch 0 batch 200 loss 0.784248 acc 0.750000\n", - "Epoch 1 batch 0 loss 0.379059 acc 0.875000\n", - "Epoch 1 batch 200 loss 0.544366 acc 0.820312\n", - "Epoch 2 batch 0 loss 0.272113 acc 0.914062\n", - "Epoch 2 batch 200 loss 0.400919 acc 0.867188\n", - "Epoch 3 batch 0 loss 0.200825 acc 0.953125\n", - "Epoch 3 batch 200 loss 0.248952 acc 0.906250\n", - "Epoch 4 batch 0 loss 0.161245 acc 0.960938\n", - "Epoch 4 batch 200 loss 0.294340 acc 0.875000\n", - "Validation accuracy: 0.913063\n", - "Epoch 5 batch 0 loss 0.159573 acc 0.953125\n", - "Epoch 5 batch 200 loss 0.233308 acc 0.937500\n", - "Epoch 6 batch 0 loss 0.068345 acc 0.984375\n", - "Epoch 6 batch 200 loss 0.225371 acc 0.937500\n", - "Epoch 7 batch 0 loss 0.112335 acc 0.968750\n", - "Epoch 7 batch 200 loss 0.170626 acc 0.945312\n", - "Epoch 8 batch 0 loss 0.168985 acc 0.945312\n", - "Epoch 8 batch 200 loss 0.160869 acc 0.937500\n", - "Epoch 9 batch 0 loss 0.123516 acc 0.953125\n", - "Epoch 9 batch 200 loss 0.172936 acc 0.937500\n", - "Validation accuracy: 0.908208\n" + "Epoch 0 batch 0 loss 4.295151 acc 0.031250\n", + "Epoch 0 batch 200 loss 1.002617 acc 0.718750\n", + "Epoch 1 batch 0 loss 0.647069 acc 0.796875\n", + "Epoch 1 batch 200 loss 0.469229 acc 0.835938\n", + "Epoch 2 batch 0 loss 0.388671 acc 0.882812\n", + "Epoch 2 batch 200 loss 0.396696 acc 0.859375\n", + "Epoch 3 batch 0 loss 0.266433 acc 0.921875\n", + "Epoch 3 batch 200 loss 0.281694 acc 0.867188\n", + "Epoch 4 batch 0 loss 0.302240 acc 0.906250\n", + "Epoch 4 batch 200 loss 0.245797 acc 0.929688\n", + "Validation accuracy: 0.911003\n", + "Epoch 5 batch 0 loss 0.202542 acc 0.945312\n", + "Epoch 5 batch 200 loss 0.192004 acc 0.929688\n", + "Epoch 6 batch 0 loss 0.256735 acc 0.921875\n", + "Epoch 6 batch 200 loss 0.279066 acc 0.921875\n", + "Epoch 7 batch 0 loss 0.228837 acc 0.945312\n", + "Epoch 7 batch 200 loss 0.222357 acc 0.937500\n", + "Epoch 8 batch 0 loss 0.164639 acc 0.960938\n", + "Epoch 8 batch 200 loss 0.160117 acc 0.945312\n", + "Epoch 9 batch 0 loss 0.173849 acc 0.953125\n", + "Epoch 9 batch 200 loss 0.201694 acc 0.929688\n", + "Validation accuracy: 0.912474\n" ] } ], diff --git a/pytorch/examples/SRNN/SRNN_Example.py b/examples/pytorch/SRNN/SRNN_Example.py similarity index 52% rename from pytorch/examples/SRNN/SRNN_Example.py rename to examples/pytorch/SRNN/SRNN_Example.py index 1a8eff67..737d7168 100644 --- a/pytorch/examples/SRNN/SRNN_Example.py +++ b/examples/pytorch/SRNN/SRNN_Example.py @@ -1,19 +1,28 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + from __future__ import print_function import sys import os import numpy as np import torch -from pytorch_edgeml.graph.rnn import SRNN2 -from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer -import pytorch_edgeml.utils as utils +from edgeml.pytorch.graph.rnn import SRNN2 +from edgeml.pytorch.trainer.srnnTrainer import SRNNTrainer +import edgeml.pytorch.utils as utils +import helpermethods as helper +config = helper.getSRNN2Args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -DATA_DIR = '/datadrive/data/SRNN/GoogleSpeech/Extracted/' -x_train_, y_train = np.squeeze(np.load(DATA_DIR + 'x_train.npy')), np.squeeze(np.load(DATA_DIR + 'y_train.npy')) -x_val_, y_val = np.squeeze(np.load(DATA_DIR + 'x_val.npy')), np.squeeze(np.load(DATA_DIR + 'y_val.npy')) -x_test_, y_test = np.squeeze(np.load(DATA_DIR + 'x_test.npy')), np.squeeze(np.load(DATA_DIR + 'y_test.npy')) +DATA_DIR = config.data_dir +x_train_ = np.load(DATA_DIR + 'x_train.npy') +y_train = np.load(DATA_DIR + 'y_train.npy') +x_val_ = np.load(DATA_DIR + 'x_val.npy') +y_val = np.load(DATA_DIR + 'y_val.npy') +x_test_ = np.load(DATA_DIR + 'x_test.npy') +y_test = np.load(DATA_DIR + 'y_test.npy') + # Mean-var normalize mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0) std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0) @@ -31,17 +40,20 @@ print("Test shape", x_test.shape, y_test.shape) numTimeSteps = x_train.shape[0] numInput = x_train.shape[-1] -brickSize = 11 numClasses = y_train.shape[1] -hiddenDim0 = 64 -hiddenDim1 = 32 -cellType = 'LSTM' -learningRate = 0.01 -batchSize = 128 -epochs = 10 +hiddenDim0 = config.hidden_dim0 +brickSize = config.brick_size +hiddenDim1 = config.hidden_dim1 +cellType = config.cell_type +learningRate = config.learning_rate +batchSize = config.batch_size +epochs = config.epochs +printStep = config.print_step +valStep = config.val_step -srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device) +srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device) trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device) -trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val, printStep=200, valStep=5) \ No newline at end of file +trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val, + printStep=printStep, valStep=valStep) diff --git a/pytorch/examples/SRNN/fetch_google.sh b/examples/pytorch/SRNN/fetch_google.sh old mode 100755 new mode 100644 similarity index 100% rename from pytorch/examples/SRNN/fetch_google.sh rename to examples/pytorch/SRNN/fetch_google.sh diff --git a/examples/pytorch/SRNN/helpermethods.py b/examples/pytorch/SRNN/helpermethods.py new file mode 100644 index 00000000..a0ac3a72 --- /dev/null +++ b/examples/pytorch/SRNN/helpermethods.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import argparse + +def getSRNN2Args(): + def checkIntPos(value): + ivalue = int(value) + if ivalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive int value" % value) + return ivalue + + def checkIntNneg(value): + ivalue = int(value) + if ivalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg int value" % value) + return ivalue + + def checkFloatNneg(value): + fvalue = float(value) + if fvalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg float value" % value) + return fvalue + + def checkFloatPos(value): + fvalue = float(value) + if fvalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive float value" % value) + return fvalue + + parser = argparse.ArgumentParser( + description='Hyperparameters for 2 layer SRNN Algorithm') + + parser.add_argument('-d', '--data-dir', required=True, + help='Directory containing processed data.') + parser.add_argument('-h0', '--hidden-dim0', type=checkIntPos, default=64, + help='Hidden dimension of lower layer RNN cell.') + parser.add_argument('-h1', '--hidden-dim1', type=checkIntPos, default=32, + help='Hidden dimension of upper layer RNN cell.') + parser.add_argument('-bz', '--brick-size', type=checkIntPos, required=True, + help='Brick size to be used at the lower layer.') + parser.add_argument('-c', '--cell-type', default='LSTM', + help='Type of RNN cell to use among [LSTM, FastRNN, ' + + 'FastGRNN') + + parser.add_argument('-p', '--num-prototypes', type=checkIntPos, default=20, + help='Number of prototypes.') + parser.add_argument('-g', '--gamma', type=checkFloatPos, default=None, + help='Gamma for Gaussian kernel. If not provided, ' + + 'median heuristic will be used to estimate gamma.') + + parser.add_argument('-e', '--epochs', type=checkIntPos, default=10, + help='Total training epochs.') + parser.add_argument('-b', '--batch-size', type=checkIntPos, default=128, + help='Batch size for each pass.') + parser.add_argument('-r', '--learning-rate', type=checkFloatPos, + default=0.01, + help='Learning rate for ADAM Optimizer.') + + parser.add_argument('-pS', '--print-step', type=int, default=200, + help='The number of update steps between print ' + + 'calls to console.') + parser.add_argument('-vS', '--val-step', type=int, default=5, + help='The number of epochs between validation' + + 'performance evaluation') + return parser.parse_args() diff --git a/pytorch/examples/SRNN/process_google.py b/examples/pytorch/SRNN/process_google.py similarity index 100% rename from pytorch/examples/SRNN/process_google.py rename to examples/pytorch/SRNN/process_google.py diff --git a/tf/examples/Bonsai/README.md b/examples/tf/Bonsai/README.md similarity index 97% rename from tf/examples/Bonsai/README.md rename to examples/tf/Bonsai/README.md index 91cb0021..9468d0df 100644 --- a/tf/examples/Bonsai/README.md +++ b/examples/tf/Bonsai/README.md @@ -4,7 +4,7 @@ This directory includes, example notebook and general execution script of Bonsai developed as part of EdgeML. Also, we include a sample cleanup and use-case on the USPS10 public dataset. -`edgeml.graph.bonsai` implements the Bonsai prediction graph in tensorflow. +`edgeml.tf.graph.bonsai` implements the Bonsai prediction graph in tensorflow. The three-phase training routine for Bonsai is decoupled from the forward graph to facilitate a plug and play behaviour wherein Bonsai can be combined with or used as a final layer classifier for other architectures (RNNs, CNNs). diff --git a/tf/examples/Bonsai/bonsai_example.ipynb b/examples/tf/Bonsai/bonsai_example.ipynb similarity index 99% rename from tf/examples/Bonsai/bonsai_example.ipynb rename to examples/tf/Bonsai/bonsai_example.ipynb index 1935fd2b..03b7f101 100644 --- a/tf/examples/Bonsai/bonsai_example.ipynb +++ b/examples/tf/Bonsai/bonsai_example.ipynb @@ -33,8 +33,8 @@ "os.environ['CUDA_VISIBLE_DEVICES'] =''\n", "\n", "#Bonsai imports\n", - "from edgeml.trainer.bonsaiTrainer import BonsaiTrainer\n", - "from edgeml.graph.bonsai import Bonsai\n", + "from edgeml.tf.trainer.bonsaiTrainer import BonsaiTrainer\n", + "from edgeml.tf.graph.bonsai import Bonsai\n", "\n", "# Fixing seeds for reproducibility\n", "tf.set_random_seed(42)\n", diff --git a/tf/examples/Bonsai/bonsai_example.py b/examples/tf/Bonsai/bonsai_example.py similarity index 95% rename from tf/examples/Bonsai/bonsai_example.py rename to examples/tf/Bonsai/bonsai_example.py index a1b96e8a..61a542e5 100644 --- a/tf/examples/Bonsai/bonsai_example.py +++ b/examples/tf/Bonsai/bonsai_example.py @@ -5,8 +5,8 @@ import helpermethods import tensorflow as tf import numpy as np import sys -from edgeml.trainer.bonsaiTrainer import BonsaiTrainer -from edgeml.graph.bonsai import Bonsai +from edgeml.tf.trainer.bonsaiTrainer import BonsaiTrainer +from edgeml.tf.graph.bonsai import Bonsai def main(): diff --git a/tf/examples/Bonsai/fetch_usps.py b/examples/tf/Bonsai/fetch_usps.py similarity index 100% rename from tf/examples/Bonsai/fetch_usps.py rename to examples/tf/Bonsai/fetch_usps.py diff --git a/tf/examples/Bonsai/helpermethods.py b/examples/tf/Bonsai/helpermethods.py similarity index 100% rename from tf/examples/Bonsai/helpermethods.py rename to examples/tf/Bonsai/helpermethods.py diff --git a/tf/examples/Bonsai/process_usps.py b/examples/tf/Bonsai/process_usps.py similarity index 100% rename from tf/examples/Bonsai/process_usps.py rename to examples/tf/Bonsai/process_usps.py diff --git a/tf/examples/Bonsai/quantizeBonsaiModels.py b/examples/tf/Bonsai/quantizeBonsaiModels.py similarity index 100% rename from tf/examples/Bonsai/quantizeBonsaiModels.py rename to examples/tf/Bonsai/quantizeBonsaiModels.py diff --git a/tf/examples/EMI-RNN/00_emi_lstm_example.ipynb b/examples/tf/EMI-RNN/00_emi_lstm_example.ipynb similarity index 99% rename from tf/examples/EMI-RNN/00_emi_lstm_example.ipynb rename to examples/tf/EMI-RNN/00_emi_lstm_example.ipynb index fd957f1e..d979845a 100644 --- a/tf/examples/EMI-RNN/00_emi_lstm_example.ipynb +++ b/examples/tf/EMI-RNN/00_emi_lstm_example.ipynb @@ -40,10 +40,10 @@ "tf.set_random_seed(42)\n", "\n", "# MI-RNN and EMI-RNN imports\n", - "from edgeml.graph.rnn import EMI_DataPipeline\n", - "from edgeml.graph.rnn import EMI_BasicLSTM\n", - "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n", - "import edgeml.utils" + "from edgeml.tf.graph.rnn import EMI_DataPipeline\n", + "from edgeml.tf.graph.rnn import EMI_BasicLSTM\n", + "from edgeml.tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n", + "import edgeml.tf.utils" ] }, { diff --git a/tf/examples/EMI-RNN/01_emi_fastgrnn_example.ipynb b/examples/tf/EMI-RNN/01_emi_fastgrnn_example.ipynb similarity index 99% rename from tf/examples/EMI-RNN/01_emi_fastgrnn_example.ipynb rename to examples/tf/EMI-RNN/01_emi_fastgrnn_example.ipynb index 1eaa6d13..e3d04720 100644 --- a/tf/examples/EMI-RNN/01_emi_fastgrnn_example.ipynb +++ b/examples/tf/EMI-RNN/01_emi_fastgrnn_example.ipynb @@ -34,11 +34,11 @@ "os.environ['CUDA_VISIBLE_DEVICES'] ='1'\n", "\n", "# FastGRNN and FastRNN imports\n", - "from edgeml.graph.rnn import EMI_DataPipeline\n", - "from edgeml.graph.rnn import EMI_FastGRNN\n", - "from edgeml.graph.rnn import EMI_FastRNN\n", - "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n", - "import edgeml.utils" + "from edgeml.tf.graph.rnn import EMI_DataPipeline\n", + "from edgeml.tf.graph.rnn import EMI_FastGRNN\n", + "from edgeml.tf.graph.rnn import EMI_FastRNN\n", + "from edgeml.tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n", + "import edgeml.tf.utils" ] }, { diff --git a/tf/examples/EMI-RNN/02_emi_lstm_initialization_and_restoring.ipynb b/examples/tf/EMI-RNN/02_emi_lstm_initialization_and_restoring.ipynb similarity index 98% rename from tf/examples/EMI-RNN/02_emi_lstm_initialization_and_restoring.ipynb rename to examples/tf/EMI-RNN/02_emi_lstm_initialization_and_restoring.ipynb index d9b414c2..5bb73052 100644 --- a/tf/examples/EMI-RNN/02_emi_lstm_initialization_and_restoring.ipynb +++ b/examples/tf/EMI-RNN/02_emi_lstm_initialization_and_restoring.ipynb @@ -35,10 +35,10 @@ "os.environ['CUDA_VISIBLE_DEVICES'] ='0'\n", "\n", "# MI-RNN and EMI-RNN imports\n", - "from edgeml.graph.rnn import EMI_DataPipeline\n", - "from edgeml.graph.rnn import EMI_BasicLSTM\n", - "from edgeml.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n", - "import edgeml.utils" + "from edgeml.tf.graph.rnn import EMI_DataPipeline\n", + "from edgeml.tf.graph.rnn import EMI_BasicLSTM\n", + "from edgeml.tf.trainer.emirnnTrainer import EMI_Trainer, EMI_Driver\n", + "import edgeml.tf.utils" ] }, { diff --git a/tf/examples/EMI-RNN/README.md b/examples/tf/EMI-RNN/README.md similarity index 92% rename from tf/examples/EMI-RNN/README.md rename to examples/tf/EMI-RNN/README.md index f4261991..c263acae 100644 --- a/tf/examples/EMI-RNN/README.md +++ b/examples/tf/EMI-RNN/README.md @@ -3,7 +3,7 @@ This directory includes example notebooks EMI-RNN developed as part of EdgeML. The example is based on the UCI Human Activity Recognition dataset. -Please refer to `tf/docs/EMI-RNN.md` for detailed documentation of EMI-RNN. +Please refer to `docs/EMI-RNN.md` for detailed documentation of EMI-RNN. Please refer to `00_emi_lstm_example.ipynb` for a quick and dirty getting started guide. diff --git a/tf/examples/EMI-RNN/fetch_har.py b/examples/tf/EMI-RNN/fetch_har.py similarity index 100% rename from tf/examples/EMI-RNN/fetch_har.py rename to examples/tf/EMI-RNN/fetch_har.py diff --git a/tf/examples/EMI-RNN/helpermethods.py b/examples/tf/EMI-RNN/helpermethods.py similarity index 100% rename from tf/examples/EMI-RNN/helpermethods.py rename to examples/tf/EMI-RNN/helpermethods.py diff --git a/tf/examples/EMI-RNN/img/3PartsGraph.png b/examples/tf/EMI-RNN/img/3PartsGraph.png old mode 100755 new mode 100644 similarity index 100% rename from tf/examples/EMI-RNN/img/3PartsGraph.png rename to examples/tf/EMI-RNN/img/3PartsGraph.png diff --git a/tf/examples/EMI-RNN/img/MIML_illustration.png b/examples/tf/EMI-RNN/img/MIML_illustration.png old mode 100755 new mode 100644 similarity index 100% rename from tf/examples/EMI-RNN/img/MIML_illustration.png rename to examples/tf/EMI-RNN/img/MIML_illustration.png diff --git a/tf/examples/EMI-RNN/process_har.py b/examples/tf/EMI-RNN/process_har.py similarity index 100% rename from tf/examples/EMI-RNN/process_har.py rename to examples/tf/EMI-RNN/process_har.py diff --git a/tf/examples/FastCells/README.md b/examples/tf/FastCells/README.md similarity index 91% rename from tf/examples/FastCells/README.md rename to examples/tf/FastCells/README.md index c6b6abe6..6c136b0c 100644 --- a/tf/examples/FastCells/README.md +++ b/examples/tf/FastCells/README.md @@ -5,7 +5,7 @@ FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified UGRNN, GRU and LSTM to support the LSQ training routine. Also, we include a sample cleanup and use-case on the USPS10 public dataset. -`edgeml.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with +`edgeml.tf.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with multiple additional features like Low-Rank parameterisation, custom non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training routine for FastRNN and FastGRNN is decoupled from the custom cells to @@ -14,9 +14,9 @@ architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `G `edgeml.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../edgeml/graph/rnn.py#L862)), **GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L635)) and **LSTM** ([`LSTMLRCell`](../../edgeml/graph/rnn.py#L376)). These cells also can be substituted for FastCells where ever feasible. -For training FastCells, `edgeml.trainer.fastTrainer` implements the three-phase +For training FastCells, `edgeml.tf.trainer.fastTrainer` implements the three-phase FastCell training routine in Tensorflow. A simple example, -`examples/fastcell_example.py` is provided to illustrate its usage. +`examples/tf/fastcell_example.py` is provided to illustrate its usage. Note that `fastcell_example.py` assumes that data is in a specific format. It is assumed that train and test data is contained in two files, `train.npy` and diff --git a/tf/examples/FastCells/fastcell_example.ipynb b/examples/tf/FastCells/fastcell_example.ipynb similarity index 99% rename from tf/examples/FastCells/fastcell_example.ipynb rename to examples/tf/FastCells/fastcell_example.ipynb index d1d59ee8..82d6c597 100644 --- a/tf/examples/FastCells/fastcell_example.ipynb +++ b/examples/tf/FastCells/fastcell_example.ipynb @@ -28,12 +28,12 @@ "os.environ['CUDA_VISIBLE_DEVICES'] =''\n", "\n", "#FastRNN and FastGRNN imports\n", - "from edgeml.trainer.fastTrainer import FastTrainer\n", - "from edgeml.graph.rnn import FastGRNNCell\n", - "from edgeml.graph.rnn import FastRNNCell\n", - "from edgeml.graph.rnn import UGRNNLRCell\n", - "from edgeml.graph.rnn import GRULRCell\n", - "from edgeml.graph.rnn import LSTMLRCell\n", + "from edgeml.tf.trainer.fastTrainer import FastTrainer\n", + "from edgeml.tf.graph.rnn import FastGRNNCell\n", + "from edgeml.tf.graph.rnn import FastRNNCell\n", + "from edgeml.tf.graph.rnn import UGRNNLRCell\n", + "from edgeml.tf.graph.rnn import GRULRCell\n", + "from edgeml.tf.graph.rnn import LSTMLRCell\n", "\n", "# Fixing seeds for reproducibility\n", "tf.set_random_seed(42)\n", diff --git a/tf/examples/FastCells/fastcell_example.py b/examples/tf/FastCells/fastcell_example.py similarity index 91% rename from tf/examples/FastCells/fastcell_example.py rename to examples/tf/FastCells/fastcell_example.py index 8d3ef5ce..10e6c0ce 100644 --- a/tf/examples/FastCells/fastcell_example.py +++ b/examples/tf/FastCells/fastcell_example.py @@ -6,12 +6,12 @@ import tensorflow as tf import numpy as np import sys -from edgeml.trainer.fastTrainer import FastTrainer -from edgeml.graph.rnn import FastGRNNCell -from edgeml.graph.rnn import FastRNNCell -from edgeml.graph.rnn import UGRNNLRCell -from edgeml.graph.rnn import GRULRCell -from edgeml.graph.rnn import LSTMLRCell +from edgeml.tf.trainer.fastTrainer import FastTrainer +from edgeml.tf.graph.rnn import FastGRNNCell +from edgeml.tf.graph.rnn import FastRNNCell +from edgeml.tf.graph.rnn import UGRNNLRCell +from edgeml.tf.graph.rnn import GRULRCell +from edgeml.tf.graph.rnn import LSTMLRCell def main(): diff --git a/tf/examples/FastCells/fetch_usps.py b/examples/tf/FastCells/fetch_usps.py similarity index 100% rename from tf/examples/FastCells/fetch_usps.py rename to examples/tf/FastCells/fetch_usps.py diff --git a/tf/examples/FastCells/helpermethods.py b/examples/tf/FastCells/helpermethods.py similarity index 100% rename from tf/examples/FastCells/helpermethods.py rename to examples/tf/FastCells/helpermethods.py diff --git a/tf/examples/FastCells/process_usps.py b/examples/tf/FastCells/process_usps.py similarity index 100% rename from tf/examples/FastCells/process_usps.py rename to examples/tf/FastCells/process_usps.py diff --git a/tf/examples/FastCells/quantizeFastModels.py b/examples/tf/FastCells/quantizeFastModels.py similarity index 100% rename from tf/examples/FastCells/quantizeFastModels.py rename to examples/tf/FastCells/quantizeFastModels.py diff --git a/tf/examples/ProtoNN/README.md b/examples/tf/ProtoNN/README.md similarity index 92% rename from tf/examples/ProtoNN/README.md rename to examples/tf/ProtoNN/README.md index d0137ac4..8f9f78cb 100644 --- a/tf/examples/ProtoNN/README.md +++ b/examples/tf/ProtoNN/README.md @@ -4,11 +4,11 @@ This directory includes an example [notebook](protoNN_example.ipynb) and a command line execution script of ProtoNN developed as part of EdgeML. The example is based on the USPS dataset. -`edgeml.graph.protoNN` implements the ProtoNN prediction graph in Tensorflow. +`edgeml.tf.graph.protoNN` implements the ProtoNN prediction graph in Tensorflow. The training routine for ProtoNN is decoupled from the forward graph to facilitate a plug and play behaviour wherein ProtoNN can be combined with or used as a final layer classifier for other architectures (RNNs, CNNs). The -training routine is implemented in `edgeml.trainer.protoNNTrainer`. +training routine is implemented in `edgeml.tf.trainer.protoNNTrainer`. Note that, `protoNN_example.py` assumes the data to be in a specific format. It is assumed that train and test data is contained in two files, `train.npy` and diff --git a/tf/examples/ProtoNN/fetch_usps.py b/examples/tf/ProtoNN/fetch_usps.py similarity index 100% rename from tf/examples/ProtoNN/fetch_usps.py rename to examples/tf/ProtoNN/fetch_usps.py diff --git a/tf/examples/ProtoNN/helpermethods.py b/examples/tf/ProtoNN/helpermethods.py similarity index 99% rename from tf/examples/ProtoNN/helpermethods.py rename to examples/tf/ProtoNN/helpermethods.py index 1bd38282..1eed9f97 100644 --- a/tf/examples/ProtoNN/helpermethods.py +++ b/examples/tf/ProtoNN/helpermethods.py @@ -6,7 +6,7 @@ import sys import os import numpy as np import tensorflow as tf -import edgeml.utils as utils +import edgeml.tf.utils as utils import argparse diff --git a/tf/examples/ProtoNN/process_usps.py b/examples/tf/ProtoNN/process_usps.py similarity index 100% rename from tf/examples/ProtoNN/process_usps.py rename to examples/tf/ProtoNN/process_usps.py diff --git a/tf/examples/ProtoNN/protoNN_example.ipynb b/examples/tf/ProtoNN/protoNN_example.ipynb similarity index 99% rename from tf/examples/ProtoNN/protoNN_example.ipynb rename to examples/tf/ProtoNN/protoNN_example.ipynb index 9581b97e..26f1d6ae 100644 --- a/tf/examples/ProtoNN/protoNN_example.ipynb +++ b/examples/tf/ProtoNN/protoNN_example.ipynb @@ -29,9 +29,9 @@ "import numpy as np\n", "import tensorflow as tf\n", "\n", - "from edgeml.trainer.protoNNTrainer import ProtoNNTrainer\n", - "from edgeml.graph.protoNN import ProtoNN\n", - "import edgeml.utils as utils\n", + "from edgeml.tf.trainer.protoNNTrainer import ProtoNNTrainer\n", + "from edgeml.tf.graph.protoNN import ProtoNN\n", + "import edgeml.tf.utils as utils\n", "import helpermethods as helper" ] }, diff --git a/tf/examples/ProtoNN/protoNN_example.py b/examples/tf/ProtoNN/protoNN_example.py similarity index 95% rename from tf/examples/ProtoNN/protoNN_example.py rename to examples/tf/ProtoNN/protoNN_example.py index c98cce29..ed5a14fa 100644 --- a/tf/examples/ProtoNN/protoNN_example.py +++ b/examples/tf/ProtoNN/protoNN_example.py @@ -6,9 +6,9 @@ import sys import os import numpy as np import tensorflow as tf -from edgeml.trainer.protoNNTrainer import ProtoNNTrainer -from edgeml.graph.protoNN import ProtoNN -import edgeml.utils as utils +from edgeml.tf.trainer.protoNNTrainer import ProtoNNTrainer +from edgeml.tf.graph.protoNN import ProtoNN +import edgeml.tf.utils as utils import helpermethods as helper diff --git a/pytorch/README.md b/pytorch/README.md deleted file mode 100644 index 06c3c5b1..00000000 --- a/pytorch/README.md +++ /dev/null @@ -1,46 +0,0 @@ -## Edge Machine Learning: PyTorch Library - -This directory includes, PyTorch implementations of various techniques and -algorithms developed as part of EdgeML. Currently, the following algorithms are -available in PyTorch: - -1. [Bonsai](../docs/publications/Bonsai.pdf) -2. [FastRNN & FastGRNN](../docs/publications/FastGRNN.pdf) - -The PyTorch compute graphs for these algoriths are packaged as -`pytorch_edgeml.graph`. Trainers for these algorithms are in `pytorch_edgeml.trainer`. Usage -directions and examples for these algorithms are provided in `examples` -directory. To get started with any of the provided algorithms, please follow -the notebooks in the the `examples` directory. - -## Installation - -Use pip and the provided requirements file to first install required -dependencies before installing the `pytorch_edgeml` library. Details for installation provided below. - -It is highly recommended that EdgeML be installed in a virtual environment. Please create -a new virtual environment using your environment manager ([virtualenv](https://virtualenv.pypa.io/en/stable/userguide/#usage) or [Anaconda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands)). -Make sure the new environment is active before running the below mentioned commands. - -### CPU - -``` -pip install -r requirements-cpu.txt -pip install -e . -``` - -Tested on Python 3.6 with PyTorch 1.1. - -### GPU - -Install appropriate CUDA and cuDNN [Tested with >= CUDA 9.0 and cuDNN >= 7.0] - -``` -pip install -r requirements-gpu.txt -pip install -e . -``` - -Note: If the above commands don't go through for PyTorch installation on CPU and GPU, please follow this [link](https://pytorch.org/get-started/locally/). - -Copyright (c) Microsoft Corporation. All rights reserved. -Licensed under the MIT license. diff --git a/pytorch/setup.py b/pytorch/setup.py deleted file mode 100644 index ccfb4b93..00000000 --- a/pytorch/setup.py +++ /dev/null @@ -1,9 +0,0 @@ -from distutils.core import setup - -setup( - name='pytorch_edgeml', - version='0.2', - packages=['pytorch_edgeml', ], - license='MIT License', - long_description=open('../License.txt').read(), -) diff --git a/pytorch/requirements-cpu.txt b/requirements-cpu-pytorch.txt similarity index 100% rename from pytorch/requirements-cpu.txt rename to requirements-cpu-pytorch.txt diff --git a/tf/requirements-cpu.txt b/requirements-cpu-tf.txt similarity index 100% rename from tf/requirements-cpu.txt rename to requirements-cpu-tf.txt diff --git a/pytorch/requirements-gpu.txt b/requirements-gpu-pytorch.txt similarity index 100% rename from pytorch/requirements-gpu.txt rename to requirements-gpu-pytorch.txt diff --git a/tf/requirements-gpu.txt b/requirements-gpu-tf.txt similarity index 100% rename from tf/requirements-gpu.txt rename to requirements-gpu-tf.txt diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..2eebec73 --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +import setuptools #enables develop + +setuptools.setup( + name='edgeml', + version='0.2', + description='machine learning algorithms for edge devices developed at Microsoft Research India.', + packages=setuptools.find_packages(), + license='MIT License', + long_description=open('License.txt').read(), + url='https://github.com/Microsoft/EdgeML', +) diff --git a/tf/setup.py b/tf/setup.py deleted file mode 100644 index dfb6fac4..00000000 --- a/tf/setup.py +++ /dev/null @@ -1,9 +0,0 @@ -from distutils.core import setup - -setup( - name='edgeml', - version='0.2', - packages=['edgeml', ], - license='MIT License', - long_description=open('../License.txt').read(), -)