support for sampling non keywords

This commit is contained in:
Harsha Vardhan Simhadri 2019-09-24 22:27:43 +05:30
Родитель ae2668e24b
Коммит 0176ebc92a
1 изменённых файлов: 43 добавлений и 6 удалений

Просмотреть файл

@ -22,6 +22,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.onnx
import random
from torch.autograd import Variable, Function
from torch.utils.data import Dataset, DataLoader
@ -306,7 +307,7 @@ class AudioDataset(Dataset):
mini-batch training.
"""
def __init__(self, filename, config, keywords):
def __init__(self, filename, config, keywords, training=False):
""" Initialize the AudioDataset from the given *.npz file """
self.dataset = np.load(filename)
@ -331,10 +332,18 @@ class AudioDataset(Dataset):
else:
self.mean = None
self.std = None
self.label_names = self.dataset["labels"]
self.keywords = keywords
self.num_keywords = len(self.keywords)
self.labels = self.to_long_vector()
self.keywords_idx = None
self.non_keywords_idx = None
if training and config.sample_non_kw is not None:
self.keywords_idx, self.non_keywords_idx = self.get_keyword_idx(config.sample_non_kw)
self.sample_non_kw_probability = config.sample_non_kw_probability
msg = "Loaded dataset {} and found sample rate {}, audio_size {}, input_size {}, window_size {} and shift {}"
print(msg.format(os.path.basename(filename), self.sample_rate, self.audio_size, self.input_size,
self.window_size, self.shift))
@ -342,23 +351,40 @@ class AudioDataset(Dataset):
def get_data_loader(self, batch_size):
""" Get a DataLoader that can enumerate shuffled batches of data in this dataset """
return DataLoader(self, batch_size=batch_size, shuffle=True, drop_last=True)
def to_long_vector(self):
""" convert the expected labels to a list of integer indexes into the array of keywords """
indexer = [(0 if x == "<null>" else self.keywords.index(x)) for x in self.label_names]
return np.array(indexer, dtype=np.longlong)
def get_keyword_idx(self, non_kw_label):
""" find the keywords and store there index """
indexer = [ids for ids, label in enumerate(self.label_names) if label != non_kw_label]
non_indexer = [ids for ids, label in enumerate(self.label_names) if label == non_kw_label]
return (np.array(indexer, dtype=np.longlong), np.array(non_indexer, dtype=np.longlong))
def __len__(self):
""" Return the number of rows in this Dataset """
return self.num_rows
if self.non_keywords_idx is None:
return self.num_rows
else:
return int(len(self.keywords_idx) / (1-self.sample_non_kw_probability))
def __getitem__(self, idx):
""" Return a single labelled sample here as a tuple """
audio = self.features[idx] # batch index is second dimension
label = self.labels[idx]
if self.non_keywords_idx is None:
updated_idx=idx
else:
if idx < len(self.keywords_idx):
updated_idx=self.keywords_idx[idx]
else:
updated_idx=np.random.choice(self.non_keywords_idx)
audio = self.features[updated_idx] # batch index is second dimension
label = self.labels[updated_idx]
sample = (audio, label)
return sample
def create_model(model_config, input_size, num_keywords):
ModelClass = get_model_class(KeywordSpotter)
@ -453,7 +479,7 @@ def train(config, evaluate_only=False, outdir=".", detail=False, azureml=False):
log = None
if not evaluate_only:
print("Loading {}...".format(training_file))
training_data = AudioDataset(training_file, config.dataset, keywords)
training_data = AudioDataset(training_file, config.dataset, keywords, training=True)
print("Loading {}...".format(validation_file))
validation_data = AudioDataset(validation_file, config.dataset, keywords)
@ -556,6 +582,8 @@ if __name__ == '__main__':
parser.add_argument("--rolling", help="Whether to train model in rolling fashion or not", action="store_true")
parser.add_argument("--max_rolling_length", help="Max number of epochs you want to roll the rolling training"
" default is 100", type=int)
parser.add_argument("--sample_non_kw", "-sl", type=str, help="Sample data for this label with probability sample_prob")
parser.add_argument("--sample_non_kw_probability", "-spr", type=float, help="Sample from scl with this probability")
# arguments for fastgrnn
parser.add_argument("--wRank", "-wr", help="Rank of W in 1st layer of FastGRNN default is None", type=int)
@ -645,6 +673,15 @@ if __name__ == '__main__':
config.dataset.categories = args.categories
if args.dataset:
config.dataset.path = args.dataset
if args.sample_non_kw:
config.dataset.sample_non_kw = args.sample_non_kw
if args.sample_non_kw_probability is None:
config.dataset.sample_non_kw_probability = 0.5
else:
config.dataset.sample_non_kw_probability = args.sample_non_kw_probability
else:
config.dataset.sample_non_kw = None
if args.wRank:
config.model.wRank = args.wRank
if args.uRank: