update GPU policy for non GPU devices

This commit is contained in:
ghostplant 2023-07-17 12:20:35 +08:00
Родитель 4ab9d09d59
Коммит f50649c15f
1 изменённых файлов: 16 добавлений и 15 удалений

Просмотреть файл

@ -60,7 +60,7 @@ def schedule_branch(attrs, output, prefix):
local_slices = [list(cfg.apply_split(s, output_local, output_local.op.axis[i], [-1, 1] + data_sizes[i][1:])) for i in range(len(output_local.op.axis))]
zero, first, second, third, fourth = [x[0] for x in local_slices], [x[1] for x in local_slices], [x[2] for x in local_slices], [x[3] for x in local_slices], [x[4] for x in local_slices]
s[output_local].reorder(*(zero + first + second + [output_local_rv_o_o, output_local_rv_o_i] + third + [output_local_rv_i] + fourth))
s[output_local].reorder(*(zero + first + second + [output_local_rv_o_o,] + third + [output_local_rv_o_i,] + fourth + [output_local_rv_i]))
data_slices = [list(cfg.apply_split(s, output, output.op.axis[i], data_sizes[i])) for i in range(len(output.op.axis))]
@ -73,21 +73,22 @@ def schedule_branch(attrs, output, prefix):
s[output].bind(s[output].fuse(*second), te.thread_axis("vthread"))
s[output].bind(s[output].fuse(*third), te.thread_axis("threadIdx.x"))
load_stage = []
for i, load in enumerate(input_tensors):
use_align = cfg.define_knob(f"{prefix}AL{i}", [0, 1], init_vals=[0])
load_stage.append(s.cache_read(load, 'shared', [output_local]))
if use_align:
s[load_stage[-1]].storage_align(load_stage[-1].op.axis[0], 2, 1)
s[load_stage[-1]].compute_at(s[output_local], output_local_rv_o_o)
if '_intel' not in attrs.backend:
load_stage = []
for i, load in enumerate(input_tensors):
use_align = cfg.define_knob(f"{prefix}AL{i}", [0, 1], init_vals=[0])
load_stage.append(s.cache_read(load, 'shared', [output_local]))
if use_align:
s[load_stage[-1]].storage_align(load_stage[-1].op.axis[0], 2, 1)
s[load_stage[-1]].compute_at(s[output_local], output_local_rv_o_o)
for i, load in enumerate(load_stage):
fused_o = s[load].fuse(*s[load].op.axis)
val = 1 ## cfg.define_knob(f"{prefix}V{i}", [1, 2, 4] if not attrs.backend.startswith('c-hlsl_') else [1])
fused_o, fused_i = s[load].split(fused_o, factor=val)
s[load].vectorize(fused_i)
fused_o, fused_i = s[load].split(fused_o, factor=num_threads)
s[load].bind(fused_i, te.thread_axis("threadIdx.x"))
for i, load in enumerate(load_stage):
fused_o = s[load].fuse(*s[load].op.axis)
val = 1 ## cfg.define_knob(f"{prefix}V{i}", [1, 2, 4] if not attrs.backend.startswith('c-hlsl_') else [1])
fused_o, fused_i = s[load].split(fused_o, factor=val)
s[load].vectorize(fused_i)
fused_o, fused_i = s[load].split(fused_o, factor=num_threads)
s[load].bind(fused_i, te.thread_axis("threadIdx.x"))
# unroll
unroll_step = cfg.define_knob(f"{prefix}S", [1, 4, 32, 512])