-
Notifications
You must be signed in to change notification settings - Fork 16
shared memory multiprocess prefetch for weight update #430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I think it makes little sense to split, see below.
Currently TorchStore RDMA only works with CPU-CPU.
This actually comes automatically once you have separate processes fetching the weights to shared memory.
Will do
Ah maybe it's confusing because I am trying to do two things at once. Basically the speed up comes from 1. multiprocess shared memory (this saves 30 seconds) & 2. prefetch while completing on-the-fly generation (this saves about 10 seconds on average) |
Ahha, yes, I got it now. Okay, so the boolean guard is not "use prefetch or not" -- is should be "use shared memory or not". I think you should try profiling on proc = 8 / replica for policy. |
@JenniferWang
|
@allenwang28 @JenniferWang ptal |
Yes, I was thinking that tp = 8 on policy would be worse without shared memory ? |
I think we should make using shared memory by default for CPU based weight sync, with a flag to turn it off. |
This commit fixes multiple memory leak issues in the SharedTensor implementation by introducing explicit lifecycle management and proper cleanup patterns. Key Changes: 1. Fixed __del__ bug: changed hasattr(self, "shm") to check "_shm" 2. Added explicit close() method for releasing shared memory handles 3. Changed tensor from @cached_property to @Property with manual caching 4. Added closed state tracking with is_closed property 5. Made tensor access after close() raise RuntimeError (fail-fast) 6. Made get_handle() after close() raise RuntimeError 7. Updated drop() to call close() first, then unlink 8. Added context manager support (__enter__/__exit__) 9. Fixed _WeightFetcher to explicitly close after getting handle 10. Fixed GeneratorWorker to close shared memory after loading weights 11. Optimized SharedTensorHandle.drop() to not create unnecessary instances Memory Leak Prevention: - Creators must call close() after getting handle - Receivers must call close() after using tensor - One process should call drop() to unlink after all are done - close() and drop() are idempotent and safe to call multiple times Documentation: - Added comprehensive class docstring with lifecycle model - Documented that cached tensor references become invalid after close() - Added warnings about not relying on __del__ for cleanup - Added 12 new tests for close/cleanup behavior Test Results: 65/65 tests pass with no warnings
Refactor generator code to use context manager pattern (with statement) for SharedTensor cleanup instead of explicit close() calls. This provides: - Clearer intent: context manager makes lifecycle explicit - Automatic cleanup: ensures close() is called even on exceptions - More idiomatic Python: standard pattern for resource management Changes: - GeneratorWorker.update_weights(): Use 'with' for SharedTensor from handles - _WeightFetcher.fetch(): Use 'with' when creating SharedTensor and getting handle The context manager automatically calls close() on exit, making the code more concise and safer.
yes I believe it was ~100 seconds without shared memory for tp=8, but I have some problem with my slurm node and can't test now. |
Change prefetch_weights_to_shm from False to True to enable the new shared memory-based weight prefetching feature by default.
done |
Remove qwen3_32b_experimental.yaml as the shared memory weight prefetching feature is now enabled by default and no longer experimental.
fixed a memory leak |
shm.close() | ||
shm.unlink() | ||
except Exception: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the consideration behind swallowing the exceptions in cleaning up the resource?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this idempotent and safe to be called from multiple-processes. Open to other ideas.
Co-authored-by: Jiyue Wang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! I think we'll want to go over how we're doing prefetch again after some of this is upstreamed to torchstore, but otherwise it looks great. I wonder if this is too risky of a change to make before PTC thought?
) | ||
|
||
|
||
class _WeightFetcher(ForgeActor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be a method on the generator that gets called from main, so prefetch is controlled and visible from the main loop. I am curious if this has to actually be a separate process since this is an async method and I would think most of the time it's waiting on ts.get.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has to be a separate actor because it has to be launched in a separate process
param_key = get_param_key(version, name) | ||
param = await ts.get(param_key) | ||
# Use context manager to ensure cleanup after getting handle | ||
with SharedTensor(tensor=param) as shared_tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the plan to move this to TS and hide the rdma/shared memory logic from the user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully yes.
) | ||
self._start_processing() | ||
fetcher_procs = this_host().spawn_procs( | ||
per_host={"procs": self.n_fetcher_procs} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't follow why you need more than 1. Is it to allow you to parallelize torchstore requests?
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs) | ||
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) | ||
use_dcp_for_weight_sync: bool | None = None | ||
prefetch_weights_to_shm: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general we should try to avoid changing the "public" api when we expect to quickly change the backend again. After launch we should try to keep this in mind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed.
Changed exception handling in SharedTensor methods: - close(): Changed from silently logging a warning to properly logging an error with the exception message. This provides better visibility when shared memory cleanup fails, while still not raising exceptions that could cause issues in cleanup paths (like __del__). - drop(): Fixed exception chaining to use 'raise ... from e' for better error traceability and to comply with flake8 B904. Test plan: - Ran pytest tests/unit_tests/util/test_shared_tensor.py - 64 tests passed (1 pre-existing segfault in test_multiprocess_bidirectional)
Yes it's 2x faster than 1 process. I haven't tuned this parameter too much though. @pbontrager |
I am testing its stability right now. But fwiw, the current main is not stable / well tested either. |
we can also switch the flag to be False by default |
Fixed tests that were using the anti-pattern: SharedTensor(handle=handle).tensor This pattern creates a dangling tensor reference because the SharedTensor object is immediately garbage collected, causing __del__ to close the shared memory, which invalidates the tensor reference. Changed all multiprocess worker functions to use context managers: - test_multiprocess_read - test_multiprocess_write - test_multiprocess_bidirectional - test_to_shared_tensor_multiprocess - test_multiple_receivers_close_independently Test plan: - All 65 tests now pass (previously 1 segfault) - python -m pytest tests/unit_tests/util/test_shared_tensor.py -v
landing this as the apparent leak is in torchstore and not caused by this PR. with fix in torchstore: |
What this PR does
Perf
TL;DR e2e weight sync time is now ~50s for QWen3 32b; one training step takes <70s
Tested with