Skip to content

Commit

Permalink
Fix a few problems with prod environment (#37)
Browse files Browse the repository at this point in the history
* few detected issues

* Removed env_logger which made logs not upload to heat

* fix endpoint for remote training

* Updated Burn and removed git patch for Burn

* Can now run for dev or prod.

Regenerated Cargo.lock to fix some vulnerabilities

Co-authored-by: Jonathan Richard <[email protected]>

---------

Co-authored-by: Jonathan Richard <[email protected]>
Co-authored-by: Jonathan Richard <[email protected]>
  • Loading branch information
3 people authored Jan 8, 2025
1 parent 876d253 commit eafdbfe
Show file tree
Hide file tree
Showing 17 changed files with 525 additions and 305 deletions.
412 changes: 297 additions & 115 deletions Cargo.lock

Large diffs are not rendered by default.

27 changes: 14 additions & 13 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2
resolver = "2"

members = [
"crates/*",
"examples/*",
"xtask",
]
members = ["crates/*", "examples/*", "xtask"]

[workspace.package]
edition = "2021"
Expand All @@ -17,22 +13,22 @@ readme = "README.md"
license = "MIT OR Apache-2.0"

[workspace.dependencies]
burn = { git = "https://github.com/tracel-ai/burn", version = "0.14.0", rev="a72a533" }
# burn = { git = "https://github.com/tracel-ai/burn", tag="v0.13.2", version = "*" }
burn = { version = "0.15.0" }

anyhow = "1.0.81"
clap = { version = "4.5.4", features = ["derive"] }
colored = "2.1.0"
derive-new = { version = "0.6.0", default-features = false }
derive_more = { version = "0.99.18", features = ["display"], default-features = false }
dotenv = "0.15.0"
derive_more = { version = "0.99.18", features = [
"display",
], default-features = false }
env_logger = "0.11.3"
log = "0.4.21"
once_cell = "1.19.0"
proc-macro2 = { version = "1.0.86" }
quote = "1.0.36"
rand = "0.8.5"
reqwest = "0.12.4"
reqwest = "0.12.9"
regex = "1.10.5"
rmp-serde = "1.3.0"
rstest = "0.19.0"
Expand All @@ -41,10 +37,15 @@ serde = { version = "1.0.204", default-features = false, features = [
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = "1.0.64"
strum = {version = "0.26.2", features = ["derive"]}
syn = { version = "2.0.71", features = ["extra-traits","full"] }
strum = { version = "0.26.2", features = ["derive"] }
syn = { version = "2.0.71", features = ["extra-traits", "full"] }
thiserror = "1.0.30"
uuid = { version = "1.9.1", features = ["v4","fast-rng","macro-diagnostics", "serde"] }
uuid = { version = "1.9.1", features = [
"v4",
"fast-rng",
"macro-diagnostics",
"serde",
] }

### For xtask crate ###
tracel-xtask = { version = "=1.1.8" }
Expand Down
32 changes: 29 additions & 3 deletions crates/heat-sdk-cli-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod name_value;

use name_value::get_name_value;
use proc_macro::TokenStream;
use quote::quote;

Expand Down Expand Up @@ -120,8 +123,30 @@ pub fn heat(args: TokenStream, item: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn heat_cli_main(args: TokenStream, item: TokenStream) -> TokenStream {
let item = parse_macro_input!(item as ItemFn);

let module_path = parse_macro_input!(args as Path); // Parse the module path
let args: Punctuated<Meta, syn::token::Comma> =
parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);

let module_path = args
.first()
.expect("Should be able to get first arg.")
.path()
.clone();
let api_endpoint: Option<String> = get_name_value(&args, "api_endpoint");
let wss: Option<bool> = get_name_value(&args, "wss");

let mut config_block = quote! {
let mut config = tracel::heat::cli::config::Config::default();
};
if let Some(api_endpoint) = api_endpoint {
config_block.extend(quote! {
config.api_endpoint = #api_endpoint.to_string();
});
}
if let Some(wss) = wss {
config_block.extend(quote! {
config.wss = #wss;
});
}

let item_sig = &item.sig;
let item_block = &item.block;
Expand All @@ -147,7 +172,8 @@ pub fn heat_cli_main(args: TokenStream, item: TokenStream) -> TokenStream {
}

#item_sig {
tracel::heat::cli::cli::cli_main();
#config_block
tracel::heat::cli::cli::cli_main(config);
}
};

Expand Down
42 changes: 42 additions & 0 deletions crates/heat-sdk-cli-macros/src/name_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use syn::{punctuated::Punctuated, Expr, Meta};

pub trait LitMatcher<T> {
fn match_type(&self) -> T;
}

impl LitMatcher<String> for syn::Lit {
fn match_type(&self) -> String {
match self {
syn::Lit::Str(lit) => lit.value(),
_ => panic!("Expected a string literal"),
}
}
}

impl LitMatcher<bool> for syn::Lit {
fn match_type(&self) -> bool {
match self {
syn::Lit::Bool(lit) => lit.value,
_ => panic!("Expected a boolean literal"),
}
}
}

pub fn get_name_value<T>(args: &Punctuated<Meta, syn::token::Comma>, name: &str) -> Option<T>
where
syn::Lit: LitMatcher<T>,
{
args.iter()
.find(|meta| meta.path().is_ident(name))
.and_then(|meta| {
if let Meta::NameValue(meta) = meta {
if let Expr::Lit(lit) = &meta.value {
Some(lit.lit.match_type())
} else {
None
}
} else {
None
}
})
}
8 changes: 3 additions & 5 deletions crates/heat-sdk-cli/src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use clap::{Parser, Subcommand};

use crate::commands::time::format_duration;
use crate::config::Config;
use crate::context::HeatCliContext;
use crate::{cli_commands, print_err, print_info};

Expand All @@ -27,7 +28,7 @@ pub enum Commands {
// Logout,
}

pub fn cli_main() {
pub fn cli_main(config: Config) {
print_info!("Running CLI");
let time_begin = std::time::Instant::now();
let args = CliArgs::try_parse();
Expand All @@ -36,10 +37,7 @@ pub fn cli_main() {
std::process::exit(1);
}

let user_project_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME not set");
let user_crate_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");

let context = HeatCliContext::new(user_project_name, user_crate_dir.into()).init();
let context = HeatCliContext::new(&config).init();

let cli_res = match args.unwrap().command {
Commands::Run(run_args) => cli_commands::run::handle_command(run_args, context),
Expand Down
31 changes: 0 additions & 31 deletions crates/heat-sdk-cli/src/cli_commands/run/local/local.rs

This file was deleted.

23 changes: 0 additions & 23 deletions crates/heat-sdk-cli/src/cli_commands/run/remote/remote.rs

This file was deleted.

18 changes: 8 additions & 10 deletions crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ pub struct RemoteTrainingRunArgs {
help = "<required> The Heat API key."
)]
key: String,
/// The Heat API endpoint
#[clap(
short = 'e',
long = "endpoint",
help = "The Heat API endpoint.",
default_value = "http://127.0.0.1:9001"
)]
pub heat_endpoint: String,
/// The runner group name
#[clap(
short = 'r',
Expand All @@ -55,13 +47,14 @@ pub struct RemoteTrainingRunArgs {
pub runner: String,
}

fn create_heat_client(api_key: &str, url: &str, project_path: &str) -> HeatClient {
fn create_heat_client(api_key: &str, url: &str, wss: bool, project_path: &str) -> HeatClient {
let creds = HeatCredentials::new(api_key.to_owned());
let client_config = HeatClientConfig::builder(
creds,
ProjectPath::try_from(project_path.to_string()).expect("Project path should be valid."),
)
.with_endpoint(url)
.with_wss(wss)
.with_num_retries(10)
.build();
HeatClient::create(client_config)
Expand All @@ -72,7 +65,12 @@ pub(crate) fn handle_command(
args: RemoteTrainingRunArgs,
context: HeatCliContext,
) -> anyhow::Result<()> {
let heat_client = create_heat_client(&args.key, &args.heat_endpoint, &args.project_path);
let heat_client = create_heat_client(
&args.key,
context.get_api_endpoint().as_str(),
context.get_wss(),
&args.project_path,
);

let crates = crate::util::cargo::package::package(
&context.get_artifacts_dir_path(),
Expand Down
27 changes: 0 additions & 27 deletions crates/heat-sdk-cli/src/cli_commands/run/run.rs

This file was deleted.

14 changes: 14 additions & 0 deletions crates/heat-sdk-cli/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#[derive(Debug, Clone)]
pub struct Config {
pub api_endpoint: String,
pub wss: bool,
}

impl Default for Config {
fn default() -> Self {
Config {
api_endpoint: String::from("https://heat.tracel.ai/api/"),
wss: true,
}
}
}
Loading

0 comments on commit eafdbfe

Please sign in to comment.