update to python3, pytorch0.4.1, refactor code

This commit is contained in:
Bugra Tekin 2019-10-18 18:45:05 +02:00
Родитель 766d5eced2
Коммит 283d13645a
67 изменённых файлов: 8690 добавлений и 1933 удалений

Просмотреть файл

@ -1,9 +1,11 @@
# SINGLESHOTPOSE
This is the code for the following paper:
This is the development version of the code for the following paper:
Bugra Tekin, Sudipta N. Sinha and Pascal Fua, "Real-Time Seamless Single Shot 6D Object Pose Prediction", CVPR 2018.
The original repository for the codebase for the above paper can be found in the following [link](https://github.com/Microsoft/singleshotpose/).
### Introduction
We propose a single-shot approach for simultaneously detecting an object in an RGB image and predicting its 6D pose without requiring multiple stages or having to examine multiple hypotheses. The key component of our method is a new CNN architecture inspired by the YOLO network design that directly predicts the 2D image locations of the projected vertices of the object's 3D bounding box. The object's 6D pose is then estimated using a PnP algorithm. [Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Tekin_Real-Time_Seamless_Single_CVPR_2018_paper.pdf), [arXiv](https://arxiv.org/abs/1711.08848)
@ -13,7 +15,7 @@ We propose a single-shot approach for simultaneously detecting an object in an R
#### Citation
If you use this code, please cite the following
> @inproceedings{tekin18,
      TITLE = {{Real-Time Seamless Single Shot 6D Object Pose Prediction}},
      TITLE = {{Real-Time Seamless Single Shot 6D Object Pose Prediction}},
      AUTHOR = {Tekin, Bugra and Sinha, Sudipta N. and Fua, Pascal},
      BOOKTITLE = {CVPR},
      YEAR = {2018}
@ -25,7 +27,7 @@ SingleShotPose is released under the MIT License (refer to the LICENSE file for
#### Environment and dependencies
The code is tested on Linux with CUDA v8 and cudNN v5.1. The implementation is based on PyTorch 0.3.1 and tested on Python2.7. The code requires the following dependencies that could be installed with conda or pip: numpy, scipy, PIL, opencv-python. For a version that is Python 3 and Pytorch 0.4 compatible, you could see [this link](https://github.com/btekin/singleshot6Dpose).
The code is tested on **Windows** with CUDA v8 and cudNN v5.1. The implementation is based on **PyTorch 0.4.1** and tested on **Python3.6**. The code requires the following dependencies that could be installed with conda or pip: numpy, scipy, PIL, opencv-python. For an earlier version that is compatible with PyTorch 0.3.1 and tested on Python2.7, please see ```py2``` folder.
#### Downloading and preparing the data
@ -41,25 +43,30 @@ tar xf backup.tar
tar xf multi_obj_pose_estimation/backup_multi.tar -C multi_obj_pose_estimation/
tar xf VOCtrainval_11-May-2012.tar
```
Alternatively, you can directly go to the links above and manually download and extract the files at the corresponding directories. The whole download process might take a long while (~60 minutes).
Alternatively, you can directly go to the links above and manually download and extract the files at the corresponding directories. The whole download process might take a long while (~60 minutes). Please also be aware that access to OneDrive in some countries might be limited.
#### Training the model
To train the model run,
```
python train.py datafile cfgfile initweightfile
python train.py --datacfg [path_to_data_config_file] --modelcfg [path_to_model_config_file] --initweightfile [path_to_initialization_weights] --pretrain_num_epochs [number_of_epochs to pretrain]
```
e.g.
```
python train.py cfg/ape.data cfg/yolo-pose.cfg backup/ape/init.weights
python train.py --datacfg cfg/ape.data --modelcfg cfg/yolo-pose.cfg --initweightfile cfg/darknet19_448.conv.23 --pretrain_num_epochs 15
```
if you would like to start from ImageNet initialized weights, or
```
python train.py --datacfg cfg/ape.data --modelcfg cfg/yolo-pose.cfg --initweightfile backup/duck/init.weights
```
if you would like to start with an already pretrained model on LINEMOD, for faster convergence.
[datafile] contains information about the training/test splits and 3D object models
**[datacfg]** contains information about the training/test splits, 3D object models and camera parameters
[cfgfile] contains information about the network structure
**[modelcfg]** contains information about the network structure
[initweightfile] contains initialization weights. The weights "backup/[OBJECT_NAME]/init.weights" are pretrained on LINEMOD for faster convergence. We found it effective to pretrain the model without confidence estimation first and fine-tune the network later on with confidence estimation as well. "init.weights" contain the weights of these pretrained networks. However, you can also still train the network from a more crude initialization (with weights trained on ImageNet). This usually results in a slower and sometimes slightly worse convergence. You can find in cfg/ folder the file <<darknet19_448.conv.23>> that includes the network weights pretrained on ImageNet. Alternatively, you can pretrain your own weights by setting the regularization parameter for the confidence loss to 0 as explained in "Pretraining the model" section.
**[initweightfile]** contains initialization weights. <<darknet19_448.conv.23>> contains the network weights pretrained on ImageNet. The weights "backup/[OBJECT_NAME]/init.weights" are pretrained on LINEMOD for faster convergence. We found it effective to pretrain the model without confidence estimation first and fine-tune the network later on with confidence estimation as well. "init.weights" contain the weights of these pretrained networks. However, you can also still train the network from a more crude initialization (with weights trained on ImageNet). This usually results in a slower and sometimes slightly worse convergence. You can find in cfg/ folder the file <<darknet19_448.conv.23>> that includes the network weights pretrained on ImageNet.
At the start of the training you will see an output like this:
@ -74,30 +81,21 @@ layer filters size input output
31 detection
```
This defines the network structure. During training, the best network model is saved into the "model.weights" file. To train networks for other objects, just change the object name while calling the train function, e.g., "python train.py cfg/duck.data cfg/yolo-pose.cfg backup/duck/init.weights"
This defines the network structure. During training, the best network model is saved into the "model.weights" file. To train networks for other objects, just change the object name while calling the train function, e.g., "```python train.py --datacfg cfg/duck.data --modelcfg cfg/yolo-pose.cfg --initweightfile backup/duck/init.weights```". If you come across GPU memory errors while training, you could try lowering the batch size, to for example 16 or 8, to fit into the memory. The open source version of the code has undergone strong refactoring and furthermore some models had to be retrained. The retrained models that we provide do not change much from the initial results that we provide (sometimes slight worse and sometimes slightly better).
#### Testing the model
To test the model run
```
python valid.py datafile cfgfile weightfile
python valid.py --datacfg [path_to_data_config_file] --modelcfg [path_to_model_config_file] --weightfile [path_to_trained_model_weights]
```
e.g.
```
python valid.py cfg/ape.data cfg/yolo-pose.cfg backup/ape/model_backup.weights
python valid.py --datacfg cfg/ape.data --modelcfg cfg/yolo-pose.cfg --weightfile backup/ape/model_backup.weights
```
[weightfile] contains our trained models.
#### Pretraining the model (Optional)
Models are already pretrained but if you would like to pretrain the network from scratch and get the initialization weights yourself, you can run the following:
python train.py cfg/ape.data cfg/yolo-pose-pre.cfg cfg/darknet19_448.conv.23
cp backup/ape/model.weights backup/ape/init.weights
During pretraining the regularization parameter for the confidence term is set to "0" in the config file "cfg/yolo-pose-pre.cfg". "darknet19_448.conv.23" includes the weights of YOLOv2 trained on ImageNet.
You could also use valid.ipynb to test the model and visualize the results.
#### Multi-object pose estimation on the OCCLUSION dataset
@ -125,23 +123,30 @@ python train_multi.py cfg/occlusion.data cfg/yolo-pose-multi.cfg backup_multi/in
#### Label files
Our label files consist of 21 values. We predict 9 points corresponding to the centroid and corners of the 3D object model. Additionally we predict the class in each cell. That makes 9x2+1 = 19 points. In multi-object training, during training, we assign whichever anchor box has the most similar size to the current object as the responsible one to predict the 2D coordinates for that object. To encode the size of the objects, we have additional 2 numbers for the range in x dimension and y dimension. Therefore, we have 9x2+1+2 = 21 numbers.
Our label files consist of 21 ground-truth values. We predict 9 points corresponding to the centroid and corners of the 3D object model. Additionally we predict the class in each cell. That makes 9x2+1 = 19 points. In multi-object training, during training, we assign whichever anchor box has the most similar size to the current object as the responsible one to predict the 2D coordinates for that object. To encode the size of the objects, we have additional 2 numbers for the range in x dimension and y dimension. Therefore, we have 9x2+1+2 = 21 numbers.
Respectively, 21 numbers correspond to the following: 1st number: class label, 2nd number: x0 (x-coordinate of the centroid), 3rd number: y0 (y-coordinate of the centroid), 4th number: x1 (x-coordinate of the first corner), 5th number: y1 (y-coordinate of the first corner), ..., 18th number: x8 (x-coordinate of the eighth corner), 19th number: y8 (y-coordinate of the eighth corner), 20th number: x range, 21st number: y range.
The coordinates are normalized by the image width and height: x / image_width and y / image_height. This is useful to have similar output ranges for the coordinate regression and object classification tasks.
The coordinates are normalized by the image width and height: ```x / image_width``` and ```y / image_height```. This is useful to have similar output ranges for the coordinate regression and object classification tasks.
#### Training on your own dataset
#### Tips for training on your own dataset
To train on your own dataset, simply create the same folder structure with the provided LINEMOD dataset and adjust the paths in cfg/[OBJECT].data, [DATASET]/[OBJECT]/train.txt and [DATASET]/[OBJECT]/test.txt files. The folder for each object should contain the following:
We train and test our models on the LINEMOD dataset using the same train/test splits with [the BB8 method](https://arxiv.org/pdf/1703.10896.pdf) to validate our approach. If you would like to train a model on your own dataset, you could create the same folder structure with the provided LINEMOD dataset and adjust the paths in cfg/[OBJECT].data, [DATASET]/[OBJECT]/train.txt and [DATASET]/[OBJECT]/test.txt files. The folder for each object should contain the following:
(1) a folder containing image files,
(2) a folder containing label files (please refer to [this link](https://github.com/Microsoft/singleshotpose/blob/master/label_file_creation.md) for a detailed explanation on how to create labels),
(3) a text file containing the filenames for training images (train.txt),
(4) a text file containing the filenames for test images (test.txt),
(5) a .ply file containing the 3D object model
(6) optionally, a folder containing segmentation masks (if you want to change the background of your training images to be more robust to diverse backgrounds),
(2) a folder containing label files (Please refer to [this link](https://github.com/Microsoft/singleshotpose/blob/master/label_file_creation.md) for a detailed explanation on how to create labels. You could also find third-party [ObjectDatasetTools](https://github.com/F2Wang/ObjectDatasetTools) toolbox useful to create ground-truth labels for 6D object pose estimation),
(3) a text file containing the filenames for training images (```train.txt```),
(4) a text file containing the filenames for test images (```test.txt```),
(5) a .ply file containing the 3D object model (The unit of the object model is given in meters),
(6) optionally, a folder containing segmentation masks (If you want to change the background of your training images to be more robust to diverse backgrounds, this would be essential for a better generalization ability),
Please also make sure to adjust the following values in the data and model configuration files according to your needs:
- You should change the "```diam```" value in the data configuration file with the diameter of the object model at hand.
- Depending on the size and variability of your training data, the learning rate schedule (steps, scales, max_epochs parameters in the ```yolo-pose.cfg``` file) and some data augmentation parameters (jitter, hue, saturation, exposure parameters in ```dataset.py```) might also need to be adjusted for a better convergence on your dataset.
- For multiple object pose estimation, you should also pre-compute anchor values using the procedure described in Section 3.2 of the paper and specify it in the model configuration file (```yolo-pose-multi.cfg```). Please also make sure to use correct number of classes and specify it in ```yolo-pose-multi.cfg```.
- You should further change the image size and camera parameters (```fx```, ```fy```, ```u0```, ```v0```, ```width```, ```height```) in the data configuration files with the ones specific to your dataset.
While creating a training dataset, sampling a large number of viewpoints/distances and modeling a large variability of illumination/occlusion/background settings would be important in increasing the generalization ability of the approach on your dataset. If you would like to adjust some model & loss parameters (e.g. weighthing factor for different loss terms) for your own purposes, you could do so in the model configuration file (```yolo-pose.cfg```).
#### Acknowledgments

20
cfg.py
Просмотреть файл

@ -53,9 +53,9 @@ def print_cfg(blocks):
kernel_size = int(block['size'])
stride = int(block['stride'])
is_pad = int(block['pad'])
pad = (kernel_size-1)/2 if is_pad else 0
width = (prev_width + 2*pad - kernel_size)/stride + 1
height = (prev_height + 2*pad - kernel_size)/stride + 1
pad = (kernel_size-1)//2 if is_pad else 0
width = (prev_width + 2*pad - kernel_size)//stride + 1
height = (prev_height + 2*pad - kernel_size)//stride + 1
print('%5d %-6s %4d %d x %d / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'conv', filters, kernel_size, kernel_size, stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
@ -66,8 +66,8 @@ def print_cfg(blocks):
elif block['type'] == 'maxpool':
pool_size = int(block['size'])
stride = int(block['stride'])
width = prev_width/stride
height = prev_height/stride
width = prev_width//stride
height = prev_height//stride
print('%5d %-6s %d x %d / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'max', pool_size, pool_size, stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
@ -98,8 +98,8 @@ def print_cfg(blocks):
elif block['type'] == 'reorg':
stride = int(block['stride'])
filters = stride * stride * prev_filters
width = prev_width/stride
height = prev_height/stride
width = prev_width//stride
height = prev_height//stride
print('%5d %-6s / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'reorg', stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
@ -154,7 +154,7 @@ def load_conv(buf, start, conv_model):
num_w = conv_model.weight.numel()
num_b = conv_model.bias.numel()
conv_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w]).view_as(conv_model.weight.data)); start = start + num_w
return start
def save_conv(fp, conv_model):
@ -172,7 +172,7 @@ def load_conv_bn(buf, start, conv_model, bn_model):
bn_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
bn_model.running_mean.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
bn_model.running_var.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w]).view_as(conv_model.weight.data)); start = start + num_w
return start
def save_conv_bn(fp, conv_model, bn_model):
@ -193,7 +193,7 @@ def load_fc(buf, start, fc_model):
num_w = fc_model.weight.numel()
num_b = fc_model.bias.numel()
fc_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
fc_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
fc_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w]).view_as(fc_model.weight.data)); start = start + num_w
return start
def save_fc(fp, fc_model):

Просмотреть файл

@ -28,10 +28,10 @@ class Reorg(nn.Module):
assert(W % stride == 0)
ws = stride
hs = stride
x = x.view(B, C, H/hs, hs, W/ws, ws).transpose(3,4).contiguous()
x = x.view(B, C, H/hs*W/ws, hs*ws).transpose(2,3).contiguous()
x = x.view(B, C, hs*ws, H/hs, W/ws).transpose(1,2).contiguous()
x = x.view(B, hs*ws*C, H/hs, W/ws)
x = x.view(B, C, H//hs, hs, W//ws, ws).transpose(3,4).contiguous()
x = x.view(B, C, H//hs*W//ws, hs*ws).transpose(2,3).contiguous()
x = x.view(B, C, hs*ws, H//hs, W//ws).transpose(1,2).contiguous()
x = x.view(B, hs*ws*C, H//hs, W//ws)
return x
class GlobalAvgPool2d(nn.Module):
@ -63,8 +63,11 @@ class Darknet(nn.Module):
self.models = self.create_network(self.blocks) # merge conv, bn,leaky
self.loss = self.models[len(self.models)-1]
self.width = int(self.blocks[0]['width'])
self.height = int(self.blocks[0]['height'])
self.width = int(self.blocks[0]['width'])
self.height = int(self.blocks[0]['height'])
self.test_width = int(self.blocks[0]['test_width'])
self.test_height = int(self.blocks[0]['test_height'])
self.num_keypoints = int(self.blocks[0]['num_keypoints'])
if self.blocks[(len(self.blocks)-1)]['type'] == 'region':
self.anchors = self.loss.anchors
@ -146,7 +149,7 @@ class Darknet(nn.Module):
kernel_size = int(block['size'])
stride = int(block['stride'])
is_pad = int(block['pad'])
pad = (kernel_size-1)/2 if is_pad else 0
pad = (kernel_size-1)//2 if is_pad else 0
activation = block['activation']
model = nn.Sequential()
if batch_normalize:
@ -233,7 +236,7 @@ class Darknet(nn.Module):
loss.anchors = [float(i) for i in anchors]
loss.num_classes = int(block['classes'])
loss.num_anchors = int(block['num'])
loss.anchor_step = len(loss.anchors)/loss.num_anchors
loss.anchor_step = len(loss.anchors)//loss.num_anchors
loss.object_scale = float(block['object_scale'])
loss.noobject_scale = float(block['noobject_scale'])
loss.class_scale = float(block['class_scale'])

Просмотреть файл

@ -13,13 +13,28 @@ from utils import read_truths_args, read_truths, get_all_files
class listDataset(Dataset):
def __init__(self, root, shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4, bg_file_names=None):
def __init__(self, root, shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4, cell_size=32, bg_file_names=None, num_keypoints=9, max_num_gt=50):
# root : list of training or test images
# shape : shape of the image input to the network
# shuffle : whether to shuffle or not
# tranform : any pytorch-specific transformation to the input image
# target_transform : any pytorch-specific tranformation to the target output
# train : whether it is training data or test data
# seen : the number of visited examples (iteration of the batch x batch size) # TODO: check if this is correctly assigned
# batch_size : how many examples there are in the batch
# num_workers : check what this is
# bg_file_names : the filenames for images from which you assign random backgrounds
# read the the list of dataset images
with open(root, 'r') as file:
self.lines = file.readlines()
# Shuffle
if shuffle:
random.shuffle(self.lines)
# Initialize variables
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
@ -29,40 +44,53 @@ class listDataset(Dataset):
self.batch_size = batch_size
self.num_workers = num_workers
self.bg_file_names = bg_file_names
self.cell_size = cell_size
self.nbatches = self.nSamples // self.batch_size
self.num_keypoints = num_keypoints
self.max_num_gt = max_num_gt # maximum number of ground-truth labels an image can have
# Get the number of samples in the dataset
def __len__(self):
return self.nSamples
# Get a sample from the dataset
def __getitem__(self, index):
# Ensure the index is smallet than the number of samples in the dataset, otherwise return error
assert index <= len(self), 'index range error'
# Get the image path
imgpath = self.lines[index].rstrip()
if self.train and index % 32== 0:
if self.seen < 400*32:
width = 13*32
# Decide which size you are going to resize the image depending on the epoch (10, 20, etc.)
if self.train and index % self.batch_size== 0:
if self.seen < 10*self.nbatches*self.batch_size:
width = 13*self.cell_size
self.shape = (width, width)
elif self.seen < 800*32:
width = (random.randint(0,7) + 13)*32
elif self.seen < 20*self.nbatches*self.batch_size:
width = (random.randint(0,7) + 13)*self.cell_size
self.shape = (width, width)
elif self.seen < 1200*32:
width = (random.randint(0,9) + 12)*32
elif self.seen < 30*self.nbatches*self.batch_size:
width = (random.randint(0,9) + 12)*self.cell_size
self.shape = (width, width)
elif self.seen < 1600*32:
width = (random.randint(0,11) + 11)*32
elif self.seen < 40*self.nbatches*self.batch_size:
width = (random.randint(0,11) + 11)*self.cell_size
self.shape = (width, width)
elif self.seen < 2000*32:
width = (random.randint(0,13) + 10)*32
elif self.seen < 50*self.nbatches*self.batch_size:
width = (random.randint(0,13) + 10)*self.cell_size
self.shape = (width, width)
elif self.seen < 2400*32:
width = (random.randint(0,15) + 9)*32
elif self.seen < 60*self.nbatches*self.batch_size:
width = (random.randint(0,15) + 9)*self.cell_size
self.shape = (width, width)
elif self.seen < 3000*32:
width = (random.randint(0,17) + 8)*32
elif self.seen < 70*self.nbatches*self.batch_size:
width = (random.randint(0,17) + 8)*self.cell_size
self.shape = (width, width)
else: # self.seen < 20000*64:
width = (random.randint(0,19) + 7)*32
else:
width = (random.randint(0,19) + 7)*self.cell_size
self.shape = (width, width)
if self.train:
# Decide on how much data augmentation you are going to apply
jitter = 0.2
hue = 0.1
saturation = 1.5
@ -70,32 +98,44 @@ class listDataset(Dataset):
# Get background image path
random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
bgpath = self.bg_file_names[random_bg_index]
bgpath = self.bg_file_names[random_bg_index]
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath)
# Get the data augmented image and their corresponding labels
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath, self.num_keypoints, self.max_num_gt)
# Convert the labels to PyTorch variables
label = torch.from_numpy(label)
else:
# Get the validation image, resize it to the network input size
img = Image.open(imgpath).convert('RGB')
if self.shape:
img = img.resize(self.shape)
# Read the validation labels, allow upto 50 ground-truth objects in an image
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
label = torch.zeros(50*21)
num_labels = 2*self.num_keypoints+3 # +2 for ground-truth of width/height , +1 for class label
label = torch.zeros(self.max_num_gt*num_labels)
if os.path.getsize(labpath):
ow, oh = img.size
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/ow))
tmp = torch.from_numpy(read_truths_args(labpath))
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > 50*21:
label = tmp[0:50*21]
if tsz > self.max_num_gt*num_labels:
label = tmp[0:self.max_num_gt*num_labels]
elif tsz > 0:
label[0:tsz] = tmp
# Tranform the image data to PyTorch tensors
if self.transform is not None:
img = self.transform(img)
# If there is any PyTorch-specific transformation, transform the label data
if self.target_transform is not None:
label = self.target_transform(label)
# Increase the number of seen examples
self.seen = self.seen + self.num_workers
# Return the retrieved image and its corresponding label
return (img, label)

Просмотреть файл

@ -73,73 +73,32 @@ def data_augmentation(img, shape, jitter, hue, saturation, exposure):
return img, flip, dx,dy,sx,sy
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy):
max_boxes = 50
label = np.zeros((max_boxes,21))
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy, num_keypoints, max_num_gt):
num_labels = 2 * num_keypoints + 3
label = np.zeros((max_num_gt,num_labels))
if os.path.getsize(labpath):
bs = np.loadtxt(labpath)
if bs is None:
return label
bs = np.reshape(bs, (-1, 21))
bs = np.reshape(bs, (-1, num_labels))
cc = 0
for i in range(bs.shape[0]):
x0 = bs[i][1]
y0 = bs[i][2]
x1 = bs[i][3]
y1 = bs[i][4]
x2 = bs[i][5]
y2 = bs[i][6]
x3 = bs[i][7]
y3 = bs[i][8]
x4 = bs[i][9]
y4 = bs[i][10]
x5 = bs[i][11]
y5 = bs[i][12]
x6 = bs[i][13]
y6 = bs[i][14]
x7 = bs[i][15]
y7 = bs[i][16]
x8 = bs[i][17]
y8 = bs[i][18]
xs = list()
ys = list()
for j in range(num_keypoints):
xs.append(bs[i][2*j+1])
ys.append(bs[i][2*j+2])
x0 = min(0.999, max(0, x0 * sx - dx))
y0 = min(0.999, max(0, y0 * sy - dy))
x1 = min(0.999, max(0, x1 * sx - dx))
y1 = min(0.999, max(0, y1 * sy - dy))
x2 = min(0.999, max(0, x2 * sx - dx))
y2 = min(0.999, max(0, y2 * sy - dy))
x3 = min(0.999, max(0, x3 * sx - dx))
y3 = min(0.999, max(0, y3 * sy - dy))
x4 = min(0.999, max(0, x4 * sx - dx))
y4 = min(0.999, max(0, y4 * sy - dy))
x5 = min(0.999, max(0, x5 * sx - dx))
y5 = min(0.999, max(0, y5 * sy - dy))
x6 = min(0.999, max(0, x6 * sx - dx))
y6 = min(0.999, max(0, y6 * sy - dy))
x7 = min(0.999, max(0, x7 * sx - dx))
y7 = min(0.999, max(0, y7 * sy - dy))
x8 = min(0.999, max(0, x8 * sx - dx))
y8 = min(0.999, max(0, y8 * sy - dy))
bs[i][1] = x0
bs[i][2] = y0
bs[i][3] = x1
bs[i][4] = y1
bs[i][5] = x2
bs[i][6] = y2
bs[i][7] = x3
bs[i][8] = y3
bs[i][9] = x4
bs[i][10] = y4
bs[i][11] = x5
bs[i][12] = y5
bs[i][13] = x6
bs[i][14] = y6
bs[i][15] = x7
bs[i][16] = y7
bs[i][17] = x8
bs[i][18] = y8
# Make sure the centroid of the object/hand is within image
xs[0] = min(0.999, max(0, xs[0] * sx - dx))
ys[0] = min(0.999, max(0, ys[0] * sy - dy))
for j in range(1,num_keypoints):
xs[j] = xs[j] * sx - dx
ys[j] = ys[j] * sy - dy
for j in range(num_keypoints):
bs[i][2*j+1] = xs[j]
bs[i][2*j+2] = ys[j]
label[cc] = bs[i]
cc += 1
if cc >= 50:
@ -167,7 +126,7 @@ def change_background(img, mask, bg):
return out
def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpath):
def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpath, num_keypoints, max_num_gt):
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
maskpath = imgpath.replace('JPEGImages', 'mask').replace('/00', '/').replace('.jpg', '.png')
@ -179,6 +138,6 @@ def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpat
img = change_background(img, mask, bg)
img,flip,dx,dy,sx,sy = data_augmentation(img, shape, jitter, hue, saturation, exposure)
ow, oh = img.size
label = fill_truth_detection(labpath, ow, oh, flip, dx, dy, 1./sx, 1./sy)
label = fill_truth_detection(labpath, ow, oh, flip, dx, dy, 1./sx, 1./sy, num_keypoints, max_num_gt)
return img,label

Просмотреть файл

@ -28,10 +28,10 @@ class Reorg(nn.Module):
assert(W % stride == 0)
ws = stride
hs = stride
x = x.view(B, C, H/hs, hs, W/ws, ws).transpose(3,4).contiguous()
x = x.view(B, C, H/hs*W/ws, hs*ws).transpose(2,3).contiguous()
x = x.view(B, C, hs*ws, H/hs, W/ws).transpose(1,2).contiguous()
x = x.view(B, hs*ws*C, H/hs, W/ws)
x = x.view(B, C, H//hs, hs, W//ws, ws).transpose(3,4).contiguous()
x = x.view(B, C, H//hs*W//ws, hs*ws).transpose(2,3).contiguous()
x = x.view(B, C, hs*ws, H//hs, W//ws).transpose(1,2).contiguous()
x = x.view(B, hs*ws*C, H//hs, W//ws)
return x
class GlobalAvgPool2d(nn.Module):
@ -146,7 +146,7 @@ class Darknet(nn.Module):
kernel_size = int(block['size'])
stride = int(block['stride'])
is_pad = int(block['pad'])
pad = (kernel_size-1)/2 if is_pad else 0
pad = (kernel_size-1)//2 if is_pad else 0
activation = block['activation']
model = nn.Sequential()
if batch_normalize:
@ -230,7 +230,7 @@ class Darknet(nn.Module):
loss.anchors = [float(i) for i in anchors]
loss.num_classes = int(block['classes'])
loss.num_anchors = int(block['num'])
loss.anchor_step = len(loss.anchors)/loss.num_anchors
loss.anchor_step = len(loss.anchors)//loss.num_anchors
loss.object_scale = float(block['object_scale'])
loss.noobject_scale = float(block['noobject_scale'])
loss.class_scale = float(block['class_scale'])

Просмотреть файл

@ -8,12 +8,12 @@ import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from utils import read_truths_args, read_truths, get_all_files
from utils_multi import read_truths_args, read_truths, get_all_files
from image_multi import *
class listDataset(Dataset):
def __init__(self, root, shape=None, shuffle=True, transform=None, objclass=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4, bg_file_names=None): # bg='/cvlabdata1/home/btekin/ope/data/office_bg'
def __init__(self, root, shape=None, shuffle=True, transform=None, objclass=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4, cell_size=32, bg_file_names=None, num_keypoints=9, max_num_gt=50):
with open(root, 'r') as file:
self.lines = file.readlines()
if shuffle:
@ -26,9 +26,12 @@ class listDataset(Dataset):
self.seen = seen
self.batch_size = batch_size
self.num_workers = num_workers
# self.bg_file_names = get_all_files(bg)
self.bg_file_names = bg_file_names
self.objclass = objclass
self.cell_size = cell_size
self.nbatches = self.nSamples // self.batch_size
self.num_keypoints = num_keypoints
self.max_num_gt = max_num_gt # maximum number of ground-truth labels an image can have
def __len__(self):
return self.nSamples
@ -37,25 +40,25 @@ class listDataset(Dataset):
assert index <= len(self), 'index range error'
imgpath = self.lines[index].rstrip()
if self.train and index % 64== 0:
if self.seen < 4000*64:
width = 13*32
if self.train and index % self.batch_size == 0:
if self.seen < 20*self.nbatches*self.batch_size:
width = 13*self.cell_size
self.shape = (width, width)
elif self.seen < 8000*64:
width = (random.randint(0,3) + 13)*32
elif self.seen < 40*self.nbatches*self.batch_size:
width = (random.randint(0,3) + 13)*self.cell_size
self.shape = (width, width)
elif self.seen < 12000*64:
width = (random.randint(0,5) + 12)*32
elif self.seen < 60*self.nbatches*self.batch_size:
width = (random.randint(0,5) + 12)*self.cell_size
self.shape = (width, width)
elif self.seen < 16000*64:
width = (random.randint(0,7) + 11)*32
elif self.seen < 80*self.nbatches*self.batch_size:
width = (random.randint(0,7) + 11)*self.cell_size
self.shape = (width, width)
else: # self.seen < 20000*64:
width = (random.randint(0,9) + 10)*32
else:
width = (random.randint(0,9) + 10)*self.cell_size
self.shape = (width, width)
if self.train:
# jitter = 0.2
# Decide on how much data augmentation you are going to apply
jitter = 0.1
hue = 0.05
saturation = 1.5
@ -65,7 +68,7 @@ class listDataset(Dataset):
random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
bgpath = self.bg_file_names[random_bg_index]
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath)
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath, self.num_keypoints, self.max_num_gt)
label = torch.from_numpy(label)
else:
img = Image.open(imgpath).convert('RGB')
@ -73,14 +76,15 @@ class listDataset(Dataset):
img = img.resize(self.shape)
labpath = imgpath.replace('benchvise', self.objclass).replace('images', 'labels_occlusion').replace('JPEGImages', 'labels_occlusion').replace('.jpg', '.txt').replace('.png','.txt')
label = torch.zeros(50*21)
num_labels = 2*self.num_keypoints+3 # +2 for ground-truth of width/height , +1 for class label
label = torch.zeros(self.max_num_gt*num_labels)
if os.path.getsize(labpath):
ow, oh = img.size
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/ow))
tmp = torch.from_numpy(read_truths_args(labpath))
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > 50*21:
label = tmp[0:50*21]
if tsz > self.max_num_gt*num_labels:
label = tmp[0:self.max_num_gt*num_labels]
elif tsz > 0:
label[0:tsz] = tmp

Просмотреть файл

@ -5,20 +5,6 @@ import os
from PIL import Image, ImageChops, ImageMath
import numpy as np
def load_data_detection_backup(imgpath, shape, jitter, hue, saturation, exposure, bgpath):
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
maskpath = imgpath.replace('JPEGImages', 'mask').replace('/00', '/').replace('.jpg', '.png')
## data augmentation
img = Image.open(imgpath).convert('RGB')
mask = Image.open(maskpath).convert('RGB')
bg = Image.open(bgpath).convert('RGB')
img = change_background(img, mask, bg)
img,flip,dx,dy,sx,sy = data_augmentation(img, shape, jitter, hue, saturation, exposure)
label = fill_truth_detection(labpath, img.width, img.height, flip, dx, dy, 1./sx, 1./sy)
return img,label
def get_add_objs(objname):
# Decide how many additional objects you will augment and what will be the other types of objects
if objname == 'ape':
@ -87,7 +73,6 @@ def distort_image(im, hue, sat, val):
im = Image.merge(im.mode, tuple(cs))
im = im.convert('RGB')
#constrain_image(im)
return im
def rand_scale(s):
@ -135,98 +120,45 @@ def data_augmentation(img, shape, jitter, hue, saturation, exposure):
return img, flip, dx,dy,sx,sy
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy):
max_boxes = 50
label = np.zeros((max_boxes,21))
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy, num_keypoints, max_num_gt):
num_labels = 2*num_keypoints+3 # +2 for width, height, +1 for class label
label = np.zeros((max_num_gt,num_labels))
if os.path.getsize(labpath):
bs = np.loadtxt(labpath)
if bs is None:
return label
bs = np.reshape(bs, (-1, 21))
bs = np.reshape(bs, (-1, num_labels))
cc = 0
for i in range(bs.shape[0]):
x0 = bs[i][1]
y0 = bs[i][2]
x1 = bs[i][3]
y1 = bs[i][4]
x2 = bs[i][5]
y2 = bs[i][6]
x3 = bs[i][7]
y3 = bs[i][8]
x4 = bs[i][9]
y4 = bs[i][10]
x5 = bs[i][11]
y5 = bs[i][12]
x6 = bs[i][13]
y6 = bs[i][14]
x7 = bs[i][15]
y7 = bs[i][16]
x8 = bs[i][17]
y8 = bs[i][18]
xs = list()
ys = list()
for j in range(num_keypoints):
xs.append(bs[i][2*j+1])
ys.append(bs[i][2*j+2])
x0 = min(0.999, max(0, x0 * sx - dx))
y0 = min(0.999, max(0, y0 * sy - dy))
x1 = min(0.999, max(0, x1 * sx - dx))
y1 = min(0.999, max(0, y1 * sy - dy))
x2 = min(0.999, max(0, x2 * sx - dx))
y2 = min(0.999, max(0, y2 * sy - dy))
x3 = min(0.999, max(0, x3 * sx - dx))
y3 = min(0.999, max(0, y3 * sy - dy))
x4 = min(0.999, max(0, x4 * sx - dx))
y4 = min(0.999, max(0, y4 * sy - dy))
x5 = min(0.999, max(0, x5 * sx - dx))
y5 = min(0.999, max(0, y5 * sy - dy))
x6 = min(0.999, max(0, x6 * sx - dx))
y6 = min(0.999, max(0, y6 * sy - dy))
x7 = min(0.999, max(0, x7 * sx - dx))
y7 = min(0.999, max(0, y7 * sy - dy))
x8 = min(0.999, max(0, x8 * sx - dx))
y8 = min(0.999, max(0, y8 * sy - dy))
# Make sure the centroid of the object/hand is within image
xs[0] = min(0.999, max(0, xs[0] * sx - dx))
ys[0] = min(0.999, max(0, ys[0] * sy - dy))
for j in range(1,num_keypoints):
xs[j] = xs[j] * sx - dx
ys[j] = ys[j] * sy - dy
bs[i][0] = bs[i][0]
bs[i][1] = x0
bs[i][2] = y0
bs[i][3] = x1
bs[i][4] = y1
bs[i][5] = x2
bs[i][6] = y2
bs[i][7] = x3
bs[i][8] = y3
bs[i][9] = x4
bs[i][10] = y4
bs[i][11] = x5
bs[i][12] = y5
bs[i][13] = x6
bs[i][14] = y6
bs[i][15] = x7
bs[i][16] = y7
bs[i][17] = x8
bs[i][18] = y8
for j in range(num_keypoints):
bs[i][2*j+1] = xs[j]
bs[i][2*j+2] = ys[j]
xs = [x1, x2, x3, x4, x5, x6, x7, x8]
ys = [y1, y2, y3, y4, y5, y6, y7, y8]
min_x = min(xs);
max_x = max(xs);
min_y = min(ys);
max_y = max(ys);
bs[i][19] = max_x - min_x;
bs[i][20] = max_y - min_y;
if flip:
bs[i][1] = 0.999 - bs[i][1]
bs[i][3] = 0.999 - bs[i][3]
bs[i][5] = 0.999 - bs[i][5]
bs[i][7] = 0.999 - bs[i][7]
bs[i][9] = 0.999 - bs[i][9]
bs[i][11] = 0.999 - bs[i][11]
bs[i][13] = 0.999 - bs[i][13]
bs[i][15] = 0.999 - bs[i][15]
bs[i][17] = 0.999 - bs[i][17]
bs[i][2*num_keypoints+1] = max_x - min_x;
bs[i][2*num_keypoints+2] = max_y - min_y;
label[cc] = bs[i]
cc += 1
if cc >= 50:
if cc >= max_num_gt:
break
label = np.reshape(label, (-1))
@ -364,10 +296,10 @@ def superimpose_masks(mask, total_mask):
return out
def augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation, exposure):
def augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation, exposure, num_keypoints, max_num_gt):
pixelThreshold = 200
num_labels = 2*num_keypoints+3
random.shuffle(add_objs)
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
maskpath = imgpath.replace('JPEGImages', 'mask').replace('/00', '/').replace('.jpg', '.png')
@ -377,8 +309,8 @@ def augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation,
iw, ih = img.size
mask = Image.open(maskpath).convert('RGB')
img,mask,flip,dx,dy,sx,sy = shifted_data_augmentation_with_mask(img, mask, shape, jitter, hue, saturation, exposure)
label = fill_truth_detection(labpath, iw, ih, flip, dx, dy, 1./sx, 1./sy)
total_label = np.reshape(label, (-1, 21))
label = fill_truth_detection(labpath, iw, ih, flip, dx, dy, 1./sx, 1./sy, num_keypoints, max_num_gt)
total_label = np.reshape(label, (-1, num_labels))
# Mask the background
masked_img = mask_background(img, mask)
@ -406,7 +338,7 @@ def augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation,
obj_rand_masked_img = mask_background(obj_rand_img, obj_rand_mask)
obj_rand_masked_img,obj_rand_mask,flip,dx,dy,sx,sy = data_augmentation_with_mask(obj_rand_masked_img, obj_rand_mask, shape, jitter, hue, saturation, exposure)
obj_rand_label = fill_truth_detection(obj_rand_lab_path, iw, ih, flip, dx, dy, 1./sx, 1./sy)
obj_rand_label = fill_truth_detection(obj_rand_lab_path, iw, ih, flip, dx, dy, 1./sx, 1./sy, num_keypoints, max_num_gt)
# compute intersection (ratio of the object part intersecting with other object parts over the area of the object)
xx = np.array(obj_rand_mask)
@ -422,7 +354,7 @@ def augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation,
successful = True
total_mask = superimpose_masks(obj_rand_mask, total_mask) # total_mask + obj_rand_mask
total_masked_img = superimpose_masked_imgs(obj_rand_masked_img, obj_rand_mask, total_masked_img) # total_masked_img + obj_rand_masked_img
obj_rand_label = np.reshape(obj_rand_label, (-1, 21))
obj_rand_label = np.reshape(obj_rand_label, (-1, num_labels))
total_label[count, :] = obj_rand_label[0, :]
count = count + 1
else:
@ -432,7 +364,7 @@ def augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation,
return total_masked_img, np.reshape(total_label, (-1)), total_mask
def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpath):
def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpath, num_keypoints, max_num_gt):
# Read the background image
bg = Image.open(bgpath).convert('RGB')
@ -441,10 +373,11 @@ def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpat
dirname = os.path.dirname(os.path.dirname(imgpath)) ## dir of dir of file
objname = os.path.basename(dirname)
add_objs = get_add_objs(objname)
num_labels = 2*num_keypoints+3
# Add additional objects in the scene, apply data augmentation on the objects
total_masked_img, label, total_mask = augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation, exposure)
total_masked_img, label, total_mask = augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation, exposure, num_keypoints, max_num_gt)
img = change_background(total_masked_img, total_mask, bg)
lb = np.reshape(label, (-1, 21))
lb = np.reshape(label, (-1, num_labels))
return img,label

Просмотреть файл

@ -4,178 +4,110 @@ import math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from utils import *
from utils_multi import *
def build_targets(pred_corners, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, seen):
def build_targets(pred_corners, target, num_keypoints, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, seen):
nB = target.size(0)
nA = num_anchors
nC = num_classes
anchor_step = len(anchors)/num_anchors
anchor_step = len(anchors)//num_anchors
conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale
coord_mask = torch.zeros(nB, nA, nH, nW)
cls_mask = torch.zeros(nB, nA, nH, nW)
tx0 = torch.zeros(nB, nA, nH, nW)
ty0 = torch.zeros(nB, nA, nH, nW)
tx1 = torch.zeros(nB, nA, nH, nW)
ty1 = torch.zeros(nB, nA, nH, nW)
tx2 = torch.zeros(nB, nA, nH, nW)
ty2 = torch.zeros(nB, nA, nH, nW)
tx3 = torch.zeros(nB, nA, nH, nW)
ty3 = torch.zeros(nB, nA, nH, nW)
tx4 = torch.zeros(nB, nA, nH, nW)
ty4 = torch.zeros(nB, nA, nH, nW)
tx5 = torch.zeros(nB, nA, nH, nW)
ty5 = torch.zeros(nB, nA, nH, nW)
tx6 = torch.zeros(nB, nA, nH, nW)
ty6 = torch.zeros(nB, nA, nH, nW)
tx7 = torch.zeros(nB, nA, nH, nW)
ty7 = torch.zeros(nB, nA, nH, nW)
tx8 = torch.zeros(nB, nA, nH, nW)
ty8 = torch.zeros(nB, nA, nH, nW)
txs = list()
tys = list()
for i in range(num_keypoints):
txs.append(torch.zeros(nB, nA, nH, nW))
tys.append(torch.zeros(nB, nA, nH, nW))
tconf = torch.zeros(nB, nA, nH, nW)
tcls = torch.zeros(nB, nA, nH, nW)
num_labels = 2 * num_keypoints + 3 # +2 for width, height and +1 for class within label files
nAnchors = nA*nH*nW
nPixels = nH*nW
for b in xrange(nB):
for b in range(nB):
cur_pred_corners = pred_corners[b*nAnchors:(b+1)*nAnchors].t()
cur_confs = torch.zeros(nAnchors)
for t in xrange(50):
if target[b][t*21+1] == 0:
for t in range(50):
if target[b][t*num_labels+1] == 0:
break
gx0 = target[b][t*21+1]*nW
gy0 = target[b][t*21+2]*nH
gx1 = target[b][t*21+3]*nW
gy1 = target[b][t*21+4]*nH
gx2 = target[b][t*21+5]*nW
gy2 = target[b][t*21+6]*nH
gx3 = target[b][t*21+7]*nW
gy3 = target[b][t*21+8]*nH
gx4 = target[b][t*21+9]*nW
gy4 = target[b][t*21+10]*nH
gx5 = target[b][t*21+11]*nW
gy5 = target[b][t*21+12]*nH
gx6 = target[b][t*21+13]*nW
gy6 = target[b][t*21+14]*nH
gx7 = target[b][t*21+15]*nW
gy7 = target[b][t*21+16]*nH
gx8 = target[b][t*21+17]*nW
gy8 = target[b][t*21+18]*nH
g = list()
for i in range(num_keypoints):
g.append(target[b][t*num_labels+2*i+1])
g.append(target[b][t*num_labels+2*i+2])
cur_gt_corners = torch.FloatTensor([gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]).repeat(nAnchors,1).t() # 16 x nAnchors
cur_confs = torch.max(cur_confs, corner_confidences9(cur_pred_corners, cur_gt_corners)) # some irrelevant areas are filtered, in the same grid multiple anchor boxes might exceed the threshold
cur_gt_corners = torch.FloatTensor(g).repeat(nAnchors,1).t() # 18 x nAnchors
cur_confs = torch.max(cur_confs.view_as(conf_mask[b]), corner_confidences(cur_pred_corners, cur_gt_corners).view_as(conf_mask[b])) # some irrelevant areas are filtered, in the same grid multiple anchor boxes might exceed the threshold
conf_mask[b][cur_confs>sil_thresh] = 0
if seen < -1:#6400:
tx0.fill_(0.5)
ty0.fill_(0.5)
tx1.fill_(0.5)
ty1.fill_(0.5)
tx2.fill_(0.5)
ty2.fill_(0.5)
tx3.fill_(0.5)
ty3.fill_(0.5)
tx4.fill_(0.5)
ty4.fill_(0.5)
tx5.fill_(0.5)
ty5.fill_(0.5)
tx6.fill_(0.5)
ty6.fill_(0.5)
tx7.fill_(0.5)
ty7.fill_(0.5)
tx8.fill_(0.5)
ty8.fill_(0.5)
coord_mask.fill_(1)
nGT = 0
nCorrect = 0
for b in xrange(nB):
for t in xrange(50):
if target[b][t*21+1] == 0:
for b in range(nB):
for t in range(50):
if target[b][t*num_labels+1] == 0:
break
nGT = nGT + 1
best_iou = 0.0
best_n = -1
min_dist = 10000
gx0 = target[b][t*21+1] * nW
gy0 = target[b][t*21+2] * nH
gi0 = int(gx0)
gj0 = int(gy0)
gx1 = target[b][t*21+3] * nW
gy1 = target[b][t*21+4] * nH
gx2 = target[b][t*21+5] * nW
gy2 = target[b][t*21+6] * nH
gx3 = target[b][t*21+7] * nW
gy3 = target[b][t*21+8] * nH
gx4 = target[b][t*21+9] * nW
gy4 = target[b][t*21+10] * nH
gx5 = target[b][t*21+11] * nW
gy5 = target[b][t*21+12] * nH
gx6 = target[b][t*21+13] * nW
gy6 = target[b][t*21+14] * nH
gx7 = target[b][t*21+15] * nW
gy7 = target[b][t*21+16] * nH
gx8 = target[b][t*21+17] * nW
gy8 = target[b][t*21+18] * nH
min_dist = sys.maxsize
gx = list()
gy = list()
gt_box = list()
for i in range(num_keypoints):
gt_box.extend([target[b][t*num_labels+2*i+1], target[b][t*num_labels+2*i+2]])
gx.append(target[b][t*num_labels+2*i+1] * nW)
gy.append(target[b][t*num_labels+2*i+2] * nH)
if i == 0:
gi0 = int(gx[i])
gj0 = int(gy[i])
pred_box = pred_corners[b*nAnchors+best_n*nPixels+gj0*nW+gi0]
conf = corner_confidence(gt_box, pred_box)
gw = target[b][t*21+19]*nW
gh = target[b][t*21+20]*nH
gt_box = [0, 0, gw, gh]
for n in xrange(nA):
# Decide which anchor to use during prediction
gw = target[b][t*num_labels+num_labels-2]*nW
gh = target[b][t*num_labels+num_labels-1]*nH
gt_2d_box = [0, 0, gw, gh]
for n in range(nA):
aw = anchors[anchor_step*n]
ah = anchors[anchor_step*n+1]
anchor_box = [0, 0, aw, ah]
iou = bbox_iou(anchor_box, gt_box, x1y1x2y2=False)
iou = bbox_iou(anchor_box, gt_2d_box, x1y1x2y2=False)
if iou > best_iou:
best_iou = iou
best_n = n
gt_box = [gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]
pred_box = pred_corners[b*nAnchors+best_n*nPixels+gj0*nW+gi0]
conf = corner_confidence9(gt_box, pred_box)
coord_mask[b][best_n][gj0][gi0] = 1
cls_mask[b][best_n][gj0][gi0] = 1
conf_mask[b][best_n][gj0][gi0] = object_scale
tx0[b][best_n][gj0][gi0] = target[b][t*21+1] * nW - gi0
ty0[b][best_n][gj0][gi0] = target[b][t*21+2] * nH - gj0
tx1[b][best_n][gj0][gi0] = target[b][t*21+3] * nW - gi0
ty1[b][best_n][gj0][gi0] = target[b][t*21+4] * nH - gj0
tx2[b][best_n][gj0][gi0] = target[b][t*21+5] * nW - gi0
ty2[b][best_n][gj0][gi0] = target[b][t*21+6] * nH - gj0
tx3[b][best_n][gj0][gi0] = target[b][t*21+7] * nW - gi0
ty3[b][best_n][gj0][gi0] = target[b][t*21+8] * nH - gj0
tx4[b][best_n][gj0][gi0] = target[b][t*21+9] * nW - gi0
ty4[b][best_n][gj0][gi0] = target[b][t*21+10] * nH - gj0
tx5[b][best_n][gj0][gi0] = target[b][t*21+11] * nW - gi0
ty5[b][best_n][gj0][gi0] = target[b][t*21+12] * nH - gj0
tx6[b][best_n][gj0][gi0] = target[b][t*21+13] * nW - gi0
ty6[b][best_n][gj0][gi0] = target[b][t*21+14] * nH - gj0
tx7[b][best_n][gj0][gi0] = target[b][t*21+15] * nW - gi0
ty7[b][best_n][gj0][gi0] = target[b][t*21+16] * nH - gj0
tx8[b][best_n][gj0][gi0] = target[b][t*21+17] * nW - gi0
ty8[b][best_n][gj0][gi0] = target[b][t*21+18] * nH - gj0
# Update targets
for i in range(num_keypoints):
txs[i][b][best_n][gj0][gi0] = gx[i]- gi0
tys[i][b][best_n][gj0][gi0] = gy[i]- gj0
tconf[b][best_n][gj0][gi0] = conf
tcls[b][best_n][gj0][gi0] = target[b][t*21]
tcls[b][best_n][gj0][gi0] = target[b][t*num_labels]
if conf > 0.5:
nCorrect = nCorrect + 1
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, txs, tys, tconf, tcls
class RegionLoss(nn.Module):
def __init__(self, num_classes=0, anchors=[], num_anchors=5):
def __init__(self, num_keypoints=9, num_classes=13, anchors=[], num_anchors=5, pretrain_num_epochs=15):
super(RegionLoss, self).__init__()
self.num_classes = num_classes
self.anchors = anchors
self.num_anchors = num_anchors
self.anchor_step = len(anchors)/num_anchors
self.num_keypoints = num_keypoints
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.seen = 0
self.pretrain_num_epochs = pretrain_num_epochs
def forward(self, output, target):
def forward(self, output, target, epoch):
# Parameters
t0 = time.time()
nB = output.data.size(0)
@ -185,81 +117,40 @@ class RegionLoss(nn.Module):
nW = output.data.size(3)
# Activation
output = output.view(nB, nA, (19+nC), nH, nW)
x0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW))
y0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW))
x1 = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW)
y1 = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW)
x2 = output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW)
y2 = output.index_select(2, Variable(torch.cuda.LongTensor([5]))).view(nB, nA, nH, nW)
x3 = output.index_select(2, Variable(torch.cuda.LongTensor([6]))).view(nB, nA, nH, nW)
y3 = output.index_select(2, Variable(torch.cuda.LongTensor([7]))).view(nB, nA, nH, nW)
x4 = output.index_select(2, Variable(torch.cuda.LongTensor([8]))).view(nB, nA, nH, nW)
y4 = output.index_select(2, Variable(torch.cuda.LongTensor([9]))).view(nB, nA, nH, nW)
x5 = output.index_select(2, Variable(torch.cuda.LongTensor([10]))).view(nB, nA, nH, nW)
y5 = output.index_select(2, Variable(torch.cuda.LongTensor([11]))).view(nB, nA, nH, nW)
x6 = output.index_select(2, Variable(torch.cuda.LongTensor([12]))).view(nB, nA, nH, nW)
y6 = output.index_select(2, Variable(torch.cuda.LongTensor([13]))).view(nB, nA, nH, nW)
x7 = output.index_select(2, Variable(torch.cuda.LongTensor([14]))).view(nB, nA, nH, nW)
y7 = output.index_select(2, Variable(torch.cuda.LongTensor([15]))).view(nB, nA, nH, nW)
x8 = output.index_select(2, Variable(torch.cuda.LongTensor([16]))).view(nB, nA, nH, nW)
y8 = output.index_select(2, Variable(torch.cuda.LongTensor([17]))).view(nB, nA, nH, nW)
conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([18]))).view(nB, nA, nH, nW))
cls = output.index_select(2, Variable(torch.linspace(19,19+nC-1,nC).long().cuda()))
output = output.view(nB, nA, (2*self.num_keypoints+1+nC), nH, nW)
x = list()
y = list()
x.append(torch.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW)))
y.append(torch.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW)))
for i in range(1,self.num_keypoints):
x.append(output.index_select(2, Variable(torch.cuda.LongTensor([2 * i + 0]))).view(nB, nA, nH, nW))
y.append(output.index_select(2, Variable(torch.cuda.LongTensor([2 * i + 1]))).view(nB, nA, nH, nW))
conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([2*self.num_keypoints]))).view(nB, nA, nH, nW))
cls = output.index_select(2, Variable(torch.linspace(2*self.num_keypoints+1,2*self.num_keypoints+1+nC-1,nC).long().cuda()))
cls = cls.view(nB*nA, nC, nH*nW).transpose(1,2).contiguous().view(nB*nA*nH*nW, nC)
t1 = time.time()
# Create pred boxes
pred_corners = torch.cuda.FloatTensor(18, nB*nA*nH*nW)
pred_corners = torch.cuda.FloatTensor(2*self.num_keypoints, nB*nA*nH*nW)
grid_x = torch.linspace(0, nW-1, nW).repeat(nH,1).repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
grid_y = torch.linspace(0, nH-1, nH).repeat(nW,1).t().repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
pred_corners[0] = (x0.data + grid_x) / nW
pred_corners[1] = (y0.data + grid_y) / nH
pred_corners[2] = (x1.data + grid_x) / nW
pred_corners[3] = (y1.data + grid_y) / nH
pred_corners[4] = (x2.data + grid_x) / nW
pred_corners[5] = (y2.data + grid_y) / nH
pred_corners[6] = (x3.data + grid_x) / nW
pred_corners[7] = (y3.data + grid_y) / nH
pred_corners[8] = (x4.data + grid_x) / nW
pred_corners[9] = (y4.data + grid_y) / nH
pred_corners[10] = (x5.data + grid_x) / nW
pred_corners[11] = (y5.data + grid_y) / nH
pred_corners[12] = (x6.data + grid_x) / nW
pred_corners[13] = (y6.data + grid_y) / nH
pred_corners[14] = (x7.data + grid_x) / nW
pred_corners[15] = (y7.data + grid_y) / nH
pred_corners[16] = (x8.data + grid_x) / nW
pred_corners[17] = (y8.data + grid_y) / nH
gpu_matrix = pred_corners.transpose(0,1).contiguous().view(-1,18)
for i in range(self.num_keypoints):
pred_corners[2 * i + 0] = (x[i].data.view_as(grid_x) + grid_x) / nW
pred_corners[2 * i + 1] = (y[i].data.view_as(grid_y) + grid_y) / nH
gpu_matrix = pred_corners.transpose(0,1).contiguous().view(-1,2*self.num_keypoints)
pred_corners = convert2cpu(gpu_matrix)
t2 = time.time()
# Build targets
nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls = \
build_targets(pred_corners, target.data, self.anchors, nA, nC, nH, nW, self.noobject_scale, self.object_scale, self.thresh, self.seen)
nGT, nCorrect, coord_mask, conf_mask, cls_mask, txs, tys, tconf, tcls = \
build_targets(pred_corners, target.data, self.num_keypoints, self.anchors, nA, nC, nH, nW, self.noobject_scale, self.object_scale, self.thresh, self.seen)
cls_mask = (cls_mask == 1)
nProposals = int((conf > 0.25).sum().data[0])
tx0 = Variable(tx0.cuda())
ty0 = Variable(ty0.cuda())
tx1 = Variable(tx1.cuda())
ty1 = Variable(ty1.cuda())
tx2 = Variable(tx2.cuda())
ty2 = Variable(ty2.cuda())
tx3 = Variable(tx3.cuda())
ty3 = Variable(ty3.cuda())
tx4 = Variable(tx4.cuda())
ty4 = Variable(ty4.cuda())
tx5 = Variable(tx5.cuda())
ty5 = Variable(ty5.cuda())
tx6 = Variable(tx6.cuda())
ty6 = Variable(ty6.cuda())
tx7 = Variable(tx7.cuda())
ty7 = Variable(ty7.cuda())
tx8 = Variable(tx8.cuda())
ty8 = Variable(ty8.cuda())
for i in range(self.num_keypoints):
txs[i] = Variable(txs[i].cuda())
tys[i] = Variable(tys[i].cuda())
tconf = Variable(tconf.cuda())
tcls = Variable(tcls.view(-1)[cls_mask].long().cuda())
tcls = Variable(tcls[cls_mask].long().cuda())
coord_mask = Variable(coord_mask.cuda())
conf_mask = Variable(conf_mask.cuda().sqrt())
cls_mask = Variable(cls_mask.view(-1, 1).repeat(1,nC).cuda())
@ -267,35 +158,24 @@ class RegionLoss(nn.Module):
t3 = time.time()
# Create loss
loss_x0 = self.coord_scale * nn.MSELoss(size_average=False)(x0*coord_mask, tx0*coord_mask)/2.0
loss_y0 = self.coord_scale * nn.MSELoss(size_average=False)(y0*coord_mask, ty0*coord_mask)/2.0
loss_x1 = self.coord_scale * nn.MSELoss(size_average=False)(x1*coord_mask, tx1*coord_mask)/2.0
loss_y1 = self.coord_scale * nn.MSELoss(size_average=False)(y1*coord_mask, ty1*coord_mask)/2.0
loss_x2 = self.coord_scale * nn.MSELoss(size_average=False)(x2*coord_mask, tx2*coord_mask)/2.0
loss_y2 = self.coord_scale * nn.MSELoss(size_average=False)(y2*coord_mask, ty2*coord_mask)/2.0
loss_x3 = self.coord_scale * nn.MSELoss(size_average=False)(x3*coord_mask, tx3*coord_mask)/2.0
loss_y3 = self.coord_scale * nn.MSELoss(size_average=False)(y3*coord_mask, ty3*coord_mask)/2.0
loss_x4 = self.coord_scale * nn.MSELoss(size_average=False)(x4*coord_mask, tx4*coord_mask)/2.0
loss_y4 = self.coord_scale * nn.MSELoss(size_average=False)(y4*coord_mask, ty4*coord_mask)/2.0
loss_x5 = self.coord_scale * nn.MSELoss(size_average=False)(x5*coord_mask, tx5*coord_mask)/2.0
loss_y5 = self.coord_scale * nn.MSELoss(size_average=False)(y5*coord_mask, ty5*coord_mask)/2.0
loss_x6 = self.coord_scale * nn.MSELoss(size_average=False)(x6*coord_mask, tx6*coord_mask)/2.0
loss_y6 = self.coord_scale * nn.MSELoss(size_average=False)(y6*coord_mask, ty6*coord_mask)/2.0
loss_x7 = self.coord_scale * nn.MSELoss(size_average=False)(x7*coord_mask, tx7*coord_mask)/2.0
loss_y7 = self.coord_scale * nn.MSELoss(size_average=False)(y7*coord_mask, ty7*coord_mask)/2.0
loss_x8 = self.coord_scale * nn.MSELoss(size_average=False)(x8*coord_mask, tx8*coord_mask)/2.0
loss_y8 = self.coord_scale * nn.MSELoss(size_average=False)(y8*coord_mask, ty8*coord_mask)/2.0
loss_conf = nn.MSELoss(size_average=False)(conf*conf_mask, tconf*conf_mask)/2.0
loss_x = loss_x0 + loss_x1 + loss_x2 + loss_x3 + loss_x4 + loss_x5 + loss_x6 + loss_x7 + loss_x8
loss_y = loss_y0 + loss_y1 + loss_y2 + loss_y3 + loss_y4 + loss_y5 + loss_y6 + loss_y7 + loss_y8
loss_cls = self.class_scale * nn.CrossEntropyLoss(size_average=False)(cls, tcls)
loss = loss_x + loss_y + loss_conf + loss_cls
print('%d: nGT %d, recall %d, proposals %d, loss: x0: %f x %f, y0: %f y %f, conf %f, cls %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x0.data[0], loss_x.data[0], loss_y0.data[0], loss_y.data[0], loss_conf.data[0], loss_cls.data[0], loss.data[0]))
#else:
# loss = loss_x + loss_y + loss_conf
# print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss.data[0]))
loss_xs = list()
loss_ys = list()
for i in range(self.num_keypoints):
loss_xs.append(self.coord_scale * nn.MSELoss(size_average=False)(x[i]*coord_mask, txs[i]*coord_mask)/2.0)
loss_ys.append(self.coord_scale * nn.MSELoss(size_average=False)(y[i]*coord_mask, tys[i]*coord_mask)/2.0)
loss_conf = nn.MSELoss(size_average=False)(conf*conf_mask, tconf*conf_mask)/2.0
loss_x = np.sum(loss_xs)
loss_y = np.sum(loss_ys)
loss_cls = self.class_scale * nn.CrossEntropyLoss(size_average=False)(cls, tcls)
if epoch > self.pretrain_num_epochs:
loss = loss_x + loss_y + loss_cls + loss_conf # in single object pose estimation, there is no classification loss
else:
# pretrain initially without confidence loss
# once the coordinate predictions get better, start training for confidence as well
loss = loss_x + loss_y + loss_cls
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, cls %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss_cls.data[0], loss.data[0]))
t4 = time.time()
if False:

Просмотреть файл

@ -2,6 +2,7 @@ from __future__ import print_function
import os
os.sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import sys
import argparse
import time
import torch
import torch.nn as nn
@ -17,16 +18,10 @@ from torch.autograd import Variable # Useful info about autograd: http://pytorch
from darknet_multi import Darknet
from MeshPly import MeshPly
from utils import *
from utils_multi import *
from cfg import parse_cfg
import dataset_multi
from region_loss_multi import RegionLoss
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
import dataset_multi
# Adjust learning rate during training, learning schedule can be changed in network config file
def adjust_learning_rate(optimizer, batch):
@ -89,7 +84,7 @@ def train(epoch):
t6 = time.time()
region_loss.seen = region_loss.seen + data.data.size(0)
# Compute loss, grow an array of losses for saving later on
loss = region_loss(output, target)
loss = region_loss(output, target, epoch)
training_iters.append(epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter)
training_losses.append(convert2cpu(loss.data))
niter += 1
@ -125,7 +120,7 @@ def train(epoch):
t1 = time.time()
return epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter - 1
def eval(niter, datacfg, cfgfile):
def eval(niter, datacfg):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
@ -137,13 +132,12 @@ def eval(niter, datacfg, cfgfile):
meshname = options['mesh']
backupdir = options['backup']
name = options['name']
prefix = 'results'
# Read object model information, get 3D bounding box corners
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
# Read intrinsic camera parameters
internal_calibration = get_camera_intrinsic()
internal_calibration = get_camera_intrinsic(u0, v0, fx, fy)
# Get validation file names
with open(valid_images) as fp:
@ -194,8 +188,9 @@ def eval(niter, datacfg, cfgfile):
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
trgt = target[0].view(-1, 21)
all_boxes = get_corresponding_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, int(trgt[0][0]), only_objectness=0)
trgt = target[0].view(-1, num_labels)
all_boxes = get_multi_region_boxes(output, conf_thresh, num_classes, num_keypoints, anchors, num_anchors, int(trgt[0][0]), only_objectness=0)
t4 = time.time()
# Iterate through all batch elements
@ -205,7 +200,7 @@ def eval(niter, datacfg, cfgfile):
boxes = all_boxes[i]
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
truths = target[i].view(-1, num_labels)
# Get how many objects are present in the scene
num_gts = truths_length(truths)
@ -213,24 +208,23 @@ def eval(niter, datacfg, cfgfile):
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
box_gt = list()
for j in range(1, num_labels):
box_gt.append(truths[k][j])
box_gt.extend([1.0, 1.0])
box_gt.append(truths[k][0])
# If the prediction has the highest confidence, choose it as our prediction
best_conf_est = -sys.maxsize
for j in range(len(boxes)):
if (boxes[j][18] > best_conf_est) and (boxes[j][20] == int(truths[k][0])):
best_conf_est = boxes[j][18]
if (boxes[j][2*num_keypoints] > best_conf_est) and (boxes[j][2*num_keypoints+2] == int(truths[k][0])):
best_conf_est = boxes[j][2*num_keypoints]
box_pr = boxes[j]
bb2d_gt = get_2d_bb(box_gt[:18], output.size(3))
bb2d_pr = get_2d_bb(box_pr[:18], output.size(3))
iou = bbox_iou(bb2d_gt, bb2d_pr)
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
match = corner_confidence(box_gt[:2*num_keypoints], torch.FloatTensor(boxes[j][:2*num_keypoints]))
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt = np.array(np.reshape(box_gt[:2*num_keypoints], [num_keypoints, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:2*num_keypoints], [num_keypoints, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width
@ -244,23 +238,24 @@ def eval(niter, datacfg, cfgfile):
R_pr, t_pr = pnp(objpoints3D, corners2D_pr, K)
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, internal_calibration))
proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, internal_calibration))
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
# Sum errors
testing_error_pixel += pixel_dist
testing_samples += 1
testing_error_pixel += pixel_dist
testing_samples += 1
t5 = time.time()
# Compute 2D reprojection score
eps = 1e-5
for px_threshold in [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]:
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))
@ -281,72 +276,79 @@ def eval(niter, datacfg, cfgfile):
def test(niter):
cfgfile = 'cfg/yolo-pose-multi.cfg'
modelcfg = 'cfg/yolo-pose-multi.cfg'
datacfg = 'cfg/ape_occlusion.data'
logging("Testing ape...")
eval(niter, datacfg, cfgfile)
eval(niter, datacfg)
datacfg = 'cfg/can_occlusion.data'
logging("Testing can...")
eval(niter, datacfg, cfgfile)
eval(niter, datacfg)
datacfg = 'cfg/cat_occlusion.data'
logging("Testing cat...")
eval(niter, datacfg, cfgfile)
eval(niter, datacfg)
datacfg = 'cfg/duck_occlusion.data'
logging("Testing duck...")
eval(niter, datacfg, cfgfile)
eval(niter, datacfg)
datacfg = 'cfg/driller_occlusion.data'
logging("Testing driller...")
eval(niter, datacfg, cfgfile)
eval(niter, datacfg)
datacfg = 'cfg/glue_occlusion.data'
logging("Testing glue...")
eval(niter, datacfg, cfgfile)
# datacfg = 'cfg/holepuncher_occlusion.data'
# logging("Testing holepuncher...")
# eval(niter, datacfg, cfgfile)
eval(niter, datacfg)
if __name__ == "__main__":
# Training settings
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
weightfile = sys.argv[3]
# Parse command window input
parser = argparse.ArgumentParser(description='SingleShotPose')
parser.add_argument('--datacfg', type=str, default='cfg/occlusion.data') # data config
parser.add_argument('--modelcfg', type=str, default='cfg/yolo-pose-multi.cfg') # network config
parser.add_argument('--initweightfile', type=str, default='backup_multi/init.weights') # initialization weights
parser.add_argument('--pretrain_num_epochs', type=int, default=0) # how many epoch to pretrain
args = parser.parse_args()
datacfg = args.datacfg
modelcfg = args.modelcfg
initweightfile = args.initweightfile
pretrain_num_epochs = args.pretrain_num_epochs
# Parse configuration files
data_options = read_data_cfg(datacfg)
net_options = parse_cfg(cfgfile)[0]
trainlist = data_options['train']
nsamples = file_lines(trainlist)
gpus = data_options['gpus'] # e.g. 0,1,2,3
gpus = '0'
num_workers = int(data_options['num_workers'])
backupdir = data_options['backup']
if not os.path.exists(backupdir):
makedirs(backupdir)
# Parse data configuration file
data_options = read_data_cfg(datacfg)
trainlist = data_options['train']
gpus = data_options['gpus']
num_workers = int(data_options['num_workers'])
backupdir = data_options['backup']
im_width = int(data_options['im_width'])
im_height = int(data_options['im_height'])
fx = float(data_options['fx'])
fy = float(data_options['fy'])
u0 = float(data_options['u0'])
v0 = float(data_options['v0'])
# Parse network and training configuration parameters
net_options = parse_cfg(modelcfg)[0]
loss_options = parse_cfg(modelcfg)[-1]
batch_size = int(net_options['batch'])
max_batches = int(net_options['max_batches'])
max_epochs = int(net_options['max_epochs'])
learning_rate = float(net_options['learning_rate'])
momentum = float(net_options['momentum'])
decay = float(net_options['decay'])
conf_thresh = float(net_options['conf_thresh'])
num_keypoints = int(net_options['num_keypoints'])
num_classes = int(loss_options['classes'])
num_anchors = int(loss_options['num'])
steps = [float(step) for step in net_options['steps'].split(',')]
scales = [float(scale) for scale in net_options['scales'].split(',')]
bg_file_names = get_all_files('../VOCdevkit/VOC2012/JPEGImages')
anchors = [float(anchor) for anchor in loss_options['anchors'].split(',')]
# Train parameters
max_epochs = 700 # max_batches*batch_size/nsamples+1
# Further params
if not os.path.exists(backupdir):
makedirs(backupdir)
bg_file_names = get_all_files('../VOCdevkit/VOC2012/JPEGImages')
nsamples = file_lines(trainlist)
use_cuda = True
seed = int(time.time())
eps = 1e-5
save_interval = 10 # epoches
dot_interval = 70 # batches
best_acc = -1
# Test parameters
conf_thresh = 0.05
nms_thresh = 0.4
match_thresh = 0.5
iou_thresh = 0.5
im_width = 640
im_height = 480
best_acc = -sys.maxsize
num_labels = num_keypoints*2+3 # + 2 for image width, height, +1 for image class
# Specify which gpus to use
torch.manual_seed(seed)
@ -355,12 +357,11 @@ if __name__ == "__main__":
torch.cuda.manual_seed(seed)
# Specifiy the model and the loss
model = Darknet(cfgfile)
region_loss = model.loss
model = Darknet(modelcfg)
region_loss = RegionLoss(num_keypoints=num_keypoints, num_classes=num_classes, anchors=anchors, num_anchors=num_anchors, pretrain_num_epochs=pretrain_num_epochs)
# Model settings
# model.load_weights(weightfile)
model.load_weights_until_last(weightfile)
model.load_weights_until_last(initweightfile)
model.print_network()
model.seen = 0
region_loss.iter = model.iter
@ -368,20 +369,18 @@ if __name__ == "__main__":
processed_batches = model.seen/batch_size
init_width = model.width
init_height = model.height
init_epoch = model.seen/nsamples
# Variable to save
training_iters = []
training_losses = []
testing_iters = []
testing_errors_pixel = []
testing_accuracies = []
init_epoch = model.seen//nsamples
# Variables to save
training_iters = []
training_losses = []
testing_iters = []
testing_errors_pixel = []
testing_accuracies = []
# Specify the number of workers
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
# Pass the model to GPU
if use_cuda:
# model = model.cuda()
@ -396,7 +395,6 @@ if __name__ == "__main__":
else:
params += [{'params': [value], 'weight_decay': decay*batch_size}]
optimizer = optim.SGD(model.parameters(), lr=learning_rate/batch_size, momentum=momentum, dampening=0, weight_decay=decay*batch_size)
# optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam optimization
evaluate = False
if evaluate:
@ -416,9 +414,9 @@ if __name__ == "__main__":
testing_iters=testing_iters,
testing_accuracies=testing_accuracies,
testing_errors_pixel=testing_errors_pixel)
if (np.mean(testing_accuracies[-5:]) > best_acc ):
best_acc = np.mean(testing_accuracies[-5:])
if (np.mean(testing_accuracies[-6:]) > best_acc ): # testing for 6 different objects
best_acc = np.mean(testing_accuracies[-6:])
logging('best model so far!')
logging('save weights to %s/model.weights' % (backupdir))
model.module.save_weights('%s/model.weights' % (backupdir))
shutil.copy2('%s/model.weights' % (backupdir), '%s/model_backup.weights' % (backupdir))
# shutil.copy2('%s/model.weights' % (backupdir), '%s/model_backup.weights' % (backupdir))

Просмотреть файл

@ -0,0 +1,502 @@
import sys
import os
import time
import math
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from torch.autograd import Variable
import torch.nn.functional as F
import cv2
from scipy import spatial
import struct
import imghdr
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
# Get all the files within a directory
def get_all_files(directory):
files = []
for f in os.listdir(directory):
if os.path.isfile(os.path.join(directory, f)):
files.append(os.path.join(directory, f))
else:
files.extend(get_all_files(os.path.join(directory, f)))
return files
# Calculate angular distance between two rotations
def calcAngularDistance(gt_rot, pr_rot):
rotDiff = np.dot(gt_rot, np.transpose(pr_rot))
trace = np.trace(rotDiff)
return np.rad2deg(np.arccos((trace-1.0)/2.0))
# Get camera intrinsic matrix
def get_camera_intrinsic(u0, v0, fx, fy):
'''fx, fy: focal length parameters, u0, v0: principal point offset parameters'''
return np.array([[fx, 0.0, u0], [0.0, fy, v0], [0.0, 0.0, 1.0]])
# Compute the projection of an array of 3D points onto a 2D image given the intrinsics and extrinsics
def compute_projection(points_3D, transformation, internal_calibration):
projections_2d = np.zeros((2, points_3D.shape[1]), dtype='float32')
camera_projection = (internal_calibration.dot(transformation)).dot(points_3D)
projections_2d[0, :] = camera_projection[0, :]/camera_projection[2, :]
projections_2d[1, :] = camera_projection[1, :]/camera_projection[2, :]
return projections_2d
# Transform an array of 3D points in the 3D space using extrinsics
def compute_transformation(points_3D, transformation):
return transformation.dot(points_3D)
# Calculate the diameter of an object model, diameter is defined as the longest distance between all the pairwise distances in the object model
def calc_pts_diameter(pts):
diameter = -1
for pt_id in range(pts.shape[0]):
pt_dup = np.tile(np.array([pts[pt_id, :]]), [pts.shape[0] - pt_id, 1])
pts_diff = pt_dup - pts[pt_id:, :]
max_dist = math.sqrt((pts_diff * pts_diff).sum(axis=1).max())
if max_dist > diameter:
diameter = max_dist
return diameter
# Compute adi metric, see https://github.com/thodan/obj_pose_eval/blob/master/obj_pose_eval/pose_error.py for further info
def adi(pts_est, pts_gt):
nn_index = spatial.cKDTree(pts_est)
nn_dists, _ = nn_index.query(pts_gt, k=1)
e = nn_dists.mean()
return e
# Get the 3D corners of the bounding box surrounding the object model
def get_3D_corners(vertices):
min_x = np.min(vertices[0,:])
max_x = np.max(vertices[0,:])
min_y = np.min(vertices[1,:])
max_y = np.max(vertices[1,:])
min_z = np.min(vertices[2,:])
max_z = np.max(vertices[2,:])
corners = np.array([[min_x, min_y, min_z],
[min_x, min_y, max_z],
[min_x, max_y, min_z],
[min_x, max_y, max_z],
[max_x, min_y, min_z],
[max_x, min_y, max_z],
[max_x, max_y, min_z],
[max_x, max_y, max_z]])
corners = np.concatenate((np.transpose(corners), np.ones((1,8)) ), axis=0)
return corners
# Compute pose using PnP
def pnp(points_3D, points_2D, cameraMatrix):
try:
distCoeffs = pnp.distCoeffs
except:
distCoeffs = np.zeros((8, 1), dtype='float32') # 8 distortion-coefficient model
assert points_3D.shape[0] == points_2D.shape[0], 'points 3D and points 2D must have same number of vertices'
_, R_exp, t = cv2.solvePnP(points_3D,
np.ascontiguousarray(points_2D[:,:2]).reshape((-1,1,2)),
cameraMatrix,
distCoeffs)
R, _ = cv2.Rodrigues(R_exp)
return R, t
# Get the tightest bounding box surrounding keypoints
def get_2d_bb(box, size):
x = box[0]
y = box[1]
min_x = np.min(np.reshape(box, [-1,2])[:,0])
max_x = np.max(np.reshape(box, [-1,2])[:,0])
min_y = np.min(np.reshape(box, [-1,2])[:,1])
max_y = np.max(np.reshape(box, [-1,2])[:,1])
w = max_x - min_x
h = max_y - min_y
new_box = [x*size, y*size, w*size, h*size]
return new_box
# Compute IoU between two bounding boxes
def bbox_iou(box1, box2, x1y1x2y2=False):
if x1y1x2y2:
mx = min(box1[0], box2[0])
Mx = max(box1[2], box2[2])
my = min(box1[1], box2[1])
My = max(box1[3], box2[3])
w1 = box1[2] - box1[0]
h1 = box1[3] - box1[1]
w2 = box2[2] - box2[0]
h2 = box2[3] - box2[1]
else:
mx = min(box1[0]-box1[2]/2.0, box2[0]-box2[2]/2.0)
Mx = max(box1[0]+box1[2]/2.0, box2[0]+box2[2]/2.0)
my = min(box1[1]-box1[3]/2.0, box2[1]-box2[3]/2.0)
My = max(box1[1]+box1[3]/2.0, box2[1]+box2[3]/2.0)
w1 = box1[2]
h1 = box1[3]
w2 = box2[2]
h2 = box2[3]
uw = Mx - mx
uh = My - my
cw = w1 + w2 - uw
ch = h1 + h2 - uh
carea = 0
if cw <= 0 or ch <= 0:
return 0.0
area1 = w1 * h1
area2 = w2 * h2
carea = cw * ch
uarea = area1 + area2 - carea
return carea/uarea
# Compute confidences of current keypoint predictions
def corner_confidences(gt_corners, pr_corners, th=80, sharpness=2, im_width=640, im_height=480):
''' gt_corners: Ground-truth 2D projections of the 3D bounding box corners, shape: (16 x nA), type: torch.FloatTensor
pr_corners: Prediction for the 2D projections of the 3D bounding box corners, shape: (16 x nA), type: torch.FloatTensor
th : distance threshold, type: int
sharpness : sharpness of the exponential that assigns a confidence value to the distance
-----------
return : a torch.FloatTensor of shape (nA,) with 9 confidence values
'''
shape = gt_corners.size()
nA = shape[1]
dist = gt_corners - pr_corners
num_el = dist.numel()
num_keypoints = num_el//(nA*2)
dist = dist.t().contiguous().view(nA, num_keypoints, 2)
dist[:, :, 0] = dist[:, :, 0] * im_width
dist[:, :, 1] = dist[:, :, 1] * im_height
eps = 1e-5
distthresh = torch.FloatTensor([th]).repeat(nA, num_keypoints)
dist = torch.sqrt(torch.sum((dist)**2, dim=2)).squeeze() # nA x 9
mask = (dist < distthresh).type(torch.FloatTensor)
conf = torch.exp(sharpness*(1 - dist/distthresh))-1 # mask * (torch.exp(math.log(2) * (1.0 - dist/rrt)) - 1)
conf0 = torch.exp(sharpness*(1 - torch.zeros(conf.size(0),1))) - 1
conf = conf / conf0.repeat(1, num_keypoints)
# conf = 1 - dist/distthresh
conf = mask * conf # nA x 9
mean_conf = torch.mean(conf, dim=1)
return mean_conf
# Compute confidence of the current keypoint prediction
def corner_confidence(gt_corners, pr_corners, th=80, sharpness=2, im_width=640, im_height=480):
''' gt_corners: Ground-truth 2D projections of the 3D bounding box corners, shape: (18,) type: list
pr_corners: Prediction for the 2D projections of the 3D bounding box corners, shape: (18,), type: list
th : distance threshold, type: int
sharpness : sharpness of the exponential that assigns a confidence value to the distance
-----------
return : a list of shape (9,) with 9 confidence values
'''
dist = torch.FloatTensor(gt_corners) - pr_corners
num_keypoints = dist.numel()//2
dist = dist.view(num_keypoints, 2)
dist[:, 0] = dist[:, 0] * im_width
dist[:, 1] = dist[:, 1] * im_height
eps = 1e-5
dist = torch.sqrt(torch.sum((dist)**2, dim=1))
mask = (dist < th).type(torch.FloatTensor)
conf = torch.exp(sharpness * (1.0 - dist/th)) - 1
conf0 = torch.exp(torch.FloatTensor([sharpness])) - 1 + eps
conf = conf / conf0.repeat(num_keypoints, 1)
# conf = 1.0 - dist/th
conf = mask * conf
return torch.mean(conf)
# Compute sigmoid
def sigmoid(x):
return 1.0/(math.exp(-x)+1.)
# Compute softmax function
def softmax(x):
x = torch.exp(x - torch.max(x))
x = x/x.sum()
return x
# Apply non-maxima suppression on a set of bounding boxes
def nms(boxes, nms_thresh):
if len(boxes) == 0:
return boxes
det_confs = torch.zeros(len(boxes))
for i in range(len(boxes)):
det_confs[i] = 1-boxes[i][4]
_,sortIds = torch.sort(det_confs)
out_boxes = []
for i in range(len(boxes)):
box_i = boxes[sortIds[i]]
if box_i[4] > 0:
out_boxes.append(box_i)
for j in range(i+1, len(boxes)):
box_j = boxes[sortIds[j]]
if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
box_j[4] = 0
return out_boxes
# Fix the wrong order of corners on the Occlusion dataset
def fix_corner_order(corners2D_gt):
corners2D_gt_corrected = np.zeros((9, 2), dtype='float32')
corners2D_gt_corrected[0, :] = corners2D_gt[0, :]
corners2D_gt_corrected[1, :] = corners2D_gt[1, :]
corners2D_gt_corrected[2, :] = corners2D_gt[3, :]
corners2D_gt_corrected[3, :] = corners2D_gt[5, :]
corners2D_gt_corrected[4, :] = corners2D_gt[7, :]
corners2D_gt_corrected[5, :] = corners2D_gt[2, :]
corners2D_gt_corrected[6, :] = corners2D_gt[4, :]
corners2D_gt_corrected[7, :] = corners2D_gt[6, :]
corners2D_gt_corrected[8, :] = corners2D_gt[8, :]
return corners2D_gt_corrected
# Convert float tensors in GPU to tensors in CPU
def convert2cpu(gpu_matrix):
return torch.FloatTensor(gpu_matrix.size()).copy_(gpu_matrix)
# Convert long tensors in GPU to tensors in CPU
def convert2cpu_long(gpu_matrix):
return torch.LongTensor(gpu_matrix.size()).copy_(gpu_matrix)
# Get potential sets of predictions at test time
def get_multi_region_boxes(output, conf_thresh, num_classes, num_keypoints, anchors, num_anchors, correspondingclass, only_objectness=1, validation=False):
# Parameters
anchor_step = len(anchors)//num_anchors
if output.dim() == 3:
output = output.unsqueeze(0)
batch = output.size(0)
assert(output.size(1) == (2*num_keypoints+1+num_classes)*num_anchors)
h = output.size(2)
w = output.size(3)
# Activation
t0 = time.time()
all_boxes = []
max_conf = -sys.maxsize
max_cls_conf = -sys.maxsize
output = output.view(batch*num_anchors, 2*num_keypoints+1+num_classes, h*w).transpose(0,1).contiguous().view(2*num_keypoints+1+num_classes, batch*num_anchors*h*w)
grid_x = torch.linspace(0, w-1, w).repeat(h,1).repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
grid_y = torch.linspace(0, h-1, h).repeat(w,1).t().repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
xs = list()
ys = list()
xs.append(torch.sigmoid(output[0]) + grid_x)
ys.append(torch.sigmoid(output[1]) + grid_y)
for j in range(1,num_keypoints):
xs.append(output[2*j + 0] + grid_x)
ys.append(output[2*j + 1] + grid_y)
det_confs = torch.sigmoid(output[2*num_keypoints])
cls_confs = torch.nn.Softmax()(Variable(output[2*num_keypoints+1:2*num_keypoints+1+num_classes].transpose(0,1))).data
cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
cls_max_confs = cls_max_confs.view(-1)
cls_max_ids = cls_max_ids.view(-1)
t1 = time.time()
# GPU to CPU
sz_hw = h*w
sz_hwa = sz_hw*num_anchors
det_confs = convert2cpu(det_confs)
cls_max_confs = convert2cpu(cls_max_confs)
cls_max_ids = convert2cpu_long(cls_max_ids)
for j in range(num_keypoints):
xs[j] = convert2cpu(xs[j])
ys[j] = convert2cpu(ys[j])
if validation:
cls_confs = convert2cpu(cls_confs.view(-1, num_classes))
t2 = time.time()
# Boxes filter
for b in range(batch):
boxes = []
max_conf = -1
for cy in range(h):
for cx in range(w):
for i in range(num_anchors):
ind = b*sz_hwa + i*sz_hw + cy*w + cx
det_conf = det_confs[ind]
if only_objectness:
conf = det_confs[ind]
else:
conf = det_confs[ind] * cls_max_confs[ind]
if (det_confs[ind] > max_conf) and (cls_confs[ind, correspondingclass] > max_cls_conf):
max_conf = det_confs[ind]
max_cls_conf = cls_confs[ind, correspondingclass]
max_ind = ind
if conf > conf_thresh:
bcx = list()
bcy = list()
for j in range(num_keypoints):
bcx.append(xs[j][ind])
bcy.append(ys[j][ind])
cls_max_conf = cls_max_confs[ind]
cls_max_id = cls_max_ids[ind]
box = list()
for j in range(num_keypoints):
box.append(bcx[j]/w)
box.append(bcy[j]/h)
box.append(det_conf)
box.append(cls_max_conf)
box.append(cls_max_id)
if (not only_objectness) and validation:
for c in range(num_classes):
tmp_conf = cls_confs[ind][c]
if c != cls_max_id and det_confs[ind]*tmp_conf > conf_thresh:
box.append(tmp_conf)
box.append(c)
boxes.append(box)
if (len(boxes) == 0) or (not (correspondingclass in np.array(boxes)[:,2*num_keypoints+2])):
bcx = list()
bcy = list()
for j in range(num_keypoints):
bcx.append(xs[j][max_ind])
bcy.append(ys[j][max_ind])
cls_max_conf = max_cls_conf # cls_max_confs[max_ind]
cls_max_id = correspondingclass # cls_max_ids[max_ind]
det_conf = max_conf # det_confs[max_ind]
box = list()
for j in range(num_keypoints):
box.append(bcx[j]/w)
box.append(bcy[j]/h)
box.append(det_conf)
box.append(cls_max_conf)
box.append(cls_max_id)
boxes.append(box)
all_boxes.append(boxes)
else:
all_boxes.append(boxes)
t3 = time.time()
if False:
print('---------------------------------')
print('matrix computation : %f' % (t1-t0))
print(' gpu to cpu : %f' % (t2-t1))
print(' boxes filter : %f' % (t3-t2))
print('---------------------------------')
return all_boxes
# Read the labels from the file
def read_truths(lab_path, num_keypoints=9):
num_labels = 2*num_keypoints+3 # +2 for width, height, +1 for class label
if os.path.getsize(lab_path):
truths = np.loadtxt(lab_path)
truths = truths.reshape(truths.size//num_labels, num_labels) # to avoid single truth problem
return truths
else:
return np.array([])
def read_truths_args(lab_path, num_keypoints=9):
num_labels = 2*num_keypoints+1
truths = read_truths(lab_path)
new_truths = []
for i in range(truths.shape[0]):
for j in range(num_labels):
new_truths.append(truths[i][j])
return np.array(new_truths)
def read_pose(lab_path):
if os.path.getsize(lab_path):
truths = np.loadtxt(lab_path)
return truths
else:
return np.array([])
def load_class_names(namesfile):
class_names = []
with open(namesfile, 'r') as fp:
lines = fp.readlines()
for line in lines:
line = line.rstrip()
class_names.append(line)
return class_names
def image2torch(img):
width = img.width
height = img.height
img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes()))
img = img.view(height, width, 3).transpose(0,1).transpose(0,2).contiguous()
img = img.view(1, 3, height, width)
img = img.float().div(255.0)
return img
def read_data_cfg(datacfg):
options = dict()
options['gpus'] = '0,1,2,3'
options['num_workers'] = '10'
with open(datacfg, 'r') as fp:
lines = fp.readlines()
for line in lines:
line = line.strip()
if line == '':
continue
key,value = line.split('=')
key = key.strip()
value = value.strip()
options[key] = value
return options
def scale_bboxes(bboxes, width, height):
import copy
dets = copy.deepcopy(bboxes)
for i in range(len(dets)):
dets[i][0] = dets[i][0] * width
dets[i][1] = dets[i][1] * height
dets[i][2] = dets[i][2] * width
dets[i][3] = dets[i][3] * height
return dets
def file_lines(thefilepath):
count = 0
thefile = open(thefilepath, 'rb')
while True:
buffer = thefile.read(8192*1024)
if not buffer:
break
count += buffer.count(b'\n')
thefile.close( )
return count
def get_image_size(fname):
'''Determine the image type of fhandle and return its size.
from draco'''
with open(fname, 'rb') as fhandle:
head = fhandle.read(24)
if len(head) != 24:
return
if imghdr.what(fname) == 'png':
check = struct.unpack('>i', head[4:8])[0]
if check != 0x0d0a1a0a:
return
width, height = struct.unpack('>ii', head[16:24])
elif imghdr.what(fname) == 'gif':
width, height = struct.unpack('<HH', head[6:10])
elif imghdr.what(fname) == 'jpeg' or imghdr.what(fname) == 'jpg':
try:
fhandle.seek(0) # Read 0xff next
size = 2
ftype = 0
while not 0xc0 <= ftype <= 0xcf:
fhandle.seek(size, 1)
byte = fhandle.read(1)
while ord(byte) == 0xff:
byte = fhandle.read(1)
ftype = ord(byte)
size = struct.unpack('>H', fhandle.read(2))[0] - 2
# We are at a SOFn block
fhandle.seek(1, 1) # Skip `precision' byte.
height, width = struct.unpack('>HH', fhandle.read(4))
except Exception: #IGNORE:W0703
return
else:
return
return width, height
def logging(message):
print('%s %s' % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), message))

Просмотреть файл

@ -1,77 +1,75 @@
import os
os.sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from torch.autograd import Variable
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import scipy.misc
import warnings
import sys
import argparse
warnings.filterwarnings("ignore")
from torch.autograd import Variable
from torchvision import datasets, transforms
from darknet_multi import Darknet
from utils import *
import dataset_multi
from darknet_multi import Darknet
from utils_multi import *
from cfg import parse_cfg
from MeshPly import MeshPly
def valid(datacfg, cfgfile, weightfile, conf_th):
def valid(datacfg, cfgfile, weightfile):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
# Parse configuration files
options = read_data_cfg(datacfg)
valid_images = options['valid']
meshname = options['mesh']
name = options['name']
prefix = 'results'
# Read object model information, get 3D bounding box corners
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
diam = float(options['diam'])
# Parse data configuration files
data_options = read_data_cfg(datacfg)
valid_images = data_options['valid']
meshname = data_options['mesh']
name = data_options['name']
im_width = int(data_options['im_width'])
im_height = int(data_options['im_height'])
fx = float(data_options['fx'])
fy = float(data_options['fy'])
u0 = float(data_options['u0'])
v0 = float(data_options['v0'])
# Parse net configuration file
net_options = parse_cfg(cfgfile)[0]
loss_options = parse_cfg(cfgfile)[-1]
conf_thresh = float(net_options['conf_thresh'])
num_keypoints = int(net_options['num_keypoints'])
num_classes = int(loss_options['classes'])
num_anchors = int(loss_options['num'])
anchors = [float(anchor) for anchor in loss_options['anchors'].split(',')]
# Read intrinsic camera parameters
internal_calibration = get_camera_intrinsic()
# Read object model information, get 3D bounding box corners, get intrinsics
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
diam = float(data_options['diam'])
intrinsic_calibration = get_camera_intrinsic(u0, v0, fx, fy) # camera params
# Get validation file names
with open(valid_images) as fp:
# Network I/O params
num_labels = 2*num_keypoints+3 # +2 for width, height, +1 for object class
errs_2d = [] # to save
with open(valid_images) as fp: # validation file names
tmp_files = fp.readlines()
valid_files = [item.rstrip() for item in tmp_files]
# Compute-related Parameters
use_cuda = True # whether to use cuda or no
kwargs = {'num_workers': 4, 'pin_memory': True} # number of workers etc.
# Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode
model = Darknet(cfgfile)
model.load_weights(weightfile)
model.cuda()
model.eval()
# Get the parser for the test dataset
valid_dataset = dataset_multi.listDataset(valid_images, shape=(model.width, model.height),
shuffle=False,
objclass=name,
transform=transforms.Compose([
transforms.ToTensor(),
]))
valid_batchsize = 1
# Specify the number of workers for multiple processing, get the dataloader for the test dataset
kwargs = {'num_workers': 4, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs)
# Parameters
use_cuda = True
num_classes = 13
anchors = [1.4820, 2.2412, 2.0501, 3.1265, 2.3946, 4.6891, 3.1018, 3.9910, 3.4879, 5.8851]
num_anchors = 5
eps = 1e-5
conf_thresh = conf_th
iou_thresh = 0.5
# Parameters to save
errs_2d = []
edges = [[1, 2], [1, 3], [1, 5], [2, 4], [2, 6], [3, 4], [3, 7], [4, 8], [5, 6], [5, 7], [6, 8], [7, 8]]
edges_corners = [[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]]
# Get the dataloader for the test dataset
valid_dataset = dataset_multi.listDataset(valid_images, shape=(model.width, model.height), shuffle=False, objclass=name, transform=transforms.Compose([transforms.ToTensor(),]))
test_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=False, **kwargs)
# Iterate through test batches (Batch size for test data is 1)
logging('Testing {}...'.format(name))
@ -92,8 +90,8 @@ def valid(datacfg, cfgfile, weightfile, conf_th):
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
trgt = target[0].view(-1, 21)
all_boxes = get_corresponding_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, int(trgt[0][0]), only_objectness=0)
trgt = target[0].view(-1, num_labels)
all_boxes = get_multi_region_boxes(output, conf_thresh, num_classes, num_keypoints, anchors, num_anchors, int(trgt[0][0]), only_objectness=0)
t4 = time.time()
# Iterate through all images in the batch
@ -103,51 +101,49 @@ def valid(datacfg, cfgfile, weightfile, conf_th):
boxes = all_boxes[i]
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
truths = target[i].view(-1, num_labels)
# Get how many object are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
box_gt = list()
for j in range(1, num_labels):
box_gt.append(truths[k][j])
box_gt.extend([1.0, 1.0])
box_gt.append(truths[k][0])
# If the prediction has the highest confidence, choose it as our prediction
best_conf_est = -sys.maxsize
for j in range(len(boxes)):
if (boxes[j][18] > best_conf_est) and (boxes[j][20] == int(truths[k][0])):
best_conf_est = boxes[j][18]
if (boxes[j][2*num_keypoints] > best_conf_est) and (boxes[j][2*num_keypoints+2] == int(truths[k][0])):
best_conf_est = boxes[j][2*num_keypoints]
box_pr = boxes[j]
bb2d_gt = get_2d_bb(box_gt[:18], output.size(3))
bb2d_pr = get_2d_bb(box_pr[:18], output.size(3))
iou = bbox_iou(bb2d_gt, bb2d_pr)
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
match = corner_confidence(box_gt[:2*num_keypoints], torch.FloatTensor(boxes[j][:2*num_keypoints]))
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640
corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480
corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640
corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480
corners2D_gt = np.array(np.reshape(box_gt[:2*num_keypoints], [-1, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:2*num_keypoints], [-1, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width
corners2D_pr[:, 1] = corners2D_pr[:, 1] * im_height
corners2D_gt_corrected = fix_corner_order(corners2D_gt) # Fix the order of corners
# Compute [R|t] by pnp
objpoints3D = np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32')
K = np.array(internal_calibration, dtype='float32')
K = np.array(intrinsic_calibration, dtype='float32')
R_gt, t_gt = pnp(objpoints3D, corners2D_gt_corrected, K)
R_pr, t_pr = pnp(objpoints3D, corners2D_pr, K)
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, internal_calibration))
proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, internal_calibration))
proj_2d_gt = compute_projection(vertices, Rt_gt, intrinsic_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, intrinsic_calibration)
proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, intrinsic_calibration))
proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, intrinsic_calibration))
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
@ -155,29 +151,28 @@ def valid(datacfg, cfgfile, weightfile, conf_th):
t5 = time.time()
# Compute 2D projection score
eps = 1e-5
for px_threshold in [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]:
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
# Print test statistics
logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))
if __name__ == '__main__' and __package__ is None:
import sys
if len(sys.argv) == 3:
conf_th = 0.05
cfgfile = sys.argv[1]
weightfile = sys.argv[2]
datacfg = 'cfg/ape_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/can_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/cat_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/duck_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/glue_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/holepuncher_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
else:
print('Usage:')
print(' python valid.py cfgfile weightfile')
parser = argparse.ArgumentParser(description='SingleShotPose')
parser.add_argument('--modelcfg', type=str, default='cfg/yolo-pose-multi.cfg') # network config
parser.add_argument('--initweightfile', type=str, default='backup_multi/model_backup.weights') # initialization weights
args = parser.parse_args()
datacfg = 'cfg/ape_occlusion.data'
valid(datacfg, args.modelcfg, args.initweightfile)
datacfg = 'cfg/can_occlusion.data'
valid(datacfg, args.modelcfg, args.initweightfile)
datacfg = 'cfg/cat_occlusion.data'
valid(datacfg, args.modelcfg, args.initweightfile)
datacfg = 'cfg/duck_occlusion.data'
valid(datacfg, args.modelcfg, args.initweightfile)
datacfg = 'cfg/glue_occlusion.data'
valid(datacfg, args.modelcfg, args.initweightfile)
datacfg = 'cfg/holepuncher_occlusion.data'
valid(datacfg, args.modelcfg, args.initweightfile)

332
py2/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,332 @@
## Ignore Visual Studio temporary files, build results, and
## files generated by popular Visual Studio add-ons.
##
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
# User-specific files
*.suo
*.user
*.userosscache
*.sln.docstates
*.ipynb_checkpoints
*.DS_Store
# User-specific files (MonoDevelop/Xamarin Studio)
*.userprefs
# Build results
[Dd]ebug/
[Dd]ebugPublic/
[Rr]elease/
[Rr]eleases/
x64/
x86/
bld/
[Bb]in/
[Oo]bj/
[Ll]og/
# Visual Studio 2015/2017 cache/options directory
.vs/
# Uncomment if you have tasks that create the project's static files in wwwroot
#wwwroot/
# Visual Studio 2017 auto generated files
Generated\ Files/
# MSTest test Results
[Tt]est[Rr]esult*/
[Bb]uild[Ll]og.*
# NUNIT
*.VisualState.xml
TestResult.xml
# Build Results of an ATL Project
[Dd]ebugPS/
[Rr]eleasePS/
dlldata.c
# Benchmark Results
BenchmarkDotNet.Artifacts/
# .NET Core
project.lock.json
project.fragment.lock.json
artifacts/
**/Properties/launchSettings.json
# StyleCop
StyleCopReport.xml
# Files built by Visual Studio
*_i.c
*_p.c
*_i.h
*.ilk
*.meta
*.obj
*.iobj
*.pch
*.pdb
*.ipdb
*.pgc
*.pgd
*.rsp
*.sbr
*.tlb
*.tli
*.tlh
*.tmp
*.tmp_proj
*.log
*.vspscc
*.vssscc
.builds
*.pidb
*.svclog
*.scc
# Chutzpah Test files
_Chutzpah*
# Visual C++ cache files
ipch/
*.aps
*.ncb
*.opendb
*.opensdf
*.sdf
*.cachefile
*.VC.db
*.VC.VC.opendb
# Visual Studio profiler
*.psess
*.vsp
*.vspx
*.sap
# Visual Studio Trace Files
*.e2e
# TFS 2012 Local Workspace
$tf/
# Guidance Automation Toolkit
*.gpState
# ReSharper is a .NET coding add-in
_ReSharper*/
*.[Rr]e[Ss]harper
*.DotSettings.user
# JustCode is a .NET coding add-in
.JustCode
# TeamCity is a build add-in
_TeamCity*
# DotCover is a Code Coverage Tool
*.dotCover
# AxoCover is a Code Coverage Tool
.axoCover/*
!.axoCover/settings.json
# Visual Studio code coverage results
*.coverage
*.coveragexml
# NCrunch
_NCrunch_*
.*crunch*.local.xml
nCrunchTemp_*
# MightyMoose
*.mm.*
AutoTest.Net/
# Web workbench (sass)
.sass-cache/
# Installshield output folder
[Ee]xpress/
# DocProject is a documentation generator add-in
DocProject/buildhelp/
DocProject/Help/*.HxT
DocProject/Help/*.HxC
DocProject/Help/*.hhc
DocProject/Help/*.hhk
DocProject/Help/*.hhp
DocProject/Help/Html2
DocProject/Help/html
# Click-Once directory
publish/
# Publish Web Output
*.[Pp]ublish.xml
*.azurePubxml
# Note: Comment the next line if you want to checkin your web deploy settings,
# but database connection strings (with potential passwords) will be unencrypted
*.pubxml
*.publishproj
# Microsoft Azure Web App publish settings. Comment the next line if you want to
# checkin your Azure Web App publish settings, but sensitive information contained
# in these scripts will be unencrypted
PublishScripts/
# NuGet Packages
*.nupkg
# The packages folder can be ignored because of Package Restore
**/[Pp]ackages/*
# except build/, which is used as an MSBuild target.
!**/[Pp]ackages/build/
# Uncomment if necessary however generally it will be regenerated when needed
#!**/[Pp]ackages/repositories.config
# NuGet v3's project.json files produces more ignorable files
*.nuget.props
*.nuget.targets
# Microsoft Azure Build Output
csx/
*.build.csdef
# Microsoft Azure Emulator
ecf/
rcf/
# Windows Store app package directories and files
AppPackages/
BundleArtifacts/
Package.StoreAssociation.xml
_pkginfo.txt
*.appx
# Visual Studio cache files
# files ending in .cache can be ignored
*.[Cc]ache
# but keep track of directories ending in .cache
!*.[Cc]ache/
# Others
ClientBin/
~$*
*~
*.dbmdl
*.dbproj.schemaview
*.jfm
*.pfx
*.publishsettings
orleans.codegen.cs
# Including strong name files can present a security risk
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
#*.snk
# Since there are multiple workflows, uncomment next line to ignore bower_components
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
#bower_components/
# RIA/Silverlight projects
Generated_Code/
# Backup & report files from converting an old project file
# to a newer Visual Studio version. Backup files are not needed,
# because we have git ;-)
_UpgradeReport_Files/
Backup*/
UpgradeLog*.XML
UpgradeLog*.htm
ServiceFabricBackup/
*.rptproj.bak
# SQL Server files
*.mdf
*.ldf
*.ndf
# Business Intelligence projects
*.rdl.data
*.bim.layout
*.bim_*.settings
*.rptproj.rsuser
# Microsoft Fakes
FakesAssemblies/
# GhostDoc plugin setting file
*.GhostDoc.xml
# Node.js Tools for Visual Studio
.ntvs_analysis.dat
node_modules/
# Visual Studio 6 build log
*.plg
# Visual Studio 6 workspace options file
*.opt
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
*.vbw
# Visual Studio LightSwitch build output
**/*.HTMLClient/GeneratedArtifacts
**/*.DesktopClient/GeneratedArtifacts
**/*.DesktopClient/ModelManifest.xml
**/*.Server/GeneratedArtifacts
**/*.Server/ModelManifest.xml
_Pvt_Extensions
# Paket dependency manager
.paket/paket.exe
paket-files/
# FAKE - F# Make
.fake/
# JetBrains Rider
.idea/
*.sln.iml
# CodeRush
.cr/
# Python Tools for Visual Studio (PTVS)
__pycache__/
*.pyc
# Cake - Uncomment if you are using it
# tools/**
# !tools/packages.config
# Tabs Studio
*.tss
# Telerik's JustMock configuration file
*.jmconfig
# BizTalk build output
*.btp.cs
*.btm.cs
*.odx.cs
*.xsd.cs
# OpenCover UI analysis results
OpenCover/
# Azure Stream Analytics local run output
ASALocalRun/
# MSBuild Binary and Structured Log
*.binlog
# NVidia Nsight GPU debugger configuration file
*.nvuser
# MFractors (Xamarin productivity tool) working folder
.mfractor/

13
py2/LICENSE.txt Normal file
Просмотреть файл

@ -0,0 +1,13 @@
Single Shot Seamless Object Pose Estimation
Copyright (c) Microsoft Corporation
All rights reserved.
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the Software), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

53
py2/MeshPly.py Normal file
Просмотреть файл

@ -0,0 +1,53 @@
# Class to read
class MeshPly:
def __init__(self, filename, color=[0., 0., 0.]):
f = open(filename, 'r')
self.vertices = []
self.colors = []
self.indices = []
self.normals = []
vertex_mode = False
face_mode = False
nb_vertices = 0
nb_faces = 0
idx = 0
with f as open_file_object:
for line in open_file_object:
elements = line.split()
if vertex_mode:
self.vertices.append([float(i) for i in elements[:3]])
self.normals.append([float(i) for i in elements[3:6]])
if elements[6:9]:
self.colors.append([float(i) / 255. for i in elements[6:9]])
else:
self.colors.append([float(i) / 255. for i in color])
idx += 1
if idx == nb_vertices:
vertex_mode = False
face_mode = True
idx = 0
elif face_mode:
self.indices.append([float(i) for i in elements[1:4]])
idx += 1
if idx == nb_faces:
face_mode = False
elif elements[0] == 'element':
if elements[1] == 'vertex':
nb_vertices = int(elements[2])
elif elements[1] == 'face':
nb_faces = int(elements[2])
elif elements[0] == 'end_header':
vertex_mode = True
if __name__ == '__main__':
path_model = ''
mesh = MeshPly(path_model)

208
py2/cfg.py Normal file
Просмотреть файл

@ -0,0 +1,208 @@
import torch
from utils import convert2cpu
def parse_cfg(cfgfile):
blocks = []
fp = open(cfgfile, 'r')
block = None
line = fp.readline()
while line != '':
line = line.rstrip()
if line == '' or line[0] == '#':
line = fp.readline()
continue
elif line[0] == '[':
if block:
blocks.append(block)
block = dict()
block['type'] = line.lstrip('[').rstrip(']')
# set default value
if block['type'] == 'convolutional':
block['batch_normalize'] = 0
else:
key,value = line.split('=')
key = key.strip()
if key == 'type':
key = '_type'
value = value.strip()
block[key] = value
line = fp.readline()
if block:
blocks.append(block)
fp.close()
return blocks
def print_cfg(blocks):
print('layer filters size input output');
prev_width = 416
prev_height = 416
prev_filters = 3
out_filters =[]
out_widths =[]
out_heights =[]
ind = -2
for block in blocks:
ind = ind + 1
if block['type'] == 'net':
prev_width = int(block['width'])
prev_height = int(block['height'])
continue
elif block['type'] == 'convolutional':
filters = int(block['filters'])
kernel_size = int(block['size'])
stride = int(block['stride'])
is_pad = int(block['pad'])
pad = (kernel_size-1)/2 if is_pad else 0
width = (prev_width + 2*pad - kernel_size)/stride + 1
height = (prev_height + 2*pad - kernel_size)/stride + 1
print('%5d %-6s %4d %d x %d / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'conv', filters, kernel_size, kernel_size, stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
prev_filters = filters
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'maxpool':
pool_size = int(block['size'])
stride = int(block['stride'])
width = prev_width/stride
height = prev_height/stride
print('%5d %-6s %d x %d / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'max', pool_size, pool_size, stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
prev_filters = filters
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'avgpool':
width = 1
height = 1
print('%5d %-6s %3d x %3d x%4d -> %3d' % (ind, 'avg', prev_width, prev_height, prev_filters, prev_filters))
prev_width = width
prev_height = height
prev_filters = filters
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'softmax':
print('%5d %-6s -> %3d' % (ind, 'softmax', prev_filters))
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'cost':
print('%5d %-6s -> %3d' % (ind, 'cost', prev_filters))
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'reorg':
stride = int(block['stride'])
filters = stride * stride * prev_filters
width = prev_width/stride
height = prev_height/stride
print('%5d %-6s / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'reorg', stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
prev_filters = filters
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'route':
layers = block['layers'].split(',')
layers = [int(i) if int(i) > 0 else int(i)+ind for i in layers]
if len(layers) == 1:
print('%5d %-6s %d' % (ind, 'route', layers[0]))
prev_width = out_widths[layers[0]]
prev_height = out_heights[layers[0]]
prev_filters = out_filters[layers[0]]
elif len(layers) == 2:
print('%5d %-6s %d %d' % (ind, 'route', layers[0], layers[1]))
prev_width = out_widths[layers[0]]
prev_height = out_heights[layers[0]]
assert(prev_width == out_widths[layers[1]])
assert(prev_height == out_heights[layers[1]])
prev_filters = out_filters[layers[0]] + out_filters[layers[1]]
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'region':
print('%5d %-6s' % (ind, 'detection'))
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'shortcut':
from_id = int(block['from'])
from_id = from_id if from_id > 0 else from_id+ind
print('%5d %-6s %d' % (ind, 'shortcut', from_id))
prev_width = out_widths[from_id]
prev_height = out_heights[from_id]
prev_filters = out_filters[from_id]
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'connected':
filters = int(block['output'])
print('%5d %-6s %d -> %3d' % (ind, 'connected', prev_filters, filters))
prev_filters = filters
out_widths.append(1)
out_heights.append(1)
out_filters.append(prev_filters)
else:
print('unknown type %s' % (block['type']))
def load_conv(buf, start, conv_model):
num_w = conv_model.weight.numel()
num_b = conv_model.bias.numel()
conv_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
return start
def save_conv(fp, conv_model):
if conv_model.bias.is_cuda:
convert2cpu(conv_model.bias.data).numpy().tofile(fp)
convert2cpu(conv_model.weight.data).numpy().tofile(fp)
else:
conv_model.bias.data.numpy().tofile(fp)
conv_model.weight.data.numpy().tofile(fp)
def load_conv_bn(buf, start, conv_model, bn_model):
num_w = conv_model.weight.numel()
num_b = bn_model.bias.numel()
bn_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
bn_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
bn_model.running_mean.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
bn_model.running_var.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
return start
def save_conv_bn(fp, conv_model, bn_model):
if bn_model.bias.is_cuda:
convert2cpu(bn_model.bias.data).numpy().tofile(fp)
convert2cpu(bn_model.weight.data).numpy().tofile(fp)
convert2cpu(bn_model.running_mean).numpy().tofile(fp)
convert2cpu(bn_model.running_var).numpy().tofile(fp)
convert2cpu(conv_model.weight.data).numpy().tofile(fp)
else:
bn_model.bias.data.numpy().tofile(fp)
bn_model.weight.data.numpy().tofile(fp)
bn_model.running_mean.numpy().tofile(fp)
bn_model.running_var.numpy().tofile(fp)
conv_model.weight.data.numpy().tofile(fp)
def load_fc(buf, start, fc_model):
num_w = fc_model.weight.numel()
num_b = fc_model.bias.numel()
fc_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
fc_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
return start
def save_fc(fp, fc_model):
fc_model.bias.data.numpy().tofile(fp)
fc_model.weight.data.numpy().tofile(fp)
if __name__ == '__main__':
import sys
blocks = parse_cfg('cfg/yolo.cfg')
if len(sys.argv) == 2:
blocks = parse_cfg(sys.argv[1])
print_cfg(blocks)

7
py2/cfg/ape.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/ape/train.txt
valid = LINEMOD/ape/test.txt
backup = backup/ape
mesh = LINEMOD/ape/ape.ply
tr_range = LINEMOD/ape/training_range.txt
name = ape
diam = 0.103

7
py2/cfg/benchvise.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/benchvise/train.txt
valid = LINEMOD/benchvise/test.txt
backup = backup/benchvise
mesh = LINEMOD/benchvise/benchvise.ply
tr_range = LINEMOD/benchvise/training_range.txt
name = benchvise
diam = 0.286908

7
py2/cfg/cam.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/cam/train.txt
valid = LINEMOD/cam/test.txt
backup = backup/cam
mesh = LINEMOD/cam/cam.ply
tr_range = LINEMOD/cam/training_range.txt
name = cam
diam = 0.173

7
py2/cfg/can.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/can/train.txt
valid = LINEMOD/can/test.txt
backup = backup/can
mesh = LINEMOD/can/can.ply
tr_range = LINEMOD/can/training_range.txt
name = can
diam = 0.202

7
py2/cfg/cat.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/cat/train.txt
valid = LINEMOD/cat/test.txt
backup = backup/cat
mesh = LINEMOD/cat/cat.ply
tr_range = LINEMOD/cat/training_range.txt
name = cat
diam = 0.155

7
py2/cfg/driller.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/driller/train.txt
valid = LINEMOD/driller/test.txt
backup = backup/driller
mesh = LINEMOD/driller/driller.ply
tr_range = LINEMOD/driller/training_range.txt
name = driller
diam = 0.262

7
py2/cfg/duck.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/duck/train.txt
valid = LINEMOD/duck/test.txt
backup = backup/duck
mesh = LINEMOD/duck/duck.ply
tr_range = LINEMOD/duck/training_range.txt
name = duck
diam = 0.109

7
py2/cfg/eggbox.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/eggbox/train.txt
valid = LINEMOD/eggbox/test.txt
backup = backup/eggbox
mesh = LINEMOD/eggbox/eggbox.ply
tr_range = LINEMOD/eggbox/training_range.txt
name = eggbox
diam = 0.176364

7
py2/cfg/glue.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/glue/train.txt
valid = LINEMOD/glue/test.txt
backup = backup/glue
mesh = LINEMOD/glue/glue.ply
tr_range = LINEMOD/glue/training_range.txt
name = glue
diam = 0.176

7
py2/cfg/holepuncher.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/holepuncher/train.txt
valid = LINEMOD/holepuncher/test.txt
backup = backup/holepuncher
mesh = LINEMOD/holepuncher/holepuncher.ply
tr_range = LINEMOD/holepuncher/training_range.txt
name = holepuncher
diam = 0.162

7
py2/cfg/iron.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/iron/train.txt
valid = LINEMOD/iron/test.txt
backup = backup/iron
mesh = LINEMOD/iron/iron.ply
tr_range = LINEMOD/iron/training_range.txt
name = iron
diam = 0.303153

7
py2/cfg/lamp.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/lamp/train.txt
valid = LINEMOD/lamp/test.txt
backup = backup/lamp
mesh = LINEMOD/lamp/lamp.ply
tr_range = LINEMOD/lamp/training_range.txt
name = lamp
diam = 0.285155

7
py2/cfg/phone.data Normal file
Просмотреть файл

@ -0,0 +1,7 @@
train = LINEMOD/phone/train.txt
valid = LINEMOD/phone/test.txt
backup = backup/phone
mesh = LINEMOD/phone/phone.ply
tr_range = LINEMOD/phone/training_range.txt
name = phone
diam = 0.213

256
py2/cfg/yolo-pose-pre.cfg Normal file
Просмотреть файл

@ -0,0 +1,256 @@
[net]
batch=32
height=416
width=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 80200
policy=steps
# steps=-1,500,20000,30000
# steps=-1,180,360,540
steps=-1,50,1000,2000
scales=0.1,10,.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
#######
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[route]
layers=-9
[convolutional]
batch_normalize=1
size=1
stride=1
pad=1
filters=64
activation=leaky
[reorg]
stride=2
[route]
layers=-1,-4
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
# filters=125
filters=32
activation=linear
[region]
anchors =
bias_match=1
classes=13
coords=18
num=1
softmax=1
jitter=.3
rescore=1
object_scale=0
noobject_scale=0
class_scale=1
coord_scale=1
absolute=1
thresh = .6
random=1

255
py2/cfg/yolo-pose.cfg Normal file
Просмотреть файл

@ -0,0 +1,255 @@
[net]
batch=32
height=416
width=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 80200
policy=steps
# steps=-1,500,20000,30000
steps=-1,50,3000,6000
scales=0.1,10,.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
#######
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[route]
layers=-9
[convolutional]
batch_normalize=1
size=1
stride=1
pad=1
filters=64
activation=leaky
[reorg]
stride=2
[route]
layers=-1,-4
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
# for a custom dataset, filters should be equal to (num_coords + num_classes + 1 conf value) * num_anchors
filters=20
activation=linear
[region]
anchors =
bias_match=1
classes=1
coords=18
num=1
softmax=1
jitter=.3
rescore=1
object_scale=5
noobject_scale=0.1
class_scale=1
coord_scale=1
absolute=1
thresh = .6
random=1

391
py2/darknet.py Normal file
Просмотреть файл

@ -0,0 +1,391 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from region_loss import RegionLoss
from cfg import *
class MaxPoolStride1(nn.Module):
def __init__(self):
super(MaxPoolStride1, self).__init__()
def forward(self, x):
x = F.max_pool2d(F.pad(x, (0,1,0,1), mode='replicate'), 2, stride=1)
return x
class Reorg(nn.Module):
def __init__(self, stride=2):
super(Reorg, self).__init__()
self.stride = stride
def forward(self, x):
stride = self.stride
assert(x.data.dim() == 4)
B = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
assert(H % stride == 0)
assert(W % stride == 0)
ws = stride
hs = stride
x = x.view(B, C, H/hs, hs, W/ws, ws).transpose(3,4).contiguous()
x = x.view(B, C, H/hs*W/ws, hs*ws).transpose(2,3).contiguous()
x = x.view(B, C, hs*ws, H/hs, W/ws).transpose(1,2).contiguous()
x = x.view(B, hs*ws*C, H/hs, W/ws)
return x
class GlobalAvgPool2d(nn.Module):
def __init__(self):
super(GlobalAvgPool2d, self).__init__()
def forward(self, x):
N = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
x = F.avg_pool2d(x, (H, W))
x = x.view(N, C)
return x
# for route and shortcut
class EmptyModule(nn.Module):
def __init__(self):
super(EmptyModule, self).__init__()
def forward(self, x):
return x
# support route shortcut and reorg
class Darknet(nn.Module):
def __init__(self, cfgfile):
super(Darknet, self).__init__()
self.blocks = parse_cfg(cfgfile)
self.models = self.create_network(self.blocks) # merge conv, bn,leaky
self.loss = self.models[len(self.models)-1]
self.width = int(self.blocks[0]['width'])
self.height = int(self.blocks[0]['height'])
if self.blocks[(len(self.blocks)-1)]['type'] == 'region':
self.anchors = self.loss.anchors
self.num_anchors = self.loss.num_anchors
self.anchor_step = self.loss.anchor_step
self.num_classes = self.loss.num_classes
self.header = torch.IntTensor([0,0,0,0])
self.seen = 0
self.iter = 0
def forward(self, x):
ind = -2
self.loss = None
outputs = dict()
for block in self.blocks:
ind = ind + 1
#if ind > 0:
# return x
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional' or block['type'] == 'maxpool' or block['type'] == 'reorg' or block['type'] == 'avgpool' or block['type'] == 'softmax' or block['type'] == 'connected':
x = self.models[ind](x)
outputs[ind] = x
elif block['type'] == 'route':
layers = block['layers'].split(',')
layers = [int(i) if int(i) > 0 else int(i)+ind for i in layers]
if len(layers) == 1:
x = outputs[layers[0]]
outputs[ind] = x
elif len(layers) == 2:
x1 = outputs[layers[0]]
x2 = outputs[layers[1]]
x = torch.cat((x1,x2),1)
outputs[ind] = x
elif block['type'] == 'shortcut':
from_layer = int(block['from'])
activation = block['activation']
from_layer = from_layer if from_layer > 0 else from_layer + ind
x1 = outputs[from_layer]
x2 = outputs[ind-1]
x = x1 + x2
if activation == 'leaky':
x = F.leaky_relu(x, 0.1, inplace=True)
elif activation == 'relu':
x = F.relu(x, inplace=True)
outputs[ind] = x
elif block['type'] == 'region':
continue
if self.loss:
self.loss = self.loss + self.models[ind](x)
else:
self.loss = self.models[ind](x)
outputs[ind] = None
elif block['type'] == 'cost':
continue
else:
print('unknown type %s' % (block['type']))
return x
def print_network(self):
print_cfg(self.blocks)
def create_network(self, blocks):
models = nn.ModuleList()
prev_filters = 3
out_filters =[]
conv_id = 0
for block in blocks:
if block['type'] == 'net':
prev_filters = int(block['channels'])
continue
elif block['type'] == 'convolutional':
conv_id = conv_id + 1
batch_normalize = int(block['batch_normalize'])
filters = int(block['filters'])
kernel_size = int(block['size'])
stride = int(block['stride'])
is_pad = int(block['pad'])
pad = (kernel_size-1)/2 if is_pad else 0
activation = block['activation']
model = nn.Sequential()
if batch_normalize:
model.add_module('conv{0}'.format(conv_id), nn.Conv2d(prev_filters, filters, kernel_size, stride, pad, bias=False))
model.add_module('bn{0}'.format(conv_id), nn.BatchNorm2d(filters, eps=1e-4))
#model.add_module('bn{0}'.format(conv_id), BN2d(filters))
else:
model.add_module('conv{0}'.format(conv_id), nn.Conv2d(prev_filters, filters, kernel_size, stride, pad))
if activation == 'leaky':
model.add_module('leaky{0}'.format(conv_id), nn.LeakyReLU(0.1, inplace=True))
elif activation == 'relu':
model.add_module('relu{0}'.format(conv_id), nn.ReLU(inplace=True))
prev_filters = filters
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'maxpool':
pool_size = int(block['size'])
stride = int(block['stride'])
if stride > 1:
model = nn.MaxPool2d(pool_size, stride)
else:
model = MaxPoolStride1()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'avgpool':
model = GlobalAvgPool2d()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'softmax':
model = nn.Softmax()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'cost':
if block['_type'] == 'sse':
model = nn.MSELoss(size_average=True)
elif block['_type'] == 'L1':
model = nn.L1Loss(size_average=True)
elif block['_type'] == 'smooth':
model = nn.SmoothL1Loss(size_average=True)
out_filters.append(1)
models.append(model)
elif block['type'] == 'reorg':
stride = int(block['stride'])
prev_filters = stride * stride * prev_filters
out_filters.append(prev_filters)
models.append(Reorg(stride))
elif block['type'] == 'route':
layers = block['layers'].split(',')
ind = len(models)
layers = [int(i) if int(i) > 0 else int(i)+ind for i in layers]
if len(layers) == 1:
prev_filters = out_filters[layers[0]]
elif len(layers) == 2:
assert(layers[0] == ind - 1)
prev_filters = out_filters[layers[0]] + out_filters[layers[1]]
out_filters.append(prev_filters)
models.append(EmptyModule())
elif block['type'] == 'shortcut':
ind = len(models)
prev_filters = out_filters[ind-1]
out_filters.append(prev_filters)
models.append(EmptyModule())
elif block['type'] == 'connected':
filters = int(block['output'])
if block['activation'] == 'linear':
model = nn.Linear(prev_filters, filters)
elif block['activation'] == 'leaky':
model = nn.Sequential(
nn.Linear(prev_filters, filters),
nn.LeakyReLU(0.1, inplace=True))
elif block['activation'] == 'relu':
model = nn.Sequential(
nn.Linear(prev_filters, filters),
nn.ReLU(inplace=True))
prev_filters = filters
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'region':
loss = RegionLoss()
anchors = block['anchors'].split(',')
if anchors == ['']:
loss.anchors = []
else:
loss.anchors = [float(i) for i in anchors]
loss.num_classes = int(block['classes'])
loss.num_anchors = int(block['num'])
loss.anchor_step = len(loss.anchors)/loss.num_anchors
loss.object_scale = float(block['object_scale'])
loss.noobject_scale = float(block['noobject_scale'])
loss.class_scale = float(block['class_scale'])
loss.coord_scale = float(block['coord_scale'])
out_filters.append(prev_filters)
models.append(loss)
else:
print('unknown type %s' % (block['type']))
return models
def load_weights(self, weightfile):
fp = open(weightfile, 'rb')
header = np.fromfile(fp, count=4, dtype=np.int32)
self.header = torch.from_numpy(header)
self.seen = self.header[3]
buf = np.fromfile(fp, dtype = np.float32)
fp.close()
start = 0
ind = -2
for block in self.blocks:
if start >= buf.size:
break
ind = ind + 1
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
start = load_conv_bn(buf, start, model[0], model[1])
else:
start = load_conv(buf, start, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
start = load_fc(buf, start, model[0])
else:
start = load_fc(buf, start, model)
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
def load_weights_until_last(self, weightfile):
fp = open(weightfile, 'rb')
header = np.fromfile(fp, count=4, dtype=np.int32)
self.header = torch.from_numpy(header)
self.seen = self.header[3]
buf = np.fromfile(fp, dtype = np.float32)
fp.close()
start = 0
ind = -2
blocklen = len(self.blocks)
for i in range(blocklen-2):
block = self.blocks[i]
if start >= buf.size:
break
ind = ind + 1
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
start = load_conv_bn(buf, start, model[0], model[1])
else:
start = load_conv(buf, start, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
start = load_fc(buf, start, model[0])
else:
start = load_fc(buf, start, model)
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
def save_weights(self, outfile, cutoff=0):
if cutoff <= 0:
cutoff = len(self.blocks)-1
fp = open(outfile, 'wb')
self.header[3] = self.seen
header = self.header
header.numpy().tofile(fp)
ind = -1
for blockId in range(1, cutoff+1):
ind = ind + 1
block = self.blocks[blockId]
if block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
save_conv_bn(fp, model[0], model[1])
else:
save_conv(fp, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
save_fc(fc, model)
else:
save_fc(fc, model[0])
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
fp.close()

101
py2/dataset.py Normal file
Просмотреть файл

@ -0,0 +1,101 @@
#!/usr/bin/python
# encoding: utf-8
import os
import random
from PIL import Image
import numpy as np
from image import *
import torch
from torch.utils.data import Dataset
from utils import read_truths_args, read_truths, get_all_files
class listDataset(Dataset):
def __init__(self, root, shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4, bg_file_names=None):
with open(root, 'r') as file:
self.lines = file.readlines()
if shuffle:
random.shuffle(self.lines)
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.num_workers = num_workers
self.bg_file_names = bg_file_names
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
imgpath = self.lines[index].rstrip()
if self.train and index % 32== 0:
if self.seen < 400*32:
width = 13*32
self.shape = (width, width)
elif self.seen < 800*32:
width = (random.randint(0,7) + 13)*32
self.shape = (width, width)
elif self.seen < 1200*32:
width = (random.randint(0,9) + 12)*32
self.shape = (width, width)
elif self.seen < 1600*32:
width = (random.randint(0,11) + 11)*32
self.shape = (width, width)
elif self.seen < 2000*32:
width = (random.randint(0,13) + 10)*32
self.shape = (width, width)
elif self.seen < 2400*32:
width = (random.randint(0,15) + 9)*32
self.shape = (width, width)
elif self.seen < 3000*32:
width = (random.randint(0,17) + 8)*32
self.shape = (width, width)
else: # self.seen < 20000*64:
width = (random.randint(0,19) + 7)*32
self.shape = (width, width)
if self.train:
jitter = 0.2
hue = 0.1
saturation = 1.5
exposure = 1.5
# Get background image path
random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
bgpath = self.bg_file_names[random_bg_index]
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath)
label = torch.from_numpy(label)
else:
img = Image.open(imgpath).convert('RGB')
if self.shape:
img = img.resize(self.shape)
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
label = torch.zeros(50*21)
if os.path.getsize(labpath):
ow, oh = img.size
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/ow))
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > 50*21:
label = tmp[0:50*21]
elif tsz > 0:
label[0:tsz] = tmp
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.target_transform(label)
self.seen = self.seen + self.num_workers
return (img, label)

184
py2/image.py Normal file
Просмотреть файл

@ -0,0 +1,184 @@
#!/usr/bin/python
# encoding: utf-8
import random
import os
from PIL import Image, ImageChops, ImageMath
import numpy as np
def scale_image_channel(im, c, v):
cs = list(im.split())
cs[c] = cs[c].point(lambda i: i * v)
out = Image.merge(im.mode, tuple(cs))
return out
def distort_image(im, hue, sat, val):
im = im.convert('HSV')
cs = list(im.split())
cs[1] = cs[1].point(lambda i: i * sat)
cs[2] = cs[2].point(lambda i: i * val)
def change_hue(x):
x += hue*255
if x > 255:
x -= 255
if x < 0:
x += 255
return x
cs[0] = cs[0].point(change_hue)
im = Image.merge(im.mode, tuple(cs))
im = im.convert('RGB')
return im
def rand_scale(s):
scale = random.uniform(1, s)
if(random.randint(1,10000)%2):
return scale
return 1./scale
def random_distort_image(im, hue, saturation, exposure):
dhue = random.uniform(-hue, hue)
dsat = rand_scale(saturation)
dexp = rand_scale(exposure)
res = distort_image(im, dhue, dsat, dexp)
return res
def data_augmentation(img, shape, jitter, hue, saturation, exposure):
ow, oh = img.size
dw =int(ow*jitter)
dh =int(oh*jitter)
pleft = random.randint(-dw, dw)
pright = random.randint(-dw, dw)
ptop = random.randint(-dh, dh)
pbot = random.randint(-dh, dh)
swidth = ow - pleft - pright
sheight = oh - ptop - pbot
sx = float(swidth) / ow
sy = float(sheight) / oh
flip = random.randint(1,10000)%2
cropped = img.crop( (pleft, ptop, pleft + swidth - 1, ptop + sheight - 1))
dx = (float(pleft)/ow)/sx
dy = (float(ptop) /oh)/sy
sized = cropped.resize(shape)
img = random_distort_image(sized, hue, saturation, exposure)
return img, flip, dx,dy,sx,sy
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy):
max_boxes = 50
label = np.zeros((max_boxes,21))
if os.path.getsize(labpath):
bs = np.loadtxt(labpath)
if bs is None:
return label
bs = np.reshape(bs, (-1, 21))
cc = 0
for i in range(bs.shape[0]):
x0 = bs[i][1]
y0 = bs[i][2]
x1 = bs[i][3]
y1 = bs[i][4]
x2 = bs[i][5]
y2 = bs[i][6]
x3 = bs[i][7]
y3 = bs[i][8]
x4 = bs[i][9]
y4 = bs[i][10]
x5 = bs[i][11]
y5 = bs[i][12]
x6 = bs[i][13]
y6 = bs[i][14]
x7 = bs[i][15]
y7 = bs[i][16]
x8 = bs[i][17]
y8 = bs[i][18]
x0 = min(0.999, max(0, x0 * sx - dx))
y0 = min(0.999, max(0, y0 * sy - dy))
x1 = min(0.999, max(0, x1 * sx - dx))
y1 = min(0.999, max(0, y1 * sy - dy))
x2 = min(0.999, max(0, x2 * sx - dx))
y2 = min(0.999, max(0, y2 * sy - dy))
x3 = min(0.999, max(0, x3 * sx - dx))
y3 = min(0.999, max(0, y3 * sy - dy))
x4 = min(0.999, max(0, x4 * sx - dx))
y4 = min(0.999, max(0, y4 * sy - dy))
x5 = min(0.999, max(0, x5 * sx - dx))
y5 = min(0.999, max(0, y5 * sy - dy))
x6 = min(0.999, max(0, x6 * sx - dx))
y6 = min(0.999, max(0, y6 * sy - dy))
x7 = min(0.999, max(0, x7 * sx - dx))
y7 = min(0.999, max(0, y7 * sy - dy))
x8 = min(0.999, max(0, x8 * sx - dx))
y8 = min(0.999, max(0, y8 * sy - dy))
bs[i][1] = x0
bs[i][2] = y0
bs[i][3] = x1
bs[i][4] = y1
bs[i][5] = x2
bs[i][6] = y2
bs[i][7] = x3
bs[i][8] = y3
bs[i][9] = x4
bs[i][10] = y4
bs[i][11] = x5
bs[i][12] = y5
bs[i][13] = x6
bs[i][14] = y6
bs[i][15] = x7
bs[i][16] = y7
bs[i][17] = x8
bs[i][18] = y8
label[cc] = bs[i]
cc += 1
if cc >= 50:
break
label = np.reshape(label, (-1))
return label
def change_background(img, mask, bg):
# oh = img.height
# ow = img.width
ow, oh = img.size
bg = bg.resize((ow, oh)).convert('RGB')
imcs = list(img.split())
bgcs = list(bg.split())
maskcs = list(mask.split())
fics = list(Image.new(img.mode, img.size).split())
for c in range(len(imcs)):
negmask = maskcs[c].point(lambda i: 1 - i / 255)
posmask = maskcs[c].point(lambda i: i / 255)
fics[c] = ImageMath.eval("a * c + b * d", a=imcs[c], b=bgcs[c], c=posmask, d=negmask).convert('L')
out = Image.merge(img.mode, tuple(fics))
return out
def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpath):
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
maskpath = imgpath.replace('JPEGImages', 'mask').replace('/00', '/').replace('.jpg', '.png')
## data augmentation
img = Image.open(imgpath).convert('RGB')
mask = Image.open(maskpath).convert('RGB')
bg = Image.open(bgpath).convert('RGB')
img = change_background(img, mask, bg)
img,flip,dx,dy,sx,sy = data_augmentation(img, shape, jitter, hue, saturation, exposure)
ow, oh = img.size
label = fill_truth_detection(labpath, ow, oh, flip, dx, dy, 1./sx, 1./sy)
return img,label

Просмотреть файл

@ -0,0 +1,13 @@
#### Label file creation
You could follow these steps to create labels for your custom dataset:
1. Get the 3D bounding box surrounding the 3D object model. We use the already provided 3D object model for the LINEMOD dataset to get the 3D bounding box. If you would like to create a 3D model for a custom object, you can refer to the Section 3.5 of the following paper and the references therein: http://cmp.felk.cvut.cz/~hodanto2/data/hodan2017tless.pdf
2. Define the 8 corners of the 3D bounding box and the centroid of the 3D object model as the virtual keypoints of the object. 8 corners correspond to the [[min_x, min_y, min_z], [min_x, min_y, max_z], [min_x, max_y, min_z], [min_x, max_y, max_z], [max_x, min_y, min_z], [max_x, min_y, max_z], [max_x, max_y, min_z], [max_x, max_y, max_z]] positions of the 3D object model, and the centroid corresponds to the [0, 0, 0] position.
3. Project the 3D keypoints to 2D. You can use the [compute_projection](https://github.com/Microsoft/singleshotpose/blob/master/utils.py#L39:L44) function that we provide to project the 3D points in 2D. You would need to know the intrinsic calibration matrix of the camera and the ground-truth rotation and translation to project the 3D points in 2D. Typically, obtaining ground-truth Rt transformation matrices requires a manual and intrusive annotation effort. For an example of how to acquire ground-truth data for 6D pose estimation, please refer to the Section 3.1 of the [paper](http://cmp.felk.cvut.cz/~hodanto2/data/hodan2017tless.pdf) describing the T-LESS dataset.
4. Compute the width and height of a 2D rectangle tightly fitted to a masked region around the object. If you have the 2D bounding box information (e.g. width and height) for the custom object that you have, you can use those values in your label file. In practice, however, we fit a tight bounding box to the 8 corners of the projected 3D bounding box and use the width and height of that bounding box to represent these values.
5. Create an array consisting of the class, 2D keypoint location and the range information and write it into a text file. The label file is organized in the following order. 1st number: class label, 2nd number: x0 (x-coordinate of the centroid), 3rd number: y0 (y-coordinate of the centroid), 4th number: x1 (x-coordinate of the first corner), 5th number: y1 (y-coordinate of the first corner), ..., 18th number: x8 (x-coordinate of the eighth corner), 19th number: y8 (y-coordinate of the eighth corner), 20th number: x range, 21st number: y range.

Просмотреть файл

Просмотреть файл

Просмотреть файл

@ -0,0 +1,5 @@
valid = ../LINEMOD/ape/test_occlusion.txt
mesh = ../LINEMOD/ape/ape.ply
backup = backup_multi
name = ape
diam = 0.103

Просмотреть файл

@ -0,0 +1,7 @@
train = ../LINEMOD/benchvise/train.txt
valid = ../LINEMOD/benchvise/test.txt
backup = backup_multi
mesh = ../LINEMOD/benchvise/benchvise.ply
tr_range = ../LINEMOD/benchvise/training_range.txt
name = benchvise
diam = 0.286908

Просмотреть файл

@ -0,0 +1,5 @@
valid = ../LINEMOD/can/test_occlusion.txt
mesh = ../LINEMOD/can/can.ply
backup = backup_multi
name = can
diam = 0.202

Просмотреть файл

@ -0,0 +1,5 @@
valid = ../LINEMOD/cat/test_occlusion.txt
mesh = ../LINEMOD/cat/cat.ply
backup = backup_multi
name = cat
diam = 0.155

Просмотреть файл

@ -0,0 +1,5 @@
valid = ../LINEMOD/driller/test_occlusion.txt
mesh = ../LINEMOD/driller/driller.ply
backup = backup_multi
name = driller
diam = 0.262

Просмотреть файл

@ -0,0 +1,5 @@
valid = ../LINEMOD/duck/test_occlusion.txt
mesh = ../LINEMOD/duck/duck.ply
backup = backup_multi
name = duck
diam = 0.109

Просмотреть файл

@ -0,0 +1,5 @@
valid = ../LINEMOD/eggbox/test_occlusion.txt
mesh = ../LINEMOD/eggbox/eggbox.ply
backup = backup_multi
name = eggbox
diam = 0.176364

Просмотреть файл

@ -0,0 +1,5 @@
valid = ../LINEMOD/glue/test_occlusion.txt
mesh = ../LINEMOD/glue/glue.ply
backup = backup_multi
name = glue
diam = 0.176

Просмотреть файл

@ -0,0 +1,5 @@
valid = ../LINEMOD/holepuncher/test_occlusion.txt
mesh = ../LINEMOD/holepuncher/holepuncher.ply
backup = backup_multi
name = holepuncher
diam = 0.162

Просмотреть файл

@ -0,0 +1,23 @@
train = cfg/train_occlusion.txt
valid1 = ../LINEMOD/ape/test_occlusion.txt
valid4 = ../LINEMOD/can/test_occlusion.txt
valid5 = ../LINEMOD/cat/test_occlusion.txt
valid6 = ../LINEMOD/driller/test_occlusion.txt
valid7 = ../LINEMOD/duck/test_occlusion.txt
valid9 = ../LINEMOD/glue/test_occlusion.txt
valid10 = ../LINEMOD/holepuncher/test_occlusion.txt
backup = backup_multi
mesh1 = ../LINEMOD/ape/ape.ply
mesh4 = ../LINEMOD/can/can.ply
mesh5 = ../LINEMOD/cat/cat.ply
mesh6 = ../LINEMOD/driller/driller.ply
mesh7 = ../LINEMOD/duck/duck.ply
mesh9 = ../LINEMOD/glue/glue.ply
mesh10 = ../LINEMOD/holepuncher/holepuncher.ply
diam1 = 0.103
diam4 = 0.202
diam5 = 0.155
diam6 = 0.262
diam7 = 0.109
diam9 = 0.176
diam10 = 0.162

Просмотреть файл

@ -0,0 +1,183 @@
../LINEMOD/benchvise/JPEGImages/000024.jpg
../LINEMOD/benchvise/JPEGImages/000030.jpg
../LINEMOD/benchvise/JPEGImages/000045.jpg
../LINEMOD/benchvise/JPEGImages/000053.jpg
../LINEMOD/benchvise/JPEGImages/000063.jpg
../LINEMOD/benchvise/JPEGImages/000065.jpg
../LINEMOD/benchvise/JPEGImages/000071.jpg
../LINEMOD/benchvise/JPEGImages/000072.jpg
../LINEMOD/benchvise/JPEGImages/000076.jpg
../LINEMOD/benchvise/JPEGImages/000078.jpg
../LINEMOD/benchvise/JPEGImages/000091.jpg
../LINEMOD/benchvise/JPEGImages/000092.jpg
../LINEMOD/benchvise/JPEGImages/000095.jpg
../LINEMOD/benchvise/JPEGImages/000099.jpg
../LINEMOD/benchvise/JPEGImages/000103.jpg
../LINEMOD/benchvise/JPEGImages/000106.jpg
../LINEMOD/benchvise/JPEGImages/000116.jpg
../LINEMOD/benchvise/JPEGImages/000123.jpg
../LINEMOD/benchvise/JPEGImages/000130.jpg
../LINEMOD/benchvise/JPEGImages/000134.jpg
../LINEMOD/benchvise/JPEGImages/000139.jpg
../LINEMOD/benchvise/JPEGImages/000146.jpg
../LINEMOD/benchvise/JPEGImages/000152.jpg
../LINEMOD/benchvise/JPEGImages/000153.jpg
../LINEMOD/benchvise/JPEGImages/000155.jpg
../LINEMOD/benchvise/JPEGImages/000157.jpg
../LINEMOD/benchvise/JPEGImages/000158.jpg
../LINEMOD/benchvise/JPEGImages/000161.jpg
../LINEMOD/benchvise/JPEGImages/000163.jpg
../LINEMOD/benchvise/JPEGImages/000167.jpg
../LINEMOD/benchvise/JPEGImages/000172.jpg
../LINEMOD/benchvise/JPEGImages/000174.jpg
../LINEMOD/benchvise/JPEGImages/000183.jpg
../LINEMOD/benchvise/JPEGImages/000200.jpg
../LINEMOD/benchvise/JPEGImages/000214.jpg
../LINEMOD/benchvise/JPEGImages/000221.jpg
../LINEMOD/benchvise/JPEGImages/000226.jpg
../LINEMOD/benchvise/JPEGImages/000235.jpg
../LINEMOD/benchvise/JPEGImages/000239.jpg
../LINEMOD/benchvise/JPEGImages/000243.jpg
../LINEMOD/benchvise/JPEGImages/000271.jpg
../LINEMOD/benchvise/JPEGImages/000274.jpg
../LINEMOD/benchvise/JPEGImages/000277.jpg
../LINEMOD/benchvise/JPEGImages/000286.jpg
../LINEMOD/benchvise/JPEGImages/000291.jpg
../LINEMOD/benchvise/JPEGImages/000294.jpg
../LINEMOD/benchvise/JPEGImages/000302.jpg
../LINEMOD/benchvise/JPEGImages/000307.jpg
../LINEMOD/benchvise/JPEGImages/000314.jpg
../LINEMOD/benchvise/JPEGImages/000320.jpg
../LINEMOD/benchvise/JPEGImages/000324.jpg
../LINEMOD/benchvise/JPEGImages/000347.jpg
../LINEMOD/benchvise/JPEGImages/000350.jpg
../LINEMOD/benchvise/JPEGImages/000355.jpg
../LINEMOD/benchvise/JPEGImages/000364.jpg
../LINEMOD/benchvise/JPEGImages/000367.jpg
../LINEMOD/benchvise/JPEGImages/000369.jpg
../LINEMOD/benchvise/JPEGImages/000376.jpg
../LINEMOD/benchvise/JPEGImages/000377.jpg
../LINEMOD/benchvise/JPEGImages/000379.jpg
../LINEMOD/benchvise/JPEGImages/000383.jpg
../LINEMOD/benchvise/JPEGImages/000384.jpg
../LINEMOD/benchvise/JPEGImages/000387.jpg
../LINEMOD/benchvise/JPEGImages/000394.jpg
../LINEMOD/benchvise/JPEGImages/000402.jpg
../LINEMOD/benchvise/JPEGImages/000406.jpg
../LINEMOD/benchvise/JPEGImages/000410.jpg
../LINEMOD/benchvise/JPEGImages/000413.jpg
../LINEMOD/benchvise/JPEGImages/000422.jpg
../LINEMOD/benchvise/JPEGImages/000425.jpg
../LINEMOD/benchvise/JPEGImages/000430.jpg
../LINEMOD/benchvise/JPEGImages/000434.jpg
../LINEMOD/benchvise/JPEGImages/000441.jpg
../LINEMOD/benchvise/JPEGImages/000446.jpg
../LINEMOD/benchvise/JPEGImages/000451.jpg
../LINEMOD/benchvise/JPEGImages/000456.jpg
../LINEMOD/benchvise/JPEGImages/000461.jpg
../LINEMOD/benchvise/JPEGImages/000465.jpg
../LINEMOD/benchvise/JPEGImages/000471.jpg
../LINEMOD/benchvise/JPEGImages/000480.jpg
../LINEMOD/benchvise/JPEGImages/000483.jpg
../LINEMOD/benchvise/JPEGImages/000493.jpg
../LINEMOD/benchvise/JPEGImages/000496.jpg
../LINEMOD/benchvise/JPEGImages/000498.jpg
../LINEMOD/benchvise/JPEGImages/000507.jpg
../LINEMOD/benchvise/JPEGImages/000512.jpg
../LINEMOD/benchvise/JPEGImages/000525.jpg
../LINEMOD/benchvise/JPEGImages/000527.jpg
../LINEMOD/benchvise/JPEGImages/000532.jpg
../LINEMOD/benchvise/JPEGImages/000533.jpg
../LINEMOD/benchvise/JPEGImages/000534.jpg
../LINEMOD/benchvise/JPEGImages/000539.jpg
../LINEMOD/benchvise/JPEGImages/000554.jpg
../LINEMOD/benchvise/JPEGImages/000556.jpg
../LINEMOD/benchvise/JPEGImages/000568.jpg
../LINEMOD/benchvise/JPEGImages/000571.jpg
../LINEMOD/benchvise/JPEGImages/000573.jpg
../LINEMOD/benchvise/JPEGImages/000576.jpg
../LINEMOD/benchvise/JPEGImages/000598.jpg
../LINEMOD/benchvise/JPEGImages/000603.jpg
../LINEMOD/benchvise/JPEGImages/000604.jpg
../LINEMOD/benchvise/JPEGImages/000609.jpg
../LINEMOD/benchvise/JPEGImages/000627.jpg
../LINEMOD/benchvise/JPEGImages/000635.jpg
../LINEMOD/benchvise/JPEGImages/000641.jpg
../LINEMOD/benchvise/JPEGImages/000649.jpg
../LINEMOD/benchvise/JPEGImages/000653.jpg
../LINEMOD/benchvise/JPEGImages/000656.jpg
../LINEMOD/benchvise/JPEGImages/000659.jpg
../LINEMOD/benchvise/JPEGImages/000668.jpg
../LINEMOD/benchvise/JPEGImages/000676.jpg
../LINEMOD/benchvise/JPEGImages/000692.jpg
../LINEMOD/benchvise/JPEGImages/000697.jpg
../LINEMOD/benchvise/JPEGImages/000706.jpg
../LINEMOD/benchvise/JPEGImages/000715.jpg
../LINEMOD/benchvise/JPEGImages/000717.jpg
../LINEMOD/benchvise/JPEGImages/000726.jpg
../LINEMOD/benchvise/JPEGImages/000735.jpg
../LINEMOD/benchvise/JPEGImages/000744.jpg
../LINEMOD/benchvise/JPEGImages/000747.jpg
../LINEMOD/benchvise/JPEGImages/000752.jpg
../LINEMOD/benchvise/JPEGImages/000758.jpg
../LINEMOD/benchvise/JPEGImages/000760.jpg
../LINEMOD/benchvise/JPEGImages/000772.jpg
../LINEMOD/benchvise/JPEGImages/000775.jpg
../LINEMOD/benchvise/JPEGImages/000780.jpg
../LINEMOD/benchvise/JPEGImages/000785.jpg
../LINEMOD/benchvise/JPEGImages/000800.jpg
../LINEMOD/benchvise/JPEGImages/000802.jpg
../LINEMOD/benchvise/JPEGImages/000828.jpg
../LINEMOD/benchvise/JPEGImages/000837.jpg
../LINEMOD/benchvise/JPEGImages/000842.jpg
../LINEMOD/benchvise/JPEGImages/000845.jpg
../LINEMOD/benchvise/JPEGImages/000847.jpg
../LINEMOD/benchvise/JPEGImages/000850.jpg
../LINEMOD/benchvise/JPEGImages/000859.jpg
../LINEMOD/benchvise/JPEGImages/000875.jpg
../LINEMOD/benchvise/JPEGImages/000880.jpg
../LINEMOD/benchvise/JPEGImages/000883.jpg
../LINEMOD/benchvise/JPEGImages/000891.jpg
../LINEMOD/benchvise/JPEGImages/000892.jpg
../LINEMOD/benchvise/JPEGImages/000915.jpg
../LINEMOD/benchvise/JPEGImages/000916.jpg
../LINEMOD/benchvise/JPEGImages/000923.jpg
../LINEMOD/benchvise/JPEGImages/000931.jpg
../LINEMOD/benchvise/JPEGImages/000933.jpg
../LINEMOD/benchvise/JPEGImages/000941.jpg
../LINEMOD/benchvise/JPEGImages/000945.jpg
../LINEMOD/benchvise/JPEGImages/000954.jpg
../LINEMOD/benchvise/JPEGImages/000959.jpg
../LINEMOD/benchvise/JPEGImages/000964.jpg
../LINEMOD/benchvise/JPEGImages/000975.jpg
../LINEMOD/benchvise/JPEGImages/000987.jpg
../LINEMOD/benchvise/JPEGImages/001002.jpg
../LINEMOD/benchvise/JPEGImages/001014.jpg
../LINEMOD/benchvise/JPEGImages/001020.jpg
../LINEMOD/benchvise/JPEGImages/001024.jpg
../LINEMOD/benchvise/JPEGImages/001038.jpg
../LINEMOD/benchvise/JPEGImages/001040.jpg
../LINEMOD/benchvise/JPEGImages/001048.jpg
../LINEMOD/benchvise/JPEGImages/001066.jpg
../LINEMOD/benchvise/JPEGImages/001071.jpg
../LINEMOD/benchvise/JPEGImages/001081.jpg
../LINEMOD/benchvise/JPEGImages/001084.jpg
../LINEMOD/benchvise/JPEGImages/001088.jpg
../LINEMOD/benchvise/JPEGImages/001102.jpg
../LINEMOD/benchvise/JPEGImages/001103.jpg
../LINEMOD/benchvise/JPEGImages/001106.jpg
../LINEMOD/benchvise/JPEGImages/001112.jpg
../LINEMOD/benchvise/JPEGImages/001121.jpg
../LINEMOD/benchvise/JPEGImages/001129.jpg
../LINEMOD/benchvise/JPEGImages/001133.jpg
../LINEMOD/benchvise/JPEGImages/001135.jpg
../LINEMOD/benchvise/JPEGImages/001136.jpg
../LINEMOD/benchvise/JPEGImages/001157.jpg
../LINEMOD/benchvise/JPEGImages/001159.jpg
../LINEMOD/benchvise/JPEGImages/001163.jpg
../LINEMOD/benchvise/JPEGImages/001171.jpg
../LINEMOD/benchvise/JPEGImages/001172.jpg
../LINEMOD/benchvise/JPEGImages/001174.jpg
../LINEMOD/benchvise/JPEGImages/001191.jpg
../LINEMOD/benchvise/JPEGImages/001198.jpg
../LINEMOD/benchvise/JPEGImages/001205.jpg

Просмотреть файл

@ -0,0 +1,261 @@
[net]
# Testing
batch=64
subdivisions=8
# Training
# batch=64
# subdivisions=8
height=416
width=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 80200
policy=steps
steps=-1,500,20000,30000
# steps=-1,180,360,540
scales=0.1,10,.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
#######
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[route]
layers=-9
[convolutional]
batch_normalize=1
size=1
stride=1
pad=1
filters=64
activation=leaky
[reorg]
stride=2
[route]
layers=-1,-4
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
# filters=125
filters=160
activation=linear
[region]
# anchors = 1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071
anchors = 1.4820, 2.2412, 2.0501, 3.1265, 2.3946, 4.6891, 3.1018, 3.9910, 3.4879, 5.8851
bias_match=1
classes=13
coords=18
num=5
softmax=1
jitter=.3
rescore=1
object_scale=0
noobject_scale=0
class_scale=1
coord_scale=1
absolute=1
thresh = .6
random=1

Просмотреть файл

@ -0,0 +1,261 @@
[net]
# Testing
batch=32
subdivisions=8
# Training
# batch=64
# subdivisions=8
height=416
width=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 80200
policy=steps
steps=-1,100,20000,30000
# steps=-1,180,360,540
scales=0.1,10,.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
#######
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[route]
layers=-9
[convolutional]
batch_normalize=1
size=1
stride=1
pad=1
filters=64
activation=leaky
[reorg]
stride=2
[route]
layers=-1,-4
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
# filters=125
filters=160
activation=linear
[region]
# anchors = 1.3221, 1.73145, 3.19275, 4.00944, 5.05587, 8.09892, 9.47112, 4.84053, 11.2364, 10.0071
anchors = 1.4820, 2.2412, 2.0501, 3.1265, 2.3946, 4.6891, 3.1018, 3.9910, 3.4879, 5.8851
bias_match=1
classes=13
coords=18
num=5
softmax=1
jitter=.3
rescore=1
object_scale=5
noobject_scale=0.1
class_scale=1
coord_scale=1
absolute=1
thresh = .6
random=1

Просмотреть файл

@ -0,0 +1,388 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from region_loss_multi import RegionLoss
from cfg import *
class MaxPoolStride1(nn.Module):
def __init__(self):
super(MaxPoolStride1, self).__init__()
def forward(self, x):
x = F.max_pool2d(F.pad(x, (0,1,0,1), mode='replicate'), 2, stride=1)
return x
class Reorg(nn.Module):
def __init__(self, stride=2):
super(Reorg, self).__init__()
self.stride = stride
def forward(self, x):
stride = self.stride
assert(x.data.dim() == 4)
B = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
assert(H % stride == 0)
assert(W % stride == 0)
ws = stride
hs = stride
x = x.view(B, C, H/hs, hs, W/ws, ws).transpose(3,4).contiguous()
x = x.view(B, C, H/hs*W/ws, hs*ws).transpose(2,3).contiguous()
x = x.view(B, C, hs*ws, H/hs, W/ws).transpose(1,2).contiguous()
x = x.view(B, hs*ws*C, H/hs, W/ws)
return x
class GlobalAvgPool2d(nn.Module):
def __init__(self):
super(GlobalAvgPool2d, self).__init__()
def forward(self, x):
N = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
x = F.avg_pool2d(x, (H, W))
x = x.view(N, C)
return x
# for route and shortcut
class EmptyModule(nn.Module):
def __init__(self):
super(EmptyModule, self).__init__()
def forward(self, x):
return x
# support route shortcut and reorg
class Darknet(nn.Module):
def __init__(self, cfgfile):
super(Darknet, self).__init__()
self.blocks = parse_cfg(cfgfile)
self.models = self.create_network(self.blocks) # merge conv, bn,leaky
self.loss = self.models[len(self.models)-1]
self.width = int(self.blocks[0]['width'])
self.height = int(self.blocks[0]['height'])
if self.blocks[(len(self.blocks)-1)]['type'] == 'region':
self.anchors = self.loss.anchors
self.num_anchors = self.loss.num_anchors
self.anchor_step = self.loss.anchor_step
self.num_classes = self.loss.num_classes
self.header = torch.IntTensor([0,0,0,0])
self.seen = 0
self.iter = 0
def forward(self, x):
ind = -2
self.loss = None
outputs = dict()
for block in self.blocks:
ind = ind + 1
#if ind > 0:
# return x
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional' or block['type'] == 'maxpool' or block['type'] == 'reorg' or block['type'] == 'avgpool' or block['type'] == 'softmax' or block['type'] == 'connected':
x = self.models[ind](x)
outputs[ind] = x
elif block['type'] == 'route':
layers = block['layers'].split(',')
layers = [int(i) if int(i) > 0 else int(i)+ind for i in layers]
if len(layers) == 1:
x = outputs[layers[0]]
outputs[ind] = x
elif len(layers) == 2:
x1 = outputs[layers[0]]
x2 = outputs[layers[1]]
x = torch.cat((x1,x2),1)
outputs[ind] = x
elif block['type'] == 'shortcut':
from_layer = int(block['from'])
activation = block['activation']
from_layer = from_layer if from_layer > 0 else from_layer + ind
x1 = outputs[from_layer]
x2 = outputs[ind-1]
x = x1 + x2
if activation == 'leaky':
x = F.leaky_relu(x, 0.1, inplace=True)
elif activation == 'relu':
x = F.relu(x, inplace=True)
outputs[ind] = x
elif block['type'] == 'region':
continue
if self.loss:
self.loss = self.loss + self.models[ind](x)
else:
self.loss = self.models[ind](x)
outputs[ind] = None
elif block['type'] == 'cost':
continue
else:
print('unknown type %s' % (block['type']))
return x
def print_network(self):
print_cfg(self.blocks)
def create_network(self, blocks):
models = nn.ModuleList()
prev_filters = 3
out_filters =[]
conv_id = 0
for block in blocks:
if block['type'] == 'net':
prev_filters = int(block['channels'])
continue
elif block['type'] == 'convolutional':
conv_id = conv_id + 1
batch_normalize = int(block['batch_normalize'])
filters = int(block['filters'])
kernel_size = int(block['size'])
stride = int(block['stride'])
is_pad = int(block['pad'])
pad = (kernel_size-1)/2 if is_pad else 0
activation = block['activation']
model = nn.Sequential()
if batch_normalize:
model.add_module('conv{0}'.format(conv_id), nn.Conv2d(prev_filters, filters, kernel_size, stride, pad, bias=False))
model.add_module('bn{0}'.format(conv_id), nn.BatchNorm2d(filters, eps=1e-4))
#model.add_module('bn{0}'.format(conv_id), BN2d(filters))
else:
model.add_module('conv{0}'.format(conv_id), nn.Conv2d(prev_filters, filters, kernel_size, stride, pad))
if activation == 'leaky':
model.add_module('leaky{0}'.format(conv_id), nn.LeakyReLU(0.1, inplace=True))
elif activation == 'relu':
model.add_module('relu{0}'.format(conv_id), nn.ReLU(inplace=True))
prev_filters = filters
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'maxpool':
pool_size = int(block['size'])
stride = int(block['stride'])
if stride > 1:
model = nn.MaxPool2d(pool_size, stride)
else:
model = MaxPoolStride1()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'avgpool':
model = GlobalAvgPool2d()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'softmax':
model = nn.Softmax()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'cost':
if block['_type'] == 'sse':
model = nn.MSELoss(size_average=True)
elif block['_type'] == 'L1':
model = nn.L1Loss(size_average=True)
elif block['_type'] == 'smooth':
model = nn.SmoothL1Loss(size_average=True)
out_filters.append(1)
models.append(model)
elif block['type'] == 'reorg':
stride = int(block['stride'])
prev_filters = stride * stride * prev_filters
out_filters.append(prev_filters)
models.append(Reorg(stride))
elif block['type'] == 'route':
layers = block['layers'].split(',')
ind = len(models)
layers = [int(i) if int(i) > 0 else int(i)+ind for i in layers]
if len(layers) == 1:
prev_filters = out_filters[layers[0]]
elif len(layers) == 2:
assert(layers[0] == ind - 1)
prev_filters = out_filters[layers[0]] + out_filters[layers[1]]
out_filters.append(prev_filters)
models.append(EmptyModule())
elif block['type'] == 'shortcut':
ind = len(models)
prev_filters = out_filters[ind-1]
out_filters.append(prev_filters)
models.append(EmptyModule())
elif block['type'] == 'connected':
filters = int(block['output'])
if block['activation'] == 'linear':
model = nn.Linear(prev_filters, filters)
elif block['activation'] == 'leaky':
model = nn.Sequential(
nn.Linear(prev_filters, filters),
nn.LeakyReLU(0.1, inplace=True))
elif block['activation'] == 'relu':
model = nn.Sequential(
nn.Linear(prev_filters, filters),
nn.ReLU(inplace=True))
prev_filters = filters
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'region':
loss = RegionLoss()
anchors = block['anchors'].split(',')
loss.anchors = [float(i) for i in anchors]
loss.num_classes = int(block['classes'])
loss.num_anchors = int(block['num'])
loss.anchor_step = len(loss.anchors)/loss.num_anchors
loss.object_scale = float(block['object_scale'])
loss.noobject_scale = float(block['noobject_scale'])
loss.class_scale = float(block['class_scale'])
loss.coord_scale = float(block['coord_scale'])
out_filters.append(prev_filters)
models.append(loss)
else:
print('unknown type %s' % (block['type']))
return models
def load_weights(self, weightfile):
fp = open(weightfile, 'rb')
header = np.fromfile(fp, count=4, dtype=np.int32)
self.header = torch.from_numpy(header)
self.seen = self.header[3]
buf = np.fromfile(fp, dtype = np.float32)
fp.close()
start = 0
ind = -2
for block in self.blocks:
if start >= buf.size:
break
ind = ind + 1
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
start = load_conv_bn(buf, start, model[0], model[1])
else:
start = load_conv(buf, start, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
start = load_fc(buf, start, model[0])
else:
start = load_fc(buf, start, model)
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
def load_weights_until_last(self, weightfile):
fp = open(weightfile, 'rb')
header = np.fromfile(fp, count=4, dtype=np.int32)
self.header = torch.from_numpy(header)
self.seen = self.header[3]
buf = np.fromfile(fp, dtype = np.float32)
fp.close()
start = 0
ind = -2
blocklen = len(self.blocks)
for i in range(blocklen-2):
block = self.blocks[i]
if start >= buf.size:
break
ind = ind + 1
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
start = load_conv_bn(buf, start, model[0], model[1])
else:
start = load_conv(buf, start, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
start = load_fc(buf, start, model[0])
else:
start = load_fc(buf, start, model)
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
def save_weights(self, outfile, cutoff=0):
if cutoff <= 0:
cutoff = len(self.blocks)-1
fp = open(outfile, 'wb')
self.header[3] = self.seen
header = self.header
header.numpy().tofile(fp)
ind = -1
for blockId in range(1, cutoff+1):
ind = ind + 1
block = self.blocks[blockId]
if block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
save_conv_bn(fp, model[0], model[1])
else:
save_conv(fp, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
save_fc(fc, model)
else:
save_fc(fc, model[0])
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
fp.close()

Просмотреть файл

@ -0,0 +1,94 @@
#!/usr/bin/python
# encoding: utf-8
import os
import random
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from utils import read_truths_args, read_truths, get_all_files
from image_multi import *
class listDataset(Dataset):
def __init__(self, root, shape=None, shuffle=True, transform=None, objclass=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4, bg_file_names=None): # bg='/cvlabdata1/home/btekin/ope/data/office_bg'
with open(root, 'r') as file:
self.lines = file.readlines()
if shuffle:
random.shuffle(self.lines)
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.num_workers = num_workers
# self.bg_file_names = get_all_files(bg)
self.bg_file_names = bg_file_names
self.objclass = objclass
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
imgpath = self.lines[index].rstrip()
if self.train and index % 64== 0:
if self.seen < 4000*64:
width = 13*32
self.shape = (width, width)
elif self.seen < 8000*64:
width = (random.randint(0,3) + 13)*32
self.shape = (width, width)
elif self.seen < 12000*64:
width = (random.randint(0,5) + 12)*32
self.shape = (width, width)
elif self.seen < 16000*64:
width = (random.randint(0,7) + 11)*32
self.shape = (width, width)
else: # self.seen < 20000*64:
width = (random.randint(0,9) + 10)*32
self.shape = (width, width)
if self.train:
# jitter = 0.2
jitter = 0.1
hue = 0.05
saturation = 1.5
exposure = 1.5
# Get background image path
random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
bgpath = self.bg_file_names[random_bg_index]
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath)
label = torch.from_numpy(label)
else:
img = Image.open(imgpath).convert('RGB')
if self.shape:
img = img.resize(self.shape)
labpath = imgpath.replace('benchvise', self.objclass).replace('images', 'labels_occlusion').replace('JPEGImages', 'labels_occlusion').replace('.jpg', '.txt').replace('.png','.txt')
label = torch.zeros(50*21)
if os.path.getsize(labpath):
ow, oh = img.size
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/ow))
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > 50*21:
label = tmp[0:50*21]
elif tsz > 0:
label[0:tsz] = tmp
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.target_transform(label)
self.seen = self.seen + self.num_workers
return (img, label)

Просмотреть файл

@ -0,0 +1,450 @@
#!/usr/bin/python
# encoding: utf-8
import random
import os
from PIL import Image, ImageChops, ImageMath
import numpy as np
def load_data_detection_backup(imgpath, shape, jitter, hue, saturation, exposure, bgpath):
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
maskpath = imgpath.replace('JPEGImages', 'mask').replace('/00', '/').replace('.jpg', '.png')
## data augmentation
img = Image.open(imgpath).convert('RGB')
mask = Image.open(maskpath).convert('RGB')
bg = Image.open(bgpath).convert('RGB')
img = change_background(img, mask, bg)
img,flip,dx,dy,sx,sy = data_augmentation(img, shape, jitter, hue, saturation, exposure)
label = fill_truth_detection(labpath, img.width, img.height, flip, dx, dy, 1./sx, 1./sy)
return img,label
def get_add_objs(objname):
# Decide how many additional objects you will augment and what will be the other types of objects
if objname == 'ape':
add_objs = ['can', 'cat', 'duck', 'glue', 'holepuncher', 'iron', 'phone'] # eggbox
elif objname == 'benchvise':
add_objs = ['ape', 'can', 'cat', 'driller', 'duck', 'glue', 'holepuncher']
elif objname == 'cam':
add_objs = ['ape', 'benchvise', 'can', 'cat', 'driller', 'duck', 'holepuncher']
elif objname == 'can':
add_objs = ['ape', 'benchvise', 'cat', 'driller', 'duck', 'eggbox', 'holepuncher']
elif objname == 'cat':
add_objs = ['ape', 'can', 'duck', 'glue', 'holepuncher', 'eggbox', 'phone']
elif objname == 'driller':
add_objs = ['ape', 'benchvise', 'can', 'cat', 'duck', 'glue', 'holepuncher']
elif objname == 'duck':
add_objs = ['ape', 'can', 'cat', 'eggbox', 'glue', 'holepuncher', 'phone']
elif objname == 'eggbox':
add_objs = ['ape', 'benchvise', 'cam', 'can', 'cat', 'duck', 'glue', 'holepuncher']
elif objname == 'glue':
add_objs = ['ape', 'benchvise', 'cam', 'driller', 'duck', 'eggbox', 'holepuncher' ]
elif objname == 'holepuncher':
add_objs = ['benchvise', 'cam', 'can', 'cat', 'driller', 'duck', 'eggbox']
elif objname == 'iron':
add_objs = ['ape', 'benchvise', 'can', 'cat', 'driller', 'duck', 'glue']
elif objname == 'lamp':
add_objs = ['ape', 'benchvise', 'can', 'driller', 'eggbox', 'holepuncher', 'iron']
elif objname == 'phone':
add_objs = ['ape', 'benchvise', 'cam', 'can', 'driller', 'duck', 'holepuncher']
return add_objs
def mask_background(img, mask):
ow, oh = img.size
imcs = list(img.split())
maskcs = list(mask.split())
fics = list(Image.new(img.mode, img.size).split())
for c in range(len(imcs)):
negmask = maskcs[c].point(lambda i: 1 - i / 255)
posmask = maskcs[c].point(lambda i: i / 255)
fics[c] = ImageMath.eval("a * c", a=imcs[c], c=posmask).convert('L')
out = Image.merge(img.mode, tuple(fics))
return out
def scale_image_channel(im, c, v):
cs = list(im.split())
cs[c] = cs[c].point(lambda i: i * v)
out = Image.merge(im.mode, tuple(cs))
return out
def distort_image(im, hue, sat, val):
im = im.convert('HSV')
cs = list(im.split())
cs[1] = cs[1].point(lambda i: i * sat)
cs[2] = cs[2].point(lambda i: i * val)
def change_hue(x):
x += hue*255
if x > 255:
x -= 255
if x < 0:
x += 255
return x
cs[0] = cs[0].point(change_hue)
im = Image.merge(im.mode, tuple(cs))
im = im.convert('RGB')
#constrain_image(im)
return im
def rand_scale(s):
scale = random.uniform(1, s)
if(random.randint(1,10000)%2):
return scale
return 1./scale
def random_distort_image(im, hue, saturation, exposure):
dhue = random.uniform(-hue, hue)
dsat = rand_scale(saturation)
dexp = rand_scale(exposure)
res = distort_image(im, dhue, dsat, dexp)
return res
def data_augmentation(img, shape, jitter, hue, saturation, exposure):
oh = img.height
ow = img.width
dw =int(ow*jitter)
dh =int(oh*jitter)
pleft = random.randint(-dw, dw)
pright = random.randint(-dw, dw)
ptop = random.randint(-dh, dh)
pbot = random.randint(-dh, dh)
swidth = ow - pleft - pright
sheight = oh - ptop - pbot
sx = float(swidth) / ow
sy = float(sheight) / oh
flip = random.randint(1,10000)%2
cropped = img.crop( (pleft, ptop, pleft + swidth - 1, ptop + sheight - 1))
dx = (float(pleft)/ow)/sx
dy = (float(ptop) /oh)/sy
sized = cropped.resize(shape)
if flip:
sized = sized.transpose(Image.FLIP_LEFT_RIGHT)
img = random_distort_image(sized, hue, saturation, exposure)
return img, flip, dx,dy,sx,sy
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy):
max_boxes = 50
label = np.zeros((max_boxes,21))
if os.path.getsize(labpath):
bs = np.loadtxt(labpath)
if bs is None:
return label
bs = np.reshape(bs, (-1, 21))
cc = 0
for i in range(bs.shape[0]):
x0 = bs[i][1]
y0 = bs[i][2]
x1 = bs[i][3]
y1 = bs[i][4]
x2 = bs[i][5]
y2 = bs[i][6]
x3 = bs[i][7]
y3 = bs[i][8]
x4 = bs[i][9]
y4 = bs[i][10]
x5 = bs[i][11]
y5 = bs[i][12]
x6 = bs[i][13]
y6 = bs[i][14]
x7 = bs[i][15]
y7 = bs[i][16]
x8 = bs[i][17]
y8 = bs[i][18]
x0 = min(0.999, max(0, x0 * sx - dx))
y0 = min(0.999, max(0, y0 * sy - dy))
x1 = min(0.999, max(0, x1 * sx - dx))
y1 = min(0.999, max(0, y1 * sy - dy))
x2 = min(0.999, max(0, x2 * sx - dx))
y2 = min(0.999, max(0, y2 * sy - dy))
x3 = min(0.999, max(0, x3 * sx - dx))
y3 = min(0.999, max(0, y3 * sy - dy))
x4 = min(0.999, max(0, x4 * sx - dx))
y4 = min(0.999, max(0, y4 * sy - dy))
x5 = min(0.999, max(0, x5 * sx - dx))
y5 = min(0.999, max(0, y5 * sy - dy))
x6 = min(0.999, max(0, x6 * sx - dx))
y6 = min(0.999, max(0, y6 * sy - dy))
x7 = min(0.999, max(0, x7 * sx - dx))
y7 = min(0.999, max(0, y7 * sy - dy))
x8 = min(0.999, max(0, x8 * sx - dx))
y8 = min(0.999, max(0, y8 * sy - dy))
bs[i][0] = bs[i][0]
bs[i][1] = x0
bs[i][2] = y0
bs[i][3] = x1
bs[i][4] = y1
bs[i][5] = x2
bs[i][6] = y2
bs[i][7] = x3
bs[i][8] = y3
bs[i][9] = x4
bs[i][10] = y4
bs[i][11] = x5
bs[i][12] = y5
bs[i][13] = x6
bs[i][14] = y6
bs[i][15] = x7
bs[i][16] = y7
bs[i][17] = x8
bs[i][18] = y8
xs = [x1, x2, x3, x4, x5, x6, x7, x8]
ys = [y1, y2, y3, y4, y5, y6, y7, y8]
min_x = min(xs);
max_x = max(xs);
min_y = min(ys);
max_y = max(ys);
bs[i][19] = max_x - min_x;
bs[i][20] = max_y - min_y;
if flip:
bs[i][1] = 0.999 - bs[i][1]
bs[i][3] = 0.999 - bs[i][3]
bs[i][5] = 0.999 - bs[i][5]
bs[i][7] = 0.999 - bs[i][7]
bs[i][9] = 0.999 - bs[i][9]
bs[i][11] = 0.999 - bs[i][11]
bs[i][13] = 0.999 - bs[i][13]
bs[i][15] = 0.999 - bs[i][15]
bs[i][17] = 0.999 - bs[i][17]
label[cc] = bs[i]
cc += 1
if cc >= 50:
break
label = np.reshape(label, (-1))
return label
def change_background(img, mask, bg):
ow, oh = img.size
bg = bg.resize((ow, oh)).convert('RGB')
imcs = list(img.split())
bgcs = list(bg.split())
maskcs = list(mask.split())
fics = list(Image.new(img.mode, img.size).split())
for c in range(len(imcs)):
negmask = maskcs[c].point(lambda i: 1 - i / 255)
posmask = maskcs[c].point(lambda i: i / 255)
fics[c] = ImageMath.eval("a * c + b * d", a=imcs[c], b=bgcs[c], c=posmask, d=negmask).convert('L')
out = Image.merge(img.mode, tuple(fics))
return out
def shifted_data_augmentation_with_mask(img, mask, shape, jitter, hue, saturation, exposure):
ow, oh = img.size
dw =int(ow*jitter)
dh =int(oh*jitter)
pleft = random.randint(-dw, dw)
pright = random.randint(-dw, dw)
ptop = random.randint(-dh, dh)
pbot = random.randint(-dh, dh)
swidth = ow - pleft - pright
sheight = oh - ptop - pbot
sx = float(swidth) / ow
sy = float(sheight) / oh
flip = random.randint(1,10000)%2
cropped = img.crop( (pleft, ptop, pleft + swidth - 1, ptop + sheight - 1))
mask_cropped = mask.crop( (pleft, ptop, pleft + swidth - 1, ptop + sheight - 1))
cw, ch = cropped.size
shift_x = random.randint(-80, 80)
shift_y = random.randint(-80, 80)
dx = (float(pleft)/ow)/sx - (float(shift_x)/shape[0]) # FIX HERE
dy = (float(ptop) /oh)/sy - (float(shift_y)/shape[1]) # FIX HERE
# dx = (float(pleft)/ow)/sx - (float(shift_x)/ow)
# dy = (float(ptop) /oh)/sy - (float(shift_y)/oh)
sized = cropped.resize(shape)
mask_sized = mask_cropped.resize(shape)
sized = ImageChops.offset(sized, shift_x, shift_y)
mask_sized = ImageChops.offset(mask_sized, shift_x, shift_y)
if flip:
sized = sized.transpose(Image.FLIP_LEFT_RIGHT)
mask_sized = mask_sized.transpose(Image.FLIP_LEFT_RIGHT)
img = sized
mask = mask_sized
return img, mask, flip, dx,dy,sx,sy
def data_augmentation_with_mask(img, mask, shape, jitter, hue, saturation, exposure):
ow, oh = img.size
dw =int(ow*jitter)
dh =int(oh*jitter)
pleft = random.randint(-dw, dw)
pright = random.randint(-dw, dw)
ptop = random.randint(-dh, dh)
pbot = random.randint(-dh, dh)
swidth = ow - pleft - pright
sheight = oh - ptop - pbot
sx = float(swidth) / ow
sy = float(sheight) / oh
flip = random.randint(1,10000)%2
cropped = img.crop( (pleft, ptop, pleft + swidth - 1, ptop + sheight - 1))
mask_cropped = mask.crop( (pleft, ptop, pleft + swidth - 1, ptop + sheight - 1))
dx = (float(pleft)/ow)/sx
dy = (float(ptop) /oh)/sy
sized = cropped.resize(shape)
mask_sized = mask_cropped.resize(shape)
if flip:
sized = sized.transpose(Image.FLIP_LEFT_RIGHT)
mask_sized = mask_sized.transpose(Image.FLIP_LEFT_RIGHT)
img = sized
mask = mask_sized
return img, mask, flip, dx,dy,sx,sy
def superimpose_masked_imgs(masked_img, mask, total_mask):
ow, oh = masked_img.size
total_mask = total_mask.resize((ow, oh)).convert('RGB')
imcs = list(masked_img.split())
bgcs = list(total_mask.split())
maskcs = list(mask.split())
fics = list(Image.new(masked_img.mode, masked_img.size).split())
for c in range(len(imcs)):
negmask = maskcs[c].point(lambda i: 1 - i / 255)
posmask = maskcs[c].point(lambda i: i / 255)
fics[c] = ImageMath.eval("a * c + b * d", a=imcs[c], b=bgcs[c], c=posmask, d=negmask).convert('L')
out = Image.merge(masked_img.mode, tuple(fics))
return out
def superimpose_masks(mask, total_mask):
# bg: total_mask
ow, oh = mask.size
total_mask = total_mask.resize((ow, oh)).convert('RGB')
total_maskcs = list(total_mask.split())
maskcs = list(mask.split())
fics = list(Image.new(mask.mode, mask.size).split())
for c in range(len(maskcs)):
negmask = maskcs[c].point(lambda i: 1 - i / 255)
posmask = maskcs[c].point(lambda i: i)
fics[c] = ImageMath.eval("c + b * d", b=total_maskcs[c], c=posmask, d=negmask).convert('L')
out = Image.merge(mask.mode, tuple(fics))
return out
def augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation, exposure):
pixelThreshold = 200
random.shuffle(add_objs)
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
maskpath = imgpath.replace('JPEGImages', 'mask').replace('/00', '/').replace('.jpg', '.png')
# Read the image and the mask
img = Image.open(imgpath).convert('RGB')
iw, ih = img.size
mask = Image.open(maskpath).convert('RGB')
img,mask,flip,dx,dy,sx,sy = shifted_data_augmentation_with_mask(img, mask, shape, jitter, hue, saturation, exposure)
label = fill_truth_detection(labpath, iw, ih, flip, dx, dy, 1./sx, 1./sy)
total_label = np.reshape(label, (-1, 21))
# Mask the background
masked_img = mask_background(img, mask)
mask = mask.resize(shape)
masked_img = masked_img.resize(shape)
# Initialize the total mask and total masked image
total_mask = mask
total_masked_img = masked_img
count = 1
for obj in add_objs:
successful = False
while not successful:
objpath = '../LINEMOD/' + obj + '/train.txt'
with open(objpath, 'r') as objfile:
objlines = objfile.readlines()
rand_index = random.randint(0, len(objlines) - 1)
obj_rand_img_path = '../' + objlines[rand_index].rstrip()
obj_rand_mask_path = obj_rand_img_path.replace('JPEGImages', 'mask').replace('/00', '/').replace('.jpg', '.png')
obj_rand_lab_path = obj_rand_img_path.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
obj_rand_img = Image.open(obj_rand_img_path).convert('RGB')
obj_rand_mask = Image.open(obj_rand_mask_path).convert('RGB')
obj_rand_masked_img = mask_background(obj_rand_img, obj_rand_mask)
obj_rand_masked_img,obj_rand_mask,flip,dx,dy,sx,sy = data_augmentation_with_mask(obj_rand_masked_img, obj_rand_mask, shape, jitter, hue, saturation, exposure)
obj_rand_label = fill_truth_detection(obj_rand_lab_path, iw, ih, flip, dx, dy, 1./sx, 1./sy)
# compute intersection (ratio of the object part intersecting with other object parts over the area of the object)
xx = np.array(obj_rand_mask)
xx = np.where(xx > pixelThreshold, 1, 0)
yy = np.array(total_mask)
yy = np.where(yy > pixelThreshold, 1, 0)
intersection = (xx * yy)
if (np.sum(xx) < 0.01) and (np.sum(xx) > -0.01):
successful = False
continue
intersection_ratio = float(np.sum(intersection)) / float(np.sum(xx))
if intersection_ratio < 0.2:
successful = True
total_mask = superimpose_masks(obj_rand_mask, total_mask) # total_mask + obj_rand_mask
total_masked_img = superimpose_masked_imgs(obj_rand_masked_img, obj_rand_mask, total_masked_img) # total_masked_img + obj_rand_masked_img
obj_rand_label = np.reshape(obj_rand_label, (-1, 21))
total_label[count, :] = obj_rand_label[0, :]
count = count + 1
else:
successful = False
total_masked_img = superimpose_masked_imgs(masked_img, mask, total_masked_img)
return total_masked_img, np.reshape(total_label, (-1)), total_mask
def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpath):
# Read the background image
bg = Image.open(bgpath).convert('RGB')
# Understand which object it is and get the neighboring objects
dirname = os.path.dirname(os.path.dirname(imgpath)) ## dir of dir of file
objname = os.path.basename(dirname)
add_objs = get_add_objs(objname)
# Add additional objects in the scene, apply data augmentation on the objects
total_masked_img, label, total_mask = augment_objects(imgpath, objname, add_objs, shape, jitter, hue, saturation, exposure)
img = change_background(total_masked_img, total_mask, bg)
lb = np.reshape(label, (-1, 21))
return img,label

Просмотреть файл

@ -0,0 +1,309 @@
import time
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from utils import *
def build_targets(pred_corners, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, seen):
nB = target.size(0)
nA = num_anchors
nC = num_classes
anchor_step = len(anchors)/num_anchors
conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale
coord_mask = torch.zeros(nB, nA, nH, nW)
cls_mask = torch.zeros(nB, nA, nH, nW)
tx0 = torch.zeros(nB, nA, nH, nW)
ty0 = torch.zeros(nB, nA, nH, nW)
tx1 = torch.zeros(nB, nA, nH, nW)
ty1 = torch.zeros(nB, nA, nH, nW)
tx2 = torch.zeros(nB, nA, nH, nW)
ty2 = torch.zeros(nB, nA, nH, nW)
tx3 = torch.zeros(nB, nA, nH, nW)
ty3 = torch.zeros(nB, nA, nH, nW)
tx4 = torch.zeros(nB, nA, nH, nW)
ty4 = torch.zeros(nB, nA, nH, nW)
tx5 = torch.zeros(nB, nA, nH, nW)
ty5 = torch.zeros(nB, nA, nH, nW)
tx6 = torch.zeros(nB, nA, nH, nW)
ty6 = torch.zeros(nB, nA, nH, nW)
tx7 = torch.zeros(nB, nA, nH, nW)
ty7 = torch.zeros(nB, nA, nH, nW)
tx8 = torch.zeros(nB, nA, nH, nW)
ty8 = torch.zeros(nB, nA, nH, nW)
tconf = torch.zeros(nB, nA, nH, nW)
tcls = torch.zeros(nB, nA, nH, nW)
nAnchors = nA*nH*nW
nPixels = nH*nW
for b in xrange(nB):
cur_pred_corners = pred_corners[b*nAnchors:(b+1)*nAnchors].t()
cur_confs = torch.zeros(nAnchors)
for t in xrange(50):
if target[b][t*21+1] == 0:
break
gx0 = target[b][t*21+1]*nW
gy0 = target[b][t*21+2]*nH
gx1 = target[b][t*21+3]*nW
gy1 = target[b][t*21+4]*nH
gx2 = target[b][t*21+5]*nW
gy2 = target[b][t*21+6]*nH
gx3 = target[b][t*21+7]*nW
gy3 = target[b][t*21+8]*nH
gx4 = target[b][t*21+9]*nW
gy4 = target[b][t*21+10]*nH
gx5 = target[b][t*21+11]*nW
gy5 = target[b][t*21+12]*nH
gx6 = target[b][t*21+13]*nW
gy6 = target[b][t*21+14]*nH
gx7 = target[b][t*21+15]*nW
gy7 = target[b][t*21+16]*nH
gx8 = target[b][t*21+17]*nW
gy8 = target[b][t*21+18]*nH
cur_gt_corners = torch.FloatTensor([gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]).repeat(nAnchors,1).t() # 16 x nAnchors
cur_confs = torch.max(cur_confs, corner_confidences9(cur_pred_corners, cur_gt_corners)) # some irrelevant areas are filtered, in the same grid multiple anchor boxes might exceed the threshold
conf_mask[b][cur_confs>sil_thresh] = 0
if seen < -1:#6400:
tx0.fill_(0.5)
ty0.fill_(0.5)
tx1.fill_(0.5)
ty1.fill_(0.5)
tx2.fill_(0.5)
ty2.fill_(0.5)
tx3.fill_(0.5)
ty3.fill_(0.5)
tx4.fill_(0.5)
ty4.fill_(0.5)
tx5.fill_(0.5)
ty5.fill_(0.5)
tx6.fill_(0.5)
ty6.fill_(0.5)
tx7.fill_(0.5)
ty7.fill_(0.5)
tx8.fill_(0.5)
ty8.fill_(0.5)
coord_mask.fill_(1)
nGT = 0
nCorrect = 0
for b in xrange(nB):
for t in xrange(50):
if target[b][t*21+1] == 0:
break
nGT = nGT + 1
best_iou = 0.0
best_n = -1
min_dist = 10000
gx0 = target[b][t*21+1] * nW
gy0 = target[b][t*21+2] * nH
gi0 = int(gx0)
gj0 = int(gy0)
gx1 = target[b][t*21+3] * nW
gy1 = target[b][t*21+4] * nH
gx2 = target[b][t*21+5] * nW
gy2 = target[b][t*21+6] * nH
gx3 = target[b][t*21+7] * nW
gy3 = target[b][t*21+8] * nH
gx4 = target[b][t*21+9] * nW
gy4 = target[b][t*21+10] * nH
gx5 = target[b][t*21+11] * nW
gy5 = target[b][t*21+12] * nH
gx6 = target[b][t*21+13] * nW
gy6 = target[b][t*21+14] * nH
gx7 = target[b][t*21+15] * nW
gy7 = target[b][t*21+16] * nH
gx8 = target[b][t*21+17] * nW
gy8 = target[b][t*21+18] * nH
gw = target[b][t*21+19]*nW
gh = target[b][t*21+20]*nH
gt_box = [0, 0, gw, gh]
for n in xrange(nA):
aw = anchors[anchor_step*n]
ah = anchors[anchor_step*n+1]
anchor_box = [0, 0, aw, ah]
iou = bbox_iou(anchor_box, gt_box, x1y1x2y2=False)
if iou > best_iou:
best_iou = iou
best_n = n
gt_box = [gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]
pred_box = pred_corners[b*nAnchors+best_n*nPixels+gj0*nW+gi0]
conf = corner_confidence9(gt_box, pred_box)
coord_mask[b][best_n][gj0][gi0] = 1
cls_mask[b][best_n][gj0][gi0] = 1
conf_mask[b][best_n][gj0][gi0] = object_scale
tx0[b][best_n][gj0][gi0] = target[b][t*21+1] * nW - gi0
ty0[b][best_n][gj0][gi0] = target[b][t*21+2] * nH - gj0
tx1[b][best_n][gj0][gi0] = target[b][t*21+3] * nW - gi0
ty1[b][best_n][gj0][gi0] = target[b][t*21+4] * nH - gj0
tx2[b][best_n][gj0][gi0] = target[b][t*21+5] * nW - gi0
ty2[b][best_n][gj0][gi0] = target[b][t*21+6] * nH - gj0
tx3[b][best_n][gj0][gi0] = target[b][t*21+7] * nW - gi0
ty3[b][best_n][gj0][gi0] = target[b][t*21+8] * nH - gj0
tx4[b][best_n][gj0][gi0] = target[b][t*21+9] * nW - gi0
ty4[b][best_n][gj0][gi0] = target[b][t*21+10] * nH - gj0
tx5[b][best_n][gj0][gi0] = target[b][t*21+11] * nW - gi0
ty5[b][best_n][gj0][gi0] = target[b][t*21+12] * nH - gj0
tx6[b][best_n][gj0][gi0] = target[b][t*21+13] * nW - gi0
ty6[b][best_n][gj0][gi0] = target[b][t*21+14] * nH - gj0
tx7[b][best_n][gj0][gi0] = target[b][t*21+15] * nW - gi0
ty7[b][best_n][gj0][gi0] = target[b][t*21+16] * nH - gj0
tx8[b][best_n][gj0][gi0] = target[b][t*21+17] * nW - gi0
ty8[b][best_n][gj0][gi0] = target[b][t*21+18] * nH - gj0
tconf[b][best_n][gj0][gi0] = conf
tcls[b][best_n][gj0][gi0] = target[b][t*21]
if conf > 0.5:
nCorrect = nCorrect + 1
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls
class RegionLoss(nn.Module):
def __init__(self, num_classes=0, anchors=[], num_anchors=5):
super(RegionLoss, self).__init__()
self.num_classes = num_classes
self.anchors = anchors
self.num_anchors = num_anchors
self.anchor_step = len(anchors)/num_anchors
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.seen = 0
def forward(self, output, target):
# Parameters
t0 = time.time()
nB = output.data.size(0)
nA = self.num_anchors
nC = self.num_classes
nH = output.data.size(2)
nW = output.data.size(3)
# Activation
output = output.view(nB, nA, (19+nC), nH, nW)
x0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW))
y0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW))
x1 = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW)
y1 = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW)
x2 = output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW)
y2 = output.index_select(2, Variable(torch.cuda.LongTensor([5]))).view(nB, nA, nH, nW)
x3 = output.index_select(2, Variable(torch.cuda.LongTensor([6]))).view(nB, nA, nH, nW)
y3 = output.index_select(2, Variable(torch.cuda.LongTensor([7]))).view(nB, nA, nH, nW)
x4 = output.index_select(2, Variable(torch.cuda.LongTensor([8]))).view(nB, nA, nH, nW)
y4 = output.index_select(2, Variable(torch.cuda.LongTensor([9]))).view(nB, nA, nH, nW)
x5 = output.index_select(2, Variable(torch.cuda.LongTensor([10]))).view(nB, nA, nH, nW)
y5 = output.index_select(2, Variable(torch.cuda.LongTensor([11]))).view(nB, nA, nH, nW)
x6 = output.index_select(2, Variable(torch.cuda.LongTensor([12]))).view(nB, nA, nH, nW)
y6 = output.index_select(2, Variable(torch.cuda.LongTensor([13]))).view(nB, nA, nH, nW)
x7 = output.index_select(2, Variable(torch.cuda.LongTensor([14]))).view(nB, nA, nH, nW)
y7 = output.index_select(2, Variable(torch.cuda.LongTensor([15]))).view(nB, nA, nH, nW)
x8 = output.index_select(2, Variable(torch.cuda.LongTensor([16]))).view(nB, nA, nH, nW)
y8 = output.index_select(2, Variable(torch.cuda.LongTensor([17]))).view(nB, nA, nH, nW)
conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([18]))).view(nB, nA, nH, nW))
cls = output.index_select(2, Variable(torch.linspace(19,19+nC-1,nC).long().cuda()))
cls = cls.view(nB*nA, nC, nH*nW).transpose(1,2).contiguous().view(nB*nA*nH*nW, nC)
t1 = time.time()
# Create pred boxes
pred_corners = torch.cuda.FloatTensor(18, nB*nA*nH*nW)
grid_x = torch.linspace(0, nW-1, nW).repeat(nH,1).repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
grid_y = torch.linspace(0, nH-1, nH).repeat(nW,1).t().repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
pred_corners[0] = (x0.data + grid_x) / nW
pred_corners[1] = (y0.data + grid_y) / nH
pred_corners[2] = (x1.data + grid_x) / nW
pred_corners[3] = (y1.data + grid_y) / nH
pred_corners[4] = (x2.data + grid_x) / nW
pred_corners[5] = (y2.data + grid_y) / nH
pred_corners[6] = (x3.data + grid_x) / nW
pred_corners[7] = (y3.data + grid_y) / nH
pred_corners[8] = (x4.data + grid_x) / nW
pred_corners[9] = (y4.data + grid_y) / nH
pred_corners[10] = (x5.data + grid_x) / nW
pred_corners[11] = (y5.data + grid_y) / nH
pred_corners[12] = (x6.data + grid_x) / nW
pred_corners[13] = (y6.data + grid_y) / nH
pred_corners[14] = (x7.data + grid_x) / nW
pred_corners[15] = (y7.data + grid_y) / nH
pred_corners[16] = (x8.data + grid_x) / nW
pred_corners[17] = (y8.data + grid_y) / nH
gpu_matrix = pred_corners.transpose(0,1).contiguous().view(-1,18)
pred_corners = convert2cpu(gpu_matrix)
t2 = time.time()
# Build targets
nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls = \
build_targets(pred_corners, target.data, self.anchors, nA, nC, nH, nW, self.noobject_scale, self.object_scale, self.thresh, self.seen)
cls_mask = (cls_mask == 1)
nProposals = int((conf > 0.25).sum().data[0])
tx0 = Variable(tx0.cuda())
ty0 = Variable(ty0.cuda())
tx1 = Variable(tx1.cuda())
ty1 = Variable(ty1.cuda())
tx2 = Variable(tx2.cuda())
ty2 = Variable(ty2.cuda())
tx3 = Variable(tx3.cuda())
ty3 = Variable(ty3.cuda())
tx4 = Variable(tx4.cuda())
ty4 = Variable(ty4.cuda())
tx5 = Variable(tx5.cuda())
ty5 = Variable(ty5.cuda())
tx6 = Variable(tx6.cuda())
ty6 = Variable(ty6.cuda())
tx7 = Variable(tx7.cuda())
ty7 = Variable(ty7.cuda())
tx8 = Variable(tx8.cuda())
ty8 = Variable(ty8.cuda())
tconf = Variable(tconf.cuda())
tcls = Variable(tcls.view(-1)[cls_mask].long().cuda())
coord_mask = Variable(coord_mask.cuda())
conf_mask = Variable(conf_mask.cuda().sqrt())
cls_mask = Variable(cls_mask.view(-1, 1).repeat(1,nC).cuda())
cls = cls[cls_mask].view(-1, nC)
t3 = time.time()
# Create loss
loss_x0 = self.coord_scale * nn.MSELoss(size_average=False)(x0*coord_mask, tx0*coord_mask)/2.0
loss_y0 = self.coord_scale * nn.MSELoss(size_average=False)(y0*coord_mask, ty0*coord_mask)/2.0
loss_x1 = self.coord_scale * nn.MSELoss(size_average=False)(x1*coord_mask, tx1*coord_mask)/2.0
loss_y1 = self.coord_scale * nn.MSELoss(size_average=False)(y1*coord_mask, ty1*coord_mask)/2.0
loss_x2 = self.coord_scale * nn.MSELoss(size_average=False)(x2*coord_mask, tx2*coord_mask)/2.0
loss_y2 = self.coord_scale * nn.MSELoss(size_average=False)(y2*coord_mask, ty2*coord_mask)/2.0
loss_x3 = self.coord_scale * nn.MSELoss(size_average=False)(x3*coord_mask, tx3*coord_mask)/2.0
loss_y3 = self.coord_scale * nn.MSELoss(size_average=False)(y3*coord_mask, ty3*coord_mask)/2.0
loss_x4 = self.coord_scale * nn.MSELoss(size_average=False)(x4*coord_mask, tx4*coord_mask)/2.0
loss_y4 = self.coord_scale * nn.MSELoss(size_average=False)(y4*coord_mask, ty4*coord_mask)/2.0
loss_x5 = self.coord_scale * nn.MSELoss(size_average=False)(x5*coord_mask, tx5*coord_mask)/2.0
loss_y5 = self.coord_scale * nn.MSELoss(size_average=False)(y5*coord_mask, ty5*coord_mask)/2.0
loss_x6 = self.coord_scale * nn.MSELoss(size_average=False)(x6*coord_mask, tx6*coord_mask)/2.0
loss_y6 = self.coord_scale * nn.MSELoss(size_average=False)(y6*coord_mask, ty6*coord_mask)/2.0
loss_x7 = self.coord_scale * nn.MSELoss(size_average=False)(x7*coord_mask, tx7*coord_mask)/2.0
loss_y7 = self.coord_scale * nn.MSELoss(size_average=False)(y7*coord_mask, ty7*coord_mask)/2.0
loss_x8 = self.coord_scale * nn.MSELoss(size_average=False)(x8*coord_mask, tx8*coord_mask)/2.0
loss_y8 = self.coord_scale * nn.MSELoss(size_average=False)(y8*coord_mask, ty8*coord_mask)/2.0
loss_conf = nn.MSELoss(size_average=False)(conf*conf_mask, tconf*conf_mask)/2.0
loss_x = loss_x0 + loss_x1 + loss_x2 + loss_x3 + loss_x4 + loss_x5 + loss_x6 + loss_x7 + loss_x8
loss_y = loss_y0 + loss_y1 + loss_y2 + loss_y3 + loss_y4 + loss_y5 + loss_y6 + loss_y7 + loss_y8
loss_cls = self.class_scale * nn.CrossEntropyLoss(size_average=False)(cls, tcls)
loss = loss_x + loss_y + loss_conf + loss_cls
print('%d: nGT %d, recall %d, proposals %d, loss: x0: %f x %f, y0: %f y %f, conf %f, cls %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x0.data[0], loss_x.data[0], loss_y0.data[0], loss_y.data[0], loss_conf.data[0], loss_cls.data[0], loss.data[0]))
#else:
# loss = loss_x + loss_y + loss_conf
# print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss.data[0]))
t4 = time.time()
if False:
print('-----------------------------------')
print(' activation : %f' % (t1 - t0))
print(' create pred_corners : %f' % (t2 - t1))
print(' build targets : %f' % (t3 - t2))
print(' create loss : %f' % (t4 - t3))
print(' total : %f' % (t4 - t0))
return loss

Просмотреть файл

@ -0,0 +1,424 @@
from __future__ import print_function
import os
os.sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import numpy as np
import random
import math
import shutil
from torchvision import datasets, transforms
from torch.autograd import Variable # Useful info about autograd: http://pytorch.org/docs/master/notes/autograd.html
from darknet_multi import Darknet
from MeshPly import MeshPly
from utils import *
from cfg import parse_cfg
import dataset_multi
from region_loss_multi import RegionLoss
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
# Adjust learning rate during training, learning schedule can be changed in network config file
def adjust_learning_rate(optimizer, batch):
lr = learning_rate
for i in range(len(steps)):
scale = scales[i] if i < len(scales) else 1
if batch >= steps[i]:
lr = lr * scale
if batch == steps[i]:
break
else:
break
for param_group in optimizer.param_groups:
param_group['lr'] = lr/batch_size
return lr
def train(epoch):
global processed_batches
# Initialize timer
t0 = time.time()
# Get the dataloader for training dataset
train_loader = torch.utils.data.DataLoader(dataset_multi.listDataset(trainlist, shape=(init_width, init_height),
shuffle=True,
transform=transforms.Compose([transforms.ToTensor(),]),
train=True,
seen=model.module.seen,
batch_size=batch_size,
num_workers=num_workers, bg_file_names=bg_file_names),
batch_size=batch_size, shuffle=False, **kwargs)
# TRAINING
lr = adjust_learning_rate(optimizer, processed_batches)
logging('epoch %d, processed %d samples, lr %f' % (epoch, epoch * len(train_loader.dataset), lr))
# Start training
model.train()
t1 = time.time()
avg_time = torch.zeros(9)
niter = 0
# Iterate through batches
for batch_idx, (data, target) in enumerate(train_loader):
t2 = time.time()
# adjust learning rate
adjust_learning_rate(optimizer, processed_batches)
processed_batches = processed_batches + 1
# Pass the data to GPU
if use_cuda:
data = data.cuda()
t3 = time.time()
# Wrap tensors in Variable class for automatic differentiation
data, target = Variable(data), Variable(target)
t4 = time.time()
# Zero the gradients before running the backward pass
optimizer.zero_grad()
t5 = time.time()
# Forward pass
output = model(data)
t6 = time.time()
region_loss.seen = region_loss.seen + data.data.size(0)
# Compute loss, grow an array of losses for saving later on
loss = region_loss(output, target)
training_iters.append(epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter)
training_losses.append(convert2cpu(loss.data))
niter += 1
t7 = time.time()
# Backprop: compute gradient of the loss with respect to model parameters
loss.backward()
t8 = time.time()
# Update weights
optimizer.step()
t9 = time.time()
# Print time statistics
if False and batch_idx > 1:
avg_time[0] = avg_time[0] + (t2-t1)
avg_time[1] = avg_time[1] + (t3-t2)
avg_time[2] = avg_time[2] + (t4-t3)
avg_time[3] = avg_time[3] + (t5-t4)
avg_time[4] = avg_time[4] + (t6-t5)
avg_time[5] = avg_time[5] + (t7-t6)
avg_time[6] = avg_time[6] + (t8-t7)
avg_time[7] = avg_time[7] + (t9-t8)
avg_time[8] = avg_time[8] + (t9-t1)
print('-------------------------------')
print(' load data : %f' % (avg_time[0]/(batch_idx)))
print(' cpu to cuda : %f' % (avg_time[1]/(batch_idx)))
print('cuda to variable : %f' % (avg_time[2]/(batch_idx)))
print(' zero_grad : %f' % (avg_time[3]/(batch_idx)))
print(' forward feature : %f' % (avg_time[4]/(batch_idx)))
print(' forward loss : %f' % (avg_time[5]/(batch_idx)))
print(' backward : %f' % (avg_time[6]/(batch_idx)))
print(' step : %f' % (avg_time[7]/(batch_idx)))
print(' total : %f' % (avg_time[8]/(batch_idx)))
t1 = time.time()
t1 = time.time()
return epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter - 1
def eval(niter, datacfg, cfgfile):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
# Parse configuration files
options = read_data_cfg(datacfg)
valid_images = options['valid']
meshname = options['mesh']
backupdir = options['backup']
name = options['name']
prefix = 'results'
# Read object model information, get 3D bounding box corners
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
# Read intrinsic camera parameters
internal_calibration = get_camera_intrinsic()
# Get validation file names
with open(valid_images) as fp:
tmp_files = fp.readlines()
valid_files = [item.rstrip() for item in tmp_files]
# Specify model, load pretrained weights, pass to GPU and set the module in evaluation mode
model.eval()
# Get the parser for the test dataset
valid_dataset = dataset_multi.listDataset(valid_images, shape=(model.module.width, model.module.height),
shuffle=False,
objclass=name,
transform=transforms.Compose([
transforms.ToTensor(),
]))
valid_batchsize = 1
# Specify the number of workers for multiple processing, get the dataloader for the test dataset
kwargs = {'num_workers': 4, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs)
# Parameters
num_classes = model.module.num_classes
anchors = model.module.anchors
num_anchors = model.module.num_anchors
testing_error_pixel = 0.0
testing_samples = 0.0
errs_2d = []
logging(" Number of test samples: %d" % len(test_loader.dataset))
# Iterate through test examples
for batch_idx, (data, target) in enumerate(test_loader):
t1 = time.time()
# Pass the data to GPU
if use_cuda:
data = data.cuda()
target = target.cuda()
# Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference
data = Variable(data, volatile=True)
t2 = time.time()
# Formward pass
output = model(data).data
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
trgt = target[0].view(-1, 21)
all_boxes = get_corresponding_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, int(trgt[0][0]), only_objectness=0)
t4 = time.time()
# Iterate through all batch elements
for i in range(output.size(0)):
# For each image, get all the predictions
boxes = all_boxes[i]
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
# Get how many objects are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
# If the prediction has the highest confidence, choose it as our prediction
for j in range(len(boxes)):
if (boxes[j][18] > best_conf_est) and (boxes[j][20] == int(truths[k][0])):
best_conf_est = boxes[j][18]
box_pr = boxes[j]
bb2d_gt = get_2d_bb(box_gt[:18], output.size(3))
bb2d_pr = get_2d_bb(box_pr[:18], output.size(3))
iou = bbox_iou(bb2d_gt, bb2d_pr)
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width
corners2D_pr[:, 1] = corners2D_pr[:, 1] * im_height
corners2D_gt_corrected = fix_corner_order(corners2D_gt) # Fix the order of the corners in OCCLUSION
# Compute [R|t] by pnp
objpoints3D = np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32')
K = np.array(internal_calibration, dtype='float32')
R_gt, t_gt = pnp(objpoints3D, corners2D_gt_corrected, K)
R_pr, t_pr = pnp(objpoints3D, corners2D_pr, K)
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, internal_calibration))
proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, internal_calibration))
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
# Sum errors
testing_error_pixel += pixel_dist
testing_samples += 1
t5 = time.time()
# Compute 2D reprojection score
for px_threshold in [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]:
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))
if True:
logging('-----------------------------------')
logging(' tensor to cuda : %f' % (t2 - t1))
logging(' predict : %f' % (t3 - t2))
logging('get_region_boxes : %f' % (t4 - t3))
logging(' eval : %f' % (t5 - t4))
logging(' total : %f' % (t5 - t1))
logging('-----------------------------------')
# Register losses and errors for saving later on
testing_iters.append(niter)
testing_errors_pixel.append(testing_error_pixel/(float(testing_samples)+eps))
testing_accuracies.append(acc)
def test(niter):
cfgfile = 'cfg/yolo-pose-multi.cfg'
datacfg = 'cfg/ape_occlusion.data'
logging("Testing ape...")
eval(niter, datacfg, cfgfile)
datacfg = 'cfg/can_occlusion.data'
logging("Testing can...")
eval(niter, datacfg, cfgfile)
datacfg = 'cfg/cat_occlusion.data'
logging("Testing cat...")
eval(niter, datacfg, cfgfile)
datacfg = 'cfg/duck_occlusion.data'
logging("Testing duck...")
eval(niter, datacfg, cfgfile)
datacfg = 'cfg/driller_occlusion.data'
logging("Testing driller...")
eval(niter, datacfg, cfgfile)
datacfg = 'cfg/glue_occlusion.data'
logging("Testing glue...")
eval(niter, datacfg, cfgfile)
# datacfg = 'cfg/holepuncher_occlusion.data'
# logging("Testing holepuncher...")
# eval(niter, datacfg, cfgfile)
if __name__ == "__main__":
# Training settings
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
weightfile = sys.argv[3]
# Parse configuration files
data_options = read_data_cfg(datacfg)
net_options = parse_cfg(cfgfile)[0]
trainlist = data_options['train']
nsamples = file_lines(trainlist)
gpus = data_options['gpus'] # e.g. 0,1,2,3
gpus = '0'
num_workers = int(data_options['num_workers'])
backupdir = data_options['backup']
if not os.path.exists(backupdir):
makedirs(backupdir)
batch_size = int(net_options['batch'])
max_batches = int(net_options['max_batches'])
learning_rate = float(net_options['learning_rate'])
momentum = float(net_options['momentum'])
decay = float(net_options['decay'])
steps = [float(step) for step in net_options['steps'].split(',')]
scales = [float(scale) for scale in net_options['scales'].split(',')]
bg_file_names = get_all_files('../VOCdevkit/VOC2012/JPEGImages')
# Train parameters
max_epochs = 700 # max_batches*batch_size/nsamples+1
use_cuda = True
seed = int(time.time())
eps = 1e-5
save_interval = 10 # epoches
dot_interval = 70 # batches
best_acc = -1
# Test parameters
conf_thresh = 0.05
nms_thresh = 0.4
match_thresh = 0.5
iou_thresh = 0.5
im_width = 640
im_height = 480
# Specify which gpus to use
torch.manual_seed(seed)
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
# Specifiy the model and the loss
model = Darknet(cfgfile)
region_loss = model.loss
# Model settings
# model.load_weights(weightfile)
model.load_weights_until_last(weightfile)
model.print_network()
model.seen = 0
region_loss.iter = model.iter
region_loss.seen = model.seen
processed_batches = model.seen/batch_size
init_width = model.width
init_height = model.height
init_epoch = model.seen/nsamples
# Variable to save
training_iters = []
training_losses = []
testing_iters = []
testing_errors_pixel = []
testing_accuracies = []
# Specify the number of workers
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
# Pass the model to GPU
if use_cuda:
# model = model.cuda()
model = torch.nn.DataParallel(model, device_ids=[0]).cuda() # Multiple GPU parallelism
# Get the optimizer
params_dict = dict(model.named_parameters())
params = []
for key, value in params_dict.items():
if key.find('.bn') >= 0 or key.find('.bias') >= 0:
params += [{'params': [value], 'weight_decay': 0.0}]
else:
params += [{'params': [value], 'weight_decay': decay*batch_size}]
optimizer = optim.SGD(model.parameters(), lr=learning_rate/batch_size, momentum=momentum, dampening=0, weight_decay=decay*batch_size)
# optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam optimization
evaluate = False
if evaluate:
logging('evaluating ...')
test(0, 0)
else:
for epoch in range(init_epoch, max_epochs):
# TRAIN
niter = train(epoch)
# TEST and SAVE
if (epoch % 20 == 0) and (epoch is not 0):
test(niter)
logging('save training stats to %s/costs.npz' % (backupdir))
np.savez(os.path.join(backupdir, "costs.npz"),
training_iters=training_iters,
training_losses=training_losses,
testing_iters=testing_iters,
testing_accuracies=testing_accuracies,
testing_errors_pixel=testing_errors_pixel)
if (np.mean(testing_accuracies[-5:]) > best_acc ):
best_acc = np.mean(testing_accuracies[-5:])
logging('best model so far!')
logging('save weights to %s/model.weights' % (backupdir))
model.module.save_weights('%s/model.weights' % (backupdir))
shutil.copy2('%s/model.weights' % (backupdir), '%s/model_backup.weights' % (backupdir))

Просмотреть файл

@ -0,0 +1,343 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"os.sys.path.append('..')\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
"import torch\n",
"from torch.autograd import Variable\n",
"from torchvision import datasets, transforms\n",
"from scipy.misc import imsave\n",
"import scipy.io\n",
"import warnings\n",
"import sys\n",
"warnings.filterwarnings(\"ignore\")\n",
"import matplotlib.pyplot as plt\n",
"import scipy.misc\n",
"\n",
"from darknet_multi import Darknet\n",
"from utils import *\n",
"import dataset_multi\n",
"from MeshPly import MeshPly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2018-05-06 14:09:50 Testing ape...\n",
"2018-05-06 14:10:15 Acc using 5 px 2D Projection = 7.01%\n",
"2018-05-06 14:10:15 Acc using 10 px 2D Projection = 40.43%\n",
"2018-05-06 14:10:15 Acc using 15 px 2D Projection = 59.83%\n",
"2018-05-06 14:10:15 Acc using 20 px 2D Projection = 68.55%\n",
"2018-05-06 14:10:15 Acc using 25 px 2D Projection = 72.05%\n",
"2018-05-06 14:10:15 Acc using 30 px 2D Projection = 73.68%\n",
"2018-05-06 14:10:15 Acc using 35 px 2D Projection = 74.53%\n",
"2018-05-06 14:10:15 Acc using 40 px 2D Projection = 75.13%\n",
"2018-05-06 14:10:15 Acc using 45 px 2D Projection = 75.73%\n",
"2018-05-06 14:10:15 Acc using 50 px 2D Projection = 76.50%\n",
"2018-05-06 14:10:18 Testing can...\n",
"2018-05-06 14:10:47 Acc using 5 px 2D Projection = 11.18%\n",
"2018-05-06 14:10:47 Acc using 10 px 2D Projection = 57.83%\n",
"2018-05-06 14:10:47 Acc using 15 px 2D Projection = 79.95%\n",
"2018-05-06 14:10:47 Acc using 20 px 2D Projection = 85.75%\n",
"2018-05-06 14:10:47 Acc using 25 px 2D Projection = 88.73%\n",
"2018-05-06 14:10:47 Acc using 30 px 2D Projection = 90.39%\n",
"2018-05-06 14:10:47 Acc using 35 px 2D Projection = 91.80%\n",
"2018-05-06 14:10:47 Acc using 40 px 2D Projection = 93.21%\n",
"2018-05-06 14:10:47 Acc using 45 px 2D Projection = 93.62%\n",
"2018-05-06 14:10:47 Acc using 50 px 2D Projection = 93.79%\n",
"2018-05-06 14:10:50 Testing cat...\n",
"2018-05-06 14:11:16 Acc using 5 px 2D Projection = 3.62%\n",
"2018-05-06 14:11:16 Acc using 10 px 2D Projection = 23.25%\n",
"2018-05-06 14:11:16 Acc using 15 px 2D Projection = 39.51%\n",
"2018-05-06 14:11:16 Acc using 20 px 2D Projection = 49.45%\n",
"2018-05-06 14:11:16 Acc using 25 px 2D Projection = 54.76%\n",
"2018-05-06 14:11:16 Acc using 30 px 2D Projection = 57.96%\n",
"2018-05-06 14:11:16 Acc using 35 px 2D Projection = 59.56%\n",
"2018-05-06 14:11:16 Acc using 40 px 2D Projection = 60.99%\n",
"2018-05-06 14:11:16 Acc using 45 px 2D Projection = 62.51%\n",
"2018-05-06 14:11:16 Acc using 50 px 2D Projection = 63.27%\n",
"2018-05-06 14:11:19 Testing duck...\n",
"2018-05-06 14:11:42 Acc using 5 px 2D Projection = 5.07%\n",
"2018-05-06 14:11:42 Acc using 10 px 2D Projection = 18.20%\n",
"2018-05-06 14:11:42 Acc using 15 px 2D Projection = 30.88%\n",
"2018-05-06 14:11:42 Acc using 20 px 2D Projection = 55.12%\n",
"2018-05-06 14:11:42 Acc using 25 px 2D Projection = 75.15%\n",
"2018-05-06 14:11:42 Acc using 30 px 2D Projection = 81.45%\n",
"2018-05-06 14:11:42 Acc using 35 px 2D Projection = 83.20%\n",
"2018-05-06 14:11:42 Acc using 40 px 2D Projection = 83.64%\n",
"2018-05-06 14:11:42 Acc using 45 px 2D Projection = 83.90%\n",
"2018-05-06 14:11:42 Acc using 50 px 2D Projection = 84.16%\n",
"2018-05-06 14:11:45 Testing driller...\n",
"2018-05-06 14:12:10 Acc using 5 px 2D Projection = 1.40%\n",
"2018-05-06 14:12:10 Acc using 10 px 2D Projection = 17.38%\n",
"2018-05-06 14:12:10 Acc using 15 px 2D Projection = 39.87%\n",
"2018-05-06 14:12:10 Acc using 20 px 2D Projection = 62.93%\n",
"2018-05-06 14:12:10 Acc using 25 px 2D Projection = 80.64%\n",
"2018-05-06 14:12:10 Acc using 30 px 2D Projection = 89.87%\n",
"2018-05-06 14:12:10 Acc using 35 px 2D Projection = 94.89%\n",
"2018-05-06 14:12:10 Acc using 40 px 2D Projection = 95.88%\n",
"2018-05-06 14:12:10 Acc using 45 px 2D Projection = 96.54%\n",
"2018-05-06 14:12:10 Acc using 50 px 2D Projection = 96.87%\n",
"2018-05-06 14:12:13 Testing glue...\n",
"2018-05-06 14:12:31 Acc using 5 px 2D Projection = 6.53%\n",
"2018-05-06 14:12:31 Acc using 10 px 2D Projection = 26.91%\n",
"2018-05-06 14:12:31 Acc using 15 px 2D Projection = 39.65%\n",
"2018-05-06 14:12:31 Acc using 20 px 2D Projection = 46.18%\n",
"2018-05-06 14:12:31 Acc using 25 px 2D Projection = 49.50%\n",
"2018-05-06 14:12:31 Acc using 30 px 2D Projection = 51.83%\n",
"2018-05-06 14:12:31 Acc using 35 px 2D Projection = 53.05%\n",
"2018-05-06 14:12:31 Acc using 40 px 2D Projection = 53.16%\n",
"2018-05-06 14:12:31 Acc using 45 px 2D Projection = 53.93%\n",
"2018-05-06 14:12:31 Acc using 50 px 2D Projection = 54.71%\n",
"2018-05-06 14:12:45 Testing holepuncher...\n",
"2018-05-06 14:19:31 Acc using 5 px 2D Projection = 8.26%\n",
"2018-05-06 14:19:31 Acc using 10 px 2D Projection = 39.50%\n",
"2018-05-06 14:19:31 Acc using 15 px 2D Projection = 53.31%\n",
"2018-05-06 14:19:31 Acc using 20 px 2D Projection = 62.56%\n",
"2018-05-06 14:19:31 Acc using 25 px 2D Projection = 68.02%\n",
"2018-05-06 14:19:31 Acc using 30 px 2D Projection = 74.71%\n",
"2018-05-06 14:19:31 Acc using 35 px 2D Projection = 80.74%\n",
"2018-05-06 14:19:31 Acc using 40 px 2D Projection = 85.62%\n",
"2018-05-06 14:19:31 Acc using 45 px 2D Projection = 89.59%\n",
"2018-05-06 14:19:31 Acc using 50 px 2D Projection = 91.49%\n"
]
}
],
"source": [
"def valid(datacfg, cfgfile, weightfile, conf_th):\n",
" def truths_length(truths):\n",
" for i in range(50):\n",
" if truths[i][1] == 0:\n",
" return i\n",
"\n",
" # Parse configuration files\n",
" options = read_data_cfg(datacfg)\n",
" valid_images = options['valid']\n",
" meshname = options['mesh']\n",
" backupdir = options['backup']\n",
" name = options['name']\n",
" prefix = 'results'\n",
" # Read object model information, get 3D bounding box corners\n",
" mesh = MeshPly(meshname)\n",
" vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()\n",
" corners3D = get_3D_corners(vertices)\n",
" # Read intrinsic camera parameters\n",
" internal_calibration = get_camera_intrinsic()\n",
"\n",
" # Get validation file names\n",
" with open(valid_images) as fp:\n",
" tmp_files = fp.readlines()\n",
" valid_files = [item.rstrip() for item in tmp_files]\n",
" \n",
" # Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode\n",
" model = Darknet(cfgfile)\n",
" model.load_weights(weightfile)\n",
" model.cuda()\n",
" model.eval()\n",
"\n",
" # Get the parser for the test dataset\n",
" valid_dataset = dataset_multi.listDataset(valid_images, shape=(model.width, model.height),\n",
" shuffle=False,\n",
" objclass=name,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),\n",
" ]))\n",
" valid_batchsize = 1\n",
"\n",
" # Specify the number of workers for multiple processing, get the dataloader for the test dataset\n",
" kwargs = {'num_workers': 4, 'pin_memory': True}\n",
" test_loader = torch.utils.data.DataLoader(\n",
" valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs) \n",
"\n",
" # Parameters\n",
" visualize = False\n",
" use_cuda = True\n",
" num_classes = 13\n",
" anchors = [1.4820, 2.2412, 2.0501, 3.1265, 2.3946, 4.6891, 3.1018, 3.9910, 3.4879, 5.8851]\n",
" num_anchors = 5\n",
" eps = 1e-5\n",
" conf_thresh = conf_th\n",
" iou_thresh = 0.5\n",
"\n",
" # Parameters to save\n",
" errs_2d = []\n",
" edges = [[1, 2], [1, 3], [1, 5], [2, 4], [2, 6], [3, 4], [3, 7], [4, 8], [5, 6], [5, 7], [6, 8], [7, 8]]\n",
" edges_corners = [[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]]\n",
"\n",
" # Iterate through test batches (Batch size for test data is 1)\n",
" count = 0\n",
" logging('Testing {}...'.format(name))\n",
" for batch_idx, (data, target) in enumerate(test_loader):\n",
" \n",
" # Images\n",
" img = data[0, :, :, :]\n",
" img = img.numpy().squeeze()\n",
" img = np.transpose(img, (1, 2, 0))\n",
" \n",
" t1 = time.time()\n",
" # Pass data to GPU\n",
" if use_cuda:\n",
" data = data.cuda()\n",
" target = target.cuda()\n",
" \n",
" # Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference\n",
" data = Variable(data, volatile=True)\n",
" t2 = time.time()\n",
" \n",
" # Forward pass\n",
" output = model(data).data \n",
" t3 = time.time()\n",
" \n",
" # Using confidence threshold, eliminate low-confidence predictions\n",
" trgt = target[0].view(-1, 21)\n",
" all_boxes = get_corresponding_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, int(trgt[0][0]), only_objectness=0) \n",
" t4 = time.time()\n",
" \n",
" # Iterate through all images in the batch\n",
" for i in range(output.size(0)):\n",
" \n",
" # For each image, get all the predictions\n",
" boxes = all_boxes[i]\n",
" \n",
" # For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)\n",
" truths = target[i].view(-1, 21)\n",
" \n",
" # Get how many object are present in the scene\n",
" num_gts = truths_length(truths)\n",
"\n",
" # Iterate through each ground-truth object\n",
" for k in range(num_gts):\n",
" box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6], \n",
" truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12], \n",
" truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]\n",
" best_conf_est = -1\n",
" \n",
"\n",
" # If the prediction has the highest confidence, choose it as our prediction\n",
" for j in range(len(boxes)):\n",
" if (boxes[j][18] > best_conf_est) and (boxes[j][20] == int(truths[k][0])):\n",
" best_conf_est = boxes[j][18]\n",
" box_pr = boxes[j]\n",
" bb2d_gt = get_2d_bb(box_gt[:18], output.size(3))\n",
" bb2d_pr = get_2d_bb(box_pr[:18], output.size(3))\n",
" iou = bbox_iou(bb2d_gt, bb2d_pr)\n",
" match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))\n",
" \n",
" # Denormalize the corner predictions \n",
" corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')\n",
" corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')\n",
" corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640\n",
" corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480 \n",
" corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640\n",
" corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480\n",
" corners2D_gt_corrected = fix_corner_order(corners2D_gt) # Fix the order of the corners in OCCLUSION\n",
" \n",
" # Compute [R|t] by pnp\n",
" objpoints3D = np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32')\n",
" K = np.array(internal_calibration, dtype='float32')\n",
" R_gt, t_gt = pnp(objpoints3D, corners2D_gt_corrected, K)\n",
" R_pr, t_pr = pnp(objpoints3D, corners2D_pr, K)\n",
" \n",
" # Compute pixel error\n",
" Rt_gt = np.concatenate((R_gt, t_gt), axis=1)\n",
" Rt_pr = np.concatenate((R_pr, t_pr), axis=1)\n",
" proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration) \n",
" proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration) \n",
" proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, internal_calibration)) \n",
" proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, internal_calibration)) \n",
" norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)\n",
" pixel_dist = np.mean(norm)\n",
" errs_2d.append(pixel_dist)\n",
"\n",
" \n",
" if visualize:\n",
" # Visualize\n",
" plt.xlim((0, 640))\n",
" plt.ylim((0, 480))\n",
" plt.imshow(scipy.misc.imresize(img, (480, 640)))\n",
" # Projections\n",
" for edge in edges_corners:\n",
" plt.plot(proj_corners_gt[edge, 0], proj_corners_gt[edge, 1], color='g', linewidth=3.0)\n",
" plt.plot(proj_corners_pr[edge, 0], proj_corners_pr[edge, 1], color='b', linewidth=3.0)\n",
" plt.gca().invert_yaxis()\n",
" plt.show()\n",
"\n",
" t5 = time.time()\n",
"\n",
" # Compute 2D projection score\n",
" for px_threshold in [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]:\n",
" acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)\n",
" # Print test statistics\n",
" logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))\n",
"\n",
"conf_th = 0.05\n",
"cfgfile = 'cfg/yolo-pose-multi.cfg'\n",
"weightfile = 'backup_multi/model_backup2.weights'\n",
"datacfg = 'cfg/ape_occlusion.data'\n",
"valid(datacfg, cfgfile, weightfile, conf_th)\n",
"datacfg = 'cfg/can_occlusion.data'\n",
"valid(datacfg, cfgfile, weightfile, conf_th)\n",
"datacfg = 'cfg/cat_occlusion.data'\n",
"valid(datacfg, cfgfile, weightfile, conf_th)\n",
"datacfg = 'cfg/duck_occlusion.data'\n",
"valid(datacfg, cfgfile, weightfile, conf_th)\n",
"datacfg = 'cfg/driller_occlusion.data'\n",
"valid(datacfg, cfgfile, weightfile, conf_th)\n",
"datacfg = 'cfg/glue_occlusion.data'\n",
"valid(datacfg, cfgfile, weightfile, conf_th)\n",
"datacfg = 'cfg/holepuncher_occlusion.data'\n",
"valid(datacfg, cfgfile, weightfile, conf_th)\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Просмотреть файл

@ -0,0 +1,183 @@
import os
os.sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from torch.autograd import Variable
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import scipy.misc
import warnings
warnings.filterwarnings("ignore")
from darknet_multi import Darknet
from utils import *
import dataset_multi
from MeshPly import MeshPly
def valid(datacfg, cfgfile, weightfile, conf_th):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
# Parse configuration files
options = read_data_cfg(datacfg)
valid_images = options['valid']
meshname = options['mesh']
name = options['name']
prefix = 'results'
# Read object model information, get 3D bounding box corners
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
diam = float(options['diam'])
# Read intrinsic camera parameters
internal_calibration = get_camera_intrinsic()
# Get validation file names
with open(valid_images) as fp:
tmp_files = fp.readlines()
valid_files = [item.rstrip() for item in tmp_files]
# Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode
model = Darknet(cfgfile)
model.load_weights(weightfile)
model.cuda()
model.eval()
# Get the parser for the test dataset
valid_dataset = dataset_multi.listDataset(valid_images, shape=(model.width, model.height),
shuffle=False,
objclass=name,
transform=transforms.Compose([
transforms.ToTensor(),
]))
valid_batchsize = 1
# Specify the number of workers for multiple processing, get the dataloader for the test dataset
kwargs = {'num_workers': 4, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs)
# Parameters
use_cuda = True
num_classes = 13
anchors = [1.4820, 2.2412, 2.0501, 3.1265, 2.3946, 4.6891, 3.1018, 3.9910, 3.4879, 5.8851]
num_anchors = 5
eps = 1e-5
conf_thresh = conf_th
iou_thresh = 0.5
# Parameters to save
errs_2d = []
edges = [[1, 2], [1, 3], [1, 5], [2, 4], [2, 6], [3, 4], [3, 7], [4, 8], [5, 6], [5, 7], [6, 8], [7, 8]]
edges_corners = [[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]]
# Iterate through test batches (Batch size for test data is 1)
logging('Testing {}...'.format(name))
for batch_idx, (data, target) in enumerate(test_loader):
t1 = time.time()
# Pass data to GPU
if use_cuda:
data = data.cuda()
# target = target.cuda()
# Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference
data = Variable(data, volatile=True)
t2 = time.time()
# Forward pass
output = model(data).data
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
trgt = target[0].view(-1, 21)
all_boxes = get_corresponding_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, int(trgt[0][0]), only_objectness=0)
t4 = time.time()
# Iterate through all images in the batch
for i in range(output.size(0)):
# For each image, get all the predictions
boxes = all_boxes[i]
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
# Get how many object are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
# If the prediction has the highest confidence, choose it as our prediction
for j in range(len(boxes)):
if (boxes[j][18] > best_conf_est) and (boxes[j][20] == int(truths[k][0])):
best_conf_est = boxes[j][18]
box_pr = boxes[j]
bb2d_gt = get_2d_bb(box_gt[:18], output.size(3))
bb2d_pr = get_2d_bb(box_pr[:18], output.size(3))
iou = bbox_iou(bb2d_gt, bb2d_pr)
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640
corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480
corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640
corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480
corners2D_gt_corrected = fix_corner_order(corners2D_gt) # Fix the order of corners
# Compute [R|t] by pnp
objpoints3D = np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32')
K = np.array(internal_calibration, dtype='float32')
R_gt, t_gt = pnp(objpoints3D, corners2D_gt_corrected, K)
R_pr, t_pr = pnp(objpoints3D, corners2D_pr, K)
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, internal_calibration))
proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, internal_calibration))
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
t5 = time.time()
# Compute 2D projection score
for px_threshold in [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]:
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
# Print test statistics
logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))
if __name__ == '__main__' and __package__ is None:
import sys
if len(sys.argv) == 3:
conf_th = 0.05
cfgfile = sys.argv[1]
weightfile = sys.argv[2]
datacfg = 'cfg/ape_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/can_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/cat_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/duck_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/glue_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
datacfg = 'cfg/holepuncher_occlusion.data'
valid(datacfg, cfgfile, weightfile, conf_th)
else:
print('Usage:')
print(' python valid.py cfgfile weightfile')

301
py2/region_loss.py Normal file
Просмотреть файл

@ -0,0 +1,301 @@
import time
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from utils import *
def build_targets(pred_corners, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, seen):
nB = target.size(0)
nA = num_anchors
nC = num_classes
anchor_step = len(anchors)/num_anchors
conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale
coord_mask = torch.zeros(nB, nA, nH, nW)
cls_mask = torch.zeros(nB, nA, nH, nW)
tx0 = torch.zeros(nB, nA, nH, nW)
ty0 = torch.zeros(nB, nA, nH, nW)
tx1 = torch.zeros(nB, nA, nH, nW)
ty1 = torch.zeros(nB, nA, nH, nW)
tx2 = torch.zeros(nB, nA, nH, nW)
ty2 = torch.zeros(nB, nA, nH, nW)
tx3 = torch.zeros(nB, nA, nH, nW)
ty3 = torch.zeros(nB, nA, nH, nW)
tx4 = torch.zeros(nB, nA, nH, nW)
ty4 = torch.zeros(nB, nA, nH, nW)
tx5 = torch.zeros(nB, nA, nH, nW)
ty5 = torch.zeros(nB, nA, nH, nW)
tx6 = torch.zeros(nB, nA, nH, nW)
ty6 = torch.zeros(nB, nA, nH, nW)
tx7 = torch.zeros(nB, nA, nH, nW)
ty7 = torch.zeros(nB, nA, nH, nW)
tx8 = torch.zeros(nB, nA, nH, nW)
ty8 = torch.zeros(nB, nA, nH, nW)
tconf = torch.zeros(nB, nA, nH, nW)
tcls = torch.zeros(nB, nA, nH, nW)
nAnchors = nA*nH*nW
nPixels = nH*nW
for b in range(nB):
cur_pred_corners = pred_corners[b*nAnchors:(b+1)*nAnchors].t()
cur_confs = torch.zeros(nAnchors)
for t in range(50):
if target[b][t*21+1] == 0:
break
gx0 = target[b][t*21+1]*nW
gy0 = target[b][t*21+2]*nH
gx1 = target[b][t*21+3]*nW
gy1 = target[b][t*21+4]*nH
gx2 = target[b][t*21+5]*nW
gy2 = target[b][t*21+6]*nH
gx3 = target[b][t*21+7]*nW
gy3 = target[b][t*21+8]*nH
gx4 = target[b][t*21+9]*nW
gy4 = target[b][t*21+10]*nH
gx5 = target[b][t*21+11]*nW
gy5 = target[b][t*21+12]*nH
gx6 = target[b][t*21+13]*nW
gy6 = target[b][t*21+14]*nH
gx7 = target[b][t*21+15]*nW
gy7 = target[b][t*21+16]*nH
gx8 = target[b][t*21+17]*nW
gy8 = target[b][t*21+18]*nH
cur_gt_corners = torch.FloatTensor([gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]).repeat(nAnchors,1).t() # 16 x nAnchors
cur_confs = torch.max(cur_confs, corner_confidences9(cur_pred_corners, cur_gt_corners)) # some irrelevant areas are filtered, in the same grid multiple anchor boxes might exceed the threshold
conf_mask[b][cur_confs>sil_thresh] = 0
if seen < -1:#6400:
tx0.fill_(0.5)
ty0.fill_(0.5)
tx1.fill_(0.5)
ty1.fill_(0.5)
tx2.fill_(0.5)
ty2.fill_(0.5)
tx3.fill_(0.5)
ty3.fill_(0.5)
tx4.fill_(0.5)
ty4.fill_(0.5)
tx5.fill_(0.5)
ty5.fill_(0.5)
tx6.fill_(0.5)
ty6.fill_(0.5)
tx7.fill_(0.5)
ty7.fill_(0.5)
tx8.fill_(0.5)
ty8.fill_(0.5)
coord_mask.fill_(1)
nGT = 0
nCorrect = 0
for b in range(nB):
for t in range(50):
if target[b][t*21+1] == 0:
break
nGT = nGT + 1
best_iou = 0.0
best_n = -1
min_dist = 10000
gx0 = target[b][t*21+1] * nW
gy0 = target[b][t*21+2] * nH
gi0 = int(gx0)
gj0 = int(gy0)
gx1 = target[b][t*21+3] * nW
gy1 = target[b][t*21+4] * nH
gx2 = target[b][t*21+5] * nW
gy2 = target[b][t*21+6] * nH
gx3 = target[b][t*21+7] * nW
gy3 = target[b][t*21+8] * nH
gx4 = target[b][t*21+9] * nW
gy4 = target[b][t*21+10] * nH
gx5 = target[b][t*21+11] * nW
gy5 = target[b][t*21+12] * nH
gx6 = target[b][t*21+13] * nW
gy6 = target[b][t*21+14] * nH
gx7 = target[b][t*21+15] * nW
gy7 = target[b][t*21+16] * nH
gx8 = target[b][t*21+17] * nW
gy8 = target[b][t*21+18] * nH
best_n = 0 # 1 anchor box
gt_box = [gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]
pred_box = pred_corners[b*nAnchors+best_n*nPixels+gj0*nW+gi0]
conf = corner_confidence9(gt_box, pred_box)
coord_mask[b][best_n][gj0][gi0] = 1
cls_mask[b][best_n][gj0][gi0] = 1
conf_mask[b][best_n][gj0][gi0] = object_scale
tx0[b][best_n][gj0][gi0] = target[b][t*21+1] * nW - gi0
ty0[b][best_n][gj0][gi0] = target[b][t*21+2] * nH - gj0
tx1[b][best_n][gj0][gi0] = target[b][t*21+3] * nW - gi0
ty1[b][best_n][gj0][gi0] = target[b][t*21+4] * nH - gj0
tx2[b][best_n][gj0][gi0] = target[b][t*21+5] * nW - gi0
ty2[b][best_n][gj0][gi0] = target[b][t*21+6] * nH - gj0
tx3[b][best_n][gj0][gi0] = target[b][t*21+7] * nW - gi0
ty3[b][best_n][gj0][gi0] = target[b][t*21+8] * nH - gj0
tx4[b][best_n][gj0][gi0] = target[b][t*21+9] * nW - gi0
ty4[b][best_n][gj0][gi0] = target[b][t*21+10] * nH - gj0
tx5[b][best_n][gj0][gi0] = target[b][t*21+11] * nW - gi0
ty5[b][best_n][gj0][gi0] = target[b][t*21+12] * nH - gj0
tx6[b][best_n][gj0][gi0] = target[b][t*21+13] * nW - gi0
ty6[b][best_n][gj0][gi0] = target[b][t*21+14] * nH - gj0
tx7[b][best_n][gj0][gi0] = target[b][t*21+15] * nW - gi0
ty7[b][best_n][gj0][gi0] = target[b][t*21+16] * nH - gj0
tx8[b][best_n][gj0][gi0] = target[b][t*21+17] * nW - gi0
ty8[b][best_n][gj0][gi0] = target[b][t*21+18] * nH - gj0
tconf[b][best_n][gj0][gi0] = conf
tcls[b][best_n][gj0][gi0] = target[b][t*21]
if conf > 0.5:
nCorrect = nCorrect + 1
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls
class RegionLoss(nn.Module):
def __init__(self, num_classes=0, anchors=[], num_anchors=1):
super(RegionLoss, self).__init__()
self.num_classes = num_classes
self.anchors = anchors
self.num_anchors = num_anchors
self.anchor_step = len(anchors)/num_anchors
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.seen = 0
def forward(self, output, target):
# Parameters
t0 = time.time()
nB = output.data.size(0)
nA = self.num_anchors
nC = self.num_classes
nH = output.data.size(2)
nW = output.data.size(3)
# Activation
output = output.view(nB, nA, (19+nC), nH, nW)
x0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW))
y0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW))
x1 = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW)
y1 = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW)
x2 = output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW)
y2 = output.index_select(2, Variable(torch.cuda.LongTensor([5]))).view(nB, nA, nH, nW)
x3 = output.index_select(2, Variable(torch.cuda.LongTensor([6]))).view(nB, nA, nH, nW)
y3 = output.index_select(2, Variable(torch.cuda.LongTensor([7]))).view(nB, nA, nH, nW)
x4 = output.index_select(2, Variable(torch.cuda.LongTensor([8]))).view(nB, nA, nH, nW)
y4 = output.index_select(2, Variable(torch.cuda.LongTensor([9]))).view(nB, nA, nH, nW)
x5 = output.index_select(2, Variable(torch.cuda.LongTensor([10]))).view(nB, nA, nH, nW)
y5 = output.index_select(2, Variable(torch.cuda.LongTensor([11]))).view(nB, nA, nH, nW)
x6 = output.index_select(2, Variable(torch.cuda.LongTensor([12]))).view(nB, nA, nH, nW)
y6 = output.index_select(2, Variable(torch.cuda.LongTensor([13]))).view(nB, nA, nH, nW)
x7 = output.index_select(2, Variable(torch.cuda.LongTensor([14]))).view(nB, nA, nH, nW)
y7 = output.index_select(2, Variable(torch.cuda.LongTensor([15]))).view(nB, nA, nH, nW)
x8 = output.index_select(2, Variable(torch.cuda.LongTensor([16]))).view(nB, nA, nH, nW)
y8 = output.index_select(2, Variable(torch.cuda.LongTensor([17]))).view(nB, nA, nH, nW)
conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([18]))).view(nB, nA, nH, nW))
cls = output.index_select(2, Variable(torch.linspace(19,19+nC-1,nC).long().cuda()))
cls = cls.view(nB*nA, nC, nH*nW).transpose(1,2).contiguous().view(nB*nA*nH*nW, nC)
t1 = time.time()
# Create pred boxes
pred_corners = torch.cuda.FloatTensor(18, nB*nA*nH*nW)
grid_x = torch.linspace(0, nW-1, nW).repeat(nH,1).repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
grid_y = torch.linspace(0, nH-1, nH).repeat(nW,1).t().repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
pred_corners[0] = (x0.data + grid_x) / nW
pred_corners[1] = (y0.data + grid_y) / nH
pred_corners[2] = (x1.data + grid_x) / nW
pred_corners[3] = (y1.data + grid_y) / nH
pred_corners[4] = (x2.data + grid_x) / nW
pred_corners[5] = (y2.data + grid_y) / nH
pred_corners[6] = (x3.data + grid_x) / nW
pred_corners[7] = (y3.data + grid_y) / nH
pred_corners[8] = (x4.data + grid_x) / nW
pred_corners[9] = (y4.data + grid_y) / nH
pred_corners[10] = (x5.data + grid_x) / nW
pred_corners[11] = (y5.data + grid_y) / nH
pred_corners[12] = (x6.data + grid_x) / nW
pred_corners[13] = (y6.data + grid_y) / nH
pred_corners[14] = (x7.data + grid_x) / nW
pred_corners[15] = (y7.data + grid_y) / nH
pred_corners[16] = (x8.data + grid_x) / nW
pred_corners[17] = (y8.data + grid_y) / nH
gpu_matrix = pred_corners.transpose(0,1).contiguous().view(-1,18)
pred_corners = convert2cpu(gpu_matrix)
t2 = time.time()
# Build targets
nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls = \
build_targets(pred_corners, target.data, self.anchors, nA, nC, nH, nW, self.noobject_scale, self.object_scale, self.thresh, self.seen)
cls_mask = (cls_mask == 1)
nProposals = int((conf > 0.25).sum().data[0])
tx0 = Variable(tx0.cuda())
ty0 = Variable(ty0.cuda())
tx1 = Variable(tx1.cuda())
ty1 = Variable(ty1.cuda())
tx2 = Variable(tx2.cuda())
ty2 = Variable(ty2.cuda())
tx3 = Variable(tx3.cuda())
ty3 = Variable(ty3.cuda())
tx4 = Variable(tx4.cuda())
ty4 = Variable(ty4.cuda())
tx5 = Variable(tx5.cuda())
ty5 = Variable(ty5.cuda())
tx6 = Variable(tx6.cuda())
ty6 = Variable(ty6.cuda())
tx7 = Variable(tx7.cuda())
ty7 = Variable(ty7.cuda())
tx8 = Variable(tx8.cuda())
ty8 = Variable(ty8.cuda())
tconf = Variable(tconf.cuda())
tcls = Variable(tcls.view(-1)[cls_mask].long().cuda())
coord_mask = Variable(coord_mask.cuda())
conf_mask = Variable(conf_mask.cuda().sqrt())
cls_mask = Variable(cls_mask.view(-1, 1).repeat(1,nC).cuda())
cls = cls[cls_mask].view(-1, nC)
t3 = time.time()
# Create loss
loss_x0 = self.coord_scale * nn.MSELoss(size_average=False)(x0*coord_mask, tx0*coord_mask)/2.0
loss_y0 = self.coord_scale * nn.MSELoss(size_average=False)(y0*coord_mask, ty0*coord_mask)/2.0
loss_x1 = self.coord_scale * nn.MSELoss(size_average=False)(x1*coord_mask, tx1*coord_mask)/2.0
loss_y1 = self.coord_scale * nn.MSELoss(size_average=False)(y1*coord_mask, ty1*coord_mask)/2.0
loss_x2 = self.coord_scale * nn.MSELoss(size_average=False)(x2*coord_mask, tx2*coord_mask)/2.0
loss_y2 = self.coord_scale * nn.MSELoss(size_average=False)(y2*coord_mask, ty2*coord_mask)/2.0
loss_x3 = self.coord_scale * nn.MSELoss(size_average=False)(x3*coord_mask, tx3*coord_mask)/2.0
loss_y3 = self.coord_scale * nn.MSELoss(size_average=False)(y3*coord_mask, ty3*coord_mask)/2.0
loss_x4 = self.coord_scale * nn.MSELoss(size_average=False)(x4*coord_mask, tx4*coord_mask)/2.0
loss_y4 = self.coord_scale * nn.MSELoss(size_average=False)(y4*coord_mask, ty4*coord_mask)/2.0
loss_x5 = self.coord_scale * nn.MSELoss(size_average=False)(x5*coord_mask, tx5*coord_mask)/2.0
loss_y5 = self.coord_scale * nn.MSELoss(size_average=False)(y5*coord_mask, ty5*coord_mask)/2.0
loss_x6 = self.coord_scale * nn.MSELoss(size_average=False)(x6*coord_mask, tx6*coord_mask)/2.0
loss_y6 = self.coord_scale * nn.MSELoss(size_average=False)(y6*coord_mask, ty6*coord_mask)/2.0
loss_x7 = self.coord_scale * nn.MSELoss(size_average=False)(x7*coord_mask, tx7*coord_mask)/2.0
loss_y7 = self.coord_scale * nn.MSELoss(size_average=False)(y7*coord_mask, ty7*coord_mask)/2.0
loss_x8 = self.coord_scale * nn.MSELoss(size_average=False)(x8*coord_mask, tx8*coord_mask)/2.0
loss_y8 = self.coord_scale * nn.MSELoss(size_average=False)(y8*coord_mask, ty8*coord_mask)/2.0
loss_conf = nn.MSELoss(size_average=False)(conf*conf_mask, tconf*conf_mask)/2.0
# loss_cls = self.class_scale * nn.CrossEntropyLoss(size_average=False)(cls, tcls)
loss_cls = 0
loss_x = loss_x0 + loss_x1 + loss_x2 + loss_x3 + loss_x4 + loss_x5 + loss_x6 + loss_x7 + loss_x8
loss_y = loss_y0 + loss_y1 + loss_y2 + loss_y3 + loss_y4 + loss_y5 + loss_y6 + loss_y7 + loss_y8
if False:
loss = loss_x + loss_y + loss_conf + loss_cls
else:
loss = loss_x + loss_y + loss_conf
t4 = time.time()
if False:
print('-----------------------------------')
print(' activation : %f' % (t1 - t0))
print(' create pred_corners : %f' % (t2 - t1))
print(' build targets : %f' % (t3 - t2))
print(' create loss : %f' % (t4 - t3))
print(' total : %f' % (t4 - t0))
if False:
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, cls %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss_cls.data[0], loss.data[0]))
else:
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss.data[0]))
return loss

417
py2/train.py Normal file
Просмотреть файл

@ -0,0 +1,417 @@
from __future__ import print_function
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import numpy as np
import os
import random
import math
import shutil
from torchvision import datasets, transforms
from torch.autograd import Variable # Useful info about autograd: http://pytorch.org/docs/master/notes/autograd.html
import dataset
from utils import *
from cfg import parse_cfg
from region_loss import RegionLoss
from darknet import Darknet
from MeshPly import MeshPly
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
# Adjust learning rate during training, learning schedule can be changed in network config file
def adjust_learning_rate(optimizer, batch):
lr = learning_rate
for i in range(len(steps)):
scale = scales[i] if i < len(scales) else 1
if batch >= steps[i]:
lr = lr * scale
if batch == steps[i]:
break
else:
break
for param_group in optimizer.param_groups:
param_group['lr'] = lr/batch_size
return lr
def train(epoch):
global processed_batches
# Initialize timer
t0 = time.time()
# Get the dataloader for training dataset
train_loader = torch.utils.data.DataLoader(dataset.listDataset(trainlist, shape=(init_width, init_height),
shuffle=True,
transform=transforms.Compose([transforms.ToTensor(),]),
train=True,
seen=model.seen,
batch_size=batch_size,
num_workers=num_workers, bg_file_names=bg_file_names),
batch_size=batch_size, shuffle=False, **kwargs)
# TRAINING
lr = adjust_learning_rate(optimizer, processed_batches)
logging('epoch %d, processed %d samples, lr %f' % (epoch, epoch * len(train_loader.dataset), lr))
# Start training
model.train()
t1 = time.time()
avg_time = torch.zeros(9)
niter = 0
# Iterate through batches
for batch_idx, (data, target) in enumerate(train_loader):
t2 = time.time()
# adjust learning rate
adjust_learning_rate(optimizer, processed_batches)
processed_batches = processed_batches + 1
# Pass the data to GPU
if use_cuda:
data = data.cuda()
t3 = time.time()
# Wrap tensors in Variable class for automatic differentiation
data, target = Variable(data), Variable(target)
t4 = time.time()
# Zero the gradients before running the backward pass
optimizer.zero_grad()
t5 = time.time()
# Forward pass
output = model(data)
t6 = time.time()
model.seen = model.seen + data.data.size(0)
region_loss.seen = region_loss.seen + data.data.size(0)
# Compute loss, grow an array of losses for saving later on
loss = region_loss(output, target)
training_iters.append(epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter)
training_losses.append(convert2cpu(loss.data))
niter += 1
t7 = time.time()
# Backprop: compute gradient of the loss with respect to model parameters
loss.backward()
t8 = time.time()
# Update weights
optimizer.step()
t9 = time.time()
# Print time statistics
if False and batch_idx > 1:
avg_time[0] = avg_time[0] + (t2-t1)
avg_time[1] = avg_time[1] + (t3-t2)
avg_time[2] = avg_time[2] + (t4-t3)
avg_time[3] = avg_time[3] + (t5-t4)
avg_time[4] = avg_time[4] + (t6-t5)
avg_time[5] = avg_time[5] + (t7-t6)
avg_time[6] = avg_time[6] + (t8-t7)
avg_time[7] = avg_time[7] + (t9-t8)
avg_time[8] = avg_time[8] + (t9-t1)
print('-------------------------------')
print(' load data : %f' % (avg_time[0]/(batch_idx)))
print(' cpu to cuda : %f' % (avg_time[1]/(batch_idx)))
print('cuda to variable : %f' % (avg_time[2]/(batch_idx)))
print(' zero_grad : %f' % (avg_time[3]/(batch_idx)))
print(' forward feature : %f' % (avg_time[4]/(batch_idx)))
print(' forward loss : %f' % (avg_time[5]/(batch_idx)))
print(' backward : %f' % (avg_time[6]/(batch_idx)))
print(' step : %f' % (avg_time[7]/(batch_idx)))
print(' total : %f' % (avg_time[8]/(batch_idx)))
t1 = time.time()
t1 = time.time()
return epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter - 1
def test(epoch, niter):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
# Set the module in evaluation mode (turn off dropout, batch normalization etc.)
model.eval()
# Parameters
num_classes = model.num_classes
anchors = model.anchors
num_anchors = model.num_anchors
testtime = True
testing_error_trans = 0.0
testing_error_angle = 0.0
testing_error_pixel = 0.0
testing_samples = 0.0
errs_2d = []
errs_3d = []
errs_trans = []
errs_angle = []
errs_corner2D = []
logging(" Testing...")
logging(" Number of test samples: %d" % len(test_loader.dataset))
notpredicted = 0
# Iterate through test examples
for batch_idx, (data, target) in enumerate(test_loader):
t1 = time.time()
# Pass the data to GPU
if use_cuda:
data = data.cuda()
target = target.cuda()
# Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference
data = Variable(data, volatile=True)
t2 = time.time()
# Formward pass
output = model(data).data
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
all_boxes = get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors)
t4 = time.time()
# Iterate through all batch elements
for i in range(output.size(0)):
# For each image, get all the predictions
boxes = all_boxes[i]
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
# Get how many object are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
# If the prediction has the highest confidence, choose it as our prediction
for j in range(len(boxes)):
if boxes[j][18] > best_conf_est:
best_conf_est = boxes[j][18]
box_pr = boxes[j]
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width
corners2D_pr[:, 1] = corners2D_pr[:, 1] * im_height
# Compute corner prediction error
corner_norm = np.linalg.norm(corners2D_gt - corners2D_pr, axis=1)
corner_dist = np.mean(corner_norm)
errs_corner2D.append(corner_dist)
# Compute [R|t] by pnp
R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(internal_calibration, dtype='float32'))
R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))
# Compute errors
# Compute translation error
trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))
errs_trans.append(trans_dist)
# Compute angle error
angle_dist = calcAngularDistance(R_gt, R_pr)
errs_angle.append(angle_dist)
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
# Compute 3D distances
transform_3d_gt = compute_transformation(vertices, Rt_gt)
transform_3d_pred = compute_transformation(vertices, Rt_pr)
norm3d = np.linalg.norm(transform_3d_gt - transform_3d_pred, axis=0)
vertex_dist = np.mean(norm3d)
errs_3d.append(vertex_dist)
# Sum errors
testing_error_trans += trans_dist
testing_error_angle += angle_dist
testing_error_pixel += pixel_dist
testing_samples += 1
t5 = time.time()
# Compute 2D projection, 6D pose and 5cm5degree scores
px_threshold = 5
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
acc3d = len(np.where(np.array(errs_3d) <= vx_threshold)[0]) * 100. / (len(errs_3d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)
mean_err_2d = np.mean(errs_2d)
mean_corner_err_2d = np.mean(errs_corner2D)
nts = float(testing_samples)
if testtime:
print('-----------------------------------')
print(' tensor to cuda : %f' % (t2 - t1))
print(' predict : %f' % (t3 - t2))
print('get_region_boxes : %f' % (t4 - t3))
print(' eval : %f' % (t5 - t4))
print(' total : %f' % (t5 - t1))
print('-----------------------------------')
# Print test statistics
logging(" Mean corner error is %f" % (mean_corner_err_2d))
logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))
logging(' Acc using {} vx 3D Transformation = {:.2f}%'.format(vx_threshold, acc3d))
logging(' Acc using 5 cm 5 degree metric = {:.2f}%'.format(acc5cm5deg))
logging(' Translation error: %f, angle error: %f' % (testing_error_trans/(nts+eps), testing_error_angle/(nts+eps)) )
# Register losses and errors for saving later on
testing_iters.append(niter)
testing_errors_trans.append(testing_error_trans/(nts+eps))
testing_errors_angle.append(testing_error_angle/(nts+eps))
testing_errors_pixel.append(testing_error_pixel/(nts+eps))
testing_accuracies.append(acc)
if __name__ == "__main__":
# Training settings
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
weightfile = sys.argv[3]
# Parse configuration files
data_options = read_data_cfg(datacfg)
net_options = parse_cfg(cfgfile)[0]
trainlist = data_options['train']
testlist = data_options['valid']
nsamples = file_lines(trainlist)
gpus = data_options['gpus'] # e.g. 0,1,2,3
gpus = '0'
meshname = data_options['mesh']
num_workers = int(data_options['num_workers'])
backupdir = data_options['backup']
diam = float(data_options['diam'])
vx_threshold = diam * 0.1
if not os.path.exists(backupdir):
makedirs(backupdir)
batch_size = int(net_options['batch'])
max_batches = int(net_options['max_batches'])
learning_rate = float(net_options['learning_rate'])
momentum = float(net_options['momentum'])
decay = float(net_options['decay'])
steps = [float(step) for step in net_options['steps'].split(',')]
scales = [float(scale) for scale in net_options['scales'].split(',')]
bg_file_names = get_all_files('VOCdevkit/VOC2012/JPEGImages')
# Train parameters
max_epochs = 700 # max_batches*batch_size/nsamples+1
use_cuda = True
seed = int(time.time())
eps = 1e-5
save_interval = 10 # epoches
dot_interval = 70 # batches
best_acc = -1
# Test parameters
conf_thresh = 0.1
nms_thresh = 0.4
iou_thresh = 0.5
im_width = 640
im_height = 480
# Specify which gpus to use
torch.manual_seed(seed)
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
# Specifiy the model and the loss
model = Darknet(cfgfile)
region_loss = model.loss
# Model settings
# model.load_weights(weightfile)
model.load_weights_until_last(weightfile)
model.print_network()
model.seen = 0
region_loss.iter = model.iter
region_loss.seen = model.seen
processed_batches = model.seen/batch_size
init_width = model.width
init_height = model.height
test_width = 672
test_height = 672
init_epoch = model.seen/nsamples
# Variable to save
training_iters = []
training_losses = []
testing_iters = []
testing_losses = []
testing_errors_trans = []
testing_errors_angle = []
testing_errors_pixel = []
testing_accuracies = []
# Get the intrinsic camerea matrix, mesh, vertices and corners of the model
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
internal_calibration = get_camera_intrinsic()
print("vertex unit")
print(np.max(vertices))
print(np.min(vertices))
# Specify the number of workers
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
# Get the dataloader for test data
test_loader = torch.utils.data.DataLoader(dataset.listDataset(testlist, shape=(test_width, test_height),
shuffle=False,
transform=transforms.Compose([transforms.ToTensor(),]),
train=False),
batch_size=1, shuffle=False, **kwargs)
# Pass the model to GPU
if use_cuda:
model = model.cuda() # model = torch.nn.DataParallel(model, device_ids=[0]).cuda() # Multiple GPU parallelism
# Get the optimizer
params_dict = dict(model.named_parameters())
params = []
for key, value in params_dict.items():
if key.find('.bn') >= 0 or key.find('.bias') >= 0:
params += [{'params': [value], 'weight_decay': 0.0}]
else:
params += [{'params': [value], 'weight_decay': decay*batch_size}]
optimizer = optim.SGD(model.parameters(), lr=learning_rate/batch_size, momentum=momentum, dampening=0, weight_decay=decay*batch_size)
# optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam optimization
evaluate = False
if evaluate:
logging('evaluating ...')
test(0, 0)
else:
for epoch in range(init_epoch, max_epochs):
# TRAIN
niter = train(epoch)
# TEST and SAVE
if (epoch % 10 == 0) and (epoch is not 0):
test(epoch, niter)
logging('save training stats to %s/costs.npz' % (backupdir))
np.savez(os.path.join(backupdir, "costs.npz"),
training_iters=training_iters,
training_losses=training_losses,
testing_iters=testing_iters,
testing_accuracies=testing_accuracies,
testing_errors_pixel=testing_errors_pixel,
testing_errors_angle=testing_errors_angle)
if (testing_accuracies[-1] > best_acc ):
best_acc = testing_accuracies[-1]
logging('best model so far!')
logging('save weights to %s/model.weights' % (backupdir))
model.save_weights('%s/model.weights' % (backupdir))
shutil.copy2('%s/model.weights' % (backupdir), '%s/model_backup.weights' % (backupdir))

1058
py2/utils.py Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

357
py2/valid.ipynb Normal file
Просмотреть файл

@ -0,0 +1,357 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import time\n",
"import torch\n",
"from torch.autograd import Variable\n",
"from torchvision import datasets, transforms\n",
"import scipy.io\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"import matplotlib.pyplot as plt\n",
"import scipy.misc\n",
"\n",
"from darknet import Darknet\n",
"import dataset\n",
"from utils import *\n",
"from MeshPly import MeshPly\n",
"\n",
"# Create new directory\n",
"def makedirs(path):\n",
" if not os.path.exists( path ):\n",
" os.makedirs( path )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "IOError",
"evalue": "[Errno 2] No such file or directory: 'LINEMOD/ape/ape.ply'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIOError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-cd23ddeac3d5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0mcfgfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cfg/yolo-pose.cfg'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0mweightfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'backup/ape/model_backup.weights'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m \u001b[0mvalid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatacfg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcfgfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweightfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-7-cd23ddeac3d5>\u001b[0m in \u001b[0;36mvalid\u001b[0;34m(datacfg, cfgfile, weightfile)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;31m# Read object model information, get 3D bounding box corners\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mmesh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMeshPly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmeshname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0mvertices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmesh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmesh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0mcorners3D\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_3D_corners\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/cvlabdata1/home/btekin/ope/singleshotpose_release/MeshPly.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filename, color)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'r'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mIOError\u001b[0m: [Errno 2] No such file or directory: 'LINEMOD/ape/ape.ply'"
]
}
],
"source": [
"def valid(datacfg, cfgfile, weightfile):\n",
" def truths_length(truths):\n",
" for i in range(50):\n",
" if truths[i][1] == 0:\n",
" return i\n",
"\n",
" # Parse configuration files\n",
" options = read_data_cfg(datacfg)\n",
" valid_images = options['valid']\n",
" meshname = options['mesh']\n",
" backupdir = options['backup']\n",
" name = options['name']\n",
" if not os.path.exists(backupdir):\n",
" makedirs(backupdir)\n",
"\n",
" # Parameters\n",
" prefix = 'results'\n",
" seed = int(time.time())\n",
" gpus = '0' # Specify which gpus to use\n",
" test_width = 544\n",
" test_height = 544\n",
" torch.manual_seed(seed)\n",
" use_cuda = True\n",
" if use_cuda:\n",
" os.environ['CUDA_VISIBLE_DEVICES'] = gpus\n",
" torch.cuda.manual_seed(seed)\n",
" save = False\n",
" visualize = True\n",
" testtime = True\n",
" use_cuda = True\n",
" num_classes = 1\n",
" testing_samples = 0.0\n",
" eps = 1e-5\n",
" notpredicted = 0 \n",
" conf_thresh = 0.1\n",
" nms_thresh = 0.4\n",
" match_thresh = 0.5\n",
" edges_corners = [[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]]\n",
"\n",
" if save:\n",
" makedirs(backupdir + '/test')\n",
" makedirs(backupdir + '/test/gt')\n",
" makedirs(backupdir + '/test/pr')\n",
"\n",
" # To save\n",
" testing_error_trans = 0.0\n",
" testing_error_angle = 0.0\n",
" testing_error_pixel = 0.0\n",
" errs_2d = []\n",
" errs_3d = []\n",
" errs_trans = []\n",
" errs_angle = []\n",
" errs_corner2D = []\n",
" preds_trans = []\n",
" preds_rot = []\n",
" preds_corners2D = []\n",
" gts_trans = []\n",
" gts_rot = []\n",
" gts_corners2D = []\n",
" ious = []\n",
"\n",
" # Read object model information, get 3D bounding box corners\n",
" mesh = MeshPly(meshname)\n",
" vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()\n",
" corners3D = get_3D_corners(vertices)\n",
" # diam = calc_pts_diameter(np.array(mesh.vertices))\n",
" diam = float(options['diam'])\n",
"\n",
" # Read intrinsic camera parameters\n",
" internal_calibration = get_camera_intrinsic()\n",
"\n",
" # Get validation file names\n",
" with open(valid_images) as fp:\n",
" tmp_files = fp.readlines()\n",
" valid_files = [item.rstrip() for item in tmp_files]\n",
" \n",
" # Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode\n",
" model = Darknet(cfgfile)\n",
" model.print_network()\n",
" model.load_weights(weightfile)\n",
" model.cuda()\n",
" model.eval()\n",
"\n",
" # Get the parser for the test dataset\n",
" valid_dataset = dataset.listDataset(valid_images, shape=(test_width, test_height),\n",
" shuffle=False,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),]))\n",
" valid_batchsize = 1\n",
"\n",
" # Specify the number of workers for multiple processing, get the dataloader for the test dataset\n",
" kwargs = {'num_workers': 4, 'pin_memory': True}\n",
" test_loader = torch.utils.data.DataLoader(\n",
" valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs) \n",
"\n",
" logging(\" Testing {}...\".format(name))\n",
" logging(\" Number of test samples: %d\" % len(test_loader.dataset))\n",
" # Iterate through test batches (Batch size for test data is 1)\n",
" count = 0\n",
" z = np.zeros((3, 1))\n",
" for batch_idx, (data, target) in enumerate(test_loader):\n",
" \n",
" # Images\n",
" img = data[0, :, :, :]\n",
" img = img.numpy().squeeze()\n",
" img = np.transpose(img, (1, 2, 0))\n",
" \n",
" t1 = time.time()\n",
" # Pass data to GPU\n",
" if use_cuda:\n",
" data = data.cuda()\n",
" target = target.cuda()\n",
" \n",
" # Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference\n",
" data = Variable(data, volatile=True)\n",
" t2 = time.time()\n",
" \n",
" # Forward pass\n",
" output = model(data).data \n",
" t3 = time.time()\n",
" \n",
" # Using confidence threshold, eliminate low-confidence predictions\n",
" all_boxes = get_region_boxes(output, conf_thresh, num_classes) \n",
" t4 = time.time()\n",
"\n",
" # Iterate through all images in the batch\n",
" for i in range(output.size(0)):\n",
" \n",
" # For each image, get all the predictions\n",
" boxes = all_boxes[i]\n",
" \n",
" # For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)\n",
" truths = target[i].view(-1, 21)\n",
" \n",
" # Get how many object are present in the scene\n",
" num_gts = truths_length(truths)\n",
"\n",
" # Iterate through each ground-truth object\n",
" for k in range(num_gts):\n",
" box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6], \n",
" truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12], \n",
" truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]\n",
" best_conf_est = -1\n",
"\n",
" # If the prediction has the highest confidence, choose it as our prediction for single object pose estimation\n",
" for j in range(len(boxes)):\n",
" if (boxes[j][18] > best_conf_est):\n",
" match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))\n",
" box_pr = boxes[j]\n",
" best_conf_est = boxes[j][18]\n",
"\n",
" # Denormalize the corner predictions \n",
" corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')\n",
" corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')\n",
" corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640\n",
" corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480 \n",
" corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640\n",
" corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480\n",
" preds_corners2D.append(corners2D_pr)\n",
" gts_corners2D.append(corners2D_gt)\n",
"\n",
" # Compute corner prediction error\n",
" corner_norm = np.linalg.norm(corners2D_gt - corners2D_pr, axis=1)\n",
" corner_dist = np.mean(corner_norm)\n",
" errs_corner2D.append(corner_dist)\n",
" \n",
" # Compute [R|t] by pnp\n",
" R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(internal_calibration, dtype='float32'))\n",
" R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))\n",
"\n",
" if save:\n",
" preds_trans.append(t_pr)\n",
" gts_trans.append(t_gt)\n",
" preds_rot.append(R_pr)\n",
" gts_rot.append(R_gt)\n",
"\n",
" np.savetxt(backupdir + '/test/gt/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/gt/t_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(t_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/gt/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))\n",
" \n",
" # Compute translation error\n",
" trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))\n",
" errs_trans.append(trans_dist)\n",
" \n",
" # Compute angle error\n",
" angle_dist = calcAngularDistance(R_gt, R_pr)\n",
" errs_angle.append(angle_dist)\n",
" \n",
" # Compute pixel error\n",
" Rt_gt = np.concatenate((R_gt, t_gt), axis=1)\n",
" Rt_pr = np.concatenate((R_pr, t_pr), axis=1)\n",
" proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)\n",
" proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration) \n",
" proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, internal_calibration)) \n",
" proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, internal_calibration)) \n",
" norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)\n",
" pixel_dist = np.mean(norm)\n",
" errs_2d.append(pixel_dist)\n",
"\n",
" if visualize:\n",
" # Visualize\n",
" plt.xlim((0, 640))\n",
" plt.ylim((0, 480))\n",
" plt.imshow(scipy.misc.imresize(img, (480, 640)))\n",
" # Projections\n",
" for edge in edges_corners:\n",
" plt.plot(proj_corners_gt[edge, 0], proj_corners_gt[edge, 1], color='g', linewidth=3.0)\n",
" plt.plot(proj_corners_pr[edge, 0], proj_corners_pr[edge, 1], color='b', linewidth=3.0)\n",
" plt.gca().invert_yaxis()\n",
" plt.show()\n",
" \n",
" # Compute IoU score\n",
" bb_gt = compute_2d_bb_from_orig_pix(proj_2d_gt, output.size(3))\n",
" bb_pred = compute_2d_bb_from_orig_pix(proj_2d_pred, output.size(3))\n",
" iou = bbox_iou(bb_gt, bb_pred)\n",
" ious.append(iou)\n",
"\n",
" # Compute 3D distances\n",
" transform_3d_gt = compute_transformation(vertices, Rt_gt) \n",
" transform_3d_pred = compute_transformation(vertices, Rt_pr) \n",
" norm3d = np.linalg.norm(transform_3d_gt - transform_3d_pred, axis=0)\n",
" vertex_dist = np.mean(norm3d) \n",
" errs_3d.append(vertex_dist) \n",
"\n",
" # Sum errors\n",
" testing_error_trans += trans_dist\n",
" testing_error_angle += angle_dist\n",
" testing_error_pixel += pixel_dist\n",
" testing_samples += 1\n",
" count = count + 1\n",
"\n",
" t5 = time.time()\n",
"\n",
" # Compute 2D projection error, 6D pose error, 5cm5degree error\n",
" px_threshold = 5\n",
" acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)\n",
" acciou = len(np.where(np.array(errs_2d) >= 0.5)[0]) * 100. / (len(ious)+eps)\n",
" acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)\n",
" acc3d10 = len(np.where(np.array(errs_3d) <= diam * 0.1)[0]) * 100. / (len(errs_3d)+eps)\n",
" acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)\n",
" corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)\n",
" mean_err_2d = np.mean(errs_2d)\n",
" mean_corner_err_2d = np.mean(errs_corner2D)\n",
" nts = float(testing_samples)\n",
"\n",
" if testtime:\n",
" print('-----------------------------------')\n",
" print(' tensor to cuda : %f' % (t2 - t1))\n",
" print(' predict : %f' % (t3 - t2))\n",
" print('get_region_boxes : %f' % (t4 - t3))\n",
" print(' nms : %f' % (t5 - t4))\n",
" print(' total : %f' % (t5 - t1))\n",
" print('-----------------------------------')\n",
"\n",
" # Print test statistics\n",
" logging('Results of {}'.format(name))\n",
" logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))\n",
" logging(' Acc using the IoU metric = {:.6f}%'.format(acciou))\n",
" logging(' Acc using 10% threshold - {} vx 3D Transformation = {:.2f}%'.format(diam * 0.1, acc3d10))\n",
" logging(' Acc using 5 cm 5 degree metric = {:.2f}%'.format(acc5cm5deg))\n",
" logging(\" Mean 2D pixel error is %f, Mean vertex error is %f, mean corner error is %f\" % (mean_err_2d, np.mean(errs_3d), mean_corner_err_2d))\n",
" logging(' Translation error: %f m, angle error: %f degree, pixel error: % f pix' % (testing_error_trans/nts, testing_error_angle/nts, testing_error_pixel/nts) )\n",
"\n",
" if save:\n",
" predfile = backupdir + '/predictions_linemod_' + name + '.mat'\n",
" scipy.io.savemat(predfile, {'R_gts': gts_rot, 't_gts':gts_trans, 'corner_gts': gts_corners2D, 'R_prs': preds_rot, 't_prs':preds_trans, 'corner_prs': preds_corners2D})\n",
"\n",
"datacfg = 'cfg/ape.data'\n",
"cfgfile = 'cfg/yolo-pose.cfg'\n",
"weightfile = 'backup/ape/model_backup.weights'\n",
"valid(datacfg, cfgfile, weightfile)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

269
py2/valid.py Normal file
Просмотреть файл

@ -0,0 +1,269 @@
import os
import time
import torch
from torch.autograd import Variable
from torchvision import datasets, transforms
import scipy.io
import warnings
warnings.filterwarnings("ignore")
from darknet import Darknet
import dataset
from utils import *
from MeshPly import MeshPly
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
def valid(datacfg, cfgfile, weightfile, outfile):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
# Parse configuration files
options = read_data_cfg(datacfg)
valid_images = options['valid']
meshname = options['mesh']
backupdir = options['backup']
name = options['name']
if not os.path.exists(backupdir):
makedirs(backupdir)
# Parameters
prefix = 'results'
seed = int(time.time())
gpus = '0' # Specify which gpus to use
test_width = 544
test_height = 544
torch.manual_seed(seed)
use_cuda = True
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
save = False
testtime = True
use_cuda = True
num_classes = 1
testing_samples = 0.0
eps = 1e-5
notpredicted = 0
conf_thresh = 0.1
nms_thresh = 0.4
match_thresh = 0.5
if save:
makedirs(backupdir + '/test')
makedirs(backupdir + '/test/gt')
makedirs(backupdir + '/test/pr')
# To save
testing_error_trans = 0.0
testing_error_angle = 0.0
testing_error_pixel = 0.0
errs_2d = []
errs_3d = []
errs_trans = []
errs_angle = []
errs_corner2D = []
preds_trans = []
preds_rot = []
preds_corners2D = []
gts_trans = []
gts_rot = []
gts_corners2D = []
# Read object model information, get 3D bounding box corners
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
# diam = calc_pts_diameter(np.array(mesh.vertices))
diam = float(options['diam'])
# Read intrinsic camera parameters
internal_calibration = get_camera_intrinsic()
# Get validation file names
with open(valid_images) as fp:
tmp_files = fp.readlines()
valid_files = [item.rstrip() for item in tmp_files]
# Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode
model = Darknet(cfgfile)
model.print_network()
model.load_weights(weightfile)
model.cuda()
model.eval()
# Get the parser for the test dataset
valid_dataset = dataset.listDataset(valid_images, shape=(test_width, test_height),
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor(),]))
valid_batchsize = 1
# Specify the number of workers for multiple processing, get the dataloader for the test dataset
kwargs = {'num_workers': 4, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs)
logging(" Testing {}...".format(name))
logging(" Number of test samples: %d" % len(test_loader.dataset))
# Iterate through test batches (Batch size for test data is 1)
count = 0
z = np.zeros((3, 1))
for batch_idx, (data, target) in enumerate(test_loader):
t1 = time.time()
# Pass data to GPU
if use_cuda:
data = data.cuda()
target = target.cuda()
# Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference
data = Variable(data, volatile=True)
t2 = time.time()
# Forward pass
output = model(data).data
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
all_boxes = get_region_boxes(output, conf_thresh, num_classes)
t4 = time.time()
# Iterate through all images in the batch
for i in range(output.size(0)):
# For each image, get all the predictions
boxes = all_boxes[i]
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
# Get how many object are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
# If the prediction has the highest confidence, choose it as our prediction for single object pose estimation
for j in range(len(boxes)):
if (boxes[j][18] > best_conf_est):
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
box_pr = boxes[j]
best_conf_est = boxes[j][18]
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640
corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480
corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640
corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480
preds_corners2D.append(corners2D_pr)
gts_corners2D.append(corners2D_gt)
# Compute corner prediction error
corner_norm = np.linalg.norm(corners2D_gt - corners2D_pr, axis=1)
corner_dist = np.mean(corner_norm)
errs_corner2D.append(corner_dist)
# Compute [R|t] by pnp
R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(internal_calibration, dtype='float32'))
R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))
if save:
preds_trans.append(t_pr)
gts_trans.append(t_gt)
preds_rot.append(R_pr)
gts_rot.append(R_gt)
np.savetxt(backupdir + '/test/gt/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_gt, dtype='float32'))
np.savetxt(backupdir + '/test/gt/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_gt, dtype='float32'))
np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))
np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))
np.savetxt(backupdir + '/test/gt/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_gt, dtype='float32'))
np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))
# Compute translation error
trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))
errs_trans.append(trans_dist)
# Compute angle error
angle_dist = calcAngularDistance(R_gt, R_pr)
errs_angle.append(angle_dist)
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
# Compute 3D distances
transform_3d_gt = compute_transformation(vertices, Rt_gt)
transform_3d_pred = compute_transformation(vertices, Rt_pr)
norm3d = np.linalg.norm(transform_3d_gt - transform_3d_pred, axis=0)
vertex_dist = np.mean(norm3d)
errs_3d.append(vertex_dist)
# Sum errors
testing_error_trans += trans_dist
testing_error_angle += angle_dist
testing_error_pixel += pixel_dist
testing_samples += 1
count = count + 1
t5 = time.time()
# Compute 2D projection error, 6D pose error, 5cm5degree error
px_threshold = 5
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
acc3d10 = len(np.where(np.array(errs_3d) <= diam * 0.1)[0]) * 100. / (len(errs_3d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)
mean_err_2d = np.mean(errs_2d)
mean_corner_err_2d = np.mean(errs_corner2D)
nts = float(testing_samples)
if testtime:
print('-----------------------------------')
print(' tensor to cuda : %f' % (t2 - t1))
print(' predict : %f' % (t3 - t2))
print('get_region_boxes : %f' % (t4 - t3))
print(' eval : %f' % (t5 - t4))
print(' total : %f' % (t5 - t1))
print('-----------------------------------')
# Print test statistics
logging('Results of {}'.format(name))
logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))
logging(' Acc using 10% threshold - {} vx 3D Transformation = {:.2f}%'.format(diam * 0.1, acc3d10))
logging(' Acc using 5 cm 5 degree metric = {:.2f}%'.format(acc5cm5deg))
logging(" Mean 2D pixel error is %f, Mean vertex error is %f, mean corner error is %f" % (mean_err_2d, np.mean(errs_3d), mean_corner_err_2d))
logging(' Translation error: %f m, angle error: %f degree, pixel error: % f pix' % (testing_error_trans/nts, testing_error_angle/nts, testing_error_pixel/nts) )
if save:
predfile = backupdir + '/predictions_linemod_' + name + '.mat'
scipy.io.savemat(predfile, {'R_gts': gts_rot, 't_gts':gts_trans, 'corner_gts': gts_corners2D, 'R_prs': preds_rot, 't_prs':preds_trans, 'corner_prs': preds_corners2D})
if __name__ == '__main__':
import sys
if len(sys.argv) == 4:
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
weightfile = sys.argv[3]
outfile = 'comp4_det_test_'
valid(datacfg, cfgfile, weightfile, outfile)
else:
print('Usage:')
print(' python valid.py datacfg cfgfile weightfile')

Просмотреть файл

@ -6,165 +6,93 @@ import torch.nn.functional as F
from torch.autograd import Variable
from utils import *
def build_targets(pred_corners, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, seen):
def build_targets(pred_corners, target, num_keypoints, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, seen):
nB = target.size(0)
nA = num_anchors
nC = num_classes
anchor_step = len(anchors)/num_anchors
conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale
coord_mask = torch.zeros(nB, nA, nH, nW)
cls_mask = torch.zeros(nB, nA, nH, nW)
tx0 = torch.zeros(nB, nA, nH, nW)
ty0 = torch.zeros(nB, nA, nH, nW)
tx1 = torch.zeros(nB, nA, nH, nW)
ty1 = torch.zeros(nB, nA, nH, nW)
tx2 = torch.zeros(nB, nA, nH, nW)
ty2 = torch.zeros(nB, nA, nH, nW)
tx3 = torch.zeros(nB, nA, nH, nW)
ty3 = torch.zeros(nB, nA, nH, nW)
tx4 = torch.zeros(nB, nA, nH, nW)
ty4 = torch.zeros(nB, nA, nH, nW)
tx5 = torch.zeros(nB, nA, nH, nW)
ty5 = torch.zeros(nB, nA, nH, nW)
tx6 = torch.zeros(nB, nA, nH, nW)
ty6 = torch.zeros(nB, nA, nH, nW)
tx7 = torch.zeros(nB, nA, nH, nW)
ty7 = torch.zeros(nB, nA, nH, nW)
tx8 = torch.zeros(nB, nA, nH, nW)
ty8 = torch.zeros(nB, nA, nH, nW)
tconf = torch.zeros(nB, nA, nH, nW)
tcls = torch.zeros(nB, nA, nH, nW)
txs = list()
tys = list()
for i in range(num_keypoints):
txs.append(torch.zeros(nB, nA, nH, nW))
tys.append(torch.zeros(nB, nA, nH, nW))
tconf = torch.zeros(nB, nA, nH, nW)
tcls = torch.zeros(nB, nA, nH, nW)
num_labels = 2 * num_keypoints + 3 # +2 for width, height and +1 for class within label files
nAnchors = nA*nH*nW
nPixels = nH*nW
for b in range(nB):
cur_pred_corners = pred_corners[b*nAnchors:(b+1)*nAnchors].t()
cur_confs = torch.zeros(nAnchors)
for t in range(50):
if target[b][t*21+1] == 0:
if target[b][t*num_labels+1] == 0:
break
gx0 = target[b][t*21+1]*nW
gy0 = target[b][t*21+2]*nH
gx1 = target[b][t*21+3]*nW
gy1 = target[b][t*21+4]*nH
gx2 = target[b][t*21+5]*nW
gy2 = target[b][t*21+6]*nH
gx3 = target[b][t*21+7]*nW
gy3 = target[b][t*21+8]*nH
gx4 = target[b][t*21+9]*nW
gy4 = target[b][t*21+10]*nH
gx5 = target[b][t*21+11]*nW
gy5 = target[b][t*21+12]*nH
gx6 = target[b][t*21+13]*nW
gy6 = target[b][t*21+14]*nH
gx7 = target[b][t*21+15]*nW
gy7 = target[b][t*21+16]*nH
gx8 = target[b][t*21+17]*nW
gy8 = target[b][t*21+18]*nH
g = list()
for i in range(num_keypoints):
g.append(target[b][t*num_labels+2*i+1])
g.append(target[b][t*num_labels+2*i+2])
cur_gt_corners = torch.FloatTensor([gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]).repeat(nAnchors,1).t() # 16 x nAnchors
cur_confs = torch.max(cur_confs, corner_confidences9(cur_pred_corners, cur_gt_corners)) # some irrelevant areas are filtered, in the same grid multiple anchor boxes might exceed the threshold
cur_gt_corners = torch.FloatTensor(g).repeat(nAnchors,1).t() # 16 x nAnchors
cur_confs = torch.max(cur_confs, corner_confidences(cur_pred_corners, cur_gt_corners)).view_as(conf_mask[b]) # some irrelevant areas are filtered, in the same grid multiple anchor boxes might exceed the threshold
conf_mask[b][cur_confs>sil_thresh] = 0
if seen < -1:#6400:
tx0.fill_(0.5)
ty0.fill_(0.5)
tx1.fill_(0.5)
ty1.fill_(0.5)
tx2.fill_(0.5)
ty2.fill_(0.5)
tx3.fill_(0.5)
ty3.fill_(0.5)
tx4.fill_(0.5)
ty4.fill_(0.5)
tx5.fill_(0.5)
ty5.fill_(0.5)
tx6.fill_(0.5)
ty6.fill_(0.5)
tx7.fill_(0.5)
ty7.fill_(0.5)
tx8.fill_(0.5)
ty8.fill_(0.5)
coord_mask.fill_(1)
nGT = 0
nCorrect = 0
for b in range(nB):
for t in range(50):
if target[b][t*21+1] == 0:
if target[b][t*num_labels+1] == 0:
break
# Get gt box for the current label
nGT = nGT + 1
best_iou = 0.0
best_n = -1
min_dist = 10000
gx0 = target[b][t*21+1] * nW
gy0 = target[b][t*21+2] * nH
gi0 = int(gx0)
gj0 = int(gy0)
gx1 = target[b][t*21+3] * nW
gy1 = target[b][t*21+4] * nH
gx2 = target[b][t*21+5] * nW
gy2 = target[b][t*21+6] * nH
gx3 = target[b][t*21+7] * nW
gy3 = target[b][t*21+8] * nH
gx4 = target[b][t*21+9] * nW
gy4 = target[b][t*21+10] * nH
gx5 = target[b][t*21+11] * nW
gy5 = target[b][t*21+12] * nH
gx6 = target[b][t*21+13] * nW
gy6 = target[b][t*21+14] * nH
gx7 = target[b][t*21+15] * nW
gy7 = target[b][t*21+16] * nH
gx8 = target[b][t*21+17] * nW
gy8 = target[b][t*21+18] * nH
gx = list()
gy = list()
gt_box = list()
for i in range(num_keypoints):
gt_box.extend([target[b][t*num_labels+2*i+1], target[b][t*num_labels+2*i+2]])
gx.append(target[b][t*num_labels+2*i+1] * nW)
gy.append(target[b][t*num_labels+2*i+2] * nH)
if i == 0:
gi0 = int(gx[i])
gj0 = int(gy[i])
# Update masks
best_n = 0 # 1 anchor box
gt_box = [gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]
pred_box = pred_corners[b*nAnchors+best_n*nPixels+gj0*nW+gi0]
conf = corner_confidence9(gt_box, pred_box)
conf = corner_confidence(gt_box, pred_box)
coord_mask[b][best_n][gj0][gi0] = 1
cls_mask[b][best_n][gj0][gi0] = 1
conf_mask[b][best_n][gj0][gi0] = object_scale
tx0[b][best_n][gj0][gi0] = target[b][t*21+1] * nW - gi0
ty0[b][best_n][gj0][gi0] = target[b][t*21+2] * nH - gj0
tx1[b][best_n][gj0][gi0] = target[b][t*21+3] * nW - gi0
ty1[b][best_n][gj0][gi0] = target[b][t*21+4] * nH - gj0
tx2[b][best_n][gj0][gi0] = target[b][t*21+5] * nW - gi0
ty2[b][best_n][gj0][gi0] = target[b][t*21+6] * nH - gj0
tx3[b][best_n][gj0][gi0] = target[b][t*21+7] * nW - gi0
ty3[b][best_n][gj0][gi0] = target[b][t*21+8] * nH - gj0
tx4[b][best_n][gj0][gi0] = target[b][t*21+9] * nW - gi0
ty4[b][best_n][gj0][gi0] = target[b][t*21+10] * nH - gj0
tx5[b][best_n][gj0][gi0] = target[b][t*21+11] * nW - gi0
ty5[b][best_n][gj0][gi0] = target[b][t*21+12] * nH - gj0
tx6[b][best_n][gj0][gi0] = target[b][t*21+13] * nW - gi0
ty6[b][best_n][gj0][gi0] = target[b][t*21+14] * nH - gj0
tx7[b][best_n][gj0][gi0] = target[b][t*21+15] * nW - gi0
ty7[b][best_n][gj0][gi0] = target[b][t*21+16] * nH - gj0
tx8[b][best_n][gj0][gi0] = target[b][t*21+17] * nW - gi0
ty8[b][best_n][gj0][gi0] = target[b][t*21+18] * nH - gj0
# Update targets
for i in range(num_keypoints):
txs[i][b][best_n][gj0][gi0] = gx[i]- gi0
tys[i][b][best_n][gj0][gi0] = gy[i]- gj0
tconf[b][best_n][gj0][gi0] = conf
tcls[b][best_n][gj0][gi0] = target[b][t*21]
if conf > 0.5:
tcls[b][best_n][gj0][gi0] = target[b][t*num_labels]
# Update recall during training
if conf > 0.5:
nCorrect = nCorrect + 1
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, txs, tys, tconf, tcls
class RegionLoss(nn.Module):
def __init__(self, num_classes=0, anchors=[], num_anchors=1):
def __init__(self, num_keypoints=9, num_classes=1, anchors=[], num_anchors=1, pretrain_num_epochs=15):
# Define the loss layer
super(RegionLoss, self).__init__()
self.num_classes = num_classes
self.anchors = anchors
self.num_anchors = num_anchors
self.anchor_step = len(anchors)/num_anchors
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.seen = 0
self.num_classes = num_classes
self.num_anchors = num_anchors # for single object pose estimation, there is only 1 trivial predictor (anchor)
self.num_keypoints = num_keypoints
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.seen = 0
self.pretrain_num_epochs = pretrain_num_epochs
def forward(self, output, target):
def forward(self, output, target, epoch):
# Parameters
t0 = time.time()
nB = output.data.size(0)
@ -172,83 +100,43 @@ class RegionLoss(nn.Module):
nC = self.num_classes
nH = output.data.size(2)
nW = output.data.size(3)
num_keypoints = self.num_keypoints
# Activation
output = output.view(nB, nA, (19+nC), nH, nW)
x0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW))
y0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW))
x1 = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW)
y1 = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW)
x2 = output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW)
y2 = output.index_select(2, Variable(torch.cuda.LongTensor([5]))).view(nB, nA, nH, nW)
x3 = output.index_select(2, Variable(torch.cuda.LongTensor([6]))).view(nB, nA, nH, nW)
y3 = output.index_select(2, Variable(torch.cuda.LongTensor([7]))).view(nB, nA, nH, nW)
x4 = output.index_select(2, Variable(torch.cuda.LongTensor([8]))).view(nB, nA, nH, nW)
y4 = output.index_select(2, Variable(torch.cuda.LongTensor([9]))).view(nB, nA, nH, nW)
x5 = output.index_select(2, Variable(torch.cuda.LongTensor([10]))).view(nB, nA, nH, nW)
y5 = output.index_select(2, Variable(torch.cuda.LongTensor([11]))).view(nB, nA, nH, nW)
x6 = output.index_select(2, Variable(torch.cuda.LongTensor([12]))).view(nB, nA, nH, nW)
y6 = output.index_select(2, Variable(torch.cuda.LongTensor([13]))).view(nB, nA, nH, nW)
x7 = output.index_select(2, Variable(torch.cuda.LongTensor([14]))).view(nB, nA, nH, nW)
y7 = output.index_select(2, Variable(torch.cuda.LongTensor([15]))).view(nB, nA, nH, nW)
x8 = output.index_select(2, Variable(torch.cuda.LongTensor([16]))).view(nB, nA, nH, nW)
y8 = output.index_select(2, Variable(torch.cuda.LongTensor([17]))).view(nB, nA, nH, nW)
conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([18]))).view(nB, nA, nH, nW))
cls = output.index_select(2, Variable(torch.linspace(19,19+nC-1,nC).long().cuda()))
output = output.view(nB, nA, (num_keypoints*2+1+nC), nH, nW)
x = list()
y = list()
x.append(torch.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW)))
y.append(torch.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW)))
for i in range(1,num_keypoints):
x.append(output.index_select(2, Variable(torch.cuda.LongTensor([2 * i + 0]))).view(nB, nA, nH, nW))
y.append(output.index_select(2, Variable(torch.cuda.LongTensor([2 * i + 1]))).view(nB, nA, nH, nW))
conf = torch.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([2 * num_keypoints]))).view(nB, nA, nH, nW))
cls = output.index_select(2, Variable(torch.linspace(2*num_keypoints+1,2*num_keypoints+1+nC-1,nC).long().cuda()))
cls = cls.view(nB*nA, nC, nH*nW).transpose(1,2).contiguous().view(nB*nA*nH*nW, nC)
t1 = time.time()
# Create pred boxes
pred_corners = torch.cuda.FloatTensor(18, nB*nA*nH*nW)
pred_corners = torch.cuda.FloatTensor(2*num_keypoints, nB*nA*nH*nW)
grid_x = torch.linspace(0, nW-1, nW).repeat(nH,1).repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
grid_y = torch.linspace(0, nH-1, nH).repeat(nW,1).t().repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
pred_corners[0] = (x0.data + grid_x) / nW
pred_corners[1] = (y0.data + grid_y) / nH
pred_corners[2] = (x1.data + grid_x) / nW
pred_corners[3] = (y1.data + grid_y) / nH
pred_corners[4] = (x2.data + grid_x) / nW
pred_corners[5] = (y2.data + grid_y) / nH
pred_corners[6] = (x3.data + grid_x) / nW
pred_corners[7] = (y3.data + grid_y) / nH
pred_corners[8] = (x4.data + grid_x) / nW
pred_corners[9] = (y4.data + grid_y) / nH
pred_corners[10] = (x5.data + grid_x) / nW
pred_corners[11] = (y5.data + grid_y) / nH
pred_corners[12] = (x6.data + grid_x) / nW
pred_corners[13] = (y6.data + grid_y) / nH
pred_corners[14] = (x7.data + grid_x) / nW
pred_corners[15] = (y7.data + grid_y) / nH
pred_corners[16] = (x8.data + grid_x) / nW
pred_corners[17] = (y8.data + grid_y) / nH
gpu_matrix = pred_corners.transpose(0,1).contiguous().view(-1,18)
for i in range(num_keypoints):
pred_corners[2 * i + 0] = (x[i].data.view_as(grid_x) + grid_x) / nW
pred_corners[2 * i + 1] = (y[i].data.view_as(grid_y) + grid_y) / nH
gpu_matrix = pred_corners.transpose(0,1).contiguous().view(-1,2*num_keypoints)
pred_corners = convert2cpu(gpu_matrix)
t2 = time.time()
# Build targets
nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls = \
build_targets(pred_corners, target.data, self.anchors, nA, nC, nH, nW, self.noobject_scale, self.object_scale, self.thresh, self.seen)
nGT, nCorrect, coord_mask, conf_mask, cls_mask, txs, tys, tconf, tcls = \
build_targets(pred_corners, target.data, num_keypoints, nA, nC, nH, nW, self.noobject_scale, self.object_scale, self.thresh, self.seen)
cls_mask = (cls_mask == 1)
nProposals = int((conf > 0.25).sum().data[0])
tx0 = Variable(tx0.cuda())
ty0 = Variable(ty0.cuda())
tx1 = Variable(tx1.cuda())
ty1 = Variable(ty1.cuda())
tx2 = Variable(tx2.cuda())
ty2 = Variable(ty2.cuda())
tx3 = Variable(tx3.cuda())
ty3 = Variable(ty3.cuda())
tx4 = Variable(tx4.cuda())
ty4 = Variable(ty4.cuda())
tx5 = Variable(tx5.cuda())
ty5 = Variable(ty5.cuda())
tx6 = Variable(tx6.cuda())
ty6 = Variable(ty6.cuda())
tx7 = Variable(tx7.cuda())
ty7 = Variable(ty7.cuda())
tx8 = Variable(tx8.cuda())
ty8 = Variable(ty8.cuda())
for i in range(num_keypoints):
txs[i] = Variable(txs[i].cuda())
tys[i] = Variable(tys[i].cuda())
tconf = Variable(tconf.cuda())
tcls = Variable(tcls.view(-1)[cls_mask].long().cuda())
tcls = Variable(tcls[cls_mask].long().cuda())
coord_mask = Variable(coord_mask.cuda())
conf_mask = Variable(conf_mask.cuda().sqrt())
cls_mask = Variable(cls_mask.view(-1, 1).repeat(1,nC).cuda())
@ -256,33 +144,22 @@ class RegionLoss(nn.Module):
t3 = time.time()
# Create loss
loss_x0 = self.coord_scale * nn.MSELoss(size_average=False)(x0*coord_mask, tx0*coord_mask)/2.0
loss_y0 = self.coord_scale * nn.MSELoss(size_average=False)(y0*coord_mask, ty0*coord_mask)/2.0
loss_x1 = self.coord_scale * nn.MSELoss(size_average=False)(x1*coord_mask, tx1*coord_mask)/2.0
loss_y1 = self.coord_scale * nn.MSELoss(size_average=False)(y1*coord_mask, ty1*coord_mask)/2.0
loss_x2 = self.coord_scale * nn.MSELoss(size_average=False)(x2*coord_mask, tx2*coord_mask)/2.0
loss_y2 = self.coord_scale * nn.MSELoss(size_average=False)(y2*coord_mask, ty2*coord_mask)/2.0
loss_x3 = self.coord_scale * nn.MSELoss(size_average=False)(x3*coord_mask, tx3*coord_mask)/2.0
loss_y3 = self.coord_scale * nn.MSELoss(size_average=False)(y3*coord_mask, ty3*coord_mask)/2.0
loss_x4 = self.coord_scale * nn.MSELoss(size_average=False)(x4*coord_mask, tx4*coord_mask)/2.0
loss_y4 = self.coord_scale * nn.MSELoss(size_average=False)(y4*coord_mask, ty4*coord_mask)/2.0
loss_x5 = self.coord_scale * nn.MSELoss(size_average=False)(x5*coord_mask, tx5*coord_mask)/2.0
loss_y5 = self.coord_scale * nn.MSELoss(size_average=False)(y5*coord_mask, ty5*coord_mask)/2.0
loss_x6 = self.coord_scale * nn.MSELoss(size_average=False)(x6*coord_mask, tx6*coord_mask)/2.0
loss_y6 = self.coord_scale * nn.MSELoss(size_average=False)(y6*coord_mask, ty6*coord_mask)/2.0
loss_x7 = self.coord_scale * nn.MSELoss(size_average=False)(x7*coord_mask, tx7*coord_mask)/2.0
loss_y7 = self.coord_scale * nn.MSELoss(size_average=False)(y7*coord_mask, ty7*coord_mask)/2.0
loss_x8 = self.coord_scale * nn.MSELoss(size_average=False)(x8*coord_mask, tx8*coord_mask)/2.0
loss_y8 = self.coord_scale * nn.MSELoss(size_average=False)(y8*coord_mask, ty8*coord_mask)/2.0
loss_xs = list()
loss_ys = list()
for i in range(num_keypoints):
loss_xs.append(self.coord_scale * nn.MSELoss(size_average=False)(x[i]*coord_mask, txs[i]*coord_mask)/2.0)
loss_ys.append(self.coord_scale * nn.MSELoss(size_average=False)(y[i]*coord_mask, tys[i]*coord_mask)/2.0)
loss_conf = nn.MSELoss(size_average=False)(conf*conf_mask, tconf*conf_mask)/2.0
# loss_cls = self.class_scale * nn.CrossEntropyLoss(size_average=False)(cls, tcls)
loss_cls = 0
loss_x = loss_x0 + loss_x1 + loss_x2 + loss_x3 + loss_x4 + loss_x5 + loss_x6 + loss_x7 + loss_x8
loss_y = loss_y0 + loss_y1 + loss_y2 + loss_y3 + loss_y4 + loss_y5 + loss_y6 + loss_y7 + loss_y8
if False:
loss = loss_x + loss_y + loss_conf + loss_cls
loss_x = np.sum(loss_xs)
loss_y = np.sum(loss_ys)
if epoch > self.pretrain_num_epochs:
loss = loss_x + loss_y + loss_conf # in single object pose estimation, there is no classification loss
else:
loss = loss_x + loss_y + loss_conf
# pretrain initially without confidence loss
# once the coordinate predictions get better, start training for confidence as well
loss = loss_x + loss_y
t4 = time.time()
if False:
@ -293,9 +170,6 @@ class RegionLoss(nn.Module):
print(' create loss : %f' % (t4 - t3))
print(' total : %f' % (t4 - t0))
if False:
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, cls %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss_cls.data[0], loss.data[0]))
else:
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss.data[0]))
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss.data[0]))
return loss

199
train.py
Просмотреть файл

@ -11,6 +11,7 @@ import os
import random
import math
import shutil
import argparse
from torchvision import datasets, transforms
from torch.autograd import Variable # Useful info about autograd: http://pytorch.org/docs/master/notes/autograd.html
@ -21,6 +22,9 @@ from region_loss import RegionLoss
from darknet import Darknet
from MeshPly import MeshPly
import warnings
warnings.filterwarnings("ignore")
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
@ -49,13 +53,15 @@ def train(epoch):
t0 = time.time()
# Get the dataloader for training dataset
train_loader = torch.utils.data.DataLoader(dataset.listDataset(trainlist, shape=(init_width, init_height),
shuffle=True,
transform=transforms.Compose([transforms.ToTensor(),]),
train=True,
seen=model.seen,
batch_size=batch_size,
num_workers=num_workers, bg_file_names=bg_file_names),
train_loader = torch.utils.data.DataLoader(dataset.listDataset(trainlist,
shape=(init_width, init_height),
shuffle=True,
transform=transforms.Compose([transforms.ToTensor(),]),
train=True,
seen=model.seen,
batch_size=batch_size,
num_workers=num_workers,
bg_file_names=bg_file_names),
batch_size=batch_size, shuffle=False, **kwargs)
# TRAINING
@ -88,7 +94,7 @@ def train(epoch):
model.seen = model.seen + data.data.size(0)
region_loss.seen = region_loss.seen + data.data.size(0)
# Compute loss, grow an array of losses for saving later on
loss = region_loss(output, target)
loss = region_loss(output, target, epoch)
training_iters.append(epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter)
training_losses.append(convert2cpu(loss.data))
niter += 1
@ -147,7 +153,6 @@ def test(epoch, niter):
errs_trans = []
errs_angle = []
errs_corner2D = []
logging(" Testing...")
logging(" Number of test samples: %d" % len(test_loader.dataset))
notpredicted = 0
@ -165,34 +170,25 @@ def test(epoch, niter):
output = model(data).data
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
all_boxes = get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors)
all_boxes = get_region_boxes(output, num_classes, num_keypoints)
t4 = time.time()
# Iterate through all batch elements
for i in range(output.size(0)):
# For each image, get all the predictions
boxes = all_boxes[i]
for box_pr, target in zip([all_boxes], [target[0]]):
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
# Get how many object are present in the scene
num_gts = truths_length(truths)
truths = target.view(-1, num_keypoints*2+3)
# Get how many objects are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
# If the prediction has the highest confidence, choose it as our prediction
for j in range(len(boxes)):
if boxes[j][18] > best_conf_est:
best_conf_est = boxes[j][18]
box_pr = boxes[j]
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
box_gt = list()
for j in range(1, 2*num_keypoints+1):
box_gt.append(truths[k][j])
box_gt.extend([1.0, 1.0])
box_gt.append(truths[k][0])
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt = np.array(np.reshape(box_gt[:num_keypoints*2], [num_keypoints, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:num_keypoints*2], [num_keypoints, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width
@ -208,7 +204,6 @@ def test(epoch, niter):
R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))
# Compute errors
# Compute translation error
trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))
errs_trans.append(trans_dist)
@ -242,12 +237,13 @@ def test(epoch, niter):
t5 = time.time()
# Compute 2D projection, 6D pose and 5cm5degree scores
px_threshold = 5
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
acc3d = len(np.where(np.array(errs_3d) <= vx_threshold)[0]) * 100. / (len(errs_3d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)
mean_err_2d = np.mean(errs_2d)
px_threshold = 5 # 5 pixel threshold for 2D reprojection error is standard in recent sota 6D object pose estimation works
eps = 1e-5
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
acc3d = len(np.where(np.array(errs_3d) <= vx_threshold)[0]) * 100. / (len(errs_3d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)
mean_err_2d = np.mean(errs_2d)
mean_corner_err_2d = np.mean(errs_corner2D)
nts = float(testing_samples)
@ -276,24 +272,28 @@ def test(epoch, niter):
if __name__ == "__main__":
# Training settings
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
weightfile = sys.argv[3]
# Parse configuration files
parser = argparse.ArgumentParser(description='SingleShotPose')
parser.add_argument('--datacfg', type=str, default='cfg/ape.data') # data config
parser.add_argument('--modelcfg', type=str, default='cfg/yolo-pose.cfg') # network config
parser.add_argument('--initweightfile', type=str, default='cfg/darknet19_448.conv.23') # imagenet initialized weights
parser.add_argument('--pretrain_num_epochs', type=int, default=15) # how many epoch to pretrain
args = parser.parse_args()
datacfg = args.datacfg
modelcfg = args.modelcfg
initweightfile = args.initweightfile
pretrain_num_epochs = args.pretrain_num_epochs
# Parse configuration files
data_options = read_data_cfg(datacfg)
net_options = parse_cfg(cfgfile)[0]
net_options = parse_cfg(modelcfg)[0]
trainlist = data_options['train']
testlist = data_options['valid']
nsamples = file_lines(trainlist)
gpus = data_options['gpus'] # e.g. 0,1,2,3
gpus = '0'
gpus = data_options['gpus']
meshname = data_options['mesh']
num_workers = int(data_options['num_workers'])
backupdir = data_options['backup']
diam = float(data_options['diam'])
vx_threshold = diam * 0.1
vx_threshold = float(data_options['diam']) * 0.1 # threshold for the ADD metric
if not os.path.exists(backupdir):
makedirs(backupdir)
batch_size = int(net_options['batch'])
@ -301,49 +301,49 @@ if __name__ == "__main__":
learning_rate = float(net_options['learning_rate'])
momentum = float(net_options['momentum'])
decay = float(net_options['decay'])
steps = [float(step) for step in net_options['steps'].split(',')]
nsamples = file_lines(trainlist)
batch_size = int(net_options['batch'])
nbatches = nsamples / batch_size
steps = [float(step)*nbatches for step in net_options['steps'].split(',')]
scales = [float(scale) for scale in net_options['scales'].split(',')]
bg_file_names = get_all_files('VOCdevkit/VOC2012/JPEGImages')
# Train parameters
max_epochs = 700 # max_batches*batch_size/nsamples+1
use_cuda = True
seed = int(time.time())
eps = 1e-5
save_interval = 10 # epoches
dot_interval = 70 # batches
best_acc = -1
max_epochs = int(net_options['max_epochs'])
num_keypoints = int(net_options['num_keypoints'])
# Test parameters
conf_thresh = 0.1
nms_thresh = 0.4
iou_thresh = 0.5
im_width = 640
im_height = 480
im_width = int(data_options['width'])
im_height = int(data_options['height'])
fx = float(data_options['fx'])
fy = float(data_options['fy'])
u0 = float(data_options['u0'])
v0 = float(data_options['v0'])
test_width = int(net_options['test_width'])
test_height = int(net_options['test_height'])
# Specify which gpus to use
use_cuda = True
seed = int(time.time())
torch.manual_seed(seed)
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
# Specifiy the model and the loss
model = Darknet(cfgfile)
region_loss = model.loss
model = Darknet(modelcfg)
region_loss = RegionLoss(num_keypoints=9, num_classes=1, anchors=[], num_anchors=1, pretrain_num_epochs=15)
# Model settings
# model.load_weights(weightfile)
model.load_weights_until_last(weightfile)
model.load_weights_until_last(initweightfile)
model.print_network()
model.seen = 0
region_loss.iter = model.iter
region_loss.seen = model.seen
processed_batches = model.seen/batch_size
processed_batches = model.seen//batch_size
init_width = model.width
init_height = model.height
test_width = 672
test_height = 672
init_epoch = model.seen/nsamples
init_epoch = model.seen//nsamples
# Variable to save
training_iters = []
@ -359,16 +359,18 @@ if __name__ == "__main__":
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
internal_calibration = get_camera_intrinsic()
internal_calibration = get_camera_intrinsic(u0, v0, fx, fy)
# Specify the number of workers
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
# Get the dataloader for test data
test_loader = torch.utils.data.DataLoader(dataset.listDataset(testlist, shape=(test_width, test_height),
shuffle=False,
transform=transforms.Compose([transforms.ToTensor(),]),
train=False),
test_loader = torch.utils.data.DataLoader(dataset.listDataset(testlist,
shape=(test_width, test_height),
shuffle=False,
transform=transforms.Compose([transforms.ToTensor(),]),
train=False),
batch_size=1, shuffle=False, **kwargs)
# Pass the model to GPU
@ -384,30 +386,25 @@ if __name__ == "__main__":
else:
params += [{'params': [value], 'weight_decay': decay*batch_size}]
optimizer = optim.SGD(model.parameters(), lr=learning_rate/batch_size, momentum=momentum, dampening=0, weight_decay=decay*batch_size)
# optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam optimization
evaluate = False
if evaluate:
logging('evaluating ...')
test(0, 0)
else:
for epoch in range(init_epoch, max_epochs):
# TRAIN
niter = train(epoch)
# TEST and SAVE
if (epoch % 10 == 0) and (epoch is not 0):
test(epoch, niter)
logging('save training stats to %s/costs.npz' % (backupdir))
np.savez(os.path.join(backupdir, "costs.npz"),
training_iters=training_iters,
training_losses=training_losses,
testing_iters=testing_iters,
testing_accuracies=testing_accuracies,
testing_errors_pixel=testing_errors_pixel,
testing_errors_angle=testing_errors_angle)
if (testing_accuracies[-1] > best_acc ):
best_acc = testing_accuracies[-1]
logging('best model so far!')
logging('save weights to %s/model.weights' % (backupdir))
model.save_weights('%s/model.weights' % (backupdir))
shutil.copy2('%s/model.weights' % (backupdir), '%s/model_backup.weights' % (backupdir))
best_acc = -sys.maxsize
for epoch in range(init_epoch, max_epochs):
# TRAIN
niter = train(epoch)
# TEST and SAVE
if (epoch % 10 == 0) and (epoch > 15):
test(epoch, niter)
logging('save training stats to %s/costs.npz' % (backupdir))
np.savez(os.path.join(backupdir, "costs.npz"),
training_iters=training_iters,
training_losses=training_losses,
testing_iters=testing_iters,
testing_accuracies=testing_accuracies,
testing_errors_pixel=testing_errors_pixel,
testing_errors_angle=testing_errors_angle)
if (testing_accuracies[-1] > best_acc ):
best_acc = testing_accuracies[-1]
logging('best model so far!')
logging('save weights to %s/model.weights' % (backupdir))
model.save_weights('%s/model.weights' % (backupdir))
# shutil.copy2('%s/model.weights' % (backupdir), '%s/model_backup.weights' % (backupdir))

758
utils.py
Просмотреть файл

@ -13,6 +13,11 @@ from scipy import spatial
import struct
import imghdr
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
def get_all_files(directory):
files = []
@ -29,12 +34,8 @@ def calcAngularDistance(gt_rot, pr_rot):
trace = np.trace(rotDiff)
return np.rad2deg(np.arccos((trace-1.0)/2.0))
def get_camera_intrinsic():
K = np.zeros((3, 3), dtype='float64')
K[0, 0], K[0, 2] = 572.4114, 325.2611
K[1, 1], K[1, 2] = 573.5704, 242.0489
K[2, 2] = 1.
return K
def get_camera_intrinsic(u0, v0, fx, fy):
return np.array([[fx, 0.0, u0], [0.0, fy, v0], [0.0, 0.0, 1.0]])
def compute_projection(points_3D, transformation, internal_calibration):
projections_2d = np.zeros((2, points_3D.shape[1]), dtype='float32')
@ -70,7 +71,6 @@ def get_3D_corners(vertices):
max_y = np.max(vertices[1,:])
min_z = np.min(vertices[2,:])
max_z = np.max(vertices[2,:])
corners = np.array([[min_x, min_y, min_z],
[min_x, min_y, max_z],
[min_x, max_y, min_z],
@ -87,35 +87,25 @@ def pnp(points_3D, points_2D, cameraMatrix):
try:
distCoeffs = pnp.distCoeffs
except:
distCoeffs = np.zeros((8, 1), dtype='float32')
distCoeffs = np.zeros((8, 1), dtype='float32')
assert points_2D.shape[0] == points_2D.shape[0], 'points 3D and points 2D must have same number of vertices'
_, R_exp, t = cv2.solvePnP(points_3D,
# points_2D,
np.ascontiguousarray(points_2D[:,:2]).reshape((-1,1,2)),
cameraMatrix,
distCoeffs)
# , None, None, False, cv2.SOLVEPNP_UPNP)
# R_exp, t, _ = cv2.solvePnPRansac(points_3D,
# points_2D,
# cameraMatrix,
# distCoeffs,
# reprojectionError=12.0)
#
distCoeffs)
R, _ = cv2.Rodrigues(R_exp)
# Rt = np.c_[R, t]
return R, t
def get_2d_bb(box, size):
x = box[0]
y = box[1]
min_x = np.min(np.reshape(box, [9,2])[:,0])
max_x = np.max(np.reshape(box, [9,2])[:,0])
min_y = np.min(np.reshape(box, [9,2])[:,1])
max_y = np.max(np.reshape(box, [9,2])[:,1])
min_x = np.min(np.reshape(box, [-1,2])[:,0])
max_x = np.max(np.reshape(box, [-1,2])[:,0])
min_y = np.min(np.reshape(box, [-1,2])[:,1])
max_y = np.max(np.reshape(box, [-1,2])[:,1])
w = max_x - min_x
h = max_y - min_y
new_box = [x*size, y*size, w*size, h*size]
@ -145,89 +135,7 @@ def compute_2d_bb_from_orig_pix(pts, size):
new_box = [cx*size, cy*size, w*size, h*size]
return new_box
def bbox_iou(box1, box2, x1y1x2y2=False):
if x1y1x2y2:
mx = min(box1[0], box2[0])
Mx = max(box1[2], box2[2])
my = min(box1[1], box2[1])
My = max(box1[3], box2[3])
w1 = box1[2] - box1[0]
h1 = box1[3] - box1[1]
w2 = box2[2] - box2[0]
h2 = box2[3] - box2[1]
else:
mx = min(box1[0]-box1[2]/2.0, box2[0]-box2[2]/2.0)
Mx = max(box1[0]+box1[2]/2.0, box2[0]+box2[2]/2.0)
my = min(box1[1]-box1[3]/2.0, box2[1]-box2[3]/2.0)
My = max(box1[1]+box1[3]/2.0, box2[1]+box2[3]/2.0)
w1 = box1[2]
h1 = box1[3]
w2 = box2[2]
h2 = box2[3]
uw = Mx - mx
uh = My - my
cw = w1 + w2 - uw
ch = h1 + h2 - uh
carea = 0
if cw <= 0 or ch <= 0:
return 0.0
area1 = w1 * h1
area2 = w2 * h2
carea = cw * ch
uarea = area1 + area2 - carea
return carea/uarea
def corner_confidences(gt_corners, pr_corners, th=30, sharpness=2, im_width=640, im_height=480):
''' gt_corners: Ground-truth 2D projections of the 3D bounding box corners, shape: (16 x nA), type: torch.FloatTensor
pr_corners: Prediction for the 2D projections of the 3D bounding box corners, shape: (16 x nA), type: torch.FloatTensor
th : distance threshold, type: int
sharpness : sharpness of the exponential that assigns a confidence value to the distance
-----------
return : a torch.FloatTensor of shape (nA,) with 8 confidence values
'''
shape = gt_corners.size()
nA = shape[1]
dist = gt_corners - pr_corners
dist = dist.t().contiguous().view(nA, 8, 2)
dist[:, :, 0] = dist[:, :, 0] * im_width
dist[:, :, 1] = dist[:, :, 1] * im_height
eps = 1e-5
distthresh = torch.FloatTensor([th]).repeat(nA, 8)
dist = torch.sqrt(torch.sum((dist)**2, dim=2)).squeeze() # nA x 8
mask = (dist < distthresh).type(torch.FloatTensor)
conf = torch.exp(sharpness*(1 - dist/distthresh))-1 # mask * (torch.exp(math.log(2) * (1.0 - dist/rrt)) - 1)
conf0 = torch.exp(sharpness*(1 - torch.zeros(conf.size(0),1))) - 1
conf = conf / conf0.repeat(1, 8)
# conf = 1 - dist/distthresh
conf = mask * conf # nA x 8
mean_conf = torch.mean(conf, dim=1)
return mean_conf
def corner_confidence(gt_corners, pr_corners, th=30, sharpness=2, im_width=640, im_height=480):
''' gt_corners: Ground-truth 2D projections of the 3D bounding box corners, shape: (16,) type: list
pr_corners: Prediction for the 2D projections of the 3D bounding box corners, shape: (16,), type: list
th : distance threshold, type: int
sharpness : sharpness of the exponential that assigns a confidence value to the distance
-----------
return : a list of shape (8,) with 8 confidence values
'''
dist = torch.FloatTensor(gt_corners) - pr_corners
dist = dist.view(8, 2)
dist[:, 0] = dist[:, 0] * im_width
dist[:, 1] = dist[:, 1] * im_height
eps = 1e-5
dist = torch.sqrt(torch.sum((dist)**2, dim=1))
mask = (dist < th).type(torch.FloatTensor)
conf = torch.exp(sharpness * (1.0 - dist/th)) - 1
conf0 = torch.exp(torch.FloatTensor([sharpness])) - 1 + eps
conf = conf / conf0.repeat(8, 1)
# conf = 1.0 - dist/th
conf = mask * conf
return torch.mean(conf)
def corner_confidences9(gt_corners, pr_corners, th=80, sharpness=2, im_width=640, im_height=480):
def corner_confidences(gt_corners, pr_corners, th=80, sharpness=2, im_width=640, im_height=480):
''' gt_corners: Ground-truth 2D projections of the 3D bounding box corners, shape: (16 x nA), type: torch.FloatTensor
pr_corners: Prediction for the 2D projections of the 3D bounding box corners, shape: (16 x nA), type: torch.FloatTensor
th : distance threshold, type: int
@ -238,23 +146,25 @@ def corner_confidences9(gt_corners, pr_corners, th=80, sharpness=2, im_width=640
shape = gt_corners.size()
nA = shape[1]
dist = gt_corners - pr_corners
dist = dist.t().contiguous().view(nA, 9, 2)
num_el = dist.numel()
num_keypoints = num_el//(nA*2)
dist = dist.t().contiguous().view(nA, num_keypoints, 2)
dist[:, :, 0] = dist[:, :, 0] * im_width
dist[:, :, 1] = dist[:, :, 1] * im_height
eps = 1e-5
distthresh = torch.FloatTensor([th]).repeat(nA, 9)
distthresh = torch.FloatTensor([th]).repeat(nA, num_keypoints)
dist = torch.sqrt(torch.sum((dist)**2, dim=2)).squeeze() # nA x 9
mask = (dist < distthresh).type(torch.FloatTensor)
conf = torch.exp(sharpness*(1 - dist/distthresh))-1 # mask * (torch.exp(math.log(2) * (1.0 - dist/rrt)) - 1)
conf0 = torch.exp(sharpness*(1 - torch.zeros(conf.size(0),1))) - 1
conf = conf / conf0.repeat(1, 9)
conf = conf / conf0.repeat(1, num_keypoints)
# conf = 1 - dist/distthresh
conf = mask * conf # nA x 9
mean_conf = torch.mean(conf, dim=1)
return mean_conf
def corner_confidence9(gt_corners, pr_corners, th=80, sharpness=2, im_width=640, im_height=480):
def corner_confidence(gt_corners, pr_corners, th=80, sharpness=2, im_width=640, im_height=480):
''' gt_corners: Ground-truth 2D projections of the 3D bounding box corners, shape: (18,) type: list
pr_corners: Prediction for the 2D projections of the 3D bounding box corners, shape: (18,), type: list
th : distance threshold, type: int
@ -263,7 +173,8 @@ def corner_confidence9(gt_corners, pr_corners, th=80, sharpness=2, im_width=640,
return : a list of shape (9,) with 9 confidence values
'''
dist = torch.FloatTensor(gt_corners) - pr_corners
dist = dist.view(9, 2)
num_keypoints = dist.numel()//2
dist = dist.view(num_keypoints, 2)
dist[:, 0] = dist[:, 0] * im_width
dist[:, 1] = dist[:, 1] * im_height
eps = 1e-5
@ -271,8 +182,7 @@ def corner_confidence9(gt_corners, pr_corners, th=80, sharpness=2, im_width=640,
mask = (dist < th).type(torch.FloatTensor)
conf = torch.exp(sharpness * (1.0 - dist/th)) - 1
conf0 = torch.exp(torch.FloatTensor([sharpness])) - 1 + eps
conf = conf / conf0.repeat(9, 1)
# conf = 1.0 - dist/th
conf = conf / conf0.repeat(num_keypoints, 1)
conf = mask * conf
return torch.mean(conf)
@ -284,27 +194,6 @@ def softmax(x):
x = x/x.sum()
return x
def nms(boxes, nms_thresh):
if len(boxes) == 0:
return boxes
det_confs = torch.zeros(len(boxes))
for i in range(len(boxes)):
det_confs[i] = 1-boxes[i][4]
_,sortIds = torch.sort(det_confs)
out_boxes = []
for i in range(len(boxes)):
box_i = boxes[sortIds[i]]
if box_i[4] > 0:
out_boxes.append(box_i)
for j in range(i+1, len(boxes)):
box_j = boxes[sortIds[j]]
if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
#print(box_i, box_j, bbox_iou(box_i, box_j, x1y1x2y2=False))
box_j[4] = 0
return out_boxes
def fix_corner_order(corners2D_gt):
corners2D_gt_corrected = np.zeros((9, 2), dtype='float32')
corners2D_gt_corrected[0, :] = corners2D_gt[0, :]
@ -324,44 +213,33 @@ def convert2cpu(gpu_matrix):
def convert2cpu_long(gpu_matrix):
return torch.LongTensor(gpu_matrix.size()).copy_(gpu_matrix)
def get_region_boxes(output, conf_thresh, num_classes, only_objectness=1, validation=False):
def get_region_boxes(output, num_classes, num_keypoints, only_objectness=1, validation=True):
# Parameters
anchor_dim = 1
if output.dim() == 3:
output = output.unsqueeze(0)
batch = output.size(0)
assert(output.size(1) == (19+num_classes)*anchor_dim)
assert(output.size(1) == (2*num_keypoints+1+num_classes)*anchor_dim)
h = output.size(2)
w = output.size(3)
# Activation
t0 = time.time()
all_boxes = []
max_conf = -100000
output = output.view(batch*anchor_dim, 19+num_classes, h*w).transpose(0,1).contiguous().view(19+num_classes, batch*anchor_dim*h*w)
max_conf = -sys.maxsize
output = output.view(batch*anchor_dim, 2*num_keypoints+1+num_classes, h*w).transpose(0,1).contiguous().view(2*num_keypoints+1+num_classes, batch*anchor_dim*h*w)
grid_x = torch.linspace(0, w-1, w).repeat(h,1).repeat(batch*anchor_dim, 1, 1).view(batch*anchor_dim*h*w).cuda()
grid_y = torch.linspace(0, h-1, h).repeat(w,1).t().repeat(batch*anchor_dim, 1, 1).view(batch*anchor_dim*h*w).cuda()
xs0 = torch.sigmoid(output[0]) + grid_x
ys0 = torch.sigmoid(output[1]) + grid_y
xs1 = output[2] + grid_x
ys1 = output[3] + grid_y
xs2 = output[4] + grid_x
ys2 = output[5] + grid_y
xs3 = output[6] + grid_x
ys3 = output[7] + grid_y
xs4 = output[8] + grid_x
ys4 = output[9] + grid_y
xs5 = output[10] + grid_x
ys5 = output[11] + grid_y
xs6 = output[12] + grid_x
ys6 = output[13] + grid_y
xs7 = output[14] + grid_x
ys7 = output[15] + grid_y
xs8 = output[16] + grid_x
ys8 = output[17] + grid_y
det_confs = torch.sigmoid(output[18])
cls_confs = torch.nn.Softmax()(Variable(output[19:19+num_classes].transpose(0,1))).data
xs = list()
ys = list()
xs.append(torch.sigmoid(output[0]) + grid_x)
ys.append(torch.sigmoid(output[1]) + grid_y)
for j in range(1,num_keypoints):
xs.append(output[2*j + 0] + grid_x)
ys.append(output[2*j + 1] + grid_y)
det_confs = torch.sigmoid(output[2*num_keypoints])
cls_confs = torch.nn.Softmax()(Variable(output[2*num_keypoints+1:2*num_keypoints+1+num_classes].transpose(0,1))).data
cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
cls_max_confs = cls_max_confs.view(-1)
cls_max_ids = cls_max_ids.view(-1)
@ -373,32 +251,15 @@ def get_region_boxes(output, conf_thresh, num_classes, only_objectness=1, valida
det_confs = convert2cpu(det_confs)
cls_max_confs = convert2cpu(cls_max_confs)
cls_max_ids = convert2cpu_long(cls_max_ids)
xs0 = convert2cpu(xs0)
ys0 = convert2cpu(ys0)
xs1 = convert2cpu(xs1)
ys1 = convert2cpu(ys1)
xs2 = convert2cpu(xs2)
ys2 = convert2cpu(ys2)
xs3 = convert2cpu(xs3)
ys3 = convert2cpu(ys3)
xs4 = convert2cpu(xs4)
ys4 = convert2cpu(ys4)
xs5 = convert2cpu(xs5)
ys5 = convert2cpu(ys5)
xs6 = convert2cpu(xs6)
ys6 = convert2cpu(ys6)
xs7 = convert2cpu(xs7)
ys7 = convert2cpu(ys7)
xs8 = convert2cpu(xs8)
ys8 = convert2cpu(ys8)
for j in range(num_keypoints):
xs[j] = convert2cpu(xs[j])
ys[j] = convert2cpu(ys[j])
if validation:
cls_confs = convert2cpu(cls_confs.view(-1, num_classes))
t2 = time.time()
# Boxes filter
for b in range(batch):
boxes = []
max_conf = -1
for cy in range(h):
for cx in range(w):
for i in range(anchor_dim):
@ -411,66 +272,20 @@ def get_region_boxes(output, conf_thresh, num_classes, only_objectness=1, valida
if conf > max_conf:
max_conf = conf
max_ind = ind
if conf > conf_thresh:
bcx0 = xs0[ind]
bcy0 = ys0[ind]
bcx1 = xs1[ind]
bcy1 = ys1[ind]
bcx2 = xs2[ind]
bcy2 = ys2[ind]
bcx3 = xs3[ind]
bcy3 = ys3[ind]
bcx4 = xs4[ind]
bcy4 = ys4[ind]
bcx5 = xs5[ind]
bcy5 = ys5[ind]
bcx6 = xs6[ind]
bcy6 = ys6[ind]
bcx7 = xs7[ind]
bcy7 = ys7[ind]
bcx8 = xs8[ind]
bcy8 = ys8[ind]
bcx = list()
bcy = list()
for j in range(num_keypoints):
bcx.append(xs[j][ind])
bcy.append(ys[j][ind])
cls_max_conf = cls_max_confs[ind]
cls_max_id = cls_max_ids[ind]
box = [bcx0/w, bcy0/h, bcx1/w, bcy1/h, bcx2/w, bcy2/h, bcx3/w, bcy3/h, bcx4/w, bcy4/h, bcx5/w, bcy5/h, bcx6/w, bcy6/h, bcx7/w, bcy7/h, bcx8/w, bcy8/h, det_conf, cls_max_conf, cls_max_id]
if (not only_objectness) and validation:
for c in range(num_classes):
tmp_conf = cls_confs[ind][c]
if c != cls_max_id and det_confs[ind]*tmp_conf > conf_thresh:
box.append(tmp_conf)
box.append(c)
boxes.append(box)
if len(boxes) == 0:
bcx0 = xs0[max_ind]
bcy0 = ys0[max_ind]
bcx1 = xs1[max_ind]
bcy1 = ys1[max_ind]
bcx2 = xs2[max_ind]
bcy2 = ys2[max_ind]
bcx3 = xs3[max_ind]
bcy3 = ys3[max_ind]
bcx4 = xs4[max_ind]
bcy4 = ys4[max_ind]
bcx5 = xs5[max_ind]
bcy5 = ys5[max_ind]
bcx6 = xs6[max_ind]
bcy6 = ys6[max_ind]
bcx7 = xs7[max_ind]
bcy7 = ys7[max_ind]
bcx8 = xs8[max_ind]
bcy8 = ys8[max_ind]
cls_max_conf = cls_max_confs[max_ind]
cls_max_id = cls_max_ids[max_ind]
det_conf = det_confs[max_ind]
box = [bcx0/w, bcy0/h, bcx1/w, bcy1/h, bcx2/w, bcy2/h, bcx3/w, bcy3/h, bcx4/w, bcy4/h, bcx5/w, bcy5/h, bcx6/w, bcy6/h, bcx7/w, bcy7/h, bcx8/w, bcy8/h, det_conf, cls_max_conf, cls_max_id]
boxes.append(box)
all_boxes.append(boxes)
else:
all_boxes.append(boxes)
all_boxes.append(boxes)
box = list()
for j in range(num_keypoints):
box.append(bcx[j]/w)
box.append(bcy[j]/h)
box.append(det_conf)
box.append(cls_max_conf)
box.append(cls_max_id)
t3 = time.time()
if False:
print('---------------------------------')
@ -478,424 +293,25 @@ def get_region_boxes(output, conf_thresh, num_classes, only_objectness=1, valida
print(' gpu to cpu : %f' % (t2-t1))
print(' boxes filter : %f' % (t3-t2))
print('---------------------------------')
return all_boxes
return box
def get_corresponding_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, correspondingclass, only_objectness=1, validation=False):
# Parameters
anchor_step = len(anchors)/num_anchors
if output.dim() == 3:
output = output.unsqueeze(0)
batch = output.size(0)
assert(output.size(1) == (19+num_classes)*num_anchors)
h = output.size(2)
w = output.size(3)
# Activation
t0 = time.time()
all_boxes = []
max_conf = -100000
max_cls_conf = -100000
output = output.view(batch*num_anchors, 19+num_classes, h*w).transpose(0,1).contiguous().view(19+num_classes, batch*num_anchors*h*w)
grid_x = torch.linspace(0, w-1, w).repeat(h,1).repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
grid_y = torch.linspace(0, h-1, h).repeat(w,1).t().repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
xs0 = torch.sigmoid(output[0]) + grid_x
ys0 = torch.sigmoid(output[1]) + grid_y
xs1 = output[2] + grid_x
ys1 = output[3] + grid_y
xs2 = output[4] + grid_x
ys2 = output[5] + grid_y
xs3 = output[6] + grid_x
ys3 = output[7] + grid_y
xs4 = output[8] + grid_x
ys4 = output[9] + grid_y
xs5 = output[10] + grid_x
ys5 = output[11] + grid_y
xs6 = output[12] + grid_x
ys6 = output[13] + grid_y
xs7 = output[14] + grid_x
ys7 = output[15] + grid_y
xs8 = output[16] + grid_x
ys8 = output[17] + grid_y
det_confs = torch.sigmoid(output[18])
cls_confs = torch.nn.Softmax()(Variable(output[19:19+num_classes].transpose(0,1))).data
cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
cls_max_confs = cls_max_confs.view(-1)
cls_max_ids = cls_max_ids.view(-1)
t1 = time.time()
# GPU to CPU
sz_hw = h*w
sz_hwa = sz_hw*num_anchors
det_confs = convert2cpu(det_confs)
cls_max_confs = convert2cpu(cls_max_confs)
cls_max_ids = convert2cpu_long(cls_max_ids)
xs0 = convert2cpu(xs0)
ys0 = convert2cpu(ys0)
xs1 = convert2cpu(xs1)
ys1 = convert2cpu(ys1)
xs2 = convert2cpu(xs2)
ys2 = convert2cpu(ys2)
xs3 = convert2cpu(xs3)
ys3 = convert2cpu(ys3)
xs4 = convert2cpu(xs4)
ys4 = convert2cpu(ys4)
xs5 = convert2cpu(xs5)
ys5 = convert2cpu(ys5)
xs6 = convert2cpu(xs6)
ys6 = convert2cpu(ys6)
xs7 = convert2cpu(xs7)
ys7 = convert2cpu(ys7)
xs8 = convert2cpu(xs8)
ys8 = convert2cpu(ys8)
if validation:
cls_confs = convert2cpu(cls_confs.view(-1, num_classes))
t2 = time.time()
# Boxes filter
for b in range(batch):
boxes = []
max_conf = -1
for cy in range(h):
for cx in range(w):
for i in range(num_anchors):
ind = b*sz_hwa + i*sz_hw + cy*w + cx
det_conf = det_confs[ind]
if only_objectness:
conf = det_confs[ind]
else:
conf = det_confs[ind] * cls_max_confs[ind]
if (det_confs[ind] > max_conf) and (cls_confs[ind, correspondingclass] > max_cls_conf):
max_conf = det_confs[ind]
max_cls_conf = cls_confs[ind, correspondingclass]
max_ind = ind
if conf > conf_thresh:
bcx0 = xs0[ind]
bcy0 = ys0[ind]
bcx1 = xs1[ind]
bcy1 = ys1[ind]
bcx2 = xs2[ind]
bcy2 = ys2[ind]
bcx3 = xs3[ind]
bcy3 = ys3[ind]
bcx4 = xs4[ind]
bcy4 = ys4[ind]
bcx5 = xs5[ind]
bcy5 = ys5[ind]
bcx6 = xs6[ind]
bcy6 = ys6[ind]
bcx7 = xs7[ind]
bcy7 = ys7[ind]
bcx8 = xs8[ind]
bcy8 = ys8[ind]
cls_max_conf = cls_max_confs[ind]
cls_max_id = cls_max_ids[ind]
box = [bcx0/w, bcy0/h, bcx1/w, bcy1/h, bcx2/w, bcy2/h, bcx3/w, bcy3/h, bcx4/w, bcy4/h, bcx5/w, bcy5/h, bcx6/w, bcy6/h, bcx7/w, bcy7/h, bcx8/w, bcy8/h, det_conf, cls_max_conf, cls_max_id]
if (not only_objectness) and validation:
for c in range(num_classes):
tmp_conf = cls_confs[ind][c]
if c != cls_max_id and det_confs[ind]*tmp_conf > conf_thresh:
box.append(tmp_conf)
box.append(c)
boxes.append(box)
boxesnp = np.array(boxes)
if (len(boxes) == 0) or (not (correspondingclass in boxesnp[:,20])):
bcx0 = xs0[max_ind]
bcy0 = ys0[max_ind]
bcx1 = xs1[max_ind]
bcy1 = ys1[max_ind]
bcx2 = xs2[max_ind]
bcy2 = ys2[max_ind]
bcx3 = xs3[max_ind]
bcy3 = ys3[max_ind]
bcx4 = xs4[max_ind]
bcy4 = ys4[max_ind]
bcx5 = xs5[max_ind]
bcy5 = ys5[max_ind]
bcx6 = xs6[max_ind]
bcy6 = ys6[max_ind]
bcx7 = xs7[max_ind]
bcy7 = ys7[max_ind]
bcx8 = xs8[max_ind]
bcy8 = ys8[max_ind]
cls_max_conf = max_cls_conf # cls_max_confs[max_ind]
cls_max_id = correspondingclass # cls_max_ids[max_ind]
det_conf = max_conf # det_confs[max_ind]
box = [bcx0/w, bcy0/h, bcx1/w, bcy1/h, bcx2/w, bcy2/h, bcx3/w, bcy3/h, bcx4/w, bcy4/h, bcx5/w, bcy5/h, bcx6/w, bcy6/h, bcx7/w, bcy7/h, bcx8/w, bcy8/h, det_conf, cls_max_conf, cls_max_id]
boxes.append(box)
# print(boxes)
all_boxes.append(boxes)
else:
all_boxes.append(boxes)
t3 = time.time()
if False:
print('---------------------------------')
print('matrix computation : %f' % (t1-t0))
print(' gpu to cpu : %f' % (t2-t1))
print(' boxes filter : %f' % (t3-t2))
print('---------------------------------')
return all_boxes
def get_boxes(output, conf_thresh, num_classes, anchors, num_anchors, correspondingclass, only_objectness=1, validation=False):
# Parameters
anchor_step = len(anchors)/num_anchors
if output.dim() == 3:
output = output.unsqueeze(0)
batch = output.size(0)
assert(output.size(1) == (19+num_classes)*num_anchors)
h = output.size(2)
w = output.size(3)
# Activation
t0 = time.time()
all_boxes = []
max_conf = -100000
max_cls_conf = -100000
output = output.view(batch*num_anchors, 19+num_classes, h*w).transpose(0,1).contiguous().view(19+num_classes, batch*num_anchors*h*w)
grid_x = torch.linspace(0, w-1, w).repeat(h,1).repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
grid_y = torch.linspace(0, h-1, h).repeat(w,1).t().repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
xs0 = torch.sigmoid(output[0]) + grid_x
ys0 = torch.sigmoid(output[1]) + grid_y
xs1 = output[2] + grid_x
ys1 = output[3] + grid_y
xs2 = output[4] + grid_x
ys2 = output[5] + grid_y
xs3 = output[6] + grid_x
ys3 = output[7] + grid_y
xs4 = output[8] + grid_x
ys4 = output[9] + grid_y
xs5 = output[10] + grid_x
ys5 = output[11] + grid_y
xs6 = output[12] + grid_x
ys6 = output[13] + grid_y
xs7 = output[14] + grid_x
ys7 = output[15] + grid_y
xs8 = output[16] + grid_x
ys8 = output[17] + grid_y
det_confs = torch.sigmoid(output[18])
cls_confs = torch.nn.Softmax()(Variable(output[19:19+num_classes].transpose(0,1))).data
cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
cls_max_confs = cls_max_confs.view(-1)
cls_max_ids = cls_max_ids.view(-1)
t1 = time.time()
# GPU to CPU
sz_hw = h*w
sz_hwa = sz_hw*num_anchors
det_confs = convert2cpu(det_confs)
cls_max_confs = convert2cpu(cls_max_confs)
cls_max_ids = convert2cpu_long(cls_max_ids)
xs0 = convert2cpu(xs0)
ys0 = convert2cpu(ys0)
xs1 = convert2cpu(xs1)
ys1 = convert2cpu(ys1)
xs2 = convert2cpu(xs2)
ys2 = convert2cpu(ys2)
xs3 = convert2cpu(xs3)
ys3 = convert2cpu(ys3)
xs4 = convert2cpu(xs4)
ys4 = convert2cpu(ys4)
xs5 = convert2cpu(xs5)
ys5 = convert2cpu(ys5)
xs6 = convert2cpu(xs6)
ys6 = convert2cpu(ys6)
xs7 = convert2cpu(xs7)
ys7 = convert2cpu(ys7)
xs8 = convert2cpu(xs8)
ys8 = convert2cpu(ys8)
if validation:
cls_confs = convert2cpu(cls_confs.view(-1, num_classes))
t2 = time.time()
# Boxes filter
for b in range(batch):
boxes = []
max_conf = -1
for cy in range(h):
for cx in range(w):
for i in range(num_anchors):
ind = b*sz_hwa + i*sz_hw + cy*w + cx
det_conf = det_confs[ind]
if only_objectness:
conf = det_confs[ind]
else:
conf = det_confs[ind] * cls_max_confs[ind]
if (conf > max_conf) and (cls_confs[ind, correspondingclass] > max_cls_conf):
max_conf = conf
max_cls_conf = cls_confs[ind, correspondingclass]
max_ind = ind
if conf > conf_thresh:
bcx0 = xs0[ind]
bcy0 = ys0[ind]
bcx1 = xs1[ind]
bcy1 = ys1[ind]
bcx2 = xs2[ind]
bcy2 = ys2[ind]
bcx3 = xs3[ind]
bcy3 = ys3[ind]
bcx4 = xs4[ind]
bcy4 = ys4[ind]
bcx5 = xs5[ind]
bcy5 = ys5[ind]
bcx6 = xs6[ind]
bcy6 = ys6[ind]
bcx7 = xs7[ind]
bcy7 = ys7[ind]
bcx8 = xs8[ind]
bcy8 = ys8[ind]
cls_max_conf = cls_max_confs[ind]
cls_max_id = cls_max_ids[ind]
box = [bcx0/w, bcy0/h, bcx1/w, bcy1/h, bcx2/w, bcy2/h, bcx3/w, bcy3/h, bcx4/w, bcy4/h, bcx5/w, bcy5/h, bcx6/w, bcy6/h, bcx7/w, bcy7/h, bcx8/w, bcy8/h, det_conf, cls_max_conf, cls_max_id]
if (not only_objectness) and validation:
for c in range(num_classes):
tmp_conf = cls_confs[ind][c]
if c != cls_max_id and det_confs[ind]*tmp_conf > conf_thresh:
box.append(tmp_conf)
box.append(c)
boxes.append(box)
boxesnp = np.array(boxes)
if (len(boxes) == 0) or (not (correspondingclass in boxesnp[:,20])):
bcx0 = xs0[max_ind]
bcy0 = ys0[max_ind]
bcx1 = xs1[max_ind]
bcy1 = ys1[max_ind]
bcx2 = xs2[max_ind]
bcy2 = ys2[max_ind]
bcx3 = xs3[max_ind]
bcy3 = ys3[max_ind]
bcx4 = xs4[max_ind]
bcy4 = ys4[max_ind]
bcx5 = xs5[max_ind]
bcy5 = ys5[max_ind]
bcx6 = xs6[max_ind]
bcy6 = ys6[max_ind]
bcx7 = xs7[max_ind]
bcy7 = ys7[max_ind]
bcx8 = xs8[max_ind]
bcy8 = ys8[max_ind]
cls_max_conf = max_cls_conf # cls_max_confs[max_ind]
cls_max_id = correspondingclass # cls_max_ids[max_ind]
det_conf = det_confs[max_ind]
box = [bcx0/w, bcy0/h, bcx1/w, bcy1/h, bcx2/w, bcy2/h, bcx3/w, bcy3/h, bcx4/w, bcy4/h, bcx5/w, bcy5/h, bcx6/w, bcy6/h, bcx7/w, bcy7/h, bcx8/w, bcy8/h, det_conf, cls_max_conf, cls_max_id]
boxes.append(box)
# print(boxes)
all_boxes.append(boxes)
else:
all_boxes.append(boxes)
t3 = time.time()
if False:
print('---------------------------------')
print('matrix computation : %f' % (t1-t0))
print(' gpu to cpu : %f' % (t2-t1))
print(' boxes filter : %f' % (t3-t2))
print('---------------------------------')
return all_boxes
def plot_boxes_cv2(img, boxes, savename=None, class_names=None, color=None):
import cv2
colors = torch.FloatTensor([[1,0,1],[0,0,1],[0,1,1],[0,1,0],[1,1,0],[1,0,0]]);
def get_color(c, x, max_val):
ratio = float(x)/max_val * 5
i = int(math.floor(ratio))
j = int(math.ceil(ratio))
ratio = ratio - i
r = (1-ratio) * colors[i][c] + ratio*colors[j][c]
return int(r*255)
width = img.shape[1]
height = img.shape[0]
for i in range(len(boxes)):
box = boxes[i]
x1 = int(round((box[0] - box[2]/2.0) * width))
y1 = int(round((box[1] - box[3]/2.0) * height))
x2 = int(round((box[0] + box[2]/2.0) * width))
y2 = int(round((box[1] + box[3]/2.0) * height))
if color:
rgb = color
else:
rgb = (255, 0, 0)
if len(box) >= 7 and class_names:
cls_conf = box[5]
cls_id = box[6]
print('%s: %f' % (class_names[cls_id], cls_conf))
classes = len(class_names)
offset = cls_id * 123457 % classes
red = get_color(2, offset, classes)
green = get_color(1, offset, classes)
blue = get_color(0, offset, classes)
if color is None:
rgb = (red, green, blue)
img = cv2.putText(img, class_names[cls_id], (x1,y1), cv2.FONT_HERSHEY_SIMPLEX, 1.2, rgb, 1)
img = cv2.rectangle(img, (x1,y1), (x2,y2), rgb, 1)
if savename:
print("save plot results to %s" % savename)
cv2.imwrite(savename, img)
return img
def plot_boxes(img, boxes, savename=None, class_names=None):
colors = torch.FloatTensor([[1,0,1],[0,0,1],[0,1,1],[0,1,0],[1,1,0],[1,0,0]]);
def get_color(c, x, max_val):
ratio = float(x)/max_val * 5
i = int(math.floor(ratio))
j = int(math.ceil(ratio))
ratio = ratio - i
r = (1-ratio) * colors[i][c] + ratio*colors[j][c]
return int(r*255)
width = img.width
height = img.height
draw = ImageDraw.Draw(img)
for i in range(len(boxes)):
box = boxes[i]
x1 = (box[0] - box[2]/2.0) * width
y1 = (box[1] - box[3]/2.0) * height
x2 = (box[0] + box[2]/2.0) * width
y2 = (box[1] + box[3]/2.0) * height
rgb = (255, 0, 0)
if len(box) >= 7 and class_names:
cls_conf = box[5]
cls_id = box[6]
print('%s: %f' % (class_names[cls_id], cls_conf))
classes = len(class_names)
offset = cls_id * 123457 % classes
red = get_color(2, offset, classes)
green = get_color(1, offset, classes)
blue = get_color(0, offset, classes)
rgb = (red, green, blue)
draw.text((x1, y1), class_names[cls_id], fill=rgb)
draw.rectangle([x1, y1, x2, y2], outline = rgb)
if savename:
print("save plot results to %s" % savename)
img.save(savename)
return img
def read_truths(lab_path):
def read_truths(lab_path, num_keypoints=9):
num_labels = 2*num_keypoints+3 # +2 for width, height, +1 for class label
if os.path.getsize(lab_path):
truths = np.loadtxt(lab_path)
truths = truths.reshape(truths.size/21, 21) # to avoid single truth problem
truths = truths.reshape(truths.size//num_labels, num_labels) # to avoid single truth problem
return truths
else:
return np.array([])
def read_truths_args(lab_path, min_box_scale):
def read_truths_args(lab_path, num_keypoints=9):
num_labels = 2 * num_keypoints + 1
truths = read_truths(lab_path)
new_truths = []
for i in range(truths.shape[0]):
new_truths.append([truths[i][0], truths[i][1], truths[i][2], truths[i][3], truths[i][4],
truths[i][5], truths[i][6], truths[i][7], truths[i][8], truths[i][9], truths[i][10],
truths[i][11], truths[i][12], truths[i][13], truths[i][14], truths[i][15], truths[i][16], truths[i][17], truths[i][18]])
for j in range(num_labels):
new_truths.append(truths[i][j])
return np.array(new_truths)
def read_pose(lab_path):
@ -924,59 +340,9 @@ def image2torch(img):
img = img.float().div(255.0)
return img
def do_detect(model, img, conf_thresh, nms_thresh, use_cuda=1):
model.eval()
t0 = time.time()
if isinstance(img, Image.Image):
width = img.width
height = img.height
img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes()))
img = img.view(height, width, 3).transpose(0,1).transpose(0,2).contiguous()
img = img.view(1, 3, height, width)
img = img.float().div(255.0)
elif type(img) == np.ndarray: # cv2 image
img = torch.from_numpy(img.transpose(2,0,1)).float().div(255.0).unsqueeze(0)
else:
print("unknow image type")
exit(-1)
t1 = time.time()
if use_cuda:
img = img.cuda()
img = torch.autograd.Variable(img)
t2 = time.time()
output = model(img)
output = output.data
#for j in range(100):
# sys.stdout.write('%f ' % (output.storage()[j]))
#print('')
t3 = time.time()
boxes = get_region_boxes(output, conf_thresh, model.num_classes, model.anchors, model.num_anchors)[0]
#for j in range(len(boxes)):
# print(boxes[j])
t4 = time.time()
boxes = nms(boxes, nms_thresh)
t5 = time.time()
if False:
print('-----------------------------------')
print(' image to tensor : %f' % (t1 - t0))
print(' tensor to cuda : %f' % (t2 - t1))
print(' predict : %f' % (t3 - t2))
print('get_region_boxes : %f' % (t4 - t3))
print(' nms : %f' % (t5 - t4))
print(' total : %f' % (t5 - t0))
print('-----------------------------------')
return boxes
def read_data_cfg(datacfg):
options = dict()
options['gpus'] = '0,1,2,3'
options['gpus'] = '0'
options['num_workers'] = '10'
with open(datacfg, 'r') as fp:
lines = fp.readlines()
@ -1008,7 +374,7 @@ def file_lines(thefilepath):
buffer = thefile.read(8192*1024)
if not buffer:
break
count += buffer.count('\n')
count += buffer.count(b'\n')
thefile.close( )
return count

Просмотреть файл

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -31,68 +31,46 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"ename": "IOError",
"evalue": "[Errno 2] No such file or directory: 'LINEMOD/ape/ape.ply'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIOError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-cd23ddeac3d5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0mcfgfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cfg/yolo-pose.cfg'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0mweightfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'backup/ape/model_backup.weights'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m \u001b[0mvalid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatacfg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcfgfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweightfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-7-cd23ddeac3d5>\u001b[0m in \u001b[0;36mvalid\u001b[0;34m(datacfg, cfgfile, weightfile)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;31m# Read object model information, get 3D bounding box corners\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mmesh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMeshPly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmeshname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0mvertices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmesh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmesh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0mcorners3D\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_3D_corners\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/cvlabdata1/home/btekin/ope/singleshotpose_release/MeshPly.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filename, color)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'r'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mIOError\u001b[0m: [Errno 2] No such file or directory: 'LINEMOD/ape/ape.ply'"
]
}
],
"outputs": [],
"source": [
"def valid(datacfg, cfgfile, weightfile):\n",
" def truths_length(truths):\n",
" for i in range(50):\n",
"def valid(datacfg, modelcfg, weightfile):\n",
" def truths_length(truths, max_num_gt=50):\n",
" for i in range(max_num_gt):\n",
" if truths[i][1] == 0:\n",
" return i\n",
"\n",
" # Parse configuration files\n",
" options = read_data_cfg(datacfg)\n",
" valid_images = options['valid']\n",
" meshname = options['mesh']\n",
" backupdir = options['backup']\n",
" name = options['name']\n",
" data_options = read_data_cfg(datacfg)\n",
" valid_images = data_options['valid']\n",
" meshname = data_options['mesh']\n",
" backupdir = data_options['backup']\n",
" name = data_options['name']\n",
" gpus = data_options['gpus'] \n",
" fx = float(data_options['fx'])\n",
" fy = float(data_options['fy'])\n",
" u0 = float(data_options['u0'])\n",
" v0 = float(data_options['v0'])\n",
" im_width = int(data_options['width'])\n",
" im_height = int(data_options['height'])\n",
" if not os.path.exists(backupdir):\n",
" makedirs(backupdir)\n",
"\n",
" # Parameters\n",
" prefix = 'results'\n",
" seed = int(time.time())\n",
" gpus = '0' # Specify which gpus to use\n",
" test_width = 544\n",
" test_height = 544\n",
" torch.manual_seed(seed)\n",
" use_cuda = True\n",
" if use_cuda:\n",
" os.environ['CUDA_VISIBLE_DEVICES'] = gpus\n",
" torch.cuda.manual_seed(seed)\n",
" seed = int(time.time())\n",
" os.environ['CUDA_VISIBLE_DEVICES'] = gpus\n",
" torch.cuda.manual_seed(seed)\n",
" save = False\n",
" visualize = True\n",
" testtime = True\n",
" use_cuda = True\n",
" num_classes = 1\n",
" testing_samples = 0.0\n",
" eps = 1e-5\n",
" notpredicted = 0 \n",
" conf_thresh = 0.1\n",
" nms_thresh = 0.4\n",
" match_thresh = 0.5\n",
" edges_corners = [[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]]\n",
"\n",
" if save:\n",
" makedirs(backupdir + '/test')\n",
" makedirs(backupdir + '/test/gt')\n",
" makedirs(backupdir + '/test/pr')\n",
"\n",
" # To save\n",
" testing_error_trans = 0.0\n",
" testing_error_angle = 0.0\n",
@ -108,17 +86,18 @@
" gts_trans = []\n",
" gts_rot = []\n",
" gts_corners2D = []\n",
" ious = []\n",
"\n",
" # Read object model information, get 3D bounding box corners\n",
" mesh = MeshPly(meshname)\n",
" vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()\n",
" corners3D = get_3D_corners(vertices)\n",
" # diam = calc_pts_diameter(np.array(mesh.vertices))\n",
" diam = float(options['diam'])\n",
"\n",
" mesh = MeshPly(meshname)\n",
" vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()\n",
" corners3D = get_3D_corners(vertices)\n",
" try:\n",
" diam = float(options['diam'])\n",
" except:\n",
" diam = calc_pts_diameter(np.array(mesh.vertices))\n",
" \n",
" # Read intrinsic camera parameters\n",
" internal_calibration = get_camera_intrinsic()\n",
" intrinsic_calibration = get_camera_intrinsic(u0, v0, fx, fy)\n",
"\n",
" # Get validation file names\n",
" with open(valid_images) as fp:\n",
@ -126,29 +105,30 @@
" valid_files = [item.rstrip() for item in tmp_files]\n",
" \n",
" # Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode\n",
" model = Darknet(cfgfile)\n",
" model = Darknet(modelcfg)\n",
" model.print_network()\n",
" model.load_weights(weightfile)\n",
" model.cuda()\n",
" model.eval()\n",
" test_width = model.test_width\n",
" test_height = model.test_height\n",
" num_keypoints = model.num_keypoints \n",
" num_labels = num_keypoints * 2 + 3\n",
"\n",
" # Get the parser for the test dataset\n",
" valid_dataset = dataset.listDataset(valid_images, shape=(test_width, test_height),\n",
" shuffle=False,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),]))\n",
" valid_batchsize = 1\n",
" valid_dataset = dataset.listDataset(valid_images, \n",
" shape=(test_width, test_height),\n",
" shuffle=False,\n",
" transform=transforms.Compose([transforms.ToTensor(),]))\n",
"\n",
" # Specify the number of workers for multiple processing, get the dataloader for the test dataset\n",
" kwargs = {'num_workers': 4, 'pin_memory': True}\n",
" test_loader = torch.utils.data.DataLoader(\n",
" valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs) \n",
" test_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=False, **kwargs) \n",
"\n",
" logging(\" Testing {}...\".format(name))\n",
" logging(\" Number of test samples: %d\" % len(test_loader.dataset))\n",
" # Iterate through test batches (Batch size for test data is 1)\n",
" count = 0\n",
" z = np.zeros((3, 1))\n",
" for batch_idx, (data, target) in enumerate(test_loader):\n",
" \n",
" # Images\n",
@ -158,55 +138,39 @@
" \n",
" t1 = time.time()\n",
" # Pass data to GPU\n",
" if use_cuda:\n",
" data = data.cuda()\n",
" target = target.cuda()\n",
" \n",
" data = data.cuda()\n",
" target = target.cuda()\n",
" # Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference\n",
" data = Variable(data, volatile=True)\n",
" t2 = time.time()\n",
" \n",
" # Forward pass\n",
" output = model(data).data \n",
" t3 = time.time()\n",
" \n",
" # Using confidence threshold, eliminate low-confidence predictions\n",
" all_boxes = get_region_boxes(output, conf_thresh, num_classes) \n",
" all_boxes = get_region_boxes(output, num_classes, num_keypoints) \n",
" t4 = time.time()\n",
"\n",
" # Iterate through all images in the batch\n",
" for i in range(output.size(0)):\n",
" \n",
" # For each image, get all the predictions\n",
" boxes = all_boxes[i]\n",
" \n",
" # Evaluation\n",
" # Iterate through all batch elements\n",
" for box_pr, target in zip([all_boxes], [target[0]]):\n",
" # For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)\n",
" truths = target[i].view(-1, 21)\n",
" \n",
" # Get how many object are present in the scene\n",
" num_gts = truths_length(truths)\n",
"\n",
" # Iterate through each ground-truth object\n",
" truths = target.view(-1, num_keypoints*2+3)\n",
" # Get how many objects are present in the scene\n",
" num_gts = truths_length(truths)\n",
" # Iterate through each ground-truth object\n",
" for k in range(num_gts):\n",
" box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6], \n",
" truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12], \n",
" truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]\n",
" best_conf_est = -1\n",
"\n",
" # If the prediction has the highest confidence, choose it as our prediction for single object pose estimation\n",
" for j in range(len(boxes)):\n",
" if (boxes[j][18] > best_conf_est):\n",
" match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))\n",
" box_pr = boxes[j]\n",
" best_conf_est = boxes[j][18]\n",
" box_gt = list()\n",
" for j in range(1, 2*num_keypoints+1):\n",
" box_gt.append(truths[k][j])\n",
" box_gt.extend([1.0, 1.0])\n",
" box_gt.append(truths[k][0])\n",
"\n",
" # Denormalize the corner predictions \n",
" corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')\n",
" corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')\n",
" corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640\n",
" corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480 \n",
" corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640\n",
" corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480\n",
" corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width\n",
" corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height \n",
" corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width\n",
" corners2D_pr[:, 1] = corners2D_pr[:, 1] * im_height\n",
" preds_corners2D.append(corners2D_pr)\n",
" gts_corners2D.append(corners2D_gt)\n",
"\n",
@ -216,21 +180,8 @@
" errs_corner2D.append(corner_dist)\n",
" \n",
" # Compute [R|t] by pnp\n",
" R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(internal_calibration, dtype='float32'))\n",
" R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))\n",
"\n",
" if save:\n",
" preds_trans.append(t_pr)\n",
" gts_trans.append(t_gt)\n",
" preds_rot.append(R_pr)\n",
" gts_rot.append(R_gt)\n",
"\n",
" np.savetxt(backupdir + '/test/gt/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/gt/t_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(t_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/gt/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))\n",
" R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(intrinsic_calibration, dtype='float32'))\n",
" R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(intrinsic_calibration, dtype='float32'))\n",
" \n",
" # Compute translation error\n",
" trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))\n",
@ -243,19 +194,19 @@
" # Compute pixel error\n",
" Rt_gt = np.concatenate((R_gt, t_gt), axis=1)\n",
" Rt_pr = np.concatenate((R_pr, t_pr), axis=1)\n",
" proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)\n",
" proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration) \n",
" proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, internal_calibration)) \n",
" proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, internal_calibration)) \n",
" proj_2d_gt = compute_projection(vertices, Rt_gt, intrinsic_calibration)\n",
" proj_2d_pred = compute_projection(vertices, Rt_pr, intrinsic_calibration) \n",
" proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, intrinsic_calibration)) \n",
" proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, intrinsic_calibration)) \n",
" norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)\n",
" pixel_dist = np.mean(norm)\n",
" errs_2d.append(pixel_dist)\n",
"\n",
" if visualize:\n",
" # Visualize\n",
" plt.xlim((0, 640))\n",
" plt.ylim((0, 480))\n",
" plt.imshow(scipy.misc.imresize(img, (480, 640)))\n",
" plt.xlim((0, im_width))\n",
" plt.ylim((0, im_height))\n",
" plt.imshow(scipy.misc.imresize(img, (im_height, im_width)))\n",
" # Projections\n",
" for edge in edges_corners:\n",
" plt.plot(proj_corners_gt[edge, 0], proj_corners_gt[edge, 1], color='g', linewidth=3.0)\n",
@ -263,12 +214,6 @@
" plt.gca().invert_yaxis()\n",
" plt.show()\n",
" \n",
" # Compute IoU score\n",
" bb_gt = compute_2d_bb_from_orig_pix(proj_2d_gt, output.size(3))\n",
" bb_pred = compute_2d_bb_from_orig_pix(proj_2d_pred, output.size(3))\n",
" iou = bbox_iou(bb_gt, bb_pred)\n",
" ious.append(iou)\n",
"\n",
" # Compute 3D distances\n",
" transform_3d_gt = compute_transformation(vertices, Rt_gt) \n",
" transform_3d_pred = compute_transformation(vertices, Rt_pr) \n",
@ -283,33 +228,46 @@
" testing_samples += 1\n",
" count = count + 1\n",
"\n",
" if save:\n",
" preds_trans.append(t_pr)\n",
" gts_trans.append(t_gt)\n",
" preds_rot.append(R_pr)\n",
" gts_rot.append(R_gt)\n",
"\n",
" np.savetxt(backupdir + '/test/gt/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/gt/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/gt/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))\n",
"\n",
"\n",
" t5 = time.time()\n",
"\n",
" # Compute 2D projection error, 6D pose error, 5cm5degree error\n",
" px_threshold = 5\n",
" acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)\n",
" acciou = len(np.where(np.array(errs_2d) >= 0.5)[0]) * 100. / (len(ious)+eps)\n",
" acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)\n",
" acc3d10 = len(np.where(np.array(errs_3d) <= diam * 0.1)[0]) * 100. / (len(errs_3d)+eps)\n",
" acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)\n",
" corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)\n",
" mean_err_2d = np.mean(errs_2d)\n",
" px_threshold = 5 # 5 pixel threshold for 2D reprojection error is standard in recent sota 6D object pose estimation works \n",
" eps = 1e-5\n",
" acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)\n",
" acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)\n",
" acc3d10 = len(np.where(np.array(errs_3d) <= diam * 0.1)[0]) * 100. / (len(errs_3d)+eps)\n",
" acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)\n",
" corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)\n",
" mean_err_2d = np.mean(errs_2d)\n",
" mean_corner_err_2d = np.mean(errs_corner2D)\n",
" nts = float(testing_samples)\n",
"\n",
" if testtime:\n",
" print('-----------------------------------')\n",
" print(' tensor to cuda : %f' % (t2 - t1))\n",
" print(' predict : %f' % (t3 - t2))\n",
" print(' forward pass : %f' % (t3 - t2))\n",
" print('get_region_boxes : %f' % (t4 - t3))\n",
" print(' nms : %f' % (t5 - t4))\n",
" print(' total : %f' % (t5 - t1))\n",
" print(' prediction time : %f' % (t4 - t1))\n",
" print(' eval : %f' % (t5 - t4))\n",
" print('-----------------------------------')\n",
"\n",
" # Print test statistics\n",
" logging('Results of {}'.format(name))\n",
" logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))\n",
" logging(' Acc using the IoU metric = {:.6f}%'.format(acciou))\n",
" logging(' Acc using 10% threshold - {} vx 3D Transformation = {:.2f}%'.format(diam * 0.1, acc3d10))\n",
" logging(' Acc using 5 cm 5 degree metric = {:.2f}%'.format(acc5cm5deg))\n",
" logging(\" Mean 2D pixel error is %f, Mean vertex error is %f, mean corner error is %f\" % (mean_err_2d, np.mean(errs_3d), mean_corner_err_2d))\n",
@ -319,10 +277,11 @@
" predfile = backupdir + '/predictions_linemod_' + name + '.mat'\n",
" scipy.io.savemat(predfile, {'R_gts': gts_rot, 't_gts':gts_trans, 'corner_gts': gts_corners2D, 'R_prs': preds_rot, 't_prs':preds_trans, 'corner_prs': preds_corners2D})\n",
"\n",
"datacfg = 'cfg/ape.data'\n",
"cfgfile = 'cfg/yolo-pose.cfg'\n",
"datacfg = 'cfg/ape.data'\n",
"modelcfg = 'cfg/yolo-pose.cfg'\n",
"weightfile = 'backup/ape/model_backup.weights'\n",
"valid(datacfg, cfgfile, weightfile)"
"valid(datacfg, modelcfg, weightfile)\n",
" "
]
},
{
@ -335,21 +294,21 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"display_name": "Python 3",
"language": "python",
"name": "python2"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,

221
valid.py
Просмотреть файл

@ -1,63 +1,51 @@
import os
import time
import torch
from torch.autograd import Variable
from torchvision import datasets, transforms
import argparse
import scipy.io
import warnings
warnings.filterwarnings("ignore")
from torch.autograd import Variable
from torchvision import datasets, transforms
from darknet import Darknet
import dataset
from darknet import Darknet
from utils import *
from MeshPly import MeshPly
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
def valid(datacfg, cfgfile, weightfile, outfile):
def truths_length(truths):
for i in range(50):
def valid(datacfg, modelcfg, weightfile):
def truths_length(truths, max_num_gt=50):
for i in range(max_num_gt):
if truths[i][1] == 0:
return i
# Parse configuration files
options = read_data_cfg(datacfg)
valid_images = options['valid']
meshname = options['mesh']
backupdir = options['backup']
name = options['name']
data_options = read_data_cfg(datacfg)
valid_images = data_options['valid']
meshname = data_options['mesh']
backupdir = data_options['backup']
name = data_options['name']
gpus = data_options['gpus']
fx = float(data_options['fx'])
fy = float(data_options['fy'])
u0 = float(data_options['u0'])
v0 = float(data_options['v0'])
im_width = int(data_options['width'])
im_height = int(data_options['height'])
if not os.path.exists(backupdir):
makedirs(backupdir)
# Parameters
prefix = 'results'
seed = int(time.time())
gpus = '0' # Specify which gpus to use
test_width = 544
test_height = 544
torch.manual_seed(seed)
use_cuda = True
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
seed = int(time.time())
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
save = False
testtime = True
use_cuda = True
num_classes = 1
testing_samples = 0.0
eps = 1e-5
notpredicted = 0
conf_thresh = 0.1
nms_thresh = 0.4
match_thresh = 0.5
if save:
makedirs(backupdir + '/test')
makedirs(backupdir + '/test/gt')
makedirs(backupdir + '/test/pr')
# To save
testing_error_trans = 0.0
testing_error_angle = 0.0
@ -75,14 +63,16 @@ def valid(datacfg, cfgfile, weightfile, outfile):
gts_corners2D = []
# Read object model information, get 3D bounding box corners
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
# diam = calc_pts_diameter(np.array(mesh.vertices))
diam = float(options['diam'])
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
try:
diam = float(options['diam'])
except:
diam = calc_pts_diameter(np.array(mesh.vertices))
# Read intrinsic camera parameters
internal_calibration = get_camera_intrinsic()
intrinsic_calibration = get_camera_intrinsic(u0, v0, fx, fy)
# Get validation file names
with open(valid_images) as fp:
@ -90,82 +80,66 @@ def valid(datacfg, cfgfile, weightfile, outfile):
valid_files = [item.rstrip() for item in tmp_files]
# Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode
model = Darknet(cfgfile)
model = Darknet(modelcfg)
model.print_network()
model.load_weights(weightfile)
model.cuda()
model.eval()
test_width = model.test_width
test_height = model.test_height
num_keypoints = model.num_keypoints
num_labels = num_keypoints * 2 + 3
# Get the parser for the test dataset
valid_dataset = dataset.listDataset(valid_images, shape=(test_width, test_height),
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor(),]))
valid_batchsize = 1
valid_dataset = dataset.listDataset(valid_images,
shape=(test_width, test_height),
shuffle=False,
transform=transforms.Compose([transforms.ToTensor(),]))
# Specify the number of workers for multiple processing, get the dataloader for the test dataset
kwargs = {'num_workers': 4, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs)
test_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=False, **kwargs)
logging(" Testing {}...".format(name))
logging(" Number of test samples: %d" % len(test_loader.dataset))
# Iterate through test batches (Batch size for test data is 1)
count = 0
z = np.zeros((3, 1))
for batch_idx, (data, target) in enumerate(test_loader):
t1 = time.time()
# Pass data to GPU
if use_cuda:
data = data.cuda()
target = target.cuda()
data = data.cuda()
target = target.cuda()
# Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference
data = Variable(data, volatile=True)
t2 = time.time()
# Forward pass
output = model(data).data
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
all_boxes = get_region_boxes(output, conf_thresh, num_classes)
all_boxes = get_region_boxes(output, num_classes, num_keypoints)
t4 = time.time()
# Iterate through all images in the batch
for i in range(output.size(0)):
# For each image, get all the predictions
boxes = all_boxes[i]
# Evaluation
# Iterate through all batch elements
for box_pr, target in zip([all_boxes], [target[0]]):
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
# Get how many object are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
truths = target.view(-1, num_keypoints*2+3)
# Get how many objects are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
# If the prediction has the highest confidence, choose it as our prediction for single object pose estimation
for j in range(len(boxes)):
if (boxes[j][18] > best_conf_est):
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
box_pr = boxes[j]
best_conf_est = boxes[j][18]
box_gt = list()
for j in range(1, 2*num_keypoints+1):
box_gt.append(truths[k][j])
box_gt.extend([1.0, 1.0])
box_gt.append(truths[k][0])
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640
corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480
corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640
corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width
corners2D_pr[:, 1] = corners2D_pr[:, 1] * im_height
preds_corners2D.append(corners2D_pr)
gts_corners2D.append(corners2D_gt)
@ -175,21 +149,8 @@ def valid(datacfg, cfgfile, weightfile, outfile):
errs_corner2D.append(corner_dist)
# Compute [R|t] by pnp
R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(internal_calibration, dtype='float32'))
R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))
if save:
preds_trans.append(t_pr)
gts_trans.append(t_gt)
preds_rot.append(R_pr)
gts_rot.append(R_gt)
np.savetxt(backupdir + '/test/gt/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_gt, dtype='float32'))
np.savetxt(backupdir + '/test/gt/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_gt, dtype='float32'))
np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))
np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))
np.savetxt(backupdir + '/test/gt/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_gt, dtype='float32'))
np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))
R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(intrinsic_calibration, dtype='float32'))
R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(intrinsic_calibration, dtype='float32'))
# Compute translation error
trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))
@ -202,8 +163,8 @@ def valid(datacfg, cfgfile, weightfile, outfile):
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
proj_2d_gt = compute_projection(vertices, Rt_gt, intrinsic_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, intrinsic_calibration)
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
@ -222,26 +183,41 @@ def valid(datacfg, cfgfile, weightfile, outfile):
testing_samples += 1
count = count + 1
if save:
preds_trans.append(t_pr)
gts_trans.append(t_gt)
preds_rot.append(R_pr)
gts_rot.append(R_gt)
np.savetxt(backupdir + '/test/gt/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_gt, dtype='float32'))
np.savetxt(backupdir + '/test/gt/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_gt, dtype='float32'))
np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))
np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))
np.savetxt(backupdir + '/test/gt/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_gt, dtype='float32'))
np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))
t5 = time.time()
# Compute 2D projection error, 6D pose error, 5cm5degree error
px_threshold = 5
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
acc3d10 = len(np.where(np.array(errs_3d) <= diam * 0.1)[0]) * 100. / (len(errs_3d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)
mean_err_2d = np.mean(errs_2d)
px_threshold = 5 # 5 pixel threshold for 2D reprojection error is standard in recent sota 6D object pose estimation works
eps = 1e-5
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
acc3d10 = len(np.where(np.array(errs_3d) <= diam * 0.1)[0]) * 100. / (len(errs_3d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)
mean_err_2d = np.mean(errs_2d)
mean_corner_err_2d = np.mean(errs_corner2D)
nts = float(testing_samples)
if testtime:
print('-----------------------------------')
print(' tensor to cuda : %f' % (t2 - t1))
print(' predict : %f' % (t3 - t2))
print(' forward pass : %f' % (t3 - t2))
print('get_region_boxes : %f' % (t4 - t3))
print(' prediction time : %f' % (t4 - t1))
print(' eval : %f' % (t5 - t4))
print(' total : %f' % (t5 - t1))
print('-----------------------------------')
# Print test statistics
@ -257,13 +233,14 @@ def valid(datacfg, cfgfile, weightfile, outfile):
scipy.io.savemat(predfile, {'R_gts': gts_rot, 't_gts':gts_trans, 'corner_gts': gts_corners2D, 'R_prs': preds_rot, 't_prs':preds_trans, 'corner_prs': preds_corners2D})
if __name__ == '__main__':
import sys
if len(sys.argv) == 4:
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
weightfile = sys.argv[3]
outfile = 'comp4_det_test_'
valid(datacfg, cfgfile, weightfile, outfile)
else:
print('Usage:')
print(' python valid.py datacfg cfgfile weightfile')
# Parse configuration files
parser = argparse.ArgumentParser(description='SingleShotPose')
parser.add_argument('--datacfg', type=str, default='cfg/ape.data') # data config
parser.add_argument('--modelcfg', type=str, default='cfg/yolo-pose.cfg') # network config
parser.add_argument('--weightfile', type=str, default='backup/ape/model_backup.weights') # imagenet initialized weights
args = parser.parse_args()
datacfg = args.datacfg
modelcfg = args.modelcfg
weightfile = args.weightfile
valid(datacfg, modelcfg, weightfile)