зеркало из https://github.com/microsoft/EdgeML.git
GPU support for Bonsai pytorch
This commit is contained in:
Родитель
4eb690df40
Коммит
193c28e349
|
@ -61,7 +61,7 @@ def main():
|
|||
else:
|
||||
batchSize = args.batch_size
|
||||
|
||||
useMCHLoss = False
|
||||
useMCHLoss = True
|
||||
|
||||
if numClasses == 2:
|
||||
numClasses = 1
|
||||
|
|
|
@ -69,7 +69,7 @@ class BonsaiTrainer:
|
|||
|
||||
if (self.bonsaiObj.numClasses > 2):
|
||||
if self.useMCHLoss is True:
|
||||
marginLoss = utils.multiClassHingeLoss(logits, labels)
|
||||
marginLoss = utils.multiClassHingeLoss(logits, labels, self.device)
|
||||
else:
|
||||
marginLoss = utils.crossEntropyLoss(logits, labels)
|
||||
loss = marginLoss + regLoss
|
||||
|
|
|
@ -7,14 +7,14 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def multiClassHingeLoss(logits, labels):
|
||||
def multiClassHingeLoss(logits, labels, device):
|
||||
'''
|
||||
MultiClassHingeLoss to match C++ Version - No pytorch internal version
|
||||
'''
|
||||
flatLogits = torch.reshape(logits, [-1, ])
|
||||
labels_ = labels.argmax(dim=1)
|
||||
|
||||
correctId = torch.arange(labels.shape[0]) * labels.shape[1] + labels_
|
||||
correctId = torch.arange(labels.shape[0]).to(device) * labels.shape[1] + labels_
|
||||
correctLogit = torch.gather(flatLogits, 0, correctId)
|
||||
|
||||
maxLabel = logits.argmax(dim=1)
|
||||
|
|
Загрузка…
Ссылка в новой задаче