Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,8 @@ impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
tensor
}

fn graph_cleanup() {
let () = crate::runtime::graph::GraphCleaner::cleanup_orphaned_entries();
}
}
37 changes: 33 additions & 4 deletions crates/burn-autodiff/src/runtime/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub struct GraphLocator {
///
/// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was
/// first created.
pub(crate) struct Graph {
struct Graph {
origin: NodeId,
state: Mutex<GraphState>,
}
Expand Down Expand Up @@ -121,12 +121,41 @@ impl AutodiffClient for GraphMutexClient {
}
}

struct GraphCleaner<'a> {
pub(crate) struct GraphCleaner<'a> {
guard: MutexGuard<'a, Option<GraphLocator>>,
}

impl<'a> GraphCleaner<'a> {
fn cleanup_orphaned_entries() {
pub(crate) fn cleanup_orphaned_entries() {
// extra cleanup procedure
{
let graphs_to_visit = {
let graph_locator = crate::runtime::graph::STATE.lock();
let graph_locator = graph_locator.as_ref().unwrap();
let mut graphs_to_visit = HashMap::new();
for (_node_id, graph) in &graph_locator.graphs {
graphs_to_visit
.entry(graph.origin)
.or_insert_with(|| Arc::clone(graph));
}
graphs_to_visit
};

let mut cleaner = crate::runtime::graph::GraphCleaner::init();
for (_graph_origin, graph) in graphs_to_visit {
let mut state = graph.state.lock();
let server = &mut state.server;
server
.memory_management
.free_unavailable_nodes(|node_id: &NodeId| {
server.steps.remove(node_id);
server.actions_builder.remove(node_id);
cleaner.clean(node_id);
});
}
drop(cleaner);
}

let graphs = {
// Get the available graphs and release the lock
match STATE.lock().as_ref() {
Expand Down Expand Up @@ -185,7 +214,7 @@ impl GraphLocator {
/// # Returns
///
/// An `Arc<Graph>` representing the selected or merged graph.
pub(crate) fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc<Graph> {
fn select(&mut self, node: NodeId, parents: &[Parent]) -> Arc<Graph> {
match self.analyse(node, parents) {
GraphAnalysis::NoCollision(graph) => {
if graph.origin != node {
Expand Down
30 changes: 10 additions & 20 deletions crates/burn-autodiff/src/runtime/memory_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use crate::{
graph::Parent,
tensor::NodeRefCount,
};
use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
use core::mem;
use alloc::{borrow::ToOwned, sync::Arc, vec::Vec};

#[derive(Default, Debug)]
pub struct GraphMemoryManagement {
Expand Down Expand Up @@ -54,9 +53,8 @@ impl GraphMemoryManagement {
/// This function goes into three steps, which must happen for all leaves
/// before going into the next step. Then it deletes what can be safely deleted
pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
let leaves = self.leaves.clone();
let mut new_leaves = HashSet::new();
let mut deletables = Vec::new();
// Leaves cache to avoid having a ref to self
let leaves = core::mem::take(&mut self.leaves);

// When consuming nodes with a backward pass, some other backward passes become
// unavailable because some of their parents have been consumed. They are
Expand All @@ -71,14 +69,9 @@ impl GraphMemoryManagement {
// hence the need to iterate on all leaves.
self.useful_propagation(leaves.clone());

// New leaves are the roots of a useful backward sub-tree.
// Add new leaves as the roots of useful backward sub-tree.
// Deletables are everything not marked as useful.
for leaf in leaves {
self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
}

// Replace leaves by the new ones and delete everything not useful anymore
mem::swap(&mut self.leaves, &mut new_leaves);
let mut deletables = self.new_leaves_and_deletables(leaves.into_iter().collect());

self.clear_unused_roots(&mut deletables);

Expand Down Expand Up @@ -215,14 +208,9 @@ impl GraphMemoryManagement {
}
}

fn identify_leaves_and_deletables(
&self,
leaf_id: NodeId,
new_leaves: &mut HashSet<NodeId>,
to_delete: &mut Vec<NodeId>,
) {
fn new_leaves_and_deletables(&mut self, mut to_visit: Vec<NodeId>) -> Vec<NodeId> {
let mut to_delete: Vec<NodeId> = Vec::new();
let mut visited = HashSet::new();
let mut to_visit = vec![leaf_id];

while let Some(node_id) = to_visit.pop() {
visited.insert(node_id);
Expand All @@ -233,7 +221,8 @@ impl GraphMemoryManagement {
.expect("Node should have status")
{
NodeMemoryStatus::Useful => {
new_leaves.insert(node_id);
// New leaves are the roots of a useful backward sub-tree.
self.leaves.insert(node_id);
}
_ => {
to_delete.push(node_id);
Expand All @@ -252,6 +241,7 @@ impl GraphMemoryManagement {
}
};
}
to_delete
}

fn is_referenced(&self, node_id: NodeId) -> bool {
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/runtime/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ use alloc::vec::Vec;

#[derive(Default)]
pub struct AutodiffServer {
steps: HashMap<NodeId, StepBoxed>,
actions_builder: HashMap<NodeId, CheckpointerBuilder>,
memory_management: GraphMemoryManagement,
pub(crate) steps: HashMap<NodeId, StepBoxed>,
pub(crate) actions_builder: HashMap<NodeId, CheckpointerBuilder>,
pub(crate) memory_management: GraphMemoryManagement,
}

/// Defines how nodes are clean.
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-backend/src/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,7 @@ pub trait AutodiffBackend: Backend {
///
/// The autodiff backend tensor.
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;

/// Sweeps over all autodiff graphs and remove unused nodes.
fn graph_cleanup();
}
Loading