added sparsity and hard thresholding support to train_classifier and Fast(G)RNN cells and models

This commit is contained in:
Harsha Vardhan Simhadri 2019-07-29 19:52:12 +05:30
Родитель a2f1ddb1df
Коммит e2c8a8fa93
3 изменённых файлов: 17 добавлений и 13 удалений

Просмотреть файл

@ -1086,20 +1086,22 @@ class FastGRNN(nn.Module):
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
update_nonlinearity="tanh", wRank=None, uRank=None,
zetaInit=1.0, nuInit=-4.0, batch_first=True):
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
batch_first=True):
super(FastGRNN, self).__init__()
self._input_size = input_size
self._hidden_size = hidden_size
self._gate_nonlinearity = gate_nonlinearity
self._update_nonlinearity = update_nonlinearity
self._wRank = wRank
self._uRank = uRank
#self._input_size = input_size
#self._hidden_size = hidden_size
#self._gate_nonlinearity = gate_nonlinearity
#self._update_nonlinearity = update_nonlinearity
#self._wRank = wRank
#self._uRank = uRank
self.batch_first = batch_first
self.cell = FastGRNNCell(input_size, hidden_size,
gate_nonlinearity=gate_nonlinearity,
update_nonlinearity=update_nonlinearity,
wRank=wRank, uRank=uRank,
wRank=wRank, uRank=uRank,
wSparsity=wSparsity, uSparsity=uSparsity,
zetaInit=zetaInit, nuInit=nuInit)
self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first)

Просмотреть файл

@ -31,6 +31,8 @@ def fastgrnnmodel(inheritance_class=nn.Module):
self.num_classes = num_classes
self.wRank_list = wRank_list
self.uRank_list = uRank_list
self.wSparsity_list = wSparsity_list
self.uSparsity_list = uSparsity_list
self.gate_nonlinearity = gate_nonlinearity
self.update_nonlinearity = update_nonlinearity
self.linear = linear
@ -48,7 +50,7 @@ def fastgrnnmodel(inheritance_class=nn.Module):
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity,
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
wSparisty=self.wSparisty_list[0], uSparisty=self.uSparisty_list[0],
wSparsity=self.wSparsity_list[0], uSparsity=self.uSparsity_list[0],
batch_first = self.batch_first)
self.fastgrnn2 = None
last_output_size = self.hidden_units_list[0]
@ -57,7 +59,7 @@ def fastgrnnmodel(inheritance_class=nn.Module):
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity,
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
wSparisty=self.wSparisty_list[1], uSparisty=self.uSparisty_list[1],
wSparsity=self.wSparsity_list[1], uSparsity=self.uSparsity_list[1],
batch_first = self.batch_first)
last_output_size = self.hidden_units_list[1]
self.fastgrnn3 = None
@ -66,7 +68,7 @@ def fastgrnnmodel(inheritance_class=nn.Module):
gate_nonlinearity=self.gate_nonlinearity,
update_nonlinearity=self.update_nonlinearity,
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
wSparisty=self.wSparisty_list[2], uSparisty=self.uSparisty_list[2],
wSparsity=self.wSparsity_list[2], uSparsity=self.uSparsity_list[2],
batch_first = self.batch_first)
last_output_size = self.hidden_units_list[2]

Просмотреть файл

@ -136,7 +136,7 @@ class KeywordSpotter(nn.Module):
scheduler = ExponentialResettingLR(optimizer, gamma, reset)
return scheduler
def fit(self, training_data, validation_data, options, sparsify=false, device=None, detail=False, run=None):
def fit(self, training_data, validation_data, options, sparsify=False, device=None, detail=False, run=None):
"""
Perform the training. This is not called "train" because
the base class already defines that method with a different meaning.
@ -654,7 +654,7 @@ if __name__ == '__main__':
config.model.uSparsity = args.uSparsity
else:
config.model.uSparsity = 1.0
if config.model.uSparsity < 1.0 or config.model.wSparsity < 1.0
if config.model.uSparsity < 1.0 or config.model.wSparsity < 1.0:
config.model.sparsify = True
else:
config.model.sparsify = False