Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BjornTheProgrammer committed Aug 16, 2024
1 parent a8584a0 commit a3f17e0
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 234 deletions.
282 changes: 144 additions & 138 deletions Cargo.lock

Large diffs are not rendered by default.

16 changes: 12 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ members = [

exclude = [
"examples/notebook",
"examples/onnx-inference-rp2040", # will cause dependency building issues otherwise
# "crates/burn-cuda", # comment this line to work on burn-cuda
]

[workspace.package]
Expand Down Expand Up @@ -72,7 +74,11 @@ serde_bytes = { version = "0.11.15", default-features = false, features = [
] } # alloc for no_std
serde_rusqlite = "0.35.0"
serial_test = "3.1.1"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
spin = { version = "0.9.8", features = [
"mutex",
"spin_mutex",
"portable-atomic",
] }
strum = "0.26.3"
strum_macros = "0.26.4"
syn = { version = "2.0.74", features = ["full", "extra-traits"] }
Expand Down Expand Up @@ -118,7 +124,7 @@ half = { version = "2.4.1", features = [
"num-traits",
"serde",
], default-features = false }
ndarray = { version = "0.15.6", default-features = false }
ndarray = { version = "0.16.0", default-features = false }
matrixmultiply = { version = "0.3.9", default-features = false }
openblas-src = "0.10.9"
blas-src = { version = "0.10.0", default-features = false }
Expand All @@ -142,9 +148,11 @@ nvml-wrapper = "0.10.0"
sysinfo = "0.30.13"
systemstat = "0.2.3"

portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "034f667da6e92a81b7da9f303e8507db944cc2a4" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "034f667da6e92a81b7da9f303e8507db944cc2a4" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl" }
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
Expand Down
2 changes: 1 addition & 1 deletion burn-book/src/advanced/no-std.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# No Standard Library

In this seciton, you will learn how to run an onnx inference model on an embedded system, with no standard library support on a Raspberry Pi Pico. This should be universally applicable to other platforms. All the code can be found under the
In this section, you will learn how to run an onnx inference model on an embedded system, with no standard library support on a Raspberry Pi Pico. This should be universally applicable to other platforms. All the code can be found under the
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference-rp2040).

## Step-by-Step Guide
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ fn arg<E: NdArrayElement, const D: usize>(
idx as i64
});

let output = output.into_shape_with_order(Dim(reshape.as_slice())).unwrap();
let output = output.to_shape(Dim(reshape.as_slice())).unwrap();

NdArrayTensor {
array: output.into_shared(),
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-ndarray/src/ops/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
});

let output = output
.into_shape_with_order([batch_size, out_channels, out_height, out_width])
.to_shape([batch_size, out_channels, out_height, out_width])
.unwrap()
.into_dyn()
.into_shared();
Expand Down Expand Up @@ -437,7 +437,7 @@ pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
});

let output = output
.into_shape_with_order([batch_size, out_channels, out_depth, out_height, out_width])
.to_shape([batch_size, out_channels, out_depth, out_height, out_width])
.unwrap()
.into_dyn()
.into_shared();
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-ndarray/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ macro_rules! reshape {
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
true => $array
.into_shape_with_order(dim)
.to_shape(dim)
.expect("Safe to change shape without relayout")
.into_shared(),
false => $array.into_shape_with_order(dim).unwrap(),
false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
};
let array = array.into_dyn();

Expand Down
Loading

0 comments on commit a3f17e0

Please sign in to comment.