fix minor bug in main.py
This commit is contained in:
Родитель
16e0b21015
Коммит
1d44b532c1
6
main.py
6
main.py
|
@ -162,7 +162,7 @@ def main():
|
|||
ndata = train_dataset.__len__()
|
||||
lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum).cuda()
|
||||
|
||||
|
||||
|
||||
criterion = NCACrossEntropy(torch.LongTensor([y for (p, y) in train_loader.dataset.imgs]),
|
||||
args.margin / args.temperature).cuda()
|
||||
cudnn.benchmark = True
|
||||
|
@ -238,9 +238,9 @@ def train(train_loader, model, lemniscate, criterion, optimizer, epoch):
|
|||
print('Epoch: [{0}][{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t').format(
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
|
||||
epoch, i, len(train_loader), batch_time=batch_time,
|
||||
data_time=data_time, loss=losses)
|
||||
data_time=data_time, loss=losses))
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
||||
|
|
Загрузка…
Ссылка в новой задаче