зеркало из https://github.com/microsoft/antares.git
fix a sharding bug in IPU backend (#251)
This commit is contained in:
Родитель
f10f2e52ab
Коммит
ed7f3ad849
|
@ -16,19 +16,13 @@ def schedule(attrs):
|
|||
for i in range(len(output.op.axis)):
|
||||
ax_name = 'axis_%d' % i
|
||||
cfg.define_split(ax_name, attrs.get_extent(output.op.axis[i]), num_outputs=2)
|
||||
|
||||
num_cores, align_width = 1216, 64
|
||||
cfg.define_knob('start_core', [x * align_width for x in range(num_cores // align_width)])
|
||||
# num_cores, align_width = 1216, 64
|
||||
# cfg.define_knob('start_core', [x * align_width for x in range(num_cores // align_width)])
|
||||
return
|
||||
|
||||
loop_axes = []
|
||||
for i in range(len(output.op.axis)):
|
||||
lo, li = s[output].split(output.op.axis[i], nparts=1)
|
||||
if i == 0:
|
||||
s[output].bind(lo, te.thread_axis('blockIdx.x'))
|
||||
else:
|
||||
li = s[output].fuse(lo, li)
|
||||
s[output].bind(li, te.thread_axis('vthread'))
|
||||
loop_axes.append(li)
|
||||
la = [s[output].split(x, nparts=1) for x in output.op.axis]
|
||||
s[output].bind(la[0][0], te.thread_axis('blockIdx.x'))
|
||||
la = [ax[1] for ax in la]
|
||||
|
||||
s[output].reorder(*reversed(loop_axes))
|
||||
la = la + [x for x in output.op.reduce_axis]
|
||||
s[output].reorder(*la)
|
||||
|
|
|
@ -143,3 +143,11 @@ def run_pass_v2(ast_seq, global_input_dict, global_output_dict):
|
|||
ast['props']['shard']['local_shape'] = [x['range'] for x in ast['props']['data_axes']]
|
||||
with open(local_get_dir_file('range_book.json'), 'w') as fp:
|
||||
json.dump(ast['props']['shard'], fp)
|
||||
for k in global_input_dict:
|
||||
if k in ast['props']['input_dict']:
|
||||
global_input_dict[k] = ast['props']['input_dict'][k]
|
||||
|
||||
assert len(global_output_dict) == 1
|
||||
for k in global_output_dict:
|
||||
global_output_dict[k]['shape'] = [x['range'] for x in ast['props']['data_axes']]
|
||||
break
|
||||
|
|
|
@ -15,8 +15,10 @@ sess = onnxruntime.InferenceSession(model_path)
|
|||
|
||||
space_input, space_output = {}, {}
|
||||
for it in sess.get_inputs():
|
||||
print('input:', it)
|
||||
space_input[it.name] = np.array([1.0] * np.product(it.shape), dtype=np.float32)
|
||||
for it in sess.get_outputs():
|
||||
print('output:', it)
|
||||
space_output[it.name] = popart.AnchorReturnType("ALL")
|
||||
|
||||
if 'PROF' in os.environ:
|
||||
|
@ -40,11 +42,16 @@ stepio = popart.PyStepIO(space_input, anchors)
|
|||
|
||||
session.run(stepio)
|
||||
import time
|
||||
t1 = time.time()
|
||||
step = 100
|
||||
for i in range(step):
|
||||
|
||||
def run(step):
|
||||
t1 = time.time()
|
||||
for i in range(step):
|
||||
session.run(stepio)
|
||||
t2 = time.time()
|
||||
t2 = time.time()
|
||||
return (t2 - t1) / step
|
||||
|
||||
print("=>", anchors)
|
||||
print('Time:', (t2 - t1) / step)
|
||||
print('Time:', run(1))
|
||||
print('Time:', run(1))
|
||||
print('Time:', run(10))
|
||||
print('Time:', run(100))
|
||||
|
|
Загрузка…
Ссылка в новой задаче