Skip to content

Commit

Permalink
Feat/cube/slice (#2004)
Browse files Browse the repository at this point in the history
* Refactor Variable types

* Sice

* Implement slice wgsl

* handle lifetime correctly

* Add cuda impl

* Update cmma

* Cleanup

* Fix tests

* Fix slice signature
  • Loading branch information
nathanielsimard authored Jul 11, 2024
1 parent c30ffcf commit 35345de
Show file tree
Hide file tree
Showing 70 changed files with 1,663 additions and 565 deletions.
26 changes: 20 additions & 6 deletions crates/burn-cube/src/codegen/integrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,17 +423,24 @@ impl KernelIntegrator {
} else {
item
};
let elem_adapted = bool_item(item);
let item_adapted = bool_item(item);

self.output_bindings.push(Binding {
item: elem_adapted,
item: item_adapted,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
});
self.expansion.scope.write_global(
Variable::Local(local, item, self.expansion.scope.depth),
Variable::GlobalOutputArray(index, elem_adapted),
Variable::Local {
id: local,
item,
depth: self.expansion.scope.depth,
},
Variable::GlobalOutputArray {
id: index,
item: item_adapted,
},
position,
);
index += 1;
Expand All @@ -451,8 +458,15 @@ impl KernelIntegrator {
};

self.expansion.scope.write_global(
Variable::Local(local, item, self.expansion.scope.depth),
Variable::GlobalInputArray(input, bool_item(item)),
Variable::Local {
id: local,
item,
depth: self.expansion.scope.depth,
},
Variable::GlobalInputArray {
id: input,
item: bool_item(item),
},
position,
);
}
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-cube/src/frontend/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ where

if unroll {
let start = match start.deref() {
Variable::ConstantScalar(val, _) => *val as usize,
Variable::ConstantScalar { value, .. } => *value as usize,
_ => panic!("Only constant start can be unrolled."),
};
let end = match end.deref() {
Variable::ConstantScalar(val, _) => *val as usize,
Variable::ConstantScalar { value, .. } => *value as usize,
_ => panic!("Only constant end can be unrolled."),
};

Expand Down
26 changes: 16 additions & 10 deletions crates/burn-cube/src/frontend/cmma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
//! 16,
//! 16,
//! 16,
//! cmma::MatrixLayout::ColMajor,
//! cmma::MatrixLayout::RowMajor,
//! );
//! let b = cmma::Matrix::<F16>::new(
//! cmma::MatrixIdent::B,
//! 16,
//! 16,
//! 16,
//! cmma::MatrixLayout::RowMajor,
//! cmma::MatrixLayout::ColMajor,
//! );
//! let c = cmma::Matrix::<F32>::new(
//! cmma::MatrixIdent::Accumulator,
Expand All @@ -32,12 +32,17 @@
//! cmma::MatrixLayout::Undefined,
//! );
//! cmma::fill::<F32>(&c, F32::new(0.0));
//! cmma::load::<F16>(&a, lhs, UInt::new(16));
//! cmma::load::<F16>(&b, rhs, UInt::new(16));
//! cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
//! cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
//!
//! cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
//!
//! cmma::store::<F32>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
//! cmma::store::<F32>(
//! out.as_slice_mut(),
//! &c,
//! UInt::new(16),
//! cmma::MatrixLayout::RowMajor,
//! );
//! }
//! ```
Expand All @@ -49,7 +54,8 @@ use crate::{
};

use super::{
Array, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, UInt,
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut,
UInt,
};

pub use ir::{MatrixIdent, MatrixLayout};
Expand Down Expand Up @@ -137,7 +143,7 @@ pub fn fill_expand<C: CubeType>(

/// Load the matrix with the provided array using the stride.
#[allow(unused_variables)]
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Array<C>, stride: UInt) {
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
unexpanded!()
}

Expand All @@ -146,7 +152,7 @@ pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Array<C>, stride: UInt) {
pub fn load_expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElementTyped<Array<C>>,
value: ExpandElementTyped<Slice<'static, C>>,
stride: ExpandElement,
) {
context.register(Operation::CoopMma(ir::CoopMma::Load {
Expand All @@ -159,7 +165,7 @@ pub fn load_expand<C: CubeType>(
/// Store the matrix in the given array following the given stride and layout.
#[allow(unused_variables)]
pub fn store<C: CubePrimitive>(
output: &Array<C>,
output: &mut SliceMut<'_, C>,
mat: &Matrix<C>,
stride: UInt,
layout: MatrixLayout,
Expand All @@ -171,7 +177,7 @@ pub fn store<C: CubePrimitive>(
#[allow(unused_variables)]
pub fn store_expand<C: CubePrimitive>(
context: &mut CubeContext,
output: ExpandElementTyped<Array<C>>,
output: ExpandElementTyped<SliceMut<'static, C>>,
mat: MatrixExpand,
stride: ExpandElement,
layout: MatrixLayout,
Expand Down
30 changes: 14 additions & 16 deletions crates/burn-cube/src/frontend/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ use alloc::rc::Rc;
use core::cell::RefCell;
use std::collections::HashMap;

use super::{CubePrimitive, SharedMemoryExpand};

#[derive(Default, Clone)]
pub struct VariablePool {
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
Expand Down Expand Up @@ -114,34 +112,34 @@ impl CubeContext {
ExpandElement::Plain(variable)
}

pub fn create_shared<T: CubePrimitive>(
&mut self,
item: Item,
size: u32,
) -> SharedMemoryExpand<T> {
SharedMemoryExpand {
val: ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size)),
}
/// Create a new slice element.
pub fn create_slice(&mut self, item: Item) -> ExpandElement {
let variable = self.scope.borrow_mut().create_slice(item);
ExpandElement::Plain(variable)
}

pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
}

pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size))
}

/// Obtain the index-th input
pub fn input(&mut self, index: u16, item: Item) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item))
pub fn input(&mut self, id: u16, item: Item) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item })
}

/// Obtain the index-th output
pub fn output(&mut self, index: u16, item: Item) -> ExpandElement {
let var = crate::ir::Variable::GlobalOutputArray(index, item);
pub fn output(&mut self, id: u16, item: Item) -> ExpandElement {
let var = crate::ir::Variable::GlobalOutputArray { id, item };
self.scope.borrow_mut().write_global_custom(var);
ExpandElement::Plain(var)
}

/// Obtain the index-th scalar
pub fn scalar(&self, index: u16, elem: Elem) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalScalar(index, elem))
pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem })
}
}
4 changes: 2 additions & 2 deletions crates/burn-cube/src/frontend/element/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Array need constant initialization value"),
};
context
Expand All @@ -55,7 +55,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context
Expand Down
21 changes: 11 additions & 10 deletions crates/burn-cube/src/frontend/element/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl ExpandElement {
pub fn can_mut(&self) -> bool {
match self {
ExpandElement::Managed(var) => {
if let Variable::Local(_, _, _) = var.as_ref() {
if let Variable::Local { .. } = var.as_ref() {
Rc::strong_count(var) <= 2
} else {
false
Expand Down Expand Up @@ -201,10 +201,10 @@ impl Init for ExpandElement {
let mut init = |elem: Self| init_expand(context, elem, Operator::Assign);

match *self {
Variable::GlobalScalar(_, _) => init(self),
Variable::LocalScalar(_, _, _) => init(self),
Variable::ConstantScalar(_, _) => init(self),
Variable::Local(_, _, _) => init(self),
Variable::GlobalScalar { .. } => init(self),
Variable::LocalScalar { .. } => init(self),
Variable::ConstantScalar { .. } => init(self),
Variable::Local { .. } => init(self),
// Constant should be initialized since the new variable can be mutated afterward.
// And it is assumed those values are cloned.
Variable::Rank
Expand All @@ -230,11 +230,12 @@ impl Init for ExpandElement {
| Variable::AbsolutePosY
| Variable::AbsolutePosZ => init(self),
// Array types can't be copied, so we should simply return the same variable.
Variable::SharedMemory(_, _, _)
| Variable::GlobalInputArray(_, _)
| Variable::GlobalOutputArray(_, _)
| Variable::LocalArray(_, _, _, _)
| Variable::Matrix(_, _) => self,
Variable::SharedMemory { .. }
| Variable::GlobalInputArray { .. }
| Variable::GlobalOutputArray { .. }
| Variable::LocalArray { .. }
| Variable::Slice { .. }
| Variable::Matrix { .. } => self,
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-cube/src/frontend/element/cube_elem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ impl_into_expand_element!(i64);
/// Useful for Comptime
impl From<UInt> for ExpandElement {
fn from(value: UInt) -> Self {
ExpandElement::Plain(crate::ir::Variable::ConstantScalar(
value.val as f64,
UInt::as_elem(),
))
ExpandElement::Plain(crate::ir::Variable::ConstantScalar {
value: value.val as f64,
elem: UInt::as_elem(),
})
}
}
5 changes: 4 additions & 1 deletion crates/burn-cube/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ macro_rules! impl_float {
}

fn new_expand(_context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}

Expand Down
5 changes: 4 additions & 1 deletion crates/burn-cube/src/frontend/element/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ macro_rules! impl_int {
}

fn new_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-cube/src/frontend/element/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod float;
mod int;
mod numeric;
mod shared_memory;
mod slice;
mod tensor;
mod uint;
mod vectorized;
Expand All @@ -19,6 +20,7 @@ pub use float::*;
pub use int::*;
pub use numeric::*;
pub use shared_memory::*;
pub use slice::*;
pub use tensor::*;
pub use uint::*;
pub use vectorized::*;
5 changes: 4 additions & 1 deletion crates/burn-cube/src/frontend/element/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ pub trait Numeric:

/// Expand version of from_int
fn from_int_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}

Expand Down
29 changes: 10 additions & 19 deletions crates/burn-cube/src/frontend/element/shared_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,21 @@ use crate::{
ir::Item,
};

use super::{ExpandElement, Init, UInt};
use super::{ExpandElementTyped, Init, UInt};

#[derive(Clone, Copy)]
pub struct SharedMemory<T: CubeType> {
_val: PhantomData<T>,
}

#[derive(Clone)]
pub struct SharedMemoryExpand<T: CubePrimitive> {
pub val: <T as CubeType>::ExpandType,
}

impl<T: CubePrimitive> From<SharedMemoryExpand<T>> for ExpandElement {
fn from(shared_memory_expand: SharedMemoryExpand<T>) -> Self {
shared_memory_expand.val
}
}

impl<T: CubePrimitive> Init for SharedMemoryExpand<T> {
impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}

impl<T: CubePrimitive> CubeType for SharedMemory<T> {
type ExpandType = SharedMemoryExpand<T>;
type ExpandType = ExpandElementTyped<SharedMemory<T>>;
}

impl<T: CubePrimitive + Clone> SharedMemory<T> {
Expand All @@ -44,10 +33,11 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context.create_shared(Item::new(T::as_elem()), size)
let var = context.create_shared(Item::new(T::as_elem()), size);
ExpandElementTyped::new(var)
}

pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
Expand All @@ -61,12 +51,13 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context.create_shared(
let var = context.create_shared(
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
size,
)
);
ExpandElementTyped::new(var)
}
}
Loading

0 comments on commit 35345de

Please sign in to comment.