Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::{
NodeId,
checkpoint::strategy::{CheckpointStrategy, NoCheckpointing},
collections::HashMap,
grads::Gradients,
runtime::AutodiffClient,
tensor::AutodiffTensor,
};
use alloc::{format, string::String};
use alloc::{format, string::String, sync::Arc};
use burn_tensor::{
backend::{AutodiffBackend, Backend},
ops::{BoolTensor, IntTensor, QuantizedTensor},
Expand Down Expand Up @@ -124,4 +126,33 @@ impl<B: Backend, C: CheckpointStrategy> AutodiffBackend for Autodiff<B, C> {
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self> {
tensor
}

fn graph_cleanup() {
let graphs_to_visit = {
let graph_locator = crate::runtime::graph::STATE.lock();
let graph_locator = graph_locator.as_ref().unwrap();
let mut graphs_to_visit = HashMap::new();
for (_node_id, graph) in &graph_locator.graphs {
graphs_to_visit
.entry(graph.origin)
.or_insert_with(|| Arc::clone(graph));
}
graphs_to_visit
};

use crate::runtime::NodeCleaner;
let mut cleaner = crate::runtime::graph::GraphCleaner::init();
cleaner.cleanup_orphaned_entries();
for (_graph_origin, graph) in graphs_to_visit {
let mut state = graph.state.lock().unwrap();
let server = &mut state.server;
server
.memory_management
.free_unavailable_nodes(|node_id: &NodeId| {
server.steps.remove(node_id);
server.actions_builder.remove(node_id);
cleaner.clean(node_id);
});
}
}
Comment on lines +130 to +157
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the implementation more local to the GraphMutexClient, and simply call the cleanup procedure here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That makes sense as the (autograd) device isn't used at all. As I think this change is not really demanded (I'll be closing the issue and PR), I'll leave the adaption for posterity (in the case the PR is intended to be merged).

}
16 changes: 8 additions & 8 deletions crates/burn-autodiff/src/runtime/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct GraphMutexClient;
/// Multiple node ids can point to the same graph, where the autodiff graph is stored.
#[derive(Default)]
pub struct GraphLocator {
graphs: HashMap<NodeId, Arc<Graph>>,
pub(crate) graphs: HashMap<NodeId, Arc<Graph>>,
/// We keep a mapping of each original node id (graph id) => all nodes that point to that graph.
/// This is to ensure that when merging graphs, we correctly move all previous graphs to
/// the new merged one.
Expand All @@ -48,13 +48,13 @@ pub struct GraphLocator {
/// Each `Graph` contains an [AutodiffServer] and the original [NodeId] where the server was
/// first created.
pub(crate) struct Graph {
origin: NodeId,
state: Mutex<GraphState>,
pub(crate) origin: NodeId,
pub(crate) state: Mutex<GraphState>,
}

#[derive(Default)]
struct GraphState {
server: AutodiffServer,
pub(crate) struct GraphState {
pub(crate) server: AutodiffServer,
}

impl core::fmt::Debug for Graph {
Expand All @@ -65,7 +65,7 @@ impl core::fmt::Debug for Graph {
}
}

static STATE: spin::Mutex<Option<GraphLocator>> = spin::Mutex::new(None);
pub(crate) static STATE: spin::Mutex<Option<GraphLocator>> = spin::Mutex::new(None);

impl GraphMutexClient {
/// Retrieves or creates a graph for the given [NodeId] and parent dependencies.
Expand Down Expand Up @@ -133,12 +133,12 @@ impl AutodiffClient for GraphMutexClient {
}
}

struct GraphCleaner<'a> {
pub(crate) struct GraphCleaner<'a> {
guard: spin::MutexGuard<'a, Option<GraphLocator>>,
}

impl<'a> GraphCleaner<'a> {
fn cleanup_orphaned_entries(&mut self) {
pub(crate) fn cleanup_orphaned_entries(&mut self) {
if let Some(state) = self.guard.as_mut() {
state.cleanup_untracked();
}
Expand Down
30 changes: 10 additions & 20 deletions crates/burn-autodiff/src/runtime/memory_management.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ use crate::{
graph::Parent,
tensor::NodeRefCount,
};
use alloc::{borrow::ToOwned, sync::Arc, vec, vec::Vec};
use core::mem;
use alloc::{borrow::ToOwned, sync::Arc, vec::Vec};

#[derive(Default, Debug)]
pub struct GraphMemoryManagement {
Expand Down Expand Up @@ -54,9 +53,8 @@ impl GraphMemoryManagement {
/// This function goes into three steps, which must happen for all leaves
/// before going into the next step. Then it deletes what can be safely deleted
pub(crate) fn free_unavailable_nodes(&mut self, mut on_free_graph: impl FnMut(&NodeId)) {
let leaves = self.leaves.clone();
let mut new_leaves = HashSet::new();
let mut deletables = Vec::new();
// Leaves cache to avoid having a ref to self
let leaves = core::mem::take(&mut self.leaves);

// When consuming nodes with a backward pass, some other backward passes become
// unavailable because some of their parents have been consumed. They are
Expand All @@ -71,14 +69,9 @@ impl GraphMemoryManagement {
// hence the need to iterate on all leaves.
self.useful_propagation(leaves.clone());

// New leaves are the roots of a useful backward sub-tree.
// Add new leaves as the roots of useful backward sub-tree.
// Deletables are everything not marked as useful.
for leaf in leaves {
self.identify_leaves_and_deletables(leaf, &mut new_leaves, &mut deletables);
}

// Replace leaves by the new ones and delete everything not useful anymore
mem::swap(&mut self.leaves, &mut new_leaves);
let mut deletables = self.new_leaves_and_deletables(leaves.into_iter().collect());

self.clear_unused_roots(&mut deletables);

Expand Down Expand Up @@ -205,14 +198,9 @@ impl GraphMemoryManagement {
}
}

fn identify_leaves_and_deletables(
&self,
leaf_id: NodeId,
new_leaves: &mut HashSet<NodeId>,
to_delete: &mut Vec<NodeId>,
) {
fn new_leaves_and_deletables(&mut self, mut to_visit: Vec<NodeId>) -> Vec<NodeId> {
let mut to_delete: Vec<NodeId> = Vec::new();
let mut visited = HashSet::new();
let mut to_visit = vec![leaf_id];

while let Some(node_id) = to_visit.pop() {
visited.insert(node_id);
Expand All @@ -223,7 +211,8 @@ impl GraphMemoryManagement {
.expect("Node should have status")
{
NodeMemoryStatus::Useful => {
new_leaves.insert(node_id);
// New leaves are the roots of a useful backward sub-tree.
self.leaves.insert(node_id);
}
_ => {
to_delete.push(node_id);
Expand All @@ -242,6 +231,7 @@ impl GraphMemoryManagement {
}
};
}
to_delete
}

fn is_referenced(&self, node_id: NodeId) -> bool {
Expand Down
1 change: 1 addition & 0 deletions crates/burn-autodiff/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ mod server;

pub mod graph;
pub use client::*;
pub(crate) use server::NodeCleaner;
6 changes: 3 additions & 3 deletions crates/burn-autodiff/src/runtime/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ use alloc::vec::Vec;

#[derive(Default)]
pub struct AutodiffServer {
steps: HashMap<NodeId, StepBoxed>,
actions_builder: HashMap<NodeId, CheckpointerBuilder>,
memory_management: GraphMemoryManagement,
pub(crate) steps: HashMap<NodeId, StepBoxed>,
pub(crate) actions_builder: HashMap<NodeId, CheckpointerBuilder>,
pub(crate) memory_management: GraphMemoryManagement,
}

/// Defines how nodes are clean.
Expand Down
3 changes: 3 additions & 0 deletions crates/burn-tensor/src/tensor/backend/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,7 @@ pub trait AutodiffBackend: Backend {
///
/// The autodiff backend tensor.
fn q_from_inner(tensor: QuantizedTensor<Self::InnerBackend>) -> QuantizedTensor<Self>;

/// Sweeps over all autodiff graphs and remove unused nodes.
fn graph_cleanup();
}