зеркало из https://github.com/microsoft/EdgeML.git
Added GPU Support for srnn
This commit is contained in:
Родитель
850e2a5465
Коммит
835db8ce2e
|
@ -1,207 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SRNN on Speech Commands Dataset\n",
|
||||
"\n",
|
||||
"Please use `fetch_google.sh` to download the Google Speech Commands Dataset and `python process_google.py` to create feature extracted data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:51.914361Z",
|
||||
"start_time": "2019-07-14T12:52:51.667856Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import print_function\n",
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from pytorch_edgeml.graph.rnn import SRNN2\n",
|
||||
"from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer\n",
|
||||
"import pytorch_edgeml.utils as utils\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.040100Z",
|
||||
"start_time": "2019-07-14T12:52:51.916533Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_DIR = './GoogleSpeech/Extracted/'\n",
|
||||
"x_train_, y_train = np.squeeze(np.load(DATA_DIR + 'x_train.npy')), np.squeeze(np.load(DATA_DIR + 'y_train.npy'))\n",
|
||||
"x_val_, y_val = np.squeeze(np.load(DATA_DIR + 'x_val.npy')), np.squeeze(np.load(DATA_DIR + 'y_val.npy'))\n",
|
||||
"x_test_, y_test = np.squeeze(np.load(DATA_DIR + 'x_test.npy')), np.squeeze(np.load(DATA_DIR + 'y_test.npy'))\n",
|
||||
"# Mean-var normalize\n",
|
||||
"mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
|
||||
"std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
|
||||
"std[std[:] < 0.000001] = 1\n",
|
||||
"x_train_ = (x_train_ - mean) / std\n",
|
||||
"x_val_ = (x_val_ - mean) / std\n",
|
||||
"x_test_ = (x_test_ - mean) / std"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.047992Z",
|
||||
"start_time": "2019-07-14T12:52:56.042445Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train shape (99, 51088, 32) (51088, 13)\n",
|
||||
"Val shape (99, 6798, 32) (6798, 13)\n",
|
||||
"Test shape (99, 6835, 32) (6835, 13)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x_train = np.swapaxes(x_train_, 0, 1)\n",
|
||||
"x_val = np.swapaxes(x_val_, 0, 1)\n",
|
||||
"x_test = np.swapaxes(x_test_, 0, 1)\n",
|
||||
"print(\"Train shape\", x_train.shape, y_train.shape)\n",
|
||||
"print(\"Val shape\", x_val.shape, y_val.shape)\n",
|
||||
"print(\"Test shape\", x_test.shape, y_test.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.068329Z",
|
||||
"start_time": "2019-07-14T12:52:56.049725Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"numTimeSteps = x_train.shape[0]\n",
|
||||
"numInput = x_train.shape[-1]\n",
|
||||
"brickSize = 11\n",
|
||||
"numClasses = y_train.shape[1]\n",
|
||||
"\n",
|
||||
"hiddenDim0 = 64\n",
|
||||
"hiddenDim1 = 32\n",
|
||||
"cellType = 'LSTM'\n",
|
||||
"learningRate = 0.01\n",
|
||||
"batchSize = 128\n",
|
||||
"epochs = 10"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.088534Z",
|
||||
"start_time": "2019-07-14T12:52:56.070114Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using x-entropy loss\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device)\n",
|
||||
"trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:59:52.893161Z",
|
||||
"start_time": "2019-07-14T12:52:56.090327Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 0 batch 0 loss 2.799139 acc 0.031250\n",
|
||||
"Epoch 0 batch 200 loss 0.784248 acc 0.750000\n",
|
||||
"Epoch 1 batch 0 loss 0.379059 acc 0.875000\n",
|
||||
"Epoch 1 batch 200 loss 0.544366 acc 0.820312\n",
|
||||
"Epoch 2 batch 0 loss 0.272113 acc 0.914062\n",
|
||||
"Epoch 2 batch 200 loss 0.400919 acc 0.867188\n",
|
||||
"Epoch 3 batch 0 loss 0.200825 acc 0.953125\n",
|
||||
"Epoch 3 batch 200 loss 0.248952 acc 0.906250\n",
|
||||
"Epoch 4 batch 0 loss 0.161245 acc 0.960938\n",
|
||||
"Epoch 4 batch 200 loss 0.294340 acc 0.875000\n",
|
||||
"Validation accuracy: 0.913063\n",
|
||||
"Epoch 5 batch 0 loss 0.159573 acc 0.953125\n",
|
||||
"Epoch 5 batch 200 loss 0.233308 acc 0.937500\n",
|
||||
"Epoch 6 batch 0 loss 0.068345 acc 0.984375\n",
|
||||
"Epoch 6 batch 200 loss 0.225371 acc 0.937500\n",
|
||||
"Epoch 7 batch 0 loss 0.112335 acc 0.968750\n",
|
||||
"Epoch 7 batch 200 loss 0.170626 acc 0.945312\n",
|
||||
"Epoch 8 batch 0 loss 0.168985 acc 0.945312\n",
|
||||
"Epoch 8 batch 200 loss 0.160869 acc 0.937500\n",
|
||||
"Epoch 9 batch 0 loss 0.123516 acc 0.953125\n",
|
||||
"Epoch 9 batch 200 loss 0.172936 acc 0.937500\n",
|
||||
"Validation accuracy: 0.908208\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val, printStep=200, valStep=5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Загрузка…
Ссылка в новой задаче