From c579686a8adee6358877ba2c0806787bbc0e0195 Mon Sep 17 00:00:00 2001 From: Sylvain Benner Date: Tue, 23 Apr 2024 11:27:54 -0400 Subject: [PATCH] Move HandleContainer and Tensor Ops descriptions from burn-fusion to burn-tensor (#1654) * Move HandlerContainer and Tensor Ops description to burn-tensor Move HandleContainer and Tensor operations descriptions to burn-tensor crate. Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device. For now added modules to burn-tensor are excluded from no-std as they rely on Arc. * [burn-tensor] Flatten module hierarchy for tensor representation + Add new repr feature to cargo file. * Remove prefix on dosctring * [burn-fusion] Require default features of burn-tensor --- Cargo.lock | 8 + crates/burn-candle/src/backend.rs | 15 +- crates/burn-fusion/Cargo.toml | 2 +- crates/burn-fusion/src/backend.rs | 61 +--- crates/burn-fusion/src/client/base.rs | 14 +- crates/burn-fusion/src/client/mutex.rs | 41 ++- crates/burn-fusion/src/fusion.rs | 14 +- crates/burn-fusion/src/lib.rs | 2 - crates/burn-fusion/src/ops/binary.rs | 8 +- crates/burn-fusion/src/ops/boolean.rs | 53 +-- crates/burn-fusion/src/ops/float.rs | 82 ++--- crates/burn-fusion/src/ops/int.rs | 113 +++--- crates/burn-fusion/src/ops/module.rs | 142 ++++---- crates/burn-fusion/src/ops/unary.rs | 18 +- crates/burn-fusion/src/server.rs | 25 +- crates/burn-fusion/src/stream/base.rs | 7 +- crates/burn-fusion/src/stream/context.rs | 79 ++-- .../burn-fusion/src/stream/execution/base.rs | 12 +- .../src/stream/execution/explorer.rs | 4 +- .../src/stream/execution/policy.rs | 20 +- .../src/stream/execution/processor.rs | 3 +- .../burn-fusion/src/stream/execution/tests.rs | 41 ++- .../src/stream/execution/validator.rs | 7 +- crates/burn-fusion/src/stream/mod.rs | 2 - crates/burn-fusion/src/stream/multi.rs | 16 +- crates/burn-fusion/src/stream/store/base.rs | 2 +- crates/burn-fusion/src/stream/store/index.rs | 12 +- crates/burn-fusion/src/tensor.rs | 48 +-- crates/burn-jit/src/codegen/compilation.rs | 20 +- crates/burn-jit/src/fusion/base.rs | 44 ++- .../burn-jit/src/fusion/elemwise/builder.rs | 12 +- crates/burn-jit/src/fusion/elemwise/kernel.rs | 3 +- crates/burn-jit/src/fusion/kernel.rs | 3 +- crates/burn-jit/src/fusion/tracing/builder.rs | 8 +- crates/burn-jit/src/fusion/tracing/trace.rs | 2 +- crates/burn-jit/src/runtime.rs | 13 +- crates/burn-ndarray/src/backend.rs | 10 +- crates/burn-tch/src/backend.rs | 13 +- crates/burn-tensor/Cargo.toml | 3 +- crates/burn-tensor/src/lib.rs | 4 + crates/burn-tensor/src/repr/backend.rs | 26 ++ .../src => burn-tensor/src/repr}/handle.rs | 59 +-- crates/burn-tensor/src/repr/mod.rs | 9 + .../src/repr}/operation.rs | 338 +++++++++--------- crates/burn-tensor/src/repr/tensor.rs | 45 +++ crates/burn-tensor/src/tensor/backend/base.rs | 4 +- .../burn-tensor/src/tensor/backend/device.rs | 14 + crates/burn-tensor/src/tensor/backend/mod.rs | 2 + crates/burn-wgpu/src/fusion.rs | 14 - crates/burn-wgpu/src/lib.rs | 3 - crates/burn-wgpu/src/runtime.rs | 13 + 51 files changed, 775 insertions(+), 738 deletions(-) create mode 100644 crates/burn-tensor/src/repr/backend.rs rename crates/{burn-fusion/src => burn-tensor/src/repr}/handle.rs (69%) create mode 100644 crates/burn-tensor/src/repr/mod.rs rename crates/{burn-fusion/src/stream => burn-tensor/src/repr}/operation.rs (77%) create mode 100644 crates/burn-tensor/src/repr/tensor.rs create mode 100644 crates/burn-tensor/src/tensor/backend/device.rs delete mode 100644 crates/burn-wgpu/src/fusion.rs diff --git a/Cargo.lock b/Cargo.lock index c8bbcbe3b4..43c58a6faa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3696,6 +3696,14 @@ dependencies = [ "thiserror", ] +[[package]] +name = "refactor" +version = "0.14.0" +dependencies = [ + "burn", + "serde", +] + [[package]] name = "regex" version = "1.10.4" diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index 9a1e997445..f66a28bb29 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -1,6 +1,9 @@ use std::marker::PhantomData; -use burn_tensor::{backend::Backend, Device}; +use burn_tensor::{ + backend::{Backend, DeviceId, DeviceOps}, + Device, +}; use candle_core::DeviceLocation; use crate::{ @@ -60,6 +63,16 @@ impl From for CandleDevice { } } +impl DeviceOps for CandleDevice { + fn id(&self) -> burn_tensor::backend::DeviceId { + match self { + CandleDevice::Cpu => DeviceId::new(0, 0), + CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32), + CandleDevice::Metal(index) => DeviceId::new(2, *index as u32), + } + } +} + impl Default for CandleDevice { fn default() -> Self { Self::Cpu diff --git a/crates/burn-fusion/Cargo.toml b/crates/burn-fusion/Cargo.toml index 02e5640f4f..1d93bc6d20 100644 --- a/crates/burn-fusion/Cargo.toml +++ b/crates/burn-fusion/Cargo.toml @@ -16,7 +16,7 @@ std = ["serde/std"] doc = ["default"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.14.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.14.0" } burn-common = { path = "../burn-common", version = "0.14.0" } hashbrown = { workspace = true } derive-new = {workspace = true } diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index 79c3680472..3d4167e5cc 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -1,15 +1,17 @@ use crate::{ - client::FusionClient, - stream::{Context, OperationDescription}, - FusionClientLocator, FusionTensor, PrecisionBridge, + client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge, +}; +use burn_tensor::{ + backend::Backend, + repr::{OperationDescription, ReprBackend}, + Device, }; -use burn_tensor::{backend::Backend, Device, Shape}; use serde::{de::DeserializeOwned, Serialize}; use std::marker::PhantomData; pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new(); -pub(crate) fn get_client(device: &B::FusionDevice) -> B::FusionClient { +pub(crate) fn get_client(device: &B::Device) -> B::FusionClient { CLIENTS.client(device) } @@ -43,7 +45,7 @@ impl Backend for Fusion { } fn sync(device: &Self::Device) { - let client = CLIENTS.client::(&device.clone().into()); + let client = CLIENTS.client::(&device.clone()); client.drain(); B::sync(device) } @@ -114,62 +116,17 @@ pub trait Optimization: Send { fn from_state(device: &B::Device, state: B::OptimizationState) -> Self; } -/// The device id. -#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)] -pub struct DeviceId { - /// The type id identifies the type of the device. - pub type_id: u16, - /// The index id identifies the device number. - pub index_id: u32, -} - -/// The handle device trait allows to get an id for a backend device. -pub trait FusionDevice: Clone + Send + Sync + PartialEq { - /// Return the [device id](DeviceId). - fn id(&self) -> DeviceId; -} - /// Trait that allows an existing [backend](Backend) to specify graph optimizations using /// [operation builder](crate::OptimizationBuilder). -pub trait FusionBackend: Backend { +pub trait FusionBackend: Backend + ReprBackend { /// The state that can be serialized for an optimization. type OptimizationState: Serialize + DeserializeOwned; /// Optimization type for the backend. type Optimization: Optimization; - - /// The device type that can return an ID. - /// - /// It can be the same as (Backend::Device), but must implement (FusionDevice). - type FusionDevice: FusionDevice + From + Into + core::fmt::Debug; - /// The type that can be used to point to a tensor of any kind. - type Handle: Sync + Send + Clone; /// What kind of client should be used. type FusionClient: FusionClient; /// The list of optimizations that will be used to optimize the computational graph. fn optimizations(device: Device) -> Vec>>; - - /// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive). - fn float_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::FloatTensorPrimitive; - /// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive). - fn int_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::IntTensorPrimitive; - /// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). - fn bool_tensor( - handle: Self::Handle, - shape: Shape, - ) -> Self::BoolTensorPrimitive; - - /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](FusionBackend::Handle). - fn float_tensor_handle(tensor: Self::FloatTensorPrimitive) -> Self::Handle; - /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle). - fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle; - /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle). - fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle; } diff --git a/crates/burn-fusion/src/client/base.rs b/crates/burn-fusion/src/client/base.rs index 2332315dee..377f98b8f0 100644 --- a/crates/burn-fusion/src/client/base.rs +++ b/crates/burn-fusion/src/client/base.rs @@ -1,10 +1,12 @@ use crate::{ - stream::{Operation, OperationDescription, StreamId}, - FusionBackend, FusionTensor, Handle, TensorDescription, TensorId, + stream::{execution::Operation, StreamId}, + FusionBackend, FusionTensor, Handle, }; use burn_tensor::{ + backend::Backend, ops::{FloatElem, IntElem}, - Data, Reader, + repr::{OperationDescription, TensorDescription, TensorId}, + Data, Device, Reader, }; /// Define how to interact with the fusion server. @@ -12,8 +14,8 @@ pub trait FusionClient: Send + Sync + Clone { /// The [fusion backend](FusionBackend) associated type. type FusionBackend: FusionBackend; - /// Create a new client for the given [fusion device](FusionBackend::FusionDevice). - fn new(device: ::FusionDevice) -> Self; + /// Create a new client for the given [device](Backend::Device). + fn new(device: Device) -> Self; /// Register a new [tensor operation description](OperationDescription). fn register + 'static>( &self, @@ -24,7 +26,7 @@ pub trait FusionClient: Send + Sync + Clone { /// Register all lazy computation. fn drain(&self); /// Get the current device used by all operations handled by this client. - fn device(&self) -> &::FusionDevice; + fn device(&self) -> &::Device; /// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it. fn tensor_uninitialized(&self, shape: Vec) -> FusionTensor; /// Create a tensor with the given handle and shape. diff --git a/crates/burn-fusion/src/client/mutex.rs b/crates/burn-fusion/src/client/mutex.rs index 01d5f2a3f1..940d331f0f 100644 --- a/crates/burn-fusion/src/client/mutex.rs +++ b/crates/burn-fusion/src/client/mutex.rs @@ -1,9 +1,13 @@ use super::FusionClient; use crate::{ - stream::{Operation, OperationDescription, StreamId}, + stream::{execution::Operation, StreamId}, FusionBackend, FusionServer, FusionTensor, Handle, }; -use burn_tensor::ops::FloatElem; +use burn_tensor::{ + backend::Backend, + ops::FloatElem, + repr::{OperationDescription, TensorDescription, TensorId}, +}; use spin::Mutex; use std::sync::Arc; @@ -13,7 +17,7 @@ where B: FusionBackend, { server: Arc>>, - device: B::FusionDevice, + device: B::Device, } impl Clone for MutexFusionClient @@ -34,7 +38,7 @@ where { type FusionBackend = B; - fn new(device: B::FusionDevice) -> Self { + fn new(device: B::Device) -> Self { Self { device: device.clone(), server: Arc::new(Mutex::new(FusionServer::new(device))), @@ -63,7 +67,7 @@ where FusionTensor::new(id, shape, self.clone(), StreamId::current()) } - fn device(&self) -> &::FusionDevice { + fn device(&self) -> &::Device { &self.device } fn register_tensor( @@ -82,7 +86,7 @@ where fn read_tensor_float( &self, - tensor: crate::TensorDescription, + tensor: TensorDescription, stream: StreamId, ) -> burn_tensor::Reader, D>> { self.server.lock().read_float(tensor, stream) @@ -90,7 +94,7 @@ where fn read_tensor_int( &self, - tensor: crate::TensorDescription, + tensor: TensorDescription, id: StreamId, ) -> burn_tensor::Reader, D>> { @@ -99,7 +103,7 @@ where fn read_tensor_bool( &self, - tensor: crate::TensorDescription, + tensor: TensorDescription, stream: StreamId, ) -> burn_tensor::Reader> { self.server.lock().read_bool(tensor, stream) @@ -107,17 +111,16 @@ where fn change_client_float( &self, - tensor: crate::TensorDescription, + tensor: TensorDescription, client: Self, stream: StreamId, ) -> FusionTensor { - let device = client.device.clone().into(); - let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); - let id = server_current.change_server_float::(&tensor, &device, &mut server_other); + let id = + server_current.change_server_float::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); @@ -127,17 +130,15 @@ where fn change_client_int( &self, - tensor: crate::TensorDescription, + tensor: TensorDescription, client: Self, stream: StreamId, ) -> FusionTensor { - let device = client.device.clone().into(); - let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); - let id = server_current.change_server_int::(&tensor, &device, &mut server_other); + let id = server_current.change_server_int::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); @@ -147,17 +148,15 @@ where fn change_client_bool( &self, - tensor: crate::TensorDescription, + tensor: TensorDescription, client: Self, stream: StreamId, ) -> FusionTensor { - let device = client.device.clone().into(); - let mut server_other = client.server.lock(); let mut server_current = self.server.lock(); server_current.drain_stream(stream); - let id = server_current.change_server_bool::(&tensor, &device, &mut server_other); + let id = server_current.change_server_bool::(&tensor, &client.device, &mut server_other); core::mem::drop(server_other); core::mem::drop(server_current); @@ -165,7 +164,7 @@ where FusionTensor::new(id, tensor.shape, client, StreamId::current()) } - fn register_orphan(&self, id: &crate::TensorId) { + fn register_orphan(&self, id: &TensorId) { self.server.lock().drop_tensor_handle(*id); } } diff --git a/crates/burn-fusion/src/fusion.rs b/crates/burn-fusion/src/fusion.rs index f26630a561..f224fdaa33 100644 --- a/crates/burn-fusion/src/fusion.rs +++ b/crates/burn-fusion/src/fusion.rs @@ -1,8 +1,14 @@ -use crate::{client::FusionClient, DeviceId, FusionBackend, FusionDevice}; +use burn_tensor::{ + backend::{Backend, DeviceId, DeviceOps}, + repr::ReprBackend, +}; + +use crate::client::FusionClient; + use std::{any::Any, collections::HashMap, ops::DerefMut}; -/// Type alias for [fusion backend handle](FusionBackend::Handle). -pub type Handle = ::Handle; +/// Type alias for [representation backend handle](burn_tensor::repr::ReprBackend::Handle). +pub type Handle = ::Handle; type Key = (core::any::TypeId, DeviceId); pub(crate) struct FusionClientLocator { @@ -22,7 +28,7 @@ impl FusionClientLocator { /// Provide the init function to create a new client if it isn't already initialized. pub fn client( &self, - device: &::FusionDevice, + device: &::Device, ) -> C { let device_id = device.id(); let client_id = (core::any::TypeId::of::(), device_id); diff --git a/crates/burn-fusion/src/lib.rs b/crates/burn-fusion/src/lib.rs index 217e434f73..1eb1454c73 100644 --- a/crates/burn-fusion/src/lib.rs +++ b/crates/burn-fusion/src/lib.rs @@ -16,7 +16,6 @@ pub mod stream; mod backend; mod bridge; mod fusion; -mod handle; mod ops; mod server; mod tensor; @@ -26,5 +25,4 @@ pub(crate) use server::*; pub use backend::*; pub use bridge::*; pub use fusion::*; -pub use handle::*; pub use tensor::*; diff --git a/crates/burn-fusion/src/ops/binary.rs b/crates/burn-fusion/src/ops/binary.rs index aa5657502f..e7035f674e 100644 --- a/crates/burn-fusion/src/ops/binary.rs +++ b/crates/burn-fusion/src/ops/binary.rs @@ -11,7 +11,7 @@ macro_rules! binary_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let rhs = handles.get_float_tensor(&self.desc.rhs); let output = $ops(lhs, rhs); @@ -35,7 +35,7 @@ macro_rules! binary_float_cmp_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let rhs = handles.get_float_tensor(&self.desc.rhs); let output = $ops(lhs, rhs); @@ -59,7 +59,7 @@ macro_rules! binary_int_cmp_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let rhs = handles.get_int_tensor(&self.desc.rhs); let output = $ops(lhs, rhs); @@ -93,7 +93,7 @@ macro_rules! binary_int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let rhs = handles.get_int_tensor(&self.desc.rhs); let output = $ops(lhs, rhs); diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index 97e82008ce..c641209046 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -2,23 +2,24 @@ use crate::{ client::FusionClient, get_client, ops::binary::binary_ops_shape, - stream::{ - BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, - CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, Operation, - OperationDescription, PermuteOperationDescription, RepeatOperationDescription, - ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId, - SwapDimsDescription, UnaryOperationDescription, - }, + stream::{execution::Operation, StreamId}, Fusion, FusionBackend, }; use burn_tensor::{ ops::{BoolTensor, BoolTensorOps}, + repr::{ + BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, + CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, + HandleContainer, OperationDescription, PermuteOperationDescription, + RepeatOperationDescription, ReshapeDescription, SliceAssignOperationDescription, + SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, + }, Device, Shape, }; impl BoolTensorOps for Fusion { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let tensor = B::bool_empty(shape.clone(), device); client.register_tensor( @@ -42,7 +43,7 @@ impl BoolTensorOps for Fusion { data: burn_tensor::Data, device: &Device, ) -> BoolTensor { - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let tensor = B::bool_from_data(data, device); let shape = B::bool_shape(&tensor); @@ -62,7 +63,7 @@ impl BoolTensorOps for Fusion { } impl Operation for IntoIntOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_int(input); handles.register_int_tensor(&self.desc.out.id, output); @@ -95,7 +96,7 @@ impl BoolTensorOps for Fusion { } impl Operation for IntoFloatOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_into_float(input); handles.register_float_tensor(&self.desc.out.id, output); @@ -119,15 +120,15 @@ impl BoolTensorOps for Fusion { } fn bool_device(tensor: &BoolTensor) -> Device { - tensor.client.device().clone().into() + tensor.client.device().clone() } fn bool_to_device( tensor: BoolTensor, device: &Device, ) -> BoolTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); + let device_original: &B::Device = tensor.client.device(); + let device_target: B::Device = device.clone(); if device_original == &device_target { return tensor; @@ -154,7 +155,7 @@ impl BoolTensorOps for Fusion { } impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_reshape::(input, Shape::from(&self.desc.out.shape)); handles.register_bool_tensor(&self.desc.out.id, output); @@ -188,7 +189,7 @@ impl BoolTensorOps for Fusion { } impl Operation for SliceOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let output = @@ -232,7 +233,7 @@ impl BoolTensorOps for Fusion { } impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let value = handles.get_bool_tensor::(&self.desc.value); @@ -277,7 +278,7 @@ impl BoolTensorOps for Fusion { } impl Operation for CatOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors @@ -329,7 +330,7 @@ impl BoolTensorOps for Fusion { } impl Operation for EqualOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_bool_tensor::(&self.desc.lhs); let rhs = handles.get_bool_tensor(&self.desc.rhs); let output = B::bool_equal(lhs, rhs); @@ -364,7 +365,7 @@ impl BoolTensorOps for Fusion { } impl Operation for NotOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_not(input); handles.register_bool_tensor(&self.desc.out.id, output); @@ -381,7 +382,7 @@ impl BoolTensorOps for Fusion { out.client.register( vec![stream], - OperationDescription::Bool(crate::stream::BoolOperationDescription::Not(desc.clone())), + OperationDescription::Bool(BoolOperationDescription::Not(desc.clone())), NotOps::::new(desc), ); @@ -399,7 +400,7 @@ impl BoolTensorOps for Fusion { } impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_swap_dims(input, self.desc.dim1, self.desc.dim2); handles.register_bool_tensor(&self.desc.out.id, output); @@ -438,7 +439,7 @@ impl BoolTensorOps for Fusion { } impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::bool_permute(input, axes); @@ -478,7 +479,7 @@ impl BoolTensorOps for Fusion { } impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::bool_expand(input, shape.into()); @@ -516,7 +517,7 @@ impl BoolTensorOps for Fusion { } impl Operation for FlipOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let output = B::bool_flip(input, self.desc.axes.as_slice()); handles.register_bool_tensor(&self.desc.out.id, output); @@ -552,7 +553,7 @@ impl BoolTensorOps for Fusion { } impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_bool_tensor::(&self.desc.tensor); let output = B::bool_repeat::(tensor, self.desc.dim, self.desc.times); diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 6b3c9e1980..e282f11abd 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -4,21 +4,23 @@ use crate::{ get_client, ops::binary::binary_ops_shape, scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops, - stream::{ + stream::{execution::Operation, StreamId}, + unary_float_ops, Fusion, FusionBackend, +}; + +use burn_tensor::{ + ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, + repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, Operation, + FloatOperationDescription, GatherOperationDescription, HandleContainer, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatOperationDescription, ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, - StreamId, SwapDimsDescription, UnaryOperationDescription, + SwapDimsDescription, TensorDescription, UnaryOperationDescription, }, - unary_float_ops, Fusion, FusionBackend, TensorDescription, -}; -use burn_tensor::{ - ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, Data, Device, Distribution, ElementConversion, Reader, Shape, }; use std::ops::Range; @@ -28,7 +30,7 @@ impl FloatTensorOps for Fusion { data: Data, D>, device: &Device, ) -> FloatTensor { - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let tensor = B::float_from_data(data, device); let shape = B::float_shape(&tensor); @@ -50,7 +52,7 @@ impl FloatTensorOps for Fusion { } impl Operation for RandomOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.out.shape.clone()); let output: B::FloatTensorPrimitive = B::float_random(shape, self.desc.distribution, &handles.device); @@ -60,7 +62,7 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let out = client.tensor_uninitialized(shape); let desc = RandomOperationDescription { @@ -83,7 +85,7 @@ impl FloatTensorOps for Fusion { } impl Operation for ZerosOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); let output = B::float_zeros::(shape, &handles.device); handles.register_float_tensor(&self.out.id, output); @@ -92,7 +94,7 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let out = client.tensor_uninitialized(shape); let desc = out.to_description_out(); @@ -112,7 +114,7 @@ impl FloatTensorOps for Fusion { } impl Operation for OnesOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); let output = B::float_ones::(shape, &handles.device); handles.register_float_tensor(&self.out.id, output); @@ -121,7 +123,7 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let out = client.tensor_uninitialized(shape); let desc = out.to_description_out(); @@ -146,7 +148,7 @@ impl FloatTensorOps for Fusion { } impl Operation for FullOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.out.shape.clone()); let output: B::FloatTensorPrimitive = B::float_full(shape, self.elem.elem(), &handles.device); @@ -156,7 +158,7 @@ impl FloatTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let out = client.tensor_uninitialized(shape); let desc = (out.to_description_out(), fill_value.elem::()); @@ -180,15 +182,15 @@ impl FloatTensorOps for Fusion { } fn float_device(tensor: &FloatTensor) -> Device { - tensor.client.device().clone().into() + tensor.client.device().clone() } fn float_to_device( tensor: FloatTensor, device: &Device, ) -> FloatTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); + let device_original: &B::Device = tensor.client.device(); + let device_target: B::Device = device.clone(); if device_original == &device_target { return tensor; @@ -212,7 +214,7 @@ impl FloatTensorOps for Fusion { } impl Operation for IntoIntOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_into_int(input); @@ -237,7 +239,7 @@ impl FloatTensorOps for Fusion { } fn float_empty(shape: Shape, device: &Device) -> FloatTensor { - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let stream = StreamId::current(); let tensor = B::float_empty(shape.clone(), device); @@ -307,7 +309,7 @@ impl FloatTensorOps for Fusion { } impl Operation for ClampOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_clamp(input, self.desc.min.elem(), self.desc.max.elem()); @@ -551,7 +553,7 @@ impl FloatTensorOps for Fusion { } impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_swap_dims(input, self.desc.dim1, self.desc.dim2); handles.register_float_tensor(&self.desc.out.id, output); @@ -591,7 +593,7 @@ impl FloatTensorOps for Fusion { } impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_reshape::(input, Shape::from(&self.desc.out.shape)); handles.register_float_tensor(&self.desc.out.id, output); @@ -626,7 +628,7 @@ impl FloatTensorOps for Fusion { } impl Operation for GatherOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor(&self.desc.indices); @@ -667,7 +669,7 @@ impl FloatTensorOps for Fusion { } impl Operation for ScatterOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor(&self.desc.indices); let value = handles.get_float_tensor(&self.desc.value); @@ -712,7 +714,7 @@ impl FloatTensorOps for Fusion { } impl Operation for SelectOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor(&self.desc.indices); @@ -754,7 +756,7 @@ impl FloatTensorOps for Fusion { } impl Operation for SelectAssignOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor(&self.desc.indices); let value = handles.get_float_tensor(&self.desc.value); @@ -799,7 +801,7 @@ impl FloatTensorOps for Fusion { } impl Operation for SliceOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let output = @@ -842,7 +844,7 @@ impl FloatTensorOps for Fusion { } impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let value = handles.get_float_tensor::(&self.desc.value); @@ -887,7 +889,7 @@ impl FloatTensorOps for Fusion { } impl Operation for MaskWhereOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let value = handles.get_float_tensor(&self.desc.value); let mask = handles.get_bool_tensor(&self.desc.mask); @@ -932,7 +934,7 @@ impl FloatTensorOps for Fusion { } impl Operation for MaskFillOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let mask = handles.get_bool_tensor(&self.desc.mask); @@ -1530,7 +1532,7 @@ impl FloatTensorOps for Fusion { } impl Operation for CatOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors @@ -1607,7 +1609,7 @@ impl FloatTensorOps for Fusion { } impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let output = B::float_repeat::(tensor, self.desc.dim, self.desc.times); @@ -1715,7 +1717,7 @@ impl FloatTensorOps for Fusion { } impl Operation for MaxDimWithIndicesOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_max_dim_with_indices(tensor, self.desc.dim); @@ -1802,7 +1804,7 @@ impl FloatTensorOps for Fusion { } impl Operation for MinDimWithIndicesOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_float_tensor::(&self.desc.tensor); let (output, indices) = B::float_min_dim_with_indices(tensor, self.desc.dim); @@ -1871,7 +1873,7 @@ impl FloatTensorOps for Fusion { } impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::float_permute(input, axes); @@ -1911,7 +1913,7 @@ impl FloatTensorOps for Fusion { } impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::float_expand(input, shape.into()); @@ -1949,7 +1951,7 @@ impl FloatTensorOps for Fusion { } impl Operation for FlipOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = B::float_flip(input, &self.desc.axes); handles.register_float_tensor(&self.desc.out.id, output); diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 52b01e7913..2e691707bf 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -4,28 +4,29 @@ use crate::{ get_client, ops::binary::binary_ops_shape, scalar_int_cmp_ops, scalar_int_ops, - stream::{ - self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, - ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription, - NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription, - RandomOperationDescription, ReduceDimWithIndicesDescription, RepeatOperationDescription, - ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription, - SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription, - UnaryOperationDescription, - }, - unary_int_ops, Fusion, FusionBackend, TensorDescription, + stream::{execution::Operation, StreamId}, + unary_int_ops, Fusion, FusionBackend, }; use burn_tensor::{ ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps}, + repr::{ + self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, + ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, + GatherOperationDescription, HandleContainer, MaskFillOperationDescription, + MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, + PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, + RepeatOperationDescription, ReshapeDescription, ScalarOperationDescription, + ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, + SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, + TensorDescription, UnaryOperationDescription, + }, Data, Device, Distribution, ElementConversion, Reader, Shape, }; use core::ops::Range; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let tensor = B::int_empty(shape.clone(), device); let stream = StreamId::current(); @@ -44,7 +45,7 @@ impl IntTensorOps for Fusion { data: Data, D>, device: &Device, ) -> IntTensor { - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let tensor = B::int_from_data(data, device); let shape = B::int_shape(&tensor); let stream = StreamId::current(); @@ -53,15 +54,15 @@ impl IntTensorOps for Fusion { } fn int_device(tensor: &IntTensor) -> Device { - tensor.client.device().clone().into() + tensor.client.device().clone() } fn int_to_device( tensor: IntTensor, device: &Device, ) -> IntTensor { - let device_original: &B::FusionDevice = tensor.client.device(); - let device_target: B::FusionDevice = device.clone().into(); + let device_original: &B::Device = tensor.client.device(); + let device_target: B::Device = device.clone(); if device_original == &device_target { return tensor; @@ -86,7 +87,7 @@ impl IntTensorOps for Fusion { } impl Operation for ReshapeDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_reshape::(input, Shape::from(&self.desc.out.shape)); handles.register_int_tensor(&self.desc.out.id, output); @@ -120,7 +121,7 @@ impl IntTensorOps for Fusion { } impl Operation for SliceOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let output = @@ -164,7 +165,7 @@ impl IntTensorOps for Fusion { } impl Operation for SliceAssignOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let value = handles.get_int_tensor::(&self.desc.value); @@ -208,7 +209,7 @@ impl IntTensorOps for Fusion { } impl Operation for MaskWhereOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let value = handles.get_int_tensor(&self.desc.value); let mask = handles.get_bool_tensor(&self.desc.mask); @@ -251,7 +252,7 @@ impl IntTensorOps for Fusion { } impl Operation for MaskFillOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let mask = handles.get_bool_tensor(&self.desc.mask); @@ -291,7 +292,7 @@ impl IntTensorOps for Fusion { } impl Operation for GatherOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor(&self.desc.indices); @@ -331,7 +332,7 @@ impl IntTensorOps for Fusion { } impl Operation for ScatterOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor(&self.desc.indices); let value = handles.get_int_tensor(&self.desc.value); @@ -374,7 +375,7 @@ impl IntTensorOps for Fusion { } impl Operation for SelectOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor(&self.desc.indices); @@ -416,7 +417,7 @@ impl IntTensorOps for Fusion { } impl Operation for SelectAssignOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let indices = handles.get_int_tensor(&self.desc.indices); let value = handles.get_int_tensor(&self.desc.value); @@ -457,7 +458,7 @@ impl IntTensorOps for Fusion { } impl Operation for CatOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensors = self .desc .tensors @@ -770,9 +771,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - stream::OperationDescription::NumericInt(NumericOperationDescription::Add( - desc.clone(), - )), + repr::OperationDescription::NumericInt(NumericOperationDescription::Add(desc.clone())), AddOps::::new(desc), ); @@ -795,7 +794,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - stream::OperationDescription::NumericInt(NumericOperationDescription::AddScalar( + repr::OperationDescription::NumericInt(NumericOperationDescription::AddScalar( desc.clone(), )), AddOps::::new(desc), @@ -823,9 +822,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - stream::OperationDescription::NumericInt(NumericOperationDescription::Sub( - desc.clone(), - )), + repr::OperationDescription::NumericInt(NumericOperationDescription::Sub(desc.clone())), SubOps::::new(desc), ); @@ -848,7 +845,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - stream::OperationDescription::NumericInt(NumericOperationDescription::SubScalar( + repr::OperationDescription::NumericInt(NumericOperationDescription::SubScalar( desc.clone(), )), SubOps::::new(desc), @@ -876,9 +873,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - stream::OperationDescription::NumericInt(NumericOperationDescription::Mul( - desc.clone(), - )), + repr::OperationDescription::NumericInt(NumericOperationDescription::Mul(desc.clone())), MulOps::::new(desc), ); @@ -901,7 +896,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - stream::OperationDescription::NumericInt(NumericOperationDescription::MulScalar( + repr::OperationDescription::NumericInt(NumericOperationDescription::MulScalar( desc.clone(), )), MulOps::::new(desc), @@ -929,9 +924,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream_1, stream_2], - stream::OperationDescription::NumericInt(NumericOperationDescription::Div( - desc.clone(), - )), + repr::OperationDescription::NumericInt(NumericOperationDescription::Div(desc.clone())), DivOps::::new(desc), ); @@ -954,7 +947,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - stream::OperationDescription::NumericInt(NumericOperationDescription::DivScalar( + repr::OperationDescription::NumericInt(NumericOperationDescription::DivScalar( desc.clone(), )), DivOps::::new(desc), @@ -979,7 +972,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - stream::OperationDescription::NumericInt(NumericOperationDescription::RemScalar( + repr::OperationDescription::NumericInt(NumericOperationDescription::RemScalar( desc.clone(), )), ModOps::::new(desc), @@ -995,7 +988,7 @@ impl IntTensorOps for Fusion { } impl Operation for ZerosOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.shape.clone()); let output = B::int_zeros::(shape, &handles.device); handles.register_int_tensor(&self.desc.id, output); @@ -1004,7 +997,7 @@ impl IntTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let out = client.tensor_uninitialized(shape); let desc = out.to_description_out(); client.register( @@ -1023,7 +1016,7 @@ impl IntTensorOps for Fusion { } impl Operation for OnesOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.shape.clone()); let output = B::int_ones::(shape, &handles.device); handles.register_int_tensor(&self.desc.id, output); @@ -1032,7 +1025,7 @@ impl IntTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let out = client.tensor_uninitialized(shape); let desc = out.to_description_out(); @@ -1223,7 +1216,7 @@ impl IntTensorOps for Fusion { } impl Operation for ClampOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_clamp(input, self.desc.min.elem(), self.desc.max.elem()); @@ -1274,7 +1267,7 @@ impl IntTensorOps for Fusion { } impl Operation for IntoFloatOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_into_float(input); handles.register_float_tensor(&self.desc.out.id, output); @@ -1289,7 +1282,7 @@ impl IntTensorOps for Fusion { }; out.client.register( vec![stream], - OperationDescription::Int(stream::IntOperationDescription::IntoFloat(desc.clone())), + OperationDescription::Int(repr::IntOperationDescription::IntoFloat(desc.clone())), IntoFloatOps::::new(desc), ); @@ -1307,7 +1300,7 @@ impl IntTensorOps for Fusion { } impl Operation for SwapDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = B::int_swap_dims(input, self.desc.dim1, self.desc.dim2); handles.register_int_tensor(&self.desc.out.id, output); @@ -1387,7 +1380,7 @@ impl IntTensorOps for Fusion { } impl Operation for MaxDimWithIndicesOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_max_dim_with_indices(tensor, self.desc.dim); @@ -1470,7 +1463,7 @@ impl IntTensorOps for Fusion { } impl Operation for MinDimWithIndicesOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let (output, indices) = B::int_min_dim_with_indices(tensor, self.desc.dim); @@ -1513,7 +1506,7 @@ impl IntTensorOps for Fusion { } impl Operation for IntRandomOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let shape = Shape::from(self.desc.out.shape.clone()); let output: B::IntTensorPrimitive = B::int_random(shape, self.desc.distribution, &handles.device); @@ -1523,7 +1516,7 @@ impl IntTensorOps for Fusion { let stream = StreamId::current(); let shape: Vec = shape.dims.into(); - let client = get_client::(&device.clone().into()); + let client = get_client::(&device.clone()); let out = client.tensor_uninitialized(shape); let desc = RandomOperationDescription { @@ -1549,7 +1542,7 @@ impl IntTensorOps for Fusion { } impl Operation for PermuteDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let axes: [usize; D] = self.desc.axes.try_into().unwrap(); let output = B::int_permute(input, axes); @@ -1589,7 +1582,7 @@ impl IntTensorOps for Fusion { } impl Operation for ExpandOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_bool_tensor::(&self.desc.input); let shape: [usize; D2] = self.desc.shape.try_into().unwrap(); let output = B::bool_expand(input, shape.into()); @@ -1623,7 +1616,7 @@ impl IntTensorOps for Fusion { } impl Operation for FlipDimsOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let axes = &self.desc.axes; let output = B::int_flip(input, axes); @@ -1661,7 +1654,7 @@ impl IntTensorOps for Fusion { } impl Operation for RepeatOps { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let tensor = handles.get_int_tensor::(&self.desc.tensor); let output = B::int_repeat::(tensor, self.desc.dim, self.desc.times); diff --git a/crates/burn-fusion/src/ops/module.rs b/crates/burn-fusion/src/ops/module.rs index a7c84a4a4b..a2543559c8 100644 --- a/crates/burn-fusion/src/ops/module.rs +++ b/crates/burn-fusion/src/ops/module.rs @@ -1,25 +1,25 @@ -use crate::stream::InterpolateBackwardDescription; -use crate::{ - client::FusionClient, - stream::{ +use crate::{client::FusionClient, stream::execution::Operation, Fusion, FusionBackend}; +use burn_tensor::{ + ops::{ + conv::{ + calculate_conv_output_size, calculate_conv_transpose_output_size, + calculate_pool_output_size, + }, + ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions, + MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, + ModuleOps, + }, + repr::{ AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription, AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, - ConvTranspose2dDescription, InterpolateDescription, MaxPool1dDescription, - MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, - MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, - MaxPool2dWithIndicesDescription, Operation, OperationDescription, - }, - Fusion, FusionBackend, HandleContainer, -}; -use burn_tensor::ops::{ - conv::{ - calculate_conv_output_size, calculate_conv_transpose_output_size, - calculate_pool_output_size, + ConvTranspose2dDescription, HandleContainer, InterpolateBackwardDescription, + InterpolateDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, + MaxPool1dWithIndicesDescription, MaxPool2dDescription, + MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, + ModuleOperationDescription, OperationDescription, }, - ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions, - MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; macro_rules! make_ops { @@ -30,7 +30,7 @@ macro_rules! make_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { #[allow(clippy::redundant_closure_call)] $fn(self.desc, handles) } @@ -88,9 +88,7 @@ impl ModuleOps> for Fusion { }; out.client.clone().register( streams, - OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv1d( - description.clone(), - )), + OperationDescription::Module(ModuleOperationDescription::Conv1d(description.clone())), Conv1dOps::new(description), ); @@ -155,9 +153,7 @@ impl ModuleOps> for Fusion { }; out.client.register( streams, - OperationDescription::Module(crate::stream::ModuleOperationDescription::Conv2d( - desc.clone(), - )), + OperationDescription::Module(ModuleOperationDescription::Conv2d(desc.clone())), Conv2dOps::new(desc), ); @@ -216,9 +212,7 @@ impl ModuleOps> for Fusion { }; out.client.register( streams, - OperationDescription::Module( - crate::stream::ModuleOperationDescription::ConvTranspose1d(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::ConvTranspose1d(desc.clone())), ConvTranspose1dOps::new(desc), ); @@ -285,9 +279,7 @@ impl ModuleOps> for Fusion { }; out.client.register( streams, - OperationDescription::Module( - crate::stream::ModuleOperationDescription::ConvTranspose2d(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::ConvTranspose2d(desc.clone())), ConvTranspose2dOps::new(desc), ); @@ -333,9 +325,7 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream], - OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool1d( - desc.clone(), - )), + OperationDescription::Module(ModuleOperationDescription::AvgPool1d(desc.clone())), AvgPool1dOps::new(desc), ); @@ -385,9 +375,7 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream], - OperationDescription::Module(crate::stream::ModuleOperationDescription::AvgPool2d( - desc.clone(), - )), + OperationDescription::Module(ModuleOperationDescription::AvgPool2d(desc.clone())), AvgPool2dOps::new(desc), ); @@ -436,9 +424,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::AvgPool1dBackward(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::AvgPool1dBackward( + desc.clone(), + )), AvgPool1dBackwardOps::new(desc), ); @@ -487,9 +475,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::AvgPool2dBackward(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::AvgPool2dBackward( + desc.clone(), + )), AvgPool2dBackwardOps::new(desc), ); @@ -536,9 +524,7 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream], - OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool1d( - desc.clone(), - )), + OperationDescription::Module(ModuleOperationDescription::MaxPool1d(desc.clone())), MaxPool1dOps::new(desc), ); @@ -598,9 +584,7 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream], - OperationDescription::Module(crate::stream::ModuleOperationDescription::MaxPool2d( - desc.clone(), - )), + OperationDescription::Module(ModuleOperationDescription::MaxPool2d(desc.clone())), MaxPool2dOps::new(desc), ); @@ -649,9 +633,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::MaxPool1dWithIndices(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndices( + desc.clone(), + )), MaxPool1dWithIndicesOps::new(desc), ); @@ -714,9 +698,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::MaxPool2dWithIndices(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndices( + desc.clone(), + )), MaxPool2dWithIndicesOps::new(desc), ); @@ -770,11 +754,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream_1, stream_2, stream_3], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::MaxPool1dWithIndicesBackward( - desc.clone(), - ), - ), + OperationDescription::Module(ModuleOperationDescription::MaxPool1dWithIndicesBackward( + desc.clone(), + )), MaxPool1dWithIndicesBackwardOps::new(desc), ); @@ -828,11 +810,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream_1, stream_2, stream_3], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::MaxPool2dWithIndicesBackward( - desc.clone(), - ), - ), + OperationDescription::Module(ModuleOperationDescription::MaxPool2dWithIndicesBackward( + desc.clone(), + )), MaxPool2dWithIndicesBackwardOps::new(desc), ); @@ -862,9 +842,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::AdaptiveAvgPool1d(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1d( + desc.clone(), + )), AdaptiveAvgPool1dOps::new(desc), ); @@ -897,9 +877,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::AdaptiveAvgPool2d(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2d( + desc.clone(), + )), AdaptiveAvgPool2dOps::new(desc), ); @@ -933,9 +913,9 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream_1, stream_2], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool1dBackward( + desc.clone(), + )), AdaptiveAvgPool1dBackwardOps::new(desc), ); @@ -969,9 +949,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::AdaptiveAvgPool2dBackward( + desc.clone(), + )), AdaptiveAvgPool2dBackwardOps::new(desc), ); @@ -1006,9 +986,7 @@ impl ModuleOps> for Fusion { out.client.register( vec![stream], - OperationDescription::Module(crate::stream::ModuleOperationDescription::Interpolate( - desc.clone(), - )), + OperationDescription::Module(ModuleOperationDescription::Interpolate(desc.clone())), InterpolateOps::new(desc), ); @@ -1047,9 +1025,9 @@ impl ModuleOps> for Fusion { }; out.client.register( vec![stream_1, stream_2], - OperationDescription::Module( - crate::stream::ModuleOperationDescription::InterpolateBackward(desc.clone()), - ), + OperationDescription::Module(ModuleOperationDescription::InterpolateBackward( + desc.clone(), + )), InterpolateBackwardOps::new(desc), ); out diff --git a/crates/burn-fusion/src/ops/unary.rs b/crates/burn-fusion/src/ops/unary.rs index f40d339dd5..d6f0833c64 100644 --- a/crates/burn-fusion/src/ops/unary.rs +++ b/crates/burn-fusion/src/ops/unary.rs @@ -18,7 +18,7 @@ macro_rules! scalar_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); @@ -38,7 +38,7 @@ macro_rules! scalar_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); @@ -62,7 +62,7 @@ macro_rules! scalar_float2int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs.clone()); @@ -85,7 +85,7 @@ macro_rules! unary_float_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_float_tensor::(&self.desc.input); let output = $ops(input); @@ -108,7 +108,7 @@ macro_rules! unary_int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let input = handles.get_int_tensor::(&self.desc.input); let output = $ops(input); @@ -131,7 +131,7 @@ macro_rules! scalar_float_cmp_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_float_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); @@ -154,7 +154,7 @@ macro_rules! scalar_int_cmp_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); @@ -184,7 +184,7 @@ macro_rules! scalar_int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, burn_tensor::ElementConversion::elem(self.desc.rhs)); @@ -204,7 +204,7 @@ macro_rules! scalar_int_ops { } impl Operation for $name { - fn execute(self: Box, handles: &mut $crate::HandleContainer) { + fn execute(self: Box, handles: &mut HandleContainer) { let lhs = handles.get_int_tensor::(&self.desc.lhs); let output = $ops(lhs, self.desc.rhs); diff --git a/crates/burn-fusion/src/server.rs b/crates/burn-fusion/src/server.rs index 3ef081e40a..4c6ec62fc7 100644 --- a/crates/burn-fusion/src/server.rs +++ b/crates/burn-fusion/src/server.rs @@ -1,8 +1,11 @@ use crate::{ - stream::{MultiStream, Operation, OperationDescription, StreamId}, - FusionBackend, HandleContainer, TensorId, + stream::{execution::Operation, MultiStream, StreamId}, + FusionBackend, +}; +use burn_tensor::{ + ops::{FloatElem, IntElem}, + repr::{HandleContainer, OperationDescription, TensorDescription, TensorId}, }; -use burn_tensor::ops::{FloatElem, IntElem}; use std::sync::Arc; pub struct FusionServer @@ -11,14 +14,14 @@ where { streams: MultiStream, pub(crate) handles: HandleContainer, - pub device: B::FusionDevice, + pub device: B::Device, } impl FusionServer where B: FusionBackend, { - pub fn new(device: B::FusionDevice) -> Self { + pub fn new(device: B::Device) -> Self { Self { streams: MultiStream::new(device.clone()), handles: HandleContainer::new(device.clone()), @@ -46,7 +49,7 @@ where pub fn read_float( &mut self, - tensor: crate::TensorDescription, + tensor: TensorDescription, id: StreamId, ) -> burn_tensor::Reader, D>> { // Make sure all registered operations are executed. @@ -59,7 +62,7 @@ where pub fn read_int( &mut self, - tensor: crate::TensorDescription, + tensor: TensorDescription, id: StreamId, ) -> burn_tensor::Reader, D>> { // Make sure all registered operations are executed. @@ -72,7 +75,7 @@ where pub fn read_bool( &mut self, - tensor: crate::TensorDescription, + tensor: TensorDescription, id: StreamId, ) -> burn_tensor::Reader> { // Make sure all registered operations are executed. @@ -85,7 +88,7 @@ where pub fn change_server_float( &mut self, - tensor: &crate::TensorDescription, + tensor: &TensorDescription, device: &B::Device, server_device: &mut Self, ) -> Arc { @@ -101,7 +104,7 @@ where } pub fn change_server_int( &mut self, - tensor: &crate::TensorDescription, + tensor: &TensorDescription, device: &B::Device, server_device: &mut Self, ) -> Arc { @@ -117,7 +120,7 @@ where } pub fn change_server_bool( &mut self, - tensor: &crate::TensorDescription, + tensor: &TensorDescription, device: &B::Device, server_device: &mut Self, ) -> Arc { diff --git a/crates/burn-fusion/src/stream/base.rs b/crates/burn-fusion/src/stream/base.rs index 5a017ef5ee..fb3d8f99b3 100644 --- a/crates/burn-fusion/src/stream/base.rs +++ b/crates/burn-fusion/src/stream/base.rs @@ -1,8 +1,9 @@ -use super::Operation; -use super::OperationConverter; -use super::OperationDescription; +use burn_tensor::repr::OperationDescription; + use crate::FusionBackend; +use super::{execution::Operation, OperationConverter, RelativeOps}; + /// A growing list of [tensor operation descriptions](OperationDescription). pub struct OperationQueue { pub(crate) global: Vec, diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index a3afca2ace..11441a40a8 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -1,23 +1,5 @@ -use super::{ - AdaptiveAvgPool1dBackwardDescription, AdaptiveAvgPool1dDescription, - AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription, - AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOperationDescription, - BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription, - Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription, - EmbeddingBackwardDescription, EmbeddingDescription, ExpandOperationDescription, - FlipOperationDescription, FloatOperationDescription, GatherOperationDescription, - IntOperationDescription, InterpolateBackwardDescription, InterpolateDescription, - MaskFillOperationDescription, MaskWhereOperationDescription, MaxPool1dDescription, - MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription, MaxPool2dDescription, - MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, - ModuleOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription, - SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription, - SwapDimsDescription, UnaryOperationDescription, -}; -use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId}; -use burn_tensor::{Element, ElementConversion}; +use crate::FusionBackend; +use burn_tensor::{repr::*, Element, ElementConversion}; use hashbrown::HashMap; /// The context contains the relative graph tensor mapping so that a relative tensor id can be @@ -49,6 +31,16 @@ pub(crate) struct OperationConverter { scalar_ints: Vec, } +pub(crate) trait RelativeOps { + fn to_relative(&self, converter: &mut OperationConverter) -> Self; +} + +trait RelativeOpsScalar { + fn to_relative(&self, converter: &mut OperationConverter, local_elem: F) -> Self + where + F: Fn(&mut OperationConverter, &E) -> E; +} + impl OperationConverter { pub(crate) fn context<'a, B: FusionBackend>( &'a self, @@ -85,8 +77,8 @@ impl OperationConverter { } } -impl OperationDescription { - pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self { +impl RelativeOps for OperationDescription { + fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { OperationDescription::BaseFloat(ops) => { OperationDescription::BaseFloat(ops.to_relative(converter)) @@ -117,8 +109,8 @@ impl OperationDescription { } } -impl ModuleOperationDescription { - pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self { +impl RelativeOps for ModuleOperationDescription { + fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { ModuleOperationDescription::Embedding(desc) => { ModuleOperationDescription::Embedding(EmbeddingDescription { @@ -172,7 +164,7 @@ impl ModuleOperationDescription { }) } ModuleOperationDescription::AvgPool1d(desc) => { - ModuleOperationDescription::AvgPool1d(super::AvgPool1dDescription { + ModuleOperationDescription::AvgPool1d(AvgPool1dDescription { x: desc.x.to_relative(converter), kernel_size: desc.kernel_size, stride: desc.stride, @@ -192,7 +184,7 @@ impl ModuleOperationDescription { }) } ModuleOperationDescription::AvgPool1dBackward(desc) => { - ModuleOperationDescription::AvgPool1dBackward(super::AvgPool1dBackwardDescription { + ModuleOperationDescription::AvgPool1dBackward(AvgPool1dBackwardDescription { x: desc.x.to_relative(converter), grad: desc.grad.to_relative(converter), kernel_size: desc.kernel_size, @@ -336,8 +328,8 @@ impl ModuleOperationDescription { } } -impl FloatOperationDescription { - pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self { +impl RelativeOps for FloatOperationDescription { + fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { FloatOperationDescription::Exp(desc) => { FloatOperationDescription::Exp(UnaryOperationDescription { @@ -423,8 +415,8 @@ impl FloatOperationDescription { } } -impl BoolOperationDescription { - pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self { +impl RelativeOps for BoolOperationDescription { + fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { BoolOperationDescription::IntoFloat(desc) => { BoolOperationDescription::IntoFloat(UnaryOperationDescription { @@ -448,8 +440,8 @@ impl BoolOperationDescription { } } -impl IntOperationDescription { - pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self { +impl RelativeOps for IntOperationDescription { + fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { IntOperationDescription::IntoFloat(desc) => { IntOperationDescription::IntoFloat(UnaryOperationDescription { @@ -461,8 +453,8 @@ impl IntOperationDescription { } } -impl NumericOperationDescription { - pub(crate) fn to_relative(&self, converter: &mut OperationConverter, local_elem: F) -> Self +impl RelativeOpsScalar for NumericOperationDescription { + fn to_relative(&self, converter: &mut OperationConverter, local_elem: F) -> Self where F: Fn(&mut OperationConverter, &E) -> E, { @@ -779,8 +771,8 @@ impl NumericOperationDescription { } } -impl BaseOperationDescription { - pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self { +impl RelativeOps for BaseOperationDescription { + fn to_relative(&self, converter: &mut OperationConverter) -> Self { match self { BaseOperationDescription::ToDevice(desc) => { BaseOperationDescription::ToDevice(desc.to_relative(converter)) @@ -828,7 +820,7 @@ impl BaseOperationDescription { }) } BaseOperationDescription::SliceAssign(desc) => { - BaseOperationDescription::SliceAssign(super::SliceAssignOperationDescription { + BaseOperationDescription::SliceAssign(SliceAssignOperationDescription { tensor: desc.tensor.to_relative(converter), ranges: desc.ranges.iter().map(|_range| 0..1).collect(), value: desc.value.to_relative(converter), @@ -836,14 +828,14 @@ impl BaseOperationDescription { }) } BaseOperationDescription::Equal(desc) => { - BaseOperationDescription::Equal(super::BinaryOperationDescription { + BaseOperationDescription::Equal(BinaryOperationDescription { lhs: desc.lhs.to_relative(converter), rhs: desc.rhs.to_relative(converter), out: desc.out.to_relative(converter), }) } BaseOperationDescription::Repeat(desc) => { - BaseOperationDescription::Repeat(super::RepeatOperationDescription { + BaseOperationDescription::Repeat(RepeatOperationDescription { tensor: desc.tensor.to_relative(converter), dim: desc.dim, times: desc.times, @@ -851,7 +843,7 @@ impl BaseOperationDescription { }) } BaseOperationDescription::Cat(desc) => { - BaseOperationDescription::Cat(super::CatOperationDescription { + BaseOperationDescription::Cat(CatOperationDescription { tensors: desc .tensors .iter() @@ -865,8 +857,8 @@ impl BaseOperationDescription { } } -impl TensorDescription { - pub(crate) fn to_relative(&self, converter: &mut OperationConverter) -> Self { +impl RelativeOps for TensorDescription { + fn to_relative(&self, converter: &mut OperationConverter) -> Self { let relative_id = if let Some(value) = converter.tensors_global2relative.get(&self.id) { // If we already have the same tensor registered, we have to update its value, but not // its id. @@ -911,9 +903,8 @@ impl TensorDescription { #[cfg(test)] mod tests { - use crate::TensorStatus; - use super::*; + use burn_tensor::repr::{TensorDescription, TensorId, TensorStatus}; #[test] fn tensor_description_to_relative() { diff --git a/crates/burn-fusion/src/stream/execution/base.rs b/crates/burn-fusion/src/stream/execution/base.rs index f117e69d9f..7b2dce0f50 100644 --- a/crates/burn-fusion/src/stream/execution/base.rs +++ b/crates/burn-fusion/src/stream/execution/base.rs @@ -1,9 +1,11 @@ +use burn_tensor::repr::HandleContainer; + use crate::{ stream::{ store::{ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy}, - OperationQueue, + OperationQueue, RelativeOps, }, - FusionBackend, HandleContainer, Optimization, + FusionBackend, Optimization, }; /// The mode in which the execution is done. @@ -13,6 +15,12 @@ pub(crate) enum ExecutionMode { Sync, } +/// General trait to abstract how a single operation is executed. +pub trait Operation: Send + Sync { + /// Execute the operation. + fn execute(self: Box, handles: &mut HandleContainer); +} + impl OperationQueue { /// Execute the queue partially following the execution strategy from the plan. pub(crate) fn execute( diff --git a/crates/burn-fusion/src/stream/execution/explorer.rs b/crates/burn-fusion/src/stream/execution/explorer.rs index f8f23ee43f..49621b2679 100644 --- a/crates/burn-fusion/src/stream/execution/explorer.rs +++ b/crates/burn-fusion/src/stream/execution/explorer.rs @@ -1,5 +1,7 @@ +use burn_tensor::repr::OperationDescription; + use super::ExecutionMode; -use crate::{stream::OperationDescription, OptimizationBuilder, OptimizationStatus}; +use crate::{OptimizationBuilder, OptimizationStatus}; /// Explore and create new optimization. pub struct Explorer { diff --git a/crates/burn-fusion/src/stream/execution/policy.rs b/crates/burn-fusion/src/stream/execution/policy.rs index 6b723f3109..e56f7bce62 100644 --- a/crates/burn-fusion/src/stream/execution/policy.rs +++ b/crates/burn-fusion/src/stream/execution/policy.rs @@ -1,13 +1,12 @@ +use burn_tensor::repr::OperationDescription; + use super::validator::{ ExecutionPlanOperationsStore, TriggerOperationsStore, TriggerProgress, TriggerValidator, ValidatorState, }; use super::ExecutionMode; use crate::stream::execution::validator::OperationsValidator; -use crate::stream::{ - store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery}, - OperationDescription, -}; +use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger, SearchQuery}; use std::marker::PhantomData; /// The policy keeps track of all possible execution plans for the current operations. @@ -266,14 +265,13 @@ impl Policy { #[cfg(test)] mod tests { - use super::*; - use crate::{ - stream::{ - store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger}, - FloatOperationDescription, UnaryOperationDescription, - }, - TensorDescription, TensorId, TensorStatus, + use burn_tensor::repr::{ + FloatOperationDescription, TensorDescription, TensorId, TensorStatus, + UnaryOperationDescription, }; + + use super::*; + use crate::stream::store::{ExecutionPlan, ExecutionStrategy, ExecutionTrigger}; use std::ops::Range; #[test] diff --git a/crates/burn-fusion/src/stream/execution/processor.rs b/crates/burn-fusion/src/stream/execution/processor.rs index 81bb31bf5b..40c447c181 100644 --- a/crates/burn-fusion/src/stream/execution/processor.rs +++ b/crates/burn-fusion/src/stream/execution/processor.rs @@ -1,9 +1,10 @@ +use burn_tensor::repr::OperationDescription; + use super::{ExecutionMode, Exploration, Explorer}; use crate::stream::execution::{Action, Policy}; use crate::stream::store::{ ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger, }; -use crate::stream::OperationDescription; use crate::OptimizationBuilder; /// Process a [stream segment](StreamSegment) following a [policy](Policy). diff --git a/crates/burn-fusion/src/stream/execution/tests.rs b/crates/burn-fusion/src/stream/execution/tests.rs index 31e249e49c..6755b624b2 100644 --- a/crates/burn-fusion/src/stream/execution/tests.rs +++ b/crates/burn-fusion/src/stream/execution/tests.rs @@ -6,16 +6,17 @@ //! To test these components effectively, we create mock types for the stream, optimization, //! optimization builder, and stream segment. These mock types aid in comprehensively //! understanding the process of optimizing streams. +use burn_tensor::repr::{ + BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, + OperationDescription, ScalarOperationDescription, TensorDescription, TensorId, TensorStatus, + UnaryOperationDescription, +}; + use crate::{ - stream::{ - store::{ - ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger, - }, - BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, - OperationDescription, ScalarOperationDescription, + stream::store::{ + ExecutionPlan, ExecutionPlanId, ExecutionPlanStore, ExecutionStrategy, ExecutionTrigger, }, - OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId, - TensorStatus, + OptimizationBuilder, OptimizationProperties, OptimizationStatus, }; use super::*; @@ -558,18 +559,16 @@ fn operation_2() -> OperationDescription { /// Just a simple operation. fn operation_3() -> OperationDescription { - OperationDescription::Float(FloatOperationDescription::Log( - crate::stream::UnaryOperationDescription { - input: TensorDescription { - id: TensorId::new(0), - shape: vec![32, 32], - status: TensorStatus::ReadOnly, - }, - out: TensorDescription { - id: TensorId::new(0), - shape: vec![32, 32], - status: TensorStatus::NotInit, - }, + OperationDescription::Float(FloatOperationDescription::Log(UnaryOperationDescription { + input: TensorDescription { + id: TensorId::new(0), + shape: vec![32, 32], + status: TensorStatus::ReadOnly, }, - )) + out: TensorDescription { + id: TensorId::new(0), + shape: vec![32, 32], + status: TensorStatus::NotInit, + }, + })) } diff --git a/crates/burn-fusion/src/stream/execution/validator.rs b/crates/burn-fusion/src/stream/execution/validator.rs index 07ae6efeaa..5cc68a25c0 100644 --- a/crates/burn-fusion/src/stream/execution/validator.rs +++ b/crates/burn-fusion/src/stream/execution/validator.rs @@ -1,7 +1,6 @@ -use crate::stream::{ - store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger}, - OperationDescription, -}; +use burn_tensor::repr::OperationDescription; + +use crate::stream::store::{ExecutionPlanId, ExecutionPlanStore, ExecutionTrigger}; /// Compare each operation in the list of operations provided by the [store](OperationsStore) /// to verify if the newly added operations match the original list. diff --git a/crates/burn-fusion/src/stream/mod.rs b/crates/burn-fusion/src/stream/mod.rs index ac12a69288..37c3455707 100644 --- a/crates/burn-fusion/src/stream/mod.rs +++ b/crates/burn-fusion/src/stream/mod.rs @@ -4,9 +4,7 @@ pub(crate) mod store; mod base; mod context; mod multi; -mod operation; pub use base::*; pub use context::*; pub use multi::*; -pub use operation::*; diff --git a/crates/burn-fusion/src/stream/multi.rs b/crates/burn-fusion/src/stream/multi.rs index 52fdc853ca..035b6ed280 100644 --- a/crates/burn-fusion/src/stream/multi.rs +++ b/crates/burn-fusion/src/stream/multi.rs @@ -1,20 +1,22 @@ +use burn_tensor::repr::{HandleContainer, OperationDescription}; + use super::{ - execution::{ExecutionMode, Processor, StreamSegment}, + execution::{ExecutionMode, Operation, Processor, StreamSegment}, store::{ExecutionPlanId, ExecutionPlanStore}, - Operation, OperationDescription, OperationQueue, StreamId, + OperationQueue, StreamId, }; -use crate::{FusionBackend, HandleContainer}; +use crate::FusionBackend; use std::collections::HashMap; /// Keep track of multiple concurrent streams of operations. pub struct MultiStream { streams: HashMap>, optimizations: ExecutionPlanStore, - device: B::FusionDevice, + device: B::Device, } impl MultiStream { - pub(crate) fn new(device: B::FusionDevice) -> Self { + pub(crate) fn new(device: B::Device) -> Self { Self { streams: HashMap::new(), optimizations: ExecutionPlanStore::new(), @@ -146,9 +148,9 @@ impl<'i, B: FusionBackend> StreamSegment for Segment<'i, B> { } impl Stream { - fn new(device: B::FusionDevice) -> Self { + fn new(device: B::Device) -> Self { Self { - processor: Processor::new(B::optimizations(device.into())), + processor: Processor::new(B::optimizations(device)), queue: OperationQueue::new(), } } diff --git a/crates/burn-fusion/src/stream/store/base.rs b/crates/burn-fusion/src/stream/store/base.rs index ddea281d92..cf32bcce39 100644 --- a/crates/burn-fusion/src/stream/store/base.rs +++ b/crates/burn-fusion/src/stream/store/base.rs @@ -1,5 +1,5 @@ use super::{ExecutionPlanIndex, InsertQuery, SearchQuery}; -use crate::stream::OperationDescription; +use burn_tensor::repr::OperationDescription; use serde::{Deserialize, Serialize}; /// The store that contains all explorations done on a device. diff --git a/crates/burn-fusion/src/stream/store/index.rs b/crates/burn-fusion/src/stream/store/index.rs index 3e77625d15..b1c3111ac9 100644 --- a/crates/burn-fusion/src/stream/store/index.rs +++ b/crates/burn-fusion/src/stream/store/index.rs @@ -1,4 +1,5 @@ -use crate::stream::{store::ExecutionPlanId, OperationDescription}; +use crate::stream::store::ExecutionPlanId; +use burn_tensor::repr::OperationDescription; use serde::{Deserialize, Serialize}; use std::{ collections::{hash_map::DefaultHasher, HashMap}, @@ -115,14 +116,13 @@ impl ExecutionPlanIndex { #[cfg(test)] mod tests { - use super::*; - use crate::{ - stream::{ - BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription, - }, + use burn_tensor::repr::{ + BinaryOperationDescription, NumericOperationDescription, ScalarOperationDescription, TensorDescription, TensorId, TensorStatus, }; + use super::*; + #[test] fn should_find_optimization_id_based_on_tensor_ops() { let mut index = ExecutionPlanIndex::default(); diff --git a/crates/burn-fusion/src/tensor.rs b/crates/burn-fusion/src/tensor.rs index 5b430158df..6fad70e723 100644 --- a/crates/burn-fusion/src/tensor.rs +++ b/crates/burn-fusion/src/tensor.rs @@ -2,9 +2,9 @@ use crate::{client::FusionClient, stream::StreamId}; use burn_tensor::{ backend::Backend, ops::{FloatElem, IntElem}, + repr::{TensorDescription, TensorId, TensorStatus}, Data, Reader, Shape, }; -use serde::{Deserialize, Serialize}; use std::sync::Arc; /// Tensor primitive for the [fusion backend](crate::FusionBackend) for all kind. @@ -33,7 +33,7 @@ impl core::fmt::Debug for FusionTensor { self.shape, self.is_orphan, ::name(), - self.client.device().clone().into(), + self.client.device().clone(), ) .as_str(), ) @@ -127,47 +127,3 @@ impl Drop for FusionTensor { } } } - -/// The tensor unique identifier. -#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)] -pub struct TensorId { - value: u64, -} - -/// The status of the current tensor. -#[derive(Hash, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -pub enum TensorStatus { - /// The tensor can be read, but not written. - ReadOnly, - /// The tensor can be mutated inplace. - ReadWrite, - /// No handle exists for that tensor. - NotInit, -} - -/// A tensor definition represents a snapshot of a tensor when it was used. -/// -/// # Example -/// -/// A tensor that is used multiple times has its status updated for each operation. -/// -/// 1. Status::NotInit -/// 2. Status::ReadOnly -/// 3. Status::ReadOnly -/// 4. Status::ReadWrite -#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct TensorDescription { - /// The [tensor id](TensorId). - pub id: TensorId, - /// The shape of the tensor. - pub shape: Vec, - /// The [status](TensorStatus) of the tensor when it was used. - pub status: TensorStatus, -} - -impl TensorId { - /// Create a new tensor id. - pub fn new(value: u64) -> Self { - Self { value } - } -} diff --git a/crates/burn-jit/src/codegen/compilation.rs b/crates/burn-jit/src/codegen/compilation.rs index 50136a88df..27dc9126dc 100644 --- a/crates/burn-jit/src/codegen/compilation.rs +++ b/crates/burn-jit/src/codegen/compilation.rs @@ -1,7 +1,5 @@ #[cfg(feature = "fusion")] use crate::fusion::JitFusionHandle; -#[cfg(feature = "fusion")] -use burn_fusion::TensorDescription; use super::{ dialect::gpu::{self}, @@ -136,8 +134,8 @@ impl CompilationSettings { pub fn dynamic_settings( self, info: &CompilationInfo, - inputs: &[&TensorDescription], - outputs: &[&TensorDescription], + inputs: &[&burn_tensor::repr::TensorDescription], + outputs: &[&burn_tensor::repr::TensorDescription], handles_inputs: &[JitFusionHandle], stateful: bool, ) -> Self { @@ -154,8 +152,8 @@ impl CompilationSettings { fn dynamic_inplace( self, info: &CompilationInfo, - inputs: &[&TensorDescription], - outputs: &[&TensorDescription], + inputs: &[&burn_tensor::repr::TensorDescription], + outputs: &[&burn_tensor::repr::TensorDescription], handles_inputs: &[JitFusionHandle], ) -> Self { let mut potential_inplace = inputs @@ -170,9 +168,9 @@ impl CompilationSettings { } match desc.status { - burn_fusion::TensorStatus::ReadOnly => return None, - burn_fusion::TensorStatus::NotInit => return None, - burn_fusion::TensorStatus::ReadWrite => (), + burn_tensor::repr::TensorStatus::ReadOnly => return None, + burn_tensor::repr::TensorStatus::NotInit => return None, + burn_tensor::repr::TensorStatus::ReadWrite => (), }; Some((pos, desc, input)) @@ -215,8 +213,8 @@ impl CompilationSettings { fn dynamic_reading_strategy( mut self, info: &CompilationInfo, - inputs: &[&TensorDescription], - outputs: &[&TensorDescription], + inputs: &[&burn_tensor::repr::TensorDescription], + outputs: &[&burn_tensor::repr::TensorDescription], handles_inputs: &[JitFusionHandle], ) -> Self { // First output is chosen for the layout reference. diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 0a0bbe2f72..1cc9cb107f 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -4,7 +4,7 @@ use crate::{ }; use burn_compute::client::ComputeClient; use burn_fusion::{client::MutexFusionClient, FusionBackend}; -use burn_tensor::Shape; +use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; use serde::{Deserialize, Serialize}; @@ -53,53 +53,61 @@ impl burn_fusion::Optimization> for JitOptimization } } -impl FusionBackend for JitBackend { - type OptimizationState = JitOptimizationState; - type Optimization = JitOptimization; - type FusionDevice = R::Device; +impl ReprBackend for JitBackend { type Handle = JitFusionHandle; - type FusionClient = MutexFusionClient; - - fn optimizations( - device: R::Device, - ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::new(device))] - } fn float_tensor( handle: Self::Handle, shape: Shape, - ) -> Self::FloatTensorPrimitive { + ) -> burn_tensor::ops::FloatTensor { handle.into_tensor(shape) } fn int_tensor( handle: Self::Handle, shape: Shape, - ) -> Self::IntTensorPrimitive { + ) -> burn_tensor::ops::IntTensor { handle.into_tensor(shape) } fn bool_tensor( handle: Self::Handle, shape: Shape, - ) -> Self::BoolTensorPrimitive { + ) -> burn_tensor::ops::BoolTensor { handle.into_tensor(shape) } - fn float_tensor_handle(tensor: Self::FloatTensorPrimitive) -> Self::Handle { + fn float_tensor_handle( + tensor: burn_tensor::ops::FloatTensor, + ) -> Self::Handle { tensor.into() } - fn int_tensor_handle(tensor: Self::IntTensorPrimitive) -> Self::Handle { + fn int_tensor_handle( + tensor: burn_tensor::ops::IntTensor, + ) -> Self::Handle { tensor.into() } - fn bool_tensor_handle(tensor: Self::BoolTensorPrimitive) -> Self::Handle { + fn bool_tensor_handle( + tensor: burn_tensor::ops::BoolTensor, + ) -> Self::Handle { tensor.into() } } +impl FusionBackend for JitBackend { + type OptimizationState = JitOptimizationState; + type Optimization = JitOptimization; + type FusionClient = MutexFusionClient; + + fn optimizations( + device: R::Device, + ) -> Vec>> { + vec![Box::new(ElementWiseBuilder::new(device))] + } +} + pub fn strides_dyn_rank(shape: &[usize]) -> Vec { let mut strides = vec![0; shape.len()]; diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index ccd2d93fa0..7958fe7e45 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -7,16 +7,14 @@ use crate::{ fusion::{tracing::TraceBuilder, JitOptimization}, JitBackend, Runtime, }; -use burn_fusion::{ - stream::{ +use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus}; +use burn_tensor::{ + ops::{FloatElem, IntElem}, + repr::{ BaseOperationDescription, BinaryOperationDescription, FloatOperationDescription, NumericOperationDescription, OperationDescription, ScalarOperationDescription, - UnaryOperationDescription, + TensorDescription, UnaryOperationDescription, }, - OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, -}; -use burn_tensor::{ - ops::{FloatElem, IntElem}, Device, Element, }; diff --git a/crates/burn-jit/src/fusion/elemwise/kernel.rs b/crates/burn-jit/src/fusion/elemwise/kernel.rs index 88566f5c3e..3fb9e2f8dc 100644 --- a/crates/burn-jit/src/fusion/elemwise/kernel.rs +++ b/crates/burn-jit/src/fusion/elemwise/kernel.rs @@ -1,3 +1,5 @@ +use burn_tensor::repr::TensorDescription; + use crate::{ codegen::{ calculate_num_elems_dyn_rank, @@ -11,7 +13,6 @@ use crate::{ kernel::elemwise_workgroup, Runtime, }; -use burn_fusion::TensorDescription; use std::{marker::PhantomData, sync::Arc}; #[derive(new)] diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index aec699b497..d0fe049049 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -15,7 +15,8 @@ use burn_compute::client::ComputeClient; use burn_compute::server::Handle; use burn_compute::tune::AutotuneOperation; use burn_fusion::stream::Context; -use burn_fusion::{TensorDescription, TensorStatus}; +use burn_tensor::repr::TensorDescription; +use burn_tensor::repr::TensorStatus; use burn_tensor::Device; use std::marker::PhantomData; use std::sync::Arc; diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 8642127eea..0e62ae3806 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -1,7 +1,9 @@ use super::{trace::Trace, Scalars}; use crate::codegen::dialect::gpu::{self, Operation, Variable}; -use burn_fusion::{TensorDescription, TensorId}; -use burn_tensor::Element; +use burn_tensor::{ + repr::{TensorDescription, TensorId, TensorStatus}, + Element, +}; use hashbrown::HashMap; /// Type facilitating building a [trace](Trace) by doing most of the conversions between the @@ -415,7 +417,7 @@ impl TraceBuilder { // are going to be used after the fused kernel by other operations. for entry in self.tensors.values() { let (tensor, _) = &entry; - if let burn_fusion::TensorStatus::ReadOnly = tensor.status { + if let TensorStatus::ReadOnly = tensor.status { if self.output_to_local.contains_key(&tensor.id) { outputs.push(entry.clone()); } diff --git a/crates/burn-jit/src/fusion/tracing/trace.rs b/crates/burn-jit/src/fusion/tracing/trace.rs index 78b3113281..4bda66d858 100644 --- a/crates/burn-jit/src/fusion/tracing/trace.rs +++ b/crates/burn-jit/src/fusion/tracing/trace.rs @@ -1,6 +1,6 @@ use super::Scalars; use crate::codegen::{dialect::gpu, CompilationInfo, InputInfo, OutputInfo}; -use burn_fusion::TensorDescription; +use burn_tensor::repr::TensorDescription; use serde::{Deserialize, Serialize}; /// A trace encapsulates all information necessary to perform the compilation and execution of diff --git a/crates/burn-jit/src/runtime.rs b/crates/burn-jit/src/runtime.rs index 4d6380631f..6954df913c 100644 --- a/crates/burn-jit/src/runtime.rs +++ b/crates/burn-jit/src/runtime.rs @@ -16,8 +16,7 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { /// The channel used to communicate with the compute server. type Channel: ComputeChannel; /// The device used to retrieve the compute client. - #[cfg(any(feature = "fusion", test))] - type Device: burn_fusion::FusionDevice + type Device: burn_tensor::backend::DeviceOps + Default + core::hash::Hash + PartialEq @@ -26,16 +25,6 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { + core::fmt::Debug + Sync + Send; - /// The device used to retrieve the compute client. - #[cfg(not(any(feature = "fusion", test)))] - type Device: Default - + core::hash::Hash - + PartialEq - + Eq - + Clone - + core::fmt::Debug - + Sync - + Send; /// A version of the runtime that supports full precision. /// diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index d9c5e493c3..39c4686150 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -2,7 +2,7 @@ use crate::NdArrayTensor; use crate::{element::FloatNdArrayElement, PrecisionBridge}; use alloc::string::String; use burn_common::stub::Mutex; -use burn_tensor::backend::Backend; +use burn_tensor::backend::{Backend, DeviceId, DeviceOps}; use core::marker::PhantomData; use rand::{rngs::StdRng, SeedableRng}; @@ -15,6 +15,14 @@ pub enum NdArrayDevice { Cpu, } +impl DeviceOps for NdArrayDevice { + fn id(&self) -> burn_tensor::backend::DeviceId { + match self { + NdArrayDevice::Cpu => DeviceId::new(0, 0), + } + } +} + impl Default for NdArrayDevice { fn default() -> Self { Self::Cpu diff --git a/crates/burn-tch/src/backend.rs b/crates/burn-tch/src/backend.rs index 81e31020bb..bf436702f9 100644 --- a/crates/burn-tch/src/backend.rs +++ b/crates/burn-tch/src/backend.rs @@ -2,7 +2,7 @@ use crate::PrecisionBridge; use super::element::TchElement; use super::TchTensor; -use burn_tensor::backend::Backend; +use burn_tensor::backend::{Backend, DeviceId, DeviceOps}; use burn_tensor::ops::IntTensorOps; use burn_tensor::{Int, Tensor}; @@ -59,6 +59,17 @@ impl From for LibTorchDevice { } } +impl DeviceOps for LibTorchDevice { + fn id(&self) -> burn_tensor::backend::DeviceId { + match self { + LibTorchDevice::Cpu => DeviceId::new(0, 0), + LibTorchDevice::Cuda(index) => DeviceId::new(1, *index as u32), + LibTorchDevice::Mps => DeviceId::new(2, 0), + LibTorchDevice::Vulkan => DeviceId::new(3, 0), + } + } +} + impl Default for LibTorchDevice { fn default() -> Self { Self::Cpu diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index c3e819d777..00022902a5 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -11,11 +11,12 @@ repository = "https://github.com/tracel-ai/burn/tree/main/burn-tensor" version.workspace = true [features] -default = ["std"] +default = ["std", "repr"] doc = ["default"] experimental-named-tensor = [] export_tests = ["burn-tensor-testgen"] std = ["rand/std", "half/std", "num-traits/std"] +repr = [] wasm-sync = [] [dependencies] diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index d16a02e220..18156030b8 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -11,6 +11,10 @@ extern crate alloc; mod tensor; +/// Burn Tensor representaton +#[cfg(feature = "repr")] +pub mod repr; + #[cfg(feature = "export_tests")] #[allow(missing_docs)] mod tests; diff --git a/crates/burn-tensor/src/repr/backend.rs b/crates/burn-tensor/src/repr/backend.rs new file mode 100644 index 0000000000..28e86eda30 --- /dev/null +++ b/crates/burn-tensor/src/repr/backend.rs @@ -0,0 +1,26 @@ +use crate::{ + backend::Backend, + ops::{BoolTensor, FloatTensor, IntTensor}, + Shape, +}; + +/// Backend extension trait that allows an existing [backend](Backend) to use the Burn tensor representation +/// for compilation purpose or other... +pub trait ReprBackend: Backend { + /// The type that can be used to point to a tensor of any kind. + type Handle: Sync + Send + Clone; + + /// Convert a [handle](ReprBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive). + fn float_tensor(handle: Self::Handle, shape: Shape) -> FloatTensor; + /// Convert a [handle](ReprBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive). + fn int_tensor(handle: Self::Handle, shape: Shape) -> IntTensor; + /// Convert a [handle](ReprBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive). + fn bool_tensor(handle: Self::Handle, shape: Shape) -> BoolTensor; + + /// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](ReprBackend::Handle). + fn float_tensor_handle(tensor: FloatTensor) -> Self::Handle; + /// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](ReprBackend::Handle). + fn int_tensor_handle(tensor: IntTensor) -> Self::Handle; + /// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](ReprBackend::Handle). + fn bool_tensor_handle(tensor: BoolTensor) -> Self::Handle; +} diff --git a/crates/burn-fusion/src/handle.rs b/crates/burn-tensor/src/repr/handle.rs similarity index 69% rename from crates/burn-fusion/src/handle.rs rename to crates/burn-tensor/src/repr/handle.rs index a66a80c23b..89cf6b8fc2 100644 --- a/crates/burn-fusion/src/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -1,30 +1,41 @@ -use crate::{FusionBackend, TensorDescription, TensorId, TensorStatus}; -use burn_tensor::Shape; +use crate::{ + backend::Backend, + repr::{ + backend::ReprBackend, + tensor::{TensorDescription, TensorId, TensorStatus}, + }, + Shape, +}; use std::{collections::HashMap, sync::Arc}; -/// Keep all [tensor handles](FusionBackend::Handle) in one place and ensure that all resources +/// Keep all [tensor handles](ReprBackend::Handle) in one place and ensure that all resources /// are used optimally. #[derive(Default)] -pub struct HandleContainer { +pub struct HandleContainer { handles: HashMap>, counter: u64, - pub(crate) handles_orphan: Vec, + /// Handle candidates to be freed. + pub handles_orphan: Vec, /// The device on which all tensors are held. pub device: B::Device, } -enum Handle { +/// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state +pub enum Handle { + /// No [tensor handle](ReprBackend::Handle) has been created yet NotInit, + /// A [tensor handle](ReprBackend::Handle) has been created Existing(B::Handle), } -impl HandleContainer { - pub(crate) fn new(device_handle: B::FusionDevice) -> Self { +impl HandleContainer { + /// Create a new HandleContainer + pub fn new(device_handle: B::Device) -> Self { Self { handles: HashMap::new(), handles_orphan: Vec::new(), counter: 0, - device: device_handle.clone().into(), + device: device_handle.clone(), } } @@ -59,69 +70,69 @@ impl HandleContainer { } } - /// Get the [float tensor](burn_tensor::backend::Backend::FloatTensorPrimitive) corresponding to the + /// Get the [float tensor](Backend::FloatTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_float_tensor( &mut self, tensor: &TensorDescription, ) -> B::FloatTensorPrimitive { - B::float_tensor( + B::float_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), ) } - /// Get the [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) corresponding to the + /// Get the [int tensor](Backend::IntTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_int_tensor( &mut self, tensor: &TensorDescription, ) -> B::IntTensorPrimitive { - B::int_tensor( + B::int_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), ) } - /// Get the [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) corresponding to the + /// Get the [bool tensor](Backend::BoolTensorPrimitive) corresponding to the /// given [tensor description](TensorDescription). pub fn get_bool_tensor( &mut self, tensor: &TensorDescription, ) -> B::BoolTensorPrimitive { - B::bool_tensor( + B::bool_tensor::( self.get_handle(&tensor.id, &tensor.status), Shape::from(&tensor.shape), ) } - /// Register a new [float tensor](burn_tensor::backend::Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [float tensor](Backend::FloatTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_float_tensor( &mut self, id: &TensorId, tensor: B::FloatTensorPrimitive, ) { - let handle = B::float_tensor_handle(tensor); + let handle = B::float_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } - /// Register a new [int tensor](burn_tensor::backend::Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [int tensor](Backend::IntTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_int_tensor( &mut self, id: &TensorId, tensor: B::IntTensorPrimitive, ) { - let handle = B::int_tensor_handle(tensor); + let handle = B::int_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } - /// Register a new [bool tensor](burn_tensor::backend::Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). + /// Register a new [bool tensor](Backend::BoolTensorPrimitive) with the corresponding [tensor id](TensorId). pub fn register_bool_tensor( &mut self, id: &TensorId, tensor: B::BoolTensorPrimitive, ) { - let handle = B::bool_tensor_handle(tensor); + let handle = B::bool_tensor_handle::(tensor); self.handles.insert(*id, Handle::Existing(handle)); } @@ -134,7 +145,8 @@ impl HandleContainer { Arc::new(id) } - pub(crate) fn free(&mut self, tensor: &TensorDescription) { + /// Remove tensor handle from container if writable + pub fn free(&mut self, tensor: &TensorDescription) { match tensor.status { TensorStatus::ReadOnly => (), TensorStatus::NotInit => (), @@ -144,7 +156,8 @@ impl HandleContainer { } } - pub(crate) fn free_orphans(&mut self, remaining: &[&TensorId]) { + /// Remove tensor handle from container if not in use + pub fn free_orphans(&mut self, remaining: &[&TensorId]) { let mut handles_orphan = Vec::new(); // TODO: Optimization => Change the for loop order depending of the length of each. diff --git a/crates/burn-tensor/src/repr/mod.rs b/crates/burn-tensor/src/repr/mod.rs new file mode 100644 index 0000000000..b98e43ba3d --- /dev/null +++ b/crates/burn-tensor/src/repr/mod.rs @@ -0,0 +1,9 @@ +mod backend; +mod handle; +mod operation; +mod tensor; + +pub use backend::*; +pub use handle::*; +pub use operation::*; +pub use tensor::*; diff --git a/crates/burn-fusion/src/stream/operation.rs b/crates/burn-tensor/src/repr/operation.rs similarity index 77% rename from crates/burn-fusion/src/stream/operation.rs rename to crates/burn-tensor/src/repr/operation.rs index bba5f30c8e..220d8a90cf 100644 --- a/crates/burn-fusion/src/stream/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -1,15 +1,11 @@ -use crate::FusionBackend; -use crate::{HandleContainer, TensorDescription}; -use burn_tensor::ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions}; -use burn_tensor::{Distribution, Element}; use serde::{Deserialize, Serialize}; use std::ops::Range; -/// General trait to abstract how a single operation is executed. -pub trait Operation: Send + Sync { - /// Execute the operation. - fn execute(self: Box, handles: &mut HandleContainer); -} +use crate::{ + ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions}, + repr::tensor::TensorDescription, + Distribution, Element, +}; /// Describe all tensor operations possible. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] @@ -37,92 +33,92 @@ pub enum OperationDescription { /// Operation description specific to a float tensor. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum FloatOperationDescription { - /// Operation corresponding to [exp](burn_tensor::ops::FloatTensorOps::float_exp). + /// Operation corresponding to [exp](crate::ops::FloatTensorOps::float_exp). Exp(UnaryOperationDescription), - /// Operation corresponding to [log](burn_tensor::ops::FloatTensorOps::float_log). + /// Operation corresponding to [log](crate::ops::FloatTensorOps::float_log). Log(UnaryOperationDescription), - /// Operation corresponding to [log1p](burn_tensor::ops::FloatTensorOps::float_log1p). + /// Operation corresponding to [log1p](crate::ops::FloatTensorOps::float_log1p). Log1p(UnaryOperationDescription), - /// Operation corresponding to [erf](burn_tensor::ops::FloatTensorOps::float_erf). + /// Operation corresponding to [erf](crate::ops::FloatTensorOps::float_erf). Erf(UnaryOperationDescription), - /// Operation corresponding to [powf_scalar](burn_tensor::ops::FloatTensorOps::float_powf_scalar). + /// Operation corresponding to [powf_scalar](crate::ops::FloatTensorOps::float_powf_scalar). PowfScalar(ScalarOperationDescription), - /// Operation corresponding to [sqrt](burn_tensor::ops::FloatTensorOps::float_sqrt). + /// Operation corresponding to [sqrt](crate::ops::FloatTensorOps::float_sqrt). Sqrt(UnaryOperationDescription), - /// Operation corresponding to [cos](burn_tensor::ops::FloatTensorOps::float_cos). + /// Operation corresponding to [cos](crate::ops::FloatTensorOps::float_cos). Cos(UnaryOperationDescription), - /// Operation corresponding to [sin](burn_tensor::ops::FloatTensorOps::float_sin). + /// Operation corresponding to [sin](crate::ops::FloatTensorOps::float_sin). Sin(UnaryOperationDescription), - /// Operation corresponding to [tanh](burn_tensor::ops::FloatTensorOps::float_tanh). + /// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh). Tanh(UnaryOperationDescription), - /// Operation corresponding to [into_int](burn_tensor::ops::FloatTensorOps::float_into_int). + /// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int). IntoInt(UnaryOperationDescription), - /// Operation corresponding to [matmul](burn_tensor::ops::FloatTensorOps::float_matmul). + /// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul). Matmul(BinaryOperationDescription), - /// Operation corresponding to [random](burn_tensor::ops::FloatTensorOps::float_random). + /// Operation corresponding to [random](crate::ops::FloatTensorOps::float_random). Random(RandomOperationDescription), - /// Operation corresponding to [recip](burn_tensor::ops::FloatTensorOps::float_recip). + /// Operation corresponding to [recip](crate::ops::FloatTensorOps::float_recip). Recip(UnaryOperationDescription), } /// Operation description specific to module. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum ModuleOperationDescription { - /// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding). + /// Operation corresponding to [embedding](crate::ops::ModuleOps::embedding). Embedding(EmbeddingDescription), - /// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward). + /// Operation corresponding to [embedding_backward](crate::ops::ModuleOps::embedding_backward). EmbeddingBackward(EmbeddingBackwardDescription), - /// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d). + /// Operation corresponding to [conv1d](crate::ops::ModuleOps::conv1d). Conv1d(Conv1dDescription), - /// Operation corresponding to [conv2d](burn_tensor::ops::ModuleOps::conv2d). + /// Operation corresponding to [conv2d](crate::ops::ModuleOps::conv2d). Conv2d(Conv2dDescription), - /// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d). + /// Operation corresponding to [conv transpose 1d](crate::ops::ModuleOps::conv_transpose1d). ConvTranspose1d(ConvTranspose1dDescription), - /// Operation corresponding to [conv transpose 2d](burn_tensor::ops::ModuleOps::conv_transpose2d). + /// Operation corresponding to [conv transpose 2d](crate::ops::ModuleOps::conv_transpose2d). ConvTranspose2d(ConvTranspose2dDescription), - /// Operation corresponding to [avg pool 1d](burn_tensor::ops::ModuleOps::avg_pool1d). + /// Operation corresponding to [avg pool 1d](crate::ops::ModuleOps::avg_pool1d). AvgPool1d(AvgPool1dDescription), - /// Operation corresponding to [avg pool 2d](burn_tensor::ops::ModuleOps::avg_pool2d). + /// Operation corresponding to [avg pool 2d](crate::ops::ModuleOps::avg_pool2d). AvgPool2d(AvgPool2dDescription), /// Operation corresponding to - /// [avg pool 1d backward](burn_tensor::ops::ModuleOps::avg_pool1d_backward). + /// [avg pool 1d backward](crate::ops::ModuleOps::avg_pool1d_backward). AvgPool1dBackward(AvgPool1dBackwardDescription), /// Operation corresponding to - /// [avg pool 2d backward](burn_tensor::ops::ModuleOps::avg_pool2d_backward). + /// [avg pool 2d backward](crate::ops::ModuleOps::avg_pool2d_backward). AvgPool2dBackward(AvgPool2dBackwardDescription), /// Operation corresponding to - /// [adaptive avg pool 1d](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d). + /// [adaptive avg pool 1d](crate::ops::ModuleOps::adaptive_avg_pool1d). AdaptiveAvgPool1d(AdaptiveAvgPool1dDescription), /// Operation corresponding to - /// [adaptive avg pool 2d](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d). + /// [adaptive avg pool 2d](crate::ops::ModuleOps::adaptive_avg_pool2d). AdaptiveAvgPool2d(AdaptiveAvgPool2dDescription), /// Operation corresponding to - /// [adaptive avg pool 1d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d_backward). + /// [adaptive avg pool 1d backward](crate::ops::ModuleOps::adaptive_avg_pool1d_backward). AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardDescription), /// Operation corresponding to - /// [adaptive avg pool 2d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d_backward). + /// [adaptive avg pool 2d backward](crate::ops::ModuleOps::adaptive_avg_pool2d_backward). AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardDescription), /// Operation corresponding to - /// [max pool 1d](burn_tensor::ops::ModuleOps::max_pool1d). + /// [max pool 1d](crate::ops::ModuleOps::max_pool1d). MaxPool1d(MaxPool1dDescription), /// Operation corresponding to - /// [max pool 1d with indices](burn_tensor::ops::ModuleOps::max_pool1d_with_indices). + /// [max pool 1d with indices](crate::ops::ModuleOps::max_pool1d_with_indices). MaxPool1dWithIndices(MaxPool1dWithIndicesDescription), /// Operation corresponding to - /// [max pool 1d with indices backward](burn_tensor::ops::ModuleOps::max_pool1d_with_indices_backward). + /// [max pool 1d with indices backward](crate::ops::ModuleOps::max_pool1d_with_indices_backward). MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardDescription), /// Operation corresponding to - /// [max pool 2d](burn_tensor::ops::ModuleOps::max_pool1d). + /// [max pool 2d](crate::ops::ModuleOps::max_pool1d). MaxPool2d(MaxPool2dDescription), /// Operation corresponding to - /// [max pool 2d with indices](burn_tensor::ops::ModuleOps::max_pool2d_with_indices). + /// [max pool 2d with indices](crate::ops::ModuleOps::max_pool2d_with_indices). MaxPool2dWithIndices(MaxPool2dWithIndicesDescription), /// Operation corresponding to - /// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward). + /// [max pool 2d with indices backward](crate::ops::ModuleOps::max_pool2d_with_indices_backward). MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardDescription), - /// Operation corresponding to [interpolate](burn_tensor::ops::ModuleOps::interpolate). + /// Operation corresponding to [interpolate](crate::ops::ModuleOps::interpolate). Interpolate(InterpolateDescription), - /// Operation corresponding to [interpolate backward](burn_tensor::ops::ModuleOps::interpolate_backward). + /// Operation corresponding to [interpolate backward](crate::ops::ModuleOps::interpolate_backward). InterpolateBackward(InterpolateBackwardDescription), } @@ -131,73 +127,73 @@ pub enum ModuleOperationDescription { pub enum BaseOperationDescription { /// Operation corresponding to: /// - /// Float => [to device](burn_tensor::ops::FloatTensorOps::float_to_device). - /// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device). - /// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device). + /// Float => [to device](crate::ops::FloatTensorOps::float_to_device). + /// Int => [to device](crate::ops::IntTensorOps::int_to_device). + /// Bool => [to device](crate::ops::BoolTensorOps::bool_to_device). ToDevice(TensorDescription), /// Operation corresponding to: /// - /// Float => [reshape](burn_tensor::ops::FloatTensorOps::float_reshape). - /// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape). - /// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape). + /// Float => [reshape](crate::ops::FloatTensorOps::float_reshape). + /// Int => [reshape](crate::ops::IntTensorOps::int_reshape). + /// Bool => [reshape](crate::ops::BoolTensorOps::bool_reshape). Reshape(ReshapeDescription), /// Operation corresponding to: /// - /// Float => [swap_dims](burn_tensor::ops::FloatTensorOps::float_swap_dims). - /// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims). - /// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims). + /// Float => [swap_dims](crate::ops::FloatTensorOps::float_swap_dims). + /// Int => [swap_dims](crate::ops::IntTensorOps::int_swap_dims). + /// Bool => [swap_dims](crate::ops::BoolTensorOps::bool_swap_dims). SwapDims(SwapDimsDescription), /// Operation corresponding to: /// - /// Float => [permute](burn_tensor::ops::FloatTensorOps::float_permute). - /// Int => [permute](burn_tensor::ops::IntTensorOps::int_permute). - /// Bool => [permute](burn_tensor::ops::BoolTensorOps::bool_permute). + /// Float => [permute](crate::ops::FloatTensorOps::float_permute). + /// Int => [permute](crate::ops::IntTensorOps::int_permute). + /// Bool => [permute](crate::ops::BoolTensorOps::bool_permute). Permute(PermuteOperationDescription), /// Operation corresponding to: - /// Float => [flip](burn_tensor::ops::FloatTensorOps::float_flip). - /// Int => [flip](burn_tensor::ops::IntTensorOps::int_flip). - /// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip). + /// Float => [flip](crate::ops::FloatTensorOps::float_flip). + /// Int => [flip](crate::ops::IntTensorOps::int_flip). + /// Bool => [flip](crate::ops::BoolTensorOps::bool_flip). Flip(FlipOperationDescription), /// Operation corresponding to: /// - /// Float => [expand](burn_tensor::ops::FloatTensorOps::float_expand). - /// Int => [expand](burn_tensor::ops::IntTensorOps::int_expand). - /// Bool => [expand](burn_tensor::ops::BoolTensorOps::bool_expand). + /// Float => [expand](crate::ops::FloatTensorOps::float_expand). + /// Int => [expand](crate::ops::IntTensorOps::int_expand). + /// Bool => [expand](crate::ops::BoolTensorOps::bool_expand). Expand(ExpandOperationDescription), /// Operation corresponding to: /// - /// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice). - /// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice). - /// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice). + /// Float => [slice](crate::ops::FloatTensorOps::float_slice). + /// Int => [slice](crate::ops::IntTensorOps::int_slice). + /// Bool => [slice](crate::ops::BoolTensorOps::bool_slice). Slice(SliceOperationDescription), /// Operation corresponding to: /// - /// Float => [slice assign](burn_tensor::ops::FloatTensorOps::float_slice_assign). - /// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign). - /// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign). + /// Float => [slice assign](crate::ops::FloatTensorOps::float_slice_assign). + /// Int => [slice assign](crate::ops::IntTensorOps::int_slice_assign). + /// Bool => [slice assign](crate::ops::BoolTensorOps::bool_slice_assign). SliceAssign(SliceAssignOperationDescription), /// Operation corresponding to: /// - /// Float => [equal](burn_tensor::ops::FloatTensorOps::float_equal). - /// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal). - /// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal). + /// Float => [equal](crate::ops::FloatTensorOps::float_equal). + /// Int => [equal](crate::ops::IntTensorOps::int_equal). + /// Bool => [equal](crate::ops::BoolTensorOps::bool_equal). Equal(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [repeat](burn_tensor::ops::FloatTensorOps::float_repeat). - /// Int => [repeat](burn_tensor::ops::IntTensorOps::int_repeat). - /// Bool => [repeat](burn_tensor::ops::BoolTensorOps::bool_repeat). + /// Float => [repeat](crate::ops::FloatTensorOps::float_repeat). + /// Int => [repeat](crate::ops::IntTensorOps::int_repeat). + /// Bool => [repeat](crate::ops::BoolTensorOps::bool_repeat). Repeat(RepeatOperationDescription), /// Operation corresponding to: /// - /// Float => [cat](burn_tensor::ops::FloatTensorOps::float_cat). - /// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat). - /// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat). + /// Float => [cat](crate::ops::FloatTensorOps::float_cat). + /// Int => [cat](crate::ops::IntTensorOps::int_cat). + /// Bool => [cat](crate::ops::BoolTensorOps::bool_cat). Cat(CatOperationDescription), } @@ -206,248 +202,248 @@ pub enum BaseOperationDescription { pub enum NumericOperationDescription { /// Operation corresponding to: /// - /// Float => [add](burn_tensor::ops::FloatTensorOps::float_add). - /// Int => [add](burn_tensor::ops::IntTensorOps::int_add). + /// Float => [add](crate::ops::FloatTensorOps::float_add). + /// Int => [add](crate::ops::IntTensorOps::int_add). Add(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [add scalar](burn_tensor::ops::FloatTensorOps::float_add_scalar). - /// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar). + /// Float => [add scalar](crate::ops::FloatTensorOps::float_add_scalar). + /// Int => [add scalar](crate::ops::IntTensorOps::int_add_scalar). AddScalar(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [sub](burn_tensor::ops::FloatTensorOps::float_sub). - /// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub). + /// Float => [sub](crate::ops::FloatTensorOps::float_sub). + /// Int => [sub](crate::ops::IntTensorOps::int_sub). Sub(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [sub scalar](burn_tensor::ops::FloatTensorOps::float_sub_scalar). - /// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar). + /// Float => [sub scalar](crate::ops::FloatTensorOps::float_sub_scalar). + /// Int => [sub scalar](crate::ops::IntTensorOps::int_sub_scalar). SubScalar(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [div](burn_tensor::ops::FloatTensorOps::float_div). - /// Int => [div](burn_tensor::ops::IntTensorOps::int_div). + /// Float => [div](crate::ops::FloatTensorOps::float_div). + /// Int => [div](crate::ops::IntTensorOps::int_div). Div(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [div scalar](burn_tensor::ops::FloatTensorOps::float_div_scalar). - /// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar). + /// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar). + /// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar). DivScalar(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [div](burn_tensor::ops::FloatTensorOps::float_remainder_scalar). - /// Int => [div](burn_tensor::ops::IntTensorOps::int_remainder_scalar). + /// Float => [div](crate::ops::FloatTensorOps::float_remainder_scalar). + /// Int => [div](crate::ops::IntTensorOps::int_remainder_scalar). RemScalar(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [mul](burn_tensor::ops::FloatTensorOps::float_mul). - /// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul). + /// Float => [mul](crate::ops::FloatTensorOps::float_mul). + /// Int => [mul](crate::ops::IntTensorOps::int_mul). Mul(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [mul scalar](burn_tensor::ops::FloatTensorOps::float_mul_scalar). - /// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar). + /// Float => [mul scalar](crate::ops::FloatTensorOps::float_mul_scalar). + /// Int => [mul scalar](crate::ops::IntTensorOps::int_mul_scalar). MulScalar(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [abs](burn_tensor::ops::FloatTensorOps::float_abs). - /// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs). + /// Float => [abs](crate::ops::FloatTensorOps::float_abs). + /// Int => [abs](crate::ops::IntTensorOps::int_abs). Abs(UnaryOperationDescription), /// Operation corresponding to: /// - /// Float => [ones](burn_tensor::ops::FloatTensorOps::float_ones). - /// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones). + /// Float => [ones](crate::ops::FloatTensorOps::float_ones). + /// Int => [ones](crate::ops::IntTensorOps::int_ones). Ones(TensorDescription), /// Operation corresponding to: /// - /// Float => [zeros](burn_tensor::ops::FloatTensorOps::float_zeros). - /// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros). + /// Float => [zeros](crate::ops::FloatTensorOps::float_zeros). + /// Int => [zeros](crate::ops::IntTensorOps::int_zeros). Zeros(TensorDescription), /// Operation corresponding to: /// - /// Float => [full](burn_tensor::ops::FloatTensorOps::float_full). - /// Int => [full](burn_tensor::ops::IntTensorOps::int_full). + /// Float => [full](crate::ops::FloatTensorOps::float_full). + /// Int => [full](crate::ops::IntTensorOps::int_full). Full((TensorDescription, E)), /// Operation corresponding to: /// - /// Float => [gather](burn_tensor::ops::FloatTensorOps::float_gather). - /// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather). + /// Float => [gather](crate::ops::FloatTensorOps::float_gather). + /// Int => [gather](crate::ops::IntTensorOps::int_gather). Gather(GatherOperationDescription), /// Operation corresponding to: /// - /// Float => [scatter](burn_tensor::ops::FloatTensorOps::float_scatter). - /// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter). + /// Float => [scatter](crate::ops::FloatTensorOps::float_scatter). + /// Int => [scatter](crate::ops::IntTensorOps::int_scatter). Scatter(ScatterOperationDescription), /// Operation corresponding to: /// - /// Float => [select](burn_tensor::ops::FloatTensorOps::float_select). - /// Int => [select](burn_tensor::ops::IntTensorOps::int_select). + /// Float => [select](crate::ops::FloatTensorOps::float_select). + /// Int => [select](crate::ops::IntTensorOps::int_select). Select(SelectOperationDescription), /// Operation corresponding to: /// - /// Float => [select assign](burn_tensor::ops::FloatTensorOps::float_select_assign). - /// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign). + /// Float => [select assign](crate::ops::FloatTensorOps::float_select_assign). + /// Int => [select assign](crate::ops::IntTensorOps::int_select_assign). SelectAssign(SelectAssignOperationDescription), /// Operation corresponding to: /// - /// Float => [mask where](burn_tensor::ops::FloatTensorOps::float_mask_where). - /// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where). + /// Float => [mask where](crate::ops::FloatTensorOps::float_mask_where). + /// Int => [mask where](crate::ops::IntTensorOps::int_mask_where). MaskWhere(MaskWhereOperationDescription), /// Operation corresponding to: /// - /// Float => [mask fill](burn_tensor::ops::FloatTensorOps::float_mask_fill). - /// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill). + /// Float => [mask fill](crate::ops::FloatTensorOps::float_mask_fill). + /// Int => [mask fill](crate::ops::IntTensorOps::int_mask_fill). MaskFill(MaskFillOperationDescription), /// Operation corresponding to: /// - /// Float => [mean dim](burn_tensor::ops::FloatTensorOps::float_mean_dim). - /// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim). + /// Float => [mean dim](crate::ops::FloatTensorOps::float_mean_dim). + /// Int => [mean dim](crate::ops::IntTensorOps::int_mean_dim). MeanDim(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [mean](burn_tensor::ops::FloatTensorOps::float_mean). - /// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean). + /// Float => [mean](crate::ops::FloatTensorOps::float_mean). + /// Int => [mean](crate::ops::IntTensorOps::int_mean). Mean(UnaryOperationDescription), /// Operation corresponding to: /// - /// Float => [sum](burn_tensor::ops::FloatTensorOps::float_sum). - /// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum). + /// Float => [sum](crate::ops::FloatTensorOps::float_sum). + /// Int => [sum](crate::ops::IntTensorOps::int_sum). Sum(UnaryOperationDescription), /// Operation corresponding to: /// - /// Float => [sum dim](burn_tensor::ops::FloatTensorOps::float_sum_dim). - /// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim). + /// Float => [sum dim](crate::ops::FloatTensorOps::float_sum_dim). + /// Int => [sum dim](crate::ops::IntTensorOps::int_sum_dim). SumDim(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [prod](burn_tensor::ops::FloatTensorOps::float_prod). - /// Int => [prod](burn_tensor::ops::IntTensorOps::int_prod). + /// Float => [prod](crate::ops::FloatTensorOps::float_prod). + /// Int => [prod](crate::ops::IntTensorOps::int_prod). Prod(UnaryOperationDescription), /// Operation corresponding to: /// - /// Float => [prod dim](burn_tensor::ops::FloatTensorOps::float_prod_dim). - /// Int => [prod dim](burn_tensor::ops::IntTensorOps::int_prod_dim). + /// Float => [prod dim](crate::ops::FloatTensorOps::float_prod_dim). + /// Int => [prod dim](crate::ops::IntTensorOps::int_prod_dim). ProdDim(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [equal elem](burn_tensor::ops::FloatTensorOps::float_equal_elem). - /// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem). + /// Float => [equal elem](crate::ops::FloatTensorOps::float_equal_elem). + /// Int => [equal elem](crate::ops::IntTensorOps::int_equal_elem). EqualElem(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [greater](burn_tensor::ops::FloatTensorOps::float_greater). - /// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater). + /// Float => [greater](crate::ops::FloatTensorOps::float_greater). + /// Int => [greater](crate::ops::IntTensorOps::int_greater). Greater(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [greater elem](burn_tensor::ops::FloatTensorOps::float_greater_elem). - /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). + /// Float => [greater elem](crate::ops::FloatTensorOps::float_greater_elem). + /// Int => [greater elem](crate::ops::IntTensorOps::int_greater_elem). GreaterElem(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [greater equal](burn_tensor::ops::FloatTensorOps::float_greater_elem). - /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem). + /// Float => [greater equal](crate::ops::FloatTensorOps::float_greater_elem). + /// Int => [greater elem](crate::ops::IntTensorOps::int_greater_elem). GreaterEqual(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [greater equal elem](burn_tensor::ops::FloatTensorOps::float_greater_equal_elem). - /// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem). + /// Float => [greater equal elem](crate::ops::FloatTensorOps::float_greater_equal_elem). + /// Int => [greater equal elem](crate::ops::IntTensorOps::int_greater_equal_elem). GreaterEqualElem(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [lower](burn_tensor::ops::FloatTensorOps::float_lower). - /// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower). + /// Float => [lower](crate::ops::FloatTensorOps::float_lower). + /// Int => [lower](crate::ops::IntTensorOps::int_lower). Lower(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [lower elem](burn_tensor::ops::FloatTensorOps::float_lower_elem). - /// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem). + /// Float => [lower elem](crate::ops::FloatTensorOps::float_lower_elem). + /// Int => [lower elem](crate::ops::IntTensorOps::int_lower_elem). LowerElem(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [lower equal](burn_tensor::ops::FloatTensorOps::float_lower_equal). - /// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal). + /// Float => [lower equal](crate::ops::FloatTensorOps::float_lower_equal). + /// Int => [lower equal](crate::ops::IntTensorOps::int_lower_equal). LowerEqual(BinaryOperationDescription), /// Operation corresponding to: /// - /// Float => [lower equal elem](burn_tensor::ops::FloatTensorOps::float_lower_equal_elem). - /// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem). + /// Float => [lower equal elem](crate::ops::FloatTensorOps::float_lower_equal_elem). + /// Int => [lower equal elem](crate::ops::IntTensorOps::int_lower_equal_elem). LowerEqualElem(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [argmax](burn_tensor::ops::FloatTensorOps::float_argmax). - /// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax). + /// Float => [argmax](crate::ops::FloatTensorOps::float_argmax). + /// Int => [argmax](crate::ops::IntTensorOps::int_argmax). ArgMax(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [argmin](burn_tensor::ops::FloatTensorOps::float_argmin). - /// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin). + /// Float => [argmin](crate::ops::FloatTensorOps::float_argmin). + /// Int => [argmin](crate::ops::IntTensorOps::int_argmin). ArgMin(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [max](burn_tensor::ops::FloatTensorOps::float_max). - /// Int => [max](burn_tensor::ops::IntTensorOps::int_max). + /// Float => [max](crate::ops::FloatTensorOps::float_max). + /// Int => [max](crate::ops::IntTensorOps::int_max). Max(UnaryOperationDescription), /// Operation corresponding to: /// - /// Float => [max dim with indices](burn_tensor::ops::FloatTensorOps::float_max_dim_with_indices). - /// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices). + /// Float => [max dim with indices](crate::ops::FloatTensorOps::float_max_dim_with_indices). + /// Int => [max dim with indices](crate::ops::IntTensorOps::int_max_dim_with_indices). MaxDimWithIndices(ReduceDimWithIndicesDescription), /// Operation corresponding to: /// - /// Float => [min dim with indices](burn_tensor::ops::FloatTensorOps::float_min_dim_with_indices). - /// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices). + /// Float => [min dim with indices](crate::ops::FloatTensorOps::float_min_dim_with_indices). + /// Int => [min dim with indices](crate::ops::IntTensorOps::int_min_dim_with_indices). MinDimWithIndices(ReduceDimWithIndicesDescription), /// Operation corresponding to: /// - /// Float => [min](burn_tensor::ops::FloatTensorOps::float_min). - /// Int => [min](burn_tensor::ops::IntTensorOps::int_min). + /// Float => [min](crate::ops::FloatTensorOps::float_min). + /// Int => [min](crate::ops::IntTensorOps::int_min). Min(UnaryOperationDescription), /// Operation corresponding to: /// - /// Float => [max dim](burn_tensor::ops::FloatTensorOps::float_max_dim). - /// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim). + /// Float => [max dim](crate::ops::FloatTensorOps::float_max_dim). + /// Int => [max dim](crate::ops::IntTensorOps::int_max_dim). MaxDim(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [min dim](burn_tensor::ops::FloatTensorOps::float_min_dim). - /// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim). + /// Float => [min dim](crate::ops::FloatTensorOps::float_min_dim). + /// Int => [min dim](crate::ops::IntTensorOps::int_min_dim). MinDim(ScalarOperationDescription), /// Operation corresponding to: /// - /// Float => [clamp](burn_tensor::ops::FloatTensorOps::float_clamp). - /// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp). + /// Float => [clamp](crate::ops::FloatTensorOps::float_clamp). + /// Int => [clamp](crate::ops::IntTensorOps::int_clamp). Clamp(ClampOperationDescription), /// Operation corresponding to: /// - /// Int => [random](burn_tensor::ops::IntTensorOps::int_random). + /// Int => [random](crate::ops::IntTensorOps::int_random). IntRandom(RandomOperationDescription), /// Operation corresponding to: /// - /// Float => [powf](burn_tensor::ops::FloatTensorOps::float_powf). - /// Int => [powf](burn_tensor::ops::IntTensorOps::int_powf). + /// Float => [powf](crate::ops::FloatTensorOps::float_powf). + /// Int => [powf](crate::ops::IntTensorOps::int_powf). Powf(BinaryOperationDescription), } /// Operation description specific to an int tensor. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum IntOperationDescription { - /// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float). + /// Operation corresponding to [into float](crate::ops::IntTensorOps::int_into_float). IntoFloat(UnaryOperationDescription), } /// Operation description specific to a bool tensor. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum BoolOperationDescription { - /// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float). + /// Operation corresponding to [into float](crate::ops::BoolTensorOps::bool_into_float). IntoFloat(UnaryOperationDescription), - /// Operation corresponding to [into int](burn_tensor::ops::BoolTensorOps::bool_into_int). + /// Operation corresponding to [into int](crate::ops::BoolTensorOps::bool_into_int). IntoInt(UnaryOperationDescription), - /// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not). + /// Operation corresponding to [not](crate::ops::BoolTensorOps::bool_not). Not(UnaryOperationDescription), } @@ -1057,7 +1053,7 @@ pub struct InterpolateBackwardDescription { impl OperationDescription { /// Cleanup the remaining tensor handles that have not been used. - pub(crate) fn nodes(&self) -> Vec<&TensorDescription> { + pub fn nodes(&self) -> Vec<&TensorDescription> { match self { OperationDescription::BaseFloat(ops) => ops.nodes(), OperationDescription::BaseInt(ops) => ops.nodes(), diff --git a/crates/burn-tensor/src/repr/tensor.rs b/crates/burn-tensor/src/repr/tensor.rs new file mode 100644 index 0000000000..525ad9c50b --- /dev/null +++ b/crates/burn-tensor/src/repr/tensor.rs @@ -0,0 +1,45 @@ +use serde::{Deserialize, Serialize}; + +/// The tensor unique identifier. +#[derive(Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize)] +pub struct TensorId { + value: u64, +} + +/// The status of the current tensor. +#[derive(Hash, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub enum TensorStatus { + /// The tensor can be read, but not written. + ReadOnly, + /// The tensor can be mutated inplace. + ReadWrite, + /// No handle exists for that tensor. + NotInit, +} + +/// A tensor definition represents a snapshot of a tensor when it was used. +/// +/// # Example +/// +/// A tensor that is used multiple times has its status updated for each operation. +/// +/// 1. Status::NotInit +/// 2. Status::ReadOnly +/// 3. Status::ReadOnly +/// 4. Status::ReadWrite +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct TensorDescription { + /// The [tensor id](TensorId). + pub id: TensorId, + /// The shape of the tensor. + pub shape: Vec, + /// The [status](TensorStatus) of the tensor when it was used. + pub status: TensorStatus, +} + +impl TensorId { + /// Create a new tensor id. + pub fn new(value: u64) -> Self { + Self { value } + } +} diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index 68cc3be975..182235df1d 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -3,7 +3,7 @@ use alloc::string::String; use crate::ops::*; use crate::tensor::Element; -use super::BackendBridge; +use super::{BackendBridge, DeviceOps}; /// This trait defines all types and functions needed for a backend to be used with burn. /// @@ -66,7 +66,7 @@ pub trait Backend: + 'static { /// Device type. - type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; + type Device: DeviceOps; /// A bridge that can cast tensors to full precision. type FullPrecisionBridge: BackendBridge + 'static; diff --git a/crates/burn-tensor/src/tensor/backend/device.rs b/crates/burn-tensor/src/tensor/backend/device.rs new file mode 100644 index 0000000000..64279f5b01 --- /dev/null +++ b/crates/burn-tensor/src/tensor/backend/device.rs @@ -0,0 +1,14 @@ +/// The device id. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)] +pub struct DeviceId { + /// The type id identifies the type of the device. + pub type_id: u16, + /// The index id identifies the device number. + pub index_id: u32, +} + +/// The handle device trait allows to get an id for a backend device. +pub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debug { + /// Return the [device id](DeviceId). + fn id(&self) -> DeviceId; +} diff --git a/crates/burn-tensor/src/tensor/backend/mod.rs b/crates/burn-tensor/src/tensor/backend/mod.rs index 64780085c1..931552cb82 100644 --- a/crates/burn-tensor/src/tensor/backend/mod.rs +++ b/crates/burn-tensor/src/tensor/backend/mod.rs @@ -1,8 +1,10 @@ mod base; mod bridge; +mod device; pub use base::*; pub use bridge::*; +pub use device::*; // Not needed for now, useful for different tensor memory layout // pub mod conversion; diff --git a/crates/burn-wgpu/src/fusion.rs b/crates/burn-wgpu/src/fusion.rs deleted file mode 100644 index 727369acd8..0000000000 --- a/crates/burn-wgpu/src/fusion.rs +++ /dev/null @@ -1,14 +0,0 @@ -use crate::WgpuDevice; -use burn_fusion::{DeviceId, FusionDevice}; - -impl FusionDevice for WgpuDevice { - fn id(&self) -> DeviceId { - match self { - WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32), - WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32), - WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), - WgpuDevice::Cpu => DeviceId::new(3, 0), - WgpuDevice::BestAvailable => DeviceId::new(4, 0), - } - } -} diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index fb4f511994..f34791347b 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -10,9 +10,6 @@ mod element; mod graphics; mod runtime; -#[cfg(feature = "fusion")] -mod fusion; - #[cfg(feature = "template")] pub use burn_jit::{ compute::Kernel, diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs index 4267b9faf4..9c01deb29e 100644 --- a/crates/burn-wgpu/src/runtime.rs +++ b/crates/burn-wgpu/src/runtime.rs @@ -13,6 +13,7 @@ use burn_compute::{ ComputeRuntime, }; use burn_jit::Runtime; +use burn_tensor::backend::{DeviceId, DeviceOps}; use std::marker::PhantomData; use wgpu::{AdapterInfo, DeviceDescriptor}; @@ -52,6 +53,18 @@ impl Runtime for WgpuRuntime DeviceId { + match self { + WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32), + WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32), + WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), + WgpuDevice::Cpu => DeviceId::new(3, 0), + WgpuDevice::BestAvailable => DeviceId::new(4, 0), + } + } +} + /// The values that control how a WGPU Runtime will perform its calculations. pub struct RuntimeOptions { /// How the buffers are deallocated.