fix a sharding bug in IPU backend (#251)

This commit is contained in:
ghostplant 2021-04-25 22:54:22 +00:00 коммит произвёл GitHub
Родитель f10f2e52ab
Коммит ed7f3ad849
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 28 добавлений и 19 удалений

Просмотреть файл

@ -16,19 +16,13 @@ def schedule(attrs):
for i in range(len(output.op.axis)):
ax_name = 'axis_%d' % i
cfg.define_split(ax_name, attrs.get_extent(output.op.axis[i]), num_outputs=2)
num_cores, align_width = 1216, 64
cfg.define_knob('start_core', [x * align_width for x in range(num_cores // align_width)])
# num_cores, align_width = 1216, 64
# cfg.define_knob('start_core', [x * align_width for x in range(num_cores // align_width)])
return
loop_axes = []
for i in range(len(output.op.axis)):
lo, li = s[output].split(output.op.axis[i], nparts=1)
if i == 0:
s[output].bind(lo, te.thread_axis('blockIdx.x'))
else:
li = s[output].fuse(lo, li)
s[output].bind(li, te.thread_axis('vthread'))
loop_axes.append(li)
la = [s[output].split(x, nparts=1) for x in output.op.axis]
s[output].bind(la[0][0], te.thread_axis('blockIdx.x'))
la = [ax[1] for ax in la]
s[output].reorder(*reversed(loop_axes))
la = la + [x for x in output.op.reduce_axis]
s[output].reorder(*la)

Просмотреть файл

@ -143,3 +143,11 @@ def run_pass_v2(ast_seq, global_input_dict, global_output_dict):
ast['props']['shard']['local_shape'] = [x['range'] for x in ast['props']['data_axes']]
with open(local_get_dir_file('range_book.json'), 'w') as fp:
json.dump(ast['props']['shard'], fp)
for k in global_input_dict:
if k in ast['props']['input_dict']:
global_input_dict[k] = ast['props']['input_dict'][k]
assert len(global_output_dict) == 1
for k in global_output_dict:
global_output_dict[k]['shape'] = [x['range'] for x in ast['props']['data_axes']]
break

Просмотреть файл

@ -15,8 +15,10 @@ sess = onnxruntime.InferenceSession(model_path)
space_input, space_output = {}, {}
for it in sess.get_inputs():
print('input:', it)
space_input[it.name] = np.array([1.0] * np.product(it.shape), dtype=np.float32)
for it in sess.get_outputs():
print('output:', it)
space_output[it.name] = popart.AnchorReturnType("ALL")
if 'PROF' in os.environ:
@ -40,11 +42,16 @@ stepio = popart.PyStepIO(space_input, anchors)
session.run(stepio)
import time
t1 = time.time()
step = 100
for i in range(step):
session.run(stepio)
t2 = time.time()
def run(step):
t1 = time.time()
for i in range(step):
session.run(stepio)
t2 = time.time()
return (t2 - t1) / step
print("=>", anchors)
print('Time:', (t2 - t1) / step)
print('Time:', run(1))
print('Time:', run(1))
print('Time:', run(10))
print('Time:', run(100))