Skip to content

Commit

Permalink
Generate random experiment name in guide example (#25)
Browse files Browse the repository at this point in the history
* Generate random experiment name in guide example

* Cargo fmt

* Update cargo dependencies to fix RUSTSEC-2024-0350

* Update guide example with last burn API
  • Loading branch information
syl20bnr authored Jul 16, 2024
1 parent adbef96 commit 1df2c3a
Show file tree
Hide file tree
Showing 14 changed files with 737 additions and 462 deletions.
898 changes: 532 additions & 366 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,24 @@ readme = "README.md"
[workspace.dependencies]
burn = { git = "https://github.com/tracel-ai/burn", branch = "main" }

anyhow = "1.0.81"
clap = { version = "4.5.4", features = ["derive"] }
anyhow = "1.0.86"
clap = { version = "4.5.9", features = ["derive"] }
derive-new = { version = "0.6.0", default-features = false }
derive_more = { version = "0.99.17", features = ["display"], default-features = false }
derive_more = { version = "0.99.18", features = ["display"], default-features = false }
dotenv = "0.15.0"
env_logger = "0.11.3"
log = "0.4.21"
log = "0.4.22"
rand = "0.8.5"
rmp-serde = "1.3.0"
rstest = "0.19.0"
serde = { version = "1.0.200", default-features = false, features = [
serde = { version = "1.0.204", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = "1.0.64"
strum = {version = "0.26.2", features = ["derive"]}
thiserror = "1.0.30"
reqwest = "0.12.4"
serde_json = "1.0.120"
strum = {version = "0.26.3", features = ["derive"]}
thiserror = "1.0.62"
reqwest = "0.12.5"

[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
1 change: 1 addition & 0 deletions crates/heat-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rmp-serde = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
log = { workspace = true }
rand = { version = "0.8.5" }
reqwest = { workspace = true, features = ["blocking", "json"] }
tungstenite = { version = "0.21.0" }
thiserror = { workspace = true }
Expand Down
9 changes: 9 additions & 0 deletions crates/heat-sdk/src/http/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;

use rand::Rng;
use reqwest::header::{COOKIE, SET_COOKIE};
use serde::Serialize;

Expand Down Expand Up @@ -125,11 +126,19 @@ impl HttpClient {

let url = format!("{}/projects/{}/experiments", self.base_url, project_id);

let mut body = HashMap::new();
let mut rng = rand::thread_rng();
body.insert(
"experiment_name",
format!("guide-{}", rng.gen_range(0..10000)),
);

// Create a new experiment
let exp_uuid = self
.http_client
.post(url)
.header(COOKIE, self.session_cookie.as_ref().unwrap())
.json(&body)
.send()?
.error_for_status()?
.json::<CreateExperimentResponseSchema>()?
Expand Down
12 changes: 9 additions & 3 deletions examples/guide/src/data.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
tensor::TensorData,
};

#[derive(Clone)]
Expand All @@ -24,8 +25,8 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items
.iter()
.map(|item| Data::<f32, 2>::from(item.image))
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device))
.map(|item| TensorData::from(item.image))
.map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), &self.device))
.map(|tensor| tensor.reshape([1, 28, 28]))
// normalize: make between [0,1] and make the mean = 0 and std = 1
// values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example
Expand All @@ -35,7 +36,12 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {

let targets = items
.iter()
.map(|item| Tensor::<B, 1, Int>::from_data([(item.label as i64).elem()], &self.device))
.map(|item| {
Tensor::<B, 1, Int>::from_data(
[(item.label as i64).elem::<B::IntElem>()],
&self.device,
)
})
.collect();

let images = Tensor::cat(images, 0);
Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ mod training;

use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
backend::{Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};

fn main() {
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;

let args = Args::parse();
Expand Down
82 changes: 59 additions & 23 deletions xtask/src/commands/check.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use std::process::Command;

use anyhow::{Ok, Result, anyhow};
use anyhow::{anyhow, Ok, Result};
use clap::{Args, Subcommand};
use strum::{Display, EnumIter, EnumString, IntoEnumIterator};

use crate::{endgroup, group, utils::{cargo::ensure_cargo_crate_is_installed, prompt::ask_once, workspace::{get_workspace_members, WorkspaceMemberType}}};
use crate::{
endgroup, group,
utils::{
cargo::ensure_cargo_crate_is_installed,
prompt::ask_once,
workspace::{get_workspace_members, WorkspaceMemberType},
},
};

use super::Target;

Expand Down Expand Up @@ -36,17 +43,21 @@ pub(crate) fn handle_command(args: CheckCmdArgs, answer: Option<bool>) -> anyhow
CheckCommand::Format => run_format(&args.target, answer),
CheckCommand::Lint => run_lint(&args.target, answer),
CheckCommand::All => {
let answer = ask_once("This will run all the checks with autofix on all members of the workspace.");
let answer = ask_once(
"This will run all the checks with autofix on all members of the workspace.",
);
CheckCommand::iter()
.filter(|c| *c != CheckCommand::All)
.try_for_each(|c| handle_command(
CheckCmdArgs {
command: c,
target: args.target.clone()
},
Some(answer),
))
},
.try_for_each(|c| {
handle_command(
CheckCmdArgs {
command: c,
target: args.target.clone(),
},
Some(answer),
)
})
}
}
}

Expand All @@ -71,7 +82,7 @@ pub(crate) fn run_audit(target: &Target, mut answer: Option<bool>) -> anyhow::Re
}
endgroup!();
}
},
}
Target::All => {
let answer = ask_once("This will run audit checks on all targets.");
Target::iter()
Expand All @@ -94,7 +105,11 @@ fn run_format(target: &Target, mut answer: Option<bool>) -> Result<()> {
if answer.is_none() {
answer = Some(ask_once(&format!(
"This will run format checks on all {} of the workspace.",
if *target == Target::Crates { "crates" } else { "examples" }
if *target == Target::Crates {
"crates"
} else {
"examples"
}
)));
}

Expand All @@ -107,22 +122,27 @@ fn run_format(target: &Target, mut answer: Option<bool>) -> Result<()> {
.status()
.map_err(|e| anyhow!("Failed to execute cargo fmt: {}", e))?;
if !status.success() {
return Err(anyhow!("Format check execution failed for {}", &member.name));
return Err(anyhow!(
"Format check execution failed for {}",
&member.name
));
}
endgroup!();
}
}
},
}
Target::All => {
if answer.is_none() {
answer = Some(ask_once("This will run format check on all members of the workspace."));
answer = Some(ask_once(
"This will run format check on all members of the workspace.",
));
}
if answer.unwrap() {
Target::iter()
.filter(|t| *t != Target::All)
.try_for_each(|t| run_format(&t, answer))?;
}
},
}
}
Ok(())
}
Expand All @@ -139,16 +159,30 @@ fn run_lint(target: &Target, mut answer: Option<bool>) -> anyhow::Result<()> {
if answer.is_none() {
answer = Some(ask_once(&format!(
"This will run lint fix on all {} of the workspace.",
if *target == Target::Crates { "crates" } else { "examples" }
if *target == Target::Crates {
"crates"
} else {
"examples"
}
)));
}

if answer.unwrap() {
for member in members {
group!("Lint: {}", member.name);
info!("Command line: cargo clippy --no-deps --fix --allow-dirty -p {}", &member.name);
info!(
"Command line: cargo clippy --no-deps --fix --allow-dirty -p {}",
&member.name
);
let status = Command::new("cargo")
.args(["clippy", "--no-deps", "--fix", "--allow-dirty", "-p", &member.name])
.args([
"clippy",
"--no-deps",
"--fix",
"--allow-dirty",
"-p",
&member.name,
])
.status()
.map_err(|e| anyhow!("Failed to execute cargo clippy: {}", e))?;
if !status.success() {
Expand All @@ -157,17 +191,19 @@ fn run_lint(target: &Target, mut answer: Option<bool>) -> anyhow::Result<()> {
endgroup!();
}
}
},
}
Target::All => {
if answer.is_none() {
answer = Some(ask_once("This will run lint fix on all members of the workspace."));
answer = Some(ask_once(
"This will run lint fix on all members of the workspace.",
));
}
if answer.unwrap() {
Target::iter()
.filter(|t| *t != Target::All)
.try_for_each(|t| run_lint(&t, answer))?;
}
},
}
}
Ok(())
}
Loading

0 comments on commit 1df2c3a

Please sign in to comment.