This commit is contained in:
Anthony Cintron Roman 2021-06-21 20:32:11 -07:00
Родитель 345dc8b752
Коммит ec8713e9ea
8 изменённых файлов: 2593 добавлений и 18 удалений

310
CNN_CNN.ipynb Normal file
Просмотреть файл

@ -0,0 +1,310 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"Licensed under the MIT License."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.optim as optim\n",
"import torch.utils.data as data_utils\n",
"import os\n",
"import numpy as np\n",
"from sklearn.preprocessing import LabelBinarizer, LabelEncoder\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.model_selection import KFold, StratifiedKFold\n",
"import csv\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import glob\n",
"import gc\n",
"import h5py\n",
"import pickle as pk\n",
"\n",
"from utils import log_results, SaveBestModel, train_seq, test_seq\n",
"from utils import normalize_mel_sp_slides\n",
"\n",
"from models import cnn_cnn"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set directories"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"dataDir = 'processed'\n",
"resultsDir = 'Results'\n",
"tempDir = 'temp'\n",
"\n",
"if not os.path.exists(resultsDir):\n",
" os.makedirs(resultsDir)\n",
"if not os.path.exists(tempDir):\n",
" os.makedirs(tempDir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load data"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"fname = 'birds_cornell_spectr_slide_100_species_sr_32000_len_7_sec_500_250_New.h5'\n",
"fileLoc = os.path.join(dataDir,fname) # 19707 samples per class\n",
"hf = h5py.File(fileLoc, 'r')\n",
"mel_sp = hf.get('mel_spectr')[()]\n",
"metadata_total = pd.read_hdf(fileLoc, 'info')\n",
"hf.close()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of Species: 100\n"
]
}
],
"source": [
"original_label = list(metadata_total['ebird_code'])\n",
"lb_bin = LabelBinarizer()\n",
"lb_enc = LabelEncoder()\n",
"labels_one_hot = lb_bin.fit_transform(original_label)\n",
"labels_multi_lbl = lb_enc.fit_transform(original_label)\n",
"\n",
"number_of_sample_classes = len(lb_enc.classes_)\n",
"print(\"Number of Species: \", number_of_sample_classes)\n",
"species_id_class_dict_tp = dict()\n",
"for (class_label, species_id) in enumerate(lb_bin.classes_):\n",
" species_id_class_dict_tp[species_id] = class_label"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Transform data"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"mel_sp_normalized = []\n",
"for i in range(len(mel_sp)):\n",
" xx_ = normalize_mel_sp_slides(mel_sp[i]).astype('float32')\n",
" mel_sp_normalized += [np.expand_dims(xx_, axis=-3)]\n",
"mel_sp_normalized = np.array(mel_sp_normalized)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## configs used in the current paper"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Other configs can be generated using similar pattern"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"cfg_cnn = [[32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M'],\n",
" [32, 64, 'M', 64, 64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M'],\n",
" [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']]\n",
"\n",
"cfg_ts = [[64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M'],\n",
" [64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M'],\n",
" [64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M']]\n",
"nunits = [256*4, 256*4, 512*4]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 16*2\n",
"num_classes=100\n",
"shuffleBatches=True\n",
"num_epoch = 50"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run all configs of CNN+CNN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"exp_no = 0\n",
"for ii, cfg1 in enumerate(cfg_cnn):\n",
" for jj, cfg2 in enumerate(cfg_ts):\n",
" exp_no += 1\n",
" \n",
" skf = StratifiedKFold(n_splits=5, random_state=42)\n",
"\n",
" log_file_name = f'100_species_spectr_cnn_p_cnn_7sec_{exp_no}.p'\n",
" store_ = log_results(file_name=log_file_name, results_dir = resultsDir)\n",
"\n",
" exp_ind = 0\n",
" for train_ind, test_ind in skf.split(mel_sp_normalized, labels_multi_lbl): #5-fold resampling\n",
"\n",
" PATH_curr = os.path.join(tempDir, f'currentModel_cnn_p_cnn_{exp_no}_{exp_ind}.pt')\n",
" saveModel = SaveBestModel(PATH=PATH_curr, monitor=-np.inf, verbose=True)\n",
"\n",
" X_train, X_test_p_valid = mel_sp_normalized[train_ind,:], mel_sp_normalized[test_ind,:]\n",
"\n",
" y_train, y_test_p_valid = labels_one_hot[train_ind], labels_one_hot[test_ind]\n",
" y_train_mlbl, y_test_p_valid_mlbl = labels_multi_lbl[train_ind], labels_multi_lbl[test_ind]\n",
" X_valid, X_test, \\\n",
" y_valid, y_test = train_test_split(X_test_p_valid, y_test_p_valid,\n",
" test_size=0.5,\n",
" stratify=y_test_p_valid_mlbl,\n",
" random_state=42)\n",
"\n",
" print('X_train shape: ', X_train.shape)\n",
" print('X_valid shape: ', X_valid.shape)\n",
" print('X_test shape: ', X_test.shape)\n",
"\n",
" X_train, X_valid = torch.from_numpy(X_train).float(), torch.from_numpy(X_valid).float()\n",
" y_train, y_valid = torch.from_numpy(y_train), torch.from_numpy(y_valid)\n",
"\n",
" y_train, y_valid = y_train.float(), y_valid.float()\n",
" train_use = data_utils.TensorDataset(X_train, y_train)\n",
" train_loader = data_utils.DataLoader(train_use, batch_size=batch_size, shuffle=shuffleBatches)\n",
"\n",
" val_use = data_utils.TensorDataset(X_valid, y_valid)\n",
" val_loader = data_utils.DataLoader(val_use, batch_size=32, shuffle=False)\n",
"\n",
" model = cnn_cnn(cfg1, \n",
" cfg2, \n",
" nunits[jj], \n",
" num_classes=100)\n",
" model.to(device)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001, weight_decay=1e-7)\n",
"\n",
" val_acc_epochs = []\n",
" val_loss_epochs = []\n",
" for epoch in range(1, num_epoch+1):\n",
" train_loss = train_seq(model, train_loader, optimizer, epoch, \n",
" device,\n",
" verbose=1, loss_fn = 'bceLogit')\n",
" val_loss, val_acc = test_seq(model, val_loader,\n",
" device,\n",
" loss_fn = 'bceLogit')\n",
" val_acc_epochs.append(val_acc)\n",
" val_loss_epochs.append(val_loss)\n",
" print('val loss = %f, val acc = %f'%(val_loss, val_acc))\n",
" saveModel.check(model, val_acc, comp='max')\n",
"\n",
" # loading best validated model\n",
" model = cnn_cnn(cfg1, \n",
" cfg2, \n",
" nunits[jj], \n",
" num_classes=100)\n",
" model.to(device)\n",
" model.load_state_dict(torch.load(PATH_curr))\n",
"\n",
" X_test, y_test = torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float()\n",
"\n",
" test_use = data_utils.TensorDataset(X_test, y_test)\n",
" test_loader = data_utils.DataLoader(test_use, batch_size=32, shuffle=False)\n",
" test_loss, test_acc = test_seq(model, test_loader,\n",
" device,\n",
" loss_fn = 'bceLogit')\n",
" print('test loss = %f, test acc = %f'%(test_loss, test_acc))\n",
"\n",
" log_ = dict(\n",
" exp_ind = exp_ind,\n",
" epochs = num_epoch,\n",
" validation_accuracy = val_acc_epochs,\n",
" validation_loss = val_loss_epochs,\n",
" test_loss = test_loss,\n",
" test_accuracy = test_acc,\n",
" X_train_shape = X_train.shape,\n",
" X_valid_shape = X_valid.shape,\n",
" batch_size =batch_size,\n",
" )\n",
" store_.update(log_)\n",
" exp_ind += 1 \n",
" print(f'COMPLETED for configs {ii, jj}...')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.5",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

342
CNN_RNN.ipynb Normal file
Просмотреть файл

@ -0,0 +1,342 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"Licensed under the MIT License."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.optim as optim\n",
"import torch.utils.data as data_utils\n",
"import os\n",
"import numpy as np\n",
"from sklearn.preprocessing import LabelBinarizer, LabelEncoder\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.model_selection import KFold, StratifiedKFold\n",
"import csv\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import glob\n",
"import gc\n",
"import h5py\n",
"import pickle as pk\n",
"\n",
"from utils import log_results, SaveBestModel, train_seq, test_seq\n",
"from utils import normalize_mel_sp_slides\n",
"\n",
"from models import cnn_rnn"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set directories"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"dataDir = 'processed'\n",
"resultsDir = 'Results'\n",
"tempDir = 'temp'\n",
"\n",
"if not os.path.exists(resultsDir):\n",
" os.makedirs(resultsDir)\n",
"if not os.path.exists(tempDir):\n",
" os.makedirs(tempDir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"fname = 'birds_cornell_spectr_slide_100_species_sr_32000_len_7_sec_500_250_New.h5'\n",
"fileLoc = os.path.join(dataDir,fname) # 19707 samples per class\n",
"hf = h5py.File(fileLoc, 'r')\n",
"mel_sp = hf.get('mel_spectr')[()]\n",
"metadata_total = pd.read_hdf(fileLoc, 'info')\n",
"hf.close()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of Species: 100\n"
]
}
],
"source": [
"original_label = list(metadata_total['ebird_code'])\n",
"lb_bin = LabelBinarizer()\n",
"lb_enc = LabelEncoder()\n",
"labels_one_hot = lb_bin.fit_transform(original_label)\n",
"labels_multi_lbl = lb_enc.fit_transform(original_label)\n",
"\n",
"number_of_sample_classes = len(lb_enc.classes_)\n",
"print(\"Number of Species: \", number_of_sample_classes)\n",
"species_id_class_dict_tp = dict()\n",
"for (class_label, species_id) in enumerate(lb_bin.classes_):\n",
" species_id_class_dict_tp[species_id] = class_label"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"mel_sp_normalized = []\n",
"for i in range(len(mel_sp)):\n",
" xx_ = normalize_mel_sp_slides(mel_sp[i]).astype('float32')\n",
" mel_sp_normalized += [np.expand_dims(xx_, axis=-3)]\n",
"mel_sp_normalized = np.array(mel_sp_normalized)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 16*2\n",
"shuffleBatches=True\n",
"num_epoch = 50"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## CNN configs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg_cnn = [32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M'] # CNN1\n",
"# n_units = 128*2\n",
"\n",
"cfg_cnn2 = [32, 64, 'M', 64, 64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M']\n",
"# n_units = 256*2\n",
"\n",
"cfg_cnn3 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] # CNN3\n",
"n_units = 512*2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RNN configs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For RNN, a list of configs could be provided for testing multiple configurations.\n",
"\n",
"Each configuration element is dictionary with key as 'ordered' name of required RNNS. For example, to have 2 layers of GRUs, use 'GRU_0', 'GRU_1', similarly, for 1 GRU followed by 1 LMU, use 'GRU_0', 'LMU_1', contrary, to use LMU and then GRU, use 'LMU_0', 'GRU_1'. Currently supported RNN cells are LSTM, GRU, and LMU.\n",
"\n",
"Each key has value as another dictionary with entries:\n",
"input_size-> input dimension of this RNN cell\n",
"h_states_ctr-> number of inner states in the RNN cell. For LSTM it is 2, GRU has 1, LMU has 2."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hidden_size = 512\n",
"rnnConfigs = [\n",
" {'LSTM_0':{'input_size':n_units, 'h_states_ctr':2},\n",
" 'LSTM_1':{'input_size':hidden_size, 'h_states_ctr':2} # 2 layers of LSTM cell\n",
" },\n",
" {'LMU_0':{'input_size':n_units, 'h_states_ctr':2},\n",
" 'LMU_1':{'input_size':hidden_size, 'h_states_ctr':2}, # 2 layers of LMU cell\n",
" },\n",
" {'GRU_0':{'input_size':n_units, 'h_states_ctr':1},\n",
" 'GRU_1':{'input_size':hidden_size, 'h_states_ctr':1}, # 2 layers of GRU cell\n",
" },\n",
" {'GRU_0':{'input_size':n_units, 'h_states_ctr':1},\n",
" 'LMU_1':{'input_size':hidden_size, 'h_states_ctr':2}, # 1 GRU cell and then 1 LMU cell\n",
" },\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"make sure to assign different exp_no for each experiments"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"exp_no_base = 0\n",
"exp_ctr = 0\n",
"for ii, cfg in enumerate(rnnConfigs):\n",
" exp_ctr += 1\n",
"\n",
" exp_no = exp_no_base + exp_ctr\n",
" log_file_name = f'100_species_spectr_cnn_rnn_7sec_h_{hidden_size}_nl_{ii+1}_{exp_no}.p'\n",
" store_ = log_results(file_name=log_file_name, results_dir = resultsDir)\n",
" PATH_curr = os.path.join(tempDir, f'currentModel_cnn_rnn_{exp_no}.pt')\n",
" saveModel = SaveBestModel(PATH=PATH_curr, monitor=-np.inf, verbose=True)\n",
"\n",
" exp_ind = 0\n",
" skf = StratifiedKFold(n_splits=5, random_state=42)\n",
" for train_ind, test_ind in skf.split(mel_sp_normalized, labels_multi_lbl):\n",
"\n",
" PATH_curr = os.path.join(tempDir, f'currentModel_cnn_rnn_{exp_no}_{exp_ind}.pt')\n",
" saveModel = SaveBestModel(PATH=PATH_curr, monitor=-np.inf, verbose=True)\n",
"\n",
" X_train, X_test_p_valid = mel_sp_normalized[train_ind,:], mel_sp_normalized[test_ind,:]\n",
"\n",
" y_train, y_test_p_valid = labels_one_hot[train_ind], labels_one_hot[test_ind]\n",
" y_train_mlbl, y_test_p_valid_mlbl = labels_multi_lbl[train_ind], labels_multi_lbl[test_ind]\n",
" X_valid, X_test, y_valid, y_test = train_test_split(X_test_p_valid, y_test_p_valid,\n",
" test_size=0.5,\n",
" stratify=y_test_p_valid_mlbl,\n",
" random_state=42)\n",
"\n",
" print('X_train shape: ', X_train.shape)\n",
" print('X_valid shape: ', X_valid.shape)\n",
" print('X_test shape: ', X_test.shape)\n",
"\n",
" X_train, X_valid = torch.from_numpy(X_train).float(), torch.from_numpy(X_valid).float()\n",
" y_train, y_valid = torch.from_numpy(y_train), torch.from_numpy(y_valid)\n",
"\n",
" y_train, y_valid = y_train.float(), y_valid.float()\n",
" train_use = data_utils.TensorDataset(X_train, y_train)\n",
" train_loader = data_utils.DataLoader(train_use, batch_size=batch_size, shuffle=shuffleBatches)\n",
"\n",
" val_use = data_utils.TensorDataset(X_valid, y_valid)\n",
" val_loader = data_utils.DataLoader(val_use, batch_size=32, shuffle=False)\n",
"\n",
" model = cnn_rnn(cnnConfig = cfg_cnn3, \n",
" rnnConfig = cfg, \n",
" hidden_size=hidden_size, \n",
" order=order,\n",
" theta=theta,\n",
" num_classes=100)\n",
" model.to(device)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001, weight_decay=1e-7)\n",
"\n",
" val_acc_epochs = []\n",
" val_loss_epochs = []\n",
" for epoch in range(1, num_epoch+1):\n",
" train_loss = train_seq(model, train_loader, optimizer, epoch, \n",
" device,\n",
" verbose=1, loss_fn = 'bceLogit')\n",
" val_loss, val_acc = test_seq(model, val_loader,\n",
" device,\n",
" loss_fn = 'bceLogit')\n",
" val_acc_epochs.append(val_acc)\n",
" val_loss_epochs.append(val_loss)\n",
" print('val loss = %f, val acc = %f'%(val_loss, val_acc))\n",
" saveModel.check(model, val_acc, comp='max')\n",
"\n",
" # loading best validated model\n",
" model = cnn_rnn(cnnConfig = cfg_cnn3, \n",
" rnnConfig = cfg, \n",
" hidden_size=hidden_size, \n",
" order=order,\n",
" theta=theta,\n",
" num_classes=100)\n",
" model.to(device)\n",
" model.load_state_dict(torch.load(PATH_curr))\n",
"\n",
" X_test, y_test = torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float()\n",
"\n",
" test_use = data_utils.TensorDataset(X_test, y_test)\n",
" test_loader = data_utils.DataLoader(test_use, batch_size=32, shuffle=False)\n",
" test_loss, test_acc = test_seq(model, test_loader,\n",
" device,\n",
" loss_fn = 'bceLogit')\n",
" print('test loss = %f, test acc = %f'%(test_loss, test_acc))\n",
"\n",
" log_ = dict(\n",
" exp_ind = exp_ind,\n",
" epochs = num_epoch,\n",
" validation_accuracy = val_acc_epochs,\n",
" validation_loss = val_loss_epochs,\n",
" test_loss = test_loss,\n",
" test_accuracy = test_acc,\n",
" X_train_shape = X_train.shape,\n",
" X_valid_shape = X_valid.shape,\n",
" batch_size =batch_size,\n",
" )\n",
" store_.update(log_)\n",
" exp_ind += 1 "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.5",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Просмотреть файл

@ -1,28 +1,23 @@
# Project
# Recurrent Concolutional Neural networks for sound classification
> This repo has been populated by an initial template to help get you started. Please
> make sure to update the content to build a great experience for community-building.
We present a deep learning approach towards the large-scale prediction and analysis of bird acoustics from 100 different birdspecies. We use spectrograms constructed on bird audio recordings from the Cornell Bird Challenge (CBC)2020 dataset, which includes recordings of multiple and potentially overlapping bird vocalizations per audio and recordings with background noise. Our experiments show that a hybrid modeling approach that involves a Convolutional Neural Network (CNN) for learning therepresentation for a slice of the spectrogram, and a Recurrent Neural Network (RNN) for the temporal component to combineacross time-points leads to the most accurate model on this dataset. The code has models ranging from stand-alone CNNs to hybrid models of various types obtained by combining CNNs with CNNs or RNNs of the following types:Long Short-Term Memory (LSTM) networks, Gated Recurrent Units (GRU) and Legendre Memory Units (LMU).
As the maintainer of this project, please make a few updates:
## Setup
- Improving this README.MD file to provide a great experience
- Updating SUPPORT.MD with content about this project's support experience
- Understanding the security reporting process in SECURITY.MD
- Remove this section from the README
### Requirements
The code package is developed using Python 3.6 and Pytorch 1.2 with cuda 10.0. For running the experiments first install the required packages using 'requirements.txt'
## Contributing
## Experiments
The data for bird sound classification is downloaded from the Kaggle competition [Cornell birdcall Identification](https://www.kaggle.com/c/birdsong-recognition).
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
For running the experiments, a data preprocessing pipeline is demostrated in the process_data.ipynb
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
After preprocessing the data, the RCNN models with various combinations of representation/temporal models can be run as follows:
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
### CNN + CNN
An example is shown in CNN_CNN.ipynb notebook for the CNN and TCNN configs taken in the paper. In a similar way, a different set of configs could be supplied to the cnn+cnn model.
### CNN + RNN
An exampe for CNN+GRU, CNN+LMU, and CNN+LSTM is shown in CNN_RNN.ipynh notebook. Other variants of RCNNs with different set of parameters can be set as explained in the notebook.
## Trademarks

715
models.py Normal file
Просмотреть файл

@ -0,0 +1,715 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data_utils
import numpy as np
import math
from typing import List, Tuple
from scipy.special import legendre
from nengolib.signal import Identity, cont2discrete
from nengolib.synapses import LegendreDelay
# VGG pytorch model is taken from:
# https://pytorch.org/vision/stable/_modules/torchvision/models/vgg.html
cfg_vgg16 = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
class VGG16_pool(nn.Module):
def __init__(self, cfg=cfg_vgg16, num_classes=10, init_weights=True):
super(VGG16_pool, self).__init__()
self.convBlock = self.make_layers(cfg)
self.avgpool = nn.AdaptiveAvgPool2d((7,7))
self.Dense1 = nn.Linear(512*7*7, 4096)
self.Dense2 = nn.Linear(4096, 4096)
self.Dense3 = nn.Linear(4096, num_classes)
self.dropout1 = nn.Dropout(0.5)
self.dropout2 = nn.Dropout(0.5)
if init_weights:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.convBlock(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = F.relu(self.Dense1(x))
x = self.dropout1(x)
x = F.relu(self.Dense2(x))
x = self.Dense3(x)
return x
def make_layers(self, cfg):
layers = []
in_channels = 3
for layer in cfg:
if layer == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, layer, kernel_size=3, padding=1)
layers += [conv2d, nn.BatchNorm2d(layer), nn.ReLU(inplace=True)]
in_channels = layer
return nn.Sequential(*layers)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
# Resnet pytorch model is taken from:
# https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
def resnet50(**kwargs):
return ResNet(Bottleneck, [3, 4, 6, 3],
**kwargs)
def resnet18(**kwargs):
return ResNet(Bottleneck, [2, 2, 2, 2],
**kwargs)
class cnn(nn.Module):
def __init__(self, cfg, init_weights=True):
super(cnn, self).__init__()
self.convBlock = self.make_layers(cfg)
self.avgpool = nn.AdaptiveAvgPool2d((2,1))
if init_weights:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.convBlock(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
def make_layers(self, cfg):
layers = []
in_channels = 1
for layer in cfg:
if layer == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, layer, kernel_size=3, padding=1)
layers += [conv2d, nn.BatchNorm2d(layer), nn.ReLU(inplace=True)]
in_channels = layer
return nn.Sequential(*layers)
class cnn_ts(nn.Module):
def __init__(self, cfg_ts, init_weights=True):
super(cnn_ts, self).__init__()
self.conv = self.make_layers(cfg_ts)
self.avgpool = nn.AdaptiveAvgPool2d((1,4))
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.conv(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def make_layers(self, cfg):
layers = []
in_channels = 1
for layer in cfg:
if layer == 'M':
layers += [nn.MaxPool2d(kernel_size=(2,3), stride=(2,3))]
else:
conv2d = nn.Conv2d(in_channels, layer, kernel_size=3, padding=1)
layers += [conv2d, nn.BatchNorm2d(layer), nn.ReLU(inplace=True)]
in_channels = layer
return nn.Sequential(*layers)
class cnn_cnn(nn.Module):
def __init__(self,
cfg,
cfg_ts,
n_units_ts,
num_classes=100):
super(cnn_cnn, self).__init__()
self.cnn = cnn(cfg)
self.cnn1 = cnn_ts(cfg_ts)
self.linear1 = nn.Linear(n_units_ts, 512)
self.dropout1 = nn.Dropout(0.5)
self.linear2 = nn.Linear(512, num_classes)
def forward(self, x, hidden_state=None):
batch_size, timesteps, C, H, W = x.size()
c_in = x.view(batch_size * timesteps, C, H, W)
c_out = self.cnn(c_in)
x = c_out.view(batch_size, timesteps, -1)
x = x.unsqueeze(1)
x = self.cnn1(x)
x = F.relu(self.linear1(x))
x = self.dropout1(x)
x = self.linear2(x)
return x
def init_hidden(self, batch_size):
return torch.zeros(1, batch_size, 1)
class cnn_lstm(nn.Module):
def __init__(self, n_units, cfg, num_layers=2, num_classes=100):
super(cnn_lstm, self).__init__()
self.cnn = cnn(cfg)
self.hidden_size = 512
self.num_layers = num_layers
self.rnn1 = nn.LSTM(input_size=n_units, hidden_size=self.hidden_size, batch_first=True, num_layers=self.num_layers)
self.linear1 = nn.Linear(self.hidden_size, 512)
self.dropout1 = nn.Dropout(0.5)
self.linear2 = nn.Linear(512, num_classes)
def forward(self, x, hidden_state=None):
batch_size, timesteps, C, H, W = x.size()
c_in = x.view(batch_size * timesteps, C, H, W)
c_out = self.cnn(c_in)
r_in = c_out.view(batch_size, timesteps, -1)
x, _ = self.rnn1(r_in, hidden_state)
x = x.sum(dim=1)
x = self.linear1(x)
# x = self.linear1(x[:, -1, :]) # feeding last ouput of seq to linear layer (OPTIONAL)
x = F.relu(x)
x = self.dropout1(x)
x = self.linear2(x)
return x
def init_hidden(self, batch_size):
return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
torch.zeros(self.num_layers, batch_size, self.hidden_size))
def Legendre(shape):
if len(shape) != 2:
raise ValueError("Legendre initializer assumes shape is 2D; "
"but shape=%s" % (shape,))
return np.asarray([legendre(i)(np.linspace(-1, 1, shape[1]))
for i in range(shape[0])])
# LMU cell taken from: https://github.com/nengo/keras-lmu
# and converted to pytorch
class LMUCell(nn.Module):
def __init__(self,
input_dim,
units,
order,
theta, # relative to dt=1
method='zoh',
realizer=Identity(), # TODO: Deprecate?
factory=LegendreDelay, # TODO: Deprecate?
trainable_input_encoders=True,
trainable_hidden_encoders=True,
trainable_memory_encoders=True,
trainable_input_kernel=True,
trainable_hidden_kernel=True,
trainable_memory_kernel=True,
trainable_A=False,
trainable_B=False,
input_encoders_initializer='lecun_uniform',
input_encoders_initial_val = 0,
hidden_encoders_initializer='lecun_uniform',
hidden_encoders_initial_val = 0,
memory_encoders_initializer='Constant', # 'lecun_uniform',
memory_encoders_initial_val = 0,
input_kernel_initializer='glorot_normal',
input_kernel_initial_val = 0,
hidden_kernel_initializer='glorot_normal',
hidden_kernel_initial_val = 0,
memory_kernel_initializer='glorot_normal',
memory_kernel_initial_val = 0,
hidden_activation='tanh',
**kwargs):
super(LMUCell,self).__init__()
self.units = units
self.order = order
self.theta = theta
self.method = method
self.realizer = realizer
self.factory = factory
self.trainable_input_encoders = trainable_input_encoders
self.trainable_hidden_encoders = trainable_hidden_encoders
self.trainable_memory_encoders = trainable_memory_encoders
self.trainable_input_kernel = trainable_input_kernel
self.trainable_hidden_kernel = trainable_hidden_kernel
self.trainable_memory_kernel = trainable_memory_kernel
self.trainable_A = trainable_A
self.trainable_B = trainable_B
self.hidden_activation = hidden_activation
self._realizer_result = realizer(
factory(theta=theta, order=self.order))
self._ss = cont2discrete(
self._realizer_result.realization, dt=1., method=method)
self._A = self._ss.A - np.eye(order) # puts into form: x += Ax
self._B = self._ss.B
self._C = self._ss.C
assert np.allclose(self._ss.D, 0) # proper LTI
self.state_size = (self.units, self.order)
self.output_size = self.units
def weight_mod(input_dim, output_dim, initialization,
constant_val = 0):
w = torch.FloatTensor(input_dim, output_dim)
w.requires_grad = True
if initialization == 'lecun_uniform':
torch.nn.init.kaiming_uniform_(w)
elif initialization == 'glorot_normal':
torch.nn.init.xavier_normal_(w)
elif initialization == 'Constant':
if np.size(constant_val) == 1:
torch.nn.init.constant_(w, constant_val)
else:
w.data = torch.from_numpy(constant_val).float()
elif initialization == 'Legendre':
w.data = torch.from_numpy(Legendre((input_dim, output_dim))).float()
elif initialization == 'uniform':
stdv = 1.0 / math.sqrt(self.state_size[0])
torch.nn.init.uniform_(w, -stdv, stdv)
return w
self.input_encoders = nn.Parameter(weight_mod(input_dim, 1,
initialization=input_encoders_initializer,
constant_val = input_encoders_initial_val)
)
if not self.trainable_input_encoders:
self.input_encoders.requires_grad = False
self.hidden_encoders = nn.Parameter(weight_mod(self.units, 1,
initialization=hidden_encoders_initializer,
constant_val = hidden_encoders_initial_val)
)
if not self.trainable_hidden_encoders:
self.hidden_encoders.requires_grad = False
self.memory_encoders = nn.Parameter(weight_mod(self.order, 1,
initialization='Constant',
constant_val=0)
)
if not self.trainable_memory_encoders:
self.memory_encoders.requires_grad = False
self.input_kernel = nn.Parameter(weight_mod(input_dim, self.units,
initialization=input_kernel_initializer,
constant_val = input_kernel_initial_val)
)
if not self.trainable_input_kernel:
self.input_kernel.requires_grad = False
self.hidden_kernel = nn.Parameter(weight_mod(self.units, self.units,
initialization=hidden_kernel_initializer,
constant_val = hidden_kernel_initial_val)
)
if not self.trainable_hidden_kernel:
self.hidden_kernel.requires_grad = False
self.memory_kernel = nn.Parameter(weight_mod(self.order, self.units,
initialization=memory_kernel_initializer,
constant_val = memory_kernel_initial_val)
)
if not self.trainable_memory_kernel:
self.memory_kernel.requires_grad = False
self.AT = nn.Parameter(weight_mod(self.order, self.order,
initialization='Constant',
constant_val=self._A.T) # transposed
)
if not self.trainable_A:
self.AT.requires_grad = False
self.BT = nn.Parameter(weight_mod(1, self.order,
initialization='Constant',
constant_val=self._B.T) # transposed
)
if not self.trainable_B:
self.BT.requires_grad = False
def forward(self, inputs, states):
h, m = states
u = torch.mm(inputs, self.input_encoders) \
+ torch.mm(h, self.hidden_encoders) \
+ torch.mm(m, self.memory_encoders)
m = m + torch.mm(m, self.AT) + torch.mm(u, self.BT)
if self.hidden_activation == 'tanh':
h = torch.tanh(
torch.mm(inputs, self.input_kernel) +
torch.mm(h, self.hidden_kernel) +
torch.mm(m, self.memory_kernel)
)
elif self.hidden_activation == 'linear':
h = torch.mm(inputs, self.input_kernel) \
+ torch.mm(h, self.hidden_kernel) \
+ torch.mm(m, self.memory_kernel)
return h, (h, m)
# using https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
# for custom LSTMs
class LMU(nn.Module):
def __init__(self, inp_size = 1, order = 100, theta=100, output_dims = 1):
super(LMU, self).__init__()
self.units = output_dims
self.order = order
self.output_dims = output_dims
self.lmu_cell = LMUCell(
input_dim = inp_size,
units=output_dims,
order=order,
theta = theta,
input_encoders_initializer='uniform',
hidden_encoders_initializer='uniform',
memory_encoders_initializer='uniform', # 'lecun_uniform',
input_kernel_initializer='uniform',
hidden_kernel_initializer='uniform',
memory_kernel_initializer='uniform',
)
def forward(self, x, state):
x = x.unbind(1)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(x)):
out, state = self.lmu_cell(x[i], state)
outputs += [out]
return torch.stack(outputs).permute(1, 0, 2), state # axes permuted to make output of shape B, seq_len, num_outputs
def init_hidden(self, batch_size):
return (torch.zeros(batch_size, self.units),
torch.zeros(batch_size, self.order))
# using https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
# for custom LSTMs
class LMUGate(nn.Module):
def __init__(self, inp_size = 1, order = 100, theta=100, output_dims = 1):
super(LMUGate, self).__init__()
self.units = output_dims
self.order = order
self.output_dims = output_dims
self.lmu_cell = LMUCellGate(
input_dim = inp_size,
units=output_dims,
order=order,
theta = theta,
)
def forward(self, x, state):
x = x.unbind(1)
outputs = torch.jit.annotate(List[Tensor], [])
for i in range(len(x)):
out, state = self.lmu_cell(x[i], state)
outputs += [out]
return torch.stack(outputs).permute(1, 0, 2), state # axes permuted to make output of shape B, seq_len, num_outputs
def init_hidden(self, batch_size):
return (torch.zeros(batch_size, self.units),
torch.zeros(batch_size, self.order))
class cnn_rnn(nn.Module):
def __init__(self,
theta=500,
num_classes=100,
order=100,
hidden_size = 512,
cnnConfig = cfg_vgg16,
rnnConfig = {
'LMU_0':
{'input_size':1024, 'h_states_ctr':1}
}
):
super(cnn_rnn, self).__init__()
# sample: GRU -> LMU
# rnnConfig = {
# 'GRU_0':
# {'input_size':n_units, 'h_states_ctr':1},
# 'LMU_1':
# {'input_size':self.hidden_size, 'h_states_ctr':2},
# }
self.cnn = cnn(cnnConfig)
self.hidden_size = hidden_size
self.order = order
self.theta = theta
self.rnnConfig = rnnConfig
self.rnnBlock = self.make_rnn()
self.linear1 = nn.Linear(self.hidden_size, 512)
self.dropout1 = nn.Dropout(0.5)
self.linear2 = nn.Linear(512, num_classes)
def forward(self, x, hidden_state):
batch_size, timesteps, C, H, W = x.size()
c_in = x.view(batch_size * timesteps, C, H, W)
c_out = self.cnn(c_in)
x = c_out.view(batch_size, timesteps, -1)
ctr = 0
for (rnn_name, config_), rnn in zip(self.rnnConfig.items(), self.rnnBlock):
if config_['h_states_ctr']==1:
h_s = hidden_state[ctr]
else:
h_s = tuple([hidden_state[ctr+i] for i in range(config_['h_states_ctr'])])
ctr += config_['h_states_ctr']
x, _ = rnn(x, h_s)
x = x.sum(dim=1)
x = self.linear1(x)
x = F.relu(x)
x = self.dropout1(x)
x = self.linear2(x)
return x
def make_rnn(self):
layers = []
for rnn_name, config_ in self.rnnConfig.items():
if rnn_name.split('_')[0] == 'LMU':
layers += [LMU(inp_size=config_['input_size'],
order=self.order, theta=self.theta,
output_dims=self.hidden_size)]
elif rnn_name.split('_')[0] == 'GRU':
layers += [nn.GRU(input_size=config_['input_size'],
hidden_size=self.hidden_size, batch_first=True)]
elif rnn_name.split('_')[0] == 'LSTM':
layers += [nn.LSTM(input_size=config_['input_size'],
hidden_size=self.hidden_size, batch_first=True)]
return nn.ModuleList(layers)
def init_hidden(self, batch_size):
h_s = []
for rnn_name in self.rnnConfig.keys():
if rnn_name.split('_')[0] == 'LMU':
h_s += [torch.zeros(batch_size, self.hidden_size)]
h_s += [torch.zeros(batch_size, self.order)]
elif rnn_name.split('_')[0] == 'GRU':
h_s += [torch.zeros(1, batch_size, self.hidden_size)]
elif rnn_name.split('_')[0] == 'LSTM':
h_s += [torch.zeros(1, batch_size, self.hidden_size)]
h_s += [torch.zeros(1, batch_size, self.hidden_size)]
return tuple(h_s)

626
process_data.ipynb Normal file

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

9
requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,9 @@
pytorch == 1.2.0
librosa == 0.8.0
pandas == 1.1.1
nengolib == 0.5.2
nengo == 2.8.0
h5py == 2.10.0
fastprogress == 1.0.0
audioread == 2.1.8
scikit-learn == 0.21.3

253
spectr_vgg16.ipynb Normal file
Просмотреть файл

@ -0,0 +1,253 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"Licensed under the MIT License."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.optim as optim\n",
"import torch.utils.data as data_utils\n",
"import os\n",
"import numpy as np\n",
"from sklearn.preprocessing import LabelBinarizer, LabelEncoder\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.model_selection import KFold, StratifiedKFold\n",
"import csv\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import glob\n",
"import gc\n",
"import h5py\n",
"import pickle as pk\n",
"\n",
"from utils import log_results, SaveBestModel, train, test\n",
"from utils import mel_sp_to_image\n",
"\n",
"from models import VGG16_pool"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set directories"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"dataDir = 'processed'\n",
"resultsDir = 'Results'\n",
"tempDir = 'temp'\n",
"\n",
"if not os.path.exists(resultsDir):\n",
" os.makedirs(resultsDir)\n",
"if not os.path.exists(tempDir):\n",
" os.makedirs(tempDir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"fname = 'birds_cornell_spectr_100_species_sr_32000_len_7_sec_New.h5'\n",
"fileLoc = os.path.join(dataDir,fname) # 19707 samples per class\n",
"hf = h5py.File(fileLoc, 'r')\n",
"mel_sp = hf.get('mel_spectr')[()]\n",
"metadata_total = pd.read_hdf(fileLoc, 'info')\n",
"hf.close()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"original_label = list(metadata_total['ebird_code'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lb_bin = LabelBinarizer()\n",
"lb_enc = LabelEncoder()\n",
"labels_one_hot = lb_bin.fit_transform(original_label)\n",
"labels_multi_lbl = lb_enc.fit_transform(original_label)\n",
"\n",
"number_of_sample_classes = len(lb_enc.classes_)\n",
"print(\"Number of Species: \", number_of_sample_classes)\n",
"species_id_class_dict_tp = dict()\n",
"for (class_label, species_id) in enumerate(lb_bin.classes_):\n",
" species_id_class_dict_tp[species_id] = class_label"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"mel_sp_normalized = []\n",
"for i in range(len(mel_sp)):\n",
" xx_ = mel_sp_to_image(mel_sp[i]).astype('float32')\n",
" mel_sp_normalized += [np.rollaxis(xx_, 2, 0)]\n",
"mel_sp_normalized = np.array(mel_sp_normalized)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 16*2\n",
"num_classes=100\n",
"shuffleBatches=True\n",
"num_epoch = 50"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"skf = StratifiedKFold(n_splits=5, random_state=42)\n",
"exp_no = 0\n",
"log_file_name = f'100_species_spectr_vgg16_7sec_{exp_no}.p'\n",
"store_ = log_results(file_name=log_file_name, results_dir = resultsDir)\n",
"\n",
"exp_ind = 0\n",
"for train_ind, test_ind in skf.split(mel_sp_normalized, labels_multi_lbl):\n",
" \n",
" PATH_curr = os.path.join(tempDir, f'currentModel_vgg16_{exp_no}_{exp_ind}.pt')\n",
" saveModel = SaveBestModel(PATH=PATH_curr, monitor=-np.inf, verbose=True)\n",
"\n",
" X_train, X_test_p_valid = mel_sp_normalized[train_ind,:], mel_sp_normalized[test_ind,:]\n",
" \n",
" y_train, y_test_p_valid = labels_one_hot[train_ind], labels_one_hot[test_ind]\n",
" y_train_mlbl, y_test_p_valid_mlbl = labels_multi_lbl[train_ind], labels_multi_lbl[test_ind]\n",
" X_valid, X_test, y_valid, y_test = train_test_split(X_test_p_valid, y_test_p_valid,\n",
" test_size=0.5,\n",
" stratify=y_test_p_valid_mlbl,\n",
" random_state=42)\n",
"\n",
" print('X_train shape: ', X_train.shape)\n",
" print('X_valid shape: ', X_valid.shape)\n",
" print('X_test shape: ', X_test.shape)\n",
"\n",
" X_train, X_valid = torch.from_numpy(X_train).float(), torch.from_numpy(X_valid).float()\n",
" y_train, y_valid = torch.from_numpy(y_train), torch.from_numpy(y_valid)\n",
" \n",
" y_train, y_valid = y_train.float(), y_valid.float()\n",
" train_use = data_utils.TensorDataset(X_train, y_train)\n",
" train_loader = data_utils.DataLoader(train_use, batch_size=batch_size, shuffle=shuffleBatches)\n",
"\n",
" val_use = data_utils.TensorDataset(X_valid, y_valid)\n",
" val_loader = data_utils.DataLoader(val_use, batch_size=32, shuffle=False)\n",
" \n",
" model = VGG16_pool(num_classes=100)\n",
" model.to(device)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001, weight_decay=1e-7)\n",
"\n",
" val_acc_epochs = []\n",
" val_loss_epochs = []\n",
" for epoch in range(1, num_epoch+1):\n",
" train_loss = train(model, train_loader, optimizer, epoch, \n",
" device,\n",
" verbose=1, loss_fn = 'bceLogit')\n",
" val_loss, val_acc = test(model, val_loader,\n",
" device,\n",
" loss_fn = 'bceLogit')\n",
" val_acc_epochs.append(val_acc)\n",
" val_loss_epochs.append(val_loss)\n",
" print('val loss = %f, val acc = %f'%(val_loss, val_acc))\n",
" saveModel.check(model, val_acc, comp='max')\n",
" \n",
" # loading best validated model\n",
" model = VGG16_pool(num_classes=100)\n",
" model.to(device)\n",
" model.load_state_dict(torch.load(PATH_curr))\n",
"\n",
" X_test, y_test = torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float()\n",
"\n",
" test_use = data_utils.TensorDataset(X_test, y_test)\n",
" test_loader = data_utils.DataLoader(test_use, batch_size=32, shuffle=False)\n",
" test_loss, test_acc = test(model, test_loader,\n",
" device,\n",
" loss_fn = 'bceLogit')\n",
" print('test loss = %f, test acc = %f'%(test_loss, test_acc))\n",
" \n",
" log_ = dict(\n",
" exp_ind = exp_ind,\n",
" epochs = num_epoch,\n",
" validation_accuracy = val_acc_epochs,\n",
" validation_loss = val_loss_epochs,\n",
" test_loss = test_loss,\n",
" test_accuracy = test_acc,\n",
" X_train_shape = X_train.shape,\n",
" X_valid_shape = X_valid.shape,\n",
" batch_size =batch_size,\n",
" )\n",
" store_.update(log_)\n",
" exp_ind += 1 "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.5",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

325
utils.py Normal file
Просмотреть файл

@ -0,0 +1,325 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import torch
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
import numpy as np
import librosa
from joblib import Parallel, delayed
import multiprocessing
import h5py
import pickle as pk
import cv2
import os
import matplotlib.pyplot as plt
from fastprogress.fastprogress import progress_bar
class log_results(object):
def __init__(self, file_name = 'log', results_dir = 'Results'):
self.results_dir = results_dir
self.fname = file_name
if not os.path.exists(self.results_dir):
os.makedirs(results_dir)
def update(self, log):
file_path = os.path.join(self.results_dir, self.fname)
if isinstance(log, dict):
pk.dump(log, open(file_path, 'ab'))
else:
print('log has to be in dictionary format')
class SaveBestModel(object):
def __init__(self, monitor = np.inf, PATH = './currTorchModel.pt',
verbose=False):
self.monitor = monitor
self.PATH = PATH
self.verbose = verbose
def check(self, model, currVal, comp='min'):
if comp is 'min':
if currVal < self.monitor:
self.monitor = currVal
torch.save(model.state_dict(), self.PATH)
if self.verbose:
print('saving best model...')
elif comp is 'max':
if currVal > self.monitor:
self.monitor = currVal
torch.save(model.state_dict(), self.PATH)
if self.verbose:
print('saving best model...')
def normalize_mel_sp_slides(X, eps=1e-6):
mean = X.mean()
X = X - mean
std = X.std()
Xstd = X / (std + eps)
_min, _max = Xstd.min(), Xstd.max()
norm_max = _max
norm_min = _min
if (_max - _min) > eps:
# Normalize to [0, 255]
V = Xstd
V[V < norm_min] = norm_min
V[V > norm_max] = norm_max
V = (V - norm_min) / (norm_max - norm_min)
else:
V = np.zeros_like(X, dtype=np.uint8)
return V
def mel_sp_slides_to_image(X, eps=1e-6, resize=False, nrow=224, ncol=224):
mean = X.mean()
X = X - mean
std = X.std()
Xstd = X / (std + eps)
# cmap = plt.cm.jet
cmap = plt.cm.viridis
norm = plt.Normalize(vmin=Xstd.min(), vmax=Xstd.max())
# map the normalized data to colors
# image is now RGBA (nrowxncolx4)
# last channel is alpha value for transparency, set to 1
image = cmap(norm(Xstd))
if resize:
return cv2.resize(
image[:,:,:3], (nrow, ncol),
interpolation=cv2.INTER_LINEAR
)
else:
return image[:,:,:,:3]
def mel_sp_to_image(X, eps=1e-6, nrow=224, ncol=224):
mean = X.mean()
X = X - mean
std = X.std()
Xstd = X / (std + eps)
# cmap = plt.cm.jet
cmap = plt.cm.viridis
norm = plt.Normalize(vmin=Xstd.min(), vmax=Xstd.max())
# map the normalized data to colors
# image is now RGBA (nrowxncolx4)
# last channel is alpha value for transparency, set to 1
image = cmap(norm(Xstd))
return cv2.resize(image[:,:,:3], (nrow, ncol),
interpolation=cv2.INTER_LINEAR
)
def train_seq(model, train_loader, optimizer, epoch, device, verbose = 0,
lr_schedule = None, weight = None, loss_fn = 'crossEnt'):
"""Training"""
if lr_schedule is not None:
optimizer = lr_schedule(optimizer, epoch)
model.train()
for batch_idx, (data, target) in enumerate(progress_bar(train_loader)):
h_s = model.init_hidden(len(data))
if isinstance(h_s, tuple):
h_s = tuple([x.to(device) for x in h_s])
else:
h_s = h_s.to(device)
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data, h_s)
if loss_fn == 'crossEnt':
criteria = nn.CrossEntropyLoss().cuda()
elif loss_fn == 'bceLogit':
criteria = nn.BCEWithLogitsLoss().cuda()
loss = criteria(output, target)
loss.backward()
optimizer.step()
if verbose>0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
return loss.item()
def evalModel_seq(data_loader, model, device, verbose=0, stochastic_pass = True,
compute_metrics=True, activationName = None,
loss_fn = 'crossEnt'):
if stochastic_pass:
model.train()
else:
model.eval()
test_loss = 0
predictions = []
activations = []
correct = 0
with torch.no_grad():
for data, target in data_loader:
data, target = data.to(device), target.to(device)
h_s = model.init_hidden(len(data))
if isinstance(h_s, tuple):
h_s = tuple([x.to(device) for x in h_s])
else:
h_s =h_s.to(device)
output = model(data, h_s)
if compute_metrics:
predictionClasses = output.argmax(dim=1, keepdim=True)
if loss_fn == 'crossEnt':
criteria = nn.CrossEntropyLoss().cuda()
correct += predictionClasses.eq(target.view_as(predictionClasses)).sum().item()
elif loss_fn == 'bceLogit':
criteria = nn.BCEWithLogitsLoss().cuda()
correct += predictionClasses.eq(target.argmax(dim=1).view_as(predictionClasses)).sum().item()
test_loss += criteria(output, target).sum().item()
else:
softmaxed = F.softmax(output.cpu(), dim=1)
predictions.extend(softmaxed.data.numpy())
if compute_metrics:
return test_loss, correct
else:
return predictions, activations
def test_seq(model, test_loader, device, verbose=0, activationName = None,
loss_fn = 'crossEnt'):
"""Testing"""
model.eval()
test_loss = 0
correct = 0
total_test_loss, total_corrections = evalModel_seq(test_loader, model, device=device,
verbose = verbose,
stochastic_pass = False, compute_metrics = True,
activationName = activationName, loss_fn = loss_fn)
test_loss = total_test_loss/ len(test_loader) # loss function already averages over batch size
test_acc = total_corrections / len(test_loader.dataset)
if verbose>0:
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
print('{{"metric": "Eval - cross entropy Loss", "value": {}, "epoch": {}}}'.format(
test_loss, epoch))
print('{{"metric": "Eval - Accuracy", "value": {}, "epoch": {}}}'.format(
100. * correct / len(test_loader.dataset), epoch))
return test_loss, test_acc
def train(model, train_loader, optimizer, epoch, device, verbose = 0,
lr_schedule = None, weight = None, loss_fn = 'crossEnt'):
"""Training"""
if lr_schedule is not None:
optimizer = lr_schedule(optimizer, epoch)
model.train()
for batch_idx, (data, target) in enumerate(progress_bar(train_loader)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
if loss_fn == 'crossEnt':
criteria = nn.CrossEntropyLoss().cuda()
elif loss_fn == 'bceLogit':
criteria = nn.BCEWithLogitsLoss().cuda()
loss = criteria(output, target)
loss.backward()
optimizer.step()
if verbose>0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
return loss.item()
def evalModel(data_loader, model, device, verbose=0, stochastic_pass = True,
compute_metrics=True, activationName = None,
loss_fn = 'crossEnt'):
if stochastic_pass:
model.train()
else:
model.eval()
test_loss = 0
predictions = []
activations = []
correct = 0
with torch.no_grad():
for data, target in data_loader:
data, target = data.to(device), target.to(device)
output = model(data)
if compute_metrics:
predictionClasses = output.argmax(dim=1, keepdim=True)
if loss_fn == 'crossEnt':
criteria = nn.CrossEntropyLoss().cuda()
correct += predictionClasses.eq(target.view_as(predictionClasses)).sum().item()
elif loss_fn == 'bceLogit':
criteria = nn.BCEWithLogitsLoss().cuda()
correct += predictionClasses.eq(target.argmax(dim=1).view_as(predictionClasses)).sum().item()
test_loss += criteria(output, target).sum().item()
else:
softmaxed = F.softmax(output.cpu(), dim=1)
predictions.extend(softmaxed.data.numpy())
if compute_metrics:
return test_loss, correct
else:
return predictions, activations
def test(model, test_loader, device, verbose=0, activationName = None,
loss_fn = 'crossEnt'):
"""Testing"""
model.eval()
test_loss = 0
correct = 0
total_test_loss, total_corrections = evalModel(test_loader, model, device=device,
verbose = verbose,
stochastic_pass = False, compute_metrics = True,
activationName = activationName,
loss_fn=loss_fn)
test_loss = total_test_loss/ len(test_loader) # loss function already averages over batch size
test_acc = total_corrections / len(test_loader.dataset)
if verbose>0:
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
print('{{"metric": "Eval - cross entropy Loss", "value": {}, "epoch": {}}}'.format(
test_loss, epoch))
print('{{"metric": "Eval - Accuracy", "value": {}, "epoch": {}}}'.format(
100. * correct / len(test_loader.dataset), epoch))
return test_loss, test_acc