Skip to content

Commit

Permalink
Move HandleContainer and Tensor Ops descriptions from burn-fusion to …
Browse files Browse the repository at this point in the history
…burn-tensor (#1654)

* Move HandlerContainer and Tensor Ops description to burn-tensor

Move HandleContainer and Tensor operations descriptions to burn-tensor crate.
Removed the FusionDevice and replaced it with a DeviceOps trait bound to Backend::Device.

For now added modules to burn-tensor are excluded from no-std as they rely on Arc.

* [burn-tensor] Flatten module hierarchy for tensor representation

+ Add new repr feature to cargo file.

* Remove prefix on dosctring

* [burn-fusion] Require default features of burn-tensor
  • Loading branch information
syl20bnr authored Apr 23, 2024
1 parent e6b1b7a commit c579686
Show file tree
Hide file tree
Showing 51 changed files with 775 additions and 738 deletions.
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::marker::PhantomData;

use burn_tensor::{backend::Backend, Device};
use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps},
Device,
};
use candle_core::DeviceLocation;

use crate::{
Expand Down Expand Up @@ -60,6 +63,16 @@ impl From<candle_core::Device> for CandleDevice {
}
}

impl DeviceOps for CandleDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
match self {
CandleDevice::Cpu => DeviceId::new(0, 0),
CandleDevice::Cuda(index) => DeviceId::new(1, *index as u32),
CandleDevice::Metal(index) => DeviceId::new(2, *index as u32),
}
}
}

