diff --git a/src/compiler_utils.rs b/src/compiler_utils.rs index 401aa311..2b4ad556 100644 --- a/src/compiler_utils.rs +++ b/src/compiler_utils.rs @@ -112,6 +112,7 @@ impl ToIds for () { } } +#[allow(clippy::implicit_hasher)] impl ToIds for FxHashMap { fn to_ids(&self) -> Vec { self.values().flat_map(|i| i.to_ids()).collect() @@ -192,7 +193,7 @@ impl Compiler for Looped { fn compile(&self, graph: &mut Graph, mut remap: T) { let mut linearized = None; loop { - self.0.compile(graph, &mut remap); + let _cur_output = self.0.compile(graph, &mut remap); graph.toposort(); if linearized == graph.linearized_graph { break; @@ -317,7 +318,7 @@ tuple_impls!( // Helpers impl Graph { - /// Add op on the graph, and get back a NewOp + /// Add `op` on the graph, and get back a `NewOp` /// /// ```rust /// use luminal::prelude::*; @@ -337,7 +338,7 @@ impl Graph { num_srcs: 0, } } - /// Add op on the graph, and get back a NewOp. Just like add_op, except a boxed op is expected. + /// Add `op` on the graph, and get back a `NewOp`. Just like `add_op`, except a boxed op is expected. pub fn add_boxed_op(&mut self, op: Box) -> NewOp<'_> { self.linearized_graph = None; NewOp { @@ -407,11 +408,7 @@ impl Graph { } if show_shapes && new_graph.contains_node(id_map[&edge.target()]) - && edge - .weight() - .as_data() - .map(|d| !d.2.is_empty()) - .unwrap_or_default() + && edge.weight().as_data().is_some_and(|d| !d.2.is_empty()) { new_graph .node_weight_mut(id_map[&edge.target()]) @@ -712,7 +709,7 @@ fn backtrack_match( mapping.insert(pattern_root, main_root); let main_parents = get_parents(main_graph, main_root, |e| !e.weight().is_schedule()); 'pattern_loop: for pattern_parent in get_parents(pattern_graph, pattern_root, |_| true) { - for parent in main_parents.iter() { + for parent in &main_parents { if mapping.values().any(|&v| v == *parent) { // This main node was used already, skip it continue; @@ -764,24 +761,21 @@ fn test_node( return false; } for (a, b) in a_sh.iter().zip(b_sh.dims().into_iter()) { - match a.to_usize() { - Some(n) => { - if b.to_usize().map(|i| i != n).unwrap_or(true) { - return false; - } + if let Some(n) = a.to_usize() { + if b.to_usize() != Some(n) { + return false; } - None => { - let c = a - .to_symbols() - .pop() - .expect("Selector dimension must be either a symbol or number"); - if let Some(expected) = shape_map.get(&c) { - if b != *expected { - return false; - } - } else { - shape_map.insert(c, b); + } else { + let c = a + .to_symbols() + .pop() + .expect("Selector dimension must be either a symbol or number"); + if let Some(expected) = shape_map.get(&c) { + if b != *expected { + return false; } + } else { + shape_map.insert(c, b); } } } @@ -955,6 +949,5 @@ pub fn debug() -> bool { std::env::var("DEBUG") .unwrap_or_default() .parse::() - .map(|i| i == 1) - .unwrap_or_default() + .is_ok_and(|i| i == 1) } diff --git a/src/generic_compiler.rs b/src/generic_compiler.rs index 80fbe2b3..56aacbe7 100644 --- a/src/generic_compiler.rs +++ b/src/generic_compiler.rs @@ -34,23 +34,11 @@ impl Compiler for CSE { while eliminated { eliminated = false; let mut srcs_set: HashMap, Vec> = HashMap::new(); - for node in graph.graph.node_indices().collect_vec() { - if graph - .graph - .node_weight(node) - .unwrap() - .as_any() - .is::() - { + for node in graph.collect_node_indices() { + if graph.this_node_is::(node) { continue; } - let srcs = graph - .graph - .edges_directed(node, petgraph::Direction::Incoming) - .filter(|e| !e.weight().is_schedule()) - .sorted_by_key(|e| e.weight().as_data().unwrap().0) - .map(|e| e.source()) - .collect_vec(); + let srcs = graph.get_incomings(node); if let Some(other_nodes) = srcs_set.get(&srcs) { for other_node in other_nodes { @@ -108,40 +96,23 @@ impl Compiler for RemoveSingleReductions { type Output = (); fn compile(&self, graph: &mut Graph, mut ids: T) { for node in graph.graph.node_indices().collect::>() { - let dim = if let Some(red) = graph - .graph - .node_weight(node) - .unwrap() - .as_any() - .downcast_ref::() - { + let dim = if let Some(red) = graph.get_this_node_is::(node) { Some(red.0) } else { - graph - .graph - .node_weight(node) - .unwrap() - .as_any() - .downcast_ref::() - .map(|red| red.0) + graph.get_this_node_is::(node).map(|red| red.0) }; if let Some(dim) = dim { if graph .graph .edges_directed(node, Direction::Incoming) .next() - .map(|e| { - e.weight() - .as_data() - .map(|w| { - w.2.dims[w.2.indexes[dim]] - .to_usize() - .map(|i| i == 1) - .unwrap_or_default() - }) - .unwrap_or_default() + .is_some_and(|e| { + e.weight().as_data().is_some_and(|w| { + w.2.dims[w.2.indexes[dim]] + .to_usize() + .is_some_and(|i| i == 1) + }) }) - .unwrap_or_default() { let upstream = graph .graph @@ -243,6 +214,7 @@ pub struct ArithmeticElimination; impl Compiler for ArithmeticElimination { type Output = (); + #[allow(clippy::too_many_lines)] fn compile(&self, graph: &mut Graph, mut ids: T) { // x + 0, 0 + x let zero = constant(0.); diff --git a/src/graph.rs b/src/graph.rs index 76aa79be..1d10a08a 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -8,8 +8,12 @@ use std::{ use super::compiler_utils::{ToIds, ToIdsMut}; use colored::Colorize; use itertools::Itertools; -use petgraph::{stable_graph::StableGraph, visit::EdgeRef, Direction}; -use rustc_hash::{FxHashMap, FxHashSet}; +use petgraph::{ + stable_graph::StableGraph, + visit::{EdgeRef, IntoEdgeReferences}, + Direction, +}; +use rustc_hash::{FxBuildHasher, FxHashMap, FxHashSet}; pub type StorageGraph = StableGraph, Dependency>; @@ -27,7 +31,8 @@ pub struct Graph { pub graph: StorageGraph, /// Tensors marked in this set will not get deleted when the graph is ran pub no_delete: FxHashSet, - /// Tensors marked in this set need to be retrieved later (mostly for optimizers to insert copy back calls, the graph itself doesn't treat these differently) + /// Tensors marked in this set need to be retrieved later + /// (mostly for optimizers to insert copy back calls, the graph itself doesn't treat these differently) pub to_retrieve: FxHashMap, /// A cached list of nodes to run, source nodes, and view nodes to delete after execution. #[allow(clippy::type_complexity)] @@ -205,7 +210,7 @@ impl Graph { get_source_tensors(&self.no_delete, &mut self.tensors, src_ids, &consumers); // Substitute in the dyn dims - for (_, st) in srcs.iter_mut() { + for (_, st) in &mut srcs { st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack); } @@ -230,7 +235,7 @@ impl Graph { self.toposort(); } let mut dim_stack = Vec::new(); - for (node, src_ids) in self.linearized_graph.as_ref().unwrap().iter() { + for (node, src_ids) in self.linearized_graph.as_ref().unwrap() { if self.tensors.contains_key(&(*node, 0)) { continue; } @@ -245,7 +250,7 @@ impl Graph { .collect_vec(); // Substitute in the dyn dims - for (_, st) in srcs.iter_mut() { + for (_, st) in &mut srcs { st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack); } @@ -284,7 +289,7 @@ impl Graph { (width.saturating_sub(" Executing ".len())) / 2 ); let start = std::time::Instant::now(); - for (node, src_ids) in self.linearized_graph.as_ref().unwrap().iter() { + for (node, src_ids) in self.linearized_graph.as_ref().unwrap() { if self.tensors.contains_key(&(*node, 0)) { continue; } @@ -295,7 +300,7 @@ impl Graph { get_source_tensors(&self.no_delete, &mut self.tensors, src_ids, &consumers); // Substitute in the dyn dims - for (_, st) in srcs.iter_mut() { + for (_, st) in &mut srcs { st.resolve_global_dyn_dims_stack(&self.dyn_map, &mut dim_stack); } @@ -363,6 +368,134 @@ impl Graph { println!("Total: {}", format_duration(&start.elapsed()).bold()); self.reset(); } + + /// The `Operator` associated to this `node` is of type `F`? + /// Assuming this `node` exists in the graph + #[inline] + pub fn this_node_is(&self, node: NodeIndex) -> bool { + self.graph.node_weight(node).unwrap().as_any().is::() + } + + /// If the `Operator` associated to this `node` is of type `F`, + /// give that, otherwise None. + /// Assuming this `node` exists in the graph + #[inline] + pub fn get_this_node_is(&self, node: NodeIndex) -> Option<&F> { + self.graph + .node_weight(node) + .unwrap() + .as_any() + .downcast_ref::() + } + + /// Gather the nodes which have an edge going to `node` (assume exists) + /// and where the edge connecting them is a data dependency + /// They are properly sorted according to the `input_order` + /// field of the data dependencies. + #[inline] + pub fn get_incomings(&self, node: NodeIndex) -> Vec { + self.graph + .edges_directed(node, petgraph::Direction::Incoming) + .filter(|e| !e.weight().is_schedule()) + .sorted_by_key(|e| e.weight().as_data().unwrap().0) + .map(|e| e.source()) + .collect_vec() + } + + /// All of the node indices + #[inline] + pub fn collect_node_indices(&self) -> Vec { + self.graph.node_indices().collect_vec() + } + + /// The shape trackers for all the sources of `cur_node` (assume exists) + #[inline] + pub fn get_source_shapes(&self, cur_node: &NodeIndex) -> Vec { + self.get_sources(*cur_node) + .into_iter() + .map(|(_, _, a)| a) + .collect_vec() + } + + /// Return all those in `to_retrieve` whose operators are not `FType` + #[inline] + pub fn do_to_retrieve(&self) -> Vec<(NodeIndex, (u8, ShapeTracker))> { + self.to_retrieve + .iter() + .map(|(a, b)| (*a, *b)) + // Filter to non-FType + .filter(|(n, _)| !self.node_weight(*n).unwrap().as_any().is::()) + .collect::>() + } + + pub fn to_retrieve_graph_tensors(&mut self) -> impl Iterator + '_ { + let mut retrieving = self + .to_retrieve + .iter() + .map(|(a, (b, c))| (*a, (*b, *c))) + .collect_vec(); + retrieving.sort_by_key(|(a, (_, _))| *a); + retrieving + .into_iter() + .map(|(id, (_, shape))| GraphTensor::from_id(id, shape, self)) + } + + /// CAUTION: All `GraphTensor`s which refer to self or other + /// will be unusable from here on + pub fn disjoint_union(mut self, mut other: Self) -> Self { + let mut other_remap = FxHashMap::::with_capacity_and_hasher( + other.graph.node_count(), + FxBuildHasher, + ); + for other_node_idx in other.collect_node_indices() { + let other_node_weight = other.graph.node_weight_mut(other_node_idx).unwrap(); + let mut dummy: Box = Box::new(Add); + core::mem::swap(&mut dummy, other_node_weight); + let new_node_idx = self.graph.add_node(dummy); + other_remap.insert(other_node_idx, new_node_idx); + } + for other_edge_ref in other.edge_references() { + let (a, b) = (other_edge_ref.source(), other_edge_ref.target()); + let (a_new, b_new) = (*other_remap.get(&a).unwrap(), *other_remap.get(&b).unwrap()); + self.graph.add_edge(a_new, b_new, *other_edge_ref.weight()); + } + self.tensors.extend( + other + .tensors + .into_iter() + .map(|((a, b), c)| ((*other_remap.get(&a).unwrap(), b), c)), + ); + let bad_keys: Vec = self + .dyn_map + .keys() + .filter_map(|key| { + if other.dyn_map.contains_key(key) + && other.dyn_map.get(key) != self.dyn_map.get(key) + { + Some(*key) + } else { + None + } + }) + .collect(); + assert!(bad_keys.is_empty()); + self.dyn_map.extend(other.dyn_map); + self.no_delete.extend( + other + .no_delete + .into_iter() + .map(|a| *other_remap.get(&a).unwrap()), + ); + self.to_retrieve.extend( + other + .to_retrieve + .into_iter() + .map(|(a, (b, c))| (*other_remap.get(&a).unwrap(), (b, c))), + ); + self.linearized_graph = None; + self.consumers_map = None; + self + } } impl Deref for Graph { diff --git a/src/graph_tensor.rs b/src/graph_tensor.rs index 7949a9ef..995914dc 100644 --- a/src/graph_tensor.rs +++ b/src/graph_tensor.rs @@ -28,7 +28,7 @@ impl From<&GraphTensor> for GraphTensor { } impl GraphTensor { - /// Create a GraphTensor from a NodeIndex + /// Create a `GraphTensor` from a `NodeIndex` pub fn from_id(id: NodeIndex, shape: ShapeTracker, graph_ref: *mut Graph) -> Self { Self { id, diff --git a/src/hl_ops/matmul.rs b/src/hl_ops/matmul.rs index f6eec47f..32a0db0b 100644 --- a/src/hl_ops/matmul.rs +++ b/src/hl_ops/matmul.rs @@ -1,6 +1,7 @@ use crate::prelude::*; impl GraphTensor { + #[allow(clippy::many_single_char_names)] pub fn matmul(mut self, mut rhs: GraphTensor) -> Self { if (self.shape.len() == 1 || self.shape.len() == 2) && rhs.shape.len() == 2 { let vec = self.shape.len() == 1; diff --git a/src/hl_ops/other.rs b/src/hl_ops/other.rs index 02d38b0c..029d9577 100644 --- a/src/hl_ops/other.rs +++ b/src/hl_ops/other.rs @@ -62,7 +62,7 @@ impl Graph { /// ARange from 0 to N pub fn arange(&mut self, to: impl Into) -> GraphTensor { let to = to.into(); - if to.to_usize().map(|i| i == 1).unwrap_or_default() { + if to.to_usize().is_some_and(|i| i == 1) { // Single number ARange is just 0 self.constant(0.).expand_dim(0, to) } else { @@ -88,7 +88,7 @@ impl Graph { /// Lower left-hand triangle of 1s. Currently required to be square /// - /// Same API as https://pytorch.org/docs/stable/generated/torch.tril + /// Same API as pub fn tril(&mut self, size: impl Into, diagonal: i32) -> GraphTensor { let size = size.into(); let horizontal = self.arange(size).expand_dim(0, size); @@ -99,7 +99,7 @@ impl Graph { /// Upper right-hand triangle of 1s /// - /// Same API as https://pytorch.org/docs/stable/generated/torch.triu + /// Same API as pub fn triu(&mut self, size: impl Into, diagonal: i32) -> GraphTensor { let size = size.into(); let horizontal = self.arange(size).expand_dim(0, size).contiguous(); @@ -153,6 +153,7 @@ impl GraphTensor { } /// Check the tensor value against a binary file + #[allow(clippy::too_many_lines)] pub fn diff(&self, file: impl Into, atol: f32, rtol: f32) -> Self { let path = file.into(); let id = self diff --git a/src/hl_ops/unary.rs b/src/hl_ops/unary.rs index dd967e36..6227633b 100644 --- a/src/hl_ops/unary.rs +++ b/src/hl_ops/unary.rs @@ -194,8 +194,8 @@ impl GraphTensor { /// The Gaussian Error Linear Unit activation function #[allow(clippy::excessive_precision)] pub fn gelu(self) -> GraphTensor { - // Based on https://github.com/tinygrad/tinygrad/blob/9fc4465557831b614b56dd645eebc940ca0fa1bb/tinygrad/tensor.py#L1162C26-L1162C104 - 0.5 * self * (1. + (0.7978845608 * self * (1. + 0.044715 * self * self)).tanh()) + // Based on + 0.5 * self * (1. + (0.797_884_560_8 * self * (1. + 0.044_715 * self * self)).tanh()) } } diff --git a/src/module.rs b/src/module.rs index 2198fabd..9bdea0ac 100644 --- a/src/module.rs +++ b/src/module.rs @@ -40,9 +40,11 @@ pub fn transfer_data( dest_graph.tensors.insert((dest, output_num), tensor); output_num += 1; } - if output_num == 0 { - panic!("No source tensor found for node {}", src.index()); - } + assert!( + output_num != 0, + "No source tensor found for node {}", + src.index() + ); } } @@ -54,9 +56,11 @@ pub fn transfer_data_same_graph(srcs: impl ToIds, dests: impl ToIds, graph: &mut graph.tensors.insert((dest, output_num), tensor); output_num += 1; } - if output_num == 0 { - panic!("No source tensor found for node {}", src.index()); - } + assert!( + output_num != 0, + "No source tensor found for node {}", + src.index() + ); } } @@ -158,7 +162,7 @@ impl> Module for Vec { impl> Module for &[M] { type Output = X; fn forward(&self, mut x: X) -> Self::Output { - for layer in self.iter() { + for layer in *self { x = layer.forward(x); } x @@ -168,7 +172,7 @@ impl> Module for &[M] { impl> Module for [M; N] { type Output = X; fn forward(&self, mut x: X) -> Self::Output { - for layer in self.iter() { + for layer in self { x = layer.forward(x); } x @@ -211,14 +215,14 @@ tuple_impls!([M1, M2, M3, M4, M5, M6, M7, M8] [0, 1, 2, 3, 4, 5, 6, 7], M8, [M7, tuple_impls!([M1, M2, M3, M4, M5, M6, M7, M8, M9] [0, 1, 2, 3, 4, 5, 6, 7, 8], M9, [M8, M7, M6, M5, M4, M3, M2, M1]); tuple_impls!([M1, M2, M3, M4, M5, M6, M7, M8, M9, M10] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], M10, [M9, M8, M7, M6, M5, M4, M3, M2, M1]); -/// Tell luminal how to represent the module as a dict of (String, NodeIndex)'s +/// Tell luminal how to represent the module as a dict of `(String, NodeIndex)`'s pub trait SerializeModule { fn serialize(&self, s: &mut Serializer); } impl SerializeModule for &T { fn serialize(&self, s: &mut Serializer) { - (*self).serialize(s) + (*self).serialize(s); } } diff --git a/src/op.rs b/src/op.rs index d5bc634b..acd244a0 100644 --- a/src/op.rs +++ b/src/op.rs @@ -10,7 +10,7 @@ use crate::prelude::*; use dyn_clone::{clone_trait_object, DynClone}; use rustc_hash::FxHashMap; -/// A tensor with data. The data can be anything that implements the Data trait +/// A tensor with data. The data can be anything that implements the `Data` trait #[derive(Debug, Clone)] pub struct Tensor { data: Box, @@ -33,7 +33,10 @@ impl Tensor { } } -/// Some sort of data, for instance a Vec on CPU, CudaSlice on Nvidia GPUs, or metal::Buffer for Apple GPUs +/// Some sort of data, for instance a +/// `Vec` on CPU +/// `CudaSlice` on Nvidia GPUs +/// `metal::Buffer` for Apple GPUs pub trait Data: Any + Debug + DynClone { fn as_any(&self) -> &dyn Any; fn as_any_mut(&mut self) -> &mut dyn Any; @@ -110,6 +113,8 @@ impl Operator for Arc> { #[allow(clippy::type_complexity)] pub struct Function( pub String, + // This could have been a closure `FnMut` that did capture mutable state without issue + // but the use here is intended for honest functions pub Box) -> Vec>, ); @@ -348,8 +353,9 @@ impl Operator for LessThan { let rexpr = (inp[1].1.index_expression(), inp[1].1.valid_expression()); let mut stack = vec![]; for (i, out) in out_data.iter_mut().enumerate() { - *out = (get_index(lhs, &lexpr, &mut stack, i) < get_index(rhs, &rexpr, &mut stack, i)) - as i32 as f32; + *out = i32::from( + get_index(lhs, &lexpr, &mut stack, i) < get_index(rhs, &rexpr, &mut stack, i), + ) as f32; } vec![Tensor::new(out_data)] } @@ -414,11 +420,11 @@ fn get_vec<'a>(tensor: &'a InputTensor<'a>) -> &'a Vec { fn get_index( data: &[f32], - (ind, val): &(Expression, Expression), + (ind, valid_ind): &(Expression, Expression), stack: &mut Vec, index: usize, ) -> f32 { - if val.exec_single_var_stack(index, stack) != 0 { + if valid_ind.exec_single_var_stack(index, stack) != 0 { let i = ind.exec_single_var_stack(index, stack); data[i] } else { diff --git a/src/shape/mod.rs b/src/shape/mod.rs index f9472f7d..631608f4 100644 --- a/src/shape/mod.rs +++ b/src/shape/mod.rs @@ -503,7 +503,7 @@ impl> ToShape for [A; E] { } } -impl> ToShape for Vec { +impl> ToShape for Vec { fn to_shape(self) -> Vec { self.into_iter().map(|i| i.into()).collect() } diff --git a/src/shape/symbolic.rs b/src/shape/symbolic.rs index bcab926a..e3d4f472 100644 --- a/src/shape/symbolic.rs +++ b/src/shape/symbolic.rs @@ -1,4 +1,7 @@ -use egg::*; +use egg::{ + define_language, merge_option, rewrite, Analysis, AstSize, DidMerge, Extractor, FromOp, Id, + Language, RecExpr, RecExprParseError, Runner, Subst, Symbol, Var, +}; use generational_box::{AnyStorage, GenerationalBox, Owner, SyncStorage}; use rustc_hash::FxHashMap; use serde::{Serialize, Serializer}; @@ -116,10 +119,10 @@ impl Term { Term::Mod => Some(|a, b| a.checked_rem(b)), Term::Max => Some(|a, b| Some(a.max(b))), Term::Min => Some(|a, b| Some(a.min(b))), - Term::And => Some(|a, b| Some((a != 0 && b != 0) as i64)), - Term::Or => Some(|a, b| Some((a != 0 || b != 0) as i64)), - Term::Gte => Some(|a, b| Some((a >= b) as i64)), - Term::Lt => Some(|a, b| Some((a < b) as i64)), + Term::And => Some(|a, b| Some(i64::from(a != 0 && b != 0))), + Term::Or => Some(|a, b| Some(i64::from(a != 0 || b != 0))), + Term::Gte => Some(|a, b| Some(i64::from(a >= b))), + Term::Lt => Some(|a, b| Some(i64::from(a < b))), _ => None, } } @@ -132,10 +135,10 @@ impl Term { Term::Mod => Some(|a, b| a % b), Term::Max => Some(|a, b| a.max(b)), Term::Min => Some(|a, b| a.min(b)), - Term::And => Some(|a, b| (a.abs() > 1e-4 && b.abs() > 1e-4) as i32 as f64), - Term::Or => Some(|a, b| (a.abs() > 1e-4 || b.abs() > 1e-4) as i32 as f64), - Term::Gte => Some(|a, b| (a >= b) as i32 as f64), - Term::Lt => Some(|a, b| (a < b) as i32 as f64), + Term::And => Some(|a, b| f64::from(i32::from(a.abs() > 1e-4 && b.abs() > 1e-4))), + Term::Or => Some(|a, b| f64::from(i32::from(a.abs() > 1e-4 || b.abs() > 1e-4))), + Term::Gte => Some(|a, b| f64::from(i32::from(a >= b))), + Term::Lt => Some(|a, b| f64::from(i32::from(a < b))), _ => None, } } @@ -421,7 +424,7 @@ impl Expression { pub fn exec_single_var_stack(&self, value: usize, stack: &mut Vec) -> usize { for term in self.terms.read().iter() { match term { - Term::Num(n) => stack.push(*n as i64), + Term::Num(n) => stack.push(i64::from(*n)), Term::Acc(_) => stack.push(1), Term::Var(_) => stack.push(value as i64), _ => { @@ -445,13 +448,13 @@ impl Expression { ) -> Option { for term in self.terms.read().iter() { match term { - Term::Num(n) => stack.push(*n as i64), + Term::Num(n) => stack.push(i64::from(*n)), Term::Acc(_) => stack.push(1), Term::Var(c) => { #[allow(clippy::needless_borrow)] if let Some(n) = variables.get(&c) { - stack.push(*n as i64) + stack.push(*n as i64); } else { return None; } @@ -477,13 +480,13 @@ impl Expression { ) -> Option { for term in self.terms.read().iter() { match term { - Term::Num(n) => stack.push(*n as f64), + Term::Num(n) => stack.push(f64::from(*n)), Term::Acc(_) => stack.push(1.0), Term::Var(c) => { #[allow(clippy::needless_borrow)] if let Some(n) = variables.get(&c) { - stack.push(*n as f64) + stack.push(*n as f64); } else { return None; } @@ -562,13 +565,13 @@ impl From<&i32> for Expression { impl From for Expression { fn from(value: bool) -> Self { - Expression::new(vec![Term::Num(value as i32)]) + Expression::new(vec![Term::Num(i32::from(value))]) } } impl From<&bool> for Expression { fn from(value: &bool) -> Self { - Expression::new(vec![Term::Num(*value as i32)]) + Expression::new(vec![Term::Num(i32::from(*value))]) } } @@ -903,7 +906,7 @@ fn luminal_to_egg(expr: &Expression) -> RecExpr { for term in expr.terms.read().iter() { match term { Term::Num(_) | Term::Var(_) => { - stack.push(symbolic_expressions::Sexp::String(format!("{term:?}"))) + stack.push(symbolic_expressions::Sexp::String(format!("{term:?}"))); } Term::Acc(_) => stack.push(symbolic_expressions::Sexp::String("1".to_string())), _ => { @@ -918,6 +921,7 @@ fn luminal_to_egg(expr: &Expression) -> RecExpr { } } } + #[allow(clippy::items_after_statements)] fn parse_sexp_into( sexp: &Sexp, expr: &mut RecExpr, @@ -955,68 +959,68 @@ fn egg_to_luminal(expr: RecExpr) -> Expression { match expr.last().unwrap() { Math::Num(i) => vec![Term::Num(*i)], Math::Add([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Add], ] .concat(), Math::Sub([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Sub], ] .concat(), Math::Mul([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Mul], ] .concat(), Math::Div([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Div], ] .concat(), Math::Mod([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Mod], ] .concat(), Math::Min([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Min], ] .concat(), Math::Max([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Max], ] .concat(), Math::And([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::And], ] .concat(), Math::Or([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Or], ] .concat(), Math::LessThan([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Lt], ] .concat(), Math::GreaterThanEqual([a, b]) => [ - create_postfix(&expr[..usize::from(*b) + 1]), - create_postfix(&expr[..usize::from(*a) + 1]), + create_postfix(&expr[..=usize::from(*b)]), + create_postfix(&expr[..=usize::from(*a)]), vec![Term::Gte], ] .concat(), @@ -1065,17 +1069,16 @@ impl Analysis for ConstantFold { let (a, b) = (x(a)?, x(b)?); if a % b != 0 { return None; - } else { - a.checked_div(b)? } + a.checked_div(b)? } Math::Mod([a, b]) => x(a)?.checked_rem(x(b)?)?, Math::Min([a, b]) => x(a)?.min(x(b)?), Math::Max([a, b]) => x(a)?.max(x(b)?), - Math::And([a, b]) => (x(a)? != 0 && x(b)? != 0) as i32, - Math::Or([a, b]) => (x(a)? != 0 || x(b)? != 0) as i32, - Math::LessThan([a, b]) => (x(a)? < x(b)?) as i32, - Math::GreaterThanEqual([a, b]) => (x(a)? >= x(b)?) as i32, + Math::And([a, b]) => i32::from(x(a)? != 0 && x(b)? != 0), + Math::Or([a, b]) => i32::from(x(a)? != 0 || x(b)? != 0), + Math::LessThan([a, b]) => i32::from(x(a)? < x(b)?), + Math::GreaterThanEqual([a, b]) => i32::from(x(a)? >= x(b)?), _ => return None, }) } @@ -1102,14 +1105,14 @@ impl Analysis for ConstantFold { fn is_not_zero(var: &str) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let var = var.parse().unwrap(); - move |egraph, _, subst| egraph[subst[var]].data.map(|i| i != 0).unwrap_or(true) + move |egraph, _, subst| egraph[subst[var]].data != Some(0) } fn is_const_positive(vars: &[&str]) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { let vars: Vec = vars.iter().map(|i| i.parse().unwrap()).collect::>(); move |egraph, _, subst| { vars.iter() - .all(|i| egraph[subst[*i]].data.map(|i| i >= 0).unwrap_or(false)) + .all(|i| egraph[subst[*i]].data.is_some_and(|i| i >= 0)) } } @@ -1191,6 +1194,7 @@ fn egg_simplify(e: Expression, lower_bound_zero: bool) -> Expression { #[cfg(test)] mod tests { use super::*; + #[test] fn test_expressions() { let n = Expression::from('x') + (256 - (Expression::from('x') % 256)); @@ -1215,13 +1219,19 @@ mod tests { let sub = Expression::from('x') / 2; let new = main.substitute('x', sub).simplify(); assert_eq!(new.len(), 5); + assert_eq!(new, (Expression::from('x') + -255 * 2) / 2); } #[test] fn test_group_terms() { let s = Expression::from('s'); let expr = (s * ((s - 4) + 1)) + (((s + 1) * ((s - 4) + 1)) - (s * ((s - 4) + 1))); - assert_eq!(expr.simplify().len(), 7); + let new = expr.simplify(); + assert_eq!(new.len(), 7); + assert_eq!( + new, + (Expression::from('s') + -3) * (Expression::from('s') + 1) + ); } #[test] diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index 31c995fa..db6cdad1 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -13,8 +13,8 @@ impl NewShapeTracker { /// Make a new row-major shape tracker pub fn new(dims: impl ToShape) -> NewShapeTracker { let mut s = Self { - dims: Default::default(), - strides: Default::default(), + dims: ArrayVec::default(), + strides: ArrayVec::default(), }; let mut stride = Expression::from(1); for d in dims.to_shape().into_iter().rev() { @@ -35,8 +35,8 @@ impl NewShapeTracker { "Dimensions and strides need to be the same size!" ); let mut s = Self { - dims: Default::default(), - strides: Default::default(), + dims: ArrayVec::default(), + strides: ArrayVec::default(), }; for (dim, stride) in dims.into_iter().zip(strides) { s.dims.push(dim); @@ -59,11 +59,11 @@ impl ShapeTracker { #[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn new(dims: impl ToShape) -> Self { let mut s = Self { - dims: Default::default(), - indexes: Default::default(), - fake: Default::default(), - mask: Default::default(), - padding: Default::default(), + dims: ArrayVec::default(), + indexes: ArrayVec::default(), + fake: ArrayVec::default(), + mask: ArrayVec::default(), + padding: ArrayVec::default(), }; for (i, d) in dims.to_shape().into_iter().enumerate() { s.dims.push(d); @@ -132,7 +132,7 @@ impl ShapeTracker { pub fn remove_dim(&mut self, axis: usize) -> Expression { let index = self.indexes.remove(axis); self.fake.remove(index); - for i in self.indexes.iter_mut() { + for i in &mut self.indexes { if *i > index { *i -= 1; } @@ -160,11 +160,11 @@ impl ShapeTracker { .rev() .scan(Expression::from(1), |state, i| { let ret = *state; - if !self.fake[i] { + if self.fake[i] { + Some(Expression::from(0)) + } else { *state *= self.dims[i]; Some(ret) - } else { - Some(Expression::from(0)) } }) .collect::>(); @@ -234,11 +234,9 @@ impl ShapeTracker { ret &= dim_ind.gte(greater_than); } ret &= dim_ind.lt(self.dims[i] + self.padding[i].0); - if top_slice - .to_usize() - .map(|s| self.dims[i].to_usize().map(|dim| s < dim).unwrap_or(true)) - .unwrap_or(true) - { + if top_slice.to_usize().map_or(true, |s| { + self.dims[i].to_usize().map_or(true, |dim| s < dim) + }) { ret = ret.min(top_slice); } } @@ -319,7 +317,11 @@ impl ShapeTracker { .collect() } - /// Realize the true shape and convert it to usizes. All dyn dims must be replaced already + /// Realize the true shape and convert it to usizes. + /// + /// # Panics + /// + /// All dyn dims must be replaced already pub fn shape_usize(&self) -> Vec { self.dims().iter().map(|e| e.to_usize().unwrap()).collect() } @@ -340,18 +342,13 @@ impl ShapeTracker { .map(|(i, m)| (self.indexes[i], m)) { // Make sure we aren't padding a masked dimension - if (e.to_usize().map(|n| n != 0).unwrap_or(true) + if ((e.to_usize() != Some(0)) && self.mask[ind] .1 .to_usize() - .map(|n| n as i32 != i32::MAX) - .unwrap_or(true)) - || (s.to_usize().map(|n| n != 0).unwrap_or(true) - && self.mask[ind] - .0 - .to_usize() - .map(|n| n as i32 != 0) - .unwrap_or(true)) + .map_or(true, |n| n as i32 != i32::MAX)) + || ((s.to_usize() != Some(0)) + && self.mask[ind].0.to_usize().map_or(true, |n| n as i32 != 0)) { panic!("Adding padding to a masked shape isn't supported") } @@ -372,14 +369,14 @@ impl ShapeTracker { dyn_dim_map: &FxHashMap, stack: &mut Vec, ) { - for d in self.dims.iter_mut() { + for d in &mut self.dims { *d = d.exec_stack(dyn_dim_map, stack).unwrap().into(); } - for (a, b) in self.padding.iter_mut() { + for (a, b) in &mut self.padding { *a = a.exec_stack(dyn_dim_map, stack).unwrap().into(); *b = b.exec_stack(dyn_dim_map, stack).unwrap().into(); } - for (a, b) in self.mask.iter_mut() { + for (a, b) in &mut self.mask { *a = a.exec_stack(dyn_dim_map, stack).unwrap().into(); *b = b.exec_stack(dyn_dim_map, stack).unwrap().into(); } @@ -387,16 +384,14 @@ impl ShapeTracker { pub fn is_sliced(&self) -> bool { self.mask.iter().any(|(b, e)| { - b.to_usize().map(|i| i != 0).unwrap_or(true) - || e.to_usize().map(|n| n as i32 != i32::MAX).unwrap_or(true) + (b.to_usize() != Some(0)) || e.to_usize().map_or(true, |n| n as i32 != i32::MAX) }) } pub fn is_padded(&self) -> bool { - self.padding.iter().any(|(b, e)| { - b.to_usize().map(|i| i != 0).unwrap_or(true) - || e.to_usize().map(|n| n != 0).unwrap_or(true) - }) + self.padding + .iter() + .any(|(b, e)| (b.to_usize() != Some(0)) || (e.to_usize() != Some(0))) } } @@ -421,8 +416,26 @@ mod tests { tracker.permute(&[2, 0, 1]); println!("Shape: [10, 5, 3]"); println!("Strides: {:?}", tracker.strides()); - println!("Ind: {:?}", tracker.index_expression()); + assert_eq!( + tracker.strides(), + vec![ + Expression::from(1), + Expression::from(15), + Expression::from(3) + ] + ); + let index_expression = tracker.index_expression(); + println!("Ind: {:?}", index_expression); + assert_eq!( + index_expression.substitute('z', 0).simplify(), + Expression::from(0) + ); + assert_eq!( + index_expression.substitute('z', 50).simplify(), + Expression::from(1) + ); println!("Val: {:?}", tracker.valid_expression()); + assert_eq!(tracker.valid_expression(), Expression::from(1)); } #[test] diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 6eeb32ce..f126c0d8 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -68,31 +68,32 @@ pub fn assert_close(a_vec: &[f32], b_vec: &[f32]) { assert_close_precision(a_vec, b_vec, 1e-3); } +/// # Panics +/// /// Ensure two arrays are nearly equal to a decimal place pub fn assert_close_precision(a_vec: &[f32], b_vec: &[f32], threshold: f32) { assert_eq!(a_vec.len(), b_vec.len(), "Number of elements doesn't match"); for (i, (a, b)) in a_vec.iter().zip(b_vec.iter()).enumerate() { - if (a - b).abs() > threshold { - panic!( - "{a} is not close to {b}, index {i}, avg distance: {}", - a_vec - .iter() - .zip(b_vec.iter()) - .map(|(a, b)| (a - b).abs()) - .sum::() - / a_vec.len() as f32 - ); - } + assert!( + (a - b).abs() <= threshold, + "{a} is not close to {b}, index {i}, avg distance: {}", + a_vec + .iter() + .zip(b_vec.iter()) + .map(|(a, b)| (a - b).abs()) + .sum::() + / a_vec.len() as f32 + ); } } +/// # Panics +/// /// Ensure two arrays are exactly equal pub fn assert_exact(a_vec: &[T], b_vec: &[T]) { assert_eq!(a_vec.len(), b_vec.len(), "Number of elements doesn't match"); for (i, (a, b)) in a_vec.iter().zip(b_vec.iter()).enumerate() { - if a != b { - panic!("{a:?} is not equal to {b:?}, index {i}"); - } + assert_eq!(a, b, "{a:?} is not equal to {b:?}, index {i}"); } } diff --git a/src/tests/test_graphs.rs b/src/tests/test_graphs.rs index 6c51ea5e..86cb6244 100644 --- a/src/tests/test_graphs.rs +++ b/src/tests/test_graphs.rs @@ -53,3 +53,34 @@ fn execute_no_delete_keeps_tensors() { super::assert_close(&c.data(), &d_c.as_vec()); } + +#[test] +fn disjoint_pieces() { + let mut cx0 = Graph::new(); + let a0 = cx0.tensor(3).set([1., 2., 3.]); + let b0 = cx0.tensor(3).set([4., 5., 6.]); + let _c0 = (a0 + b0).retrieve(); + let mut cx1 = Graph::new(); + let a1 = cx1.tensor(3).set([7., 8., 9.]); + let b1 = cx1.tensor(3).set([8., 7., 8.]); + let _c1 = (a1 * b1).retrieve(); + + let mut cx = cx0.disjoint_union(cx1); + + cx.execute(); + + let d_dev = dfdx::tensor::Cpu::default(); + let d_a = d_dev.tensor([1., 2., 3.]); + let d_b = d_dev.tensor([4., 5., 6.]); + let d_c0 = d_a + d_b; + let d_a = d_dev.tensor([7., 8., 9.]); + let d_b = d_dev.tensor([8., 7., 8.]); + let d_c1 = d_a * d_b; + + let mut c0c1 = cx.to_retrieve_graph_tensors(); + let c0 = c0c1.next().unwrap(); + let c1 = c0c1.next().unwrap(); + assert_eq!(c0c1.count(), 0); + super::assert_close(&c0.data(), &d_c0.as_vec()); + super::assert_close(&c1.data(), &d_c1.as_vec()); +}