This commit is contained in:
Sudipta N. Sinha 2018-06-30 11:11:16 -07:00 коммит произвёл GitHub
Родитель 4be159a331
Коммит 1c1501738e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 3477 добавлений и 10 удалений

Двоичные данные
.DS_Store поставляемый Normal file

Двоичный файл не отображается.

13
LICENSE.txt Normal file
Просмотреть файл

@ -0,0 +1,13 @@
Single Shot Seamless Object Pose Estimation
Copyright (c) Microsoft Corporation
All rights reserved.
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the Software), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

53
MeshPly.py Normal file
Просмотреть файл

@ -0,0 +1,53 @@
# Class to read
class MeshPly:
def __init__(self, filename, color=[0., 0., 0.]):
f = open(filename, 'r')
self.vertices = []
self.colors = []
self.indices = []
self.normals = []
vertex_mode = False
face_mode = False
nb_vertices = 0
nb_faces = 0
idx = 0
with f as open_file_object:
for line in open_file_object:
elements = line.split()
if vertex_mode:
self.vertices.append([float(i) for i in elements[:3]])
self.normals.append([float(i) for i in elements[3:6]])
if elements[6:9]:
self.colors.append([float(i) / 255. for i in elements[6:9]])
else:
self.colors.append([float(i) / 255. for i in color])
idx += 1
if idx == nb_vertices:
vertex_mode = False
face_mode = True
idx = 0
elif face_mode:
self.indices.append([float(i) for i in elements[1:4]])
idx += 1
if idx == nb_faces:
face_mode = False
elif elements[0] == 'element':
if elements[1] == 'vertex':
nb_vertices = int(elements[2])
elif elements[1] == 'face':
nb_faces = int(elements[2])
elif elements[0] == 'end_header':
vertex_mode = True
if __name__ == '__main__':
path_model = ''
mesh = MeshPly(path_model)

142
README.md
Просмотреть файл

@ -1,14 +1,136 @@
# SINGLESHOTPOSE
This is the code for the following paper:
# Contributing
Bugra Tekin, Sudipta N. Sinha and Pascal Fua, "Real-Time Seamless Single Shot 6D Object Pose Prediction", CVPR 2018.
### Introduction
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.microsoft.com.
We propose a single-shot approach for simultaneously detecting an object in an RGB image and predicting its 6D pose without requiring multiple stages or having to examine multiple hypotheses. The key component of our method is a new CNN architecture inspired by the YOLO network design that directly predicts the 2D image locations of the projected vertices of the object's 3D bounding box. The object's 6D pose is then estimated using a PnP algorithm. [Paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Tekin_Real-Time_Seamless_Single_CVPR_2018_paper.pdf), [arXiv](https://arxiv.org/abs/1711.08848)
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
![SingleShotPose](https://btekin.github.io/single_shot_pose.png)
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
#### Citation
If you use this code, please cite the following
> @article{varol17a,
      TITLE = {{Real-Time Seamless Single Shot 6D Object Pose Prediction}},
      AUTHOR = {Tekin, Bugra and Sinha, Sudipta N. and Fua, Pascal},
      JOURNAL = {CVPR},
      YEAR = {2018}
}
### License
SingleShotPose is released under the MIT License (refer to the LICENSE file for details).
#### Environment and dependencies
The code is tested on Linux with CUDA v8 and cudNN v5.1. The implementation is based on PyTorch and tested on Python2.7. The code requires the following dependencies that could be installed with conda or pip: numpy, scipy, PIL, opencv-python
#### Downloading and preparing the data
Inside the main code directory, run the following to download and extract (1) the preprocessed LINEMOD dataset, (2) trained models for the LINEMOD dataset, (3) the trained model for the OCCLUSION dataset, (4) background images from the VOC2012 dataset respectively.
```
wget -O LINEMOD.tar --no-check-certificate "https://onedrive.live.com/download?cid=05750EBEE1537631&resid=5750EBEE1537631%21135&authkey=AJRHFmZbcjXxTmI"
wget -O backup.tar --no-check-certificate "https://onedrive.live.com/download?cid=0C78B7DE6C569D7B&resid=C78B7DE6C569D7B%21191&authkey=AP183o4PlczZR78"
wget -O multi_obj_pose_estimation/backup_multi.tar --no-check-certificate "https://onedrive.live.com/download?cid=05750EBEE1537631&resid=5750EBEE1537631%21136&authkey=AFQv01OSbvhGnoM"
wget https://pjreddie.com/media/files/VOCtrainval_11-May-2012.tar
wget https://pjreddie.com/media/files/darknet19_448.conv.23 -P cfg/
tar xf LINEMOD.tar
tar xf backup.tar
tar xf multi_obj_pose_estimation/backup_multi.tar -C multi_obj_pose_estimation/
tar xf VOCtrainval_11-May-2012.tar
```
Alternatively, you can directly go to the links above and manually download and extract the files at the corresponding directories. The whole download process might take a long while (~60 minutes).
#### Training the model
To train the model run,
```
python train.py datafile cfgfile initweightfile
```
e.g.
```
python train.py cfg/ape.data cfg/yolo-pose.cfg backup/ape/init.weights
```
[datafile] contains information about the training/test splits and 3D object models
[cfgfile] contains information about the network structure
[initweightfile] contains initialization weights. The weights "backup/[OBJECT_NAME]/init.weights" are pretrained on LINEMOD for faster convergence. We found it effective to pretrain the model without confidence estimation first and fine-tune the network later on with confidence estimation as well. "init.weights" contain the weights of these pretrained networks. However, you can also still train the network from a more crude initialization (with weights trained on ImageNet). This usually results in a slower and sometime slightly worse convergence. You can find in cfg/ folder, the file <<darknet19_448.conv.23>>, includes the network weights pretrained on ImageNet. Alternatively, you can pretrain your own weights by setting the regularization parameter for the confidence loss to 0 as explained in "Pretraining the model" section.
At the start of the training you will see an output like this:
```
layer filters size input output
0 conv 32 3 x 3 / 1 416 x 416 x 3 -> 416 x 416 x 32
1 max 2 x 2 / 2 416 x 416 x 32 -> 208 x 208 x 32
2 conv 64 3 x 3 / 1 208 x 208 x 32 -> 208 x 208 x 64
3 max 2 x 2 / 2 208 x 208 x 64 -> 104 x 104 x 64
...
30 conv 20 1 x 1 / 1 13 x 13 x1024 -> 13 x 13 x 20
31 detection
```
This defines the network structure. During training, the best network model is saved into the "model.weights" file. To train networks for other objects, just change the object name while calling the train function, e.g., "python train.py cfg/duck.data cfg/yolo-pose.cfg backup/duck/init.weights"
#### Testing the model
To test the model run
```
python valid.py datafile cfgfile weightfile
e.g.,
python valid.py cfg/ape.data cfg/yolo-pose.cfg backup/ape/model_backup.weights
```
[weightfile] contains our trained models.
#### Pretraining the model (Optional)
Models are already pretrained but if you would like to pretrain the network from scratch and get the initialization weights yourself, you can run the following:
python train.py cfg/ape.data cfg/yolo-pose-pre.cfg cfg/darknet19_448.conv.23
cp backup/ape/model.weights backup/ape/init.weights
During pretraining the regularization parameter for the confidence term is set to "0" in the config file "cfg/yolo-pose-pre.cfg". "darknet19_448.conv.23" includes the weights of YOLOv2 trained on ImageNet.
#### Multi-object pose estimation on the OCCLUSION dataset
Inside multi_obj_pose_estimation/ folder
Testing:
```
python valid_multi.py cfgfile weightfile
e.g.,
python valid_multi.py cfg/yolo-pose-multi.cfg backup_multi/model_backup.weights
```
Training:
```
python train_multi.py datafile cfgfile weightfile
```
e.g.,
```
python train_multi.py cfg/occlusion.data cfg/yolo-pose-multi.cfg backup_multi/init.weights
```
#### Output Representation
Our output target representation consist of 21 values. We predict 9 points corresponding to the centroid and corners of the 3D object model. Additionally we predict the class in each cell. That makes 9x2+1 = 19 points. In multi-object training, during training, we assign whichever anchor box has the most similar size to the current object as the responsible one to predict the 2D coordinates for that object. To encode the size of the objects, we have additional 2 numbers for the range in x dimension and y dimension. Therefore, we have 9x2+1+2 = 21 numbers
Respectively, 21 numbers correspond to the following: 1st number: class label, 2nd number: x0 (x-coordinate of the centroid), 3rd number: y0 (y-coordinate of the centroid), 4th number: x1 (x-coordinate of the first corner), 5th number: y1 (y-coordinate of the first corner), ..., 18th number: x8 (x-coordinate of the eighth corner), 19th number: y8 (y-coordinate of the eighth corner), 20th number: x range, 21st number: y range.
The coordinates are normalized by the image width and height: x / image_width åand y / image_height. This is useful to have similar output ranges for the coordinate regression and object classification tasks.
#### Acknowledgments
The code is written by [Bugra Tekin](http://bugratekin.info) and is built on the YOLOv2 implementation of the github user [@marvis](https://github.com/marvis)
#### Contact
For any questions or bug reports, please contact [Bugra Tekin](http://bugratekin.info)

208
cfg.py Normal file
Просмотреть файл

@ -0,0 +1,208 @@
import torch
from utils import convert2cpu
def parse_cfg(cfgfile):
blocks = []
fp = open(cfgfile, 'r')
block = None
line = fp.readline()
while line != '':
line = line.rstrip()
if line == '' or line[0] == '#':
line = fp.readline()
continue
elif line[0] == '[':
if block:
blocks.append(block)
block = dict()
block['type'] = line.lstrip('[').rstrip(']')
# set default value
if block['type'] == 'convolutional':
block['batch_normalize'] = 0
else:
key,value = line.split('=')
key = key.strip()
if key == 'type':
key = '_type'
value = value.strip()
block[key] = value
line = fp.readline()
if block:
blocks.append(block)
fp.close()
return blocks
def print_cfg(blocks):
print('layer filters size input output');
prev_width = 416
prev_height = 416
prev_filters = 3
out_filters =[]
out_widths =[]
out_heights =[]
ind = -2
for block in blocks:
ind = ind + 1
if block['type'] == 'net':
prev_width = int(block['width'])
prev_height = int(block['height'])
continue
elif block['type'] == 'convolutional':
filters = int(block['filters'])
kernel_size = int(block['size'])
stride = int(block['stride'])
is_pad = int(block['pad'])
pad = (kernel_size-1)/2 if is_pad else 0
width = (prev_width + 2*pad - kernel_size)/stride + 1
height = (prev_height + 2*pad - kernel_size)/stride + 1
print('%5d %-6s %4d %d x %d / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'conv', filters, kernel_size, kernel_size, stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
prev_filters = filters
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'maxpool':
pool_size = int(block['size'])
stride = int(block['stride'])
width = prev_width/stride
height = prev_height/stride
print('%5d %-6s %d x %d / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'max', pool_size, pool_size, stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
prev_filters = filters
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'avgpool':
width = 1
height = 1
print('%5d %-6s %3d x %3d x%4d -> %3d' % (ind, 'avg', prev_width, prev_height, prev_filters, prev_filters))
prev_width = width
prev_height = height
prev_filters = filters
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'softmax':
print('%5d %-6s -> %3d' % (ind, 'softmax', prev_filters))
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'cost':
print('%5d %-6s -> %3d' % (ind, 'cost', prev_filters))
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'reorg':
stride = int(block['stride'])
filters = stride * stride * prev_filters
width = prev_width/stride
height = prev_height/stride
print('%5d %-6s / %d %3d x %3d x%4d -> %3d x %3d x%4d' % (ind, 'reorg', stride, prev_width, prev_height, prev_filters, width, height, filters))
prev_width = width
prev_height = height
prev_filters = filters
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'route':
layers = block['layers'].split(',')
layers = [int(i) if int(i) > 0 else int(i)+ind for i in layers]
if len(layers) == 1:
print('%5d %-6s %d' % (ind, 'route', layers[0]))
prev_width = out_widths[layers[0]]
prev_height = out_heights[layers[0]]
prev_filters = out_filters[layers[0]]
elif len(layers) == 2:
print('%5d %-6s %d %d' % (ind, 'route', layers[0], layers[1]))
prev_width = out_widths[layers[0]]
prev_height = out_heights[layers[0]]
assert(prev_width == out_widths[layers[1]])
assert(prev_height == out_heights[layers[1]])
prev_filters = out_filters[layers[0]] + out_filters[layers[1]]
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'region':
print('%5d %-6s' % (ind, 'detection'))
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'shortcut':
from_id = int(block['from'])
from_id = from_id if from_id > 0 else from_id+ind
print('%5d %-6s %d' % (ind, 'shortcut', from_id))
prev_width = out_widths[from_id]
prev_height = out_heights[from_id]
prev_filters = out_filters[from_id]
out_widths.append(prev_width)
out_heights.append(prev_height)
out_filters.append(prev_filters)
elif block['type'] == 'connected':
filters = int(block['output'])
print('%5d %-6s %d -> %3d' % (ind, 'connected', prev_filters, filters))
prev_filters = filters
out_widths.append(1)
out_heights.append(1)
out_filters.append(prev_filters)
else:
print('unknown type %s' % (block['type']))
def load_conv(buf, start, conv_model):
num_w = conv_model.weight.numel()
num_b = conv_model.bias.numel()
conv_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
return start
def save_conv(fp, conv_model):
if conv_model.bias.is_cuda:
convert2cpu(conv_model.bias.data).numpy().tofile(fp)
convert2cpu(conv_model.weight.data).numpy().tofile(fp)
else:
conv_model.bias.data.numpy().tofile(fp)
conv_model.weight.data.numpy().tofile(fp)
def load_conv_bn(buf, start, conv_model, bn_model):
num_w = conv_model.weight.numel()
num_b = bn_model.bias.numel()
bn_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
bn_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
bn_model.running_mean.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
bn_model.running_var.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
return start
def save_conv_bn(fp, conv_model, bn_model):
if bn_model.bias.is_cuda:
convert2cpu(bn_model.bias.data).numpy().tofile(fp)
convert2cpu(bn_model.weight.data).numpy().tofile(fp)
convert2cpu(bn_model.running_mean).numpy().tofile(fp)
convert2cpu(bn_model.running_var).numpy().tofile(fp)
convert2cpu(conv_model.weight.data).numpy().tofile(fp)
else:
bn_model.bias.data.numpy().tofile(fp)
bn_model.weight.data.numpy().tofile(fp)
bn_model.running_mean.numpy().tofile(fp)
bn_model.running_var.numpy().tofile(fp)
conv_model.weight.data.numpy().tofile(fp)
def load_fc(buf, start, fc_model):
num_w = fc_model.weight.numel()
num_b = fc_model.bias.numel()
fc_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])); start = start + num_b
fc_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])); start = start + num_w
return start
def save_fc(fp, fc_model):
fc_model.bias.data.numpy().tofile(fp)
fc_model.weight.data.numpy().tofile(fp)
if __name__ == '__main__':
import sys
blocks = parse_cfg('cfg/yolo.cfg')
if len(sys.argv) == 2:
blocks = parse_cfg(sys.argv[1])
print_cfg(blocks)

388
darknet.py Normal file
Просмотреть файл

@ -0,0 +1,388 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from region_loss import RegionLoss
from cfg import *
class MaxPoolStride1(nn.Module):
def __init__(self):
super(MaxPoolStride1, self).__init__()
def forward(self, x):
x = F.max_pool2d(F.pad(x, (0,1,0,1), mode='replicate'), 2, stride=1)
return x
class Reorg(nn.Module):
def __init__(self, stride=2):
super(Reorg, self).__init__()
self.stride = stride
def forward(self, x):
stride = self.stride
assert(x.data.dim() == 4)
B = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
assert(H % stride == 0)
assert(W % stride == 0)
ws = stride
hs = stride
x = x.view(B, C, H/hs, hs, W/ws, ws).transpose(3,4).contiguous()
x = x.view(B, C, H/hs*W/ws, hs*ws).transpose(2,3).contiguous()
x = x.view(B, C, hs*ws, H/hs, W/ws).transpose(1,2).contiguous()
x = x.view(B, hs*ws*C, H/hs, W/ws)
return x
class GlobalAvgPool2d(nn.Module):
def __init__(self):
super(GlobalAvgPool2d, self).__init__()
def forward(self, x):
N = x.data.size(0)
C = x.data.size(1)
H = x.data.size(2)
W = x.data.size(3)
x = F.avg_pool2d(x, (H, W))
x = x.view(N, C)
return x
# for route and shortcut
class EmptyModule(nn.Module):
def __init__(self):
super(EmptyModule, self).__init__()
def forward(self, x):
return x
# support route shortcut and reorg
class Darknet(nn.Module):
def __init__(self, cfgfile):
super(Darknet, self).__init__()
self.blocks = parse_cfg(cfgfile)
self.models = self.create_network(self.blocks) # merge conv, bn,leaky
self.loss = self.models[len(self.models)-1]
self.width = int(self.blocks[0]['width'])
self.height = int(self.blocks[0]['height'])
if self.blocks[(len(self.blocks)-1)]['type'] == 'region':
self.anchors = self.loss.anchors
self.num_anchors = self.loss.num_anchors
self.anchor_step = self.loss.anchor_step
self.num_classes = self.loss.num_classes
self.header = torch.IntTensor([0,0,0,0])
self.seen = 0
self.iter = 0
def forward(self, x):
ind = -2
self.loss = None
outputs = dict()
for block in self.blocks:
ind = ind + 1
#if ind > 0:
# return x
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional' or block['type'] == 'maxpool' or block['type'] == 'reorg' or block['type'] == 'avgpool' or block['type'] == 'softmax' or block['type'] == 'connected':
x = self.models[ind](x)
outputs[ind] = x
elif block['type'] == 'route':
layers = block['layers'].split(',')
layers = [int(i) if int(i) > 0 else int(i)+ind for i in layers]
if len(layers) == 1:
x = outputs[layers[0]]
outputs[ind] = x
elif len(layers) == 2:
x1 = outputs[layers[0]]
x2 = outputs[layers[1]]
x = torch.cat((x1,x2),1)
outputs[ind] = x
elif block['type'] == 'shortcut':
from_layer = int(block['from'])
activation = block['activation']
from_layer = from_layer if from_layer > 0 else from_layer + ind
x1 = outputs[from_layer]
x2 = outputs[ind-1]
x = x1 + x2
if activation == 'leaky':
x = F.leaky_relu(x, 0.1, inplace=True)
elif activation == 'relu':
x = F.relu(x, inplace=True)
outputs[ind] = x
elif block['type'] == 'region':
continue
if self.loss:
self.loss = self.loss + self.models[ind](x)
else:
self.loss = self.models[ind](x)
outputs[ind] = None
elif block['type'] == 'cost':
continue
else:
print('unknown type %s' % (block['type']))
return x
def print_network(self):
print_cfg(self.blocks)
def create_network(self, blocks):
models = nn.ModuleList()
prev_filters = 3
out_filters =[]
conv_id = 0
for block in blocks:
if block['type'] == 'net':
prev_filters = int(block['channels'])
continue
elif block['type'] == 'convolutional':
conv_id = conv_id + 1
batch_normalize = int(block['batch_normalize'])
filters = int(block['filters'])
kernel_size = int(block['size'])
stride = int(block['stride'])
is_pad = int(block['pad'])
pad = (kernel_size-1)/2 if is_pad else 0
activation = block['activation']
model = nn.Sequential()
if batch_normalize:
model.add_module('conv{0}'.format(conv_id), nn.Conv2d(prev_filters, filters, kernel_size, stride, pad, bias=False))
model.add_module('bn{0}'.format(conv_id), nn.BatchNorm2d(filters, eps=1e-4))
#model.add_module('bn{0}'.format(conv_id), BN2d(filters))
else:
model.add_module('conv{0}'.format(conv_id), nn.Conv2d(prev_filters, filters, kernel_size, stride, pad))
if activation == 'leaky':
model.add_module('leaky{0}'.format(conv_id), nn.LeakyReLU(0.1, inplace=True))
elif activation == 'relu':
model.add_module('relu{0}'.format(conv_id), nn.ReLU(inplace=True))
prev_filters = filters
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'maxpool':
pool_size = int(block['size'])
stride = int(block['stride'])
if stride > 1:
model = nn.MaxPool2d(pool_size, stride)
else:
model = MaxPoolStride1()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'avgpool':
model = GlobalAvgPool2d()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'softmax':
model = nn.Softmax()
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'cost':
if block['_type'] == 'sse':
model = nn.MSELoss(size_average=True)
elif block['_type'] == 'L1':
model = nn.L1Loss(size_average=True)
elif block['_type'] == 'smooth':
model = nn.SmoothL1Loss(size_average=True)
out_filters.append(1)
models.append(model)
elif block['type'] == 'reorg':
stride = int(block['stride'])
prev_filters = stride * stride * prev_filters
out_filters.append(prev_filters)
models.append(Reorg(stride))
elif block['type'] == 'route':
layers = block['layers'].split(',')
ind = len(models)
layers = [int(i) if int(i) > 0 else int(i)+ind for i in layers]
if len(layers) == 1:
prev_filters = out_filters[layers[0]]
elif len(layers) == 2:
assert(layers[0] == ind - 1)
prev_filters = out_filters[layers[0]] + out_filters[layers[1]]
out_filters.append(prev_filters)
models.append(EmptyModule())
elif block['type'] == 'shortcut':
ind = len(models)
prev_filters = out_filters[ind-1]
out_filters.append(prev_filters)
models.append(EmptyModule())
elif block['type'] == 'connected':
filters = int(block['output'])
if block['activation'] == 'linear':
model = nn.Linear(prev_filters, filters)
elif block['activation'] == 'leaky':
model = nn.Sequential(
nn.Linear(prev_filters, filters),
nn.LeakyReLU(0.1, inplace=True))
elif block['activation'] == 'relu':
model = nn.Sequential(
nn.Linear(prev_filters, filters),
nn.ReLU(inplace=True))
prev_filters = filters
out_filters.append(prev_filters)
models.append(model)
elif block['type'] == 'region':
loss = RegionLoss()
anchors = block['anchors'].split(',')
loss.anchors = [float(i) for i in anchors]
loss.num_classes = int(block['classes'])
loss.num_anchors = int(block['num'])
loss.anchor_step = len(loss.anchors)/loss.num_anchors
loss.object_scale = float(block['object_scale'])
loss.noobject_scale = float(block['noobject_scale'])
loss.class_scale = float(block['class_scale'])
loss.coord_scale = float(block['coord_scale'])
out_filters.append(prev_filters)
models.append(loss)
else:
print('unknown type %s' % (block['type']))
return models
def load_weights(self, weightfile):
fp = open(weightfile, 'rb')
header = np.fromfile(fp, count=4, dtype=np.int32)
self.header = torch.from_numpy(header)
self.seen = self.header[3]
buf = np.fromfile(fp, dtype = np.float32)
fp.close()
start = 0
ind = -2
for block in self.blocks:
if start >= buf.size:
break
ind = ind + 1
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
start = load_conv_bn(buf, start, model[0], model[1])
else:
start = load_conv(buf, start, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
start = load_fc(buf, start, model[0])
else:
start = load_fc(buf, start, model)
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
def load_weights_until_last(self, weightfile):
fp = open(weightfile, 'rb')
header = np.fromfile(fp, count=4, dtype=np.int32)
self.header = torch.from_numpy(header)
self.seen = self.header[3]
buf = np.fromfile(fp, dtype = np.float32)
fp.close()
start = 0
ind = -2
blocklen = len(self.blocks)
for i in range(blocklen-2):
block = self.blocks[i]
if start >= buf.size:
break
ind = ind + 1
if block['type'] == 'net':
continue
elif block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
start = load_conv_bn(buf, start, model[0], model[1])
else:
start = load_conv(buf, start, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
start = load_fc(buf, start, model[0])
else:
start = load_fc(buf, start, model)
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
def save_weights(self, outfile, cutoff=0):
if cutoff <= 0:
cutoff = len(self.blocks)-1
fp = open(outfile, 'wb')
self.header[3] = self.seen
header = self.header
header.numpy().tofile(fp)
ind = -1
for blockId in range(1, cutoff+1):
ind = ind + 1
block = self.blocks[blockId]
if block['type'] == 'convolutional':
model = self.models[ind]
batch_normalize = int(block['batch_normalize'])
if batch_normalize:
save_conv_bn(fp, model[0], model[1])
else:
save_conv(fp, model[0])
elif block['type'] == 'connected':
model = self.models[ind]
if block['activation'] != 'linear':
save_fc(fc, model)
else:
save_fc(fc, model[0])
elif block['type'] == 'maxpool':
pass
elif block['type'] == 'reorg':
pass
elif block['type'] == 'route':
pass
elif block['type'] == 'shortcut':
pass
elif block['type'] == 'region':
pass
elif block['type'] == 'avgpool':
pass
elif block['type'] == 'softmax':
pass
elif block['type'] == 'cost':
pass
else:
print('unknown type %s' % (block['type']))
fp.close()

101
dataset.py Normal file
Просмотреть файл

@ -0,0 +1,101 @@
#!/usr/bin/python
# encoding: utf-8
import os
import random
from PIL import Image
import numpy as np
from image import *
import torch
from torch.utils.data import Dataset
from utils import read_truths_args, read_truths, get_all_files
class listDataset(Dataset):
def __init__(self, root, shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4, bg_file_names=None):
with open(root, 'r') as file:
self.lines = file.readlines()
if shuffle:
random.shuffle(self.lines)
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.num_workers = num_workers
self.bg_file_names = bg_file_names
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
imgpath = self.lines[index].rstrip()
if self.train and index % 32== 0:
if self.seen < 400*32:
width = 13*32
self.shape = (width, width)
elif self.seen < 800*32:
width = (random.randint(0,7) + 13)*32
self.shape = (width, width)
elif self.seen < 1200*32:
width = (random.randint(0,9) + 12)*32
self.shape = (width, width)
elif self.seen < 1600*32:
width = (random.randint(0,11) + 11)*32
self.shape = (width, width)
elif self.seen < 2000*32:
width = (random.randint(0,13) + 10)*32
self.shape = (width, width)
elif self.seen < 2400*32:
width = (random.randint(0,15) + 9)*32
self.shape = (width, width)
elif self.seen < 3000*32:
width = (random.randint(0,17) + 8)*32
self.shape = (width, width)
else: # self.seen < 20000*64:
width = (random.randint(0,19) + 7)*32
self.shape = (width, width)
if self.train:
jitter = 0.2
hue = 0.1
saturation = 1.5
exposure = 1.5
# Get background image path
random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
bgpath = self.bg_file_names[random_bg_index]
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath)
label = torch.from_numpy(label)
else:
img = Image.open(imgpath).convert('RGB')
if self.shape:
img = img.resize(self.shape)
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
label = torch.zeros(50*21)
if os.path.getsize(labpath):
ow, oh = img.size
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/ow))
tmp = tmp.view(-1)
tsz = tmp.numel()
if tsz > 50*21:
label = tmp[0:50*21]
elif tsz > 0:
label[0:tsz] = tmp
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.target_transform(label)
self.seen = self.seen + self.num_workers
return (img, label)

184
image.py Normal file
Просмотреть файл

@ -0,0 +1,184 @@
#!/usr/bin/python
# encoding: utf-8
import random
import os
from PIL import Image, ImageChops, ImageMath
import numpy as np
def scale_image_channel(im, c, v):
cs = list(im.split())
cs[c] = cs[c].point(lambda i: i * v)
out = Image.merge(im.mode, tuple(cs))
return out
def distort_image(im, hue, sat, val):
im = im.convert('HSV')
cs = list(im.split())
cs[1] = cs[1].point(lambda i: i * sat)
cs[2] = cs[2].point(lambda i: i * val)
def change_hue(x):
x += hue*255
if x > 255:
x -= 255
if x < 0:
x += 255
return x
cs[0] = cs[0].point(change_hue)
im = Image.merge(im.mode, tuple(cs))
im = im.convert('RGB')
return im
def rand_scale(s):
scale = random.uniform(1, s)
if(random.randint(1,10000)%2):
return scale
return 1./scale
def random_distort_image(im, hue, saturation, exposure):
dhue = random.uniform(-hue, hue)
dsat = rand_scale(saturation)
dexp = rand_scale(exposure)
res = distort_image(im, dhue, dsat, dexp)
return res
def data_augmentation(img, shape, jitter, hue, saturation, exposure):
ow, oh = img.size
dw =int(ow*jitter)
dh =int(oh*jitter)
pleft = random.randint(-dw, dw)
pright = random.randint(-dw, dw)
ptop = random.randint(-dh, dh)
pbot = random.randint(-dh, dh)
swidth = ow - pleft - pright
sheight = oh - ptop - pbot
sx = float(swidth) / ow
sy = float(sheight) / oh
flip = random.randint(1,10000)%2
cropped = img.crop( (pleft, ptop, pleft + swidth - 1, ptop + sheight - 1))
dx = (float(pleft)/ow)/sx
dy = (float(ptop) /oh)/sy
sized = cropped.resize(shape)
img = random_distort_image(sized, hue, saturation, exposure)
return img, flip, dx,dy,sx,sy
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy):
max_boxes = 50
label = np.zeros((max_boxes,21))
if os.path.getsize(labpath):
bs = np.loadtxt(labpath)
if bs is None:
return label
bs = np.reshape(bs, (-1, 21))
cc = 0
for i in range(bs.shape[0]):
x0 = bs[i][1]
y0 = bs[i][2]
x1 = bs[i][3]
y1 = bs[i][4]
x2 = bs[i][5]
y2 = bs[i][6]
x3 = bs[i][7]
y3 = bs[i][8]
x4 = bs[i][9]
y4 = bs[i][10]
x5 = bs[i][11]
y5 = bs[i][12]
x6 = bs[i][13]
y6 = bs[i][14]
x7 = bs[i][15]
y7 = bs[i][16]
x8 = bs[i][17]
y8 = bs[i][18]
x0 = min(0.999, max(0, x0 * sx - dx))
y0 = min(0.999, max(0, y0 * sy - dy))
x1 = min(0.999, max(0, x1 * sx - dx))
y1 = min(0.999, max(0, y1 * sy - dy))
x2 = min(0.999, max(0, x2 * sx - dx))
y2 = min(0.999, max(0, y2 * sy - dy))
x3 = min(0.999, max(0, x3 * sx - dx))
y3 = min(0.999, max(0, y3 * sy - dy))
x4 = min(0.999, max(0, x4 * sx - dx))
y4 = min(0.999, max(0, y4 * sy - dy))
x5 = min(0.999, max(0, x5 * sx - dx))
y5 = min(0.999, max(0, y5 * sy - dy))
x6 = min(0.999, max(0, x6 * sx - dx))
y6 = min(0.999, max(0, y6 * sy - dy))
x7 = min(0.999, max(0, x7 * sx - dx))
y7 = min(0.999, max(0, y7 * sy - dy))
x8 = min(0.999, max(0, x8 * sx - dx))
y8 = min(0.999, max(0, y8 * sy - dy))
bs[i][1] = x0
bs[i][2] = y0
bs[i][3] = x1
bs[i][4] = y1
bs[i][5] = x2
bs[i][6] = y2
bs[i][7] = x3
bs[i][8] = y3
bs[i][9] = x4
bs[i][10] = y4
bs[i][11] = x5
bs[i][12] = y5
bs[i][13] = x6
bs[i][14] = y6
bs[i][15] = x7
bs[i][16] = y7
bs[i][17] = x8
bs[i][18] = y8
label[cc] = bs[i]
cc += 1
if cc >= 50:
break
label = np.reshape(label, (-1))
return label
def change_background(img, mask, bg):
# oh = img.height
# ow = img.width
ow, oh = img.size
bg = bg.resize((ow, oh)).convert('RGB')
imcs = list(img.split())
bgcs = list(bg.split())
maskcs = list(mask.split())
fics = list(Image.new(img.mode, img.size).split())
for c in range(len(imcs)):
negmask = maskcs[c].point(lambda i: 1 - i / 255)
posmask = maskcs[c].point(lambda i: i / 255)
fics[c] = ImageMath.eval("a * c + b * d", a=imcs[c], b=bgcs[c], c=posmask, d=negmask).convert('L')
out = Image.merge(img.mode, tuple(fics))
return out
def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure, bgpath):
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
maskpath = imgpath.replace('JPEGImages', 'mask').replace('/00', '/').replace('.jpg', '.png')
## data augmentation
img = Image.open(imgpath).convert('RGB')
mask = Image.open(maskpath).convert('RGB')
bg = Image.open(bgpath).convert('RGB')
img = change_background(img, mask, bg)
img,flip,dx,dy,sx,sy = data_augmentation(img, shape, jitter, hue, saturation, exposure)
ow, oh = img.size
label = fill_truth_detection(labpath, ow, oh, flip, dx, dy, 1./sx, 1./sy)
return img,label

301
region_loss.py Normal file
Просмотреть файл

@ -0,0 +1,301 @@
import time
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from utils import *
def build_targets(pred_corners, target, anchors, num_anchors, num_classes, nH, nW, noobject_scale, object_scale, sil_thresh, seen):
nB = target.size(0)
nA = num_anchors
nC = num_classes
anchor_step = len(anchors)/num_anchors
conf_mask = torch.ones(nB, nA, nH, nW) * noobject_scale
coord_mask = torch.zeros(nB, nA, nH, nW)
cls_mask = torch.zeros(nB, nA, nH, nW)
tx0 = torch.zeros(nB, nA, nH, nW)
ty0 = torch.zeros(nB, nA, nH, nW)
tx1 = torch.zeros(nB, nA, nH, nW)
ty1 = torch.zeros(nB, nA, nH, nW)
tx2 = torch.zeros(nB, nA, nH, nW)
ty2 = torch.zeros(nB, nA, nH, nW)
tx3 = torch.zeros(nB, nA, nH, nW)
ty3 = torch.zeros(nB, nA, nH, nW)
tx4 = torch.zeros(nB, nA, nH, nW)
ty4 = torch.zeros(nB, nA, nH, nW)
tx5 = torch.zeros(nB, nA, nH, nW)
ty5 = torch.zeros(nB, nA, nH, nW)
tx6 = torch.zeros(nB, nA, nH, nW)
ty6 = torch.zeros(nB, nA, nH, nW)
tx7 = torch.zeros(nB, nA, nH, nW)
ty7 = torch.zeros(nB, nA, nH, nW)
tx8 = torch.zeros(nB, nA, nH, nW)
ty8 = torch.zeros(nB, nA, nH, nW)
tconf = torch.zeros(nB, nA, nH, nW)
tcls = torch.zeros(nB, nA, nH, nW)
nAnchors = nA*nH*nW
nPixels = nH*nW
for b in range(nB):
cur_pred_corners = pred_corners[b*nAnchors:(b+1)*nAnchors].t()
cur_confs = torch.zeros(nAnchors)
for t in range(50):
if target[b][t*21+1] == 0:
break
gx0 = target[b][t*21+1]*nW
gy0 = target[b][t*21+2]*nH
gx1 = target[b][t*21+3]*nW
gy1 = target[b][t*21+4]*nH
gx2 = target[b][t*21+5]*nW
gy2 = target[b][t*21+6]*nH
gx3 = target[b][t*21+7]*nW
gy3 = target[b][t*21+8]*nH
gx4 = target[b][t*21+9]*nW
gy4 = target[b][t*21+10]*nH
gx5 = target[b][t*21+11]*nW
gy5 = target[b][t*21+12]*nH
gx6 = target[b][t*21+13]*nW
gy6 = target[b][t*21+14]*nH
gx7 = target[b][t*21+15]*nW
gy7 = target[b][t*21+16]*nH
gx8 = target[b][t*21+17]*nW
gy8 = target[b][t*21+18]*nH
cur_gt_corners = torch.FloatTensor([gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]).repeat(nAnchors,1).t() # 16 x nAnchors
cur_confs = torch.max(cur_confs, corner_confidences9(cur_pred_corners, cur_gt_corners)) # some irrelevant areas are filtered, in the same grid multiple anchor boxes might exceed the threshold
conf_mask[b][cur_confs>sil_thresh] = 0
if seen < -1:#6400:
tx0.fill_(0.5)
ty0.fill_(0.5)
tx1.fill_(0.5)
ty1.fill_(0.5)
tx2.fill_(0.5)
ty2.fill_(0.5)
tx3.fill_(0.5)
ty3.fill_(0.5)
tx4.fill_(0.5)
ty4.fill_(0.5)
tx5.fill_(0.5)
ty5.fill_(0.5)
tx6.fill_(0.5)
ty6.fill_(0.5)
tx7.fill_(0.5)
ty7.fill_(0.5)
tx8.fill_(0.5)
ty8.fill_(0.5)
coord_mask.fill_(1)
nGT = 0
nCorrect = 0
for b in range(nB):
for t in range(50):
if target[b][t*21+1] == 0:
break
nGT = nGT + 1
best_iou = 0.0
best_n = -1
min_dist = 10000
gx0 = target[b][t*21+1] * nW
gy0 = target[b][t*21+2] * nH
gi0 = int(gx0)
gj0 = int(gy0)
gx1 = target[b][t*21+3] * nW
gy1 = target[b][t*21+4] * nH
gx2 = target[b][t*21+5] * nW
gy2 = target[b][t*21+6] * nH
gx3 = target[b][t*21+7] * nW
gy3 = target[b][t*21+8] * nH
gx4 = target[b][t*21+9] * nW
gy4 = target[b][t*21+10] * nH
gx5 = target[b][t*21+11] * nW
gy5 = target[b][t*21+12] * nH
gx6 = target[b][t*21+13] * nW
gy6 = target[b][t*21+14] * nH
gx7 = target[b][t*21+15] * nW
gy7 = target[b][t*21+16] * nH
gx8 = target[b][t*21+17] * nW
gy8 = target[b][t*21+18] * nH
best_n = 0 # 1 anchor box
gt_box = [gx0/nW,gy0/nH,gx1/nW,gy1/nH,gx2/nW,gy2/nH,gx3/nW,gy3/nH,gx4/nW,gy4/nH,gx5/nW,gy5/nH,gx6/nW,gy6/nH,gx7/nW,gy7/nH,gx8/nW,gy8/nH]
pred_box = pred_corners[b*nAnchors+best_n*nPixels+gj0*nW+gi0]
conf = corner_confidence9(gt_box, pred_box)
coord_mask[b][best_n][gj0][gi0] = 1
cls_mask[b][best_n][gj0][gi0] = 1
conf_mask[b][best_n][gj0][gi0] = object_scale
tx0[b][best_n][gj0][gi0] = target[b][t*21+1] * nW - gi0
ty0[b][best_n][gj0][gi0] = target[b][t*21+2] * nH - gj0
tx1[b][best_n][gj0][gi0] = target[b][t*21+3] * nW - gi0
ty1[b][best_n][gj0][gi0] = target[b][t*21+4] * nH - gj0
tx2[b][best_n][gj0][gi0] = target[b][t*21+5] * nW - gi0
ty2[b][best_n][gj0][gi0] = target[b][t*21+6] * nH - gj0
tx3[b][best_n][gj0][gi0] = target[b][t*21+7] * nW - gi0
ty3[b][best_n][gj0][gi0] = target[b][t*21+8] * nH - gj0
tx4[b][best_n][gj0][gi0] = target[b][t*21+9] * nW - gi0
ty4[b][best_n][gj0][gi0] = target[b][t*21+10] * nH - gj0
tx5[b][best_n][gj0][gi0] = target[b][t*21+11] * nW - gi0
ty5[b][best_n][gj0][gi0] = target[b][t*21+12] * nH - gj0
tx6[b][best_n][gj0][gi0] = target[b][t*21+13] * nW - gi0
ty6[b][best_n][gj0][gi0] = target[b][t*21+14] * nH - gj0
tx7[b][best_n][gj0][gi0] = target[b][t*21+15] * nW - gi0
ty7[b][best_n][gj0][gi0] = target[b][t*21+16] * nH - gj0
tx8[b][best_n][gj0][gi0] = target[b][t*21+17] * nW - gi0
ty8[b][best_n][gj0][gi0] = target[b][t*21+18] * nH - gj0
tconf[b][best_n][gj0][gi0] = conf
tcls[b][best_n][gj0][gi0] = target[b][t*21]
if conf > 0.5:
nCorrect = nCorrect + 1
return nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls
class RegionLoss(nn.Module):
def __init__(self, num_classes=0, anchors=[], num_anchors=1):
super(RegionLoss, self).__init__()
self.num_classes = num_classes
self.anchors = anchors
self.num_anchors = num_anchors
self.anchor_step = len(anchors)/num_anchors
self.coord_scale = 1
self.noobject_scale = 1
self.object_scale = 5
self.class_scale = 1
self.thresh = 0.6
self.seen = 0
def forward(self, output, target):
# Parameters
t0 = time.time()
nB = output.data.size(0)
nA = self.num_anchors
nC = self.num_classes
nH = output.data.size(2)
nW = output.data.size(3)
# Activation
output = output.view(nB, nA, (19+nC), nH, nW)
x0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([0]))).view(nB, nA, nH, nW))
y0 = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([1]))).view(nB, nA, nH, nW))
x1 = output.index_select(2, Variable(torch.cuda.LongTensor([2]))).view(nB, nA, nH, nW)
y1 = output.index_select(2, Variable(torch.cuda.LongTensor([3]))).view(nB, nA, nH, nW)
x2 = output.index_select(2, Variable(torch.cuda.LongTensor([4]))).view(nB, nA, nH, nW)
y2 = output.index_select(2, Variable(torch.cuda.LongTensor([5]))).view(nB, nA, nH, nW)
x3 = output.index_select(2, Variable(torch.cuda.LongTensor([6]))).view(nB, nA, nH, nW)
y3 = output.index_select(2, Variable(torch.cuda.LongTensor([7]))).view(nB, nA, nH, nW)
x4 = output.index_select(2, Variable(torch.cuda.LongTensor([8]))).view(nB, nA, nH, nW)
y4 = output.index_select(2, Variable(torch.cuda.LongTensor([9]))).view(nB, nA, nH, nW)
x5 = output.index_select(2, Variable(torch.cuda.LongTensor([10]))).view(nB, nA, nH, nW)
y5 = output.index_select(2, Variable(torch.cuda.LongTensor([11]))).view(nB, nA, nH, nW)
x6 = output.index_select(2, Variable(torch.cuda.LongTensor([12]))).view(nB, nA, nH, nW)
y6 = output.index_select(2, Variable(torch.cuda.LongTensor([13]))).view(nB, nA, nH, nW)
x7 = output.index_select(2, Variable(torch.cuda.LongTensor([14]))).view(nB, nA, nH, nW)
y7 = output.index_select(2, Variable(torch.cuda.LongTensor([15]))).view(nB, nA, nH, nW)
x8 = output.index_select(2, Variable(torch.cuda.LongTensor([16]))).view(nB, nA, nH, nW)
y8 = output.index_select(2, Variable(torch.cuda.LongTensor([17]))).view(nB, nA, nH, nW)
conf = F.sigmoid(output.index_select(2, Variable(torch.cuda.LongTensor([18]))).view(nB, nA, nH, nW))
cls = output.index_select(2, Variable(torch.linspace(19,19+nC-1,nC).long().cuda()))
cls = cls.view(nB*nA, nC, nH*nW).transpose(1,2).contiguous().view(nB*nA*nH*nW, nC)
t1 = time.time()
# Create pred boxes
pred_corners = torch.cuda.FloatTensor(18, nB*nA*nH*nW)
grid_x = torch.linspace(0, nW-1, nW).repeat(nH,1).repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
grid_y = torch.linspace(0, nH-1, nH).repeat(nW,1).t().repeat(nB*nA, 1, 1).view(nB*nA*nH*nW).cuda()
pred_corners[0] = (x0.data + grid_x) / nW
pred_corners[1] = (y0.data + grid_y) / nH
pred_corners[2] = (x1.data + grid_x) / nW
pred_corners[3] = (y1.data + grid_y) / nH
pred_corners[4] = (x2.data + grid_x) / nW
pred_corners[5] = (y2.data + grid_y) / nH
pred_corners[6] = (x3.data + grid_x) / nW
pred_corners[7] = (y3.data + grid_y) / nH
pred_corners[8] = (x4.data + grid_x) / nW
pred_corners[9] = (y4.data + grid_y) / nH
pred_corners[10] = (x5.data + grid_x) / nW
pred_corners[11] = (y5.data + grid_y) / nH
pred_corners[12] = (x6.data + grid_x) / nW
pred_corners[13] = (y6.data + grid_y) / nH
pred_corners[14] = (x7.data + grid_x) / nW
pred_corners[15] = (y7.data + grid_y) / nH
pred_corners[16] = (x8.data + grid_x) / nW
pred_corners[17] = (y8.data + grid_y) / nH
gpu_matrix = pred_corners.transpose(0,1).contiguous().view(-1,18)
pred_corners = convert2cpu(gpu_matrix)
t2 = time.time()
# Build targets
nGT, nCorrect, coord_mask, conf_mask, cls_mask, tx0, tx1, tx2, tx3, tx4, tx5, tx6, tx7, tx8, ty0, ty1, ty2, ty3, ty4, ty5, ty6, ty7, ty8, tconf, tcls = \
build_targets(pred_corners, target.data, self.anchors, nA, nC, nH, nW, self.noobject_scale, self.object_scale, self.thresh, self.seen)
cls_mask = (cls_mask == 1)
nProposals = int((conf > 0.25).sum().data[0])
tx0 = Variable(tx0.cuda())
ty0 = Variable(ty0.cuda())
tx1 = Variable(tx1.cuda())
ty1 = Variable(ty1.cuda())
tx2 = Variable(tx2.cuda())
ty2 = Variable(ty2.cuda())
tx3 = Variable(tx3.cuda())
ty3 = Variable(ty3.cuda())
tx4 = Variable(tx4.cuda())
ty4 = Variable(ty4.cuda())
tx5 = Variable(tx5.cuda())
ty5 = Variable(ty5.cuda())
tx6 = Variable(tx6.cuda())
ty6 = Variable(ty6.cuda())
tx7 = Variable(tx7.cuda())
ty7 = Variable(ty7.cuda())
tx8 = Variable(tx8.cuda())
ty8 = Variable(ty8.cuda())
tconf = Variable(tconf.cuda())
tcls = Variable(tcls.view(-1)[cls_mask].long().cuda())
coord_mask = Variable(coord_mask.cuda())
conf_mask = Variable(conf_mask.cuda().sqrt())
cls_mask = Variable(cls_mask.view(-1, 1).repeat(1,nC).cuda())
cls = cls[cls_mask].view(-1, nC)
t3 = time.time()
# Create loss
loss_x0 = self.coord_scale * nn.MSELoss(size_average=False)(x0*coord_mask, tx0*coord_mask)/2.0
loss_y0 = self.coord_scale * nn.MSELoss(size_average=False)(y0*coord_mask, ty0*coord_mask)/2.0
loss_x1 = self.coord_scale * nn.MSELoss(size_average=False)(x1*coord_mask, tx1*coord_mask)/2.0
loss_y1 = self.coord_scale * nn.MSELoss(size_average=False)(y1*coord_mask, ty1*coord_mask)/2.0
loss_x2 = self.coord_scale * nn.MSELoss(size_average=False)(x2*coord_mask, tx2*coord_mask)/2.0
loss_y2 = self.coord_scale * nn.MSELoss(size_average=False)(y2*coord_mask, ty2*coord_mask)/2.0
loss_x3 = self.coord_scale * nn.MSELoss(size_average=False)(x3*coord_mask, tx3*coord_mask)/2.0
loss_y3 = self.coord_scale * nn.MSELoss(size_average=False)(y3*coord_mask, ty3*coord_mask)/2.0
loss_x4 = self.coord_scale * nn.MSELoss(size_average=False)(x4*coord_mask, tx4*coord_mask)/2.0
loss_y4 = self.coord_scale * nn.MSELoss(size_average=False)(y4*coord_mask, ty4*coord_mask)/2.0
loss_x5 = self.coord_scale * nn.MSELoss(size_average=False)(x5*coord_mask, tx5*coord_mask)/2.0
loss_y5 = self.coord_scale * nn.MSELoss(size_average=False)(y5*coord_mask, ty5*coord_mask)/2.0
loss_x6 = self.coord_scale * nn.MSELoss(size_average=False)(x6*coord_mask, tx6*coord_mask)/2.0
loss_y6 = self.coord_scale * nn.MSELoss(size_average=False)(y6*coord_mask, ty6*coord_mask)/2.0
loss_x7 = self.coord_scale * nn.MSELoss(size_average=False)(x7*coord_mask, tx7*coord_mask)/2.0
loss_y7 = self.coord_scale * nn.MSELoss(size_average=False)(y7*coord_mask, ty7*coord_mask)/2.0
loss_x8 = self.coord_scale * nn.MSELoss(size_average=False)(x8*coord_mask, tx8*coord_mask)/2.0
loss_y8 = self.coord_scale * nn.MSELoss(size_average=False)(y8*coord_mask, ty8*coord_mask)/2.0
loss_conf = nn.MSELoss(size_average=False)(conf*conf_mask, tconf*conf_mask)/2.0
# loss_cls = self.class_scale * nn.CrossEntropyLoss(size_average=False)(cls, tcls)
loss_cls = 0
loss_x = loss_x0 + loss_x1 + loss_x2 + loss_x3 + loss_x4 + loss_x5 + loss_x6 + loss_x7 + loss_x8
loss_y = loss_y0 + loss_y1 + loss_y2 + loss_y3 + loss_y4 + loss_y5 + loss_y6 + loss_y7 + loss_y8
if False:
loss = loss_x + loss_y + loss_conf + loss_cls
else:
loss = loss_x + loss_y + loss_conf
t4 = time.time()
if False:
print('-----------------------------------')
print(' activation : %f' % (t1 - t0))
print(' create pred_corners : %f' % (t2 - t1))
print(' build targets : %f' % (t3 - t2))
print(' create loss : %f' % (t4 - t3))
print(' total : %f' % (t4 - t0))
if False:
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, cls %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss_cls.data[0], loss.data[0]))
else:
print('%d: nGT %d, recall %d, proposals %d, loss: x %f, y %f, conf %f, total %f' % (self.seen, nGT, nCorrect, nProposals, loss_x.data[0], loss_y.data[0], loss_conf.data[0], loss.data[0]))
return loss

