[IR] Try to improve nms and get_valid_count (#3282)

* improve nms

* add back get_valid_count syncs
This commit is contained in:
Leyuan Wang 2019-06-04 20:32:31 -07:00 коммит произвёл Wuwei Lin
Родитель befd8c1e48
Коммит f2ddb1961c
1 изменённых файлов: 15 добавлений и 21 удалений

Просмотреть файл

@ -457,15 +457,15 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
box_indices = ib.buffer_ptr(box_indices)
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
max_threads = int(
tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
k = bx * max_threads + tx
j = bx * max_threads + tx
iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
@ -480,22 +480,22 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
nkeep = if_then_else( \
tvm.all(top_k > 0, top_k < valid_count[i]),
top_k, valid_count[i])
with ib.for_range(0, nkeep) as j:
with ib.if_scope(k < box_data_length):
with ib.if_scope(j < nkeep):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + j * box_data_length + k)] = \
data[(base_idx + sorted_index[i * num_anchors + j] \
* box_data_length + k)]
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
with ib.for_range(0, valid_count[i] - nkeep) as j:
with ib.if_scope(k < box_data_length):
with ib.if_scope(j < valid_count[i] - nkeep):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
box_indices[i * num_anchors + (j + nkeep)] = -1
# Apply nms
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(out[base_idx + offset_j] >= 0):
with ib.if_scope(k < valid_count[i]):
with ib.for_range(0, valid_count[i]) as k:
offset_k = k * box_data_length
with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \
tvm.any(force_suppress > 0, id_index < 0, \
@ -506,35 +506,29 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
with ib.if_scope(iou >= iou_threshold):
out[base_idx + offset_k] = -1.0
box_indices[i * num_anchors + k] = -1
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
with ib.else_scope():
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(k < box_data_length):
with ib.for_range(0, box_data_length) as k:
out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
box_indices[i * num_anchors + j] = j
# Set invalid entry to be -1
with ib.for_range(0, num_anchors - valid_count[i]) as j:
with ib.if_scope(k < box_data_length):
with ib.if_scope(j < num_anchors - valid_count[i]):
with ib.for_range(0, box_data_length) as k:
out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
box_indices[i * num_anchors + j + valid_count[i]] = -1
# Only return max_output_size number of valid boxes
num_valid_boxes[0] = 0
with ib.if_scope(max_output_size > 0):
with ib.for_range(0, valid_count[i]) as j:
with ib.if_scope(j < valid_count[i]):
offset_j = j * box_data_length
with ib.if_scope(out[base_idx + offset_j] >= 0):
with ib.if_scope(num_valid_boxes[0] == max_output_size):
with ib.if_scope(k < box_data_length):
with ib.for_range(0, box_data_length) as k:
out[base_idx + offset_j + k] = -1.0
box_indices[i * num_anchors + j] = -1
with ib.else_scope():
num_valid_boxes[0] += 1
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
return ib.get()