Skip to content

Conversation

@R2-Y
Copy link
Contributor

@R2-Y R2-Y commented Dec 16, 2025

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.

Purpose

Placing multiple instances of the same model on the same device can cause memory calculation errors. Because non_torch_memory mistakenly includes memory consumed by other Omni Instances. As a result, the value of non_kv_cache_memory is more than real value, leave less memory for kv_cache to trigger assert.

When we calculate non_torch_memory:
cuda_memory = total_memory - free_memory ====> reflects memory used by all processes on the GPU
torch_memory = torch.cuda.memory_reserved() ====> returns memory reserved by PyTorch’s CUDA allocator for the current process only
non_torch_memory = cuda_memory - torch_memory
Hence, if there are any processes don't belong current instance will disrupt the current instance's memory calculations.

Use qwen3-omni as example:
we init thinker tp:0 on device 0:
Thinker_tp0_mem_snapshot
Thinker_tp0_non_torch_memory
Thinker_tp0_avaliable_kv_cache_mem

init thinker tp:1 on device 1:
Thinker_tp1_mem_snapshot
Thinker_tp1_non_torch_mem
Thinker_tp1_avaliable_kv_cache_mem

and talker is init on device 1 simultaneously:
talker_memorysnapshot
talker_non_torch_memory
talker_avaliable_kv_cache_mem

Because the two processes on device 1 interfere with each other during initialization, even with sufficient memory, their calculated non_torch_mem results are disordered, leading to an error in available_kvcache_mem.

I added a file lock during the initialization phase to ensure that multiple instances initialized on the same device can be started in sequence.

Test Plan

verify if memory calculation is correct

Test Result

Thinker:
image
Talker:
image


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft.

BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)

@R2-Y R2-Y requested a review from hsliuustc0106 as a code owner December 16, 2025 04:27
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 523 to 526
# Create our lock file
try:
with open(lock_file, "w") as f:
f.write(f"{_os.getpid()}\n")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Make device init lock atomic

The new sequential init uses while _os.path.exists(lock_file) and then creates the lock with open(lock_file, "w") only after that check. If two stage processes on the same CUDA device start at the same moment while no lock file exists, both will fall through the wait loop and each will create/truncate the file, so they proceed to initialize concurrently and the non_torch_memory miscalculation you’re trying to prevent can still happen. The lock needs an atomic creation (e.g., os.O_EXCL) or the existence check must be combined with the creation step; the same pattern is repeated in the async worker block.

Useful? React with 👍 / 👎.

lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock"

# Wait for other instances to finish initialization
max_wait_time = 300 # 5 minutes max wait
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part seems duplicated with the input arg --init-sleep-time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np, using init_sleep time sets the time after which the instance will start init. I've configured it to retry continuously within this max_wait_time. If an instance already in the process of initializing has finished, this instance will immediately begin initialization.

if torch.cuda.is_available():
# Get the current device ID (logical device 0 after set_stage_devices)
device_id = torch.cuda.current_device()
lock_file = f"/tmp/vllm_omni_device_{device_id}_init.lock"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Is there any method to delete the temp lock files in case some stage init fails? Otherwise, the GPU will be always locked.
  2. What if Stage A is on GPU 1 and Stage B is TP with size 2 on GPU 0,1?
  3. Or can we do it in another way:
  1. We determine the initialization in the OmniLLM before init the stages
  2. We read the required GPU Ids from the stage configs and deploy an initialization sorting algorithm that for stages without GPU overlapping, we init them in parallel and for stages with GPU overlapping, we do it one by one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I set the max_wait_time to remove those stale tmp lock file
# Check if the lock file is stale (older than 5 minutes)
if _time.time() - _os.path.getmtime(lock_file) > max_wait_time:
    _os.remove(lock_file)
    break
  1. because the device id I got is logical id, so even if stage A is on GPU 1, stage B with TP on GPU-0,1 , both of them lock file device id is 0. This is a bug, I fixed it in my next commit to use physical device id, will push soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. In the current design, I ensure that they are initialized in parallel on different devices, but sequentially on the same device.

@hsliuustc0106
Copy link
Collaborator

could you please add to the faq docs about usage of memory placing different stages using the same device?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready label to trigger buildkite CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: ValueError: Free memory on device is less than desired GPU memory utilization when serving Qwen3-Omni-30B-A3B-Instruct with 2×A100 (80GB)

3 participants