413
train.py Normal file
Просмотреть файл

@ -0,0 +1,413 @@
from __future__ import print_function
import sys
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import numpy as np
import os
import random
import math
import shutil
from torchvision import datasets, transforms
from torch.autograd import Variable # Useful info about autograd: http://pytorch.org/docs/master/notes/autograd.html
import dataset
from utils import *
from cfg import parse_cfg
from region_loss import RegionLoss
from darknet import Darknet
from MeshPly import MeshPly
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
# Adjust learning rate during training, learning schedule can be changed in network config file
def adjust_learning_rate(optimizer, batch):
lr = learning_rate
for i in range(len(steps)):
scale = scales[i] if i < len(scales) else 1
if batch >= steps[i]:
lr = lr * scale
if batch == steps[i]:
break
else:
break
for param_group in optimizer.param_groups:
param_group['lr'] = lr/batch_size
return lr
def train(epoch):
global processed_batches
# Initialize timer
t0 = time.time()
# Get the dataloader for training dataset
train_loader = torch.utils.data.DataLoader(dataset.listDataset(trainlist, shape=(init_width, init_height),
shuffle=True,
transform=transforms.Compose([transforms.ToTensor(),]),
train=True,
seen=model.seen,
batch_size=batch_size,
num_workers=num_workers, bg_file_names=bg_file_names),
batch_size=batch_size, shuffle=False, **kwargs)
# TRAINING
lr = adjust_learning_rate(optimizer, processed_batches)
logging('epoch %d, processed %d samples, lr %f' % (epoch, epoch * len(train_loader.dataset), lr))
# Start training
model.train()
t1 = time.time()
avg_time = torch.zeros(9)
niter = 0
# Iterate through batches
for batch_idx, (data, target) in enumerate(train_loader):
t2 = time.time()
# adjust learning rate
adjust_learning_rate(optimizer, processed_batches)
processed_batches = processed_batches + 1
# Pass the data to GPU
if use_cuda:
data = data.cuda()
t3 = time.time()
# Wrap tensors in Variable class for automatic differentiation
data, target = Variable(data), Variable(target)
t4 = time.time()
# Zero the gradients before running the backward pass
optimizer.zero_grad()
t5 = time.time()
# Forward pass
output = model(data)
t6 = time.time()
region_loss.seen = region_loss.seen + data.data.size(0)
# Compute loss, grow an array of losses for saving later on
loss = region_loss(output, target)
training_iters.append(epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter)
training_losses.append(convert2cpu(loss.data))
niter += 1
t7 = time.time()
# Backprop: compute gradient of the loss with respect to model parameters
loss.backward()
t8 = time.time()
# Update weights
optimizer.step()
t9 = time.time()
# Print time statistics
if False and batch_idx > 1:
avg_time[0] = avg_time[0] + (t2-t1)
avg_time[1] = avg_time[1] + (t3-t2)
avg_time[2] = avg_time[2] + (t4-t3)
avg_time[3] = avg_time[3] + (t5-t4)
avg_time[4] = avg_time[4] + (t6-t5)
avg_time[5] = avg_time[5] + (t7-t6)
avg_time[6] = avg_time[6] + (t8-t7)
avg_time[7] = avg_time[7] + (t9-t8)
avg_time[8] = avg_time[8] + (t9-t1)
print('-------------------------------')
print(' load data : %f' % (avg_time[0]/(batch_idx)))
print(' cpu to cuda : %f' % (avg_time[1]/(batch_idx)))
print('cuda to variable : %f' % (avg_time[2]/(batch_idx)))
print(' zero_grad : %f' % (avg_time[3]/(batch_idx)))
print(' forward feature : %f' % (avg_time[4]/(batch_idx)))
print(' forward loss : %f' % (avg_time[5]/(batch_idx)))
print(' backward : %f' % (avg_time[6]/(batch_idx)))
print(' step : %f' % (avg_time[7]/(batch_idx)))
print(' total : %f' % (avg_time[8]/(batch_idx)))
t1 = time.time()
t1 = time.time()
return epoch * math.ceil(len(train_loader.dataset) / float(batch_size) ) + niter - 1
def test(epoch, niter):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
# Set the module in evaluation mode (turn off dropout, batch normalization etc.)
model.eval()
# Parameters
num_classes = model.num_classes
anchors = model.anchors
num_anchors = model.num_anchors
testtime = True
testing_error_trans = 0.0
testing_error_angle = 0.0
testing_error_pixel = 0.0
testing_samples = 0.0
errs_2d = []
errs_3d = []
errs_trans = []
errs_angle = []
errs_corner2D = []
logging(" Testing...")
logging(" Number of test samples: %d" % len(test_loader.dataset))
notpredicted = 0
# Iterate through test examples
for batch_idx, (data, target) in enumerate(test_loader):
t1 = time.time()
# Pass the data to GPU
if use_cuda:
data = data.cuda()
target = target.cuda()
# Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference
data = Variable(data, volatile=True)
t2 = time.time()
# Formward pass
output = model(data).data
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
all_boxes = get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors)
t4 = time.time()
# Iterate through all batch elements
for i in range(output.size(0)):
# For each image, get all the predictions
boxes = all_boxes[i]
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
# Get how many object are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
# If the prediction has the highest confidence, choose it as our prediction
for j in range(len(boxes)):
if boxes[j][18] > best_conf_est:
best_conf_est = boxes[j][18]
box_pr = boxes[j]
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * im_width
corners2D_gt[:, 1] = corners2D_gt[:, 1] * im_height
corners2D_pr[:, 0] = corners2D_pr[:, 0] * im_width
corners2D_pr[:, 1] = corners2D_pr[:, 1] * im_height
# Compute corner prediction error
corner_norm = np.linalg.norm(corners2D_gt - corners2D_pr, axis=1)
corner_dist = np.mean(corner_norm)
errs_corner2D.append(corner_dist)
# Compute [R|t] by pnp
R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(internal_calibration, dtype='float32'))
R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))
# Compute errors
# Compute translation error
trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))
errs_trans.append(trans_dist)
# Compute angle error
angle_dist = calcAngularDistance(R_gt, R_pr)
errs_angle.append(angle_dist)
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
# Compute 3D distances
transform_3d_gt = compute_transformation(vertices, Rt_gt)
transform_3d_pred = compute_transformation(vertices, Rt_pr)
norm3d = np.linalg.norm(transform_3d_gt - transform_3d_pred, axis=0)
vertex_dist = np.mean(norm3d)
errs_3d.append(vertex_dist)
# Sum errors
testing_error_trans += trans_dist
testing_error_angle += angle_dist
testing_error_pixel += pixel_dist
testing_samples += 1
t5 = time.time()
# Compute 2D projection, 6D pose and 5cm5degree scores
px_threshold = 5
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
acc3d = len(np.where(np.array(errs_3d) <= vx_threshold)[0]) * 100. / (len(errs_3d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)
mean_err_2d = np.mean(errs_2d)
mean_corner_err_2d = np.mean(errs_corner2D)
nts = float(testing_samples)
if testtime:
print('-----------------------------------')
print(' tensor to cuda : %f' % (t2 - t1))
print(' predict : %f' % (t3 - t2))
print('get_region_boxes : %f' % (t4 - t3))
print(' eval : %f' % (t5 - t4))
print(' total : %f' % (t5 - t1))
print('-----------------------------------')
# Print test statistics
logging(" Mean corner error is %f" % (mean_corner_err_2d))
logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))
logging(' Acc using {} vx 3D Transformation = {:.2f}%'.format(vx_threshold, acc3d))
logging(' Acc using 5 cm 5 degree metric = {:.2f}%'.format(acc5cm5deg))
logging(' Acc using iou metric = {:.2f}%'.format(accious))
logging(' Translation error: %f, angle error: %f' % (testing_error_trans/(nts+eps), testing_error_angle/(nts+eps)) )
# Register losses and errors for saving later on
testing_iters.append(niter)
testing_errors_trans.append(testing_error_trans/(nts+eps))
testing_errors_angle.append(testing_error_angle/(nts+eps))
testing_errors_pixel.append(testing_error_pixel/(nts+eps))
testing_accuracies.append(acc)
if __name__ == "__main__":
# Training settings
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
weightfile = sys.argv[3]
# Parse configuration files
data_options = read_data_cfg(datacfg)
net_options = parse_cfg(cfgfile)[0]
trainlist = data_options['train']
testlist = data_options['valid']
nsamples = file_lines(trainlist)
gpus = data_options['gpus'] # e.g. 0,1,2,3
gpus = '0'
meshname = data_options['mesh']
num_workers = int(data_options['num_workers'])
backupdir = data_options['backup']
diam = float(data_options['diam'])
vx_threshold = diam * 0.1
if not os.path.exists(backupdir):
makedirs(backupdir)
batch_size = int(net_options['batch'])
max_batches = int(net_options['max_batches'])
learning_rate = float(net_options['learning_rate'])
momentum = float(net_options['momentum'])
decay = float(net_options['decay'])
steps = [float(step) for step in net_options['steps'].split(',')]
scales = [float(scale) for scale in net_options['scales'].split(',')]
bg_file_names = get_all_files('VOCdevkit/VOC2012/JPEGImages')
# Train parameters
max_epochs = 700 # max_batches*batch_size/nsamples+1
use_cuda = True
seed = int(time.time())
eps = 1e-5
save_interval = 10 # epoches
dot_interval = 70 # batches
best_acc = -1
# Test parameters
conf_thresh = 0.1
nms_thresh = 0.4
iou_thresh = 0.5
im_width = 640
im_height = 480
# Specify which gpus to use
torch.manual_seed(seed)
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
# Specifiy the model and the loss
model = Darknet(cfgfile)
region_loss = model.loss
# Model settings
# model.load_weights(weightfile)
model.load_weights_until_last(weightfile)
model.print_network()
model.seen = 0
region_loss.iter = model.iter
region_loss.seen = model.seen
processed_batches = model.seen/batch_size
init_width = model.width
init_height = model.height
test_width = 672
test_height = 672
init_epoch = model.seen/nsamples
# Variable to save
training_iters = []
training_losses = []
testing_iters = []
testing_losses = []
testing_errors_trans = []
testing_errors_angle = []
testing_errors_pixel = []
testing_accuracies = []
# Get the intrinsic camerea matrix, mesh, vertices and corners of the model
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
internal_calibration = get_camera_intrinsic()
# Specify the number of workers
kwargs = {'num_workers': num_workers, 'pin_memory': True} if use_cuda else {}
# Get the dataloader for test data
test_loader = torch.utils.data.DataLoader(dataset.listDataset(testlist, shape=(test_width, test_height),
shuffle=False,
transform=transforms.Compose([transforms.ToTensor(),]),
train=False),
batch_size=1, shuffle=False, **kwargs)
# Pass the model to GPU
if use_cuda:
model = model.cuda() # model = torch.nn.DataParallel(model, device_ids=[0]).cuda() # Multiple GPU parallelism
# Get the optimizer
params_dict = dict(model.named_parameters())
params = []
for key, value in params_dict.items():
if key.find('.bn') >= 0 or key.find('.bias') >= 0:
params += [{'params': [value], 'weight_decay': 0.0}]
else:
params += [{'params': [value], 'weight_decay': decay*batch_size}]
optimizer = optim.SGD(model.parameters(), lr=learning_rate/batch_size, momentum=momentum, dampening=0, weight_decay=decay*batch_size)
# optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam optimization
evaluate = False
if evaluate:
logging('evaluating ...')
test(0, 0)
else:
for epoch in range(init_epoch, max_epochs):
# TRAIN
niter = train(epoch)
# TEST and SAVE
if (epoch % 10 == 0) and (epoch is not 0):
test(epoch, niter)
logging('save training stats to %s/costs.npz' % (backupdir))
np.savez(os.path.join(backupdir, "costs.npz"),
training_iters=training_iters,
training_losses=training_losses,
testing_iters=testing_iters,
testing_accuracies=testing_accuracies,
testing_errors_pixel=testing_errors_pixel,
testing_errors_angle=testing_errors_angle)
if (testing_accuracies[-1] > best_acc ):
best_acc = testing_accuracies[-1]
logging('best model so far!')
logging('save weights to %s/model.weights' % (backupdir))
model.save_weights('%s/model.weights' % (backupdir))
shutil.copy2('%s/model.weights' % (backupdir), '%s/model_backup.weights' % (backupdir))

