Skip to content

Conversation

casteryh
Copy link
Contributor

@casteryh casteryh commented Oct 15, 2025

What this PR does

  • Add a SharedTensor class that's backed by shared memory. It exposes a handle that can be trivially serialized and shared across processes inside a host. This is necessary since pytorch's native tensor backed by shared memory is serialized by value in monarch/cloudpickle.
  • Spawn dedicated procs colocated with the workers that fetch weights from torchstore to shared memory which will be forwarded to GeneratorWorker. Note naively parallelize weight fetch by doing asyncio.gather() will cause monarch RDMA operations to fail, having multiple procs circumvents this.
  • Let the dedicated procs prefetch while generator is still generating.

Perf

TL;DR e2e weight sync time is now ~50s for QWen3 32b; one training step takes <70s

before after
wait for existing generations 5 - 15s 5 - 15s
update_weights (total) ~80s ~50s
update_weights (since generation completes) ~80s 30s - 45s
fetch_weights n/a ~45s
shared_memory -> vllm worker n/a ~5s

Tested with

@casteryh
Copy link
Contributor Author

casteryh commented Oct 16, 2025

@casteryh could you split the PR with separate concerns?

I think it makes little sense to split, see below.

  1. Using shared memory for weight sync when CPU is involved (TorchStore RDMA CPU-CPU, DCP). This can be turned on by default for now.

Currently TorchStore RDMA only works with CPU-CPU.
I can enable this for DCP in a separate PR.

  1. Prefetching the weights while completing the on-the-fly generation requests.

This actually comes automatically once you have separate processes fetching the weights to shared memory.
I can add a flag that waits until all generation completes and then fetch the weights, but why would someone want to do that in the first place?

Also maybe add comment somewhere saying the following is up-for-discussion

  1. TorchStore RDMA GPU-GPU.
  2. Multi-node vLLM

Will do

I also don't quite get the "before" and "after" table.

Ah maybe it's confusing because I am trying to do two things at once. Basically the speed up comes from 1. multiprocess shared memory (this saves 30 seconds) & 2. prefetch while completing on-the-fly generation (this saves about 10 seconds on average)

@JenniferWang
Copy link
Contributor

This actually comes automatically once you have separate processes fetching the weights to shared memory.
I can add a flag that waits until all generation completes and then fetch the weights, but why would someone want to do that in the first place?

Ahha, yes, I got it now. Okay, so the boolean guard is not "use prefetch or not" -- is should be "use shared memory or not". I think you should try profiling on proc = 8 / replica for policy.

@casteryh
Copy link
Contributor Author

Ahha, yes, I got it now. Okay, so the boolean guard is not "use prefetch or not" -- is should be "use shared memory or not". I think you should try profiling on proc = 8 / replica for policy.

@JenniferWang
for tp=8 on policy
almost the same except loading from shared memory to gpu is now faster (because each worker has a smaller shard)

generator_perf/_fetch_weights/total_duration_avg_s: 43.30411313060904
generator_perf/_fetch_weights/total_duration_max_s: 44.24197581084445
generator_perf/waiting_for_fetch_weights/total_duration_avg_s: 28.329474024591036
generator_perf/waiting_for_fetch_weights/total_duration_max_s: 40.779690923169255
generator_worker_perf/update_weights_from_shared_memory/total_duration_avg_s: 3.1711842537115444
generator_worker_perf/update_weights_from_shared_memory/total_duration_max_s: 3.380359285045415

@casteryh
Copy link
Contributor Author

@allenwang28 @JenniferWang ptal
do you want me to enable this by default?

@JenniferWang
Copy link
Contributor

Ahha, yes, I got it now. Okay, so the boolean guard is not "use prefetch or not" -- is should be "use shared memory or not". I think you should try profiling on proc = 8 / replica for policy.

@JenniferWang for tp=8 on policy almost the same except loading from shared memory to gpu is now faster (because each worker has a smaller shard)

generator_perf/_fetch_weights/total_duration_avg_s: 43.30411313060904
generator_perf/_fetch_weights/total_duration_max_s: 44.24197581084445
generator_perf/waiting_for_fetch_weights/total_duration_avg_s: 28.329474024591036
generator_perf/waiting_for_fetch_weights/total_duration_max_s: 40.779690923169255
generator_worker_perf/update_weights_from_shared_memory/total_duration_avg_s: 3.1711842537115444
generator_worker_perf/update_weights_from_shared_memory/total_duration_max_s: 3.380359285045415

Yes, I was thinking that tp = 8 on policy would be worse without shared memory ?

@JenniferWang
Copy link
Contributor

I think we should make using shared memory by default for CPU based weight sync, with a flag to turn it off.

This commit fixes multiple memory leak issues in the SharedTensor
implementation by introducing explicit lifecycle management and proper
cleanup patterns.

Key Changes:
1. Fixed __del__ bug: changed hasattr(self, "shm") to check "_shm"
2. Added explicit close() method for releasing shared memory handles
3. Changed tensor from @cached_property to @Property with manual caching
4. Added closed state tracking with is_closed property
5. Made tensor access after close() raise RuntimeError (fail-fast)
6. Made get_handle() after close() raise RuntimeError
7. Updated drop() to call close() first, then unlink
8. Added context manager support (__enter__/__exit__)
9. Fixed _WeightFetcher to explicitly close after getting handle
10. Fixed GeneratorWorker to close shared memory after loading weights
11. Optimized SharedTensorHandle.drop() to not create unnecessary instances

Memory Leak Prevention:
- Creators must call close() after getting handle
- Receivers must call close() after using tensor
- One process should call drop() to unlink after all are done
- close() and drop() are idempotent and safe to call multiple times

