This commit is contained in:
Harsha Vardhan Simhadri 2019-07-27 21:31:07 +05:30
Родитель 8de1b8ce71
Коммит e35aee8f59
3 изменённых файлов: 11 добавлений и 7 удалений

Просмотреть файл

@ -60,7 +60,7 @@ class BaseRNN(nn.Module):
'''
Generic equivalent of static_rnn in tf
Used to unroll all the cell written in this file
We assume data to be batch_first by default ie.,
We assume input to be batch_first by default ie.,
[batchSize, timeSteps, inputDims] else
[timeSteps, batchSize, inputDims]
'''
@ -436,7 +436,7 @@ class LSTMLRCell(nn.Module):
LSTM architecture and compression techniques are found in
LSTM paper
Basic architecture is like:
Basic architecture:
f_t = gate_nl(W1x_t + U1h_{t-1} + B_f)
i_t = gate_nl(W2x_t + U2h_{t-1} + B_i)

Просмотреть файл

@ -10,9 +10,13 @@ from edgeml.pytorch.graph.rnn import *
def fastgrnnmodel(inheritance_class=nn.Module):
class FastGRNNModel(inheritance_class):
"""This class is a PyTorch Module that implements a 1, 2 or 3 layer GRU based audio classifier"""
"""This class is a PyTorch Module that implements a 1, 2 or 3 layer
RNN-based classifier
"""
def __init__(self, input_dim, num_layers, hidden_units_list, wRank_list, uRank_list, gate_nonlinearity, update_nonlinearity, num_classes=None, linear=True, batch_first=False, apply_softmax=True):
def __init__(self, input_dim, num_layers, hidden_units_list,
wRank_list, uRank_list, gate_nonlinearity, update_nonlinearity,
num_classes=None, linear=True, batch_first=False, apply_softmax=True):
"""
Initialize the KeywordSpotter with the following parameters:
input_dim - the size of the input audio frame in # samples.

Просмотреть файл

@ -111,9 +111,9 @@ class KeywordSpotter(nn.Module):
def fit(self, training_data, validation_data, options, model, device=None, detail=False, run=None):
"""
Perform the training. This is not called "train" because the base class already defines
that method with a different meaning. The base class "train" method puts the Module into
"training mode".
Perform the training. This is not called "train" because
the base class already defines that method with a different meaning.
The base class "train" method puts the Module into "training mode".
"""
print("Training {} using {} rows of featurized training input...".format(self.name(), training_data.num_rows))