[IR] Try to improve nms and get_valid_count (#3282)
* improve nms * add back get_valid_count syncs
This commit is contained in:
Родитель
befd8c1e48
Коммит
f2ddb1961c
|
@ -457,15 +457,15 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
|
|||
box_indices = ib.buffer_ptr(box_indices)
|
||||
num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
|
||||
|
||||
max_threads = int(math.sqrt(
|
||||
tvm.target.current_target(allow_none=False).max_num_threads))
|
||||
max_threads = int(
|
||||
tvm.target.current_target(allow_none=False).max_num_threads)
|
||||
nthread_tx = max_threads
|
||||
nthread_bx = num_anchors // max_threads + 1
|
||||
tx = tvm.thread_axis("threadIdx.x")
|
||||
bx = tvm.thread_axis("blockIdx.x")
|
||||
ib.scope_attr(tx, "thread_extent", nthread_tx)
|
||||
ib.scope_attr(bx, "thread_extent", nthread_bx)
|
||||
k = bx * max_threads + tx
|
||||
j = bx * max_threads + tx
|
||||
|
||||
iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
|
||||
top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
|
||||
|
@ -480,22 +480,22 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
|
|||
nkeep = if_then_else( \
|
||||
tvm.all(top_k > 0, top_k < valid_count[i]),
|
||||
top_k, valid_count[i])
|
||||
with ib.for_range(0, nkeep) as j:
|
||||
with ib.if_scope(k < box_data_length):
|
||||
with ib.if_scope(j < nkeep):
|
||||
with ib.for_range(0, box_data_length) as k:
|
||||
out[(base_idx + j * box_data_length + k)] = \
|
||||
data[(base_idx + sorted_index[i * num_anchors + j] \
|
||||
* box_data_length + k)]
|
||||
box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
|
||||
with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
|
||||
with ib.for_range(0, valid_count[i] - nkeep) as j:
|
||||
with ib.if_scope(k < box_data_length):
|
||||
with ib.if_scope(j < valid_count[i] - nkeep):
|
||||
with ib.for_range(0, box_data_length) as k:
|
||||
out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
|
||||
box_indices[i * num_anchors + (j + nkeep)] = -1
|
||||
# Apply nms
|
||||
with ib.for_range(0, valid_count[i]) as j:
|
||||
with ib.if_scope(j < valid_count[i]):
|
||||
offset_j = j * box_data_length
|
||||
with ib.if_scope(out[base_idx + offset_j] >= 0):
|
||||
with ib.if_scope(k < valid_count[i]):
|
||||
with ib.for_range(0, valid_count[i]) as k:
|
||||
offset_k = k * box_data_length
|
||||
with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \
|
||||
tvm.any(force_suppress > 0, id_index < 0, \
|
||||
|
@ -506,35 +506,29 @@ def nms_ir(data, sorted_index, valid_count, out, box_indices,
|
|||
with ib.if_scope(iou >= iou_threshold):
|
||||
out[base_idx + offset_k] = -1.0
|
||||
box_indices[i * num_anchors + k] = -1
|
||||
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
|
||||
tvm.convert(['shared']),
|
||||
tvm.expr.Call.Intrinsic, None, 0))
|
||||
with ib.else_scope():
|
||||
with ib.for_range(0, valid_count[i]) as j:
|
||||
with ib.if_scope(j < valid_count[i]):
|
||||
offset_j = j * box_data_length
|
||||
with ib.if_scope(k < box_data_length):
|
||||
with ib.for_range(0, box_data_length) as k:
|
||||
out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
|
||||
box_indices[i * num_anchors + j] = j
|
||||
# Set invalid entry to be -1
|
||||
with ib.for_range(0, num_anchors - valid_count[i]) as j:
|
||||
with ib.if_scope(k < box_data_length):
|
||||
with ib.if_scope(j < num_anchors - valid_count[i]):
|
||||
with ib.for_range(0, box_data_length) as k:
|
||||
out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
|
||||
box_indices[i * num_anchors + j + valid_count[i]] = -1
|
||||
# Only return max_output_size number of valid boxes
|
||||
num_valid_boxes[0] = 0
|
||||
with ib.if_scope(max_output_size > 0):
|
||||
with ib.for_range(0, valid_count[i]) as j:
|
||||
with ib.if_scope(j < valid_count[i]):
|
||||
offset_j = j * box_data_length
|
||||
with ib.if_scope(out[base_idx + offset_j] >= 0):
|
||||
with ib.if_scope(num_valid_boxes[0] == max_output_size):
|
||||
with ib.if_scope(k < box_data_length):
|
||||
with ib.for_range(0, box_data_length) as k:
|
||||
out[base_idx + offset_j + k] = -1.0
|
||||
box_indices[i * num_anchors + j] = -1
|
||||
with ib.else_scope():
|
||||
num_valid_boxes[0] += 1
|
||||
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
|
||||
tvm.convert(['shared']),
|
||||
tvm.expr.Call.Intrinsic, None, 0))
|
||||
|
||||
return ib.get()
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче