diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 4c5d40f..7dbe0c8 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -22,10 +22,17 @@ def register(): def register_model(): """Register the FL model.""" from vllm import ModelRegistry + import vllm.model_executor.models.qwen3_next as qwen3_next_module try: from vllm_fl.models.qwen3_next import Qwen3NextForCausalLM # noqa: F401 + qwen3_next_module.Qwen3NextForCausalLM = Qwen3NextForCausalLM + logger.warning( + "Qwen3NextForCausalLM has been patched to use vllm_fl.models.qwen3_next, " + "original vLLM implementation is overridden" + ) + ModelRegistry.register_model( "Qwen3NextForCausalLM", "vllm_fl.models.qwen3_next:Qwen3NextForCausalLM" ) diff --git a/vllm_fl/worker/worker.py b/vllm_fl/worker/worker.py index 5975f70..2db1926 100644 --- a/vllm_fl/worker/worker.py +++ b/vllm_fl/worker/worker.py @@ -21,6 +21,7 @@ from vllm.distributed import ( ensure_model_parallel_initialized, init_distributed_environment, + set_custom_all_reduce, ) from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized from vllm.distributed.kv_transfer import ( @@ -1055,6 +1056,7 @@ def init_worker_distributed_environment( """Initialize the distributed environment.""" attention_config = vllm_config.attention_config parallel_config = vllm_config.parallel_config + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) from vllm.model_executor.layers.batch_invariant import init_batch_invariance init_batch_invariance(attention_config.backend)