minor code clean up in validation

This commit is contained in:
Bugra Tekin 2019-10-18 19:07:57 +02:00
Родитель 283d13645a
Коммит 607a859e1b
1 изменённых файлов: 4 добавлений и 4 удалений

Просмотреть файл

@ -88,7 +88,7 @@ def valid(datacfg, modelcfg, weightfile):
test_width = model.test_width
test_height = model.test_height
num_keypoints = model.num_keypoints
num_labels = num_keypoints * 2 + 3
num_labels = num_keypoints * 2 + 3 # +2 for width, height, +1 for class label
# Get the parser for the test dataset
valid_dataset = dataset.listDataset(valid_images,
@ -122,7 +122,7 @@ def valid(datacfg, modelcfg, weightfile):
# Iterate through all batch elements
for box_pr, target in zip([all_boxes], [target[0]]):
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target.view(-1, num_keypoints*2+3)
truths = target.view(-1, num_labels)
# Get how many objects are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
@ -134,8 +134,8 @@ def valid(datacfg, modelcfg, weightfile):
box_gt.append(truths[k][0])
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt = np.array(np.reshape(box_gt[:18], [-1, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [-1, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width