FIX: use args.save_preds_file to store the results, remove async=True
This commit is contained in:
Родитель
7698501ad0
Коммит
f09be21995
|
@ -49,7 +49,7 @@ def main():
|
||||||
' (default: resnext101)')
|
' (default: resnext101)')
|
||||||
parser.add_argument('--image_size', default=224, nargs='+',
|
parser.add_argument('--image_size', default=224, nargs='+',
|
||||||
type=int, metavar='RESOLUTION', help='The side length of the CNN input image ' + \
|
type=int, metavar='RESOLUTION', help='The side length of the CNN input image ' + \
|
||||||
'(default: 448). For ensembles, provide one resolution for each network.')
|
'(default: 224). For ensembles, provide one resolution for each network.')
|
||||||
parser.add_argument('--epochs', default=200,
|
parser.add_argument('--epochs', default=200,
|
||||||
type=int, metavar='N', help='Number of total epochs to run.')
|
type=int, metavar='N', help='Number of total epochs to run.')
|
||||||
parser.add_argument('--start_epoch', default=None,
|
parser.add_argument('--start_epoch', default=None,
|
||||||
|
@ -329,7 +329,7 @@ def main():
|
||||||
# write predictions to file
|
# write predictions to file
|
||||||
if args.save_preds:
|
if args.save_preds:
|
||||||
prec1, prec3, prec5, preds, im_ids = validate(val_loader, model, criterion, 0, True)
|
prec1, prec3, prec5, preds, im_ids = validate(val_loader, model, criterion, 0, True)
|
||||||
with open(args.op_file_name, 'w') as opfile:
|
with open(args.save_preds_file, 'w') as opfile:
|
||||||
opfile.write('id,predicted\n')
|
opfile.write('id,predicted\n')
|
||||||
for ii in range(len(im_ids)):
|
for ii in range(len(im_ids)):
|
||||||
opfile.write(str(im_ids[ii]) + ',' + ' '.join(str(x) for x in preds[ii,:]) + '\n')
|
opfile.write(str(im_ids[ii]) + ',' + ' '.join(str(x) for x in preds[ii,:]) + '\n')
|
||||||
|
@ -451,7 +451,7 @@ def train(train_loader, model, criterion, optimizer, epoch, param_copy = None):
|
||||||
|
|
||||||
input = input.cuda()
|
input = input.cuda()
|
||||||
|
|
||||||
target = target.cuda(async=True)
|
target = target.cuda()
|
||||||
input_var = torch.autograd.Variable(input)
|
input_var = torch.autograd.Variable(input)
|
||||||
target_var = torch.autograd.Variable(target)
|
target_var = torch.autograd.Variable(target)
|
||||||
|
|
||||||
|
@ -543,7 +543,7 @@ def validate(val_loader, model, criterion, epoch, global_step, save_preds=False)
|
||||||
#output = torch.max(output, outputNew)
|
#output = torch.max(output, outputNew)
|
||||||
output /= len(inputIn)
|
output /= len(inputIn)
|
||||||
|
|
||||||
target = target.cuda(async=True)
|
target = target.cuda()
|
||||||
target_var = torch.autograd.Variable(target)
|
target_var = torch.autograd.Variable(target)
|
||||||
loss = criterion(output, target_var)
|
loss = criterion(output, target_var)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче