зеркало из https://github.com/microsoft/EdgeML.git
added sparsity and hard thresholding support to train_classifier and Fast(G)RNN cells and models
This commit is contained in:
Родитель
a2f1ddb1df
Коммит
e2c8a8fa93
|
@ -1086,20 +1086,22 @@ class FastGRNN(nn.Module):
|
|||
|
||||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid",
|
||||
update_nonlinearity="tanh", wRank=None, uRank=None,
|
||||
zetaInit=1.0, nuInit=-4.0, batch_first=True):
|
||||
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0,
|
||||
batch_first=True):
|
||||
super(FastGRNN, self).__init__()
|
||||
self._input_size = input_size
|
||||
self._hidden_size = hidden_size
|
||||
self._gate_nonlinearity = gate_nonlinearity
|
||||
self._update_nonlinearity = update_nonlinearity
|
||||
self._wRank = wRank
|
||||
self._uRank = uRank
|
||||
#self._input_size = input_size
|
||||
#self._hidden_size = hidden_size
|
||||
#self._gate_nonlinearity = gate_nonlinearity
|
||||
#self._update_nonlinearity = update_nonlinearity
|
||||
#self._wRank = wRank
|
||||
#self._uRank = uRank
|
||||
self.batch_first = batch_first
|
||||
|
||||
self.cell = FastGRNNCell(input_size, hidden_size,
|
||||
gate_nonlinearity=gate_nonlinearity,
|
||||
update_nonlinearity=update_nonlinearity,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wRank=wRank, uRank=uRank,
|
||||
wSparsity=wSparsity, uSparsity=uSparsity,
|
||||
zetaInit=zetaInit, nuInit=nuInit)
|
||||
self.unrollRNN = BaseRNN(self.cell, batch_first=self.batch_first)
|
||||
|
||||
|
|
|
@ -31,6 +31,8 @@ def fastgrnnmodel(inheritance_class=nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.wRank_list = wRank_list
|
||||
self.uRank_list = uRank_list
|
||||
self.wSparsity_list = wSparsity_list
|
||||
self.uSparsity_list = uSparsity_list
|
||||
self.gate_nonlinearity = gate_nonlinearity
|
||||
self.update_nonlinearity = update_nonlinearity
|
||||
self.linear = linear
|
||||
|
@ -48,7 +50,7 @@ def fastgrnnmodel(inheritance_class=nn.Module):
|
|||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[0], uRank=self.uRank_list[0],
|
||||
wSparisty=self.wSparisty_list[0], uSparisty=self.uSparisty_list[0],
|
||||
wSparsity=self.wSparsity_list[0], uSparsity=self.uSparsity_list[0],
|
||||
batch_first = self.batch_first)
|
||||
self.fastgrnn2 = None
|
||||
last_output_size = self.hidden_units_list[0]
|
||||
|
@ -57,7 +59,7 @@ def fastgrnnmodel(inheritance_class=nn.Module):
|
|||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[1], uRank=self.uRank_list[1],
|
||||
wSparisty=self.wSparisty_list[1], uSparisty=self.uSparisty_list[1],
|
||||
wSparsity=self.wSparsity_list[1], uSparsity=self.uSparsity_list[1],
|
||||
batch_first = self.batch_first)
|
||||
last_output_size = self.hidden_units_list[1]
|
||||
self.fastgrnn3 = None
|
||||
|
@ -66,7 +68,7 @@ def fastgrnnmodel(inheritance_class=nn.Module):
|
|||
gate_nonlinearity=self.gate_nonlinearity,
|
||||
update_nonlinearity=self.update_nonlinearity,
|
||||
wRank=self.wRank_list[2], uRank=self.uRank_list[2],
|
||||
wSparisty=self.wSparisty_list[2], uSparisty=self.uSparisty_list[2],
|
||||
wSparsity=self.wSparsity_list[2], uSparsity=self.uSparsity_list[2],
|
||||
batch_first = self.batch_first)
|
||||
last_output_size = self.hidden_units_list[2]
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ class KeywordSpotter(nn.Module):
|
|||
scheduler = ExponentialResettingLR(optimizer, gamma, reset)
|
||||
return scheduler
|
||||
|
||||
def fit(self, training_data, validation_data, options, sparsify=false, device=None, detail=False, run=None):
|
||||
def fit(self, training_data, validation_data, options, sparsify=False, device=None, detail=False, run=None):
|
||||
"""
|
||||
Perform the training. This is not called "train" because
|
||||
the base class already defines that method with a different meaning.
|
||||
|
@ -654,7 +654,7 @@ if __name__ == '__main__':
|
|||
config.model.uSparsity = args.uSparsity
|
||||
else:
|
||||
config.model.uSparsity = 1.0
|
||||
if config.model.uSparsity < 1.0 or config.model.wSparsity < 1.0
|
||||
if config.model.uSparsity < 1.0 or config.model.wSparsity < 1.0:
|
||||
config.model.sparsify = True
|
||||
else:
|
||||
config.model.sparsify = False
|
||||
|
|
Загрузка…
Ссылка в новой задаче