Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Refactor Worker and ModelRunner to consolidate control plane communication #5408

Merged
merged 64 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
725b0b2
tmp
stephanie-wang Jun 11, 2024
b74eb10
fix
stephanie-wang Jun 11, 2024
38b0ddf
ray and mp backends work
stephanie-wang Jun 11, 2024
0d11e92
embedding model runner works
stephanie-wang Jun 11, 2024
2cdc218
GPU executor works
stephanie-wang Jun 11, 2024
c728512
remove comment
stephanie-wang Jun 11, 2024
2bf752b
use the right ModelInput class
stephanie-wang Jun 11, 2024
f35a23f
CPU worker
stephanie-wang Jun 11, 2024
11133fe
remove commented
stephanie-wang Jun 11, 2024
174bdb1
lint
stephanie-wang Jun 11, 2024
c0e98ca
Worker.execute_model vs execute_model_local
stephanie-wang Jun 11, 2024
dccec95
lint
stephanie-wang Jun 11, 2024
dad94ba
neuron model runner
stephanie-wang Jun 11, 2024
fca606e
disallow distributed comms
stephanie-wang Jun 11, 2024
6ed3c2a
disable communication
stephanie-wang Jun 12, 2024
1803e33
Update worker.py
stephanie-wang Jun 12, 2024
dde799e
fix tests
stephanie-wang Jun 12, 2024
0398631
update
stephanie-wang Jun 12, 2024
5c41cc6
Merge branch 'control-refactor-2' of github.com:stephanie-wang/vllm i…
stephanie-wang Jun 12, 2024
72f0383
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 12, 2024
eef6623
merge
stephanie-wang Jun 12, 2024
3004ceb
update
Jun 13, 2024
8d852e9
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 13, 2024
9380ed8
fix
stephanie-wang Jun 13, 2024
3c4de6d
fix
stephanie-wang Jun 13, 2024
5053f30
fix
stephanie-wang Jun 13, 2024
db38556
x
stephanie-wang Jun 14, 2024
456185d
rm
stephanie-wang Jun 14, 2024
e860652
lint
stephanie-wang Jun 14, 2024
3d4f242
add missing
stephanie-wang Jun 14, 2024
11304cb
revert
stephanie-wang Jun 14, 2024
99f532e
refactor
stephanie-wang Jun 15, 2024
797a7cf
doc
stephanie-wang Jun 15, 2024
6ad2513
revert spec decode and doc
stephanie-wang Jun 15, 2024
97ec303
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 15, 2024
e10bace
typing
stephanie-wang Jun 15, 2024
ce087ae
fix
stephanie-wang Jun 18, 2024
f851b00
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 18, 2024
0e2acc4
XPU worker and rename
stephanie-wang Jun 18, 2024
d318ec8
lint
stephanie-wang Jun 18, 2024
b48f783
lint
stephanie-wang Jun 18, 2024
c93afc1
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 18, 2024
30ac400
fix
stephanie-wang Jun 18, 2024
01688d5
x
stephanie-wang Jun 18, 2024
7dbb646
fix
stephanie-wang Jun 18, 2024
d2e4c41
fix
stephanie-wang Jun 19, 2024
3e46253
lint
stephanie-wang Jun 19, 2024
0a2890a
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 21, 2024
36dfce1
merge
stephanie-wang Jun 21, 2024
ea5412e
Merge remote-tracking branch 'upstream/main' into control-refactor-2
stephanie-wang Jun 21, 2024
dc2f103
x
stephanie-wang Jun 21, 2024
fbf074d
x
stephanie-wang Jun 22, 2024
660a8d5
rename ModelInput -> ModelInputBase, override as_broadcastable_tensor…
stephanie-wang Jun 23, 2024
8cca634
fixes
stephanie-wang Jun 23, 2024
0a25c19
rename
stephanie-wang Jun 23, 2024
0b26877
fix
stephanie-wang Jun 24, 2024
e7052d5
do not filter Nones
stephanie-wang Jun 24, 2024
df5551f
dupe
stephanie-wang Jun 24, 2024
6745b3b
update
stephanie-wang Jun 25, 2024
ebae970
lint
stephanie-wang Jun 25, 2024
5763621
revert
stephanie-wang Jun 25, 2024
46d5b18
Merge branch 'main' into control-refactor-2
stephanie-wang Jun 25, 2024
d16d5fe
rm
stephanie-wang Jun 25, 2024
f6c6234
fix
stephanie-wang Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ def get_impl_cls() -> Type["AttentionImpl"]:

@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError

@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
return BlocksparseFlashAttentionImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashInferMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
return ROCmFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
return TorchSDPAMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl

@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata

@staticmethod
def get_kv_cache_shape(
Expand Down
43 changes: 43 additions & 0 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,43 @@
get_tp_pynccl_communicator)


@dataclass
class DistributedContext:
communication_allowed: bool = True

@staticmethod
def get_current() -> "DistributedContext":
"""
Get the singleton context.
"""
global _default_context
return _default_context


_default_context: DistributedContext = DistributedContext()


def disable_communication(fn):
"""
Helper decorator to disable control plane communication, i.e.
calling broadcast_tensor_dict will throw a RuntimeError. This can be used
to ensure that decorated code stays worker-local.
"""

def wrapper(*args, **kwargs):
# Disallow control plane communication.
comm_ctx = DistributedContext.get_current()
original_comm_allowed = comm_ctx.communication_allowed
comm_ctx.communication_allowed = False

try:
return fn(*args, **kwargs)
finally:
comm_ctx.communication_allowed = original_comm_allowed

return wrapper


@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
Expand Down Expand Up @@ -235,6 +272,12 @@ def broadcast_tensor_dict(
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
ctx = DistributedContext.get_current()
if not ctx.communication_allowed:
raise RuntimeError(
"Control plane communication not allowed in functions decorated "
"with @disable_communication")

# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
Expand Down
16 changes: 8 additions & 8 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks=num_cpu_blocks)

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
Expand All @@ -79,7 +79,7 @@ def stop_remote_worker_execution_loop(self) -> None:
if self.parallel_worker_tasks is None:
return

self._driver_execute_model()
self._driver_execute_model(execute_model_req=None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
Expand Down Expand Up @@ -116,13 +116,13 @@ def save_sharded_state(

@abstractmethod
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.

Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
Passing None will cause the driver to stop the model execution loop
running in each of the remote workers. In this case, this method
returns None. Otherwise, this method returns the model output.
"""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def initialize_cache(self, num_gpu_blocks: int,

@abstractmethod
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences."""
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:

def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> List[Union[SamplerOutput, PoolerOutput]]:
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output

Expand Down
8 changes: 3 additions & 5 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,14 @@ def shutdown(self):
worker_monitor.close()

def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.

Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)
return self.driver_worker.execute_model(execute_model_req)

def _run_workers(
self,
Expand Down
3 changes: 1 addition & 2 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def execute_model(
assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")

output = self.driver_worker.execute_model(
execute_model_req.seq_group_metadata_list)
output = self.driver_worker.execute_model(execute_model_req)
return output

def add_lora(self, lora_request: LoRARequest) -> bool:
Expand Down
5 changes: 2 additions & 3 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
max_parallel_loading_workers)

def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.

Passing None will cause the driver to stop the model execution
Expand Down
Loading
Loading