-
Notifications
You must be signed in to change notification settings - Fork 69
Transformer sequence parallel forward #5560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit 68e8a9a Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Refactoring |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Sequence Parallelism Implementation
|
|
!build |
Greptile OverviewGreptile SummaryThis PR adds sequence parallelism support to the transformer forward pass test, complementing the existing tensor parallelism.
The benchmarks show nvFuser achieves 2.1ms vs TransformerEngine's 2.5ms on 8 H100 nodes. Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Test as test_transformer_forward
participant FD as FusionDefinition
participant Schedule as multidevice_schedule
participant Exec as fd.execute
Test->>Test: Select parallelism (tp/sp)
Test->>Test: Create input tensor (CPU)
alt TENSOR_PARALLEL
Test->>Test: inp.cuda() (full tensor)
else SEQUENCE_PARALLEL
Test->>Test: shard_tensor(inp, dim=1) (split sequence)
end
Test->>FD: transformer_forward_definition(b, s, h, e)
Test->>Schedule: transformer_forward_multidevice_schedule(fd, d, parallelism)
Schedule->>Schedule: set_device_mesh for all inputs
alt SEQUENCE_PARALLEL
Schedule->>Schedule: inp.outer_split(1, num_devices)
Schedule->>Schedule: inp.axis(1).parallelize(mesh_x)
end
Schedule->>Schedule: Split weight tensors for tensor parallelism
Test->>Exec: fd.execute(ins)
Exec-->>Test: outputs (with s_local for SP mode)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
|
|
||
|
|
||
| def transformer_forward_multidevice_schedule(fd: FusionDefinition, num_devices: int): | ||
| def transformer_forward_multidevice_schedule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With direct bindings, the schedule doesn't have to live in a separate function from the definition. This way, you can remove the input-unpacking code. (This could have been addressed when the test was migrated from legacy bindings to direct bindings, but I didn’t push for it at the time to keep the migration moving quickly.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it.
I prefer keeping it separate. It's easier to focus on just the schedule (with the downside on unpacking) as the definition is quite long.
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Wall-clock time measured on
8 80GB H100 nodes: