Skip to content

Commit

Permalink
Can now run for dev or prod.
Browse files Browse the repository at this point in the history
Regenerated Cargo.lock to fix some vulnerabilities

Co-authored-by: Jonathan Richard <[email protected]>
  • Loading branch information
ThierryCantin-Demers and jwric committed Nov 29, 2024
1 parent a44e5fb commit f359e79
Show file tree
Hide file tree
Showing 14 changed files with 757 additions and 470 deletions.
922 changes: 543 additions & 379 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ license = "MIT OR Apache-2.0"

[workspace.dependencies]
burn = { version = "0.15.0" }
# burn = { git = "https://github.com/tracel-ai/burn", tag="v0.13.2", version = "*" }

anyhow = "1.0.81"
clap = { version = "4.5.4", features = ["derive"] }
Expand All @@ -23,14 +22,13 @@ derive-new = { version = "0.6.0", default-features = false }
derive_more = { version = "0.99.18", features = [
"display",
], default-features = false }
dotenv = "0.15.0"
env_logger = "0.11.3"
log = "0.4.21"
once_cell = "1.19.0"
proc-macro2 = { version = "1.0.86" }
quote = "1.0.36"
rand = "0.8.5"
reqwest = "0.12.4"
reqwest = "0.12.9"
regex = "1.10.5"
rmp-serde = "1.3.0"
rstest = "0.19.0"
Expand Down
32 changes: 29 additions & 3 deletions crates/heat-sdk-cli-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod name_value;

use name_value::get_name_value;
use proc_macro::TokenStream;
use quote::quote;

Expand Down Expand Up @@ -120,8 +123,30 @@ pub fn heat(args: TokenStream, item: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn heat_cli_main(args: TokenStream, item: TokenStream) -> TokenStream {
let item = parse_macro_input!(item as ItemFn);

let module_path = parse_macro_input!(args as Path); // Parse the module path
let args: Punctuated<Meta, syn::token::Comma> =
parse_macro_input!(args with Punctuated::<Meta, syn::Token![,]>::parse_terminated);

let module_path = args
.first()
.expect("Should be able to get first arg.")
.path()
.clone();
let api_endpoint: Option<String> = get_name_value(&args, "api_endpoint");
let wss: Option<bool> = get_name_value(&args, "wss");

let mut config_block = quote! {
let mut config = tracel::heat::cli::config::Config::default();
};
if let Some(api_endpoint) = api_endpoint {
config_block.extend(quote! {
config.api_endpoint = #api_endpoint.to_string();
});
}
if let Some(wss) = wss {
config_block.extend(quote! {
config.wss = #wss;
});
}

let item_sig = &item.sig;
let item_block = &item.block;
Expand All @@ -147,7 +172,8 @@ pub fn heat_cli_main(args: TokenStream, item: TokenStream) -> TokenStream {
}

#item_sig {
tracel::heat::cli::cli::cli_main();
#config_block
tracel::heat::cli::cli::cli_main(config);
}
};

Expand Down
42 changes: 42 additions & 0 deletions crates/heat-sdk-cli-macros/src/name_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use syn::{punctuated::Punctuated, Expr, Meta};

pub trait LitMatcher<T> {
fn match_type(&self) -> T;
}

impl LitMatcher<String> for syn::Lit {
fn match_type(&self) -> String {
match self {
syn::Lit::Str(lit) => lit.value(),
_ => panic!("Expected a string literal"),
}
}
}

impl LitMatcher<bool> for syn::Lit {
fn match_type(&self) -> bool {
match self {
syn::Lit::Bool(lit) => lit.value,
_ => panic!("Expected a boolean literal"),
}
}
}

pub fn get_name_value<T>(args: &Punctuated<Meta, syn::token::Comma>, name: &str) -> Option<T>
where
syn::Lit: LitMatcher<T>,
{
args.iter()
.find(|meta| meta.path().is_ident(name))
.and_then(|meta| {
if let Meta::NameValue(meta) = meta {
if let Expr::Lit(lit) = &meta.value {
Some(lit.lit.match_type())
} else {
None
}
} else {
None
}
})
}
8 changes: 3 additions & 5 deletions crates/heat-sdk-cli/src/cli.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use clap::{Parser, Subcommand};

use crate::commands::time::format_duration;
use crate::config::Config;
use crate::context::HeatCliContext;
use crate::{cli_commands, print_err, print_info};

Expand All @@ -27,7 +28,7 @@ pub enum Commands {
// Logout,
}