Documentation:
- Added comprehensive class docstring with lifecycle model
- Documented that cached tensor references become invalid after close()
- Added warnings about not relying on __del__ for cleanup
- Added 12 new tests for close/cleanup behavior

Test Results: 65/65 tests pass with no warnings
Refactor generator code to use context manager pattern (with statement)
for SharedTensor cleanup instead of explicit close() calls. This provides:

- Clearer intent: context manager makes lifecycle explicit
- Automatic cleanup: ensures close() is called even on exceptions
- More idiomatic Python: standard pattern for resource management

Changes:
- GeneratorWorker.update_weights(): Use 'with' for SharedTensor from handles
- _WeightFetcher.fetch(): Use 'with' when creating SharedTensor and getting handle

The context manager automatically calls close() on exit, making the code
more concise and safer.
@casteryh
Copy link
Contributor Author

Ahha, yes, I got it now. Okay, so the boolean guard is not "use prefetch or not" -- is should be "use shared memory or not". I think you should try profiling on proc = 8 / replica for policy.

@JenniferWang for tp=8 on policy almost the same except loading from shared memory to gpu is now faster (because each worker has a smaller shard)

generator_perf/_fetch_weights/total_duration_avg_s: 43.30411313060904
generator_perf/_fetch_weights/total_duration_max_s: 44.24197581084445
generator_perf/waiting_for_fetch_weights/total_duration_avg_s: 28.329474024591036
generator_perf/waiting_for_fetch_weights/total_duration_max_s: 40.779690923169255
generator_worker_perf/update_weights_from_shared_memory/total_duration_avg_s: 3.1711842537115444
generator_worker_perf/update_weights_from_shared_memory/total_duration_max_s: 3.380359285045415

Yes, I was thinking that tp = 8 on policy would be worse without shared memory ?

yes I believe it was ~100 seconds without shared memory for tp=8, but I have some problem with my slurm node and can't test now.

Change prefetch_weights_to_shm from False to True to enable the new
shared memory-based weight prefetching feature by default.
@casteryh
Copy link
Contributor Author

I think we should make using shared memory by default for CPU based weight sync, with a flag to turn it off.

done

Remove qwen3_32b_experimental.yaml as the shared memory weight prefetching
feature is now enabled by default and no longer experimental.
@casteryh
Copy link
Contributor Author

fixed a memory leak

shm.close()
shm.unlink()
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the consideration behind swallowing the exceptions in cleaning up the resource?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this idempotent and safe to be called from multiple-processes. Open to other ideas.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! I think we'll want to go over how we're doing prefetch again after some of this is upstreamed to torchstore, but otherwise it looks great. I wonder if this is too risky of a change to make before PTC thought?

)


class _WeightFetcher(ForgeActor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be a method on the generator that gets called from main, so prefetch is controlled and visible from the main loop. I am curious if this has to actually be a separate process since this is an async method and I would think most of the time it's waiting on ts.get.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be a separate actor because it has to be launched in a separate process

param_key = get_param_key(version, name)
param = await ts.get(param_key)
# Use context manager to ensure cleanup after getting handle
with SharedTensor(tensor=param) as shared_tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to move this to TS and hide the rdma/shared memory logic from the user?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully yes.

)
self._start_processing()
fetcher_procs = this_host().spawn_procs(
per_host={"procs": self.n_fetcher_procs}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't follow why you need more than 1. Is it to allow you to parallelize torchstore requests?

engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
use_dcp_for_weight_sync: bool | None = None
prefetch_weights_to_shm: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we should try to avoid changing the "public" api when we expect to quickly change the backend again. After launch we should try to keep this in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed.

Changed exception handling in SharedTensor methods:
- close(): Changed from silently logging a warning to properly logging
  an error with the exception message. This provides better visibility
  when shared memory cleanup fails, while still not raising exceptions
  that could cause issues in cleanup paths (like __del__).
- drop(): Fixed exception chaining to use 'raise ... from e' for better
  error traceability and to comply with flake8 B904.

Test plan:
- Ran pytest tests/unit_tests/util/test_shared_tensor.py
- 64 tests passed (1 pre-existing segfault in test_multiprocess_bidirectional)
@casteryh
Copy link
Contributor Author

I also don't follow why you need more than 1. Is it to allow you to parallelize torchstore requests?

Yes it's 2x faster than 1 process. I haven't tuned this parameter too much though. @pbontrager

@casteryh
Copy link
Contributor Author

This is awesome! I think we'll want to go over how we're doing prefetch again after some of this is upstreamed to torchstore, but otherwise it looks great. I wonder if this is too risky of a change to make before PTC thought?

I am testing its stability right now. But fwiw, the current main is not stable / well tested either.

@casteryh
Copy link
Contributor Author

This is awesome! I think we'll want to go over how we're doing prefetch again after some of this is upstreamed to torchstore, but otherwise it looks great. I wonder if this is too risky of a change to make before PTC thought?

we can also switch the flag to be False by default

Fixed tests that were using the anti-pattern:
  SharedTensor(handle=handle).tensor

This pattern creates a dangling tensor reference because the SharedTensor
object is immediately garbage collected, causing __del__ to close the
shared memory, which invalidates the tensor reference.

Changed all multiprocess worker functions to use context managers:
- test_multiprocess_read
- test_multiprocess_write
- test_multiprocess_bidirectional
- test_to_shared_tensor_multiprocess
- test_multiple_receivers_close_independently

Test plan:
- All 65 tests now pass (previously 1 segfault)
- python -m pytest tests/unit_tests/util/test_shared_tensor.py -v
@casteryh
Copy link
Contributor Author

landing this as the apparent leak is in torchstore and not caused by this PR.

with fix in torchstore:
image
https://meta.wandb.io/torchforge/grpo-training/runs/huufaif1

@casteryh casteryh merged commit 6899e95 into main Oct 18, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants