-
-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Distributed] merge two broadcast_tensor_dict #5354
Conversation
@zhuohan123 PTAL, if you agree with this design, I can refactor the rest model runners as well. |
the performance gain seems to be non-trivial: with this PR: without this PR: script:
machine: 8*H100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey still a general comment can we make the code more specific? Why do we design the function arguments with a general name like aux
?
@@ -47,12 +46,14 @@ def __init__( | |||
@torch.inference_mode() | |||
def execute_model( | |||
self, | |||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], | |||
broadcast_inputs: Dict[str, Any], | |||
aux: Optional[List[Any]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we add an aux
argument? I feel like this makes arguments more confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about the current design:
def execute_model(
self,
modelrunner_input: ModelRunnerInput,
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
i feel this is quite general? the driver worker prepares input, separate it into objects to broadcast, and objects to keep for itself (i.e. the |
But the method is called |
How about this design:
for non-driver workers:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @youkaichao I am very in favor of this (had been planning to do it myself at some point)!
Some other thoughts:
- This halves the number of broadcasts done, we can halve again by adding [Core] Avoid one broadcast op when propagating metadata #4844 or equivalent, which I expect would give additional non-negligible latency benefit
- Spec decoding adds another per-step broadcast which I hope can be similarly coalesced with this one
return metadata_dict, [sampling_metadata] | ||
|
||
def convert_broadcast_inputs_to_modelrunner_input( | ||
self, metadata_dict: Dict[str, Any], aux: Optional[List[Any]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add return typing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I delete it to avoid typing conflict. This function convert_broadcast_inputs_to_modelrunner_input
is implemented in embedding_model_runner.py
as well, but the return types differ (although both are named ModelRunnerInput
, they live in different modules, so mypy complains about it).
Suggestions are welcome for fixing it while making mypy happy.
if attn_metadata: | ||
metadata_dict.update(attn_metadata.asdict_zerocopy()) | ||
broadcast_tensor_dict(metadata_dict, src=0) | ||
def prepare_inputs_to_broadcast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether we should have this return ModelRunnerInput
and then add a method to that get_dict_to_broadcast()
which you only call in the TP > 1 case. This would be more efficient for the non-TP case.
Then this method could be called prepare_modelrunner_input()
which would make more sense to me since it's used even for non-TP where broadcasting isn't being done.
That would also obviate the need for this separate aux
arg/variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is who takes the responsibility to drive the broadcast process and how does it know the function to call, given that we have inheritance in model_runner.py
and embedding_model_runner.py
.
Previously, each model runner drives the broadcast operation itself, inside execute_model
, which leads to a separate broadcast.
In this PR, worker drives the broadcast operation, so it needs ModelRunner
to return inputs_to_broadcast
, broadcast it, and then feed it back to ModelRunner
.
If we use ModelRunnerInput. get_dict_to_broadcast
, how can the worker know which ModelRunnerInput
to use? We have multiple ModelRunnerInput
, used in both model_runner.py
and embedding_model_runner.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use
ModelRunnerInput. get_dict_to_broadcast
, how can the worker know whichModelRunnerInput
to use? We have multipleModelRunnerInput
, used in bothmodel_runner.py
andembedding_model_runner.py
.
You kind of answered your own question :) ... both ModelRunnerInput
s implement that and the worker just calls it (could be a protocol)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be solved by adding an abstract class and let ModelRunnerInput inherit the abstract class. For reasons I don't know, inheritance is discouraged in vllm, and several modelrunners just duplicate the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inheritance is used in various places already. IMO it makes sense to use judiciously but to avoid overdoing it.
In any case that's why I suggested to use a protocol here, you can do the same thing without inheritance.
I'm planning to use share memory for transport, so that we don't need to broadcast metadata at all. Only tensors will be broadcasted.
Sure, we can do it in a followup PR. |
close as #5408 is a superset of this pr. |
We have two
broadcast_tensor_dict
call, one in worker, one in model_runner. We can merge them together in onebroadcast_tensor_dict
call.