This commit is contained in:
b-liu14 2018-08-30 10:32:17 +09:00
Родитель 16e0b21015
Коммит 1d44b532c1
1 изменённых файлов: 3 добавлений и 3 удалений

Просмотреть файл

@ -162,7 +162,7 @@ def main():
ndata = train_dataset.__len__()
lemniscate = LinearAverage(args.low_dim, ndata, args.temperature, args.memory_momentum).cuda()
criterion = NCACrossEntropy(torch.LongTensor([y for (p, y) in train_loader.dataset.imgs]),
args.margin / args.temperature).cuda()
cudnn.benchmark = True
@ -238,9 +238,9 @@ def train(train_loader, model, lemniscate, criterion, optimizer, epoch):
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t').format(
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses)
data_time=data_time, loss=losses))
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):