зеркало из https://github.com/microsoft/EdgeML.git
added sparsity and hard thresholding support to train_classifier and Fast(G)RNN cells and models
This commit is contained in:
Родитель
c90b28e13e
Коммит
b8b4dd70a8
|
@ -6,6 +6,8 @@ import torch.nn as nn
|
|||
from torch.autograd import Function
|
||||
import numpy as np
|
||||
|
||||
import edgeml.pytorch.utils as utils
|
||||
|
||||
def onnx_exportable_fastgrnn(input, fargs, output, hidden_size, wRank, uRank, gate_nonlinearity, update_nonlinearity):
|
||||
class RNNSymbolic(Function):
|
||||
@staticmethod
|
||||
|
@ -196,6 +198,19 @@ class FastGRNNCell(nn.Module):
|
|||
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1]))
|
||||
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1]))
|
||||
|
||||
def sparsify(self, wsp, usp):
|
||||
if self._wRank is None:
|
||||
utils.hardThreshold(self.W, wsp)
|
||||
else:
|
||||
utils.hardThreshold(self.W1, wsp)
|
||||
utils.hardThreshold(self.W2, wsp)
|
||||
|
||||
if self._uRank is None:
|
||||
utils.hardThreshold(self.U, usp)
|
||||
else:
|
||||
utils.hardThreshold(self.U1, usp)
|
||||
utils.hardThreshold(self.U2, usp)
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._hidden_size
|
||||
|
@ -340,6 +355,20 @@ class FastRNNCell(nn.Module):
|
|||
self.alpha = nn.Parameter(self._alphaInit * torch.ones([1, 1]))
|
||||
self.beta = nn.Parameter(self._betaInit * torch.ones([1, 1]))
|
||||
|
||||
def sparsify(self, wsp, usp):
|
||||
if self._wRank is None:
|
||||
utils.hardThreshold(self.W, wsp)
|
||||
else:
|
||||
utils.hardThreshold(self.W1, wsp)
|
||||
utils.hardThreshold(self.W2, wsp)
|
||||
|
||||
if self._uRank is None:
|
||||
utils.hardThreshold(self.U, usp)
|
||||
else:
|
||||
utils.hardThreshold(self.U1, usp)
|
||||
utils.hardThreshold(self.U2, usp)
|
||||
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return self._hidden_size
|
||||
|
|
|
@ -72,6 +72,14 @@ def fastgrnnmodel(inheritance_class=nn.Module):
|
|||
self.hidden2keyword = nn.Linear(last_output_size, num_classes)
|
||||
self.init_hidden()
|
||||
|
||||
def sparsify(self, wsp, usp):
|
||||
self.fastgrnn1.cell.sparsify(wsp, usp)
|
||||
if self.num_layers > 1:
|
||||
self.fastgrnn2.cell.sparsify(wsp, usp)
|
||||
if self.num_layers > 2:
|
||||
self.fastgrnn3.cell.sparsify(wsp, usp)
|
||||
|
||||
|
||||
def normalize(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
|
|
@ -41,17 +41,21 @@ def binaryHingeLoss(logits, labels):
|
|||
return torch.mean(F.relu(1.0 - (2 * labels - 1) * logits))
|
||||
|
||||
|
||||
def hardThreshold(A, s):
|
||||
def hardThreshold(A: torch.nn.Parameter, s):
|
||||
'''
|
||||
Hard thresholding function on Tensor A with sparsity s
|
||||
Hard thresholds and modifies in-palce nn.Parameter A with sparsity s
|
||||
'''
|
||||
A_ = np.copy(A)
|
||||
#PyTorch disallows numpy access/copy to tensors in graph.
|
||||
#.detach() creates a new tensor not attached to the graph.
|
||||
A_ = A.detach().numpy()
|
||||
|
||||
A_ = A_.ravel()
|
||||
if len(A_) > 0:
|
||||
th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher')
|
||||
A_[np.abs(A_) < th] = 0.0
|
||||
A_ = A_.reshape(A.shape)
|
||||
return A_
|
||||
|
||||
A.data = torch.tensor(A_, requires_grad=True)
|
||||
|
||||
|
||||
def copySupport(src, dest):
|
||||
|
|
|
@ -136,7 +136,7 @@ class KeywordSpotter(nn.Module):
|
|||
scheduler = ExponentialResettingLR(optimizer, gamma, reset)
|
||||
return scheduler
|
||||
|
||||
def fit(self, training_data, validation_data, options, model, device=None, detail=False, run=None):
|
||||
def fit(self, training_data, validation_data, training_options, model_options, device=None, detail=False, run=None):
|
||||
"""
|
||||
Perform the training. This is not called "train" because
|
||||
the base class already defines that method with a different meaning.
|
||||
|
@ -156,29 +156,29 @@ class KeywordSpotter(nn.Module):
|
|||
self.training = True
|
||||
start = time.time()
|
||||
loss_function = nn.NLLLoss()
|
||||
optimizer = self.configure_optimizer(options)
|
||||
optimizer = self.configure_optimizer(training_options)
|
||||
print(optimizer)
|
||||
|
||||
num_epochs = options.max_epochs
|
||||
batch_size = options.batch_size
|
||||
num_epochs = training_options.max_epochs
|
||||
batch_size = training_options.batch_size
|
||||
|
||||
ticks = training_data.num_rows / batch_size # iterations per epoch
|
||||
total_iterations = ticks * num_epochs
|
||||
scheduler = self.configure_lr(options, optimizer, ticks, total_iterations)
|
||||
scheduler = self.configure_lr(training_options, optimizer, ticks, total_iterations)
|
||||
|
||||
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
||||
log = []
|
||||
if options.rolling:
|
||||
if training_options.rolling:
|
||||
rolling_length = 2
|
||||
max_rolling_length = int(ticks)
|
||||
if max_rolling_length > options.max_rolling_length:
|
||||
max_rolling_length = options.max_rolling_length
|
||||
if max_rolling_length > training_options.max_rolling_length:
|
||||
max_rolling_length = training_options.max_rolling_length
|
||||
bag_count = 100
|
||||
hidden_bag_size = batch_size * bag_count
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
self.train()
|
||||
if options.rolling:
|
||||
if training_options.rolling:
|
||||
rolling_length += 1
|
||||
if rolling_length < max_rolling_length:
|
||||
self.init_hidden_bag(hidden_bag_size, device)
|
||||
|
@ -192,7 +192,7 @@ class KeywordSpotter(nn.Module):
|
|||
|
||||
# Also, we need to clear out the hidden state,
|
||||
# detaching it from its history on the last instance.
|
||||
if options.rolling:
|
||||
if training_options.rolling:
|
||||
if rolling_length < max_rolling_length:
|
||||
if (i_batch + 1) % rolling_length == 0:
|
||||
self.init_hidden()
|
||||
|
@ -228,6 +228,9 @@ class KeywordSpotter(nn.Module):
|
|||
# applying the gradients we computed during back propagation
|
||||
optimizer.step()
|
||||
|
||||
if model_options.wSparsity < 1.0 or model_options.uSparsity < 1.0:
|
||||
self.sparsify(model_options.wSparsity, model_options.uSparsity)
|
||||
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
if detail:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
|
@ -543,6 +546,8 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--uRank2", "-ur2", help="Rank of U in 2nd layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--wRank3", "-wr3", help="Rank of W in 3rd layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--uRank3", "-ur3", help="Rank of U in 3rd layer of FastGRNN default is None", type=int)
|
||||
parser.add_argument("--wSparsity", "-wsp", help="Sparsity of W matrices", type=float)
|
||||
parser.add_argument("--uSparsity", "-usp", help="Sparsity of U matrices", type=float)
|
||||
parser.add_argument("--gate_nonlinearity", "-gnl", help="Gate Non-Linearity in FastGRNN default is sigmoid"
|
||||
" use between [sigmoid, quantSigmoid, tanh, quantTanh]")
|
||||
parser.add_argument("--update_nonlinearity", "-unl", help="Update Non-Linearity in FastGRNN default is Tanh"
|
||||
|
@ -637,6 +642,14 @@ if __name__ == '__main__':
|
|||
config.model.wRank3 = args.wRank3
|
||||
if args.uRank3:
|
||||
config.model.uRank3 = args.wRank3
|
||||
if args.wSparsity:
|
||||
config.model.wSparsity = args.wSparsity
|
||||
else:
|
||||
config.model.wSparsity = 1.0
|
||||
if args.uSparsity:
|
||||
config.model.uSparsity = args.uSparsity
|
||||
else:
|
||||
config.model.uSparsity = 1.0
|
||||
if args.gate_nonlinearity:
|
||||
config.model.gate_nonlinearity = args.gate_nonlinearity
|
||||
if args.update_nonlinearity:
|
||||
|
|
Загрузка…
Ссылка в новой задаче