added sparsity and hard thresholding support to train_classifier and Fast(G)RNN cells and models

This commit is contained in:
Harsha Vardhan Simhadri 2019-07-29 17:23:26 +05:30
Родитель c90b28e13e
Коммит b8b4dd70a8
4 изменённых файлов: 68 добавлений и 14 удалений

Просмотреть файл

@ -6,6 +6,8 @@ import torch.nn as nn
from torch.autograd import Function
import numpy as np
import edgeml.pytorch.utils as utils
def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank, gate_nonlinearity, update_nonlinearity):
class RNNSymbolic(Function):
@staticmethod
@ -196,6 +198,19 @@ class FastGRNNCell(nn.Module):
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
def sparsify(self, wsp, usp):
if self._wRank is None:
utils.hardThreshold(self.W, wsp)
else:
utils.hardThreshold(self.W1, wsp)
utils.hardThreshold(self.W2, wsp)
if self._uRank is None:
utils.hardThreshold(self.U, usp)
else:
utils.hardThreshold(self.U1, usp)
utils.hardThreshold(self.U2, usp)
@property
def state_size(self):
return self._hidden_size
@ -340,6 +355,20 @@ class FastRNNCell(nn.Module):
self.alpha = nn.Parameter(self._alphaInit * torch.ones([1, 1]))
self.beta = nn.Parameter(self._betaInit * torch.ones([1, 1]))
def sparsify(self, wsp, usp):
if self._wRank is None:
utils.hardThreshold(self.W, wsp)
else:
utils.hardThreshold(self.W1, wsp)
utils.hardThreshold(self.W2, wsp)
if self._uRank is None:
utils.hardThreshold(self.U, usp)
else:
utils.hardThreshold(self.U1, usp)
utils.hardThreshold(self.U2, usp)
@property
def state_size(self):
return self._hidden_size

Просмотреть файл

@ -72,6 +72,14 @@ def fastgrnnmodel(inheritance_class=nn.Module):
self.hidden2keyword = nn.Linear(last_output_size, num_classes)
self.init_hidden()
def sparsify(self, wsp, usp):
self.fastgrnn1.cell.sparsify(wsp, usp)
if self.num_layers > 1:
self.fastgrnn2.cell.sparsify(wsp, usp)
if self.num_layers > 2:
self.fastgrnn3.cell.sparsify(wsp, usp)
def normalize(self, mean, std):
self.mean = mean
self.std = std

Просмотреть файл

@ -41,17 +41,21 @@ def binaryHingeLoss(logits, labels):
return torch.mean(F.relu(1.0 - (2 * labels - 1) * logits))
def hardThreshold(A, s):
def hardThreshold(A: torch.nn.Parameter, s):
'''
Hard thresholding function on Tensor A with sparsity s
Hard thresholds and modifies in-palce nn.Parameter A with sparsity s
'''
A_ = np.copy(A)
#PyTorch disallows numpy access/copy to tensors in graph.
#.detach() creates a new tensor not attached to the graph.
A_ = A.detach().numpy()
A_ = A_.ravel()
if len(A_) > 0:
th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher')
A_[np.abs(A_) < th] = 0.0
A_ = A_.reshape(A.shape)
return A_
A.data = torch.tensor(A_, requires_grad=True)
def copySupport(src, dest):

Просмотреть файл

@ -136,7 +136,7 @@ class KeywordSpotter(nn.Module):
scheduler = ExponentialResettingLR(optimizer, gamma, reset)
return scheduler
def fit(self, training_data, validation_data, options, model, device=None, detail=False, run=None):
def fit(self, training_data, validation_data, training_options, model_options, device=None, detail=False, run=None):
"""
Perform the training. This is not called "train" because
the base class already defines that method with a different meaning.
@ -156,29 +156,29 @@ class KeywordSpotter(nn.Module):
self.training = True
start = time.time()
loss_function = nn.NLLLoss()
optimizer = self.configure_optimizer(options)
optimizer = self.configure_optimizer(training_options)
print(optimizer)
num_epochs = options.max_epochs
batch_size = options.batch_size
num_epochs = training_options.max_epochs
batch_size = training_options.batch_size
ticks = training_data.num_rows / batch_size # iterations per epoch
total_iterations = ticks * num_epochs
scheduler = self.configure_lr(options, optimizer, ticks, total_iterations)
scheduler = self.configure_lr(training_options, optimizer, ticks, total_iterations)
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
log = []
if options.rolling:
if training_options.rolling:
rolling_length = 2
max_rolling_length = int(ticks)
if max_rolling_length > options.max_rolling_length:
max_rolling_length = options.max_rolling_length
if max_rolling_length > training_options.max_rolling_length:
max_rolling_length = training_options.max_rolling_length
bag_count = 100
hidden_bag_size = batch_size * bag_count
for epoch in range(num_epochs):
self.train()
if options.rolling:
if training_options.rolling:
rolling_length += 1
if rolling_length < max_rolling_length:
self.init_hidden_bag(hidden_bag_size, device)
@ -192,7 +192,7 @@ class KeywordSpotter(nn.Module):
# Also, we need to clear out the hidden state,
# detaching it from its history on the last instance.
if options.rolling:
if training_options.rolling:
if rolling_length < max_rolling_length:
if (i_batch + 1) % rolling_length == 0:
self.init_hidden()
@ -228,6 +228,9 @@ class KeywordSpotter(nn.Module):
# applying the gradients we computed during back propagation
optimizer.step()
if model_options.wSparsity < 1.0 or model_options.uSparsity < 1.0:
self.sparsify(model_options.wSparsity, model_options.uSparsity)
learning_rate = optimizer.param_groups[0]['lr']
if detail:
learning_rate = optimizer.param_groups[0]['lr']
@ -543,6 +546,8 @@ if __name__ == '__main__':
parser.add_argument("--uRank2", "-ur2", help="Rank of U in 2nd layer of FastGRNN default is None", type=int)
parser.add_argument("--wRank3", "-wr3", help="Rank of W in 3rd layer of FastGRNN default is None", type=int)
parser.add_argument("--uRank3", "-ur3", help="Rank of U in 3rd layer of FastGRNN default is None", type=int)
parser.add_argument("--wSparsity", "-wsp", help="Sparsity of W matrices", type=float)
parser.add_argument("--uSparsity", "-usp", help="Sparsity of U matrices", type=float)
parser.add_argument("--gate_nonlinearity", "-gnl", help="Gate Non-Linearity in FastGRNN default is sigmoid"
" use between [sigmoid, quantSigmoid, tanh, quantTanh]")
parser.add_argument("--update_nonlinearity", "-unl", help="Update Non-Linearity in FastGRNN default is Tanh"
@ -637,6 +642,14 @@ if __name__ == '__main__':
config.model.wRank3 = args.wRank3
if args.uRank3:
config.model.uRank3 = args.wRank3
if args.wSparsity:
config.model.wSparsity = args.wSparsity
else:
config.model.wSparsity = 1.0
if args.uSparsity:
config.model.uSparsity = args.uSparsity
else:
config.model.uSparsity = 1.0
if args.gate_nonlinearity:
config.model.gate_nonlinearity = args.gate_nonlinearity
if args.update_nonlinearity: