Improve IfrtMergeReshardsPass to allow for more efficient merging
#34803
+102
−106
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Improve
IfrtMergeReshardsPassto allow for more efficient mergingThe current implementation can only handle copy operations whose source operations are the same, but there can be many parallel copies that have the same source and destination devices but with arguments produced by different ops. See the added test for an example.
This CL improves the algorithm (and as a result simplifies the implementation as well) to allow for merging such parallel copies. The intuition is that the existing grouping based on the first reshard user op is actually sufficient to prevent any circular dependency after merging because reshard X that (transitively) depends on reshard Y can never have the same first reshard user op by definition. Thus, we can simply iterate over all reshard ops in the function and merge them based on the same keys that we are using today.
The algorithm now runs iteratively until fixpoint because merging some reshard ops changes the "first reshard user op", which may create more merging opportunities. The added test fails without this loop.
This is particularly useful when arguments are progressively broadcast over multiple pipeline stage submeshes because argument broadcast and across-stage transfers for intermediates will be completely batchable as long as their broadcast order is the same.