зеркало из https://github.com/microsoft/antares.git
custom matmul layout (#12)
This commit is contained in:
Родитель
9a71f46ee6
Коммит
4a4a027ff3
|
@ -114,4 +114,4 @@ BACKEND=c-rocm COMPUTE_V1='- einstein_v2("output0[N, M] +=! input0[N, K].call(\"
|
|||
BACKEND=c-mcpu COMPUTE_V1='- einstein_v2("output0[N] = input0[N].call(\"fastadd\", [input1[N]])", input_dict={"input0": {"dtype": "avx256@256", "shape": [16]}, "input1": {"dtype": "avx256@256", "shape": [16]}}) ## @: plan/c-mcpu=blend.avx_add' make
|
||||
|
||||
# [INTRISIC SPEC] CUDA FP16 Tensorcore
|
||||
BACKEND=c-cuda COMPUTE_V1='- einstein_v2("output0[N, M] +=! input0[N, K].cast(\"float32\") * input1[K, M].cast(\"float32\")", { "input0": {"dtype": "float16", "shape": [1024, 1024]}, "input1": {"dtype": "float16", "shape": [1024, 1024]}}) ## @: plan/c-cuda=blend.matmul_fp16_tensorcore' make
|
||||
BACKEND=c-cuda COMPUTE_V1='- einstein_v2("output0[N, M] +=! input0[N, K].cast(\"float32\") * input1[K, M].cast(\"float32\")", { "input0": {"dtype": "float16", "shape": [1024, 1024]}, "input1": {"dtype": "float16", "shape": [1024, 1024]}}) ## @: plan/c-cuda=blend.matmul_fp16_tensorcore|layout=NN' make
|
||||
|
|
|
@ -27,12 +27,16 @@ def schedule(attrs):
|
|||
offset = 8
|
||||
|
||||
layout = 'NN'
|
||||
for opt in attrs.options:
|
||||
if opt.startswith('layout/'):
|
||||
layout = opt[len('layout/'):]
|
||||
break
|
||||
|
||||
'''
|
||||
if dtype == 'int8':
|
||||
factor = 32
|
||||
offset = 16
|
||||
'''
|
||||
|
||||
# create cache stages
|
||||
AA = s.cache_read(A, "shared", [C])
|
||||
if (layout == "NN" or layout == "TN"):
|
||||
|
@ -55,10 +59,11 @@ def schedule(attrs):
|
|||
v = cfg['v'].val
|
||||
|
||||
# thread tile
|
||||
TX = 8
|
||||
TY = 1
|
||||
TX, TY = 8, 1
|
||||
|
||||
# warp tile
|
||||
warp_tile_m = 16 # it could also be 8 or 32 on CUDA version >= 10.0
|
||||
cfg.define_knob("warp_m", [16, 8, 32])
|
||||
warp_tile_m = cfg['warp_m'].val # it could be 8, 16, 32 on CUDA version >= 10.0
|
||||
warp_tile_k = 16 # it must be 16
|
||||
# block tile
|
||||
tile_x = bx * TX
|
||||
|
|
|
@ -171,7 +171,7 @@ int main() {
|
|||
assert(0 == MAT_BLAS_FUNC(hCublas, rocblas_operation_transpose, rocblas_operation_transpose, __mat_M, __mat_N, __mat_K, (MAT_DATA_TYPE*)&alpha, (MAT_DATA_TYPE*)d_m[1], __mat_K, (MAT_DATA_TYPE*)d_m[0], __mat_N, (MAT_DATA_TYPE*)&beta, (MAT_DATA_TYPE*)d_m[2], __mat_M));
|
||||
} else
|
||||
assert(0);
|
||||
}, __half{0}, [](__half val) -> double { return __half2float(val); });
|
||||
}, __half(0.0f), [](__half val) -> double { return __half2float(val); });
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
#undef MAT_DATA_TYPE
|
||||
|
|
Загрузка…
Ссылка в новой задаче