Check partial conversion on FP16 to FP32 AVX Cast kernel (#22091)

### Description
Added checks to convert partial vectors in the early stages of the FP16
to FP32 cast using AVX NE CONVERT ISA.



### Motivation and Context
Avoid storing data in sections outside of the output buffer, these
checks are missing on the [original
PR](https://github.com/microsoft/onnxruntime/pull/21183).
This fix prevents memory corruption when the output buffer has a size
[n*16 + 1, n*16 + 7] with 0< n
This commit is contained in:
Erick Muñoz 2024-09-16 10:20:06 -06:00 коммит произвёл GitHub
Родитель 1a1669fe81
Коммит e93f14e00d
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
2 изменённых файлов: 5 добавлений и 3 удалений

Просмотреть файл

@ -54,7 +54,7 @@ HIGH_SELECTOR equ 00110001b
LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT
test r8, r8 ; Check if we have any elements to convert test r8, r8 ; Check if we have any elements to convert
jz ExitRoutine jz ExitRoutine
cmp r8, 8 cmp r8, 8
jb ConvertMaskedVectors jb ConvertMaskedVectors
@ -80,6 +80,8 @@ Convert256Vectors:
jz ExitRoutine ; If we are done, exit jz ExitRoutine ; If we are done, exit
cmp r8, 16 ; If the vector is big enough, we go again cmp r8, 16 ; If the vector is big enough, we go again
jae Convert256Vectors jae Convert256Vectors
cmp r8, 8 ; Check if we have enough elements to convert
jb ConvertMaskedVectors

Просмотреть файл

@ -51,8 +51,6 @@ FUNCTION_ENTRY MlasCastF16ToF32KernelAvx
test rdx, rdx // Check if we have any elements to convert test rdx, rdx // Check if we have any elements to convert
jz ExitRoutine jz ExitRoutine
AVX_NE_CONVERT:
cmp rdx, 8 cmp rdx, 8
jb ConvertMaskedVectors jb ConvertMaskedVectors
cmp rdx, 16 cmp rdx, 16
@ -75,6 +73,8 @@ Convert256Vectors:
jz ExitRoutine // If we are done, exit jz ExitRoutine // If we are done, exit
cmp rdx, 16 // If the vector is big enough, we go again cmp rdx, 16 // If the vector is big enough, we go again
jae Convert256Vectors jae Convert256Vectors
cmp rdx, 8 // Check if we have enough elements to convert
jb ConvertMaskedVectors