Skip to content

Commit

Permalink
Check partial conversion on FP16 to FP32 AVX Cast kernel (#22091)
Browse files Browse the repository at this point in the history
### Description
Added checks to convert partial vectors in the early stages of the FP16
to FP32 cast using AVX NE CONVERT ISA.



### Motivation and Context
Avoid storing data in sections outside of the output buffer, these
checks are missing on the [original
PR](#21183).
This fix prevents memory corruption when the output buffer has a size
[n*16 + 1, n*16 + 7] with 0< n
  • Loading branch information
eralmual authored Sep 16, 2024
1 parent 1a1669f commit e93f14e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ HIGH_SELECTOR equ 00110001b

LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT

test r8, r8 ; Check if we have any elements to convert
test r8, r8 ; Check if we have any elements to convert
jz ExitRoutine
cmp r8, 8
jb ConvertMaskedVectors
Expand All @@ -80,6 +80,8 @@ Convert256Vectors:
jz ExitRoutine ; If we are done, exit
cmp r8, 16 ; If the vector is big enough, we go again
jae Convert256Vectors
cmp r8, 8 ; Check if we have enough elements to convert
jb ConvertMaskedVectors



Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ FUNCTION_ENTRY MlasCastF16ToF32KernelAvx

test rdx, rdx // Check if we have any elements to convert
jz ExitRoutine

AVX_NE_CONVERT:
cmp rdx, 8
jb ConvertMaskedVectors
cmp rdx, 16
Expand All @@ -75,6 +73,8 @@ Convert256Vectors:
jz ExitRoutine // If we are done, exit
cmp rdx, 16 // If the vector is big enough, we go again
jae Convert256Vectors
cmp rdx, 8 // Check if we have enough elements to convert
jb ConvertMaskedVectors



Expand Down

0 comments on commit e93f14e

Please sign in to comment.