bird-acoustics-rcnn/CNN_CNN.ipynb

311 строки
10 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"Licensed under the MIT License."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.optim as optim\n",
"import torch.utils.data as data_utils\n",
"import os\n",
"import numpy as np\n",
"from sklearn.preprocessing import LabelBinarizer, LabelEncoder\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.model_selection import KFold, StratifiedKFold\n",
"import csv\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import glob\n",
"import gc\n",
"import h5py\n",
"import pickle as pk\n",
"\n",
"from utils import log_results, SaveBestModel, train_seq, test_seq\n",
"from utils import normalize_mel_sp_slides\n",
"\n",
"from models import cnn_cnn"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Set directories"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"dataDir = 'processed'\n",
"resultsDir = 'Results'\n",
"tempDir = 'temp'\n",
"\n",
"if not os.path.exists(resultsDir):\n",
" os.makedirs(resultsDir)\n",
"if not os.path.exists(tempDir):\n",
" os.makedirs(tempDir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load data"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"fname = 'birds_cornell_spectr_slide_100_species_sr_32000_len_7_sec_500_250_New.h5'\n",
"fileLoc = os.path.join(dataDir,fname) # 19707 samples per class\n",
"hf = h5py.File(fileLoc, 'r')\n",
"mel_sp = hf.get('mel_spectr')[()]\n",
"metadata_total = pd.read_hdf(fileLoc, 'info')\n",
"hf.close()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of Species: 100\n"
]
}
],
"source": [
"original_label = list(metadata_total['ebird_code'])\n",
"lb_bin = LabelBinarizer()\n",
"lb_enc = LabelEncoder()\n",
"labels_one_hot = lb_bin.fit_transform(original_label)\n",
"labels_multi_lbl = lb_enc.fit_transform(original_label)\n",
"\n",
"number_of_sample_classes = len(lb_enc.classes_)\n",
"print(\"Number of Species: \", number_of_sample_classes)\n",
"species_id_class_dict_tp = dict()\n",
"for (class_label, species_id) in enumerate(lb_bin.classes_):\n",
" species_id_class_dict_tp[species_id] = class_label"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Transform data"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"mel_sp_normalized = []\n",
"for i in range(len(mel_sp)):\n",
" xx_ = normalize_mel_sp_slides(mel_sp[i]).astype('float32')\n",
" mel_sp_normalized += [np.expand_dims(xx_, axis=-3)]\n",
"mel_sp_normalized = np.array(mel_sp_normalized)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## configs used in the current paper"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Other configs can be generated using similar pattern"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"cfg_cnn = [[32, 'M', 64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M'],\n",
" [32, 64, 'M', 64, 64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M'],\n",
" [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']]\n",
"\n",
"cfg_ts = [[64, 64, 'M', 128, 128, 128, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M'],\n",
" [64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 256, 256, 256, 'M'],\n",
" [64, 64, 'M', 128, 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M']]\n",
"nunits = [256*4, 256*4, 512*4]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 16*2\n",
"num_classes=100\n",
"shuffleBatches=True\n",
"num_epoch = 50"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run all configs of CNN+CNN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"exp_no = 0\n",
"for ii, cfg1 in enumerate(cfg_cnn):\n",
" for jj, cfg2 in enumerate(cfg_ts):\n",
" exp_no += 1\n",
" \n",
" skf = StratifiedKFold(n_splits=5, random_state=42)\n",
"\n",
" log_file_name = f'100_species_spectr_cnn_p_cnn_7sec_{exp_no}.p'\n",
" store_ = log_results(file_name=log_file_name, results_dir = resultsDir)\n",
"\n",
" exp_ind = 0\n",
" for train_ind, test_ind in skf.split(mel_sp_normalized, labels_multi_lbl): #5-fold resampling\n",
"\n",
" PATH_curr = os.path.join(tempDir, f'currentModel_cnn_p_cnn_{exp_no}_{exp_ind}.pt')\n",
" saveModel = SaveBestModel(PATH=PATH_curr, monitor=-np.inf, verbose=True)\n",
"\n",
" X_train, X_test_p_valid = mel_sp_normalized[train_ind,:], mel_sp_normalized[test_ind,:]\n",
"\n",
" y_train, y_test_p_valid = labels_one_hot[train_ind], labels_one_hot[test_ind]\n",
" y_train_mlbl, y_test_p_valid_mlbl = labels_multi_lbl[train_ind], labels_multi_lbl[test_ind]\n",
" X_valid, X_test, \\\n",
" y_valid, y_test = train_test_split(X_test_p_valid, y_test_p_valid,\n",
" test_size=0.5,\n",
" stratify=y_test_p_valid_mlbl,\n",
" random_state=42)\n",
"\n",
" print('X_train shape: ', X_train.shape)\n",
" print('X_valid shape: ', X_valid.shape)\n",
" print('X_test shape: ', X_test.shape)\n",
"\n",
" X_train, X_valid = torch.from_numpy(X_train).float(), torch.from_numpy(X_valid).float()\n",
" y_train, y_valid = torch.from_numpy(y_train), torch.from_numpy(y_valid)\n",
"\n",
" y_train, y_valid = y_train.float(), y_valid.float()\n",
" train_use = data_utils.TensorDataset(X_train, y_train)\n",
" train_loader = data_utils.DataLoader(train_use, batch_size=batch_size, shuffle=shuffleBatches)\n",
"\n",
" val_use = data_utils.TensorDataset(X_valid, y_valid)\n",
" val_loader = data_utils.DataLoader(val_use, batch_size=32, shuffle=False)\n",
"\n",
" model = cnn_cnn(cfg1, \n",
" cfg2, \n",
" nunits[jj], \n",
" num_classes=100)\n",
" model.to(device)\n",
" optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001, weight_decay=1e-7)\n",
"\n",
" val_acc_epochs = []\n",
" val_loss_epochs = []\n",
" for epoch in range(1, num_epoch+1):\n",
" train_loss = train_seq(model, train_loader, optimizer, epoch, \n",
" device,\n",
" verbose=1, loss_fn = 'bceLogit')\n",
" val_loss, val_acc = test_seq(model, val_loader,\n",
" device,\n",
" loss_fn = 'bceLogit')\n",
" val_acc_epochs.append(val_acc)\n",
" val_loss_epochs.append(val_loss)\n",
" print('val loss = %f, val acc = %f'%(val_loss, val_acc))\n",
" saveModel.check(model, val_acc, comp='max')\n",
"\n",
" # loading best validated model\n",
" model = cnn_cnn(cfg1, \n",
" cfg2, \n",
" nunits[jj], \n",
" num_classes=100)\n",
" model.to(device)\n",
" model.load_state_dict(torch.load(PATH_curr))\n",
"\n",
" X_test, y_test = torch.from_numpy(X_test).float(), torch.from_numpy(y_test).float()\n",
"\n",
" test_use = data_utils.TensorDataset(X_test, y_test)\n",
" test_loader = data_utils.DataLoader(test_use, batch_size=32, shuffle=False)\n",
" test_loss, test_acc = test_seq(model, test_loader,\n",
" device,\n",
" loss_fn = 'bceLogit')\n",
" print('test loss = %f, test acc = %f'%(test_loss, test_acc))\n",
"\n",
" log_ = dict(\n",
" exp_ind = exp_ind,\n",
" epochs = num_epoch,\n",
" validation_accuracy = val_acc_epochs,\n",
" validation_loss = val_loss_epochs,\n",
" test_loss = test_loss,\n",
" test_accuracy = test_acc,\n",
" X_train_shape = X_train.shape,\n",
" X_valid_shape = X_valid.shape,\n",
" batch_size =batch_size,\n",
" )\n",
" store_.update(log_)\n",
" exp_ind += 1 \n",
" print(f'COMPLETED for configs {ii, jj}...')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.5",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}