зеркало из https://github.com/microsoft/EdgeML.git
support for sampling non keywords
This commit is contained in:
Родитель
ae2668e24b
Коммит
0176ebc92a
|
@ -22,6 +22,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.onnx
|
||||
import random
|
||||
|
||||
from torch.autograd import Variable, Function
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
@ -306,7 +307,7 @@ class AudioDataset(Dataset):
|
|||
mini-batch training.
|
||||
"""
|
||||
|
||||
def __init__(self, filename, config, keywords):
|
||||
def __init__(self, filename, config, keywords, training=False):
|
||||
""" Initialize the AudioDataset from the given *.npz file """
|
||||
self.dataset = np.load(filename)
|
||||
|
||||
|
@ -331,10 +332,18 @@ class AudioDataset(Dataset):
|
|||
else:
|
||||
self.mean = None
|
||||
self.std = None
|
||||
|
||||
self.label_names = self.dataset["labels"]
|
||||
self.keywords = keywords
|
||||
self.num_keywords = len(self.keywords)
|
||||
self.labels = self.to_long_vector()
|
||||
|
||||
self.keywords_idx = None
|
||||
self.non_keywords_idx = None
|
||||
if training and config.sample_non_kw is not None:
|
||||
self.keywords_idx, self.non_keywords_idx = self.get_keyword_idx(config.sample_non_kw)
|
||||
self.sample_non_kw_probability = config.sample_non_kw_probability
|
||||
|
||||
msg = "Loaded dataset {} and found sample rate {}, audio_size {}, input_size {}, window_size {} and shift {}"
|
||||
print(msg.format(os.path.basename(filename), self.sample_rate, self.audio_size, self.input_size,
|
||||
self.window_size, self.shift))
|
||||
|
@ -342,23 +351,40 @@ class AudioDataset(Dataset):
|
|||
def get_data_loader(self, batch_size):
|
||||
""" Get a DataLoader that can enumerate shuffled batches of data in this dataset """
|
||||
return DataLoader(self, batch_size=batch_size, shuffle=True, drop_last=True)
|
||||
|
||||
|
||||
def to_long_vector(self):
|
||||
""" convert the expected labels to a list of integer indexes into the array of keywords """
|
||||
indexer = [(0 if x == "<null>" else self.keywords.index(x)) for x in self.label_names]
|
||||
return np.array(indexer, dtype=np.longlong)
|
||||
|
||||
def get_keyword_idx(self, non_kw_label):
|
||||
""" find the keywords and store there index """
|
||||
indexer = [ids for ids, label in enumerate(self.label_names) if label != non_kw_label]
|
||||
non_indexer = [ids for ids, label in enumerate(self.label_names) if label == non_kw_label]
|
||||
return (np.array(indexer, dtype=np.longlong), np.array(non_indexer, dtype=np.longlong))
|
||||
|
||||
def __len__(self):
|
||||
""" Return the number of rows in this Dataset """
|
||||
return self.num_rows
|
||||
if self.non_keywords_idx is None:
|
||||
return self.num_rows
|
||||
else:
|
||||
return int(len(self.keywords_idx) / (1-self.sample_non_kw_probability))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
""" Return a single labelled sample here as a tuple """
|
||||
audio = self.features[idx] # batch index is second dimension
|
||||
label = self.labels[idx]
|
||||
if self.non_keywords_idx is None:
|
||||
updated_idx=idx
|
||||
else:
|
||||
if idx < len(self.keywords_idx):
|
||||
updated_idx=self.keywords_idx[idx]
|
||||
else:
|
||||
updated_idx=np.random.choice(self.non_keywords_idx)
|
||||
audio = self.features[updated_idx] # batch index is second dimension
|
||||
label = self.labels[updated_idx]
|
||||
sample = (audio, label)
|
||||
return sample
|
||||
|
||||
|
||||
|
||||
def create_model(model_config, input_size, num_keywords):
|
||||
ModelClass = get_model_class(KeywordSpotter)
|
||||
|
@ -453,7 +479,7 @@ def train(config, evaluate_only=False, outdir=".", detail=False, azureml=False):
|
|||
log = None
|
||||
if not evaluate_only:
|
||||
print("Loading {}...".format(training_file))
|
||||
training_data = AudioDataset(training_file, config.dataset, keywords)
|
||||
training_data = AudioDataset(training_file, config.dataset, keywords, training=True)
|
||||
|
||||
print("Loading {}...".format(validation_file))
|
||||
validation_data = AudioDataset(validation_file, config.dataset, keywords)
|
||||
|
@ -556,6 +582,8 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--rolling", help="Whether to train model in rolling fashion or not", action="store_true")
|
||||
parser.add_argument("--max_rolling_length", help="Max number of epochs you want to roll the rolling training"
|
||||
" default is 100", type=int)
|
||||
parser.add_argument("--sample_non_kw", "-sl", type=str, help="Sample data for this label with probability sample_prob")
|
||||
parser.add_argument("--sample_non_kw_probability", "-spr", type=float, help="Sample from scl with this probability")
|
||||
|
||||
# arguments for fastgrnn
|
||||
parser.add_argument("--wRank", "-wr", help="Rank of W in 1st layer of FastGRNN default is None", type=int)
|
||||
|
@ -645,6 +673,15 @@ if __name__ == '__main__':
|
|||
config.dataset.categories = args.categories
|
||||
if args.dataset:
|
||||
config.dataset.path = args.dataset
|
||||
if args.sample_non_kw:
|
||||
config.dataset.sample_non_kw = args.sample_non_kw
|
||||
if args.sample_non_kw_probability is None:
|
||||
config.dataset.sample_non_kw_probability = 0.5
|
||||
else:
|
||||
config.dataset.sample_non_kw_probability = args.sample_non_kw_probability
|
||||
else:
|
||||
config.dataset.sample_non_kw = None
|
||||
|
||||
if args.wRank:
|
||||
config.model.wRank = args.wRank
|
||||
if args.uRank:
|
||||
|
|
Загрузка…
Ссылка в новой задаче