Skip to content

Commit a6a40e8

Browse files
Added tests
1 parent a8584a0 commit a6a40e8

File tree

8 files changed

+217
-158
lines changed

8 files changed

+217
-158
lines changed

Cargo.lock

+168-142
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+12-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ members = [
1616

1717
exclude = [
1818
"examples/notebook",
19+
"examples/onnx-inference-rp2040", # will cause dependency building issues otherwise
20+
# "crates/burn-cuda", # comment this line to work on burn-cuda
1921
]
2022

2123
[workspace.package]
@@ -72,7 +74,11 @@ serde_bytes = { version = "0.11.15", default-features = false, features = [
7274
] } # alloc for no_std
7375
serde_rusqlite = "0.35.0"
7476
serial_test = "3.1.1"
75-
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
77+
spin = { version = "0.9.8", features = [
78+
"mutex",
79+
"spin_mutex",
80+
"portable-atomic",
81+
] }
7682
strum = "0.26.3"
7783
strum_macros = "0.26.4"
7884
syn = { version = "2.0.74", features = ["full", "extra-traits"] }
@@ -118,7 +124,7 @@ half = { version = "2.4.1", features = [
118124
"num-traits",
119125
"serde",
120126
], default-features = false }
121-
ndarray = { version = "0.15.6", default-features = false }
127+
ndarray = { version = "0.16.0", default-features = false }
122128
matrixmultiply = { version = "0.3.9", default-features = false }
123129
openblas-src = "0.10.9"
124130
blas-src = { version = "0.10.0", default-features = false }
@@ -142,9 +148,11 @@ nvml-wrapper = "0.10.0"
142148
sysinfo = "0.30.13"
143149
systemstat = "0.2.3"
144150

151+
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
152+
145153
### For the main burn branch. ###
146-
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
147-
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "bee7886b5c3016c425d244136f77442655097f3e" }
154+
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b09821d1d5bd1ee0cb8d80b04916acb4a9096c29" }
155+
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b09821d1d5bd1ee0cb8d80b04916acb4a9096c29" }
148156
### For local development. ###
149157
# cubecl = { path = "../cubecl/crates/cubecl" }
150158
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }

burn-book/src/advanced/no-std.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# No Standard Library
22

3-
In this seciton, you will learn how to run an onnx inference model on an embedded system, with no standard library support on a Raspberry Pi Pico. This should be universally applicable to other platforms. All the code can be found under the
3+
In this section, you will learn how to run an onnx inference model on an embedded system, with no standard library support on a Raspberry Pi Pico. This should be universally applicable to other platforms. All the code can be found under the
44
[examples directory](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference-rp2040).
55

66
## Step-by-Step Guide

crates/burn-ndarray/src/ops/base.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ fn arg<E: NdArrayElement, const D: usize>(
612612
idx as i64
613613
});
614614

615-
let output = output.into_shape_with_order(Dim(reshape.as_slice())).unwrap();
615+
let output = output.to_shape(Dim(reshape.as_slice())).unwrap();
616616

617617
NdArrayTensor {
618618
array: output.into_shared(),

crates/burn-ndarray/src/ops/conv.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
209209
});
210210

211211
let output = output
212-
.into_shape_with_order([batch_size, out_channels, out_height, out_width])
212+
.to_shape([batch_size, out_channels, out_height, out_width])
213213
.unwrap()
214214
.into_dyn()
215215
.into_shared();
@@ -437,7 +437,7 @@ pub(crate) fn conv3d<E: FloatNdArrayElement, Q: QuantElement>(
437437
});
438438

439439
let output = output
440-
.into_shape_with_order([batch_size, out_channels, out_depth, out_height, out_width])
440+
.to_shape([batch_size, out_channels, out_depth, out_height, out_width])
441441
.unwrap()
442442
.into_dyn()
443443
.into_shared();

crates/burn-ndarray/src/tensor.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ macro_rules! reshape {
7070
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
7171
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match $array.is_standard_layout() {
7272
true => $array
73-
.into_shape_with_order(dim)
73+
.to_shape(dim)
7474
.expect("Safe to change shape without relayout")
7575
.into_shared(),
76-
false => $array.into_shape_with_order(dim).unwrap(),
76+
false => $array.to_shape(dim).unwrap().as_standard_layout().into_shared(),
7777
};
7878
let array = array.into_dyn();
7979

examples/onnx-inference-rp2040/src/bin/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async fn main(_spawner: Spawner) {
3939
// Run the model
4040
let output = run_model(&model, &device, input);
4141

42-
// Ouput the values
42+
// Output the values
4343
match output.into_primitive().tensor().array.as_slice() {
4444
Some(slice) => info!("input: {} - output: {}", input, slice),
4545
None => defmt::panic!("Failed to get value")

xtask/src/runchecks.rs

+30-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::{endgroup, group};
2222
// Targets constants
2323
const WASM32_TARGET: &str = "wasm32-unknown-unknown";
2424
const ARM_TARGET: &str = "thumbv7m-none-eabi";
25+
const ARM_NO_ATOMIC_PTR_TARGET: &str = "thumbv6m-none-eabi";
2526

2627
#[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)]
2728
pub(crate) enum CheckType {
@@ -81,12 +82,12 @@ impl CheckType {
8182
}
8283

8384
/// Run cargo build command
84-
fn cargo_build(params: Params) {
85+
fn cargo_build(params: Params, envs: Option<HashMap<&str, String>>) {
8586
// Run cargo build
8687
run_cargo(
8788
"build",
8889
params + "--color=always",
89-
HashMap::new(),
90+
envs.unwrap_or_default(),
9091
"Failed to run cargo build",
9192
);
9293
}
@@ -155,7 +156,10 @@ fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]
155156
group!("Checks: {} (no-std)", crate_name);
156157

157158
// Run cargo build --no-default-features
158-
cargo_build(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
159+
cargo_build(
160+
Params::from(["-p", crate_name, "--no-default-features"]) + extra_args,
161+
None,
162+
);
159163

160164
// Run cargo test --no-default-features
161165
cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
@@ -169,6 +173,7 @@ fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]
169173
"--target",
170174
WASM32_TARGET,
171175
]) + extra_args,
176+
None,
172177
);
173178

174179
// Run cargo build --no-default-features --target thumbv7m-none-eabi
@@ -180,6 +185,22 @@ fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]
180185
"--target",
181186
ARM_TARGET,
182187
]) + extra_args,
188+
None,
189+
);
190+
191+
// Run cargo build --no-default-features --target thumbv6m-none-eabi
192+
cargo_build(
193+
Params::from([
194+
"-p",
195+
crate_name,
196+
"--no-default-features",
197+
"--target",
198+
ARM_NO_ATOMIC_PTR_TARGET,
199+
]) + extra_args,
200+
Some(HashMap::from([(
201+
"RUSTFLAGS",
202+
"--cfg portable_atomic_unsafe_assume_single_core".to_string(),
203+
)])),
183204
);
184205

185206
endgroup!();
@@ -228,6 +249,9 @@ fn no_std_checks() {
228249
// Install ARM target
229250
rustup_add_target(ARM_TARGET);
230251

252+
// Install ARM no atomic ptr target
253+
rustup_add_target(ARM_NO_ATOMIC_PTR_TARGET);
254+
231255
// Run checks for the following crates
232256
build_and_test_no_std("burn", []);
233257
build_and_test_no_std("burn-core", []);
@@ -265,7 +289,7 @@ fn burn_dataset_features_std() {
265289
group!("Checks: burn-dataset (all-features)");
266290

267291
// Run cargo build --all-features
268-
cargo_build(["-p", "burn-dataset", "--all-features"].into());
292+
cargo_build(["-p", "burn-dataset", "--all-features"].into(), None);
269293

270294
// Run cargo test --all-features
271295
cargo_test(["-p", "burn-dataset", "--all-features"].into());
@@ -334,7 +358,7 @@ fn std_checks() {
334358
}
335359

336360
group!("Checks: {}", member.name);
337-
cargo_build(Params::from(["-p", &member.name]));
361+
cargo_build(Params::from(["-p", &member.name]), None);
338362
cargo_test(Params::from(["-p", &member.name]));
339363
endgroup!();
340364
}
@@ -373,6 +397,7 @@ fn check_typos() {
373397

374398
// Run typos command as child process
375399
let typos = Command::new("typos")
400+
.args(["--exclude", "**/*.onnx"])
376401
.stdout(Stdio::inherit()) // Send stdout directly to terminal
377402
.stderr(Stdio::inherit()) // Send stderr directly to terminal
378403
.spawn()

0 commit comments

Comments
 (0)