зеркало из https://github.com/microsoft/antares.git
update GPU policy for non GPU devices
This commit is contained in:
Родитель
4ab9d09d59
Коммит
f50649c15f
|
@ -60,7 +60,7 @@ def schedule_branch(attrs, output, prefix):
|
|||
|
||||
local_slices = [list(cfg.apply_split(s, output_local, output_local.op.axis[i], [-1, 1] + data_sizes[i][1:])) for i in range(len(output_local.op.axis))]
|
||||
zero, first, second, third, fourth = [x[0] for x in local_slices], [x[1] for x in local_slices], [x[2] for x in local_slices], [x[3] for x in local_slices], [x[4] for x in local_slices]
|
||||
s[output_local].reorder(*(zero + first + second + [output_local_rv_o_o, output_local_rv_o_i] + third + [output_local_rv_i] + fourth))
|
||||
s[output_local].reorder(*(zero + first + second + [output_local_rv_o_o,] + third + [output_local_rv_o_i,] + fourth + [output_local_rv_i]))
|
||||
|
||||
data_slices = [list(cfg.apply_split(s, output, output.op.axis[i], data_sizes[i])) for i in range(len(output.op.axis))]
|
||||
|
||||
|
@ -73,21 +73,22 @@ def schedule_branch(attrs, output, prefix):
|
|||
s[output].bind(s[output].fuse(*second), te.thread_axis("vthread"))
|
||||
s[output].bind(s[output].fuse(*third), te.thread_axis("threadIdx.x"))
|
||||
|
||||
load_stage = []
|
||||
for i, load in enumerate(input_tensors):
|
||||
use_align = cfg.define_knob(f"{prefix}AL{i}", [0, 1], init_vals=[0])
|
||||
load_stage.append(s.cache_read(load, 'shared', [output_local]))
|
||||
if use_align:
|
||||
s[load_stage[-1]].storage_align(load_stage[-1].op.axis[0], 2, 1)
|
||||
s[load_stage[-1]].compute_at(s[output_local], output_local_rv_o_o)
|
||||
if '_intel' not in attrs.backend:
|
||||
load_stage = []
|
||||
for i, load in enumerate(input_tensors):
|
||||
use_align = cfg.define_knob(f"{prefix}AL{i}", [0, 1], init_vals=[0])
|
||||
load_stage.append(s.cache_read(load, 'shared', [output_local]))
|
||||
if use_align:
|
||||
s[load_stage[-1]].storage_align(load_stage[-1].op.axis[0], 2, 1)
|
||||
s[load_stage[-1]].compute_at(s[output_local], output_local_rv_o_o)
|
||||
|
||||
for i, load in enumerate(load_stage):
|
||||
fused_o = s[load].fuse(*s[load].op.axis)
|
||||
val = 1 ## cfg.define_knob(f"{prefix}V{i}", [1, 2, 4] if not attrs.backend.startswith('c-hlsl_') else [1])
|
||||
fused_o, fused_i = s[load].split(fused_o, factor=val)
|
||||
s[load].vectorize(fused_i)
|
||||
fused_o, fused_i = s[load].split(fused_o, factor=num_threads)
|
||||
s[load].bind(fused_i, te.thread_axis("threadIdx.x"))
|
||||
for i, load in enumerate(load_stage):
|
||||
fused_o = s[load].fuse(*s[load].op.axis)
|
||||
val = 1 ## cfg.define_knob(f"{prefix}V{i}", [1, 2, 4] if not attrs.backend.startswith('c-hlsl_') else [1])
|
||||
fused_o, fused_i = s[load].split(fused_o, factor=val)
|
||||
s[load].vectorize(fused_i)
|
||||
fused_o, fused_i = s[load].split(fused_o, factor=num_threads)
|
||||
s[load].bind(fused_i, te.thread_axis("threadIdx.x"))
|
||||
|
||||
# unroll
|
||||
unroll_step = cfg.define_knob(f"{prefix}S", [1, 4, 32, 512])
|
||||
|
|
Загрузка…
Ссылка в новой задаче