Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions vllm_fl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,17 @@ def register():
def register_model():
"""Register the FL model."""
from vllm import ModelRegistry
import vllm.model_executor.models.qwen3_next as qwen3_next_module

try:
from vllm_fl.models.qwen3_next import Qwen3NextForCausalLM # noqa: F401

qwen3_next_module.Qwen3NextForCausalLM = Qwen3NextForCausalLM
logger.warning(
"Qwen3NextForCausalLM has been patched to use vllm_fl.models.qwen3_next, "
"original vLLM implementation is overridden"
)

ModelRegistry.register_model(
"Qwen3NextForCausalLM", "vllm_fl.models.qwen3_next:Qwen3NextForCausalLM"
)
Expand Down
2 changes: 2 additions & 0 deletions vllm_fl/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.distributed import (
ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce,
)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
from vllm.distributed.kv_transfer import (
Expand Down Expand Up @@ -1055,6 +1056,7 @@ def init_worker_distributed_environment(
"""Initialize the distributed environment."""
attention_config = vllm_config.attention_config
parallel_config = vllm_config.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
from vllm.model_executor.layers.batch_invariant import init_batch_invariance

init_batch_invariance(attention_config.backend)
Expand Down