Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate random experiment name in guide example #25

Merged
merged 4 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
898 changes: 532 additions & 366 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,24 @@ readme = "README.md"
[workspace.dependencies]
burn = { git = "https://github.com/tracel-ai/burn", branch = "main" }

anyhow = "1.0.81"
clap = { version = "4.5.4", features = ["derive"] }
anyhow = "1.0.86"
clap = { version = "4.5.9", features = ["derive"] }
derive-new = { version = "0.6.0", default-features = false }
derive_more = { version = "0.99.17", features = ["display"], default-features = false }
derive_more = { version = "0.99.18", features = ["display"], default-features = false }
dotenv = "0.15.0"
env_logger = "0.11.3"
log = "0.4.21"
log = "0.4.22"
rand = "0.8.5"
rmp-serde = "1.3.0"
rstest = "0.19.0"
serde = { version = "1.0.200", default-features = false, features = [
serde = { version = "1.0.204", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = "1.0.64"
strum = {version = "0.26.2", features = ["derive"]}
thiserror = "1.0.30"
reqwest = "0.12.4"
serde_json = "1.0.120"
strum = {version = "0.26.3", features = ["derive"]}
thiserror = "1.0.62"
reqwest = "0.12.5"

[profile.dev]
debug = 0 # Speed up compilation time and not necessary.
1 change: 1 addition & 0 deletions crates/heat-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rmp-serde = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
log = { workspace = true }
rand = { version = "0.8.5" }
reqwest = { workspace = true, features = ["blocking", "json"] }
tungstenite = { version = "0.21.0" }
thiserror = { workspace = true }
Expand Down
9 changes: 9 additions & 0 deletions crates/heat-sdk/src/http/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;

use rand::Rng;
use reqwest::header::{COOKIE, SET_COOKIE};
use serde::Serialize;

Expand Down Expand Up @@ -125,11 +126,19 @@ impl HttpClient {

let url = format!("{}/projects/{}/experiments", self.base_url, project_id);

let mut body = HashMap::new();
let mut rng = rand::thread_rng();
body.insert(
"experiment_name",
format!("guide-{}", rng.gen_range(0..10000)),
);

// Create a new experiment
let exp_uuid = self
.http_client
.post(url)
.header(COOKIE, self.session_cookie.as_ref().unwrap())
.json(&body)
.send()?
.error_for_status()?
.json::<CreateExperimentResponseSchema>()?
Expand Down
12 changes: 9 additions & 3 deletions examples/guide/src/data.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
tensor::TensorData,
};

#[derive(Clone)]
Expand All @@ -24,8 +25,8 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items
.iter()
.map(|item| Data::<f32, 2>::from(item.image))
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device))
.map(|item| TensorData::from(item.image))
.map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), &self.device))
.map(|tensor| tensor.reshape([1, 28, 28]))
// normalize: make between [0,1] and make the mean = 0 and std = 1
// values mean=0.1307,std=0.3081 were copied from Pytorch Mist Example
Expand All @@ -35,7 +36,12 @@ impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {

let targets = items
.iter()
.map(|item| Tensor::<B, 1, Int>::from_data([(item.label as i64).elem()], &self.device))
.map(|item| {
Tensor::<B, 1, Int>::from_data(
[(item.label as i64).elem::<B::IntElem>()],
&self.device,
)
})
.collect();

let images = Tensor::cat(images, 0);
Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ mod training;

use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
backend::{Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};

fn main() {
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;

let args = Args::parse();
Expand Down
82 changes: 59 additions & 23 deletions xtask/src/commands/check.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use std::process::Command;

use anyhow::{Ok, Result, anyhow};
use anyhow::{anyhow, Ok, Result};
use clap::{Args, Subcommand};
use strum::{Display, EnumIter, EnumString, IntoEnumIterator};

use crate::{endgroup, group, utils::{cargo::ensure_cargo_crate_is_installed, prompt::ask_once, workspace::{get_workspace_members, WorkspaceMemberType}}};
use crate::{
endgroup, group,
utils::{
cargo::ensure_cargo_crate_is_installed,
prompt::ask_once,
workspace::{get_workspace_members, WorkspaceMemberType},
},
};

use super::Target;

Expand Down Expand Up @@ -36,17 +43,21 @@ pub(crate) fn handle_command(args: CheckCmdArgs, answer: Option<bool>) -> anyhow
CheckCommand::Format => run_format(&args.target, answer),
CheckCommand::Lint => run_lint(&args.target, answer),
CheckCommand::All => {
let answer = ask_once("This will run all the checks with autofix on all members of the workspace.");
let answer = ask_once(
"This will run all the checks with autofix on all members of the workspace.",
);
CheckCommand::iter()
.filter(|c| *c != CheckCommand::All)
.try_for_each(|c| handle_command(
CheckCmdArgs {
command: c,
target: args.target.clone()
},
Some(answer),
))
},
.try_for_each(|c| {
handle_command(
CheckCmdArgs {
command: c,
target: args.target.clone(),
},
Some(answer),
)
})
}
}
}

Expand All @@ -71,7 +82,7 @@ pub(crate) fn run_audit(target: &Target, mut answer: Option<bool>) -> anyhow::Re
}
endgroup!();
}
},
}
Target::All => {
let answer = ask_once("This will run audit checks on all targets.");
Target::iter()
Expand All @@ -94,7 +105,11 @@ fn run_format(target: &Target, mut answer: Option<bool>) -> Result<()> {
if answer.is_none() {
answer = Some(ask_once(&format!(
"This will run format checks on all {} of the workspace.",
if *target == Target::Crates { "crates" } else { "examples" }
if *target == Target::Crates {
"crates"
} else {
"examples"
}
)));
}

Expand All @@ -107,22 +122,27 @@ fn run_format(target: &Target, mut answer: Option<bool>) -> Result<()> {
.status()
.map_err(|e| anyhow!("Failed to execute cargo fmt: {}", e))?;
if !status.success() {
return Err(anyhow!("Format check execution failed for {}", &member.name));
return Err(anyhow!(
"Format check execution failed for {}",
&member.name
));
}
endgroup!();
}
}
},
}
Target::All => {
if answer.is_none() {
answer = Some(ask_once("This will run format check on all members of the workspace."));
answer = Some(ask_once(
"This will run format check on all members of the workspace.",
));
}
if answer.unwrap() {
Target::iter()
.filter(|t| *t != Target::All)
.try_for_each(|t| run_format(&t, answer))?;
}
},
}
}
Ok(())
}
Expand All @@ -139,16 +159,30 @@ fn run_lint(target: &Target, mut answer: Option<bool>) -> anyhow::Result<()> {
if answer.is_none() {
answer = Some(ask_once(&format!(
"This will run lint fix on all {} of the workspace.",
if *target == Target::Crates { "crates" } else { "examples" }
if *target == Target::Crates {
"crates"
} else {
"examples"
}
)));
}

if answer.unwrap() {
for member in members {
group!("Lint: {}", member.name);
info!("Command line: cargo clippy --no-deps --fix --allow-dirty -p {}", &member.name);
info!(
"Command line: cargo clippy --no-deps --fix --allow-dirty -p {}",
&member.name
);
let status = Command::new("cargo")
.args(["clippy", "--no-deps", "--fix", "--allow-dirty", "-p", &member.name])
.args([
"clippy",
"--no-deps",
"--fix",
"--allow-dirty",
"-p",
&member.name,
])
.status()
.map_err(|e| anyhow!("Failed to execute cargo clippy: {}", e))?;
if !status.success() {
Expand All @@ -157,17 +191,19 @@ fn run_lint(target: &Target, mut answer: Option<bool>) -> anyhow::Result<()> {
endgroup!();
}
}
},
}
Target::All => {
if answer.is_none() {
answer = Some(ask_once("This will run lint fix on all members of the workspace."));
answer = Some(ask_once(
"This will run lint fix on all members of the workspace.",
));
}
if answer.unwrap() {
Target::iter()
.filter(|t| *t != Target::All)
.try_for_each(|t| run_lint(&t, answer))?;
}
},
}
}
Ok(())
}
Loading
Loading