diff --git a/edgeml/pytorch/graph/rnn.py b/edgeml/pytorch/graph/rnn.py index 1fb452b2..945ce808 100644 --- a/edgeml/pytorch/graph/rnn.py +++ b/edgeml/pytorch/graph/rnn.py @@ -6,6 +6,8 @@ import torch.nn as nn from torch.autograd import Function import numpy as np +import edgeml.pytorch.utils as utils + def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank, gate_nonlinearity, update_nonlinearity): class RNNSymbolic(Function): @staticmethod @@ -196,6 +198,19 @@ class FastGRNNCell(nn.Module): self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1])) self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1])) + def sparsify(self, wsp, usp): + if self._wRank is None: + utils.hardThreshold(self.W, wsp) + else: + utils.hardThreshold(self.W1, wsp) + utils.hardThreshold(self.W2, wsp) + + if self._uRank is None: + utils.hardThreshold(self.U, usp) + else: + utils.hardThreshold(self.U1, usp) + utils.hardThreshold(self.U2, usp) + @property def state_size(self): return self._hidden_size @@ -340,6 +355,20 @@ class FastRNNCell(nn.Module): self.alpha = nn.Parameter(self._alphaInit * torch.ones([1, 1])) self.beta = nn.Parameter(self._betaInit * torch.ones([1, 1])) + def sparsify(self, wsp, usp): + if self._wRank is None: + utils.hardThreshold(self.W, wsp) + else: + utils.hardThreshold(self.W1, wsp) + utils.hardThreshold(self.W2, wsp) + + if self._uRank is None: + utils.hardThreshold(self.U, usp) + else: + utils.hardThreshold(self.U1, usp) + utils.hardThreshold(self.U2, usp) + + @property def state_size(self): return self._hidden_size diff --git a/edgeml/pytorch/trainer/fastmodel.py b/edgeml/pytorch/trainer/fastmodel.py index 83720347..cb6923bb 100644 --- a/edgeml/pytorch/trainer/fastmodel.py +++ b/edgeml/pytorch/trainer/fastmodel.py @@ -72,6 +72,14 @@ def fastgrnnmodel(inheritance_class=nn.Module): self.hidden2keyword = nn.Linear(last_output_size, num_classes) self.init_hidden() + def sparsify(self, wsp, usp): + self.fastgrnn1.cell.sparsify(wsp, usp) + if self.num_layers > 1: + self.fastgrnn2.cell.sparsify(wsp, usp) + if self.num_layers > 2: + self.fastgrnn3.cell.sparsify(wsp, usp) + + def normalize(self, mean, std): self.mean = mean self.std = std diff --git a/edgeml/pytorch/utils.py b/edgeml/pytorch/utils.py index ce9ac2ad..2a60282c 100644 --- a/edgeml/pytorch/utils.py +++ b/edgeml/pytorch/utils.py @@ -41,17 +41,21 @@ def binaryHingeLoss(logits, labels): return torch.mean(F.relu(1.0 - (2 * labels - 1) * logits)) -def hardThreshold(A, s): +def hardThreshold(A: torch.nn.Parameter, s): ''' - Hard thresholding function on Tensor A with sparsity s + Hard thresholds and modifies in-palce nn.Parameter A with sparsity s ''' - A_ = np.copy(A) + #PyTorch disallows numpy access/copy to tensors in graph. + #.detach() creates a new tensor not attached to the graph. + A_ = A.detach().numpy() + A_ = A_.ravel() if len(A_) > 0: th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher') A_[np.abs(A_) < th] = 0.0 A_ = A_.reshape(A.shape) - return A_ + + A.data = torch.tensor(A_, requires_grad=True) def copySupport(src, dest): diff --git a/examples/pytorch/FastCells/train_classifier.py b/examples/pytorch/FastCells/train_classifier.py index c88618fe..414e107d 100644 --- a/examples/pytorch/FastCells/train_classifier.py +++ b/examples/pytorch/FastCells/train_classifier.py @@ -136,7 +136,7 @@ class KeywordSpotter(nn.Module): scheduler = ExponentialResettingLR(optimizer, gamma, reset) return scheduler - def fit(self, training_data, validation_data, options, model, device=None, detail=False, run=None): + def fit(self, training_data, validation_data, training_options, model_options, device=None, detail=False, run=None): """ Perform the training. This is not called "train" because the base class already defines that method with a different meaning. @@ -156,29 +156,29 @@ class KeywordSpotter(nn.Module): self.training = True start = time.time() loss_function = nn.NLLLoss() - optimizer = self.configure_optimizer(options) + optimizer = self.configure_optimizer(training_options) print(optimizer) - num_epochs = options.max_epochs - batch_size = options.batch_size + num_epochs = training_options.max_epochs + batch_size = training_options.batch_size ticks = training_data.num_rows / batch_size # iterations per epoch total_iterations = ticks * num_epochs - scheduler = self.configure_lr(options, optimizer, ticks, total_iterations) + scheduler = self.configure_lr(training_options, optimizer, ticks, total_iterations) # optimizer = optim.Adam(model.parameters(), lr=0.0001) log = [] - if options.rolling: + if training_options.rolling: rolling_length = 2 max_rolling_length = int(ticks) - if max_rolling_length > options.max_rolling_length: - max_rolling_length = options.max_rolling_length + if max_rolling_length > training_options.max_rolling_length: + max_rolling_length = training_options.max_rolling_length bag_count = 100 hidden_bag_size = batch_size * bag_count for epoch in range(num_epochs): self.train() - if options.rolling: + if training_options.rolling: rolling_length += 1 if rolling_length < max_rolling_length: self.init_hidden_bag(hidden_bag_size, device) @@ -192,7 +192,7 @@ class KeywordSpotter(nn.Module): # Also, we need to clear out the hidden state, # detaching it from its history on the last instance. - if options.rolling: + if training_options.rolling: if rolling_length < max_rolling_length: if (i_batch + 1) % rolling_length == 0: self.init_hidden() @@ -228,6 +228,9 @@ class KeywordSpotter(nn.Module): # applying the gradients we computed during back propagation optimizer.step() + if model_options.wSparsity < 1.0 or model_options.uSparsity < 1.0: + self.sparsify(model_options.wSparsity, model_options.uSparsity) + learning_rate = optimizer.param_groups[0]['lr'] if detail: learning_rate = optimizer.param_groups[0]['lr'] @@ -543,6 +546,8 @@ if __name__ == '__main__': parser.add_argument("--uRank2", "-ur2", help="Rank of U in 2nd layer of FastGRNN default is None", type=int) parser.add_argument("--wRank3", "-wr3", help="Rank of W in 3rd layer of FastGRNN default is None", type=int) parser.add_argument("--uRank3", "-ur3", help="Rank of U in 3rd layer of FastGRNN default is None", type=int) + parser.add_argument("--wSparsity", "-wsp", help="Sparsity of W matrices", type=float) + parser.add_argument("--uSparsity", "-usp", help="Sparsity of U matrices", type=float) parser.add_argument("--gate_nonlinearity", "-gnl", help="Gate Non-Linearity in FastGRNN default is sigmoid" " use between [sigmoid, quantSigmoid, tanh, quantTanh]") parser.add_argument("--update_nonlinearity", "-unl", help="Update Non-Linearity in FastGRNN default is Tanh" @@ -637,6 +642,14 @@ if __name__ == '__main__': config.model.wRank3 = args.wRank3 if args.uRank3: config.model.uRank3 = args.wRank3 + if args.wSparsity: + config.model.wSparsity = args.wSparsity + else: + config.model.wSparsity = 1.0 + if args.uSparsity: + config.model.uSparsity = args.uSparsity + else: + config.model.uSparsity = 1.0 if args.gate_nonlinearity: config.model.gate_nonlinearity = args.gate_nonlinearity if args.update_nonlinearity: