Skip to content

Commit

Permalink
Updates for new cubecl (#2350)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee authored Oct 9, 2024
1 parent e3b10c6 commit c98b689
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 32 deletions.
17 changes: 9 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ tch = "0.15.0"
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c40c0464aab2707487143b2e8d3cf71917c2f841" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c40c0464aab2707487143b2e8d3cf71917c2f841" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "52e21ff91890c8b0ae4943816fdcaa7e3f4be058" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "52e21ff91890c8b0ae4943816fdcaa7e3f4be058" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/fusion/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ pub struct JitFusionHandle<R: JitRuntime> {
/// Compute client for jit.
pub client: ComputeClient<R::Server, R::Channel>,
/// The buffer where the data are stored.
pub handle: cubecl::server::Handle<R::Server>,
pub handle: cubecl::server::Handle,
/// The device of the current tensor.
pub device: R::Device,
pub(crate) strides: Vec<usize>,
Expand Down Expand Up @@ -250,6 +250,7 @@ impl<R: JitRuntime> JitFusionHandle<R> {
handle: &self.handle,
strides: &self.strides,
shape,
runtime: PhantomData,
}
}
/// Return the reference to a tensor argument.
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/fusion/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub trait FusionKernelFactory<R: JitRuntime> {
/// An instantiation of a [kernel](Kernel) that can be executed.
#[derive(new)]
pub struct ExecutableKernel<R: JitRuntime> {
kernel: Box<dyn CubeTask>,
kernel: Box<dyn CubeTask<R::Compiler>>,
cube_count: CubeCount<R::Server>,
bindings: Vec<Binding<R::Server>>,
client: ComputeClient<R::Server, R::Channel>,
Expand All @@ -54,7 +54,7 @@ pub struct ExecutableKernel<R: JitRuntime> {
/// The clone function used is defined in the trait [AutotuneOperation] instead of [Clone].
#[derive(new)]
pub struct AutotunableKernel<R: JitRuntime> {
kernel: Arc<dyn CubeTask>,
kernel: Arc<dyn CubeTask<R::Compiler>>,
count: CubeCount<R::Server>,
bindings: Vec<Binding<R::Server>>,
client: ComputeClient<R::Server, R::Channel>,
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/src/kernel/matmul/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ pub fn matmul<R: JitRuntime, E: FloatElement>(
}
}

pub(crate) fn simple_cube_count<R: JitRuntime>(
pub(crate) fn simple_cube_count(
lhs_shape: &Shape,
rhs_shape: &Shape,
output_shape: &Shape,
cube_dim_x: usize,
cube_dim_y: usize,
) -> CubeCount<R::Server> {
) -> CubeCount {
let ndims = lhs_shape.num_dims();
let num_rows = lhs_shape.dims[ndims - 2];
let num_cols = rhs_shape.dims[ndims - 1];
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/matmul/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement>(
// consecutively in memory, which allows to fetch them with fewer memory instructions
let rhs = into_contiguous(swap_dims(rhs, ndims - 1, ndims - 2));

let cube_count = simple_cube_count::<R>(
let cube_count = simple_cube_count(
&lhs.shape,
&rhs_original_shape,
&out.shape,
Expand Down
8 changes: 2 additions & 6 deletions crates/burn-jit/src/kernel/prng/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement>(
.outputs(&[output.as_handle_ref()])
.with_scalars(&seeds)
.with_scalars(&prng.args())
.execute(CubeCountSettings::Custom(prng_cube_count::<R>(
.execute(CubeCountSettings::Custom(prng_cube_count(
num_elems,
SUBCUBE_DIM_APPROX,
N_VALUES_PER_THREAD,
Expand All @@ -43,11 +43,7 @@ pub(crate) fn random<P: Prng<E>, R: JitRuntime, E: JitElement>(
output
}

fn prng_cube_count<R: JitRuntime>(
num_elems: usize,
cube_dim: usize,
n_values_per_thread: usize,
) -> CubeCount<R::Server> {
fn prng_cube_count(num_elems: usize, cube_dim: usize, n_values_per_thread: usize) -> CubeCount {
let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
let num_elems_per_cube = cube_dim * cube_dim;
let num_invocations = f32::ceil(num_threads / num_elems_per_cube as f32);
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-jit/src/kernel/reduce/naive/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ pub fn reduce_dim_naive<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, E
dim: usize,
) -> JitTensor<R, EO> {
let cube_dim = CubeDim::default();
let cube_count =
calculate_cube_count_elemwise::<R::Server>(output.shape.num_elements(), cube_dim);
let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);

unsafe {
naive_reduce_dim_compute_shader::launch_unchecked::<RD, EI, EO, R>(
Expand Down
6 changes: 1 addition & 5 deletions crates/burn-jit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ pub(crate) mod tune;
pub mod element;

use burn_tensor::backend::{DeviceId, DeviceOps};
use cubecl::{
compute::{CubeCount, CubeTask},
Feature, Runtime,
};
use cubecl::{compute::CubeTask, Feature, Runtime};
pub use element::{FloatElement, IntElement, JitElement};

mod backend;
Expand Down Expand Up @@ -55,7 +52,6 @@ pub trait JitRuntime: Runtime<Device = Self::JitDevice, Server = Self::JitServer
/// The cube server with the [JitAutotuneKey].
type JitServer: cubecl::server::ComputeServer<
Kernel = Box<dyn CubeTask<Self::Compiler>>,
DispatchOptions = CubeCount<Self::JitServer>,
Feature = Feature,
>;
}
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-jit/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
/// Compute client for the [runtime](JitRuntime).
pub client: ComputeClient<R::Server, R::Channel>,
/// The buffer where the data are stored.
pub handle: Handle<R::Server>,
pub handle: Handle,
/// The shape of the tensor.
pub shape: Shape,
/// The device of the tensor.
Expand Down Expand Up @@ -79,7 +79,7 @@ where
client: ComputeClient<R::Server, R::Channel>,
device: R::Device,
shape: Shape,
handle: Handle<R::Server>,
handle: Handle,
) -> Self {
let ndims = shape.num_dims();
let mut strides = vec![0; ndims];
Expand Down Expand Up @@ -133,6 +133,7 @@ where
handle: &self.handle,
strides: &self.strides,
shape: &self.shape.dims,
runtime: PhantomData,
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ autotune = ["burn-jit/autotune"]
default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
fusion = ["burn-fusion", "burn-jit/fusion"]
simple-memory-management = ["cubecl/simple-memory-management"]
exclusive-memory-only = ["cubecl/exclusive-memory-only"]
std = ["burn-jit/std", "cubecl/std"]
template = ["burn-jit/template", "cubecl/template"]

Expand Down

0 comments on commit c98b689

Please sign in to comment.