update mnist and alexnet examples (#270)

This commit is contained in:
ghostplant 2021-06-21 03:33:31 +00:00 коммит произвёл GitHub
Родитель 03546f0e53
Коммит 0ccc7b8653
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 28 добавлений и 16 удалений

Просмотреть файл

@ -27,9 +27,11 @@ b2 = create_param('dense_b2', [10])
custom_op = CustomOp(ir='''
data_0[N, M] +=! data[N, K] * weight_0[K, M];
data_1[N, K] = (data_0[N, K] + bias_0[K]).call(`max`, [0.0]);
data_0_bias[N, K] = data_0[N, K] + bias_0[K];
data_1[N, K] = data_0_bias[N, K].call(`max`, [0.0]);
data_2[N, M] +=! data_1[N, K] * weight_1[K, M];
data_3[N, K] = (data_2[N, K] + bias_1[K]).call(`max`, [0.0]);
data_2_bias[N, K] = data_2[N, K] + bias_1[K];
data_3[N, K] = data_2_bias[N, K].call(`max`, [0.0]);
data_4[N, M] +=! data_3[N, K] * weight_2[K, M];
data_5[N, K] = (data_4[N, K] + bias_2[K]);
''', feed_dict={'data': x, 'weight_0': w0, 'weight_1': w1, 'weight_2': w2, 'bias_0': b0, 'bias_1': b1, 'bias_2': b2}).to(device, dtype).tune(step=100, use_cache=True, timeout=600).emit()

Просмотреть файл

@ -20,17 +20,21 @@ def create_param(name, shape):
output_logits = CustomOp(ir=f'''
conv_0[N, F, HO, WO] +=! input_tensor[N, C, HO * 4 + KH, WO * 4 + KW] * const_0_[KH, KW, C, F] where HO in 55, WO in 55;
mpool_0[N, C, HO, WO ] >=! conv_0[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 27, WO in 27, KH in 3, KW in 3;
mpool_0[N, C, HO, WO] >=! conv_0[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 27, WO in 27, KH in 3, KW in 3;
conv_1[N, F, HO, WO] +=! mpool_0[N, C, -2 + HO + KH, -2 + WO + KW].when([-2 + HO + KH >= 0, -2 + HO + KH < 27, -2 + WO + KW >= 0, -2 + WO + KW < 27], 0.0) * const_1_[KH, KW, C, F] where HO in 27, WO in 27;
mpool_1[N, C, HO, WO ] >=! conv_1[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 13, WO in 13, KH in 3, KW in 3;
mpool_1[N, C, HO, WO] >=! conv_1[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 13, WO in 13, KH in 3, KW in 3;
conv_2[N, F, HO, WO] +=! mpool_1[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_2_[KH, KW, C, F] where HO in 13, WO in 13;
conv_3[N, F, HO, WO] +=! conv_2[N, C, -1 + HO + KH, -1 + WO + KW].call(`max`, [0.0]).when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_3_[KH, KW, C, F] where HO in 13, WO in 13;
conv_4[N, F, HO, WO] +=! conv_3[N, C, -1 + HO + KH, -1 + WO + KW].call(`max`, [0.0]).when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_4_[KH, KW, C, F] where HO in 13, WO in 13;
conv_2_relu[N, F, HO, WO] = conv_2[N, F, HO, WO].call(`max`, [0.0]);
conv_3[N, F, HO, WO] +=! conv_2_relu[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_3_[KH, KW, C, F] where HO in 13, WO in 13;
conv_3_relu[N, F, HO, WO] = conv_3[N, F, HO, WO].call(`max`, [0.0]);
conv_4[N, F, HO, WO] +=! conv_3_relu[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_4_[KH, KW, C, F] where HO in 13, WO in 13;
mpool_2[N, C, HO, WO] >=! conv_4[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 6, WO in 6, KH in 3, KW in 3;
reshape_0[N0, N1] = mpool_2[N0, N1 // 36 % 256, N1 // 6 % 6, N1 % 6] where N1 in 9216;
dense_0[N, M] +=! reshape_0[N, K] * const_5_[K, M];
dense_1[N, M] +=! dense_0[N, K].call(`max`, [0.0]) * const_6_[K, M];
dense_2[N, M] +=! dense_1[N, K].call(`max`, [0.0]) * const_7_[K, M];
dense_0_relu[N, M] = dense_0[N, M].call(`max`, [0.0]);
dense_1[N, M] +=! dense_0_relu[N, K] * const_6_[K, M];
dense_1_relu[N, M] = dense_1[N, M].call(`max`, [0.0]);
dense_2[N, M] +=! dense_1_relu[N, K] * const_7_[K, M];
''', feed_dict={
'input_tensor': input_tensor,
'const_0_': create_param('const_0_', [11, 11, 3, 64]),

Просмотреть файл

@ -33,9 +33,11 @@ tf_out = tf.add(tf.matmul(tf_out, w2), b2)
out = x
out = antares.make_op(ir='''
data_0[N, M] +=! data[N, K] * weight_0[K, M];
data_1[N, K] = (data_0[N, K] + bias_0[K]).call(`max`, [0.0]);
data_0_bias[N, K] = data_0[N, K] + bias_0[K];
data_1[N, K] = data_0_bias[N, K].call(`max`, [0.0]);
data_2[N, M] +=! data_1[N, K] * weight_1[K, M];
data_3[N, K] = (data_2[N, K] + bias_1[K]).call(`max`, [0.0]);
data_2_bias[N, K] = data_2[N, K] + bias_1[K];
data_3[N, K] = data_2_bias[N, K].call(`max`, [0.0]);
data_4[N, M] +=! data_3[N, K] * weight_2[K, M];
data_5[N, K] = (data_4[N, K] + bias_2[K]);
''', feed_dict={'data': x, 'weight_0': w0, 'weight_1': w1, 'weight_2': w2, 'bias_0': b0, 'bias_1': b1, 'bias_2': b2}).tune(step=200, use_cache=True, timeout=600).emit()

Просмотреть файл

@ -44,17 +44,21 @@ output_logits_tf = dense_2
output_logits = antares.make_op(ir=f'''
conv_0[N, F, HO, WO] +=! input_tensor[N, C, HO * 4 + KH, WO * 4 + KW] * const_0_[KH, KW, C, F] where HO in 55, WO in 55;
mpool_0[N, C, HO, WO ] >=! conv_0[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 27, WO in 27, KH in 3, KW in 3;
mpool_0[N, C, HO, WO] >=! conv_0[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 27, WO in 27, KH in 3, KW in 3;
conv_1[N, F, HO, WO] +=! mpool_0[N, C, -2 + HO + KH, -2 + WO + KW].when([-2 + HO + KH >= 0, -2 + HO + KH < 27, -2 + WO + KW >= 0, -2 + WO + KW < 27], 0.0) * const_1_[KH, KW, C, F] where HO in 27, WO in 27;
mpool_1[N, C, HO, WO ] >=! conv_1[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 13, WO in 13, KH in 3, KW in 3;
mpool_1[N, C, HO, WO] >=! conv_1[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 13, WO in 13, KH in 3, KW in 3;
conv_2[N, F, HO, WO] +=! mpool_1[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_2_[KH, KW, C, F] where HO in 13, WO in 13;
conv_3[N, F, HO, WO] +=! conv_2[N, C, -1 + HO + KH, -1 + WO + KW].call(`max`, [0.0]).when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_3_[KH, KW, C, F] where HO in 13, WO in 13;
conv_4[N, F, HO, WO] +=! conv_3[N, C, -1 + HO + KH, -1 + WO + KW].call(`max`, [0.0]).when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_4_[KH, KW, C, F] where HO in 13, WO in 13;
conv_2_relu[N, F, HO, WO] = conv_2[N, F, HO, WO].call(`max`, [0.0]);
conv_3[N, F, HO, WO] +=! conv_2_relu[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_3_[KH, KW, C, F] where HO in 13, WO in 13;
conv_3_relu[N, F, HO, WO] = conv_3[N, F, HO, WO].call(`max`, [0.0]);
conv_4[N, F, HO, WO] +=! conv_3_relu[N, C, -1 + HO + KH, -1 + WO + KW].when([-1 + HO + KH >= 0, -1 + HO + KH < 13, -1 + WO + KW >= 0, -1 + WO + KW < 13], 0.0) * const_4_[KH, KW, C, F] where HO in 13, WO in 13;
mpool_2[N, C, HO, WO] >=! conv_4[N, C, HO * 2 + KH, WO * 2 + KW].call(`max`, [0.0]) where HO in 6, WO in 6, KH in 3, KW in 3;
reshape_0[N0, N1] = mpool_2[N0, N1 // 36 % 256, N1 // 6 % 6, N1 % 6] where N1 in 9216;
dense_0[N, M] +=! reshape_0[N, K] * const_5_[K, M];
dense_1[N, M] +=! dense_0[N, K].call(`max`, [0.0]) * const_6_[K, M];
dense_2[N, M] +=! dense_1[N, K].call(`max`, [0.0]) * const_7_[K, M];
dense_0_relu[N, M] = dense_0[N, M].call(`max`, [0.0]);
dense_1[N, M] +=! dense_0_relu[N, K] * const_6_[K, M];
dense_1_relu[N, M] = dense_1[N, M].call(`max`, [0.0]);
dense_2[N, M] +=! dense_1_relu[N, K] * const_7_[K, M];
''', feed_dict=feed_dict).emit()
config = tf.ConfigProto()