Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/burn-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ network = ["dep:indicatif", "dep:reqwest", "dep:tokio"]
rayon = ["dep:rayon"]

[dependencies]
derive-new = { workspace = true }
serde = { workspace = true }

# Network downloader
Expand Down
78 changes: 78 additions & 0 deletions crates/burn-common/src/baselib/indexing/as_index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use core::fmt::Debug;

/// Helper trait for implementing indexing with support for negative indices.
///
/// # Example
/// ```rust
/// use burn_common::baselib::indexing::{AsIndex, canonicalize_dim};
///
/// fn example<I: AsIndex, const D: usize>(dim: I, size: usize) -> isize {
/// let dim: usize = canonicalize_dim(dim, D, false);
/// unimplemented!()
/// }
/// ```
pub trait AsIndex: Debug + Copy + Sized {
/// Converts into a slice index.
fn index(self) -> isize;
}

impl AsIndex for usize {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for isize {
fn index(self) -> isize {
self
}
}

impl AsIndex for i64 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for u64 {
fn index(self) -> isize {
self as isize
}
}

// Default integer type
impl AsIndex for i32 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for u32 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for i16 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for u16 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for i8 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for u8 {
fn index(self) -> isize {
self as isize
}
}
Original file line number Diff line number Diff line change
@@ -1,83 +1,14 @@
//! A module for indexing utility machinery.

use core::fmt::Debug;

/// Helper trait for implementing indexing with support for negative indices.
///
/// # Example
/// ```rust
/// use burn_tensor::indexing::{AsIndex, canonicalize_dim};
///
/// fn example<I: AsIndex, const D: usize>(dim: I, size: usize) -> isize {
/// let dim: usize = canonicalize_dim(dim, D, false);
/// unimplemented!()
/// }
/// ```
pub trait AsIndex: Debug + Copy + Sized {
/// Converts into a slice index.
fn index(self) -> isize;
}

impl AsIndex for usize {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for isize {
fn index(self) -> isize {
self
}
}

impl AsIndex for i64 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for u64 {
fn index(self) -> isize {
self as isize
}
}

// Default integer type
impl AsIndex for i32 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for u32 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for i16 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for u16 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for i8 {
fn index(self) -> isize {
self as isize
}
}

impl AsIndex for u8 {
fn index(self) -> isize {
self as isize
}
}
mod as_index;
mod range;
mod shape;
mod slice;

pub use as_index::*;
pub use range::*;
pub use shape::*;
pub use slice::*;

/// Canonicalizes and bounds checks an index with negative indexing support.
///
Expand Down
39 changes: 39 additions & 0 deletions crates/burn-common/src/baselib/indexing/range.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use crate::baselib::indexing::shape::Shape;
use crate::baselib::indexing::slice::Slice;
use std::ops::Range;

/// Trait used for slice dim arguments.
pub trait RangeArg {
/// Converts into a range for the `tensor.slice_dim()` function
fn into_range(self, shape_dim: usize) -> Range<usize>;
}

impl<T: Into<Slice>> RangeArg for T {
fn into_range(self, shape_dim: usize) -> Range<usize> {
self.into().into_range(shape_dim)
}
}

/// Trait used for slice arguments
pub trait RangesArg<const D2: usize> {
/// Converts into a set of ranges to `[Range<usize>; D2]` for the `tensor.slice()` function
fn into_ranges(self, shape: Shape) -> [Range<usize>; D2];
}

impl<const D2: usize, T: Into<Slice>> RangesArg<D2> for [T; D2] {
fn into_ranges(self, shape: Shape) -> [Range<usize>; D2] {
// clamp the ranges to the shape dimensions
let ranges = self
.into_iter()
.enumerate()
.map(|(i, range)| range.into().into_range(shape.dims[i]))
.collect::<Vec<_>>();
ranges.try_into().unwrap()
}
}

impl<T: Into<Slice>> RangesArg<1> for T {
fn into_ranges(self, shape: Shape) -> [Range<usize>; 1] {
[self.into().into_range(shape.dims[0])]
}
}
Loading
Loading