Updated the srnn ipython notebook (gpu support)

This commit is contained in:
Sachin Goyal 2019-07-15 16:25:07 +05:30 коммит произвёл GitHub
Родитель 22654964d8
Коммит 382cb58bd5
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 3 добавлений и 2 удалений

Просмотреть файл

@ -29,6 +29,7 @@
"from pytorch_edgeml.graph.rnn import SRNN2\n",
"from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer\n",
"import pytorch_edgeml.utils as utils"
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
@ -127,8 +128,8 @@
}
],
"source": [
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType)\n",
"trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy')"
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device)\n",
"trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)"
]
},
{