-
Notifications
You must be signed in to change notification settings - Fork 35
Description
Problem Description
PYTORCH_NO_HIP_MEMORY_CACHING=1 HSA_SVM_GUARD_PAGES=1 HSA_DISABLE_FRAGMENT_ALLOCATOR=1 AMD_SERIALIZE_KERNEL=3 python triton-issue-604/test_backward.py
Will trigger segfault on Triton Commit 00e09cf3008b86978f25f838659698e4a0bf6f45
If uncommentting out tl.device_print
b/w LN 154-157 inside the fwd_kernel.py
file, and run
PYTORCH_NO_HIP_MEMORY_CACHING=1 HSA_SVM_GUARD_PAGES=1 HSA_DISABLE_FRAGMENT_ALLOCATOR=1 AMD_SERIALIZE_KERNEL=3 python triton-issue-604/test_backward.py 1>1 2>2
grep off_h_k 1|sort -u
Will get the following weird output
pid (0, 7, 3) idx () off_h_k: 8
pid (0, 7, 4) idx () off_h_k: 8
pid (0, 7, 5) idx () off_h_k: 8
pid (0, 7, 6) idx () off_h_k: 8
pid (0, 7, 7) idx () off_h_k: 8
For the performance kernel, off_h_k
should be identical to the second index of PID, but the value got increased by one.
No code in the Triton kernel has performed such change. Neither branch of MQA/GQA
if group_size != 1:
off_h_k = off_h_q // group_size
else:
off_h_k = off_h_q
should increase off_h_k to 8 from 7
It's either a compiler problem, or the tl.device_print
gets broken.
Operating System
Ubuntu 20.04.6 LTS
CPU
AMD EPYC 7542
GPU
AMD Instinct MI210
ROCm Version
ROCm 6.1.0
ROCm Component
No response
Steps to Reproduce
See Description.
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response