diff --git a/examples/pytorch/FastCells/train_classifier.py b/examples/pytorch/FastCells/train_classifier.py index 69c5979b..1f0588bb 100644 --- a/examples/pytorch/FastCells/train_classifier.py +++ b/examples/pytorch/FastCells/train_classifier.py @@ -22,6 +22,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.onnx +import random from torch.autograd import Variable, Function from torch.utils.data import Dataset, DataLoader @@ -306,7 +307,7 @@ class AudioDataset(Dataset): mini-batch training. """ - def __init__(self, filename, config, keywords): + def __init__(self, filename, config, keywords, training=False): """ Initialize the AudioDataset from the given *.npz file """ self.dataset = np.load(filename) @@ -331,10 +332,18 @@ class AudioDataset(Dataset): else: self.mean = None self.std = None + self.label_names = self.dataset["labels"] self.keywords = keywords self.num_keywords = len(self.keywords) self.labels = self.to_long_vector() + + self.keywords_idx = None + self.non_keywords_idx = None + if training and config.sample_non_kw is not None: + self.keywords_idx, self.non_keywords_idx = self.get_keyword_idx(config.sample_non_kw) + self.sample_non_kw_probability = config.sample_non_kw_probability + msg = "Loaded dataset {} and found sample rate {}, audio_size {}, input_size {}, window_size {} and shift {}" print(msg.format(os.path.basename(filename), self.sample_rate, self.audio_size, self.input_size, self.window_size, self.shift)) @@ -342,23 +351,40 @@ class AudioDataset(Dataset): def get_data_loader(self, batch_size): """ Get a DataLoader that can enumerate shuffled batches of data in this dataset """ return DataLoader(self, batch_size=batch_size, shuffle=True, drop_last=True) - + def to_long_vector(self): """ convert the expected labels to a list of integer indexes into the array of keywords """ indexer = [(0 if x == "" else self.keywords.index(x)) for x in self.label_names] return np.array(indexer, dtype=np.longlong) + def get_keyword_idx(self, non_kw_label): + """ find the keywords and store there index """ + indexer = [ids for ids, label in enumerate(self.label_names) if label != non_kw_label] + non_indexer = [ids for ids, label in enumerate(self.label_names) if label == non_kw_label] + return (np.array(indexer, dtype=np.longlong), np.array(non_indexer, dtype=np.longlong)) + def __len__(self): """ Return the number of rows in this Dataset """ - return self.num_rows + if self.non_keywords_idx is None: + return self.num_rows + else: + return int(len(self.keywords_idx) / (1-self.sample_non_kw_probability)) def __getitem__(self, idx): """ Return a single labelled sample here as a tuple """ - audio = self.features[idx] # batch index is second dimension - label = self.labels[idx] + if self.non_keywords_idx is None: + updated_idx=idx + else: + if idx < len(self.keywords_idx): + updated_idx=self.keywords_idx[idx] + else: + updated_idx=np.random.choice(self.non_keywords_idx) + audio = self.features[updated_idx] # batch index is second dimension + label = self.labels[updated_idx] sample = (audio, label) return sample + def create_model(model_config, input_size, num_keywords): ModelClass = get_model_class(KeywordSpotter) @@ -453,7 +479,7 @@ def train(config, evaluate_only=False, outdir=".", detail=False, azureml=False): log = None if not evaluate_only: print("Loading {}...".format(training_file)) - training_data = AudioDataset(training_file, config.dataset, keywords) + training_data = AudioDataset(training_file, config.dataset, keywords, training=True) print("Loading {}...".format(validation_file)) validation_data = AudioDataset(validation_file, config.dataset, keywords) @@ -556,6 +582,8 @@ if __name__ == '__main__': parser.add_argument("--rolling", help="Whether to train model in rolling fashion or not", action="store_true") parser.add_argument("--max_rolling_length", help="Max number of epochs you want to roll the rolling training" " default is 100", type=int) + parser.add_argument("--sample_non_kw", "-sl", type=str, help="Sample data for this label with probability sample_prob") + parser.add_argument("--sample_non_kw_probability", "-spr", type=float, help="Sample from scl with this probability") # arguments for fastgrnn parser.add_argument("--wRank", "-wr", help="Rank of W in 1st layer of FastGRNN default is None", type=int) @@ -645,6 +673,15 @@ if __name__ == '__main__': config.dataset.categories = args.categories if args.dataset: config.dataset.path = args.dataset + if args.sample_non_kw: + config.dataset.sample_non_kw = args.sample_non_kw + if args.sample_non_kw_probability is None: + config.dataset.sample_non_kw_probability = 0.5 + else: + config.dataset.sample_non_kw_probability = args.sample_non_kw_probability + else: + config.dataset.sample_non_kw = None + if args.wRank: config.model.wRank = args.wRank if args.uRank: