[TOPI][CUDA] Add faster-rcnn proposal op (#2420)
* [TOPI][CUDA] Add faster-rcnn proposal op * Fix doc * Add global barrier * Use vthread in argsort * Update sort and nms ir * Fix lint * Update sort ir in ssd nms
This commit is contained in:
Родитель
6b0157bfa2
Коммит
d20646c790
|
@ -18,3 +18,4 @@ from .vision import *
|
|||
from . import ssd
|
||||
from .ssd import *
|
||||
from .nms import *
|
||||
from .rcnn import *
|
||||
|
|
|
@ -35,7 +35,7 @@ def sort_ir(data, index, output):
|
|||
p_index = ib.buffer_ptr(index)
|
||||
p_out = ib.buffer_ptr(output)
|
||||
nthread_tx = max_threads
|
||||
nthread_bx = num_anchors // max_threads + 1
|
||||
nthread_bx = (num_anchors + 1) // 2 // max_threads + 1
|
||||
tx = tvm.thread_axis("threadIdx.x")
|
||||
bx = tvm.thread_axis("vthread")
|
||||
ib.scope_attr(tx, "thread_extent", nthread_tx)
|
||||
|
@ -46,8 +46,10 @@ def sort_ir(data, index, output):
|
|||
|
||||
with ib.for_range(0, batch, for_type="unroll") as b:
|
||||
start = b * num_anchors
|
||||
with ib.if_scope(tid < num_anchors):
|
||||
p_out[start + tid] = tid
|
||||
for i in range(2):
|
||||
bbox_id = tid * 2 + i
|
||||
with ib.if_scope(bbox_id < num_anchors):
|
||||
p_out[start + bbox_id] = bbox_id
|
||||
# OddEvenTransposeSort
|
||||
with ib.for_range(0, p_index[b]) as k:
|
||||
with ib.if_scope(tid < (p_index[b] + 1) // 2):
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
# pylint: disable=wildcard-import
|
||||
"""Faster R-CNN and Mask R-CNN operators"""
|
||||
from .proposal import *
|
|
@ -0,0 +1,356 @@
|
|||
# pylint: disable=invalid-name, singleton-comparison
|
||||
"""Proposal operator"""
|
||||
import math
|
||||
import tvm
|
||||
from ...vision.rcnn import proposal, generate_anchor, reg_bbox, reg_iou
|
||||
from ...util import get_const_tuple, get_const_int
|
||||
|
||||
|
||||
def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios,
|
||||
feature_stride, rpn_min_size, iou_loss):
|
||||
"""Predict bounding boxes based on anchors, scores and deltas.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cls_prob_buf : tvm.schedule.Buffer
|
||||
4-D with shape [batch, 2 * num_anchors, height, width]
|
||||
|
||||
bbox_pred_buf : tvm.schedule.Buffer
|
||||
4-D with shape [batch, 4 * num_anchors, height, width]
|
||||
|
||||
im_info_buf : tvm.schedule.Buffer
|
||||
2-D with shape [batch, 3]
|
||||
|
||||
out_buf : tvm.schedule.Buffer
|
||||
3-D with shape [batch, num_bbox, 5]
|
||||
The last dimension is in format of [w_start, h_start, w_end, h_end, score]
|
||||
|
||||
scales : list/tuple of float
|
||||
Scales of anchor windoes.
|
||||
|
||||
ratios : list/tuple of float
|
||||
Ratios of anchor windoes.
|
||||
|
||||
feature_stride : int
|
||||
The size of the receptive field each unit in the convolution layer of the rpn, for example
|
||||
the product of all stride's prior to this layer.
|
||||
|
||||
rpn_min_size : int
|
||||
Minimum height or width in proposal.
|
||||
|
||||
iou_loss : bool
|
||||
Usage of IoU loss.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stmt : Stmt
|
||||
The result IR statement.
|
||||
"""
|
||||
batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape)
|
||||
num_anchors //= 2
|
||||
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
|
||||
nthread_tx = max_threads
|
||||
nthread_bx = (batch * height * width) // max_threads + 1
|
||||
tx = tvm.thread_axis("threadIdx.x")
|
||||
bx = tvm.thread_axis("blockIdx.x")
|
||||
tid = bx * max_threads + tx
|
||||
ib = tvm.ir_builder.create()
|
||||
ib.scope_attr(tx, "thread_extent", nthread_tx)
|
||||
ib.scope_attr(bx, "thread_extent", nthread_bx)
|
||||
|
||||
p_score = ib.buffer_ptr(cls_prob_buf)
|
||||
p_delta = ib.buffer_ptr(bbox_pred_buf)
|
||||
p_im_info = ib.buffer_ptr(im_info_buf)
|
||||
p_out = ib.buffer_ptr(out_buf)
|
||||
|
||||
with ib.if_scope(tid < batch * height * width):
|
||||
w = tid % width
|
||||
h = (tid // width) % height
|
||||
b = tid // width // height
|
||||
|
||||
for k in range(num_anchors):
|
||||
out_index = tid * num_anchors + k
|
||||
ratio = ratios[k // len(scales)]
|
||||
scale = scales[k % len(scales)]
|
||||
anchor = generate_anchor(ratio, scale, feature_stride)
|
||||
im_height = p_im_info[b * 3]
|
||||
im_width = p_im_info[b * 3 + 1]
|
||||
x1 = anchor[0] + w * feature_stride
|
||||
y1 = anchor[1] + h * feature_stride
|
||||
x2 = anchor[2] + w * feature_stride
|
||||
y2 = anchor[3] + h * feature_stride
|
||||
|
||||
delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
|
||||
for i in range(4)]
|
||||
regression_func = reg_iou if iou_loss else reg_bbox
|
||||
pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta)
|
||||
|
||||
pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0)
|
||||
pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0)
|
||||
pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0)
|
||||
pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0)
|
||||
|
||||
real_height = (im_height / feature_stride).astype('int32')
|
||||
real_width = (im_width / feature_stride).astype('int32')
|
||||
|
||||
bbox_w = pred_x2 - pred_x1 + 1.0
|
||||
bbox_h = pred_y2 - pred_y1 + 1.0
|
||||
min_size = p_im_info[b * 3 + 2] * rpn_min_size
|
||||
|
||||
pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w]
|
||||
pred_score = tvm.expr.Select(tvm.any(h >= real_height, w >= real_width),
|
||||
-1.0, pred_score)
|
||||
p_out[out_index * 5 + 0] = pred_x1
|
||||
p_out[out_index * 5 + 1] = pred_y1
|
||||
p_out[out_index * 5 + 2] = pred_x2
|
||||
p_out[out_index * 5 + 3] = pred_y2
|
||||
p_out[out_index * 5 + 4] = pred_score
|
||||
|
||||
with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)):
|
||||
p_out[out_index * 5 + 0] -= min_size / 2.0
|
||||
p_out[out_index * 5 + 1] -= min_size / 2.0
|
||||
p_out[out_index * 5 + 2] += min_size / 2.0
|
||||
p_out[out_index * 5 + 3] += min_size / 2.0
|
||||
p_out[out_index * 5 + 4] = -1.0
|
||||
|
||||
return ib.get()
|
||||
|
||||
|
||||
def argsort_ir(data_buf, out_index_buf):
|
||||
"""Batched odd-even transposition sort.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_buf : tvm.schedule.Buffer
|
||||
2-D with shape [batch, num_bbox]
|
||||
|
||||
out_index_buf : tvm.schedule.Buffer
|
||||
2-D with shape [batch, num_bbox]. Indices of data in sorted order.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stmt : Stmt
|
||||
The result IR statement.
|
||||
"""
|
||||
batch, num_bbox = get_const_tuple(data_buf.shape)
|
||||
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
|
||||
ib = tvm.ir_builder.create()
|
||||
p_data = ib.buffer_ptr(data_buf)
|
||||
index_out = ib.buffer_ptr(out_index_buf)
|
||||
nthread_tx = max_threads
|
||||
nthread_bx = (num_bbox + 1) // 2 // max_threads + 1
|
||||
tx = tvm.thread_axis("threadIdx.x")
|
||||
bx = tvm.thread_axis("vthread")
|
||||
ib.scope_attr(tx, "thread_extent", nthread_tx)
|
||||
ib.scope_attr(bx, "virtual_thread", nthread_bx)
|
||||
tid = bx * nthread_tx + tx
|
||||
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
|
||||
temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
|
||||
|
||||
with ib.for_range(0, batch, for_type="unroll") as b:
|
||||
start = b * num_bbox
|
||||
for i in range(2):
|
||||
bbox_id = tid * 2 + i
|
||||
with ib.if_scope(bbox_id < num_bbox):
|
||||
index_out[start + bbox_id] = bbox_id
|
||||
with ib.for_range(0, num_bbox) as k:
|
||||
offset = start + 2 * tid + (k % 2)
|
||||
with ib.if_scope(
|
||||
tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
|
||||
temp_data[0] = p_data[offset]
|
||||
p_data[offset] = p_data[offset + 1]
|
||||
p_data[offset + 1] = temp_data[0]
|
||||
temp_index[0] = index_out[offset]
|
||||
index_out[offset] = index_out[offset + 1]
|
||||
index_out[offset + 1] = temp_index[0]
|
||||
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
|
||||
tvm.convert(['shared']),
|
||||
tvm.expr.Call.Intrinsic, None, 0))
|
||||
return ib.get()
|
||||
|
||||
|
||||
def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
|
||||
"""Non-maximum supression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sorted_bbox_buf : tvm.schedule.Buffer
|
||||
3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
|
||||
[w_start, h_start, w_end, h_end, score].
|
||||
|
||||
out_buf : tvm.schedule.Buffer
|
||||
2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
|
||||
|
||||
nms_threshold : float
|
||||
Non-maximum suppression threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stmt : Stmt
|
||||
The result IR statement.
|
||||
"""
|
||||
def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
|
||||
"""Calculate overlap of two boxes.
|
||||
"""
|
||||
w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
|
||||
- tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
|
||||
h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
|
||||
- tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
|
||||
i = w * h
|
||||
u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
|
||||
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
|
||||
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
|
||||
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
|
||||
return i / u
|
||||
|
||||
batch, num_bbox = get_const_tuple(out_buf.shape)
|
||||
max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
|
||||
tx = tvm.thread_axis("threadIdx.x")
|
||||
bx = tvm.thread_axis("blockIdx.x")
|
||||
ib = tvm.ir_builder.create()
|
||||
p_data = ib.buffer_ptr(sorted_bbox_buf)
|
||||
p_out = ib.buffer_ptr(out_buf)
|
||||
nthread_tx = max_threads
|
||||
nthread_bx = num_bbox // max_threads + 1
|
||||
ib.scope_attr(tx, "thread_extent", nthread_tx)
|
||||
ib.scope_attr(bx, "thread_extent", nthread_bx)
|
||||
i = bx * max_threads + tx
|
||||
with ib.for_range(0, batch, for_type="unroll", name="n") as b:
|
||||
base_idx = b * num_bbox
|
||||
with ib.if_scope(i < num_bbox):
|
||||
p_out[base_idx + i] = False
|
||||
with ib.for_range(0, num_bbox - 1) as l:
|
||||
with ib.if_scope(tvm.all(i < num_bbox, i > l, p_out[base_idx + l] == False)):
|
||||
iou = calculate_overlap(p_data, (base_idx + l) * 5, (base_idx + i) * 5)
|
||||
with ib.if_scope(iou > nms_threshold):
|
||||
p_out[base_idx + i] = True
|
||||
return ib.get()
|
||||
|
||||
|
||||
def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
|
||||
"""Copy output after applying nms to continuous memory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sorted_bbox_buf : tvm.schedule.Buffer
|
||||
3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
|
||||
[w_start, h_start, w_end, h_end, score].
|
||||
|
||||
remove_mask_buf : tvm.schedule.Buffer
|
||||
2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
|
||||
|
||||
out_buf : tvm.schedule.Buffer
|
||||
2-D with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
|
||||
[batch_index, w_start, h_start, w_end, h_end].
|
||||
|
||||
Returns
|
||||
-------
|
||||
stmt : Stmt
|
||||
The result IR statement.
|
||||
"""
|
||||
batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape)
|
||||
rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch
|
||||
nthread_tx = batch
|
||||
tx = tvm.thread_axis("threadIdx.x")
|
||||
ib = tvm.ir_builder.create()
|
||||
ib.scope_attr(tx, "thread_extent", nthread_tx)
|
||||
i = ib.allocate('int32', (1,), 'i', scope='local')
|
||||
i[0] = 0
|
||||
p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf)
|
||||
p_remove = ib.buffer_ptr(remove_mask_buf)
|
||||
p_out = ib.buffer_ptr(out_buf)
|
||||
b = tx
|
||||
|
||||
nkeep = ib.allocate('int32', (1,), 'nkeep', scope='local')
|
||||
nkeep[0] = 0 # number of bbox after nms
|
||||
|
||||
with ib.for_range(0, num_bbox) as j:
|
||||
with ib.if_scope(p_remove[b * num_bbox + j] == False):
|
||||
nkeep[0] += 1
|
||||
with ib.if_scope(nkeep[0] > 0):
|
||||
with ib.for_range(0, tvm.ceil(
|
||||
tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[0]).astype('int32')):
|
||||
with ib.for_range(0, num_bbox) as j:
|
||||
offset_j = (b * num_bbox + j) * 5
|
||||
offset_i = (b * rpn_post_nms_top_n + i[0]) * 5
|
||||
with ib.if_scope(tvm.all(i[0] < rpn_post_nms_top_n,
|
||||
p_remove[(b*num_bbox+j)] == False)):
|
||||
p_out[offset_i] = tvm.expr.Cast('float32', b)
|
||||
with ib.for_range(0, 4, for_type='unroll') as k:
|
||||
p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
|
||||
i[0] = i[0] + 1
|
||||
|
||||
body = ib.get()
|
||||
return body
|
||||
|
||||
|
||||
@proposal.register("cuda")
|
||||
def proposal_cuda(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
|
||||
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
|
||||
"""Proposal operator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cls_prob : tvm.Tensor
|
||||
4-D with shape [batch, 2 * num_anchors, height, width]
|
||||
|
||||
bbox_pred : tvm.Tensor
|
||||
4-D with shape [batch, 4 * num_anchors, height, width]
|
||||
|
||||
im_info : tvm.Tensor
|
||||
2-D with shape [batch, 3]
|
||||
|
||||
scales : list/tuple of float
|
||||
Scales of anchor windoes.
|
||||
|
||||
ratios : list/tuple of float
|
||||
Ratios of anchor windoes.
|
||||
|
||||
feature_stride : int
|
||||
The size of the receptive field each unit in the convolution layer of the rpn, for example
|
||||
the product of all stride's prior to this layer.
|
||||
|
||||
threshold : float
|
||||
Non-maximum suppression threshold.
|
||||
|
||||
rpn_pre_nms_top_n : int
|
||||
Number of top scoring boxes to apply NMS. -1 to use all boxes.
|
||||
|
||||
rpn_post_nms_top_n : int
|
||||
Number of top scoring boxes to keep after applying NMS to RPN proposals.
|
||||
|
||||
rpn_min_size : int
|
||||
Minimum height or width in proposal.
|
||||
|
||||
iou_loss : bool
|
||||
Usage of IoU loss.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : tvm.Tensor
|
||||
2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
|
||||
[batch_index, w_start, h_start, w_end, h_end].
|
||||
"""
|
||||
|
||||
batch, _, height, width = get_const_tuple(cls_prob.shape)
|
||||
num_anchors = len(scales) * len(ratios)
|
||||
num_bbox = height * width * num_anchors
|
||||
rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox
|
||||
|
||||
bbox = tvm.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs:
|
||||
predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios,
|
||||
feature_stride, rpn_min_size, iou_loss),
|
||||
dtype=bbox_pred.dtype)
|
||||
score = tvm.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score')
|
||||
sorted_index = tvm.extern([score.shape], [score],
|
||||
lambda ins, outs: argsort_ir(ins[0], outs[0]),
|
||||
dtype='int32')
|
||||
sorted_bbox = tvm.compute((batch, rpn_pre_nms_top_n, 5),
|
||||
lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox')
|
||||
nms_remove_mask = tvm.extern((batch, rpn_pre_nms_top_n), [sorted_bbox],
|
||||
lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
|
||||
dtype='bool')
|
||||
nms_out = tvm.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask],
|
||||
lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
|
||||
dtype=sorted_bbox.dtype)
|
||||
return nms_out
|
|
@ -151,3 +151,32 @@ def schedule_multibox_detection(outs):
|
|||
@generic.schedule_roi_align.register(["cuda", "gpu"])
|
||||
def schedule_roi_align(outs):
|
||||
return schedule_pool(outs, 'NCHW')
|
||||
|
||||
@generic.schedule_proposal.register(["cuda", "gpu"])
|
||||
def schedule_proposal(outs):
|
||||
"""Schedule for proposal operator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of proposal
|
||||
in the format of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
s: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
|
||||
s = tvm.create_schedule([x.op for x in outs])
|
||||
scheduled_ops = []
|
||||
from .injective import _schedule_injective
|
||||
def traverse(op):
|
||||
if op.tag in ['bbox_score', 'sorted_bbox']:
|
||||
_schedule_injective(op, s)
|
||||
for tensor in op.input_tensors:
|
||||
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
|
||||
traverse(tensor.op)
|
||||
scheduled_ops.append(op)
|
||||
traverse(outs[0].op)
|
||||
return s
|
||||
|
|
|
@ -157,3 +157,20 @@ def schedule_roi_align(outs):
|
|||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
||||
@tvm.target.generic_func
|
||||
def schedule_proposal(outs):
|
||||
"""Schedule for proposal operator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outs: Array of Tensor
|
||||
The computation graph description of proposal
|
||||
in the format of an array of tensors.
|
||||
|
||||
Returns
|
||||
-------
|
||||
s: Schedule
|
||||
The computation schedule for the op.
|
||||
"""
|
||||
return _default_schedule(outs, False)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# pylint: disable=wildcard-import
|
||||
"""Faster R-CNN and Mask R-CNN operators"""
|
||||
from .roi_align import *
|
||||
from .proposal import *
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
# pylint: disable=invalid-name
|
||||
"""Proposal operator"""
|
||||
import math
|
||||
import tvm
|
||||
|
||||
|
||||
def generate_anchor(ratio, scale, base_size):
|
||||
"""Generate anchor"""
|
||||
w = h = float(base_size)
|
||||
x_ctr = 0.5 * (w - 1.)
|
||||
y_ctr = 0.5 * (h - 1.)
|
||||
size = w * h
|
||||
size_ratios = math.floor(size / ratio)
|
||||
new_w = math.floor(math.sqrt(size_ratios) + 0.5) * scale
|
||||
new_h = math.floor((new_w / scale * ratio) + 0.5) * scale
|
||||
return (x_ctr - 0.5 * (new_w - 1.0), y_ctr - 0.5 * (new_h - 1.0),
|
||||
x_ctr + 0.5 * (new_w - 1.0), y_ctr + 0.5 * (new_h - 1.0))
|
||||
|
||||
|
||||
def reg_bbox(x1, y1, x2, y2, dx, dy, dw, dh):
|
||||
"""Bounding box regression function"""
|
||||
bbox_w = x2 - x1 + 1.0
|
||||
bbox_h = y2 - y1 + 1.0
|
||||
ctr_x = x1 + 0.5 * (bbox_w - 1.0)
|
||||
ctr_y = y1 + 0.5 * (bbox_h - 1.0)
|
||||
|
||||
pred_ctr_x = dx * bbox_w + ctr_x
|
||||
pred_ctr_y = dy * bbox_h + ctr_y
|
||||
pred_w = tvm.exp(dw) * bbox_w
|
||||
pred_h = tvm.exp(dh) * bbox_h
|
||||
|
||||
pred_x1 = pred_ctr_x - 0.5 * (pred_w - 1.0)
|
||||
pred_y1 = pred_ctr_y - 0.5 * (pred_h - 1.0)
|
||||
pred_x2 = pred_ctr_x + 0.5 * (pred_w - 1.0)
|
||||
pred_y2 = pred_ctr_y + 0.5 * (pred_h - 1.0)
|
||||
return pred_x1, pred_y1, pred_x2, pred_y2
|
||||
|
||||
|
||||
def reg_iou(x1, y1, x2, y2, dx1, dy1, dx2, dy2):
|
||||
"""Bounding box regression function"""
|
||||
pred_x1 = x1 + dx1
|
||||
pred_y1 = y1 + dy1
|
||||
pred_x2 = x2 + dx2
|
||||
pred_y2 = y2 + dy2
|
||||
return pred_x1, pred_y1, pred_x2, pred_y2
|
||||
|
||||
|
||||
@tvm.target.generic_func
|
||||
def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
|
||||
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
|
||||
"""Proposal operator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cls_prob : tvm.Tensor
|
||||
4-D with shape [batch, 2 * num_anchors, height, width]
|
||||
|
||||
bbox_pred : tvm.Tensor
|
||||
4-D with shape [batch, 4 * num_anchors, height, width]
|
||||
|
||||
im_info : tvm.Tensor
|
||||
2-D with shape [batch, 3]
|
||||
|
||||
scales : list/tuple of float
|
||||
Scales of anchor windoes.
|
||||
|
||||
ratios : list/tuple of float
|
||||
Ratios of anchor windoes.
|
||||
|
||||
feature_stride : int
|
||||
The size of the receptive field each unit in the convolution layer of the rpn, for example
|
||||
the product of all stride's prior to this layer.
|
||||
|
||||
threshold : float
|
||||
Non-maximum suppression threshold.
|
||||
|
||||
rpn_pre_nms_top_n : int
|
||||
Number of top scoring boxes to apply NMS. -1 to use all boxes.
|
||||
|
||||
rpn_post_nms_top_n : int
|
||||
Number of top scoring boxes to keep after applying NMS to RPN proposals.
|
||||
|
||||
rpn_min_size : int
|
||||
Minimum height or width in proposal.
|
||||
|
||||
iou_loss : bool
|
||||
Usage of IoU loss.
|
||||
|
||||
Returns
|
||||
-------
|
||||
out : tvm.Tensor
|
||||
2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
|
||||
[batch_index, w_start, h_start, w_end, h_end].
|
||||
"""
|
||||
# pylint: disable=unused-argument
|
||||
raise ValueError("missing register for topi.vision.rcnn.proposal")
|
|
@ -1,4 +1,5 @@
|
|||
"""Test code for vision package"""
|
||||
from __future__ import print_function
|
||||
import math
|
||||
import numpy as np
|
||||
import tvm
|
||||
|
@ -206,8 +207,75 @@ def test_roi_align():
|
|||
verify_roi_align(4, 16, 32, 64, 7, 0.5, 2)
|
||||
|
||||
|
||||
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
|
||||
cls_prob = tvm.placeholder(np_cls_prob.shape)
|
||||
bbox_pred = tvm.placeholder(np_bbox_pred.shape)
|
||||
im_info = tvm.placeholder(np_im_info.shape, dtype='int32')
|
||||
|
||||
def check_device(device):
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
print("Skip because %s is not enabled" % device)
|
||||
return
|
||||
print("Running on target: %s" % device)
|
||||
with tvm.target.create(device):
|
||||
out = topi.vision.proposal(cls_prob, bbox_pred, im_info, **attrs)
|
||||
s = topi.generic.schedule_proposal(out)
|
||||
f = tvm.build(s, [cls_prob, bbox_pred, im_info, out], device)
|
||||
tvm_cls_prob = tvm.nd.array(np_cls_prob, ctx=ctx)
|
||||
tvm_bbox_pred = tvm.nd.array(np_bbox_pred, ctx=ctx)
|
||||
tvm_im_info = tvm.nd.array(np_im_info, ctx=ctx)
|
||||
tvm_out = tvm.nd.empty(ctx=ctx, shape=out.shape, dtype=out.dtype)
|
||||
f(tvm_cls_prob, tvm_bbox_pred, tvm_im_info, tvm_out)
|
||||
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4)
|
||||
|
||||
for device in ['cuda']:
|
||||
check_device(device)
|
||||
|
||||
|
||||
def test_proposal():
|
||||
attrs = {'scales': (0.5,),'ratios': (0.5,),
|
||||
'feature_stride': 16,
|
||||
'iou_loss': False,
|
||||
'rpn_min_size': 16,
|
||||
'threshold': 0.7,
|
||||
'rpn_pre_nms_top_n': 200,
|
||||
'rpn_post_nms_top_n': 4,
|
||||
}
|
||||
np_cls_prob = np.array([[
|
||||
[[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]],
|
||||
[[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]]
|
||||
]], dtype='float32')
|
||||
np_bbox_pred = np.array([[
|
||||
[[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]],
|
||||
[[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]],
|
||||
[[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]],
|
||||
[[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]],
|
||||
]], dtype='float32')
|
||||
np_im_info = np.array([[48, 48, 1]], dtype='int32')
|
||||
np_out = np.array([
|
||||
[0., 0., 2.8451548,28.38012, 18.154846],
|
||||
[0., 0., 15.354933, 41.96971, 41.245064],
|
||||
[0., 18.019852, 1.0538368, 51.98015, 25.946163],
|
||||
[0., 27.320923, -1.266357, 55., 24.666357]
|
||||
], dtype='float32')
|
||||
|
||||
verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
|
||||
|
||||
np_out = np.array([
|
||||
[ 0., -5.25, -2.5, 21.75, 19.],
|
||||
[ 0., 11.25, -2., 37.25, 18.5],
|
||||
[ 0., 26.849998, -2.3000002, 53.45, 18.6],
|
||||
[ 0., -4.95, 13.799999, 22.25, 35.5]
|
||||
], dtype='float32')
|
||||
|
||||
attrs['iou_loss'] = True
|
||||
verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nms()
|
||||
test_multibox_prior()
|
||||
test_multibox_detection()
|
||||
test_roi_align()
|
||||
test_proposal()
|
||||
|
|
Загрузка…
Ссылка в новой задаче