Skip to content

Conversation

@swfsql
Copy link
Contributor

@swfsql swfsql commented Nov 4, 2025

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Changes

  • Adds the graph_cleanup method to the burn::tensor::backend::AutodiffBackend trait.
  • Also changed how the ordering of node removal is made, ensuring all leaf nodes are removed first.
    • This is not necessary since a given parent can only have a single child, but I think this could be a good little change.

Testing

Example of usage:

// (assuming the same code as in the linked issue)
AutoB::graph_cleanup(); // either after the tensor is dropped, or after the loop

Then by accompanying the system's ram usage, the memory usage doesn't grow (in case it's in the loop). Also, the are no vector/hashmaps/hashsets re-allocations, possibly improving runtime by a little.

@codecov
Copy link

codecov bot commented Nov 4, 2025

Codecov Report

❌ Patch coverage is 25.80645% with 23 lines in your changes missing coverage. Please review.
✅ Project coverage is 65.40%. Comparing base (1c5c777) to head (8f4aa51).
⚠️ Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-autodiff/src/backend.rs 0.00% 23 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3977      +/-   ##
==========================================
- Coverage   65.42%   65.40%   -0.03%     
==========================================
  Files        1183     1183              
  Lines      139251   139266      +15     
==========================================
- Hits        91103    91084      -19     
- Misses      48148    48182      +34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little late on the review sorry 😅

This could be useful to expose for users that want more fine-grained control over the cleanup, especially if they know they are creating tracked autodiff tensors that are not used.

Though I am not sure if this is a very widespread pattern, to explicitly create tensors to require grad without actually using them 🤔

As per your example in the linked issue:

let a: Tensor<AutoB, 2> = Tensor::zeros([2, 2], &train_device);
let a = a.require_grad(); // an autodiff node is created
drop(a); // the tensor is dropped but its unusable graph persists

If you don't explicitly mark that tensor to require grad, then it will not be tracked and thus should be cleaned up with the current "orphaned" cleanup strategy.

Comment on lines +130 to +157
fn graph_cleanup() {
let graphs_to_visit = {
let graph_locator = crate::runtime::graph::STATE.lock();
let graph_locator = graph_locator.as_ref().unwrap();
let mut graphs_to_visit = HashMap::new();
for (_node_id, graph) in &graph_locator.graphs {
graphs_to_visit
.entry(graph.origin)
.or_insert_with(|| Arc::clone(graph));
}
graphs_to_visit
};

use crate::runtime::NodeCleaner;
let mut cleaner = crate::runtime::graph::GraphCleaner::init();
cleaner.cleanup_orphaned_entries();
for (_graph_origin, graph) in graphs_to_visit {
let mut state = graph.state.lock().unwrap();
let server = &mut state.server;
server
.memory_management
.free_unavailable_nodes(|node_id: &NodeId| {
server.steps.remove(node_id);
server.actions_builder.remove(node_id);
cleaner.clean(node_id);
});
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the implementation more local to the GraphMutexClient, and simply call the cleanup procedure here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That makes sense as the (autograd) device isn't used at all. As I think this change is not really demanded (I'll be closing the issue and PR), I'll leave the adaption for posterity (in the case the PR is intended to be merged).

@swfsql
Copy link
Contributor Author

swfsql commented Nov 13, 2025

Thanks @laggui, I think the majority of situations are solved by your previous solution, and I'd agree that this PR actually favor some anti-patterns (leaving non-backward-related yet requiring-grads hanging tensors). On the counter-side, I think that this could be useful for debugging -- if a user has some ram leak and isn't sure if "hanging nodes" is the cause, this sort of method can help in identifying or eliminating that possibility -- but actual demand is indeed low.

I think I'll be closing the issue and the PR. I mean, they can always be referred to or re-opened in case this sort of demand repeats in the future.

@swfsql swfsql closed this Nov 13, 2025
@laggui
Copy link
Member

laggui commented Nov 13, 2025

Yeah it could be useful for debugging, I think we should expose more utilities for that.

Though in this case, it might actually only be useful when there is a bug 😅 which hopefully shouldn't happen again. But if we want to add something like this I would at least put it behind a feature so it is not exposed at the general level.

Sounds good! Thanks for the overall feedback nonetheless, and helping us find the initial memory leak issue 🙏

@vadixidav
Copy link

vadixidav commented Nov 18, 2025

@laggui I ran my code with main and with the branch from this PR merged into main. I got horrible performance on main, starting at about 2.1s per step and gradually increasing to around 15s per step and settling there. With this branch merged to main, I call <B as AutodiffBackend>::graph_cleanup(); after each training step. The performance becomes 2.1 seconds per step. It might be a memory leak, but also I think generally this behavior is very useful. I would reconsider opening this just because I find this to be so valuable currently. It might be helpful to tell burn in manual training loops to cleanup the autograd graph between steps. I understand that in theory it isn't needed, but it is helpful for users that find the autograd graph is blowing up and consuming tons of memory. I have issues on both RCs which have been released so far.

Edit: If you need it, I am maintaining main with graph-sweep merged here for now: https://github.com/oxideye/burn/tree/graph-sweep

@laggui
Copy link
Member

laggui commented Nov 18, 2025

@vadixidav I think I agree actually. As long as we document the usage, I think this can be useful.

I was actually debugging a program yesterday where a bunch of tensors explicitly marked as require grad were never used as part of the backward pass, creating an increasing amount of tiny graphs during training that were never freed. Because some steps might require the tensors as part of the backward state, it means that the GPU handles for these tensors were also not being freed.

Was trying to think of a good way to handle this kind of "orphaned graph" cleanup. But exposing a cleanup procedure similar to the proposal in this PR would be helpful for more advanced users.

I'll have a look at your code this week to validate, but probably you have a similar behavior to what I described above.

@vadixidav
Copy link

vadixidav commented Nov 19, 2025

@laggui

I'll have a look at your code this week to validate, but probably you have a similar behavior to what I described above.

Don't worry about it, my code is literally just this PR merged into main. It has no conflicts currently. You can reopen this PR as-is and get the same behavior.

@laggui
Copy link
Member

laggui commented Nov 19, 2025

Don't worry about it, my code is literally just this PR merged into main. It has no conflicts currently. You can reopen this PR as-is and get the same behavior.

Oh lol my brain shortened the URL when reading on my phone so I thought it linked to a repo (not burn fork 😅). Forget my comment haha

@swfsql swfsql mentioned this pull request Nov 26, 2025
2 tasks
@swfsql
Copy link
Contributor Author

swfsql commented Nov 26, 2025

Hey @laggui @vadixidav, I wanted to mention that I made a new PR with the graph-sweep branch rebased to main. I'm also keeping the branch in my local env and I try to keep it updated, but turns out that because the PR was closed, all updates are ignored in the PR timeline. And I can no longer re-open it because there was a force-push (rebase). So I opened a new PR ^^' and I'd just like to refer you to it - #4075 - but I agree that the motivation and preferable API for this kind of functionality is uncertain, so I'm fine with it being an open-ended (open) PR, if it's ok ofc.

@laggui
Copy link
Member

laggui commented Nov 27, 2025

FYI not sure if you've seen #4039

which now takes care of all unused graphs, including tensors explicitly marked as require grad that are never used within a backward pass. So technically the manual "sweep" should not be required

@swfsql
Copy link
Contributor Author

swfsql commented Nov 27, 2025

Thanks @laggui! Yes I've updated the sweep to your changes. I've tested main on the "stress test" (which doesn't call a backward, but I call the cleaning method manually), and main is still leaving some nodes alive. I'm not sure of the reason, it could be because there is a lack of a proper backward call, or because there is some additional Arc ref count, or something similar..

Edit: but on main I think that the memory is increasing slower than before. I had to bump the loop to 10 million to get a 1 GB ram usage.

@laggui
Copy link
Member

laggui commented Nov 27, 2025

I've tested main on the "stress test" (which doesn't call a backward, but I call the cleaning method manually)

Ohhh right I forgot that your stress test just never called backward lol. That will not trigger a cleanup, it only happens on backward. But it now also takes care of all unused tensors. As long as your program calls backward at least once for autodiff, it should clean them up (even if not involved in the backward graph).

@swfsql
Copy link
Contributor Author

swfsql commented Nov 27, 2025

Yes, but sorry @laggui for my insistence lol I have adapted the test to create two independent tensors, both requiring grad, where one of them has a backward call. On main, a small leak still occurs (in which case I had to bump the number of iterations to 10 million to be able to notice it).

In this code snippet:

for _ in 0..10_000_000 {
    let a: Tensor<AutoB, 2> = Tensor::zeros([2, 2], &train_device);
    let b: Tensor<AutoB, 2> = Tensor::zeros([2, 2], &train_device);
    drop(a.require_grad());
    drop(b.require_grad().sum().backward());
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Remove autodiff unused nodes according to dropped tensors

3 participants