minor code clean up in validation
This commit is contained in:
Родитель
283d13645a
Коммит
607a859e1b
8
valid.py
8
valid.py
|
@ -88,7 +88,7 @@ def valid(datacfg, modelcfg, weightfile):
|
|||
test_width = model.test_width
|
||||
test_height = model.test_height
|
||||
num_keypoints = model.num_keypoints
|
||||
num_labels = num_keypoints * 2 + 3
|
||||
num_labels = num_keypoints * 2 + 3 # +2 for width, height, +1 for class label
|
||||
|
||||
# Get the parser for the test dataset
|
||||
valid_dataset = dataset.listDataset(valid_images,
|
||||
|
@ -122,7 +122,7 @@ def valid(datacfg, modelcfg, weightfile):
|
|||
# Iterate through all batch elements
|
||||
for box_pr, target in zip([all_boxes], [target[0]]):
|
||||
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
|
||||
truths = target.view(-1, num_keypoints*2+3)
|
||||
truths = target.view(-1, num_labels)
|
||||
# Get how many objects are present in the scene
|
||||
num_gts = truths_length(truths)
|
||||
# Iterate through each ground-truth object
|
||||
|
@ -134,8 +134,8 @@ def valid(datacfg, modelcfg, weightfile):
|
|||
box_gt.append(truths[k][0])
|
||||
|
||||
# Denormalize the corner predictions
|
||||
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
|
||||
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
|
||||
corners2D_gt = np.array(np.reshape(box_gt[:18], [-1, 2]), dtype='float32')
|
||||
corners2D_pr = np.array(np.reshape(box_pr[:18], [-1, 2]), dtype='float32')
|
||||
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
|
||||
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
|
||||
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width
|
||||
|
|
Загрузка…
Ссылка в новой задаче