diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index 07b2355774..5478441491 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -1,10 +1,12 @@ use crate::{ + NodeId, checkpoint::strategy::{CheckpointStrategy, NoCheckpointing}, + collections::HashMap, grads::Gradients, runtime::AutodiffClient, tensor::AutodiffTensor, }; -use alloc::{format, string::String}; +use alloc::{format, string::String, sync::Arc}; use burn_tensor::{ backend::{AutodiffBackend, Backend}, ops::{BoolTensor, IntTensor, QuantizedTensor}, @@ -124,4 +126,33 @@ impl AutodiffBackend for Autodiff { fn q_from_inner(tensor: QuantizedTensor) -> QuantizedTensor { tensor } + + fn graph_cleanup() { + let graphs_to_visit = { + let graph_locator = crate::runtime::graph::STATE.lock(); + let graph_locator = graph_locator.as_ref().unwrap(); + let mut graphs_to_visit = HashMap::new(); + for (_node_id, graph) in &graph_locator.graphs { + graphs_to_visit + .entry(graph.origin) + .or_insert_with(|| Arc::clone(graph)); + } + graphs_to_visit + }; + + use crate::runtime::NodeCleaner; + let mut cleaner = crate::runtime::graph::GraphCleaner::init(); + cleaner.cleanup_orphaned_entries(); + for (_graph_origin, graph) in graphs_to_visit { + let mut state = graph.state.lock().unwrap(); + let server = &mut state.server; + server + .memory_management + .free_unavailable_nodes(|node_id: &NodeId| { + server.steps.remove(node_id); + server.actions_builder.remove(node_id); + cleaner.clean(node_id); + }); + } + } } diff --git a/crates/burn-autodiff/src/runtime/graph.rs b/crates/burn-autodiff/src/runtime/graph.rs index db61060b4a..37cff0ab90 100644 --- a/crates/burn-autodiff/src/runtime/graph.rs +++ b/crates/burn-autodiff/src/runtime/graph.rs @@ -35,7 +35,7 @@ pub struct GraphMutexClient; /// Multiple node ids can point to the same graph, where the autodiff graph is stored. #[derive(Default)] pub struct GraphLocator { - graphs: HashMap>, + pub(crate) graphs: HashMap>, /// We keep a mapping of each original node id (graph id) => all nodes that point to that graph. /// This is to ensure that when merging graphs, we correctly move all previous graphs to /// the new merged one. @@ -48,13 +48,13 @@ pub struct GraphLocator { /// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was /// first created. pub(crate) struct Graph { - origin: NodeId, - state: Mutex, + pub(crate) origin: NodeId, + pub(crate) state: Mutex, } #[derive(Default)] -struct GraphState { - server: AutodiffServer, +pub(crate) struct GraphState { + pub(crate) server: AutodiffServer, } impl core::fmt::Debug for Graph { @@ -65,7 +65,7 @@ impl core::fmt::Debug for Graph { } } -static STATE: spin::Mutex> = spin::Mutex::new(None); +pub(crate) static STATE: spin::Mutex> = spin::Mutex::new(None); impl GraphMutexClient { /// Retrieves or creates a graph for the given [NodeId] and parent dependencies. @@ -133,12 +133,12 @@ impl AutodiffClient for GraphMutexClient { } } -struct GraphCleaner<'a> { +pub(crate) struct GraphCleaner<'a> { guard: spin::MutexGuard<'a, Option>, } impl<'a> GraphCleaner<'a> { - fn cleanup_orphaned_entries(&mut self) { + pub(crate) fn cleanup_orphaned_entries(&mut self) { if let Some(state) = self.guard.as_mut() { state.cleanup_untracked(); } diff --git a/crates/burn-autodiff/src/runtime/memory_management.rs b/crates/burn-autodiff/src/runtime/memory_management.rs index ad2471569f..cb8f92d111 100644 --- a/crates/burn-autodiff/src/runtime/memory_management.rs +++ b/crates/burn-autodiff/src/runtime/memory_management.rs @@ -4,8 +4,7 @@ use crate::{ graph::Parent, tensor::NodeRefCount, }; -use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec}; -use core::mem; +use alloc::{borrow::ToOwned, sync::Arc, vec::Vec}; #[derive(Default, Debug)] pub struct GraphMemoryManagement { @@ -54,9 +53,8 @@ impl GraphMemoryManagement { /// This function goes into three steps, which must happen for all leaves /// before going into the next step. Then it deletes what can be safely deleted pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) { - let leaves = self.leaves.clone(); - let mut new_leaves = HashSet::new(); - let mut deletables = Vec::new(); + // Leaves cache to avoid having a ref to self + let leaves = core::mem::take(&mut self.leaves); // When consuming nodes with a backward pass, some other backward passes become // unavailable because some of their parents have been consumed. They are @@ -71,14 +69,9 @@ impl GraphMemoryManagement { // hence the need to iterate on all leaves. self.useful_propagation(leaves.clone()); - // New leaves are the roots of a useful backward sub-tree. + // Add new leaves as the roots of useful backward sub-tree. // Deletables are everything not marked as useful. - for leaf in leaves { - self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables); - } - - // Replace leaves by the new ones and delete everything not useful anymore - mem::swap(&mut self.leaves, &mut new_leaves); + let mut deletables = self.new_leaves_and_deletables(leaves.into_iter().collect()); self.clear_unused_roots(&mut deletables); @@ -205,14 +198,9 @@ impl GraphMemoryManagement { } } - fn identify_leaves_and_deletables( - &self, - leaf_id: NodeId, - new_leaves: &mut HashSet, - to_delete: &mut Vec, - ) { + fn new_leaves_and_deletables(&mut self, mut to_visit: Vec) -> Vec { + let mut to_delete: Vec = Vec::new(); let mut visited = HashSet::new(); - let mut to_visit = vec![leaf_id]; while let Some(node_id) = to_visit.pop() { visited.insert(node_id); @@ -223,7 +211,8 @@ impl GraphMemoryManagement { .expect("Node should have status") { NodeMemoryStatus::Useful => { - new_leaves.insert(node_id); + // New leaves are the roots of a useful backward sub-tree. + self.leaves.insert(node_id); } _ => { to_delete.push(node_id); @@ -242,6 +231,7 @@ impl GraphMemoryManagement { } }; } + to_delete } fn is_referenced(&self, node_id: NodeId) -> bool { diff --git a/crates/burn-autodiff/src/runtime/mod.rs b/crates/burn-autodiff/src/runtime/mod.rs index fde3a5f71a..63f84a3718 100644 --- a/crates/burn-autodiff/src/runtime/mod.rs +++ b/crates/burn-autodiff/src/runtime/mod.rs @@ -4,3 +4,4 @@ mod server; pub mod graph; pub use client::*; +pub(crate) use server::NodeCleaner; diff --git a/crates/burn-autodiff/src/runtime/server.rs b/crates/burn-autodiff/src/runtime/server.rs index e94baa0e69..6c791ea37c 100644 --- a/crates/burn-autodiff/src/runtime/server.rs +++ b/crates/burn-autodiff/src/runtime/server.rs @@ -14,9 +14,9 @@ use alloc::vec::Vec; #[derive(Default)] pub struct AutodiffServer { - steps: HashMap, - actions_builder: HashMap, - memory_management: GraphMemoryManagement, + pub(crate) steps: HashMap, + pub(crate) actions_builder: HashMap, + pub(crate) memory_management: GraphMemoryManagement, } /// Defines how nodes are clean. diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index 9f76897fc6..984e8b21c1 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -277,4 +277,7 @@ pub trait AutodiffBackend: Backend { /// /// The autodiff backend tensor. fn q_from_inner(tensor: QuantizedTensor) -> QuantizedTensor; + + /// Sweeps over all autodiff graphs and remove unused nodes. + fn graph_cleanup(); }