зеркало из https://github.com/microsoft/EdgeML.git
SRNN Example, 90% accuracy
This commit is contained in:
Родитель
cfadd7181f
Коммит
d43bcc7faa
|
@ -1,34 +1,21 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SRNN on Speech Commands Dataset\n",
|
||||
"\n",
|
||||
"Please use `fetch_google.sh` to download the Google Speech Commands Dataset and `python process_google.py` to create feature extracted data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-13T09:44:51.084894Z",
|
||||
"start_time": "2019-07-13T09:44:51.073356Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# SRNN on Speech Commands Dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-13T09:44:51.378080Z",
|
||||
"start_time": "2019-07-13T09:44:51.086975Z"
|
||||
"end_time": "2019-07-14T12:52:51.914361Z",
|
||||
"start_time": "2019-07-14T12:52:51.667856Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
|
@ -46,18 +33,55 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-13T09:44:52.118652Z",
|
||||
"start_time": "2019-07-13T09:44:51.380331Z"
|
||||
"end_time": "2019-07-14T12:52:56.040100Z",
|
||||
"start_time": "2019-07-14T12:52:51.916533Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_DIR = '/datadrive/t-dodenn/Divide_and_Conquer/data/Google-13/nonmildata/'\n",
|
||||
"DATA_DIR = './GoogleSpeech/Extracted/'\n",
|
||||
"x_train_, y_train = np.squeeze(np.load(DATA_DIR + 'x_train.npy')), np.squeeze(np.load(DATA_DIR + 'y_train.npy'))\n",
|
||||
"x_val_, y_val = np.squeeze(np.load(DATA_DIR + 'x_val.npy')), np.squeeze(np.load(DATA_DIR + 'y_val.npy'))"
|
||||
"x_val_, y_val = np.squeeze(np.load(DATA_DIR + 'x_val.npy')), np.squeeze(np.load(DATA_DIR + 'y_val.npy'))\n",
|
||||
"x_test_, y_test = np.squeeze(np.load(DATA_DIR + 'x_test.npy')), np.squeeze(np.load(DATA_DIR + 'y_test.npy'))\n",
|
||||
"# Mean-var normalize\n",
|
||||
"mean = np.mean(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
|
||||
"std = np.std(np.reshape(x_train_, [-1, x_train_.shape[-1]]), axis=0)\n",
|
||||
"std[std[:] < 0.000001] = 1\n",
|
||||
"x_train_ = (x_train_ - mean) / std\n",
|
||||
"x_val_ = (x_val_ - mean) / std\n",
|
||||
"x_test_ = (x_test_ - mean) / std"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-14T12:52:56.047992Z",
|
||||
"start_time": "2019-07-14T12:52:56.042445Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train shape (99, 51088, 32) (51088, 13)\n",
|
||||
"Val shape (99, 6798, 32) (6798, 13)\n",
|
||||
"Test shape (99, 6835, 32) (6835, 13)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"x_train = np.swapaxes(x_train_, 0, 1)\n",
|
||||
"x_val = np.swapaxes(x_val_, 0, 1)\n",
|
||||
"x_test = np.swapaxes(x_test_, 0, 1)\n",
|
||||
"print(\"Train shape\", x_train.shape, y_train.shape)\n",
|
||||
"print(\"Val shape\", x_val.shape, y_val.shape)\n",
|
||||
"print(\"Test shape\", x_test.shape, y_test.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -65,26 +89,23 @@
|
|||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-13T09:44:52.143493Z",
|
||||
"start_time": "2019-07-13T09:44:52.121480Z"
|
||||
"end_time": "2019-07-14T12:52:56.068329Z",
|
||||
"start_time": "2019-07-14T12:52:56.049725Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"((99, 51088, 32), (51088, 13), (99, 6798, 32), (6798, 13))"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x_train = np.swapaxes(x_train_, 0, 1)\n",
|
||||
"x_val = np.swapaxes(x_val_, 0, 1)\n",
|
||||
"x_train.shape, y_train.shape, x_val.shape, y_val.shape"
|
||||
"numTimeSteps = x_train.shape[0]\n",
|
||||
"numInput = x_train.shape[-1]\n",
|
||||
"brickSize = 11\n",
|
||||
"numClasses = y_train.shape[1]\n",
|
||||
"\n",
|
||||
"hiddenDim0 = 64\n",
|
||||
"hiddenDim1 = 32\n",
|
||||
"cellType = 'LSTM'\n",
|
||||
"learningRate = 0.01\n",
|
||||
"batchSize = 128\n",
|
||||
"epochs = 10"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -92,23 +113,22 @@
|
|||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-13T09:44:52.159481Z",
|
||||
"start_time": "2019-07-13T09:44:52.145491Z"
|
||||
"end_time": "2019-07-14T12:52:56.088534Z",
|
||||
"start_time": "2019-07-14T12:52:56.070114Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using x-entropy loss\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"numTimeSteps = x_train.shape[0]\n",
|
||||
"numInput = x_train.shape[-1]\n",
|
||||
"brickSize = 9\n",
|
||||
"numClasses = y_train.shape[1]\n",
|
||||
"\n",
|
||||
"hiddenDim0 = 32\n",
|
||||
"hiddenDim1 = 32\n",
|
||||
"cellType = 'LSTM'\n",
|
||||
"learningRate = 0.01\n",
|
||||
"batchSize = 128\n",
|
||||
"epochs = 50"
|
||||
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType)\n",
|
||||
"trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -116,40 +136,42 @@
|
|||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-13T09:44:52.246650Z",
|
||||
"start_time": "2019-07-13T09:44:52.161240Z"
|
||||
"end_time": "2019-07-14T12:59:52.893161Z",
|
||||
"start_time": "2019-07-14T12:52:56.090327Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "TypeError",
|
||||
"evalue": "__init__() got an unexpected keyword argument 'device'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m<ipython-input-6-4d0448ee71f4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0msrnn2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSRNN2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumInput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnumClasses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhiddenDim0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhiddenDim1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcellType\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSRNNTrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrnn2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearningRate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'gpu'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'device'"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 0 batch 0 loss 2.799139 acc 0.031250\n",
|
||||
"Epoch 0 batch 200 loss 0.784248 acc 0.750000\n",
|
||||
"Epoch 1 batch 0 loss 0.379059 acc 0.875000\n",
|
||||
"Epoch 1 batch 200 loss 0.544366 acc 0.820312\n",
|
||||
"Epoch 2 batch 0 loss 0.272113 acc 0.914062\n",
|
||||
"Epoch 2 batch 200 loss 0.400919 acc 0.867188\n",
|
||||
"Epoch 3 batch 0 loss 0.200825 acc 0.953125\n",
|
||||
"Epoch 3 batch 200 loss 0.248952 acc 0.906250\n",
|
||||
"Epoch 4 batch 0 loss 0.161245 acc 0.960938\n",
|
||||
"Epoch 4 batch 200 loss 0.294340 acc 0.875000\n",
|
||||
"Validation accuracy: 0.913063\n",
|
||||
"Epoch 5 batch 0 loss 0.159573 acc 0.953125\n",
|
||||
"Epoch 5 batch 200 loss 0.233308 acc 0.937500\n",
|
||||
"Epoch 6 batch 0 loss 0.068345 acc 0.984375\n",
|
||||
"Epoch 6 batch 200 loss 0.225371 acc 0.937500\n",
|
||||
"Epoch 7 batch 0 loss 0.112335 acc 0.968750\n",
|
||||
"Epoch 7 batch 200 loss 0.170626 acc 0.945312\n",
|
||||
"Epoch 8 batch 0 loss 0.168985 acc 0.945312\n",
|
||||
"Epoch 8 batch 200 loss 0.160869 acc 0.937500\n",
|
||||
"Epoch 9 batch 0 loss 0.123516 acc 0.953125\n",
|
||||
"Epoch 9 batch 200 loss 0.172936 acc 0.937500\n",
|
||||
"Validation accuracy: 0.908208\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType)\n",
|
||||
"trainer = SRNNTrainer(srnn2, learningRate, device='gpu')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"ExecuteTime": {
|
||||
"end_time": "2019-07-13T09:44:52.248622Z",
|
||||
"start_time": "2019-07-13T09:44:50.923Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val, printStep=200)"
|
||||
"trainer.train(brickSize, batchSize, epochs, x_train, x_val, y_train, y_val, printStep=200, valStep=5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче