Skip to content

[Issue]: FA Performance forward Kernel Segfault under certain cases #604

@xinyazhang

Description

@xinyazhang

Problem Description

triton-issue-604.tar.gz

PYTORCH_NO_HIP_MEMORY_CACHING=1 HSA_SVM_GUARD_PAGES=1 HSA_DISABLE_FRAGMENT_ALLOCATOR=1 AMD_SERIALIZE_KERNEL=3 python triton-issue-604/test_backward.py Will trigger segfault on Triton Commit 00e09cf3008b86978f25f838659698e4a0bf6f45

If uncommentting out tl.device_print b/w LN 154-157 inside the fwd_kernel.py file, and run

PYTORCH_NO_HIP_MEMORY_CACHING=1 HSA_SVM_GUARD_PAGES=1 HSA_DISABLE_FRAGMENT_ALLOCATOR=1 AMD_SERIALIZE_KERNEL=3 python triton-issue-604/test_backward.py 1>1 2>2
grep off_h_k 1|sort -u

Will get the following weird output

pid (0, 7, 3) idx () off_h_k: 8
pid (0, 7, 4) idx () off_h_k: 8
pid (0, 7, 5) idx () off_h_k: 8
pid (0, 7, 6) idx () off_h_k: 8
pid (0, 7, 7) idx () off_h_k: 8

For the performance kernel, off_h_k should be identical to the second index of PID, but the value got increased by one.
No code in the Triton kernel has performed such change. Neither branch of MQA/GQA

if group_size != 1:
    off_h_k = off_h_q // group_size
else:
    off_h_k = off_h_q

should increase off_h_k to 8 from 7

It's either a compiler problem, or the tl.device_print gets broken.

Operating System

Ubuntu 20.04.6 LTS

CPU

AMD EPYC 7542

GPU

AMD Instinct MI210

ROCm Version

ROCm 6.1.0

ROCm Component

No response

Steps to Reproduce

See Description.

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions