This commit is contained in:
SachinG007 2019-07-15 16:26:58 +05:30
Родитель 850e2a5465
Коммит 835db8ce2e
1 изменённых файлов: 0 добавлений и 207 удалений

Просмотреть файл

@ -1,207 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SRNN on Speech Commands Dataset\n",
"\n",
"Please use `fetch_google.sh` to download the Google Speech Commands Dataset and `python process_google.py` to create feature extracted data."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-14T12:52:51.914361Z",
"start_time": "2019-07-14T12:52:51.667856Z"
}
},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import sys\n",
"import os\n",
"import numpy as np\n",
"import torch\n",
"\n",
"from pytorch_edgeml.graph.rnn import SRNN2\n",
"from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer\n",
"import pytorch_edgeml.utils as utils\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-14T12:52:56.040100Z",
"start_time": "2019-07-14T12:52:51.916533Z"
}
},
"outputs": [],
"source": [
"DATA_DIR = './GoogleSpeech/Extracted/'\n",
"x_train_, y_train = np.squeeze(np.load(DATA_DIR + 'x_train.npy')), np.squeeze(np.load(DATA_DIR + 'y_train.npy'))\n",
"x_val_, y_val = np.squeeze(np.load(DATA_DIR + 'x_val.npy')), np.squeeze(np.load(DATA_DIR + 'y_val.npy'))\n",
"x_test_, y_test = np.squeeze(np.load(DATA_DIR + 'x_test.npy')), np.squeeze(np.load(DATA_DIR + 'y_test.npy'))\n",
"# Mean-var normalize\n",
"mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
"std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
"std[std[:] < 0.000001] = 1\n",
"x_train_ = (x_train_ - mean) / std\n",
"x_val_ = (x_val_ - mean) / std\n",
"x_test_ = (x_test_ - mean) / std"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-14T12:52:56.047992Z",
"start_time": "2019-07-14T12:52:56.042445Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train shape (99, 51088, 32) (51088, 13)\n",
"Val shape (99, 6798, 32) (6798, 13)\n",
"Test shape (99, 6835, 32) (6835, 13)\n"
]
}
],
"source": [
"x_train = np.swapaxes(x_train_, 0, 1)\n",
"x_val = np.swapaxes(x_val_, 0, 1)\n",
"x_test = np.swapaxes(x_test_, 0, 1)\n",
"print(\"Train shape\", x_train.shape, y_train.shape)\n",
"print(\"Val shape\", x_val.shape, y_val.shape)\n",
"print(\"Test shape\", x_test.shape, y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-14T12:52:56.068329Z",
"start_time": "2019-07-14T12:52:56.049725Z"
}
},
"outputs": [],
"source": [
"numTimeSteps = x_train.shape[0]\n",
"numInput = x_train.shape[-1]\n",
"brickSize = 11\n",
"numClasses = y_train.shape[1]\n",
"\n",
"hiddenDim0 = 64\n",
"hiddenDim1 = 32\n",
"cellType = 'LSTM'\n",
"learningRate = 0.01\n",
"batchSize = 128\n",
"epochs = 10"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-14T12:52:56.088534Z",
"start_time": "2019-07-14T12:52:56.070114Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using x-entropy loss\n"
]
}
],
"source": [
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device)\n",
"trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-14T12:59:52.893161Z",
"start_time": "2019-07-14T12:52:56.090327Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 batch 0 loss 2.799139 acc 0.031250\n",
"Epoch 0 batch 200 loss 0.784248 acc 0.750000\n",
"Epoch 1 batch 0 loss 0.379059 acc 0.875000\n",
"Epoch 1 batch 200 loss 0.544366 acc 0.820312\n",
"Epoch 2 batch 0 loss 0.272113 acc 0.914062\n",
"Epoch 2 batch 200 loss 0.400919 acc 0.867188\n",
"Epoch 3 batch 0 loss 0.200825 acc 0.953125\n",
"Epoch 3 batch 200 loss 0.248952 acc 0.906250\n",
"Epoch 4 batch 0 loss 0.161245 acc 0.960938\n",
"Epoch 4 batch 200 loss 0.294340 acc 0.875000\n",
"Validation accuracy: 0.913063\n",
"Epoch 5 batch 0 loss 0.159573 acc 0.953125\n",
"Epoch 5 batch 200 loss 0.233308 acc 0.937500\n",
"Epoch 6 batch 0 loss 0.068345 acc 0.984375\n",
"Epoch 6 batch 200 loss 0.225371 acc 0.937500\n",
"Epoch 7 batch 0 loss 0.112335 acc 0.968750\n",
"Epoch 7 batch 200 loss 0.170626 acc 0.945312\n",
"Epoch 8 batch 0 loss 0.168985 acc 0.945312\n",
"Epoch 8 batch 200 loss 0.160869 acc 0.937500\n",
"Epoch 9 batch 0 loss 0.123516 acc 0.953125\n",
"Epoch 9 batch 200 loss 0.172936 acc 0.937500\n",
"Validation accuracy: 0.908208\n"
]
}
],
"source": [
"trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val, printStep=200, valStep=5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}