Skip to content

Commit 929a2a0

Browse files
committed
Feat: Add android example of MNIST inference
1 parent a88c69a commit 929a2a0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1401
-4
lines changed

Cargo.lock

+126
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ members = [
1111
"crates/burn-import/onnx-tests",
1212
"examples/*",
1313
"examples/pytorch-import/model",
14+
"examples/mnist-inference-android/app/src/main/rust",
1415
"xtask",
1516
]
1617

1718
exclude = [
1819
"examples/notebook",
19-
"examples/raspberry-pi-pico", # will cause dependency building issues otherwise
20+
"examples/mnist-inference-android",
21+
"examples/raspberry-pi-pico", # will cause dependency building issues otherwise
2022
# "crates/burn-cuda", # comment this line to work on burn-cuda
2123
]
2224

@@ -157,8 +159,8 @@ portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
157159
# cubecl = { path = "../cubecl/crates/cubecl" }
158160
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
159161
### For the release. ###
160-
cubecl = { version="0.2.0", default-features = false }
161-
cubecl-common = { version="0.2.0", default-features = false }
162+
cubecl = { version = "0.2.0", default-features = false }
163+
cubecl-common = { version = "0.2.0", default-features = false }
162164

163165
### For xtask crate ###
164166
tracel-xtask = { git = "https://github.com/tracel-ai/xtask", rev = "921408bc16e74d3ef8ae59356d928fb6706fb8f4" }

_typos.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ extend-exclude = [
66
"*.onnx",
77
"assets/ModuleSerialization.xml",
88
"examples/image-classification-web/src/model/label.txt",
9+
"examples/mnist-inference-android/gradle/*",
910
]
1011

1112
[default.extend-words]
1213
# Don't correct "arange" which is intentional
13-
arange = "arange"
14+
arange = "arange"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
*.iml
2+
.gradle
3+
/local.properties
4+
/.idea/caches
5+
/.idea/libraries
6+
/.idea/modules.xml
7+
/.idea/workspace.xml
8+
/.idea/navEditor.xml
9+
/.idea/assetWizardSettings.xml
10+
.DS_Store
11+
/build
12+
/captures
13+
.externalNativeBuild
14+
.cxx
15+
local.properties
+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# MNIST number detector Android App
2+
3+
This project is a sample Android application that demonstrates how to integrate a `Burn` into an android app using the JNI (Java Native Interface).
4+
5+
## Table of Contents
6+
7+
- [Workflow](#workflow)
8+
- [Prerequisites](#prerequisites)
9+
- [Setup](#setup)
10+
- [How To make your own](#how-to-make-your-own)
11+
- [License](#license)
12+
13+
## Workflow
14+
1. **Image Input:** The user provides an image input through the app's interface.
15+
2. **Image Processing:** The image is converted to a grayscale `byteArray` in Kotlin.
16+
3. **JNI Bridge:** The grayscale `byteArray` is passed to a Rust function via JNI.
17+
4. **Rust Processing:** The Rust function calls the `forward` method from the `burn` library, using a pretrained MNIST ONNX model to perform inference.
18+
5. **Result Handling:** The result, an integer representing the predicted digit, is logged to the android console and returned from Rust to Kotlin.
19+
6. **Output Display:** The predicted digit is displayed on the screen
20+
21+
## Prerequisites
22+
- Android Studio (latest version recommended)
23+
- Rust (installed and configured)
24+
- Android NDK (Native Development Kit)
25+
26+
## Setup
27+
1. **Install Rust dependencies:**
28+
29+
Ensure Rust is installed and the `cargo` command is available:
30+
31+
```bash
32+
rustup update
33+
```
34+
And that you have installed all the rustup toolchains required:
35+
```bash
36+
rustup target add \
37+
aarch64-linux-android \
38+
armv7-linux-androideabi \
39+
i686-linux-android \
40+
x86_64-linux-android
41+
```
42+
43+
2. **Configure the Android NDK:**
44+
45+
Ensure that the Android NDK is installed. You can install it via Android Studio's SDK Manager.
46+
47+
3. **Build the android app:**
48+
49+
Running the android app should automatically build the rust libraries due to the gradle tasks configured at the app level. (More on that later)
50+
51+
52+
## How To make your own
53+
1. There are a few ways to compile a rust library for android -
54+
- Add targets in `.cargo/config.toml` and build with them. Then we can add the `.so` files generated to the jni directory in `app/src/main/jniLibs`
55+
- Add gradle plugins (like [rust-android-gradle](https://github.com/mozilla/rust-android-gradle) or [cargo-ndk-android](https://github.com/willir/cargo-ndk-android-gradle) using `rust-android-gradle` in this project) to do the work for you, so that the rust library is built on each app build. (Might want to change for expensive library builds)
56+
2. To interface with Kotlin(Java) you can either use an interface generator (like [flapigen-rs](https://github.com/Dushistov/flapigen-rs)) or make them by yourself. This sample function doesn't use flapigen.
57+
3. Now the function to be called from android (`infer()` here) needs to follow the [JNI naming conventions](https://docs.oracle.com/javase/1.5.0/docs/guide/jni/spec/design.html) (The correct name is also shown in the call error if it doesn't exist).
58+
4. **Important** The first 2 arguments of the jni interfacing function will be the `env` variable (for interface functions) and the `this` object. The data you pass will start from the 3rd argument.
59+
5. Next for converting the data from java to rust data types, there are multiple functions in the env variable passed to the function. Use as required...
60+
6. Then in the app's `build.gradle` we add the part to run the cargo build before building the app and the also the cargo build details:
61+
```kotlin
62+
// Cargo build details
63+
cargo {
64+
module = "./src/main/rust" // Or whatever directory contains your Cargo.toml
65+
libname = "mnist_inference_android" // Or whatever matches Cargo.toml's [package] name.
66+
targets = listOf(
67+
"arm", "arm64",
68+
"x86",
69+
"x86_64"
70+
)
71+
prebuiltToolchains = true
72+
}
73+
74+
// Used to build cargo before the android build task is run
75+
// See more options here: https://github.com/mozilla/rust-android-gradle/issues/133
76+
project.afterEvaluate {
77+
tasks.withType(com.nishtahir.CargoBuildTask::class)
78+
.forEach { buildTask ->
79+
tasks.withType(com.android.build.gradle.tasks.MergeSourceSetFolders::class)
80+
.configureEach {
81+
this.inputs.dir(
82+
layout.buildDirectory.dir("rustJniLibs" + File.separatorChar + buildTask.toolchain!!.folder)
83+
)
84+
this.dependsOn(buildTask)
85+
}
86+
}
87+
}
88+
```
89+
(In the example we have also added the target directory in `config.toml` since otherwise it will build into the workspace target, which we do not want)
90+
7. Here the library's name is `mnist-android` so we will initialize it in our app:
91+
```kotlin
92+
class MainActivity : ComponentActivity() {
93+
init {
94+
System.loadLibrary("mnist_android") // Note: '-' is changed to '_'
95+
}
96+
...
97+
}
98+
```
99+
8. Finally use it by declaring it as an external function first
100+
```kotlin
101+
external fun infer(inputImage: ByteArray): Int;
102+
103+
...
104+
infer(byteArray)
105+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/build

0 commit comments

Comments
 (0)