FIX: use args.save_preds_file to store the results, remove async=True

This commit is contained in:
soumya ranjan 2019-05-28 02:32:22 +05:30
Родитель 7698501ad0
Коммит f09be21995
1 изменённых файлов: 4 добавлений и 4 удалений

Просмотреть файл

@ -49,7 +49,7 @@ def main():
' (default: resnext101)') ' (default: resnext101)')
parser.add_argument('--image_size', default=224, nargs='+', parser.add_argument('--image_size', default=224, nargs='+',
type=int, metavar='RESOLUTION', help='The side length of the CNN input image ' + \ type=int, metavar='RESOLUTION', help='The side length of the CNN input image ' + \
'(default: 448). For ensembles, provide one resolution for each network.') '(default: 224). For ensembles, provide one resolution for each network.')
parser.add_argument('--epochs', default=200, parser.add_argument('--epochs', default=200,
type=int, metavar='N', help='Number of total epochs to run.') type=int, metavar='N', help='Number of total epochs to run.')
parser.add_argument('--start_epoch', default=None, parser.add_argument('--start_epoch', default=None,
@ -329,7 +329,7 @@ def main():
# write predictions to file # write predictions to file
if args.save_preds: if args.save_preds:
prec1, prec3, prec5, preds, im_ids = validate(val_loader, model, criterion, 0, True) prec1, prec3, prec5, preds, im_ids = validate(val_loader, model, criterion, 0, True)
with open(args.op_file_name, 'w') as opfile: with open(args.save_preds_file, 'w') as opfile:
opfile.write('id,predicted\n') opfile.write('id,predicted\n')
for ii in range(len(im_ids)): for ii in range(len(im_ids)):
opfile.write(str(im_ids[ii]) + ',' + ' '.join(str(x) for x in preds[ii,:]) + '\n') opfile.write(str(im_ids[ii]) + ',' + ' '.join(str(x) for x in preds[ii,:]) + '\n')
@ -451,7 +451,7 @@ def train(train_loader, model, criterion, optimizer, epoch, param_copy = None):
input = input.cuda() input = input.cuda()
target = target.cuda(async=True) target = target.cuda()
input_var = torch.autograd.Variable(input) input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target) target_var = torch.autograd.Variable(target)
@ -543,7 +543,7 @@ def validate(val_loader, model, criterion, epoch, global_step, save_preds=False)
#output = torch.max(output, outputNew) #output = torch.max(output, outputNew)
output /= len(inputIn) output /= len(inputIn)
target = target.cuda(async=True) target = target.cuda()
target_var = torch.autograd.Variable(target) target_var = torch.autograd.Variable(target)
loss = criterion(output, target_var) loss = criterion(output, target_var)