1058
utils.py Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

357
valid.ipynb Normal file
Просмотреть файл

@ -0,0 +1,357 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import os\n",
"import time\n",
"import torch\n",
"from torch.autograd import Variable\n",
"from torchvision import datasets, transforms\n",
"import scipy.io\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"import matplotlib.pyplot as plt\n",
"import scipy.misc\n",
"\n",
"from darknet import Darknet\n",
"import dataset\n",
"from utils import *\n",
"from MeshPly import MeshPly\n",
"\n",
"# Create new directory\n",
"def makedirs(path):\n",
" if not os.path.exists( path ):\n",
" os.makedirs( path )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "IOError",
"evalue": "[Errno 2] No such file or directory: 'LINEMOD/ape/ape.ply'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIOError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-cd23ddeac3d5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0mcfgfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cfg/yolo-pose.cfg'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0mweightfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'backup/ape/model_backup.weights'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m \u001b[0mvalid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatacfg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcfgfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweightfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-7-cd23ddeac3d5>\u001b[0m in \u001b[0;36mvalid\u001b[0;34m(datacfg, cfgfile, weightfile)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;31m# Read object model information, get 3D bounding box corners\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mmesh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMeshPly\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmeshname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0mvertices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmesh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmesh\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0mcorners3D\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_3D_corners\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvertices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/cvlabdata1/home/btekin/ope/singleshotpose_release/MeshPly.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filename, color)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'r'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mIOError\u001b[0m: [Errno 2] No such file or directory: 'LINEMOD/ape/ape.ply'"
]
}
],
"source": [
"def valid(datacfg, cfgfile, weightfile):\n",
" def truths_length(truths):\n",
" for i in range(50):\n",
" if truths[i][1] == 0:\n",
" return i\n",
"\n",
" # Parse configuration files\n",
" options = read_data_cfg(datacfg)\n",
" valid_images = options['valid']\n",
" meshname = options['mesh']\n",
" backupdir = options['backup']\n",
" name = options['name']\n",
" if not os.path.exists(backupdir):\n",
" makedirs(backupdir)\n",
"\n",
" # Parameters\n",
" prefix = 'results'\n",
" seed = int(time.time())\n",
" gpus = '0' # Specify which gpus to use\n",
" test_width = 544\n",
" test_height = 544\n",
" torch.manual_seed(seed)\n",
" use_cuda = True\n",
" if use_cuda:\n",
" os.environ['CUDA_VISIBLE_DEVICES'] = gpus\n",
" torch.cuda.manual_seed(seed)\n",
" save = False\n",
" visualize = True\n",
" testtime = True\n",
" use_cuda = True\n",
" num_classes = 1\n",
" testing_samples = 0.0\n",
" eps = 1e-5\n",
" notpredicted = 0 \n",
" conf_thresh = 0.1\n",
" nms_thresh = 0.4\n",
" match_thresh = 0.5\n",
" edges_corners = [[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]]\n",
"\n",
" if save:\n",
" makedirs(backupdir + '/test')\n",
" makedirs(backupdir + '/test/gt')\n",
" makedirs(backupdir + '/test/pr')\n",
"\n",
" # To save\n",
" testing_error_trans = 0.0\n",
" testing_error_angle = 0.0\n",
" testing_error_pixel = 0.0\n",
" errs_2d = []\n",
" errs_3d = []\n",
" errs_trans = []\n",
" errs_angle = []\n",
" errs_corner2D = []\n",
" preds_trans = []\n",
" preds_rot = []\n",
" preds_corners2D = []\n",
" gts_trans = []\n",
" gts_rot = []\n",
" gts_corners2D = []\n",
" ious = []\n",
"\n",
" # Read object model information, get 3D bounding box corners\n",
" mesh = MeshPly(meshname)\n",
" vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()\n",
" corners3D = get_3D_corners(vertices)\n",
" # diam = calc_pts_diameter(np.array(mesh.vertices))\n",
" diam = float(options['diam'])\n",
"\n",
" # Read intrinsic camera parameters\n",
" internal_calibration = get_camera_intrinsic()\n",
"\n",
" # Get validation file names\n",
" with open(valid_images) as fp:\n",
" tmp_files = fp.readlines()\n",
" valid_files = [item.rstrip() for item in tmp_files]\n",
" \n",
" # Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode\n",
" model = Darknet(cfgfile)\n",
" model.print_network()\n",
" model.load_weights(weightfile)\n",
" model.cuda()\n",
" model.eval()\n",
"\n",
" # Get the parser for the test dataset\n",
" valid_dataset = dataset.listDataset(valid_images, shape=(test_width, test_height),\n",
" shuffle=False,\n",
" transform=transforms.Compose([\n",
" transforms.ToTensor(),]))\n",
" valid_batchsize = 1\n",
"\n",
" # Specify the number of workers for multiple processing, get the dataloader for the test dataset\n",
" kwargs = {'num_workers': 4, 'pin_memory': True}\n",
" test_loader = torch.utils.data.DataLoader(\n",
" valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs) \n",
"\n",
" logging(\" Testing {}...\".format(name))\n",
" logging(\" Number of test samples: %d\" % len(test_loader.dataset))\n",
" # Iterate through test batches (Batch size for test data is 1)\n",
" count = 0\n",
" z = np.zeros((3, 1))\n",
" for batch_idx, (data, target) in enumerate(test_loader):\n",
" \n",
" # Images\n",
" img = data[0, :, :, :]\n",
" img = img.numpy().squeeze()\n",
" img = np.transpose(img, (1, 2, 0))\n",
" \n",
" t1 = time.time()\n",
" # Pass data to GPU\n",
" if use_cuda:\n",
" data = data.cuda()\n",
" target = target.cuda()\n",
" \n",
" # Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference\n",
" data = Variable(data, volatile=True)\n",
" t2 = time.time()\n",
" \n",
" # Forward pass\n",
" output = model(data).data \n",
" t3 = time.time()\n",
" \n",
" # Using confidence threshold, eliminate low-confidence predictions\n",
" all_boxes = get_region_boxes(output, conf_thresh, num_classes) \n",
" t4 = time.time()\n",
"\n",
" # Iterate through all images in the batch\n",
" for i in range(output.size(0)):\n",
" \n",
" # For each image, get all the predictions\n",
" boxes = all_boxes[i]\n",
" \n",
" # For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)\n",
" truths = target[i].view(-1, 21)\n",
" \n",
" # Get how many object are present in the scene\n",
" num_gts = truths_length(truths)\n",
"\n",
" # Iterate through each ground-truth object\n",
" for k in range(num_gts):\n",
" box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6], \n",
" truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12], \n",
" truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]\n",
" best_conf_est = -1\n",
"\n",
" # If the prediction has the highest confidence, choose it as our prediction for single object pose estimation\n",
" for j in range(len(boxes)):\n",
" if (boxes[j][18] > best_conf_est):\n",
" match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))\n",
" box_pr = boxes[j]\n",
" best_conf_est = boxes[j][18]\n",
"\n",
" # Denormalize the corner predictions \n",
" corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')\n",
" corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')\n",
" corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640\n",
" corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480 \n",
" corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640\n",
" corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480\n",
" preds_corners2D.append(corners2D_pr)\n",
" gts_corners2D.append(corners2D_gt)\n",
"\n",
" # Compute corner prediction error\n",
" corner_norm = np.linalg.norm(corners2D_gt - corners2D_pr, axis=1)\n",
" corner_dist = np.mean(corner_norm)\n",
" errs_corner2D.append(corner_dist)\n",
" \n",
" # Compute [R|t] by pnp\n",
" R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(internal_calibration, dtype='float32'))\n",
" R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))\n",
"\n",
" if save:\n",
" preds_trans.append(t_pr)\n",
" gts_trans.append(t_gt)\n",
" preds_rot.append(R_pr)\n",
" gts_rot.append(R_gt)\n",
"\n",
" np.savetxt(backupdir + '/test/gt/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/gt/t_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(t_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/gt/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_gt, dtype='float32'))\n",
" np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))\n",
" \n",
" # Compute translation error\n",
" trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))\n",
" errs_trans.append(trans_dist)\n",
" \n",
" # Compute angle error\n",
" angle_dist = calcAngularDistance(R_gt, R_pr)\n",
" errs_angle.append(angle_dist)\n",
" \n",
" # Compute pixel error\n",
" Rt_gt = np.concatenate((R_gt, t_gt), axis=1)\n",
" Rt_pr = np.concatenate((R_pr, t_pr), axis=1)\n",
" proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)\n",
" proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration) \n",
" proj_corners_gt = np.transpose(compute_projection(corners3D, Rt_gt, internal_calibration)) \n",
" proj_corners_pr = np.transpose(compute_projection(corners3D, Rt_pr, internal_calibration)) \n",
" norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)\n",
" pixel_dist = np.mean(norm)\n",
" errs_2d.append(pixel_dist)\n",
"\n",
" if visualize:\n",
" # Visualize\n",
" plt.xlim((0, 640))\n",
" plt.ylim((0, 480))\n",
" plt.imshow(scipy.misc.imresize(img, (480, 640)))\n",
" # Projections\n",
" for edge in edges_corners:\n",
" plt.plot(proj_corners_gt[edge, 0], proj_corners_gt[edge, 1], color='g', linewidth=3.0)\n",
" plt.plot(proj_corners_pr[edge, 0], proj_corners_pr[edge, 1], color='b', linewidth=3.0)\n",
" plt.gca().invert_yaxis()\n",
" plt.show()\n",
" \n",
" # Compute IoU score\n",
" bb_gt = compute_2d_bb_from_orig_pix(proj_2d_gt, output.size(3))\n",
" bb_pred = compute_2d_bb_from_orig_pix(proj_2d_pred, output.size(3))\n",
" iou = bbox_iou(bb_gt, bb_pred)\n",
" ious.append(iou)\n",
"\n",
" # Compute 3D distances\n",
" transform_3d_gt = compute_transformation(vertices, Rt_gt) \n",
" transform_3d_pred = compute_transformation(vertices, Rt_pr) \n",
" norm3d = np.linalg.norm(transform_3d_gt - transform_3d_pred, axis=0)\n",
" vertex_dist = np.mean(norm3d) \n",
" errs_3d.append(vertex_dist) \n",
"\n",
" # Sum errors\n",
" testing_error_trans += trans_dist\n",
" testing_error_angle += angle_dist\n",
" testing_error_pixel += pixel_dist\n",
" testing_samples += 1\n",
" count = count + 1\n",
"\n",
" t5 = time.time()\n",
"\n",
" # Compute 2D projection error, 6D pose error, 5cm5degree error\n",
" px_threshold = 5\n",
" acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)\n",
" acciou = len(np.where(np.array(errs_2d) >= 0.5)[0]) * 100. / (len(ious)+eps)\n",
" acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)\n",
" acc3d10 = len(np.where(np.array(errs_3d) <= diam * 0.1)[0]) * 100. / (len(errs_3d)+eps)\n",
" acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)\n",
" corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)\n",
" mean_err_2d = np.mean(errs_2d)\n",
" mean_corner_err_2d = np.mean(errs_corner2D)\n",
" nts = float(testing_samples)\n",
"\n",
" if testtime:\n",
" print('-----------------------------------')\n",
" print(' tensor to cuda : %f' % (t2 - t1))\n",
" print(' predict : %f' % (t3 - t2))\n",
" print('get_region_boxes : %f' % (t4 - t3))\n",
" print(' nms : %f' % (t5 - t4))\n",
" print(' total : %f' % (t5 - t1))\n",
" print('-----------------------------------')\n",
"\n",
" # Print test statistics\n",
" logging('Results of {}'.format(name))\n",
" logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))\n",
" logging(' Acc using the IoU metric = {:.6f}%'.format(acciou))\n",
" logging(' Acc using 10% threshold - {} vx 3D Transformation = {:.2f}%'.format(diam * 0.1, acc3d10))\n",
" logging(' Acc using 5 cm 5 degree metric = {:.2f}%'.format(acc5cm5deg))\n",
" logging(\" Mean 2D pixel error is %f, Mean vertex error is %f, mean corner error is %f\" % (mean_err_2d, np.mean(errs_3d), mean_corner_err_2d))\n",
" logging(' Translation error: %f m, angle error: %f degree, pixel error: % f pix' % (testing_error_trans/nts, testing_error_angle/nts, testing_error_pixel/nts) )\n",
"\n",
" if save:\n",
" predfile = backupdir + '/predictions_linemod_' + name + '.mat'\n",
" scipy.io.savemat(predfile, {'R_gts': gts_rot, 't_gts':gts_trans, 'corner_gts': gts_corners2D, 'R_prs': preds_rot, 't_prs':preds_trans, 'corner_prs': preds_corners2D})\n",
"\n",
"datacfg = 'cfg/ape.data'\n",
"cfgfile = 'cfg/yolo-pose.cfg'\n",
"weightfile = 'backup/ape/model_backup.weights'\n",
"valid(datacfg, cfgfile, weightfile)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

269
valid.py Normal file
Просмотреть файл

@ -0,0 +1,269 @@
import os
import time
import torch
from torch.autograd import Variable
from torchvision import datasets, transforms
import scipy.io
import warnings
warnings.filterwarnings("ignore")
from darknet import Darknet
import dataset
from utils import *
from MeshPly import MeshPly
# Create new directory
def makedirs(path):
if not os.path.exists( path ):
os.makedirs( path )
def valid(datacfg, cfgfile, weightfile, outfile):
def truths_length(truths):
for i in range(50):
if truths[i][1] == 0:
return i
# Parse configuration files
options = read_data_cfg(datacfg)
valid_images = options['valid']
meshname = options['mesh']
backupdir = options['backup']
name = options['name']
if not os.path.exists(backupdir):
makedirs(backupdir)
# Parameters
prefix = 'results'
seed = int(time.time())
gpus = '0' # Specify which gpus to use
test_width = 544
test_height = 544
torch.manual_seed(seed)
use_cuda = True
if use_cuda:
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
torch.cuda.manual_seed(seed)
save = False
testtime = True
use_cuda = True
num_classes = 1
testing_samples = 0.0
eps = 1e-5
notpredicted = 0
conf_thresh = 0.1
nms_thresh = 0.4
match_thresh = 0.5
if save:
makedirs(backupdir + '/test')
makedirs(backupdir + '/test/gt')
makedirs(backupdir + '/test/pr')
# To save
testing_error_trans = 0.0
testing_error_angle = 0.0
testing_error_pixel = 0.0
errs_2d = []
errs_3d = []
errs_trans = []
errs_angle = []
errs_corner2D = []
preds_trans = []
preds_rot = []
preds_corners2D = []
gts_trans = []
gts_rot = []
gts_corners2D = []
# Read object model information, get 3D bounding box corners
mesh = MeshPly(meshname)
vertices = np.c_[np.array(mesh.vertices), np.ones((len(mesh.vertices), 1))].transpose()
corners3D = get_3D_corners(vertices)
# diam = calc_pts_diameter(np.array(mesh.vertices))
diam = float(options['diam'])
# Read intrinsic camera parameters
internal_calibration = get_camera_intrinsic()
# Get validation file names
with open(valid_images) as fp:
tmp_files = fp.readlines()
valid_files = [item.rstrip() for item in tmp_files]
# Specicy model, load pretrained weights, pass to GPU and set the module in evaluation mode
model = Darknet(cfgfile)
model.print_network()
model.load_weights(weightfile)
model.cuda()
model.eval()
# Get the parser for the test dataset
valid_dataset = dataset.listDataset(valid_images, shape=(test_width, test_height),
shuffle=False,
transform=transforms.Compose([
transforms.ToTensor(),]))
valid_batchsize = 1
# Specify the number of workers for multiple processing, get the dataloader for the test dataset
kwargs = {'num_workers': 4, 'pin_memory': True}
test_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=valid_batchsize, shuffle=False, **kwargs)
logging(" Testing {}...".format(name))
logging(" Number of test samples: %d" % len(test_loader.dataset))
# Iterate through test batches (Batch size for test data is 1)
count = 0
z = np.zeros((3, 1))
for batch_idx, (data, target) in enumerate(test_loader):
t1 = time.time()
# Pass data to GPU
if use_cuda:
data = data.cuda()
target = target.cuda()
# Wrap tensors in Variable class, set volatile=True for inference mode and to use minimal memory during inference
data = Variable(data, volatile=True)
t2 = time.time()
# Forward pass
output = model(data).data
t3 = time.time()
# Using confidence threshold, eliminate low-confidence predictions
all_boxes = get_region_boxes(output, conf_thresh, num_classes)
t4 = time.time()
# Iterate through all images in the batch
for i in range(output.size(0)):
# For each image, get all the predictions
boxes = all_boxes[i]
# For each image, get all the targets (for multiple object pose estimation, there might be more than 1 target per image)
truths = target[i].view(-1, 21)
# Get how many object are present in the scene
num_gts = truths_length(truths)
# Iterate through each ground-truth object
for k in range(num_gts):
box_gt = [truths[k][1], truths[k][2], truths[k][3], truths[k][4], truths[k][5], truths[k][6],
truths[k][7], truths[k][8], truths[k][9], truths[k][10], truths[k][11], truths[k][12],
truths[k][13], truths[k][14], truths[k][15], truths[k][16], truths[k][17], truths[k][18], 1.0, 1.0, truths[k][0]]
best_conf_est = -1
# If the prediction has the highest confidence, choose it as our prediction for single object pose estimation
for j in range(len(boxes)):
if (boxes[j][18] > best_conf_est):
match = corner_confidence9(box_gt[:18], torch.FloatTensor(boxes[j][:18]))
box_pr = boxes[j]
best_conf_est = boxes[j][18]
# Denormalize the corner predictions
corners2D_gt = np.array(np.reshape(box_gt[:18], [9, 2]), dtype='float32')
corners2D_pr = np.array(np.reshape(box_pr[:18], [9, 2]), dtype='float32')
corners2D_gt[:, 0] = corners2D_gt[:, 0] * 640
corners2D_gt[:, 1] = corners2D_gt[:, 1] * 480
corners2D_pr[:, 0] = corners2D_pr[:, 0] * 640
corners2D_pr[:, 1] = corners2D_pr[:, 1] * 480
preds_corners2D.append(corners2D_pr)
gts_corners2D.append(corners2D_gt)
# Compute corner prediction error
corner_norm = np.linalg.norm(corners2D_gt - corners2D_pr, axis=1)
corner_dist = np.mean(corner_norm)
errs_corner2D.append(corner_dist)
# Compute [R|t] by pnp
R_gt, t_gt = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_gt, np.array(internal_calibration, dtype='float32'))
R_pr, t_pr = pnp(np.array(np.transpose(np.concatenate((np.zeros((3, 1)), corners3D[:3, :]), axis=1)), dtype='float32'), corners2D_pr, np.array(internal_calibration, dtype='float32'))
if save:
preds_trans.append(t_pr)
gts_trans.append(t_gt)
preds_rot.append(R_pr)
gts_rot.append(R_gt)
np.savetxt(backupdir + '/test/gt/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_gt, dtype='float32'))
np.savetxt(backupdir + '/test/gt/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_gt, dtype='float32'))
np.savetxt(backupdir + '/test/pr/R_' + valid_files[count][-8:-3] + 'txt', np.array(R_pr, dtype='float32'))
np.savetxt(backupdir + '/test/pr/t_' + valid_files[count][-8:-3] + 'txt', np.array(t_pr, dtype='float32'))
np.savetxt(backupdir + '/test/gt/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_gt, dtype='float32'))
np.savetxt(backupdir + '/test/pr/corners_' + valid_files[count][-8:-3] + 'txt', np.array(corners2D_pr, dtype='float32'))
# Compute translation error
trans_dist = np.sqrt(np.sum(np.square(t_gt - t_pr)))
errs_trans.append(trans_dist)
# Compute angle error
angle_dist = calcAngularDistance(R_gt, R_pr)
errs_angle.append(angle_dist)
# Compute pixel error
Rt_gt = np.concatenate((R_gt, t_gt), axis=1)
Rt_pr = np.concatenate((R_pr, t_pr), axis=1)
proj_2d_gt = compute_projection(vertices, Rt_gt, internal_calibration)
proj_2d_pred = compute_projection(vertices, Rt_pr, internal_calibration)
norm = np.linalg.norm(proj_2d_gt - proj_2d_pred, axis=0)
pixel_dist = np.mean(norm)
errs_2d.append(pixel_dist)
# Compute 3D distances
transform_3d_gt = compute_transformation(vertices, Rt_gt)
transform_3d_pred = compute_transformation(vertices, Rt_pr)
norm3d = np.linalg.norm(transform_3d_gt - transform_3d_pred, axis=0)
vertex_dist = np.mean(norm3d)
errs_3d.append(vertex_dist)
# Sum errors
testing_error_trans += trans_dist
testing_error_angle += angle_dist
testing_error_pixel += pixel_dist
testing_samples += 1
count = count + 1
t5 = time.time()
# Compute 2D projection error, 6D pose error, 5cm5degree error
px_threshold = 5
acc = len(np.where(np.array(errs_2d) <= px_threshold)[0]) * 100. / (len(errs_2d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
acc3d10 = len(np.where(np.array(errs_3d) <= diam * 0.1)[0]) * 100. / (len(errs_3d)+eps)
acc5cm5deg = len(np.where((np.array(errs_trans) <= 0.05) & (np.array(errs_angle) <= 5))[0]) * 100. / (len(errs_trans)+eps)
corner_acc = len(np.where(np.array(errs_corner2D) <= px_threshold)[0]) * 100. / (len(errs_corner2D)+eps)
mean_err_2d = np.mean(errs_2d)
mean_corner_err_2d = np.mean(errs_corner2D)
nts = float(testing_samples)
if testtime:
print('-----------------------------------')
print(' tensor to cuda : %f' % (t2 - t1))
print(' predict : %f' % (t3 - t2))
print('get_region_boxes : %f' % (t4 - t3))
print(' eval : %f' % (t5 - t4))
print(' total : %f' % (t5 - t1))
print('-----------------------------------')
# Print test statistics
logging('Results of {}'.format(name))
logging(' Acc using {} px 2D Projection = {:.2f}%'.format(px_threshold, acc))
logging(' Acc using 10% threshold - {} vx 3D Transformation = {:.2f}%'.format(diam * 0.1, acc3d10))
logging(' Acc using 5 cm 5 degree metric = {:.2f}%'.format(acc5cm5deg))
logging(" Mean 2D pixel error is %f, Mean vertex error is %f, mean corner error is %f" % (mean_err_2d, np.mean(errs_3d), mean_corner_err_2d))
logging(' Translation error: %f m, angle error: %f degree, pixel error: % f pix' % (testing_error_trans/nts, testing_error_angle/nts, testing_error_pixel/nts) )
if save:
predfile = backupdir + '/predictions_linemod_' + name + '.mat'
scipy.io.savemat(predfile, {'R_gts': gts_rot, 't_gts':gts_trans, 'corner_gts': gts_corners2D, 'R_prs': preds_rot, 't_prs':preds_trans, 'corner_prs': preds_corners2D})
if __name__ == '__main__':
import sys
if len(sys.argv) == 4:
datacfg = sys.argv[1]
cfgfile = sys.argv[2]
weightfile = sys.argv[3]
outfile = 'comp4_det_test_'
valid(datacfg, cfgfile, weightfile, outfile)
else:
print('Usage:')
print(' python valid.py datacfg cfgfile weightfile')