impl Default for CandleDevice {
fn default() -> Self {
Self::Cpu
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-fusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ std = ["serde/std"]
doc = ["default"]

[dependencies]
burn-tensor = { path = "../burn-tensor", version = "0.14.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.14.0" }
burn-common = { path = "../burn-common", version = "0.14.0" }
hashbrown = { workspace = true }
derive-new = {workspace = true }
Expand Down
61 changes: 9 additions & 52 deletions crates/burn-fusion/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use crate::{
client::FusionClient,
stream::{Context, OperationDescription},
FusionClientLocator, FusionTensor, PrecisionBridge,
client::FusionClient, stream::Context, FusionClientLocator, FusionTensor, PrecisionBridge,
};
use burn_tensor::{
backend::Backend,
repr::{OperationDescription, ReprBackend},
Device,
};
use burn_tensor::{backend::Backend, Device, Shape};
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;

pub(crate) static CLIENTS: FusionClientLocator = FusionClientLocator::new();

pub(crate) fn get_client<B: FusionBackend>(device: &B::FusionDevice) -> B::FusionClient {
pub(crate) fn get_client<B: FusionBackend>(device: &B::Device) -> B::FusionClient {
CLIENTS.client(device)
}

Expand Down Expand Up @@ -43,7 +45,7 @@ impl<B: FusionBackend> Backend for Fusion<B> {
}

fn sync(device: &Self::Device) {
let client = CLIENTS.client::<B::FusionClient>(&device.clone().into());
let client = CLIENTS.client::<B::FusionClient>(&device.clone());
client.drain();
B::sync(device)
}
Expand Down Expand Up @@ -114,62 +116,17 @@ pub trait Optimization<B: FusionBackend>: Send {
fn from_state(device: &B::Device, state: B::OptimizationState) -> Self;
}

/// The device id.
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, new)]
pub struct DeviceId {
/// The type id identifies the type of the device.
pub type_id: u16,
/// The index id identifies the device number.
pub index_id: u32,
}

/// The handle device trait allows to get an id for a backend device.
pub trait FusionDevice: Clone + Send + Sync + PartialEq {
/// Return the [device id](DeviceId).
fn id(&self) -> DeviceId;
}

/// Trait that allows an existing [backend](Backend) to specify graph optimizations using
/// [operation builder](crate::OptimizationBuilder).
pub trait FusionBackend: Backend {
pub trait FusionBackend: Backend + ReprBackend {
/// The state that can be serialized for an optimization.
type OptimizationState: Serialize + DeserializeOwned;
/// Optimization type for the backend.
type Optimization: Optimization<Self>;

/// The device type that can return an ID.
///
/// It can be the same as (Backend::Device), but must implement (FusionDevice).
type FusionDevice: FusionDevice + From<Self::Device> + Into<Self::Device> + core::fmt::Debug;
/// The type that can be used to point to a tensor of any kind.
type Handle: Sync + Send + Clone;
/// What kind of client should be used.
type FusionClient: FusionClient<FusionBackend = Self>;

/// The list of optimizations that will be used to optimize the computational graph.
fn optimizations(device: Device<Self>)
-> Vec<Box<dyn OptimizationBuilder<Self::Optimization>>>;

/// Convert a [handle](FusionBackend::Handle) to a [float tensor](Backend::FloatTensorPrimitive).
fn float_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::FloatTensorPrimitive<D>;
/// Convert a [handle](FusionBackend::Handle) to an [int tensor](Backend::IntTensorPrimitive).
fn int_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::IntTensorPrimitive<D>;
/// Convert a [handle](FusionBackend::Handle) to a [bool tensor](Backend::BoolTensorPrimitive).
fn bool_tensor<const D: usize>(
handle: Self::Handle,
shape: Shape<D>,
) -> Self::BoolTensorPrimitive<D>;

/// Convert a [float tensor](Backend::FloatTensorPrimitive) to a [handle](FusionBackend::Handle).
fn float_tensor_handle<const D: usize>(tensor: Self::FloatTensorPrimitive<D>) -> Self::Handle;
/// Convert an [int tensor](Backend::IntTensorPrimitive) to a [handle](FusionBackend::Handle).
fn int_tensor_handle<const D: usize>(tensor: Self::IntTensorPrimitive<D>) -> Self::Handle;
/// Convert a [bool tensor](Backend::BoolTensorPrimitive) to a [handle](FusionBackend::Handle).
fn bool_tensor_handle<const D: usize>(tensor: Self::BoolTensorPrimitive<D>) -> Self::Handle;
}
14 changes: 8 additions & 6 deletions crates/burn-fusion/src/client/base.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use crate::{
stream::{Operation, OperationDescription, StreamId},
FusionBackend, FusionTensor, Handle, TensorDescription, TensorId,
stream::{execution::Operation, StreamId},
FusionBackend, FusionTensor, Handle,
};
use burn_tensor::{
backend::Backend,
ops::{FloatElem, IntElem},
Data, Reader,
repr::{OperationDescription, TensorDescription, TensorId},
Data, Device, Reader,
};

/// Define how to interact with the fusion server.
pub trait FusionClient: Send + Sync + Clone {
/// The [fusion backend](FusionBackend) associated type.
type FusionBackend: FusionBackend;

/// Create a new client for the given [fusion device](FusionBackend::FusionDevice).
fn new(device: <Self::FusionBackend as FusionBackend>::FusionDevice) -> Self;
/// Create a new client for the given [device](Backend::Device).
fn new(device: Device<Self::FusionBackend>) -> Self;
/// Register a new [tensor operation description](OperationDescription).
fn register<O: Operation<Self::FusionBackend> + 'static>(
&self,
Expand All @@ -24,7 +26,7 @@ pub trait FusionClient: Send + Sync + Clone {
/// Register all lazy computation.
fn drain(&self);
/// Get the current device used by all operations handled by this client.
fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice;
fn device(&self) -> &<Self::FusionBackend as Backend>::Device;
/// Create a new [fusion tensor](FusionTensor), but with no resources allocated to it.
fn tensor_uninitialized(&self, shape: Vec<usize>) -> FusionTensor<Self>;
/// Create a tensor with the given handle and shape.
Expand Down
41 changes: 20 additions & 21 deletions crates/burn-fusion/src/client/mutex.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use super::FusionClient;
use crate::{
stream::{Operation, OperationDescription, StreamId},
stream::{execution::Operation, StreamId},
FusionBackend, FusionServer, FusionTensor, Handle,
};
use burn_tensor::ops::FloatElem;
use burn_tensor::{
backend::Backend,
ops::FloatElem,
repr::{OperationDescription, TensorDescription, TensorId},
};
use spin::Mutex;
use std::sync::Arc;

Expand All @@ -13,7 +17,7 @@ where
B: FusionBackend,
{
server: Arc<Mutex<FusionServer<B>>>,
device: B::FusionDevice,
device: B::Device,
}

impl<B> Clone for MutexFusionClient<B>
Expand All @@ -34,7 +38,7 @@ where
{
type FusionBackend = B;

fn new(device: B::FusionDevice) -> Self {
fn new(device: B::Device) -> Self {
Self {
device: device.clone(),
server: Arc::new(Mutex::new(FusionServer::new(device))),
Expand Down Expand Up @@ -63,7 +67,7 @@ where
FusionTensor::new(id, shape, self.clone(), StreamId::current())
}

fn device(&self) -> &<Self::FusionBackend as FusionBackend>::FusionDevice {
fn device(&self) -> &<Self::FusionBackend as Backend>::Device {
&self.device
}
fn register_tensor(
Expand All @@ -82,15 +86,15 @@ where

fn read_tensor_float<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
stream: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<FloatElem<Self::FusionBackend>, D>> {
self.server.lock().read_float(tensor, stream)
}

fn read_tensor_int<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<burn_tensor::ops::IntElem<Self::FusionBackend>, D>>
{
Expand All @@ -99,25 +103,24 @@ where

fn read_tensor_bool<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
stream: StreamId,
) -> burn_tensor::Reader<burn_tensor::Data<bool, D>> {
self.server.lock().read_bool(tensor, stream)
}

fn change_client_float<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
client: Self,
stream: StreamId,
) -> FusionTensor<Self> {
let device = client.device.clone().into();

let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);

let id = server_current.change_server_float::<D>(&tensor, &device, &mut server_other);
let id =
server_current.change_server_float::<D>(&tensor, &client.device, &mut server_other);

core::mem::drop(server_other);
core::mem::drop(server_current);
Expand All @@ -127,17 +130,15 @@ where

fn change_client_int<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
client: Self,
stream: StreamId,
) -> FusionTensor<Self> {
let device = client.device.clone().into();

let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);

let id = server_current.change_server_int::<D>(&tensor, &device, &mut server_other);
let id = server_current.change_server_int::<D>(&tensor, &client.device, &mut server_other);

core::mem::drop(server_other);
core::mem::drop(server_current);
Expand All @@ -147,25 +148,23 @@ where

fn change_client_bool<const D: usize>(
&self,
tensor: crate::TensorDescription,
tensor: TensorDescription,
client: Self,
stream: StreamId,
) -> FusionTensor<Self> {
let device = client.device.clone().into();

let mut server_other = client.server.lock();
let mut server_current = self.server.lock();
server_current.drain_stream(stream);

let id = server_current.change_server_bool::<D>(&tensor, &device, &mut server_other);
let id = server_current.change_server_bool::<D>(&tensor, &client.device, &mut server_other);

core::mem::drop(server_other);
core::mem::drop(server_current);

FusionTensor::new(id, tensor.shape, client, StreamId::current())
}

fn register_orphan(&self, id: &crate::TensorId) {
fn register_orphan(&self, id: &TensorId) {
self.server.lock().drop_tensor_handle(*id);
}
}
14 changes: 10 additions & 4 deletions crates/burn-fusion/src/fusion.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
use crate::{client::FusionClient, DeviceId, FusionBackend, FusionDevice};
use burn_tensor::{
backend::{Backend, DeviceId, DeviceOps},
repr::ReprBackend,
};

use crate::client::FusionClient;

use std::{any::Any, collections::HashMap, ops::DerefMut};

/// Type alias for [fusion backend handle](FusionBackend::Handle).
pub type Handle<B> = <B as FusionBackend>::Handle;
/// Type alias for [representation backend handle](burn_tensor::repr::ReprBackend::Handle).
pub type Handle<B> = <B as ReprBackend>::Handle;
type Key = (core::any::TypeId, DeviceId);

pub(crate) struct FusionClientLocator {
Expand All @@ -22,7 +28,7 @@ impl FusionClientLocator {
/// Provide the init function to create a new client if it isn't already initialized.
pub fn client<C: FusionClient + 'static>(
&self,
device: &<C::FusionBackend as FusionBackend>::FusionDevice,
device: &<C::FusionBackend as Backend>::Device,
) -> C {
let device_id = device.id();
let client_id = (core::any::TypeId::of::<C>(), device_id);
Expand Down
2 changes: 0 additions & 2 deletions crates/burn-fusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub mod stream;
mod backend;
mod bridge;
mod fusion;
mod handle;
mod ops;
mod server;
mod tensor;
Expand All @@ -26,5 +25,4 @@ pub(crate) use server::*;
pub use backend::*;
pub use bridge::*;
pub use fusion::*;
pub use handle::*;
pub use tensor::*;
Loading

0 comments on commit c579686

Please sign in to comment.