diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index 14b723f4a3..ae6f49a7b9 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -135,4 +135,8 @@ impl AutodiffBackend for Autodiff { fn q_from_inner(tensor: QuantizedTensor) -> QuantizedTensor { tensor } + + fn graph_cleanup() { + let () = crate::runtime::graph::GraphCleaner::cleanup_orphaned_entries(); + } } diff --git a/crates/burn-autodiff/src/runtime/graph.rs b/crates/burn-autodiff/src/runtime/graph.rs index c367873f59..36e4718c52 100644 --- a/crates/burn-autodiff/src/runtime/graph.rs +++ b/crates/burn-autodiff/src/runtime/graph.rs @@ -52,7 +52,7 @@ pub struct GraphLocator { /// /// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was /// first created. -pub(crate) struct Graph { +struct Graph { origin: NodeId, state: Mutex, } @@ -121,12 +121,41 @@ impl AutodiffClient for GraphMutexClient { } } -struct GraphCleaner<'a> { +pub(crate) struct GraphCleaner<'a> { guard: MutexGuard<'a, Option>, } impl<'a> GraphCleaner<'a> { - fn cleanup_orphaned_entries() { + pub(crate) fn cleanup_orphaned_entries() { + // extra cleanup procedure + { + let graphs_to_visit = { + let graph_locator = crate::runtime::graph::STATE.lock(); + let graph_locator = graph_locator.as_ref().unwrap(); + let mut graphs_to_visit = HashMap::new(); + for (_node_id, graph) in &graph_locator.graphs { + graphs_to_visit + .entry(graph.origin) + .or_insert_with(|| Arc::clone(graph)); + } + graphs_to_visit + }; + + let mut cleaner = crate::runtime::graph::GraphCleaner::init(); + for (_graph_origin, graph) in graphs_to_visit { + let mut state = graph.state.lock(); + let server = &mut state.server; + server + .memory_management + .free_unavailable_nodes(|node_id: &NodeId| { + server.steps.remove(node_id); + server.actions_builder.remove(node_id); + cleaner.clean(node_id); + }); + } + drop(cleaner); + } + let graphs = { // Get the available graphs and release the lock match STATE.lock().as_ref() { @@ -185,7 +214,7 @@ impl GraphLocator { /// # Returns /// /// An `Arc` representing the selected or merged graph. - pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc { + fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc { match self.analyse(node, parents) { GraphAnalysis::NoCollision(graph) => { if graph.origin != node { diff --git a/crates/burn-autodiff/src/runtime/memory_management.rs b/crates/burn-autodiff/src/runtime/memory_management.rs index e2df6be30d..bfed1ef6aa 100644 --- a/crates/burn-autodiff/src/runtime/memory_management.rs +++ b/crates/burn-autodiff/src/runtime/memory_management.rs @@ -4,8 +4,7 @@ use crate::{ graph::Parent, tensor::NodeRefCount, }; -use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec}; -use core::mem; +use alloc::{borrow::ToOwned, sync::Arc, vec::Vec}; #[derive(Default, Debug)] pub struct GraphMemoryManagement { @@ -54,9 +53,8 @@ impl GraphMemoryManagement { /// This function goes into three steps, which must happen for all leaves /// before going into the next step. Then it deletes what can be safely deleted pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) { - let leaves = self.leaves.clone(); - let mut new_leaves = HashSet::new(); - let mut deletables = Vec::new(); + // Leaves cache to avoid having a ref to self + let leaves = core::mem::take(&mut self.leaves); // When consuming nodes with a backward pass, some other backward passes become // unavailable because some of their parents have been consumed. They are @@ -71,14 +69,9 @@ impl GraphMemoryManagement { // hence the need to iterate on all leaves. self.useful_propagation(leaves.clone()); - // New leaves are the roots of a useful backward sub-tree. + // Add new leaves as the roots of useful backward sub-tree. // Deletables are everything not marked as useful. - for leaf in leaves { - self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables); - } - - // Replace leaves by the new ones and delete everything not useful anymore - mem::swap(&mut self.leaves, &mut new_leaves); + let mut deletables = self.new_leaves_and_deletables(leaves.into_iter().collect()); self.clear_unused_roots(&mut deletables); @@ -215,14 +208,9 @@ impl GraphMemoryManagement { } } - fn identify_leaves_and_deletables( - &self, - leaf_id: NodeId, - new_leaves: &mut HashSet, - to_delete: &mut Vec, - ) { + fn new_leaves_and_deletables(&mut self, mut to_visit: Vec) -> Vec { + let mut to_delete: Vec = Vec::new(); let mut visited = HashSet::new(); - let mut to_visit = vec![leaf_id]; while let Some(node_id) = to_visit.pop() { visited.insert(node_id); @@ -233,7 +221,8 @@ impl GraphMemoryManagement { .expect("Node should have status") { NodeMemoryStatus::Useful => { - new_leaves.insert(node_id); + // New leaves are the roots of a useful backward sub-tree. + self.leaves.insert(node_id); } _ => { to_delete.push(node_id); @@ -252,6 +241,7 @@ impl GraphMemoryManagement { } }; } + to_delete } fn is_referenced(&self, node_id: NodeId) -> bool { diff --git a/crates/burn-autodiff/src/runtime/server.rs b/crates/burn-autodiff/src/runtime/server.rs index d3e7c41e14..738b215645 100644 --- a/crates/burn-autodiff/src/runtime/server.rs +++ b/crates/burn-autodiff/src/runtime/server.rs @@ -14,9 +14,9 @@ use alloc::vec::Vec; #[derive(Default)] pub struct AutodiffServer { - steps: HashMap, - actions_builder: HashMap, - memory_management: GraphMemoryManagement, + pub(crate) steps: HashMap, + pub(crate) actions_builder: HashMap, + pub(crate) memory_management: GraphMemoryManagement, } /// Defines how nodes are clean. diff --git a/crates/burn-backend/src/backend/base.rs b/crates/burn-backend/src/backend/base.rs index 98d9a27663..afa112fc2f 100644 --- a/crates/burn-backend/src/backend/base.rs +++ b/crates/burn-backend/src/backend/base.rs @@ -332,4 +332,7 @@ pub trait AutodiffBackend: Backend { /// /// The autodiff backend tensor. fn q_from_inner(tensor: QuantizedTensor) -> QuantizedTensor; + + /// Sweeps over all autodiff graphs and remove unused nodes. + fn graph_cleanup(); }