antares/frameworks/pytorch/examples/5_transformer.py

69 строки
4.1 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
from antares_core.frameworks.pytorch.custom_op import CustomOp
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float32
kwargs = {'dtype': dtype,
'device': device,
'requires_grad': False}
B, S, N, H, I = 6, 128, 12, 48, 1024
def create_param(name, shape):
return (torch.rand(shape, **kwargs) - 0.5) * 0.001
input_tensor = torch.ones([B, S, N, H], **kwargs)
qkv_weight = create_param('qkv_weight', [3, N, H, N, H])
qkv_bias = create_param('qkv_bias', [3, N, H])
attention_weight = create_param('attention_weight', [N, H, N, H])
attention_bias = create_param('attention_bias', [N, H])
intermediate_weight = create_param('intermediate_weight', [N, H, I])
intermediate_bias = create_param('intermediate_bias', [I])
output_weight = create_param('output_weight', [I, N, H])
output_bias = create_param('output_bias', [N, H])
layer_output_norm = CustomOp(ir=f'''
merged_layer_local[R, B, S1, N1, H1] +=! input_tensor[B, S1, N, H] * qkv_weight[R, N, H, N1, H1];
merged_layer_trans[R, B, N1, S1, H1] = merged_layer_local[R, B, S1, N1, H1] + qkv_bias[R, N1, H1];
attention_scores[B, N1, S1, S2] +=! merged_layer_trans[0, B, N1, S1, H1] * merged_layer_trans[1, B, N1, S2, H1] / const({H}).cast(`float32`);
softmax_1_temp0[B, N1] >=! attention_scores[B, N1, S1, S2];
softmax_1_temp1[B, N1] +=! (attention_scores[B, N1, S1, S2] - softmax_1_temp0[B, N1]).call(`exp`);
attention_probs[B, N1, S1, S2] = (attention_scores[B, N1, S1, S2] - softmax_1_temp0[B, N1]).call(`exp`) / softmax_1_temp1[B, N1];
context_layer_trans[B, S1, N1, H1] +=! attention_probs[B, N1, S1, S2] * merged_layer_trans[2, B, N1, S2, H1];
attention_local[B, S1, N2, H2] +=! context_layer_trans[B, S1, N1, H1] * attention_weight[N1, H1, N2, H2];
attention_output[B, S1, N2, H2] = attention_local[B, S1, N2, H2] + attention_bias[N2, H2];
layer_norm_1_src[B, S1, N2, H2] = attention_output[B, S1, N2, H2] + input_tensor[B, S1, N2, H2];
layer_norm_1_temp0[B, S1] += layer_norm_1_src[B, S1, N2, H2];
layer_norm_1_temp1[B, S1] += layer_norm_1_src[B, S1, N2, H2] * layer_norm_1_src[B, S1, N2, H2];
attention_output_norm[B, S1, N2, H2] = (layer_norm_1_src[B, S1, N2, H2] * {N * H} - layer_norm_1_temp0[B, S1]) * (layer_norm_1_temp0[B, S1] * {N * H} - layer_norm_1_temp1[B, S1] * layer_norm_1_temp1[B, S1]).call(`max`, [1e-8]).call(`rsqrt`);
intermediate_local[B, S1, I] +=! attention_output_norm[B, S1, N2, H2] * intermediate_weight[N2, H2, I];
intermediate[B, S1, I] = intermediate_local[B, S1, I] + intermediate_bias[I];
intermediate_gelu[B, S1, I] = 0.5 * (1.0 + (0.79788456 * (intermediate[B, S1, I] + 0.044715 * intermediate[B, S1, I] * intermediate[B, S1, I] * intermediate[B, S1, I])).call(`tanh`));
layer_output_local[B, S1, N2, H2] +=! intermediate_gelu[B, S1, I] * output_weight[I, N2, H2];
layer_output[B, S1, N2, H2] = layer_output_local[B, S1, N2, H2] + output_bias[N2, H2];
layer_norm_2_src[B, S1, N2, H2] = layer_output[B, S1, N2, H2] + attention_output_norm[B, S1, N2, H2];
layer_norm_2_temp0[B, S1] += layer_norm_2_src[B, S1, N2, H2];
layer_norm_2_temp1[B, S1] += layer_norm_2_src[B, S1, N2, H2] * layer_norm_2_src[B, S1, N2, H2];
layer_output_norm[B, S1, N2, H2] = (layer_norm_2_src[B, S1, N2, H2] * {N * H} - layer_norm_2_temp0[B, S1]) * (layer_norm_2_temp0[B, S1] * {N * H} - layer_norm_2_temp1[B, S1] * layer_norm_2_temp1[B, S1]).call(`max`, [1e-8]).call(`rsqrt`);
''', input_orders={
'input_tensor': input_tensor,
'qkv_weight': qkv_weight,
'qkv_bias': qkv_bias,
'attention_weight': attention_weight,
'attention_bias': attention_bias,
'intermediate_weight': intermediate_weight,
'intermediate_bias': intermediate_bias,
'output_weight': output_weight,
'output_bias': output_bias,
}, device=device).emit()
result = layer_output_norm(input_tensor, qkv_weight, qkv_bias, attention_weight, attention_bias, intermediate_weight, intermediate_bias, output_weight, output_bias)
print('The result of tensor `%s` is:\n%s' % (layer_output_norm.output_names[0], result))