pub fn cli_main() {
pub fn cli_main(config: Config) {
print_info!("Running CLI");
let time_begin = std::time::Instant::now();
let args = CliArgs::try_parse();
Expand All @@ -36,10 +37,7 @@ pub fn cli_main() {
std::process::exit(1);
}

let user_project_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME not set");
let user_crate_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set");

let context = HeatCliContext::new(user_project_name, user_crate_dir.into()).init();
let context = HeatCliContext::new(&config).init();

let cli_res = match args.unwrap().command {
Commands::Run(run_args) => cli_commands::run::handle_command(run_args, context),
Expand Down
18 changes: 8 additions & 10 deletions crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,6 @@ pub struct RemoteTrainingRunArgs {
help = "<required> The Heat API key."
)]
key: String,
/// The Heat API endpoint
#[clap(
short = 'e',
long = "endpoint",
help = "The Heat API endpoint.",
default_value = "https://heat.tracel.ai/api"
)]
pub heat_endpoint: String,
/// The runner group name
#[clap(
short = 'r',
Expand All @@ -55,13 +47,14 @@ pub struct RemoteTrainingRunArgs {
pub runner: String,
}

fn create_heat_client(api_key: &str, url: &str, project_path: &str) -> HeatClient {
fn create_heat_client(api_key: &str, url: &str, wss: bool, project_path: &str) -> HeatClient {
let creds = HeatCredentials::new(api_key.to_owned());
let client_config = HeatClientConfig::builder(
creds,
ProjectPath::try_from(project_path.to_string()).expect("Project path should be valid."),
)
.with_endpoint(url)
.with_wss(wss)
.with_num_retries(10)
.build();
HeatClient::create(client_config)
Expand All @@ -72,7 +65,12 @@ pub(crate) fn handle_command(
args: RemoteTrainingRunArgs,
context: HeatCliContext,
) -> anyhow::Result<()> {
let heat_client = create_heat_client(&args.key, &args.heat_endpoint, &args.project_path);
let heat_client = create_heat_client(
&args.key,
context.get_api_endpoint().as_str(),
context.get_wss(),
&args.project_path,
);

let crates = crate::util::cargo::package::package(
&context.get_artifacts_dir_path(),
Expand Down
14 changes: 14 additions & 0 deletions crates/heat-sdk-cli/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#[derive(Debug, Clone)]
pub struct Config {
pub api_endpoint: String,
pub wss: bool,
}

impl Default for Config {
fn default() -> Self {
Config {
api_endpoint: String::from("https://heat.tracel.ai/api/"),
wss: true,
}
}
}
25 changes: 24 additions & 1 deletion crates/heat-sdk-cli/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{
commands::{BuildCommand, RunCommand, RunParams},
config::Config,
generation::{FileTree, GeneratedCrate, HeatDir},
print_info,
};
Expand All @@ -11,10 +12,17 @@ pub struct HeatCliContext {
generated_crate_name: Option<String>,
build_profile: String,
heat_dir: HeatDir,
api_endpoint: url::Url,
wss: bool,
}

impl HeatCliContext {
pub fn new(user_project_name: String, user_crate_dir: PathBuf) -> Self {
pub fn new(config: &Config) -> Self {
let user_project_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME not set");
let user_crate_dir: PathBuf = std::env::var("CARGO_MANIFEST_DIR")
.expect("CARGO_MANIFEST_DIR not set")
.into();

let heat_dir = match HeatDir::try_from_path(&user_crate_dir) {
Ok(heat_dir) => heat_dir,
Err(_) => HeatDir::new(),
Expand All @@ -26,6 +34,11 @@ impl HeatCliContext {
generated_crate_name: None,
build_profile: "release".to_string(),
heat_dir,
api_endpoint: config
.api_endpoint
.parse::<url::Url>()
.expect("API endpoint should be valid"),
wss: config.wss,
}
}

Expand All @@ -38,6 +51,14 @@ impl HeatCliContext {
self.user_project_name.as_str()
}

pub fn get_api_endpoint(&self) -> &url::Url {
&self.api_endpoint
}

pub fn get_wss(&self) -> bool {
self.wss
}

fn get_generated_crate_path(&self) -> PathBuf {
let crate_name = self
.generated_crate_name
Expand Down Expand Up @@ -100,6 +121,8 @@ impl HeatCliContext {
.env("HEAT_PROJECT_DIR", &self.user_crate_dir)
.args(["--project", project])
.args(["--key", key])
.args(["--heat-endpoint", self.get_api_endpoint().as_str()])
.args(["--wss", self.get_wss().to_string().as_str()])
.args(["train", function, config_path]);
command
}
Expand Down
21 changes: 15 additions & 6 deletions crates/heat-sdk-cli/src/generation/crate_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,12 @@ fn generate_clap_cli() -> proc_macro2::TokenStream {
.short('e')
.long("heat-endpoint")
.help("The Heat endpoint")
.default_value("https://heat.tracel.ai/api"),
.required(true),
clap::Arg::new("wss")
.short('w')
.long("wss")
.help("Whether to use WSS")
.required(true),
]);

command
Expand All @@ -248,7 +253,6 @@ fn generate_training_function(
train_func_match: &proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
quote! {
let client = create_heat_client(&key, &heat_endpoint, &project);
let training_config_str = std::fs::read_to_string(&config_path).expect("Config should be read");
let training_config: serde_json::Value = serde_json::from_str(&training_config_str).expect("Config should be deserialized");

Expand Down Expand Up @@ -337,10 +341,11 @@ fn generate_main_rs(main_backend: &BackendType) -> String {
use tracel::heat::command::train::*;
use burn::prelude::*;

fn create_heat_client(api_key: &str, url: &str, project: &str) -> tracel::heat::client::HeatClient {
fn create_heat_client(api_key: &str, url: &str, project: &str, wss: bool) -> tracel::heat::client::HeatClient {
let creds = tracel::heat::client::HeatCredentials::new(api_key.to_owned());
let client_config = tracel::heat::client::HeatClientConfig::builder(creds, tracel::heat::schemas::ProjectPath::try_from(project.to_string()).expect("Project path should be valid."))
.with_endpoint(url)
.with_wss(wss)
.with_num_retries(10)
.build();
tracel::heat::client::HeatClient::create(client_config)
Expand All @@ -352,12 +357,16 @@ fn generate_main_rs(main_backend: &BackendType) -> String {

let device = #backend_default_device;

let key = matches.get_one::<String>("key").expect("key should be set.");
let heat_endpoint = matches.get_one::<String>("heat-endpoint").expect("heat-endpoint should be set.");
let project = matches.get_one::<String>("project").expect("project should be set.");
let wss = matches.get_one::<String>("wss").expect("wss should be set.").parse::<bool>().expect("wss should be a boolean.");

let client = create_heat_client(&key, &heat_endpoint, &project, wss);

if let Some(train_matches) = matches.subcommand_matches("train") {
let func = train_matches.get_one::<String>("func").expect("func should be set.");
let config_path = train_matches.get_one::<String>("config").expect("config should be set.");
let project = matches.get_one::<String>("project").expect("project should be set.");
let key = matches.get_one::<String>("key").expect("key should be set.");
let heat_endpoint = matches.get_one::<String>("heat-endpoint").expect("heat-endpoint should be set.");

#generated_training
}
Expand Down
1 change: 1 addition & 0 deletions crates/heat-sdk-cli/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod cli;
pub mod config;
pub mod registry;

mod cli_commands;
Expand Down
2 changes: 1 addition & 1 deletion crates/heat-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ thiserror = { workspace = true }
tracing = { version = "0.1.40" }
tracing-core = { version = "0.1.32" }
tracing-subscriber = { version = "0.3.18" }
tungstenite = { version = "0.21.0", features = ["native-tls"] }
tungstenite = { version = "0.24.0", features = ["native-tls"] }
uuid = { workspace = true }
regex = { workspace = true }
once_cell = { workspace = true }
16 changes: 15 additions & 1 deletion crates/heat-sdk/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ impl From<HeatCredentials> for String {
pub struct HeatClientConfig {
/// The endpoint of the Heat API
pub endpoint: String,
/// Whether to use a secure WebSocket connection
pub wss: bool,
/// Heat credential to create a session with the Heat API
pub credentials: HeatCredentials,
/// The number of retries to attempt when connecting to the Heat API.
Expand Down Expand Up @@ -70,6 +72,7 @@ impl HeatClientConfigBuilder {
HeatClientConfigBuilder {
config: HeatClientConfig {
endpoint: "http://127.0.0.1:9001".into(),
wss: false,
credentials: creds,
num_retries: 3,
retry_interval: 3,
Expand All @@ -84,6 +87,13 @@ impl HeatClientConfigBuilder {
self
}

/// Set whether to use a secure WebSocket connection
/// If this is set to true, the WebSocket connection will use the `wss` protocol instead of `ws`.
pub fn with_wss(mut self, wss: bool) -> HeatClientConfigBuilder {
self.config.wss = wss;
self
}

/// Set the number of retries to attempt when connecting to the Heat API
pub fn with_num_retries(mut self, num_retries: u8) -> HeatClientConfigBuilder {
self.config.num_retries = num_retries;
Expand Down Expand Up @@ -115,7 +125,11 @@ pub type HeatClientState = HeatClient;

impl HeatClient {
fn new(config: HeatClientConfig) -> HeatClient {
let http_client = HttpClient::new(config.endpoint.clone());
let url = config
.endpoint
.parse()
.expect("Should be able to parse the URL");
let http_client = HttpClient::new(url, config.wss);

HeatClient {
config,
Expand Down
Loading

0 comments on commit f359e79

Please sign in to comment.