зеркало из https://github.com/microsoft/EdgeML.git
Updated the srnn ipython notebook (gpu support)
This commit is contained in:
Родитель
22654964d8
Коммит
382cb58bd5
|
@ -29,6 +29,7 @@
|
|||
"from pytorch_edgeml.graph.rnn import SRNN2\n",
|
||||
"from pytorch_edgeml.trainer.srnnTrainer import SRNNTrainer\n",
|
||||
"import pytorch_edgeml.utils as utils"
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -127,8 +128,8 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType)\n",
|
||||
"trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy')"
|
||||
"srnn2 = SRNN2(numInput, numClasses, hiddenDim0, hiddenDim1, cellType).to(device)\n",
|
||||
"trainer = SRNNTrainer(srnn2, learningRate, lossType='xentropy', device=device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче