GPU support for Bonsai pytorch

This commit is contained in:
Aditya Kusupati 2019-06-21 23:59:18 +00:00
Родитель 4eb690df40
Коммит 193c28e349
3 изменённых файлов: 4 добавлений и 4 удалений

Просмотреть файл

@ -61,7 +61,7 @@ def main():
else:
batchSize = args.batch_size
useMCHLoss = False
useMCHLoss = True
if numClasses == 2:
numClasses = 1

Просмотреть файл

@ -69,7 +69,7 @@ class BonsaiTrainer:
if (self.bonsaiObj.numClasses > 2):
if self.useMCHLoss is True:
marginLoss = utils.multiClassHingeLoss(logits, labels)
marginLoss = utils.multiClassHingeLoss(logits, labels, self.device)
else:
marginLoss = utils.crossEntropyLoss(logits, labels)
loss = marginLoss + regLoss

Просмотреть файл

@ -7,14 +7,14 @@ import torch
import torch.nn.functional as F
def multiClassHingeLoss(logits, labels):
def multiClassHingeLoss(logits, labels, device):
'''
MultiClassHingeLoss to match C++ Version - No pytorch internal version
'''
flatLogits = torch.reshape(logits, [-1, ])
labels_ = labels.argmax(dim=1)
correctId = torch.arange(labels.shape[0]) * labels.shape[1] + labels_
correctId = torch.arange(labels.shape[0]).to(device) * labels.shape[1] + labels_
correctLogit = torch.gather(flatLogits, 0, correctId)
maxLabel = logits.argmax(dim=1)