From eafdbfe8d0c7379f79b9464424989ba04bd4c7a3 Mon Sep 17 00:00:00 2001 From: Thierry Cantin-Demers <48504765+ThierryCantin-Demers@users.noreply.github.com> Date: Wed, 8 Jan 2025 09:25:06 -0500 Subject: [PATCH] Fix a few problems with prod environment (#37) * few detected issues * Removed env_logger which made logs not upload to heat * fix endpoint for remote training * Updated Burn and removed git patch for Burn * Can now run for dev or prod. Regenerated Cargo.lock to fix some vulnerabilities Co-authored-by: Jonathan Richard --------- Co-authored-by: Jonathan Richard Co-authored-by: Jonathan Richard --- Cargo.lock | 412 +++++++++++++----- Cargo.toml | 27 +- crates/heat-sdk-cli-macros/src/lib.rs | 32 +- crates/heat-sdk-cli-macros/src/name_value.rs | 42 ++ crates/heat-sdk-cli/src/cli.rs | 8 +- .../src/cli_commands/run/local/local.rs | 31 -- .../src/cli_commands/run/remote/remote.rs | 23 - .../src/cli_commands/run/remote/training.rs | 18 +- .../heat-sdk-cli/src/cli_commands/run/run.rs | 27 -- crates/heat-sdk-cli/src/config.rs | 14 + crates/heat-sdk-cli/src/context.rs | 34 +- .../src/generation/crate_gen/mod.rs | 21 +- crates/heat-sdk-cli/src/lib.rs | 1 + crates/heat-sdk/Cargo.toml | 2 +- crates/heat-sdk/src/client.rs | 16 +- crates/heat-sdk/src/http/client.rs | 118 ++--- examples/guide-cli/src/main.rs | 4 +- 17 files changed, 525 insertions(+), 305 deletions(-) create mode 100644 crates/heat-sdk-cli-macros/src/name_value.rs delete mode 100644 crates/heat-sdk-cli/src/cli_commands/run/local/local.rs delete mode 100644 crates/heat-sdk-cli/src/cli_commands/run/remote/remote.rs delete mode 100644 crates/heat-sdk-cli/src/cli_commands/run/run.rs create mode 100644 crates/heat-sdk-cli/src/config.rs diff --git a/Cargo.lock b/Cargo.lock index 081c3c9..3b0921b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -178,12 +178,29 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-lock" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" +dependencies = [ + "event-listener", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atomic_float" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" + [[package]] name = "autocfg" version = "1.4.0" @@ -328,8 +345,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burn" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09d130fc29965cae23afcc594423e55d977fb142cee37e70034209745bb515e2" dependencies = [ "burn-core", "burn-train", @@ -337,8 +355,9 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb84f4c9c9e5e90bfdde7cf18d2f7684f60079629be80672d30add20df0d790f" dependencies = [ "burn-common", "burn-tensor", @@ -349,8 +368,9 @@ dependencies = [ [[package]] name = "burn-candle" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c8be4878fcd5f166755cf7f828edb6f98a120a41a3dca08f2b7a4195c27b78f" dependencies = [ "burn-tensor", "candle-core", @@ -360,11 +380,11 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b812952386979b756241df0f1f3a0bdd5d2969dc39aedc028b0873c33df5e076" dependencies = [ "cubecl-common", - "data-encoding", "getrandom", "indicatif", "rayon", @@ -375,36 +395,60 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32dfd86c6615420077a089b01d83d9732f16e6dc692a1ad8184d5d46b69691cb" dependencies = [ + "ahash", "bincode", "burn-autodiff", "burn-candle", "burn-common", + "burn-cuda", "burn-dataset", "burn-derive", + "burn-hip", "burn-ndarray", "burn-tch", "burn-tensor", "burn-wgpu", + "data-encoding", "derive-new", "flate2", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "log", "num-traits", + "portable-atomic-util", "rand", "rmp-serde", "serde", "serde_json", "spin", + "uuid", +] + +[[package]] +name = "burn-cuda" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f8a396e355968cfff8ec00f4ad225ccc3bbfc62f99036722e54af68066b0cd5" +dependencies = [ + "burn-fusion", + "burn-jit", + "burn-tensor", + "bytemuck", + "cubecl", + "derive-new", + "half", + "log", ] [[package]] name = "burn-dataset" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c6dc764e250fdccda1323a045881d369af0add9756f2ca95b2df3fc63810b8e" dependencies = [ "burn-common", "csv", @@ -431,8 +475,9 @@ dependencies = [ [[package]] name = "burn-derive" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "893ec21275e5b0fceca831e6b751ab0df82fcb8ac4a987fe5d704639d23b1e0a" dependencies = [ "derive-new", "proc-macro2", @@ -442,22 +487,41 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3de03914bb49c972f138a2ad6f6d92896a0e8a537a50aca6463fc55bb7bf8d9" dependencies = [ "burn-common", "burn-tensor", "derive-new", - "hashbrown 0.14.5", + "half", + "hashbrown 0.15.2", "log", "serde", "spin", ] +[[package]] +name = "burn-hip" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28af539a8dbf799747b00e1c28f45a876673de3ff5818f9e7f1e47c0806dd795" +dependencies = [ + "burn-fusion", + "burn-jit", + "burn-tensor", + "bytemuck", + "cubecl", + "derive-new", + "half", + "log", +] + [[package]] name = "burn-jit" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d8af3155e5004a67ee0b378854912374c4b61e2d8d57fcd35794953dff336fc" dependencies = [ "burn-common", "burn-fusion", @@ -465,8 +529,9 @@ dependencies = [ "bytemuck", "cubecl", "derive-new", + "futures-lite", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "log", "num-traits", "rand", @@ -477,45 +542,53 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9df139ffaf1f3b5805dcc5e531e7fa3119ad52dcc5b42883661fbf0a694cd82b" dependencies = [ + "atomic_float", "burn-autodiff", "burn-common", "burn-tensor", "derive-new", "libm", "matrixmultiply", - "ndarray", + "ndarray 0.16.1", "num-traits", + "portable-atomic-util", "rand", "spin", ] [[package]] name = "burn-tch" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4dcb26f4a82f4f2a735d2390e390c499e84e1d5707eeabeefdbb8e160c870a7" dependencies = [ "burn-tensor", "half", "libc", + "log", "rand", "tch", ] [[package]] name = "burn-tensor" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb51351b03b2ab2a415c4de924df0409922c75268f9294793e38ae783754013c" dependencies = [ "burn-common", "bytemuck", + "colored", "cubecl", "derive-new", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "num-traits", + "portable-atomic-util", "rand", "rand_distr", "serde", @@ -524,11 +597,11 @@ dependencies = [ [[package]] name = "burn-train" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82e25321e3fb5633335217a0efe84ac2441c70bb61b4cdf66cc41f433ab38bef" dependencies = [ "burn-core", - "crossterm", "derive-new", "log", "nvml-wrapper", @@ -543,8 +616,9 @@ dependencies = [ [[package]] name = "burn-wgpu" -version = "0.14.0" -source = "git+https://github.com/tracel-ai/burn?rev=a72a533#a72a533855ce45f519d0c234474073a100e65af3" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd433918264c8233d104eee59294d13be2942740d369afa2284e08b143927c84" dependencies = [ "burn-fusion", "burn-jit", @@ -734,6 +808,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "cipher" version = "0.4.4" @@ -855,13 +935,14 @@ dependencies = [ [[package]] name = "compact_str" -version = "0.7.1" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f" +checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32" dependencies = [ "castaway", "cfg-if", "itoa", + "rustversion", "ryu", "static_assertions", ] @@ -975,15 +1056,15 @@ checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crossterm" -version = "0.27.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ "bitflags 2.6.0", "crossterm_winapi", - "libc", - "mio 0.8.11", + "mio", "parking_lot", + "rustix", "signal-hook", "signal-hook-mio", "winapi", @@ -1037,23 +1118,30 @@ dependencies = [ [[package]] name = "cubecl" -version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?tag=v0.1.1#2b95a9e245bf4362b497866ee24bec399d1c74fb" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e75c7e982b943380665c5901fe0b69d5df2627644e0e50199c52b64d8d5a1c" dependencies = [ "cubecl-core", "cubecl-cuda", + "cubecl-hip", "cubecl-linalg", + "cubecl-runtime", "cubecl-wgpu", ] [[package]] name = "cubecl-common" -version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?tag=v0.1.1#2b95a9e245bf4362b497866ee24bec399d1c74fb" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d402af454241d28d303a4cf4d2a861fae18404d65964c31934f746a40a6cf4" dependencies = [ "derive-new", + "embassy-futures", + "futures-lite", "getrandom", - "pollster", + "log", + "portable-atomic", "rand", "serde", "spin", @@ -1062,27 +1150,47 @@ dependencies = [ [[package]] name = "cubecl-core" -version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?tag=v0.1.1#2b95a9e245bf4362b497866ee24bec399d1c74fb" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec33b64139d1dfc747df8aed5834d10c3c55c716f5219041c6eb17241c96c929" dependencies = [ "bytemuck", + "cubecl-common", "cubecl-macros", "cubecl-runtime", "derive-new", "half", "log", "num-traits", + "paste", "serde", ] +[[package]] +name = "cubecl-cpp" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ded461feb0ff342a4f675131dc0ae8ad94e58f66bad11e57f852cb7f190a731" +dependencies = [ + "bytemuck", + "cubecl-common", + "cubecl-core", + "cubecl-runtime", + "derive-new", + "half", + "log", +] + [[package]] name = "cubecl-cuda" -version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?tag=v0.1.1#2b95a9e245bf4362b497866ee24bec399d1c74fb" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88dfdfe616124d2abe5e82052ff56f86843c369440e181d6936f7409e161dd82" dependencies = [ "bytemuck", "cubecl-common", "cubecl-core", + "cubecl-cpp", "cubecl-runtime", "cudarc", "derive-new", @@ -1090,10 +1198,37 @@ dependencies = [ "log", ] +[[package]] +name = "cubecl-hip" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "409e0e176152ab51a60bbebb940b7a72aba210cd42b5f8cd2e87e7d7e674a13a" +dependencies = [ + "bytemuck", + "cubecl-common", + "cubecl-core", + "cubecl-cpp", + "cubecl-hip-sys", + "cubecl-runtime", + "derive-new", + "half", + "log", +] + +[[package]] +name = "cubecl-hip-sys" +version = "0.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2553766b483a28dd7db67cc4be9c61a7aa8cc7f02b3b8059ffdaeea1d8c8590e" +dependencies = [ + "libc", +] + [[package]] name = "cubecl-linalg" -version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?tag=v0.1.1#2b95a9e245bf4362b497866ee24bec399d1c74fb" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5634782d790e9b6562fc267ffd15e9a510b4d6ec32c144cd2b166af2ba0cfb" dependencies = [ "bytemuck", "cubecl-core", @@ -1103,10 +1238,15 @@ dependencies = [ [[package]] name = "cubecl-macros" -version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?tag=v0.1.1#2b95a9e245bf4362b497866ee24bec399d1c74fb" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2d22663257d9cdbcd67f5048d6f4e6eb965dd87104c3a173a7b0ea0d720e99b" dependencies = [ + "cubecl-common", + "darling", "derive-new", + "ident_case", + "prettyplease", "proc-macro2", "quote", "syn 2.0.90", @@ -1114,45 +1254,50 @@ dependencies = [ [[package]] name = "cubecl-runtime" -version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?tag=v0.1.1#2b95a9e245bf4362b497866ee24bec399d1c74fb" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3468467f412dff4bbf97fb5061a3557445f017299e2fb73ef7b96c6cdb799bc3" dependencies = [ "async-channel", + "async-lock", + "cfg_aliases 0.2.1", "cubecl-common", "derive-new", "dirs", "hashbrown 0.14.5", "log", "md5", - "pollster", + "sanitize-filename", "serde", "serde_json", "spin", - "web-time", + "wasm-bindgen-futures", ] [[package]] name = "cubecl-wgpu" -version = "0.1.1" -source = "git+https://github.com/tracel-ai/cubecl?tag=v0.1.1#2b95a9e245bf4362b497866ee24bec399d1c74fb" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6779f1072d70923758421c6214fd0cd19a6f25b91035a522f9cd9407d03b5cae" dependencies = [ "async-channel", "bytemuck", + "cfg_aliases 0.2.1", "cubecl-common", "cubecl-core", "cubecl-runtime", "derive-new", "hashbrown 0.14.5", "log", - "pollster", + "web-time", "wgpu", ] [[package]] name = "cudarc" -version = "0.11.5" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e395cd01168d63af826749573071f3c5069b338ae473cab355d22db0b2bb5a0d" +checksum = "8cd76de2aa3a7bdb9a65941ea5a3c688d941688f736a81b2fc5beb88747a7f25" dependencies = [ "libloading", ] @@ -1339,6 +1484,12 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "embassy-futures" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f878075b9794c1e4ac788c95b728f26aa6366d32eeb10c7051389f898f7d067" + [[package]] name = "encode_unicode" version = "0.3.6" @@ -1618,6 +1769,19 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-lite" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cef40d21ae2c515b51041df9ed313ed21e572df340ea58a922a0aefe7e8891a1" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + [[package]] name = "futures-macro" version = "0.3.31" @@ -2663,7 +2827,6 @@ checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", "allocator-api2", - "serde", ] [[package]] @@ -2675,6 +2838,7 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash", + "serde", ] [[package]] @@ -3153,6 +3317,12 @@ dependencies = [ "web-time", ] +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + [[package]] name = "inout" version = "0.1.3" @@ -3162,6 +3332,19 @@ dependencies = [ "generic-array", ] +[[package]] +name = "instability" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "894813a444908c0c8c0e221b041771d107c4a21de1d317dc49bcc66e3c9e5b3f" +dependencies = [ + "darling", + "indoc", + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "interpolate_name" version = "0.2.4" @@ -3360,9 +3543,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.28.0" +version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" dependencies = [ "cc", "pkg-config", @@ -3532,18 +3715,6 @@ dependencies = [ "simd-adler32", ] -[[package]] -name = "mio" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" -dependencies = [ - "libc", - "log", - "wasi", - "windows-sys 0.48.0", -] - [[package]] name = "mio" version = "1.0.3" @@ -3551,6 +3722,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", + "log", "wasi", "windows-sys 0.52.0", ] @@ -3564,7 +3736,7 @@ dependencies = [ "arrayvec", "bit-set", "bitflags 2.6.0", - "cfg_aliases", + "cfg_aliases 0.1.1", "codespan-reporting", "hexf-parse", "indexmap", @@ -3604,6 +3776,21 @@ dependencies = [ "num-integer", "num-traits", "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", "rayon", ] @@ -3983,18 +4170,21 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "pollster" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22686f4785f02a4fcc856d3b3bb19bf6c8160d103f7a99cc258bddd0251dc7f2" - [[package]] name = "portable-atomic" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -4118,9 +4308,9 @@ dependencies = [ [[package]] name = "r2d2_sqlite" -version = "0.24.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a982edf65c129796dba72f8775b292ef482b40d035e827a9825b3bc07ccc5f2" +checksum = "eb14dba8247a6a15b7fdbc7d389e2e6f03ee9f184f87117706d509c092dfe846" dependencies = [ "r2d2", "rusqlite", @@ -4175,23 +4365,24 @@ checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab" [[package]] name = "ratatui" -version = "0.26.3" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f44c9e68fd46eda15c646fbb85e1040b657a58cdc8c98db1d97a55930d991eef" +checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b" dependencies = [ "bitflags 2.6.0", "cassowary", "compact_str", "crossterm", - "itertools 0.12.1", + "indoc", + "instability", + "itertools 0.13.0", "lru", "paste", - "stability", "strum", "time", "unicode-segmentation", "unicode-truncate", - "unicode-width 0.1.14", + "unicode-width 0.2.0", ] [[package]] @@ -4485,9 +4676,9 @@ dependencies = [ [[package]] name = "rusqlite" -version = "0.31.0" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" dependencies = [ "bitflags 2.6.0", "fallible-iterator", @@ -4758,9 +4949,9 @@ dependencies = [ [[package]] name = "serde_rusqlite" -version = "0.35.0" +version = "0.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836d903d9524cecbcd7b75745b6ee0e3f3774b878ea489dfaf2ea749463283d6" +checksum = "b741cc5ef185cd96157e762c3bba743c4e94c8dc6af0edb053c48d2b3c27e691" dependencies = [ "rusqlite", "serde", @@ -4853,7 +5044,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" dependencies = [ "libc", - "mio 0.8.11", + "mio", "signal-hook", ] @@ -4922,6 +5113,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" dependencies = [ "lock_api", + "portable-atomic", ] [[package]] @@ -4933,16 +5125,6 @@ dependencies = [ "bitflags 2.6.0", ] -[[package]] -name = "stability" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d904e7009df136af5297832a3ace3370cd14ff1546a232f4f185036c2736fcac" -dependencies = [ - "quote", - "syn 2.0.90", -] - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -5146,7 +5328,7 @@ dependencies = [ "half", "lazy_static", "libc", - "ndarray", + "ndarray 0.15.6", "rand", "safetensors 0.3.3", "thiserror 1.0.69", @@ -5324,7 +5506,7 @@ dependencies = [ "backtrace", "bytes", "libc", - "mio 1.0.3", + "mio", "pin-project-lite", "socket2", "tokio-macros", @@ -5551,9 +5733,9 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "tungstenite" -version = "0.21.0" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" dependencies = [ "byteorder", "bytes", @@ -5561,10 +5743,10 @@ dependencies = [ "http", "httparse", "log", + "native-tls", "rand", "sha1", "thiserror 1.0.69", - "url", "utf-8", ] @@ -5894,7 +6076,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d1c4ba43f80542cf63a0a6ed3134629ae73e8ab51e4b765a67f3aa062eb433" dependencies = [ "arrayvec", - "cfg_aliases", + "cfg_aliases 0.1.1", "document-features", "js-sys", "log", @@ -5921,7 +6103,7 @@ dependencies = [ "arrayvec", "bit-vec", "bitflags 2.6.0", - "cfg_aliases", + "cfg_aliases 0.1.1", "document-features", "indexmap", "log", @@ -5949,7 +6131,7 @@ dependencies = [ "bit-set", "bitflags 2.6.0", "block", - "cfg_aliases", + "cfg_aliases 0.1.1", "core-graphics-types", "d3d12", "glow", diff --git a/Cargo.toml b/Cargo.toml index 2b51d8b..75938b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,11 +4,7 @@ # https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2 resolver = "2" -members = [ - "crates/*", - "examples/*", - "xtask", -] +members = ["crates/*", "examples/*", "xtask"] [workspace.package] edition = "2021" @@ -17,22 +13,22 @@ readme = "README.md" license = "MIT OR Apache-2.0" [workspace.dependencies] -burn = { git = "https://github.com/tracel-ai/burn", version = "0.14.0", rev="a72a533" } -# burn = { git = "https://github.com/tracel-ai/burn", tag="v0.13.2", version = "*" } +burn = { version = "0.15.0" } anyhow = "1.0.81" clap = { version = "4.5.4", features = ["derive"] } colored = "2.1.0" derive-new = { version = "0.6.0", default-features = false } -derive_more = { version = "0.99.18", features = ["display"], default-features = false } -dotenv = "0.15.0" +derive_more = { version = "0.99.18", features = [ + "display", +], default-features = false } env_logger = "0.11.3" log = "0.4.21" once_cell = "1.19.0" proc-macro2 = { version = "1.0.86" } quote = "1.0.36" rand = "0.8.5" -reqwest = "0.12.4" +reqwest = "0.12.9" regex = "1.10.5" rmp-serde = "1.3.0" rstest = "0.19.0" @@ -41,10 +37,15 @@ serde = { version = "1.0.204", default-features = false, features = [ "alloc", ] } # alloc is for no_std, derive is needed serde_json = "1.0.64" -strum = {version = "0.26.2", features = ["derive"]} -syn = { version = "2.0.71", features = ["extra-traits","full"] } +strum = { version = "0.26.2", features = ["derive"] } +syn = { version = "2.0.71", features = ["extra-traits", "full"] } thiserror = "1.0.30" -uuid = { version = "1.9.1", features = ["v4","fast-rng","macro-diagnostics", "serde"] } +uuid = { version = "1.9.1", features = [ + "v4", + "fast-rng", + "macro-diagnostics", + "serde", +] } ### For xtask crate ### tracel-xtask = { version = "=1.1.8" } diff --git a/crates/heat-sdk-cli-macros/src/lib.rs b/crates/heat-sdk-cli-macros/src/lib.rs index 0071e47..ac83172 100644 --- a/crates/heat-sdk-cli-macros/src/lib.rs +++ b/crates/heat-sdk-cli-macros/src/lib.rs @@ -1,3 +1,6 @@ +mod name_value; + +use name_value::get_name_value; use proc_macro::TokenStream; use quote::quote; @@ -120,8 +123,30 @@ pub fn heat(args: TokenStream, item: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn heat_cli_main(args: TokenStream, item: TokenStream) -> TokenStream { let item = parse_macro_input!(item as ItemFn); - - let module_path = parse_macro_input!(args as Path); // Parse the module path + let args: Punctuated = + parse_macro_input!(args with Punctuated::::parse_terminated); + + let module_path = args + .first() + .expect("Should be able to get first arg.") + .path() + .clone(); + let api_endpoint: Option = get_name_value(&args, "api_endpoint"); + let wss: Option = get_name_value(&args, "wss"); + + let mut config_block = quote! { + let mut config = tracel::heat::cli::config::Config::default(); + }; + if let Some(api_endpoint) = api_endpoint { + config_block.extend(quote! { + config.api_endpoint = #api_endpoint.to_string(); + }); + } + if let Some(wss) = wss { + config_block.extend(quote! { + config.wss = #wss; + }); + } let item_sig = &item.sig; let item_block = &item.block; @@ -147,7 +172,8 @@ pub fn heat_cli_main(args: TokenStream, item: TokenStream) -> TokenStream { } #item_sig { - tracel::heat::cli::cli::cli_main(); + #config_block + tracel::heat::cli::cli::cli_main(config); } }; diff --git a/crates/heat-sdk-cli-macros/src/name_value.rs b/crates/heat-sdk-cli-macros/src/name_value.rs new file mode 100644 index 0000000..147f0d0 --- /dev/null +++ b/crates/heat-sdk-cli-macros/src/name_value.rs @@ -0,0 +1,42 @@ +use syn::{punctuated::Punctuated, Expr, Meta}; + +pub trait LitMatcher { + fn match_type(&self) -> T; +} + +impl LitMatcher for syn::Lit { + fn match_type(&self) -> String { + match self { + syn::Lit::Str(lit) => lit.value(), + _ => panic!("Expected a string literal"), + } + } +} + +impl LitMatcher for syn::Lit { + fn match_type(&self) -> bool { + match self { + syn::Lit::Bool(lit) => lit.value, + _ => panic!("Expected a boolean literal"), + } + } +} + +pub fn get_name_value(args: &Punctuated, name: &str) -> Option +where + syn::Lit: LitMatcher, +{ + args.iter() + .find(|meta| meta.path().is_ident(name)) + .and_then(|meta| { + if let Meta::NameValue(meta) = meta { + if let Expr::Lit(lit) = &meta.value { + Some(lit.lit.match_type()) + } else { + None + } + } else { + None + } + }) +} diff --git a/crates/heat-sdk-cli/src/cli.rs b/crates/heat-sdk-cli/src/cli.rs index 1891b7b..998aecb 100644 --- a/crates/heat-sdk-cli/src/cli.rs +++ b/crates/heat-sdk-cli/src/cli.rs @@ -1,6 +1,7 @@ use clap::{Parser, Subcommand}; use crate::commands::time::format_duration; +use crate::config::Config; use crate::context::HeatCliContext; use crate::{cli_commands, print_err, print_info}; @@ -27,7 +28,7 @@ pub enum Commands { // Logout, } -pub fn cli_main() { +pub fn cli_main(config: Config) { print_info!("Running CLI"); let time_begin = std::time::Instant::now(); let args = CliArgs::try_parse(); @@ -36,10 +37,7 @@ pub fn cli_main() { std::process::exit(1); } - let user_project_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME not set"); - let user_crate_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); - - let context = HeatCliContext::new(user_project_name, user_crate_dir.into()).init(); + let context = HeatCliContext::new(&config).init(); let cli_res = match args.unwrap().command { Commands::Run(run_args) => cli_commands::run::handle_command(run_args, context), diff --git a/crates/heat-sdk-cli/src/cli_commands/run/local/local.rs b/crates/heat-sdk-cli/src/cli_commands/run/local/local.rs deleted file mode 100644 index b791c25..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/local/local.rs +++ /dev/null @@ -1,31 +0,0 @@ -use clap::Parser; - -use crate::{ - cli_commands::local::{ - inference::{self, LocalInferenceRunArgs}, - training::{self, LocalTrainingRunArgs}, - }, - context::HeatCliContext, -}; - -/// Run a training or inference locally. -/// Only local training is supported at the moment. -#[derive(Parser, Debug)] -pub enum LocalRunSubcommand { - /// Run a training locally. - Training(LocalTrainingRunArgs), - /// Run an inference locally. - Inference(LocalInferenceRunArgs), -} - -pub(crate) fn handle_command( - args: LocalRunSubcommand, - context: HeatCliContext, -) -> anyhow::Result<()> { - match args { - LocalRunSubcommand::Training(training_args) => { - training::handle_command(training_args, context) - } - LocalRunSubcommand::Inference(inference_args) => inference::handle_command(inference_args), - } -} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/remote/remote.rs b/crates/heat-sdk-cli/src/cli_commands/run/remote/remote.rs deleted file mode 100644 index 2260e61..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/remote/remote.rs +++ /dev/null @@ -1,23 +0,0 @@ -use clap::Parser; - -use crate::cli_commands::remote::{ - inference::{self, RemoteInferenceRunArgs}, - training::{self, RemoteTrainingRunArgs}, -}; - -/// Run a training or inference remotely. -/// Not yet supported. -#[derive(Parser, Debug)] -pub enum RemoteRunSubcommand { - /// todo - Training(RemoteTrainingRunArgs), - /// todo - Inference(RemoteInferenceRunArgs), -} - -pub(crate) fn handle_command(args: RemoteRunSubcommand) -> anyhow::Result<()> { - match args { - RemoteRunSubcommand::Training(training_args) => training::handle_command(training_args), - RemoteRunSubcommand::Inference(inference_args) => inference::handle_command(inference_args), - } -} diff --git a/crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs b/crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs index b2b1ebb..d0cddad 100644 --- a/crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs +++ b/crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs @@ -37,14 +37,6 @@ pub struct RemoteTrainingRunArgs { help = " The Heat API key." )] key: String, - /// The Heat API endpoint - #[clap( - short = 'e', - long = "endpoint", - help = "The Heat API endpoint.", - default_value = "http://127.0.0.1:9001" - )] - pub heat_endpoint: String, /// The runner group name #[clap( short = 'r', @@ -55,13 +47,14 @@ pub struct RemoteTrainingRunArgs { pub runner: String, } -fn create_heat_client(api_key: &str, url: &str, project_path: &str) -> HeatClient { +fn create_heat_client(api_key: &str, url: &str, wss: bool, project_path: &str) -> HeatClient { let creds = HeatCredentials::new(api_key.to_owned()); let client_config = HeatClientConfig::builder( creds, ProjectPath::try_from(project_path.to_string()).expect("Project path should be valid."), ) .with_endpoint(url) + .with_wss(wss) .with_num_retries(10) .build(); HeatClient::create(client_config) @@ -72,7 +65,12 @@ pub(crate) fn handle_command( args: RemoteTrainingRunArgs, context: HeatCliContext, ) -> anyhow::Result<()> { - let heat_client = create_heat_client(&args.key, &args.heat_endpoint, &args.project_path); + let heat_client = create_heat_client( + &args.key, + context.get_api_endpoint().as_str(), + context.get_wss(), + &args.project_path, + ); let crates = crate::util::cargo::package::package( &context.get_artifacts_dir_path(), diff --git a/crates/heat-sdk-cli/src/cli_commands/run/run.rs b/crates/heat-sdk-cli/src/cli_commands/run/run.rs deleted file mode 100644 index 1413ef4..0000000 --- a/crates/heat-sdk-cli/src/cli_commands/run/run.rs +++ /dev/null @@ -1,27 +0,0 @@ -use clap::Parser; - -use crate::context::HeatCliContext; - -use crate::cli_commands::{ - local::{self, LocalRunSubcommand}, - remote::{self, RemoteRunSubcommand}, -}; - -/// Run a training or inference locally or trigger a remote run. -/// Only local training is supported at the moment. -#[derive(Parser, Debug)] -pub enum RunLocationType { - /// {training|inference} : Run a training or inference locally. - #[command(subcommand)] - Local(LocalRunSubcommand), - /// todo - #[command(subcommand)] - Remote(RemoteRunSubcommand), -} - -pub(crate) fn handle_command(args: RunLocationType, context: HeatCliContext) -> anyhow::Result<()> { - match args { - RunLocationType::Local(local_args) => local::handle_command(local_args, context), - RunLocationType::Remote(remote_args) => remote::handle_command(remote_args), - } -} diff --git a/crates/heat-sdk-cli/src/config.rs b/crates/heat-sdk-cli/src/config.rs new file mode 100644 index 0000000..e061e2f --- /dev/null +++ b/crates/heat-sdk-cli/src/config.rs @@ -0,0 +1,14 @@ +#[derive(Debug, Clone)] +pub struct Config { + pub api_endpoint: String, + pub wss: bool, +} + +impl Default for Config { + fn default() -> Self { + Config { + api_endpoint: String::from("https://heat.tracel.ai/api/"), + wss: true, + } + } +} diff --git a/crates/heat-sdk-cli/src/context.rs b/crates/heat-sdk-cli/src/context.rs index 03f6a06..a49a917 100644 --- a/crates/heat-sdk-cli/src/context.rs +++ b/crates/heat-sdk-cli/src/context.rs @@ -1,5 +1,6 @@ use crate::{ commands::{BuildCommand, RunCommand, RunParams}, + config::Config, generation::{FileTree, GeneratedCrate, HeatDir}, print_info, }; @@ -11,10 +12,17 @@ pub struct HeatCliContext { generated_crate_name: Option, build_profile: String, heat_dir: HeatDir, + api_endpoint: url::Url, + wss: bool, } impl HeatCliContext { - pub fn new(user_project_name: String, user_crate_dir: PathBuf) -> Self { + pub fn new(config: &Config) -> Self { + let user_project_name = std::env::var("CARGO_PKG_NAME").expect("CARGO_PKG_NAME not set"); + let user_crate_dir: PathBuf = std::env::var("CARGO_MANIFEST_DIR") + .expect("CARGO_MANIFEST_DIR not set") + .into(); + let heat_dir = match HeatDir::try_from_path(&user_crate_dir) { Ok(heat_dir) => heat_dir, Err(_) => HeatDir::new(), @@ -26,6 +34,11 @@ impl HeatCliContext { generated_crate_name: None, build_profile: "release".to_string(), heat_dir, + api_endpoint: config + .api_endpoint + .parse::() + .expect("API endpoint should be valid"), + wss: config.wss, } } @@ -38,6 +51,14 @@ impl HeatCliContext { self.user_project_name.as_str() } + pub fn get_api_endpoint(&self) -> &url::Url { + &self.api_endpoint + } + + pub fn get_wss(&self) -> bool { + self.wss + } + fn get_generated_crate_path(&self) -> PathBuf { let crate_name = self .generated_crate_name @@ -100,6 +121,8 @@ impl HeatCliContext { .env("HEAT_PROJECT_DIR", &self.user_crate_dir) .args(["--project", project]) .args(["--key", key]) + .args(["--heat-endpoint", self.get_api_endpoint().as_str()]) + .args(["--wss", self.get_wss().to_string().as_str()]) .args(["train", function, config_path]); command } @@ -154,14 +177,7 @@ impl HeatCliContext { .to_str() .unwrap(), ]) - .args(["--message-format", "short"]) - // todo: remove once correct burn version is published - .args([ - "--config", - "patch.crates-io.burn.git='https://github.com/tracel-ai/burn'", - "--config", - "patch.crates-io.burn.rev='a72a533'", - ]); + .args(["--message-format", "short"]); if let Some(target_dir) = new_target_dir { build_command.args(["--target-dir", &target_dir]); } diff --git a/crates/heat-sdk-cli/src/generation/crate_gen/mod.rs b/crates/heat-sdk-cli/src/generation/crate_gen/mod.rs index 3c9ec1c..0181eb7 100644 --- a/crates/heat-sdk-cli/src/generation/crate_gen/mod.rs +++ b/crates/heat-sdk-cli/src/generation/crate_gen/mod.rs @@ -236,7 +236,12 @@ fn generate_clap_cli() -> proc_macro2::TokenStream { .short('e') .long("heat-endpoint") .help("The Heat endpoint") - .default_value("http://127.0.0.1:9001"), + .required(true), + clap::Arg::new("wss") + .short('w') + .long("wss") + .help("Whether to use WSS") + .required(true), ]); command @@ -248,7 +253,6 @@ fn generate_training_function( train_func_match: &proc_macro2::TokenStream, ) -> proc_macro2::TokenStream { quote! { - let client = create_heat_client(&key, &heat_endpoint, &project); let training_config_str = std::fs::read_to_string(&config_path).expect("Config should be read"); let training_config: serde_json::Value = serde_json::from_str(&training_config_str).expect("Config should be deserialized"); @@ -337,10 +341,11 @@ fn generate_main_rs(main_backend: &BackendType) -> String { use tracel::heat::command::train::*; use burn::prelude::*; - fn create_heat_client(api_key: &str, url: &str, project: &str) -> tracel::heat::client::HeatClient { + fn create_heat_client(api_key: &str, url: &str, project: &str, wss: bool) -> tracel::heat::client::HeatClient { let creds = tracel::heat::client::HeatCredentials::new(api_key.to_owned()); let client_config = tracel::heat::client::HeatClientConfig::builder(creds, tracel::heat::schemas::ProjectPath::try_from(project.to_string()).expect("Project path should be valid.")) .with_endpoint(url) + .with_wss(wss) .with_num_retries(10) .build(); tracel::heat::client::HeatClient::create(client_config) @@ -352,12 +357,16 @@ fn generate_main_rs(main_backend: &BackendType) -> String { let device = #backend_default_device; + let key = matches.get_one::("key").expect("key should be set."); + let heat_endpoint = matches.get_one::("heat-endpoint").expect("heat-endpoint should be set."); + let project = matches.get_one::("project").expect("project should be set."); + let wss = matches.get_one::("wss").expect("wss should be set.").parse::().expect("wss should be a boolean."); + + let client = create_heat_client(&key, &heat_endpoint, &project, wss); + if let Some(train_matches) = matches.subcommand_matches("train") { let func = train_matches.get_one::("func").expect("func should be set."); let config_path = train_matches.get_one::("config").expect("config should be set."); - let project = matches.get_one::("project").expect("project should be set."); - let key = matches.get_one::("key").expect("key should be set."); - let heat_endpoint = matches.get_one::("heat-endpoint").expect("heat-endpoint should be set."); #generated_training } diff --git a/crates/heat-sdk-cli/src/lib.rs b/crates/heat-sdk-cli/src/lib.rs index 472e7d5..ae8b96c 100644 --- a/crates/heat-sdk-cli/src/lib.rs +++ b/crates/heat-sdk-cli/src/lib.rs @@ -1,4 +1,5 @@ pub mod cli; +pub mod config; pub mod registry; mod cli_commands; diff --git a/crates/heat-sdk/Cargo.toml b/crates/heat-sdk/Cargo.toml index 0765181..a08b180 100644 --- a/crates/heat-sdk/Cargo.toml +++ b/crates/heat-sdk/Cargo.toml @@ -28,7 +28,7 @@ thiserror = { workspace = true } tracing = { version = "0.1.40" } tracing-core = { version = "0.1.32" } tracing-subscriber = { version = "0.3.18" } -tungstenite = { version = "0.21.0" } +tungstenite = { version = "0.24.0", features = ["native-tls"] } uuid = { workspace = true } regex = { workspace = true } once_cell = { workspace = true } diff --git a/crates/heat-sdk/src/client.rs b/crates/heat-sdk/src/client.rs index d7629cd..6a4d17c 100644 --- a/crates/heat-sdk/src/client.rs +++ b/crates/heat-sdk/src/client.rs @@ -40,6 +40,8 @@ impl From for String { pub struct HeatClientConfig { /// The endpoint of the Heat API pub endpoint: String, + /// Whether to use a secure WebSocket connection + pub wss: bool, /// Heat credential to create a session with the Heat API pub credentials: HeatCredentials, /// The number of retries to attempt when connecting to the Heat API. @@ -70,6 +72,7 @@ impl HeatClientConfigBuilder { HeatClientConfigBuilder { config: HeatClientConfig { endpoint: "http://127.0.0.1:9001".into(), + wss: false, credentials: creds, num_retries: 3, retry_interval: 3, @@ -84,6 +87,13 @@ impl HeatClientConfigBuilder { self } + /// Set whether to use a secure WebSocket connection + /// If this is set to true, the WebSocket connection will use the `wss` protocol instead of `ws`. + pub fn with_wss(mut self, wss: bool) -> HeatClientConfigBuilder { + self.config.wss = wss; + self + } + /// Set the number of retries to attempt when connecting to the Heat API pub fn with_num_retries(mut self, num_retries: u8) -> HeatClientConfigBuilder { self.config.num_retries = num_retries; @@ -115,7 +125,11 @@ pub type HeatClientState = HeatClient; impl HeatClient { fn new(config: HeatClientConfig) -> HeatClient { - let http_client = HttpClient::new(config.endpoint.clone()); + let url = config + .endpoint + .parse() + .expect("Should be able to parse the URL"); + let http_client = HttpClient::new(url, config.wss); HeatClient { config, diff --git a/crates/heat-sdk/src/http/client.rs b/crates/heat-sdk/src/http/client.rs index 2c63310..95a72f8 100644 --- a/crates/heat-sdk/src/http/client.rs +++ b/crates/heat-sdk/src/http/client.rs @@ -49,34 +49,26 @@ impl ResponseExt for reqwest::blocking::Response { #[derive(Debug, Clone)] pub struct HttpClient { http_client: reqwest::blocking::Client, - base_url: String, + base_url: Url, + ws_secure: bool, session_cookie: Option, } impl HttpClient { /// Create a new HttpClient with the given base URL and API key. - pub fn new(base_url: String) -> Self { + pub fn new(base_url: Url, ws_secure: bool) -> Self { Self { http_client: reqwest::blocking::Client::new(), base_url, + ws_secure, session_cookie: None, } } - /// Create a new HttpClient with the given base URL, API key, and session cookie. - #[allow(dead_code)] - pub fn with_session_cookie(base_url: String, session_cookie: String) -> Self { - Self { - http_client: reqwest::blocking::Client::new(), - base_url, - session_cookie: Some(session_cookie), - } - } - /// Check if the Heat server is reachable. #[allow(dead_code)] pub fn health_check(&self) -> Result<(), HeatHttpError> { - let url = format!("{}/health", self.base_url); + let url = self.join("health"); self.http_client.get(url).send()?.map_to_heat_err()?; Ok(()) } @@ -86,9 +78,17 @@ impl HttpClient { self.session_cookie.as_ref() } + /// Join the given path to the base URL. + fn join(&self, path: &str) -> Url { + self.base_url + .join(path) + .expect("Should be able to join url") + } + /// Log in to the Heat server with the given credentials. pub fn login(&mut self, credentials: &HeatCredentials) -> Result<(), HeatHttpError> { - let url = format!("{}/login/api-key", self.base_url); + let url = self.join("login/api-key"); + let res = self .http_client .post(url) @@ -123,16 +123,14 @@ impl HttpClient { project_name: &str, exp_num: i32, ) -> String { - let mut url: Url = self - .base_url - .parse() - .expect("Should be able to parse base url"); - url.set_scheme("ws") + let mut url = self.join(&format!( + "projects/{}/{}/experiments/{}/ws", + owner_name, project_name, exp_num + )); + url.set_scheme(if self.ws_secure { "wss" } else { "ws" }) .expect("Should be able to set ws scheme"); - format!( - "{}projects/{}/{}/experiments/{}/ws", - url, owner_name, project_name, exp_num - ) + + url.to_string() } /// Create a new experiment for the given project. @@ -145,10 +143,10 @@ impl HttpClient { ) -> Result { self.validate_session_cookie()?; - let url = format!( - "{}/projects/{}/{}/experiments", - self.base_url, owner_name, project_name - ); + let url = self.join(&format!( + "projects/{}/{}/experiments", + owner_name, project_name + )); // Create a new experiment let experiment_response = self @@ -189,12 +187,14 @@ impl HttpClient { config: serde_json::to_value(config).unwrap(), }; + let url = self.join(&format!( + "projects/{}/{}/experiments/{}/start", + owner_name, project_name, exp_num + )); + // Start the experiment self.http_client - .put(format!( - "{}/projects/{}/{}/experiments/{}/start", - self.base_url, owner_name, project_name, exp_num - )) + .put(url) .header(COOKIE, self.session_cookie.as_ref().unwrap()) .json(&json) .send()? @@ -215,10 +215,10 @@ impl HttpClient { ) -> Result<(), HeatHttpError> { self.validate_session_cookie()?; - let url = format!( - "{}/projects/{}/{}/experiments/{}/end", - self.base_url, owner_name, project_name, exp_num - ); + let url = self.join(&format!( + "projects/{}/{}/experiments/{}/end", + owner_name, project_name, exp_num + )); let end_status: EndExperimentSchema = match end_status { EndExperimentStatus::Success => EndExperimentSchema::Success, @@ -247,10 +247,10 @@ impl HttpClient { ) -> Result { self.validate_session_cookie()?; - let url: String = format!( - "{}/projects/{}/{}/experiments/{}/checkpoints/{}", - self.base_url, owner_name, project_name, exp_num, file_name - ); + let url = self.join(&format!( + "projects/{}/{}/experiments/{}/checkpoints/{}", + owner_name, project_name, exp_num, file_name + )); let save_url = self .http_client @@ -276,10 +276,10 @@ impl HttpClient { ) -> Result { self.validate_session_cookie()?; - let url: String = format!( - "{}/projects/{}/{}/experiments/{}/checkpoints/{}", - self.base_url, owner_name, project_name, exp_num, file_name - ); + let url = self.join(&format!( + "projects/{}/{}/experiments/{}/checkpoints/{}", + owner_name, project_name, exp_num, file_name + )); let load_url = self .http_client @@ -304,10 +304,10 @@ impl HttpClient { ) -> Result { self.validate_session_cookie()?; - let url = format!( - "{}/projects/{}/{}/experiments/{}/save_model", - self.base_url, owner_name, project_name, exp_num - ); + let url = self.join(&format!( + "projects/{}/{}/experiments/{}/save_model", + owner_name, project_name, exp_num + )); let save_url = self .http_client @@ -332,10 +332,10 @@ impl HttpClient { ) -> Result { self.validate_session_cookie()?; - let url = format!( - "{}/projects/{}/{}/experiments/{}/logs", - self.base_url, owner_name, project_name, exp_num - ); + let url = self.join(&format!( + "projects/{}/{}/experiments/{}/logs", + owner_name, project_name, exp_num + )); let logs_upload_url = self .http_client @@ -390,10 +390,10 @@ impl HttpClient { ) -> Result { self.validate_session_cookie()?; - let url = format!( - "{}/projects/{}/{}/code/upload", - self.base_url, owner_name, project_name - ); + let url = self.join(&format!( + "projects/{}/{}/code/upload", + owner_name, project_name + )); let response = self .http_client @@ -421,10 +421,10 @@ impl HttpClient { ) -> Result<(), HeatHttpError> { self.validate_session_cookie()?; - let url = format!( - "{}/projects/{}/{}/jobs/queue", - self.base_url, owner_name, project_name - ); + let url = self.join(&format!( + "projects/{}/{}/jobs/queue", + owner_name, project_name + )); let body = RunnerQueueJobParamsSchema { runner_group_name: runner_group_name.to_string(), diff --git a/examples/guide-cli/src/main.rs b/examples/guide-cli/src/main.rs index 1542c29..afdd536 100644 --- a/examples/guide-cli/src/main.rs +++ b/examples/guide-cli/src/main.rs @@ -1,2 +1,2 @@ -#[tracel::heat::macros::heat_cli_main(guide_cli)] -fn main() {} +#[tracel::heat::macros::heat_cli_main(guide_cli)] +fn main() {}