Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused Cross Entropy Loss #1601

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Fused Cross Entropy Loss #1601

wants to merge 8 commits into from

Conversation

winglian
Copy link
Collaborator

@winglian winglian commented May 7, 2024

pytorch/pytorch#124480 served as the impetus for getting this integrated.

Adapted the code from @mgmalek's https://github.com/mgmalek/efficient_cross_entropy/blob/main/modules.py so that it properly handles causal/next-token prediction.

@YouJiacheng
Copy link

mgmalek's implementation might have numerical issues when using half precision dtype.
The gradient w.r.t. the weight is accumulated through chunks with a half precision accumulator.
Some users have reported loss curve mismatch on Twitter.

@winglian
Copy link
Collaborator Author

mgmalek's implementation might have numerical issues when using half precision dtype.
The gradient w.r.t. the weight is accumulated through chunks with a half precision accumulator.
Some users have reported loss curve mismatch on Twitter.

haha, yeah, that was probably me reporting the loss curve differences out on twitter. Do you have any insights on the best way to fix the kernels? standard CEL uses bfloat16 as well iirc.

@YouJiacheng
Copy link

YouJiacheng commented May 10, 2024

The problem is the lm_head weight (proj_weight), not CEL. By default bf16 matmul use fp32 accumulator (along the GEMM K axis). Since matmul is performed tile by tile in M&N axis (MxK@KxN matmul), registers are used as accumulator, without the need of a fp32 global memory buffer.

However, here we perform chunking along the GEMM K axis (DxT@TxV matmul)…

A simple fix can be using fp32 grad_proj_weight, but that will incur a significant memory overhead.

Can you check the loss curve with fp32 grad_proj_weight?

@winglian
Copy link
Collaborator Author

@YouJiacheng
Screenshot 2024-05-10 at 11 24 28 PM

gave your recommendation a try, but changing grad_proj_weight to fp32 didn't seem to help much.

@YouJiacheng
Copy link

Ah Oh, now we need more investigation…

@YouJiacheng
Copy link

It seems that the fused version has a different loss & grad_norm from the first step? That looks strange…

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants