diff --git a/clients/python/src/events/async/handle.rs b/clients/python/src/events/async/handle.rs index 2af619d28..3339bca9e 100644 --- a/clients/python/src/events/async/handle.rs +++ b/clients/python/src/events/async/handle.rs @@ -6,7 +6,7 @@ use crate::util; #[pyclass] pub struct ActorHandle { - pub handle: actor_core_rs::handle::ActorHandle, + pub handle: actor_core_rs::connection::ActorHandle, } #[pymethods] diff --git a/clients/python/src/events/sync/handle.rs b/clients/python/src/events/sync/handle.rs index 0bf9979ee..233a7b5bd 100644 --- a/clients/python/src/events/sync/handle.rs +++ b/clients/python/src/events/sync/handle.rs @@ -5,7 +5,7 @@ use crate::util::{self, SYNC_RUNTIME}; #[pyclass] pub struct ActorHandle { - pub handle: actor_core_rs::handle::ActorHandle, + pub handle: actor_core_rs::connection::ActorHandle, } #[pymethods] diff --git a/clients/python/src/simple/async/handle.rs b/clients/python/src/simple/async/handle.rs index 6c30bdacd..871b02f92 100644 --- a/clients/python/src/simple/async/handle.rs +++ b/clients/python/src/simple/async/handle.rs @@ -14,13 +14,13 @@ struct ActorEvent { #[pyclass] pub struct ActorHandle { - handle: actor_core_rs::handle::ActorHandle, + handle: actor_core_rs::connection::ActorHandle, event_rx: Option>, event_tx: mpsc::Sender, } impl ActorHandle { - pub fn new(handle: actor_core_rs::handle::ActorHandle) -> Self { + pub fn new(handle: actor_core_rs::connection::ActorHandle) -> Self { let (event_tx, event_rx) = mpsc::channel(EVENT_BUFFER_SIZE); Self { diff --git a/clients/python/src/simple/sync/handle.rs b/clients/python/src/simple/sync/handle.rs index 3213c1b88..037e30892 100644 --- a/clients/python/src/simple/sync/handle.rs +++ b/clients/python/src/simple/sync/handle.rs @@ -14,13 +14,13 @@ struct ActorEvent { #[pyclass] pub struct ActorHandle { - handle: actor_core_rs::handle::ActorHandle, + handle: actor_core_rs::connection::ActorHandle, event_rx: Option>, event_tx: mpsc::Sender, } impl ActorHandle { - pub fn new(handle: actor_core_rs::handle::ActorHandle) -> Self { + pub fn new(handle: actor_core_rs::connection::ActorHandle) -> Self { let (event_tx, event_rx) = mpsc::channel(EVENT_BUFFER_SIZE); Self { diff --git a/clients/rust/Cargo.toml b/clients/rust/Cargo.toml index 925eb43ea..598d859b5 100644 --- a/clients/rust/Cargo.toml +++ b/clients/rust/Cargo.toml @@ -24,7 +24,7 @@ tungstenite = "0.26.2" urlencoding = "2.1.3" [dev-dependencies] -tracing-subscriber = "0.3.19" +tracing-subscriber = { version = "0.3.19", features = ["env-filter", "std", "registry"]} tempfile = "3.10.1" tokio-test = "0.4.3" fs_extra = "1.3.0" diff --git a/clients/rust/README.md b/clients/rust/README.md index 0644997ec..cbfc409b9 100644 --- a/clients/rust/README.md +++ b/clients/rust/README.md @@ -24,21 +24,25 @@ rivetkit-client = "0.1.0" ### Step 2: Connect to Actor ```rust -use actor_core_client::{client::{Client, GetOptions}, drivers::TransportKind, encoding::EncodingKind}; +use actor_core_client::{Client, EncodingKind, GetOrCreateOptions, TransportKind}; use serde_json::json; #[tokio::main] async fn main() -> anyhow::Result<()> { // Create a client connected to your RivetKit manager let client = Client::new( - "http://localhost:6420".to_string(), + "http://localhost:6420", TransportKind::Sse, EncodingKind::Json ); // Connect to a chat room actor - let chat_room = client.get("chat-room", GetOptions::default()).await?; - + let chat_room = client.get_or_create( + "chat-room", + ["keys-here"].into(), + GetOrCreateOptions::default() + )?.connect(); + // Listen for new messages chat_room.on_event("newMessage", |args| { let username = args[0].as_str().unwrap(); @@ -53,7 +57,7 @@ async fn main() -> anyhow::Result<()> { ]).await?; // When finished - chat_room.disconnect().await; + client.disconnect(); Ok(()) } diff --git a/clients/rust/src/client.rs b/clients/rust/src/client.rs index 7f1ec24ea..67e3b6e9f 100644 --- a/clients/rust/src/client.rs +++ b/clients/rust/src/client.rs @@ -1,250 +1,189 @@ -use anyhow::Result; -use serde_json::{json, Value}; - -use crate::drivers::TransportKind; -use crate::encoding::EncodingKind; -use crate::handle::{ActorHandle, ActorHandleInner}; - -type ActorTags = Vec<(String, String)>; - -pub struct CreateRequestMetadata { - pub tags: ActorTags, - pub region: Option, -} +use std::sync::Arc; -impl Default for CreateRequestMetadata { - fn default() -> Self { - Self { - tags: vec![], - region: None, - } - } -} +use anyhow::Result; +use serde_json::{Value as JsonValue}; -pub struct PartialCreateRequestMetadata { - pub tags: Option, - pub region: Option, -} +use crate::{ + common::{resolve_actor_id, ActorKey, EncodingKind, TransportKind}, + handle::ActorHandle, + protocol::query::* +}; +#[derive(Default)] pub struct GetWithIdOptions { - pub params: Option, -} - -impl Default for GetWithIdOptions { - fn default() -> Self { - Self { params: None } - } + pub params: Option, } +#[derive(Default)] pub struct GetOptions { - pub tags: Option, - pub params: Option, - pub no_create: Option, - pub create: Option, + pub params: Option, } -impl Default for GetOptions { - fn default() -> Self { - Self { - tags: None, - params: None, - no_create: None, - create: None, - } - } +#[derive(Default)] +pub struct GetOrCreateOptions { + pub params: Option, + pub create_in_region: Option, + pub create_with_input: Option, } +#[derive(Default)] pub struct CreateOptions { - pub params: Option, - pub create: CreateRequestMetadata, + pub params: Option, + pub region: Option, + pub input: Option, } -impl Default for CreateOptions { - fn default() -> Self { - Self { - params: None, - create: CreateRequestMetadata::default() - } - } -} pub struct Client { manager_endpoint: String, encoding_kind: EncodingKind, transport_kind: TransportKind, + shutdown_tx: Arc>, } impl Client { pub fn new( - manager_endpoint: String, + manager_endpoint: &str, transport_kind: TransportKind, encoding_kind: EncodingKind, ) -> Self { Self { - manager_endpoint, + manager_endpoint: manager_endpoint.to_string(), encoding_kind, transport_kind, + shutdown_tx: Arc::new(tokio::sync::broadcast::channel(1).0) } } - async fn post_manager_endpoint(&self, path: &str, body: Value) -> Result { - let client = reqwest::Client::new(); - let req = client.post(format!("{}{}", self.manager_endpoint, path)); - let req = req.header("Content-Type", "application/json"); - let req = req.body(serde_json::to_string(&body)?); - let res = req.send().await?; - let body = res.text().await?; - - let body: Value = serde_json::from_str(&body)?; + fn create_handle( + &self, + params: Option, + query: ActorQuery + ) -> ActorHandle { + let handle = ActorHandle::new( + &self.manager_endpoint, + params, + query, + self.shutdown_tx.clone(), + self.transport_kind, + self.encoding_kind + ); - Ok(body) + handle } - #[allow(dead_code)] - async fn get_manager_endpoint(&self, path: &str) -> Result { - let client = reqwest::Client::new(); - let req = client.get(format!("{}{}", self.manager_endpoint, path)); - let res = req.send().await?; - let body = res.text().await?; - let body: Value = serde_json::from_str(&body)?; + pub fn get( + &self, + name: &str, + key: ActorKey, + opts: GetOptions + ) -> Result { + let actor_query = ActorQuery::GetForKey { + get_for_key: GetForKeyRequest { + name: name.to_string(), + key, + } + }; + + let handle = self.create_handle( + opts.params, + actor_query + ); - Ok(body) + Ok(handle) } - pub async fn get(&self, name: &str, opts: GetOptions) -> Result { - // Convert tags to a map for JSON - let tags_map: serde_json::Map = opts.tags - .unwrap_or_default() - .into_iter() - .map(|(k, v)| (k, json!(v))) - .collect(); - - // Build create object if no_create is false - let create = if !opts.no_create.unwrap_or(false) { - // Start with create options if provided - if let Some(create_opts) = &opts.create { - // Build tags map - use create.tags if provided, otherwise fall back to query tags - let create_tags = if let Some(tags) = &create_opts.tags { - tags.iter() - .map(|(k, v)| (k.clone(), json!(v.clone()))) - .collect() - } else { - tags_map.clone() - }; - - // Build create object with name, tags, and optional region - let mut create_obj = json!({ - "name": name, - "tags": create_tags - }); - - if let Some(region) = &create_opts.region { - create_obj["region"] = json!(region.clone()); - } - - Some(create_obj) - } else { - // Create with just the name and query tags - Some(json!({ - "name": name, - "tags": tags_map - })) - } - } else { - None - }; - - // Build the request body - let body = json!({ - "query": { - "getOrCreateForTags": { - "name": name, - "tags": tags_map, - "create": create - } + pub fn get_for_id( + &self, + actor_id: &str, + opts: GetOptions + ) -> Result { + let actor_query = ActorQuery::GetForId { + get_for_id: GetForIdRequest { + actor_id: actor_id.to_string(), } - }); - let res_json = self.post_manager_endpoint("/manager/actors", body).await?; - let Some(endpoint) = res_json["endpoint"].as_str() else { - return Err(anyhow::anyhow!( - "No endpoint returned. Request failed? {:?}", - res_json - )); }; - let handle = ActorHandleInner::new( - endpoint.to_string(), - self.transport_kind, - self.encoding_kind, + let handle = self.create_handle( opts.params, - )?; - handle.start_connection().await; + actor_query + ); Ok(handle) } - pub async fn get_with_id(&self, actor_id: &str, opts: GetWithIdOptions) -> Result { - let body = json!({ - "query": { - "getForId": { - "actorId": actor_id, - } - }, - }); - let res_json = self.post_manager_endpoint("/manager/actors", body).await?; - let Some(endpoint) = res_json["endpoint"].as_str() else { - return Err(anyhow::anyhow!( - "No endpoint returned. Request failed? {:?}", - res_json - )); + pub fn get_or_create( + &self, + name: &str, + key: ActorKey, + opts: GetOrCreateOptions + ) -> Result { + let input = opts.create_with_input; + let region = opts.create_in_region; + + let actor_query = ActorQuery::GetOrCreateForKey { + get_or_create_for_key: GetOrCreateRequest { + name: name.to_string(), + key: key, + input, + region + } }; - let handle = ActorHandleInner::new( - endpoint.to_string(), - self.transport_kind, - self.encoding_kind, + let handle = self.create_handle( opts.params, - )?; - handle.start_connection().await; + actor_query, + ); Ok(handle) } - pub async fn create(&self, name: &str, opts: CreateOptions) -> Result { - let mut tag_map = serde_json::Map::new(); - - for (key, value) in opts.create.tags { - tag_map.insert(key, json!(value)); - } + pub async fn create( + &self, + name: &str, + key: ActorKey, + opts: CreateOptions + ) -> Result { + let input = opts.input; + let region = opts.region; + + let create_query = ActorQuery::Create { + create: CreateRequest { + name: name.to_string(), + key, + input, + region + } + }; - let mut req_body = serde_json::Map::new(); - req_body.insert("name".to_string(), json!(name.to_string())); - req_body.insert("tags".to_string(), json!(tag_map)); - if let Some(region) = opts.create.region { - req_body.insert("region".to_string(), json!(region)); - } + let actor_id = resolve_actor_id( + &self.manager_endpoint, + create_query, + self.encoding_kind + ).await?; - let body = json!({ - "query": { - "create": req_body - }, - }); - let res_json = self.post_manager_endpoint("/manager/actors", body).await?; - let Some(endpoint) = res_json["endpoint"].as_str() else { - return Err(anyhow::anyhow!( - "No endpoint returned. Request failed? {:?}", - res_json - )); + let get_query = ActorQuery::GetForId { + get_for_id: GetForIdRequest { + actor_id, + } }; - let handle = ActorHandleInner::new( - endpoint.to_string(), - self.transport_kind, - self.encoding_kind, + let handle = self.create_handle( opts.params, - )?; - handle.start_connection().await; + get_query + ); Ok(handle) } + + pub fn disconnect(self) { + drop(self) + } } + +impl Drop for Client { + fn drop(&mut self) { + // Notify all subscribers to shutdown + let _ = self.shutdown_tx.send(()); + } +} \ No newline at end of file diff --git a/clients/rust/src/common.rs b/clients/rust/src/common.rs new file mode 100644 index 000000000..62c62974c --- /dev/null +++ b/clients/rust/src/common.rs @@ -0,0 +1,199 @@ +use anyhow::Result; +use reqwest::{header::USER_AGENT, RequestBuilder}; +use serde::{de::DeserializeOwned, Serialize}; +use serde_json::{json, Value as JsonValue}; +use tracing::debug; + +use crate::protocol::query::ActorQuery; + +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); +pub const USER_AGENT_VALUE: &str = concat!("ActorClient-Rust/", env!("CARGO_PKG_VERSION")); + +pub const HEADER_ACTOR_QUERY: &str = "X-AC-Query"; +pub const HEADER_ENCODING: &str = "X-AC-Encoding"; +pub const HEADER_CONN_PARAMS: &str = "X-AC-Conn-Params"; +pub const HEADER_ACTOR_ID: &str = "X-AC-Actor"; +pub const HEADER_CONN_ID: &str = "X-AC-Conn"; +pub const HEADER_CONN_TOKEN: &str = "X-AC-Conn-Token"; + +#[derive(Debug, Clone, Copy)] +pub enum TransportKind { + WebSocket, + Sse, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EncodingKind { + Json, + Cbor, +} + +impl EncodingKind { + pub fn as_str(&self) -> &str { + match self { + EncodingKind::Json => "json", + EncodingKind::Cbor => "cbor", + } + } +} + +impl ToString for EncodingKind { + fn to_string(&self) -> String { + self.as_str().to_string() + } +} + + + +// Max size of each entry is 128 bytes +pub type ActorKey = Vec; + +pub struct HttpRequestOptions<'a, T: Serialize> { + pub method: &'a str, + pub url: &'a str, + pub headers: Vec<(&'a str, String)>, + pub body: Option, + pub encoding_kind: EncodingKind +} + +impl<'a, T: Serialize> Default for HttpRequestOptions<'a, T> { + fn default() -> Self { + Self { + method: "GET", + url: "", + headers: Vec::new(), + body: None, + encoding_kind: EncodingKind::Json + } + } +} + +fn build_http_request(opts: &HttpRequestOptions) -> Result +where + RQ: Serialize +{ + let client = reqwest::Client::new(); + let mut req = client.request( + reqwest::Method::from_bytes(opts.method.as_bytes()).unwrap(), + opts.url, + ); + + for (key, value) in &opts.headers { + req = req.header(*key, value); + } + + if opts.method == "POST" || opts.method == "PUT" { + let Some(body) = &opts.body else { + return Err(anyhow::anyhow!("Body is required for POST/PUT requests")); + }; + + match opts.encoding_kind { + EncodingKind::Json => { + req = req.header("Content-Type", "application/json"); + let body = serde_json::to_string(&body)?; + req = req.body(body); + } + EncodingKind::Cbor => { + req = req.header("Content-Type", "application/octet-stream"); + let body =serde_cbor::to_vec(&body)?; + req = req.body(body); + } + } + }; + + req = req.header(USER_AGENT, USER_AGENT_VALUE); + + Ok(req) +} + +async fn send_http_request_raw(req: reqwest::RequestBuilder) -> Result { + let res = req.send().await?; + + if !res.status().is_success() { + // TODO: Decode + /* + let data: Option = match opts.encoding_kind { + EncodingKind::Json => { + let data = res.text().await?; + + serde_json::from_str::(&data).ok() + } + EncodingKind::Cbor => { + let data = res.bytes().await?; + + serde_cbor::from_slice(&data).ok() + } + }; + + match data { + Some(data) => { + return Err(anyhow::anyhow!( + "HTTP request failed with status: {}, error: {}", + res.status(), + data.m + )); + }, + None => { + + } + } + */ + return Err(anyhow::anyhow!( + "HTTP request failed with status: {}", + res.status() + )); + } + + Ok(res) +} + +pub async fn send_http_request<'a, RQ, RS>(opts: HttpRequestOptions<'a, RQ>) -> Result +where + RQ: Serialize, + RS: DeserializeOwned, +{ + let req = build_http_request(&opts)?; + let res = send_http_request_raw(req).await?; + + let res: RS = match opts.encoding_kind { + EncodingKind::Json => { + let data = res.text().await?; + serde_json::from_str(&data)? + } + EncodingKind::Cbor => { + let bytes = res.bytes().await?; + serde_cbor::from_slice(&bytes)? + } + }; + + Ok(res) +} + + +pub async fn resolve_actor_id( + manager_endpoint: &str, + query: ActorQuery, + encoding_kind: EncodingKind +) -> Result { + #[derive(serde::Serialize, serde::Deserialize)] + struct ResolveResponse { + i: String, + } + + let query = serde_json::to_string(&query)?; + + let res = send_http_request::( + HttpRequestOptions { + method: "POST", + url: &format!("{}/actors/resolve", manager_endpoint), + headers: vec![ + (HEADER_ENCODING, encoding_kind.to_string()), + (HEADER_ACTOR_QUERY, query), + ], + body: Some(json!({})), + encoding_kind, + } + ).await?; + + Ok(res.i) +} \ No newline at end of file diff --git a/clients/rust/src/connection.rs b/clients/rust/src/connection.rs new file mode 100644 index 000000000..935059b74 --- /dev/null +++ b/clients/rust/src/connection.rs @@ -0,0 +1,434 @@ +use anyhow::Result; +use futures_util::FutureExt; +use serde_json::Value; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::atomic::{AtomicI64, Ordering}; +use std::time::Duration; +use std::{collections::HashMap, sync::Arc}; +use tokio::sync::{broadcast, oneshot, watch, Mutex}; + +use crate::{ + backoff::Backoff, + protocol::{query::ActorQuery, *}, + drivers::*, + EncodingKind, + TransportKind +}; +use tracing::debug; + + +type RpcResponse = Result; +type EventCallback = dyn Fn(&Vec) + Send + Sync; + +struct SendMsgOpts { + ephemeral: bool, +} + +impl Default for SendMsgOpts { + fn default() -> Self { + Self { ephemeral: false } + } +} + +// struct WatchPair { +// tx: watch::Sender, +// rx: watch::Receiver, +// } +type WatchPair = (watch::Sender, watch::Receiver); + +pub type ActorConnection = Arc; + +struct ConnectionAttempt { + did_open: bool, + _task_end_reason: DriverStopReason, +} + +pub struct ActorConnectionInner { + endpoint: String, + transport_kind: TransportKind, + encoding_kind: EncodingKind, + query: ActorQuery, + parameters: Option, + + driver: Mutex>, + msg_queue: Mutex>>, + + rpc_counter: AtomicI64, + in_flight_rpcs: Mutex>>, + + event_subscriptions: Mutex>>>, + + dc_watch: WatchPair, + disconnection_rx: Mutex>>, +} + +impl ActorConnectionInner { + pub(crate) fn new( + endpoint: String, + query: ActorQuery, + transport_kind: TransportKind, + encoding_kind: EncodingKind, + parameters: Option, + ) -> ActorConnection { + Arc::new(Self { + endpoint: endpoint.clone(), + transport_kind, + encoding_kind, + query, + parameters, + driver: Mutex::new(None), + msg_queue: Mutex::new(Vec::new()), + rpc_counter: AtomicI64::new(0), + in_flight_rpcs: Mutex::new(HashMap::new()), + event_subscriptions: Mutex::new(HashMap::new()), + dc_watch: watch::channel(false), + disconnection_rx: Mutex::new(None), + }) + } + + fn is_disconnecting(self: &Arc) -> bool { + *self.dc_watch.1.borrow() == true + } + + async fn try_connect(self: &Arc) -> ConnectionAttempt { + let Ok((driver, mut recver, task)) = connect_driver( + self.transport_kind, + DriverConnectArgs { + endpoint: self.endpoint.clone(), + query: self.query.clone(), + encoding_kind: self.encoding_kind, + parameters: self.parameters.clone(), + } + ).await else { + // Either from immediate disconnect (local device connection refused) + // or from error like invalid URL + return ConnectionAttempt { + did_open: false, + _task_end_reason: DriverStopReason::TaskError, + }; + }; + + { + let mut my_driver = self.driver.lock().await; + *my_driver = Some(driver); + } + + let mut task_end_reason = task.map(|res| match res { + Ok(a) => a, + Err(task_err) => { + if task_err.is_cancelled() { + debug!("Connection task was cancelled"); + DriverStopReason::UserAborted + } else { + DriverStopReason::TaskError + } + } + }); + + let mut did_connection_open = false; + + // spawn listener for rpcs + let task_end_reason = loop { + tokio::select! { + reason = &mut task_end_reason => { + debug!("Connection closed: {:?}", reason); + + break reason; + }, + msg = recver.recv() => { + // If the sender is dropped, break the loop + let Some(msg) = msg else { + // break DriverStopReason::ServerDisconnect; + continue; + }; + + if let to_client::ToClientBody::Init { i: _ } = &msg.b { + did_connection_open = true; + } + + self.on_message(msg).await; + } + } + }; + + 'destroy_driver: { + debug!("Destroying driver"); + let mut d_guard = self.driver.lock().await; + let Some(d) = d_guard.take() else { + // We destroyed the driver already, + // e.g. .disconnect() was called + break 'destroy_driver; + }; + + d.disconnect(); + } + + ConnectionAttempt { + did_open: did_connection_open, + _task_end_reason: task_end_reason, + } + } + + async fn on_open(self: &Arc, init: &to_client::Init) { + debug!("Connected to server: {:?}", init); + + for (event_name, _) in self.event_subscriptions.lock().await.iter() { + self.send_subscription(event_name.clone(), true).await; + } + + // Flush message queue + for msg in self.msg_queue.lock().await.drain(..) { + // If its in the queue, it isn't ephemeral, so we pass + // default SendMsgOpts + self.send_msg(msg, SendMsgOpts::default()).await; + } + } + + async fn on_message(self: &Arc, msg: Arc) { + let body = &msg.b; + + match body { + to_client::ToClientBody::Init { i: init } => { + self.on_open(init).await; + } + to_client::ToClientBody::ActionResponse { ar } => { + let id = ar.i; + let mut in_flight_rpcs = self.in_flight_rpcs.lock().await; + let Some(tx) = in_flight_rpcs.remove(&id) else { + debug!("Unexpected response: rpc id not found"); + return; + }; + if let Err(e) = tx.send(Ok(ar.clone())) { + debug!("{:?}", e); + return; + } + } + to_client::ToClientBody::EventMessage { ev } => { + let listeners = self.event_subscriptions.lock().await; + if let Some(callbacks) = listeners.get(&ev.n) { + for cb in callbacks { + cb(&ev.a); + } + } + } + to_client::ToClientBody::Error { e } => { + if let Some(action_id) = e.ai { + let mut in_flight_rpcs = self.in_flight_rpcs.lock().await; + let Some(tx) = in_flight_rpcs.remove(&action_id) else { + debug!("Unexpected response: rpc id not found"); + return; + }; + if let Err(e) = tx.send(Err(e.clone())) { + debug!("{:?}", e); + return; + } + + return; + } + + + } + } + } + + async fn send_msg(self: &Arc, msg: Arc, opts: SendMsgOpts) { + let guard = self.driver.lock().await; + + 'send_immediately: { + let Some(driver) = guard.deref() else { + break 'send_immediately; + }; + + let Ok(_) = driver.send(msg.clone()).await else { + break 'send_immediately; + }; + + return; + } + + // Otherwise queue + if opts.ephemeral == false { + self.msg_queue.lock().await.push(msg.clone()); + } + + return; + } + + pub async fn action(self: &Arc, method: &str, params: Vec) -> Result { + let id: i64 = self.rpc_counter.fetch_add(1, Ordering::SeqCst); + + let (tx, rx) = oneshot::channel(); + self.in_flight_rpcs.lock().await.insert(id, tx); + + self.send_msg( + Arc::new(to_server::ToServer { + b: to_server::ToServerBody::ActionRequest { + ar: to_server::ActionRequest { + i: id, + n: method.to_string(), + a: params, + }, + }, + }), + SendMsgOpts::default(), + ) + .await; + + let Ok(res) = rx.await else { + // Verbosity + return Err(anyhow::anyhow!("Socket closed during rpc")); + }; + + match res { + Ok(ok) => Ok(ok.o), + Err(err) => { + let metadata = err.md.unwrap_or(Value::Null); + + Err(anyhow::anyhow!( + "RPC Error({}): {:?}, {:#}", + err.c, + err.m, + metadata + )) + } + } + } + + async fn send_subscription(self: &Arc, event_name: String, subscribe: bool) { + self.send_msg( + Arc::new(to_server::ToServer { + b: to_server::ToServerBody::SubscriptionRequest { + sr: to_server::SubscriptionRequest { + e: event_name, + s: subscribe, + }, + }, + }), + SendMsgOpts { ephemeral: true }, + ) + .await; + } + + async fn add_event_subscription( + self: &Arc, + event_name: String, + callback: Box, + ) { + // TODO: Support for once + let mut listeners = self.event_subscriptions.lock().await; + + let is_new_subscription = listeners.contains_key(&event_name) == false; + + listeners + .entry(event_name.clone()) + .or_insert(Vec::new()) + .push(callback); + + if is_new_subscription { + self.send_subscription(event_name, true).await; + } + } + + pub async fn on_event(self: &Arc, event_name: &str, callback: F) + where + F: Fn(&Vec) + Send + Sync + 'static, + { + self.add_event_subscription(event_name.to_string(), Box::new(callback)) + .await + } + + pub async fn disconnect(self: &Arc) { + if self.is_disconnecting() { + // We are already disconnecting + return; + } + + debug!("Disconnecting from actor conn"); + + self.dc_watch.0.send(true).ok(); + + if let Some(d) = self.driver.lock().await.deref() { + d.disconnect(); + } + self.in_flight_rpcs.lock().await.clear(); + self.event_subscriptions.lock().await.clear(); + let Some(rx) = self.disconnection_rx.lock().await.take() else { + return; + }; + + rx.await.ok(); + } +} + + +pub fn start_connection( + conn: &Arc, + mut shutdown_rx: broadcast::Receiver<()> +) { + let (tx, rx) = oneshot::channel(); + + let conn = conn.clone(); + + tokio::spawn(async move { + { + let mut stop_rx = conn.disconnection_rx.lock().await; + if stop_rx.is_some() { + // Already doing connection_with_retry + // - this drops the oneshot + return; + } + + *stop_rx = Some(rx); + } + + 'keepalive: loop { + debug!("Attempting to reconnect"); + let mut backoff = Backoff::new(Duration::from_secs(1), Duration::from_secs(30)); + let mut retry_attempt = 0; + 'retry: loop { + retry_attempt += 1; + debug!( + "Establish conn: attempt={}, timeout={:?}", + retry_attempt, + backoff.delay() + ); + let attempt = conn.try_connect().await; + + if conn.is_disconnecting() { + break 'keepalive; + } + + if attempt.did_open { + break 'retry; + } + + let mut dc_rx = conn.dc_watch.0.subscribe(); + + tokio::select! { + _ = backoff.tick() => {}, + _ = dc_rx.wait_for(|x| *x == true) => { + break 'keepalive; + } + _ = shutdown_rx.recv() => { + debug!("Received shutdown signal, stopping connection attempts"); + break 'keepalive; + } + } + } + } + + tx.send(()).ok(); + conn.disconnection_rx.lock().await.take(); + }); +} + +impl Debug for ActorConnectionInner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ActorConnection") + .field("endpoint", &self.endpoint) + .field("transport_kind", &self.transport_kind) + .field("encoding_kind", &self.encoding_kind) + .finish() + } +} \ No newline at end of file diff --git a/clients/rust/src/drivers/mod.rs b/clients/rust/src/drivers/mod.rs index 5c9fe83ca..8db37c88a 100644 --- a/clients/rust/src/drivers/mod.rs +++ b/clients/rust/src/drivers/mod.rs @@ -1,18 +1,22 @@ use std::sync::Arc; -use crate::{encoding::EncodingKind, protocol}; +use crate::{ + protocol::{query, to_client, to_server}, + EncodingKind, TransportKind +}; use anyhow::Result; use serde_json::Value; use tokio::{ sync::mpsc, task::{AbortHandle, JoinHandle}, }; -use urlencoding::encode; +use tracing::debug; pub mod sse; pub mod ws; -const MAX_CONN_PARAMS_SIZE: usize = 4096; +pub type MessageToClient = Arc; +pub type MessageToServer = Arc; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum DriverStopReason { @@ -22,11 +26,8 @@ pub enum DriverStopReason { TaskError, } -pub(crate) type MessageToClient = Arc; -pub(crate) type MessageToServer = Arc; - #[derive(Debug)] -pub(crate) struct DriverHandle { +pub struct DriverHandle { abort_handle: AbortHandle, sender: mpsc::Sender, } @@ -39,7 +40,7 @@ impl DriverHandle { } } - pub async fn send(&self, msg: Arc) -> Result<()> { + pub async fn send(&self, msg: Arc) -> Result<()> { self.sender.send(msg).await?; Ok(()) @@ -52,76 +53,32 @@ impl DriverHandle { impl Drop for DriverHandle { fn drop(&mut self) { + debug!("DriverHandle dropped, aborting task"); self.disconnect() } } -#[derive(Debug, Clone, Copy)] -pub enum TransportKind { - WebSocket, - Sse, +pub type DriverConnection = ( + DriverHandle, + mpsc::Receiver, + JoinHandle, +); + +pub struct DriverConnectArgs { + pub endpoint: String, + pub encoding_kind: EncodingKind, + pub query: query::ActorQuery, + pub parameters: Option, } -impl TransportKind { - pub(crate) async fn connect( - &self, - endpoint: String, - encoding_kind: EncodingKind, - parameters: &Option, - ) -> Result<( - DriverHandle, - mpsc::Receiver, - JoinHandle, - )> { - match *self { - TransportKind::WebSocket => ws::connect(endpoint, encoding_kind, parameters).await, - TransportKind::Sse => sse::connect(endpoint, encoding_kind, parameters).await, - } - } -} - -fn build_conn_url( - endpoint: &str, - transport_kind: &TransportKind, - encoding_kind: EncodingKind, - params: &Option, -) -> Result { - let connect_path = { - match transport_kind { - TransportKind::WebSocket => "websocket", - TransportKind::Sse => "sse", - } - }; - - let endpoint = match transport_kind { - TransportKind::WebSocket => endpoint - .to_string() - .replace("http://", "ws://") - .replace("https://", "wss://"), - TransportKind::Sse => endpoint.to_string(), - }; - - let Some(params) = params else { - return Ok(format!( - "{}/connect/{}?encoding={}", - endpoint, - connect_path, - encoding_kind.as_str() - )); +pub async fn connect_driver( + transport_kind: TransportKind, + args: DriverConnectArgs +) -> Result { + let res = match transport_kind { + TransportKind::WebSocket => ws::connect(args).await?, + TransportKind::Sse => sse::connect(args).await?, }; - let params_str = serde_json::to_string(params)?; - if params_str.len() > MAX_CONN_PARAMS_SIZE { - return Err(anyhow::anyhow!("Connection parameters too long")); - } - - let params_str = encode(¶ms_str); - - Ok(format!( - "{}/connect/{}?encoding={}¶ms={}", - endpoint, - connect_path, - encoding_kind.as_str(), - params_str - )) + Ok(res) } diff --git a/clients/rust/src/drivers/sse.rs b/clients/rust/src/drivers/sse.rs index d4feaf581..be314baa7 100644 --- a/clients/rust/src/drivers/sse.rs +++ b/clients/rust/src/drivers/sse.rs @@ -1,50 +1,94 @@ -use anyhow::{Context, Result}; +use anyhow::{Result}; use base64::prelude::*; use eventsource_client::{BoxStream, Client, ClientBuilder, ReconnectOptionsBuilder, SSE}; use futures_util::StreamExt; -use serde_json::Value; +use reqwest::header::USER_AGENT; use std::sync::Arc; use tokio::sync::mpsc; -use tokio::task::JoinHandle; use tracing::debug; -use crate::encoding::EncodingKind; -use crate::protocol::{ToClient, ToClientBody, ToServer}; +use crate::{ + common::{EncodingKind, HEADER_ACTOR_ID, HEADER_ACTOR_QUERY, HEADER_CONN_ID, HEADER_CONN_PARAMS, HEADER_CONN_TOKEN, HEADER_ENCODING, USER_AGENT_VALUE}, + protocol::{to_client, to_server} +}; use super::{ - build_conn_url, DriverHandle, DriverStopReason, MessageToClient, MessageToServer, TransportKind, + DriverConnectArgs, DriverConnection, DriverHandle, DriverStopReason, MessageToClient, MessageToServer }; #[derive(Debug, Clone, PartialEq, Eq)] struct ConnectionDetails { + actor_id: String, id: String, token: String, } -pub(crate) async fn connect( - endpoint: String, + +struct Context { + conn: ConnectionDetails, encoding_kind: EncodingKind, - parameters: &Option, -) -> Result<( - DriverHandle, - mpsc::Receiver, - JoinHandle, -)> { - let url = build_conn_url(&endpoint, &TransportKind::Sse, encoding_kind, parameters)?; - - let client = ClientBuilder::for_url(&url)? - .reconnect(ReconnectOptionsBuilder::new(false).build()) + endpoint: String, +} + +pub(crate) async fn connect(args: DriverConnectArgs) -> Result { + let endpoint = format!("{}/actors/connect/sse", args.endpoint); + + let params_string = match args.parameters { + Some(p) => Some(serde_json::to_string(&p)).transpose(), + None => Ok(None), + }?; + + let client = ClientBuilder::for_url(&endpoint)? + .header(USER_AGENT.as_str(), USER_AGENT_VALUE)? + .header(HEADER_ENCODING, args.encoding_kind.as_str())? + .header(HEADER_ACTOR_QUERY, serde_json::to_string(&args.query)?.as_str())?; + + let client = match params_string { + Some(p) => client.header(HEADER_CONN_PARAMS, p.as_str())?, + None => client, + }; + let client = client.reconnect(ReconnectOptionsBuilder::new(false).build()) .build(); let (in_tx, in_rx) = mpsc::channel::(32); let (out_tx, out_rx) = mpsc::channel::(32); - let task = tokio::spawn(start(client, endpoint, encoding_kind, in_tx, out_rx)); + let task = tokio::spawn(start(client, args.endpoint, args.encoding_kind, in_tx, out_rx)); let handle = DriverHandle::new(out_tx, task.abort_handle()); Ok((handle, in_rx, task)) } +async fn sse_send_msg(ctx: &Context, msg: MessageToServer) -> Result { + let msg = serialize(ctx.encoding_kind, &msg)?; + + // Add connection ID and token to the request URL + let request_url = format!( + "{}/actors/message", + ctx.endpoint + ); + + let res = reqwest::Client::new() + .post(request_url) + .body(msg) + .header(USER_AGENT, USER_AGENT_VALUE) + .header(HEADER_ENCODING, ctx.encoding_kind.as_str()) + .header(HEADER_ACTOR_ID, ctx.conn.actor_id.as_str()) + .header(HEADER_CONN_ID, ctx.conn.id.as_str()) + .header(HEADER_CONN_TOKEN, ctx.conn.token.as_str()) + .send() + .await?; + + + if !res.status().is_success() { + return Err(anyhow::anyhow!("Failed to send message: {:?}", res)); + } + + let res = res.text().await?; + + Ok(res) +} + async fn start( client: impl Client, endpoint: String, @@ -52,66 +96,40 @@ async fn start( in_tx: mpsc::Sender, mut out_rx: mpsc::Receiver, ) -> DriverStopReason { - let serialize = get_serializer(encoding_kind); - let deserialize = get_deserializer(encoding_kind); - let mut stream = client.stream(); - let conn = match do_handshake(&mut stream, &deserialize, &in_tx).await { - Ok(conn) => conn, - Err(reason) => { - debug!("Failed to connect: {:?}", reason); - return reason; - } + let ctx = Context { + conn: match do_handshake(&mut stream, encoding_kind, &in_tx).await { + Ok(conn) => conn, + Err(reason) => return reason + }, + encoding_kind, + endpoint, }; + debug!("Handshake completed successfully"); + loop { tokio::select! { + // Handle outgoing messages msg = out_rx.recv() => { let Some(msg) = msg else { return DriverStopReason::UserAborted; }; - let msg = match serialize(&msg) { - Ok(msg) => msg, + let res = match sse_send_msg(&ctx, msg).await { + Ok(res) => res, Err(e) => { - debug!("Failed to serialize {:?} {:?}", msg, e); + debug!("Failed to send message: {:?}", e); continue; } }; - // Add connection ID and token to the request URL - let request_url = format!( - "{}/connections/{}/message?encoding={}&connectionToken={}", - endpoint, conn.id, encoding_kind.as_str(), urlencoding::encode(&conn.token) - ); - - // Handle response - let resp = reqwest::Client::new() - .post(request_url) - .body(msg) - .send() - .await; - - match resp { - Ok(resp) => { - if !resp.status().is_success() { - debug!("Failed to send message: {:?}", resp); - } - - if let Ok(t) = resp.text().await { - debug!("Response: {:?}", t); - } - }, - Err(e) => { - debug!("Failed to send message: {:?}", e); - } - } + debug!("Response: {:?}", res); }, - // Handle sse incoming msg = stream.next() => { let Some(msg) = msg else { - debug!("Receiver dropped"); + // Receiver dropped return DriverStopReason::ServerDisconnect; }; @@ -120,8 +138,7 @@ async fn start( SSE::Comment(comment) => debug!("Sse comment: {}", comment), SSE::Connected(_) => debug!("warning: received sse connection past-handshake"), SSE::Event(event) => { - // println!("POST INIT event coming in: {:?}", event.data); - let msg = match deserialize(&event.data) { + let msg = match deserialize(encoding_kind, &event.data) { Ok(msg) => msg, Err(e) => { debug!("Failed to deserialize {:?} {:?}", event, e); @@ -147,7 +164,7 @@ async fn start( async fn do_handshake( stream: &mut BoxStream>, - deserialize: &impl Fn(&str) -> Result, + encoding_kind: EncodingKind, in_tx: &mpsc::Sender, ) -> Result { loop { @@ -164,7 +181,7 @@ async fn do_handshake( SSE::Comment(comment) => debug!("Sse comment {:?}", comment), SSE::Connected(_) => debug!("Connected Sse"), SSE::Event(event) => { - let msg = match deserialize(&event.data) { + let msg = match deserialize(encoding_kind, &event.data) { Ok(msg) => msg, Err(e) => { debug!("Failed to deserialize {:?} {:?}", event, e); @@ -180,17 +197,15 @@ async fn do_handshake( } // Wait until we get an Init packet - let ToClientBody::Init { i } = &msg.b else { + let to_client::ToClientBody::Init { i } = &msg.b else { continue; }; // Mark handshake complete - let conn_id = &i.ci; - let conn_token = &i.ct; - return Ok(ConnectionDetails { - id: conn_id.clone(), - token: conn_token.clone() + actor_id: i.ai.to_string(), + id: i.ci.clone(), + token: i.ct.clone() }) }, } @@ -204,28 +219,25 @@ async fn do_handshake( } } -fn get_serializer(encoding_kind: EncodingKind) -> impl Fn(&ToServer) -> Result> { - encoding_kind.get_default_serializer() -} - -fn get_deserializer(encoding_kind: EncodingKind) -> impl Fn(&str) -> Result { +fn deserialize(encoding_kind: EncodingKind, msg: &str) -> Result { match encoding_kind { - EncodingKind::Json => json_deserialize, - EncodingKind::Cbor => cbor_deserialize, + EncodingKind::Json => { + Ok(serde_json::from_str::(msg)?) + }, + EncodingKind::Cbor => { + let msg = serde_cbor::from_slice::( + &BASE64_STANDARD.decode(msg.as_bytes())? + )?; + + Ok(msg) + } } } -fn json_deserialize(value: &str) -> Result { - let msg = serde_json::from_str::(value)?; - - Ok(msg) +fn serialize(encoding_kind: EncodingKind, msg: &to_server::ToServer) -> Result> { + match encoding_kind { + EncodingKind::Json => Ok(serde_json::to_vec(msg)?), + EncodingKind::Cbor => Ok(serde_cbor::to_vec(msg)?), + } } -fn cbor_deserialize(msg: &str) -> Result { - let msg = BASE64_STANDARD - .decode(msg.as_bytes()) - .context("base64 failure:")?; - let msg = serde_cbor::from_slice::(&msg).context("serde failure:")?; - - Ok(msg) -} diff --git a/clients/rust/src/drivers/ws.rs b/clients/rust/src/drivers/ws.rs index c86649d48..e8d694500 100644 --- a/clients/rust/src/drivers/ws.rs +++ b/clients/rust/src/drivers/ws.rs @@ -1,36 +1,45 @@ use anyhow::{Context, Result}; use futures_util::{SinkExt, StreamExt}; -use serde_json::Value; use std::sync::Arc; use tokio::net::TcpStream; use tokio::sync::mpsc; -use tokio::task::JoinHandle; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use tracing::debug; -use crate::encoding::EncodingKind; -use crate::protocol::{ToClient, ToServer}; +use crate::{ + protocol::to_server, + protocol::to_client, + EncodingKind +}; use super::{ - build_conn_url, DriverHandle, DriverStopReason, MessageToClient, MessageToServer, TransportKind, + DriverConnectArgs, DriverConnection, DriverHandle, DriverStopReason, MessageToClient, MessageToServer }; -pub(crate) async fn connect( - endpoint: String, - encoding_kind: EncodingKind, - parameters: &Option, -) -> Result<( - DriverHandle, - mpsc::Receiver, - JoinHandle, -)> { - let url = build_conn_url( - &endpoint, - &TransportKind::WebSocket, - encoding_kind, - parameters, - )?; +fn build_connection_url(args: &DriverConnectArgs) -> Result { + let actor_query_string = serde_json::to_string(&args.query)?; + // TODO: Should replace http:// only at the start of the string + let url = args.endpoint + .to_string() + .replace("http://", "ws://") + .replace("https://", "wss://"); + + let url = format!( + "{}/actors/connect/websocket?encoding={}&query={}", + url, + args.encoding_kind.as_str(), + urlencoding::encode(&actor_query_string) + ); + + Ok(url) +} + + +pub(crate) async fn connect(args: DriverConnectArgs) -> Result { + let url = build_connection_url(&args)?; + + debug!("Connecting to: {}", url); let (ws, _res) = tokio_tungstenite::connect_async(url) .await @@ -38,10 +47,20 @@ pub(crate) async fn connect( let (in_tx, in_rx) = mpsc::channel::(32); let (out_tx, out_rx) = mpsc::channel::(32); - let task = tokio::spawn(start(ws, encoding_kind, in_tx, out_rx)); + let task = tokio::spawn(start(ws, args.encoding_kind, in_tx, out_rx)); let handle = DriverHandle::new(out_tx, task.abort_handle()); + handle.send(Arc::new( + to_server::ToServer { + b: to_server::ToServerBody::Init { + i: to_server::Init { + p: args.parameters + } + }, + } + )).await?; + Ok((handle, in_rx, task)) } @@ -118,21 +137,21 @@ async fn start( } } -fn get_msg_deserializer(encoding_kind: EncodingKind) -> fn(&Message) -> Result { +fn get_msg_deserializer(encoding_kind: EncodingKind) -> fn(&Message) -> Result { match encoding_kind { EncodingKind::Json => json_msg_deserialize, EncodingKind::Cbor => cbor_msg_deserialize, } } -fn get_msg_serializer(encoding_kind: EncodingKind) -> fn(&ToServer) -> Result { +fn get_msg_serializer(encoding_kind: EncodingKind) -> fn(&to_server::ToServer) -> Result { match encoding_kind { EncodingKind::Json => json_msg_serialize, EncodingKind::Cbor => cbor_msg_serialize, } } -fn json_msg_deserialize(value: &Message) -> Result { +fn json_msg_deserialize(value: &Message) -> Result { match value { Message::Text(text) => Ok(serde_json::from_str(text)?), Message::Binary(bin) => Ok(serde_json::from_slice(bin)?), @@ -140,7 +159,7 @@ fn json_msg_deserialize(value: &Message) -> Result { } } -fn cbor_msg_deserialize(value: &Message) -> Result { +fn cbor_msg_deserialize(value: &Message) -> Result { match value { Message::Binary(bin) => Ok(serde_cbor::from_slice(bin)?), Message::Text(text) => Ok(serde_cbor::from_slice(text.as_bytes())?), @@ -148,10 +167,10 @@ fn cbor_msg_deserialize(value: &Message) -> Result { } } -fn json_msg_serialize(value: &ToServer) -> Result { +fn json_msg_serialize(value: &to_server::ToServer) -> Result { Ok(Message::Text(serde_json::to_string(value)?.into())) } -fn cbor_msg_serialize(value: &ToServer) -> Result { +fn cbor_msg_serialize(value: &to_server::ToServer) -> Result { Ok(Message::Binary(serde_cbor::to_vec(value)?.into())) } diff --git a/clients/rust/src/encoding.rs b/clients/rust/src/encoding.rs deleted file mode 100644 index ea503bcab..000000000 --- a/clients/rust/src/encoding.rs +++ /dev/null @@ -1,37 +0,0 @@ -use anyhow::Result; - -use crate::protocol::ToServer; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum EncodingKind { - Json, - Cbor, -} - -impl EncodingKind { - pub fn as_str(&self) -> &str { - match self { - EncodingKind::Json => "json", - EncodingKind::Cbor => "cbor", - } - } - - pub fn get_default_serializer(&self) -> fn(&ToServer) -> Result> { - match self { - EncodingKind::Json => json_serialize, - EncodingKind::Cbor => cbor_serialize, - } - } -} - -fn json_serialize(value: &ToServer) -> Result> { - let msg = serde_json::to_vec(value)?; - - Ok(msg) -} - -fn cbor_serialize(msg: &ToServer) -> Result> { - let msg = serde_cbor::to_vec(msg)?; - - Ok(msg) -} diff --git a/clients/rust/src/handle.rs b/clients/rust/src/handle.rs index 3765ded57..4abda92af 100644 --- a/clients/rust/src/handle.rs +++ b/clients/rust/src/handle.rs @@ -1,414 +1,178 @@ -use anyhow::Result; -use futures_util::FutureExt; -use serde_json::Value; -use std::fmt::Debug; -use std::ops::Deref; -use std::sync::atomic::{AtomicI64, Ordering}; -use std::time::Duration; -use std::{collections::HashMap, sync::Arc}; -use tokio::sync::{oneshot, watch, Mutex}; - -use crate::drivers::{DriverHandle, DriverStopReason, TransportKind}; -use crate::encoding::EncodingKind; -use crate::{backoff::Backoff, protocol::*}; -use tracing::debug; - -use super::protocol; - -type RpcResponse = Result; -type EventCallback = dyn Fn(&Vec) + Send + Sync; - -struct SendMsgOpts { - ephemeral: bool, -} - -impl Default for SendMsgOpts { - fn default() -> Self { - Self { ephemeral: false } - } -} - -// struct WatchPair { -// tx: watch::Sender, -// rx: watch::Receiver, -// } -type WatchPair = (watch::Sender, watch::Receiver); - -pub type ActorHandle = Arc; - -struct ConnectionAttempt { - did_open: bool, - _task_end_reason: DriverStopReason, -} - -pub struct ActorHandleInner { - pub endpoint: String, - transport_kind: TransportKind, +use std::{cell::RefCell, ops::Deref, sync::Arc}; +use serde_json::Value as JsonValue; +use anyhow::{anyhow, Result}; +use urlencoding::encode as url_encode; +use crate::{ + common::{resolve_actor_id, send_http_request, HttpRequestOptions, HEADER_ACTOR_QUERY, HEADER_CONN_PARAMS, HEADER_ENCODING}, + connection::{start_connection, ActorConnection, ActorConnectionInner}, + protocol::query::*, + EncodingKind, + TransportKind +}; + +pub struct ActorHandleStateless { + endpoint: String, + params: Option, encoding_kind: EncodingKind, - parameters: Option, - - driver: Mutex>, - msg_queue: Mutex>>, - - rpc_counter: AtomicI64, - in_flight_rpcs: Mutex>>, - - event_subscriptions: Mutex>>>, - - dc_watch: WatchPair, - disconnection_rx: Mutex>>, + query: RefCell, } -impl ActorHandleInner { - pub(crate) fn new( - endpoint: String, - transport_kind: TransportKind, +impl ActorHandleStateless { + pub fn new( + endpoint: &str, + params: Option, encoding_kind: EncodingKind, - parameters: Option, - ) -> Result { - Ok(Arc::new(Self { - endpoint: endpoint.clone(), - transport_kind, + query: ActorQuery + ) -> Self { + Self { + endpoint: endpoint.to_string(), + params, encoding_kind, - parameters, - driver: Mutex::new(None), - msg_queue: Mutex::new(Vec::new()), - rpc_counter: AtomicI64::new(0), - in_flight_rpcs: Mutex::new(HashMap::new()), - event_subscriptions: Mutex::new(HashMap::new()), - dc_watch: watch::channel(false), - disconnection_rx: Mutex::new(None), - })) - } - - fn is_disconnecting(self: &Arc) -> bool { - *self.dc_watch.1.borrow() == true - } - - async fn try_connect(self: &Arc) -> ConnectionAttempt { - let (driver, mut recver, task) = match self - .transport_kind - .connect(self.endpoint.clone(), self.encoding_kind, &self.parameters) - .await - { - Ok(a) => a, - Err(_) => { - // Either from immediate disconnect (local device connection refused) - // or from error like invalid URL - return ConnectionAttempt { - did_open: false, - _task_end_reason: DriverStopReason::TaskError, - }; - } - }; - - { - let mut my_driver = self.driver.lock().await; - *my_driver = Some(driver); - } - - let mut task_end_reason = task.map(|res| match res { - Ok(a) => a, - Err(task_err) => { - if task_err.is_cancelled() { - DriverStopReason::UserAborted - } else { - DriverStopReason::TaskError - } - } - }); - - let mut did_connection_open = false; - - // spawn listener for rpcs - let task_end_reason = loop { - tokio::select! { - reason = &mut task_end_reason => { - debug!("Connection closed: {:?}", reason); - - break reason; - }, - msg = recver.recv() => { - // If the sender is dropped, break the loop - let Some(msg) = msg else { - // break DriverStopReason::ServerDisconnect; - continue; - }; - - if let ToClientBody::Init { i: _ } = &msg.b { - did_connection_open = true; - } - - self.on_message(msg).await; - } - } - }; - - 'destroy_driver: { - let mut d_guard = self.driver.lock().await; - let Some(d) = d_guard.take() else { - // We destroyed the driver already, - // e.g. .disconnect() was called - break 'destroy_driver; - }; - - d.disconnect(); - } - - ConnectionAttempt { - did_open: did_connection_open, - _task_end_reason: task_end_reason, + query: RefCell::new(query) } } - pub(crate) async fn start_connection(self: &Arc) { - let (tx, rx) = oneshot::channel(); - - { - let mut stop_rx = self.disconnection_rx.lock().await; - if stop_rx.is_some() { - // Already doing connection_with_retry - // - this drops the oneshot - return; - } - - *stop_rx = Some(rx); - } - - let handle = self.clone(); - - tokio::spawn(async move { - 'keepalive: loop { - debug!("Attempting to reconnect"); - let mut backoff = Backoff::new(Duration::from_secs(1), Duration::from_secs(30)); - let mut retry_attempt = 0; - 'retry: loop { - retry_attempt += 1; - debug!( - "Establish conn: attempt={}, timeout={:?}", - retry_attempt, - backoff.delay() - ); - let attempt = handle.try_connect().await; - - if handle.is_disconnecting() { - break 'keepalive; - } - - if attempt.did_open { - break 'retry; - } - - let mut dc_rx = handle.dc_watch.0.subscribe(); - - tokio::select! { - _ = backoff.tick() => {}, - _ = dc_rx.wait_for(|x| *x == true) => { - break 'keepalive; - } - } - } - } - - tx.send(()).ok(); - handle.disconnection_rx.lock().await.take(); - }); - } - - async fn on_open(self: &Arc, init: &protocol::Init) { - debug!("Connected to server: {:?}", init); - - for (event_name, _) in self.event_subscriptions.lock().await.iter() { - self.send_subscription(event_name.clone(), true).await; + pub async fn action(&self, name: &str, args: Vec) -> Result { + #[derive(serde::Serialize)] + struct ActionRequest { + a: Vec, } - - // Flush message queue - for msg in self.msg_queue.lock().await.drain(..) { - // If its in the queue, it isn't ephemeral, so we pass - // default SendMsgOpts - self.send_msg(msg, SendMsgOpts::default()).await; + #[derive(serde::Deserialize)] + struct ActionResponse { + o: JsonValue, } - } - async fn on_message(self: &Arc, msg: Arc) { - let body = &msg.b; + let actor_query = serde_json::to_string(&self.query)?; - match body { - protocol::ToClientBody::Init { i: init } => { - self.on_open(init).await; - } - protocol::ToClientBody::ResponseOk { ro } => { - let id = ro.i; - let mut in_flight_rpcs = self.in_flight_rpcs.lock().await; - let Some(tx) = in_flight_rpcs.remove(&id) else { - debug!("Unexpected response: rpc id not found"); - return; - }; - if let Err(e) = tx.send(Ok(ro.clone())) { - debug!("{:?}", e); - return; - } - } - protocol::ToClientBody::ResponseError { re } => { - let id = re.i; - let mut in_flight_rpcs = self.in_flight_rpcs.lock().await; - let Some(tx) = in_flight_rpcs.remove(&id) else { - debug!("Unexpected response: rpc id not found"); - return; - }; - if let Err(e) = tx.send(Err(re.clone())) { - debug!("{:?}", e); - return; - } - } - protocol::ToClientBody::EventMessage { ev } => { - let listeners = self.event_subscriptions.lock().await; - if let Some(callbacks) = listeners.get(&ev.n) { - for cb in callbacks { - cb(&ev.a); - } - } - } - protocol::ToClientBody::EventError { er } => { - debug!("Event error: {:?}", er); - } - } - } - - async fn send_msg(self: &Arc, msg: Arc, opts: SendMsgOpts) { - let guard = self.driver.lock().await; + // Build headers + let mut headers = vec![ + (HEADER_ENCODING, self.encoding_kind.to_string()), + (HEADER_ACTOR_QUERY, actor_query), + ]; - 'send_immediately: { - let Some(driver) = guard.deref() else { - break 'send_immediately; - }; - - let Ok(_) = driver.send(msg.clone()).await else { - break 'send_immediately; - }; - - return; + if let Some(params) = &self.params { + headers.push((HEADER_CONN_PARAMS, serde_json::to_string(params)?)); } - // Otherwise queue - if opts.ephemeral == false { - self.msg_queue.lock().await.push(msg.clone()); - } + let res = send_http_request::(HttpRequestOptions { + url: &format!( + "{}/actors/actions/{}", + self.endpoint, + url_encode(name) + ), + method: "POST", + headers, + body: Some(ActionRequest { + a: args, + }), + encoding_kind: self.encoding_kind, + }).await?; - return; + Ok(res.o) } - pub async fn action(self: &Arc, method: &str, params: Vec) -> Result { - let id: i64 = self.rpc_counter.fetch_add(1, Ordering::SeqCst); - - let (tx, rx) = oneshot::channel(); - self.in_flight_rpcs.lock().await.insert(id, tx); - - self.send_msg( - Arc::new(protocol::ToServer { - b: protocol::ToServerBody::RpcRequest { - rr: protocol::RpcRequest { - i: id, - n: method.to_string(), - a: params, - }, - }, - }), - SendMsgOpts::default(), - ) - .await; + pub async fn resolve(&self) -> Result { + let query = { + // None of this is async or runs on multithreads, + // it cannot fail given that both borrows are + // well contained, and cannot overlap. + let Ok(query) = self.query.try_borrow() else { + return Err(anyhow!("Failed to borrow actor query")); + }; - // TODO: Support reconnection - let Ok(res) = rx.await else { - // Verbosity - return Err(anyhow::anyhow!("Socket closed during rpc")); + query.clone() }; - match res { - Ok(ok) => Ok(ok.o), - Err(err) => { - let metadata = err.md.unwrap_or(Value::Null); + match query { + ActorQuery::Create { create: _query } => { + Err(anyhow!("actor query cannot be create")) + }, + ActorQuery::GetForId { get_for_id: query } => { + Ok(query.clone().actor_id) + }, + _ => { + let actor_id = resolve_actor_id( + &self.endpoint, + query, + self.encoding_kind + ).await?; + + { + let Ok(mut query) = self.query.try_borrow_mut() else { + // Following code will not run (see prior note) + return Err(anyhow!("Failed to borrow actor query mutably")); + }; + + *query = ActorQuery::GetForId { + get_for_id: GetForIdRequest { + actor_id: actor_id.clone(), + } + }; + } - Err(anyhow::anyhow!( - "RPC Error({}): {:?}, {:#}", - err.c, - err.m, - metadata - )) + Ok(actor_id) } } } +} - async fn send_subscription(self: &Arc, event_name: String, subscribe: bool) { - self.send_msg( - Arc::new(protocol::ToServer { - b: protocol::ToServerBody::SubscriptionRequest { - sr: protocol::SubscriptionRequest { - e: event_name, - s: subscribe, - }, - }, - }), - SendMsgOpts { ephemeral: true }, - ) - .await; - } - - async fn add_event_subscription( - self: &Arc, - event_name: String, - callback: Box, - ) { - // TODO: Support for once - let mut listeners = self.event_subscriptions.lock().await; - - let is_new_subscription = listeners.contains_key(&event_name) == false; - - listeners - .entry(event_name.clone()) - .or_insert(Vec::new()) - .push(callback); +pub struct ActorHandle { + handle: ActorHandleStateless, + endpoint: String, + params: Option, + query: ActorQuery, + client_shutdown_tx: Arc>, + transport_kind: crate::TransportKind, + encoding_kind: EncodingKind, +} - if is_new_subscription { - self.send_subscription(event_name, true).await; +impl ActorHandle { + pub fn new( + endpoint: &str, + params: Option, + query: ActorQuery, + client_shutdown_tx: Arc>, + transport_kind: TransportKind, + encoding_kind: EncodingKind + ) -> Self { + let handle = ActorHandleStateless::new( + endpoint, + params.clone(), + encoding_kind, + query.clone() + ); + + Self { + handle, + endpoint: endpoint.to_string(), + params, + query, + client_shutdown_tx, + transport_kind, + encoding_kind, } } - pub async fn on_event(self: &Arc, event_name: &str, callback: F) - where - F: Fn(&Vec) + Send + Sync + 'static, - { - self.add_event_subscription(event_name.to_string(), Box::new(callback)) - .await - } + pub fn connect(&self) -> ActorConnection { + let conn = ActorConnectionInner::new( + self.endpoint.clone(), + self.query.clone(), + self.transport_kind, + self.encoding_kind, + self.params.clone() + ); - pub async fn disconnect(self: &Arc) { - if self.is_disconnecting() { - // We are already disconnecting - return; - } - - self.dc_watch.0.send(true).ok(); - - if let Some(d) = self.driver.lock().await.deref() { - d.disconnect() - } - self.in_flight_rpcs.lock().await.clear(); - self.event_subscriptions.lock().await.clear(); - let Some(rx) = self.disconnection_rx.lock().await.take() else { - return; - }; + let rx = self.client_shutdown_tx.subscribe(); + start_connection(&conn, rx); - rx.await.ok(); + conn } } -impl Debug for ActorHandleInner { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ActorHandle") - .field("endpoint", &self.endpoint) - .field("transport_kind", &self.transport_kind) - .field("encoding_kind", &self.encoding_kind) - .finish() +impl Deref for ActorHandle { + type Target = ActorHandleStateless; + + fn deref(&self) -> &Self::Target { + &self.handle } } \ No newline at end of file diff --git a/clients/rust/src/lib.rs b/clients/rust/src/lib.rs index c4122c154..0bc31def2 100644 --- a/clients/rust/src/lib.rs +++ b/clients/rust/src/lib.rs @@ -1,12 +1,10 @@ -// cargo test -- --nocapture - mod backoff; +mod common; pub mod client; pub mod drivers; -pub mod encoding; +pub mod connection; pub mod handle; pub mod protocol; -pub use client::{Client, CreateOptions, GetOptions, GetWithIdOptions}; -pub use drivers::TransportKind; -pub use encoding::EncodingKind; +pub use client::{Client, CreateOptions, GetOptions, GetOrCreateOptions, GetWithIdOptions}; +pub use common::{TransportKind, EncodingKind}; diff --git a/clients/rust/src/protocol.rs b/clients/rust/src/protocol.rs deleted file mode 100644 index 883496bb7..000000000 --- a/clients/rust/src/protocol.rs +++ /dev/null @@ -1,106 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -// Client-bound messages (ToClient) - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Init { - // Connection id - pub ci: String, - // Connection token - pub ct: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RpcResponseOk { - // Request id - pub i: i64, - // Output value - pub o: Value, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RpcResponseError { - // Request id - pub i: i64, - // Error code - pub c: String, - // Error message - pub m: String, - // Error metadata - pub md: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToClientEvent { - // Event name - pub n: String, - // Event arguments - pub a: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToClientError { - // Error code - pub c: String, - // Error message - pub m: String, - // Error metadata - pub md: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ToClientBody { - // Initialize connection - Init { i: Init }, - // RPC response success - ResponseOk { ro: RpcResponseOk }, - // RPC response error - ResponseError { re: RpcResponseError }, - // Event message - EventMessage { ev: ToClientEvent }, - // Error message - EventError { er: ToClientError }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToClient { - // Message body - pub b: ToClientBody, -} - -// Server-bound messages (ToServer) - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RpcRequest { - // Request id - pub i: i64, - // Method name - pub n: String, - // Method arguments - pub a: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SubscriptionRequest { - // Event name - pub e: String, - // Subscribe (true) or unsubscribe (false) - pub s: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ToServerBody { - // RPC request - RpcRequest { rr: RpcRequest }, - // Subscription request - SubscriptionRequest { sr: SubscriptionRequest }, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToServer { - // Message body - pub b: ToServerBody, -} diff --git a/clients/rust/src/protocol/mod.rs b/clients/rust/src/protocol/mod.rs new file mode 100644 index 000000000..bb4f0579d --- /dev/null +++ b/clients/rust/src/protocol/mod.rs @@ -0,0 +1,3 @@ +pub mod to_server; +pub mod to_client; +pub mod query; \ No newline at end of file diff --git a/clients/rust/src/protocol/query.rs b/clients/rust/src/protocol/query.rs new file mode 100644 index 000000000..88cd2849a --- /dev/null +++ b/clients/rust/src/protocol/query.rs @@ -0,0 +1,56 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +use crate::common::ActorKey; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateRequest { + pub name: String, + pub key: ActorKey, + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub region: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GetForKeyRequest { + pub name: String, + pub key: ActorKey, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GetForIdRequest { + #[serde(rename = "actorId")] + pub actor_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GetOrCreateRequest { + pub name: String, + pub key: ActorKey, + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub region: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ActorQuery { + GetForId { + #[serde(rename = "getForId")] + get_for_id: GetForIdRequest, + }, + GetForKey { + #[serde(rename = "getForKey")] + get_for_key: GetForKeyRequest, + }, + GetOrCreateForKey { + #[serde(rename = "getOrCreateForKey")] + get_or_create_for_key: GetOrCreateRequest, + }, + Create { + create: CreateRequest, + }, +} \ No newline at end of file diff --git a/clients/rust/src/protocol/to_client.rs b/clients/rust/src/protocol/to_client.rs new file mode 100644 index 000000000..5fe67d484 --- /dev/null +++ b/clients/rust/src/protocol/to_client.rs @@ -0,0 +1,59 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +// Only called for SSE because we don't need this for WebSockets +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Init { + // Actor ID + pub ai: String, + // Connection ID + pub ci: String, + // Connection token + pub ct: String, +} + +// Used for connection errors (both during initialization and afterwards) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Error { + // Code + pub c: String, + // Message + pub m: String, + // Metadata + #[serde(skip_serializing_if = "Option::is_none")] + pub md: Option, + // Action ID + #[serde(skip_serializing_if = "Option::is_none")] + pub ai: Option +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ActionResponse { + // ID + pub i: i64, + // Output + pub o: JsonValue +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Event { + // Event name + pub n: String, + // Event arguments + pub a: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToClientBody { + Init { i: Init }, + Error { e: Error }, + ActionResponse { ar: ActionResponse }, + EventMessage { ev: Event }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToClient { + // Body + pub b: ToClientBody, +} \ No newline at end of file diff --git a/clients/rust/src/protocol/to_server.rs b/clients/rust/src/protocol/to_server.rs new file mode 100644 index 000000000..5f4a3c1d4 --- /dev/null +++ b/clients/rust/src/protocol/to_server.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Init { + // Conn Params + #[serde(skip_serializing_if = "Option::is_none")] + pub p: Option +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ActionRequest { + // ID + pub i: i64, + // Name + pub n: String, + // Args + pub a: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SubscriptionRequest { + // Event name + pub e: String, + // Subscribe + pub s: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToServerBody { + Init { i: Init }, + ActionRequest { ar: ActionRequest }, + SubscriptionRequest { sr: SubscriptionRequest }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToServer { + pub b: ToServerBody, +} diff --git a/clients/rust/src/types.rs b/clients/rust/src/types.rs deleted file mode 100644 index e69de29bb..000000000 diff --git a/clients/rust/tests/e2e.rs b/clients/rust/tests/e2e.rs index 20909e849..b3cfe67a1 100644 --- a/clients/rust/tests/e2e.rs +++ b/clients/rust/tests/e2e.rs @@ -1,7 +1,8 @@ -use actor_core_client::{Client, EncodingKind, GetOptions, TransportKind}; +use actor_core_client::{Client, EncodingKind, GetOrCreateOptions, TransportKind}; use fs_extra; use portpicker; use serde_json::json; +use tracing_subscriber::EnvFilter; use std::process::{Child, Command}; use std::time::Duration; use tempfile; @@ -168,6 +169,7 @@ async fn e2e() { // Configure logging let subscriber = tracing_subscriber::fmt() .with_max_level(tracing::Level::DEBUG) + // .with_env_filter(EnvFilter::new("actor_core_client=trace,hyper=error")) .finish(); let _guard = tracing::subscriber::set_default(subscriber); @@ -178,25 +180,26 @@ async fn e2e() { // Start the mock server let _server = MockServer::start(port).await; - // Wait for server to start info!("Waiting for server to start..."); sleep(Duration::from_secs(2)).await; // Create the client info!("Creating client to endpoint: {}", endpoint); - let client = Client::new(endpoint, TransportKind::WebSocket, EncodingKind::Cbor); - let counter = client.get("counter", GetOptions::default()).await.unwrap(); - counter - .on_event("newCount", |args| { - let new_count = args[0].as_i64().unwrap(); - println!("New count: {:?}", new_count); - }) - .await; + let client = Client::new(&endpoint, TransportKind::WebSocket, EncodingKind::Cbor); + let counter = client.get_or_create("counter", [].into(), GetOrCreateOptions::default()) + .unwrap(); + let conn = counter.connect(); + + conn.on_event("newCount", |x| { + info!("Received newCount event: {:?}", x); + }).await; let out = counter.action("increment", vec![json!(1)]).await.unwrap(); - println!("Action: {:?}", out); + info!("Action 1: {:?}", out); + let out = conn.action("increment", vec![json!(1)]).await.unwrap(); + info!("Action 2: {:?}", out); // Clean up - counter.disconnect().await; + client.disconnect(); } diff --git a/docs/clients/rust.mdx b/docs/clients/rust.mdx index fb97bf5a6..e85bd663e 100644 --- a/docs/clients/rust.mdx +++ b/docs/clients/rust.mdx @@ -4,12 +4,12 @@ icon: rust --- import MvpWarning from "/snippets/mvp-warning.mdx"; -import StepDefineWorker from "/snippets/step-define-worker.mdx"; +import StepDefineActor from "/snippets/step-define-actor.mdx"; import StepRunStudio from "/snippets/step-run-studio.mdx"; import StepDeploy from "/snippets/step-deploy.mdx"; import SetupNextSteps from "/snippets/setup-next-steps.mdx"; -The RivetKit Rust client provides a way to connect to and interact with workers from Rust applications. +The ActorCore Rust client provides a way to connect to and interact with actors from Rust applications. @@ -26,22 +26,22 @@ The RivetKit Rust client provides a way to connect to and interact with workers - Add RivetKit client & related dependencies to your project: + Add ActorCore client & related dependencies to your project: ```sh - cargo add rivetkit-client + cargo add actor-core-client cargo add serde_json cargo add tokio --features full ``` - + - Modify `src/main.rs` to connect to your worker: + Modify `src/main.rs` to connect to your actor: ```rust src/main.rs - use rivetkit_client::{Client, GetOptions, TransportKind, EncodingKind}; + use actor_core_client::{Client, EncodingKind, GetOrCreateOptions, TransportKind}; use serde_json::json; use std::time::Duration; @@ -49,14 +49,15 @@ The RivetKit Rust client provides a way to connect to and interact with workers async fn main() -> Result<(), Box> { // Replace with your endpoint URL after deployment let client = Client::new( - "http://localhost:6420".to_string(), - TransportKind::WebSocket, - EncodingKind::Cbor, + "http://localhost:6420", + TransportKind::Sse, + EncodingKind::Json ); - // Get or create a worker instance - let options = GetOptions::default(); - let counter = client.get("counter", options).await?; + // Get or create an actor instance + let options = GetOrCreateOptions::default(); + let counter = client.get("counter", [].into(), options)? + .connect(); // Subscribe to events counter.on_event("newCount", |args| { diff --git a/docs/concepts/interacting-with-workers.mdx b/docs/concepts/interacting-with-workers.mdx index 5b3fc4a7d..a833fd0e3 100644 --- a/docs/concepts/interacting-with-workers.mdx +++ b/docs/concepts/interacting-with-workers.mdx @@ -1,17 +1,17 @@ --- -title: Interacting with Workers +title: Interacting with Actors icon: square-code --- -This guide covers how to connect to and interact with workers from client applications. +This guide covers how to connect to and interact with actors from client applications. ## Setting Up the Client -The first step is to create a client that will connect to your worker service: +The first step is to create a client that will connect to your actor service: ```typescript TypeScript -import { createClient } from "rivetkit/client"; +import { createClient } from "actor-core/client"; import type { App } from "../src/index"; // Create a client with the connection address and app type @@ -19,33 +19,33 @@ const client = createClient(/* CONNECTION ADDRESS */); ``` ```rust Rust -use rivetkit_client::{Client, TransportKind, EncodingKind}; +use actor_core_client::{Client, EncodingKind, GetOrCreateOptions, TransportKind}; // Create a client with connection address and configuration let client = Client::new( - "http://localhost:6420".to_string(), // Connection address + "http://localhost:6420", // Connection address TransportKind::WebSocket, // Transport (WebSocket or SSE) EncodingKind::Cbor, // Encoding (Json or Cbor) ); ``` -```python Python (Callbacks) -from rivetkit_client import AsyncClient as WorkerClient +```python Python +from actor_core_client import AsyncClient as ActorClient # Create a client with the connection address -client = WorkerClient("http://localhost:6420") +client = ActorClient("http://localhost:6420") ``` See the setup guide for your platform for details on how to get the connection address. -## Finding & Connecting to Workers +## Finding & Connecting to Actors -RivetKit provides several methods to connect to workers: +ActorCore provides several methods to connect to actors: ### `get(tags, opts)` - Find or Create -The most common way to connect is with `get()`, which finds an existing worker matching the provided tags or creates a new one: +The most common way to connect is with `get()`, which finds an existing actor matching the provided tags or creates a new one: ```typescript TypeScript @@ -55,12 +55,12 @@ const room = await client.chatRoom.get({ channel: "general" }); -// Now you can call methods on the worker +// Now you can call methods on the actor await room.sendMessage("Alice", "Hello everyone!"); ``` ```rust Rust -use rivetkit_client::GetOptions; +use actor_core_client::GetOrCreateOptions; use serde_json::json; // Connect to a chat room for the "general" channel @@ -69,40 +69,39 @@ let tags = vec![ ("channel".to_string(), "general".to_string()), ]; -let mut options = GetOptions { +let mut options = GetOrCreateOptions { tags: Some(tags), ..Default::default() }; -let room = client.get("chatRoom", options) - .await - .expect("Failed to connect to chat room"); +let room = client.get("chatRoom", options)? + .connect(); -// Now you can call methods on the worker +// Now you can call methods on the actor room.action("sendMessage", vec![json!("Alice"), json!("Hello everyone!")]) .await .expect("Failed to send message"); ``` -```python Python (Callbacks) +```python Python # Connect to a chat room for the "general" channel room = await client.get("chatRoom", tags=[ ("name", "chat_room"), ("channel", "general") ]) -# Now you can call methods on the worker +# Now you can call methods on the actor await room.action("sendMessage", ["Alice", "Hello everyone!"]) ``` ### `create(opts)` - Explicitly Create New -When you specifically want to create a new worker instance: +When you specifically want to create a new actor instance: ```typescript TypeScript -// Create a new document worker +// Create a new document actor const doc = await client.myDocument.create({ create: { tags: { @@ -116,27 +115,21 @@ await doc.initializeDocument("My New Document"); ``` ```rust Rust -use rivetkit_client::{CreateOptions}; -use rivetkit_client::client::CreateRequestMetadata; +use actor_core_client::CreateOptions; +use actor_core_client::client::CreateRequestMetadata; use serde_json::json; -// Create a new document worker +// Create a new document actor let tags = vec![ ("name".to_string(), "my_document".to_string()), ("docId".to_string(), "123".to_string()), ]; -let create_options = CreateOptions { - params: None, - create: CreateRequestMetadata { - tags, - region: None, - }, -}; +let create_options = CreateOptions::default(); -let doc = client.create("myDocument", create_options) - .await - .expect("Failed to create document"); +let doc = client.create("myDocument", ["tags-or-keys"].into(), create_options) + .expect("Failed to create document") + .connect(); // Initialize the document doc.action("initializeDocument", vec![json!("My New Document")]) @@ -144,8 +137,8 @@ doc.action("initializeDocument", vec![json!("My New Document")]) .expect("Failed to initialize document"); ``` -```python Python (Callbacks) -# Create a new document worker +```python Python +# Create a new document actor doc = await client.get("myDocument", tags=[ ("name", "my_document"), ("docId", "123") @@ -157,28 +150,26 @@ await doc.action("initializeDocument", ["My New Document"]) ### `getWithId(id, opts)` - Connect by ID -Connect to a worker using its internal ID: +Connect to an actor using its internal ID: ```typescript TypeScript -// Connect to a specific worker by its ID -const myWorkerId = "55425f42-82f8-451f-82c1-6227c83c9372"; -const doc = await client.myDocument.getWithId(myWorkerId); +// Connect to a specific actor by its ID +const myActorId = "55425f42-82f8-451f-82c1-6227c83c9372"; +const doc = await client.myDocument.getWithId(myActorId); await doc.updateContent("Updated content"); ``` ```rust Rust -use rivetkit_client::GetWithIdOptions; +use actor_core_client::GetWithIdOptions; -// Connect to a specific worker by its ID -let my_worker_id = "55425f42-82f8-451f-82c1-6227c83c9372"; -let options = GetWithIdOptions { - params: None, -}; -let doc = client.get_with_id(my_worker_id, options) - .await - .expect("Failed to connect to document"); +// Connect to a specific actor by its ID +let my_actor_id = "55425f42-82f8-451f-82c1-6227c83c9372"; +let options = GetWithIdOptions::default(); +let doc = client.get_with_id(my_actor_id, [].into(), options) + .expect("Failed to get document") + .connect(); // Update content doc.action("updateContent", vec![json!("Updated content")]) @@ -186,22 +177,22 @@ doc.action("updateContent", vec![json!("Updated content")]) .expect("Failed to update document"); ``` -```python Python (Callbacks) -# Connect to a specific worker by its ID -my_worker_id = "55425f42-82f8-451f-82c1-6227c83c9372" -doc = await client.get_with_id(my_worker_id) +```python Python +# Connect to a specific actor by its ID +my_actor_id = "55425f42-82f8-451f-82c1-6227c83c9372" +doc = await client.get_with_id(my_actor_id) await doc.action("updateContent", ["Updated content"]) ``` -It's usually better to use tags for discovery rather than directly using worker IDs. +It's usually better to use tags for discovery rather than directly using actor IDs. ## Calling Actions -Once connected, calling worker actions are straightforward: +Once connected, calling actor actions are straightforward: ```typescript TypeScript @@ -245,7 +236,7 @@ game_room.action("updateSettings", vec![settings]) .expect("Failed to update settings"); ``` -```python Python (Callbacks) +```python Python # Call an action result = await math_utils.action("multiplyByTwo", [5]) print(result) # 10 @@ -263,12 +254,12 @@ await game_room.action("updateSettings", [{ -All worker action calls are asynchronous and require `await`, even if the worker's action is not async. +All actor action calls are asynchronous and require `await`, even if the actor's action is not async. ## Listening for Events -Workers can send realtime updates to clients using events: +Actors can send realtime updates to clients using events: ### `on(eventName, callback)` - Continuous Listening @@ -309,7 +300,7 @@ game_room.on_event("stateUpdate", move |args| { }).await; ``` -```python Python (Callbacks) +```python Python # Listen for new chat messages def handle_message(message): sender = message["sender"] @@ -335,7 +326,7 @@ For events you only need to hear once: ```typescript TypeScript // Listen for when a request is approved -worker.once("requestApproved", () => { +actor.once("requestApproved", () => { showApprovalNotification(); unlockFeatures(); }); @@ -345,13 +336,13 @@ worker.once("requestApproved", () => { // `once` is not implemented in Rust ``` -```python Python (Callbacks) +```python Python # Listen for when a request is approved def handle_approval(): show_approval_notification() unlock_features() -worker.on_event("requestApproved", handle_approval) +actor.on_event("requestApproved", handle_approval) ``` @@ -374,10 +365,10 @@ const chatRoom = await client.chatRoom.get({ channel: "super-secret" }, { ```rust Rust use serde_json::json; -use rivetkit_client::GetOptions; +use actor_core_client::GetOptions; -let tags = vec![ - ("channel".to_string(), "super-secret".to_string()), +let key = vec![ + "super-secret-channel".to_string(), ]; let params = json!({ @@ -387,18 +378,15 @@ let params = json!({ }); let options = GetOptions { - tags: Some(tags), params: Some(params), - no_create: None, - create: None, }; -let chat_room = client.get("chatRoom", options) - .await - .expect("Failed to connect to chat room"); +let chat_room = client.get("chatRoom", key, options) + .expect("Failed to get chat room") + .connect(); ``` -```python Python (Callbacks) +```python Python chat_room = await client.get( "chatRoom", tags=[("channel", "super-secret")], @@ -411,12 +399,12 @@ chat_room = await client.get( ``` -The worker can access these parameters in the `onBeforeConnect` or `createConnState` hook: +The actor can access these parameters in the `onBeforeConnect` or `createConnState` hook: ```typescript -import { worker } from "rivetkit"; +import { actor } from "actor-core"; -const chatRoom = worker({ +const chatRoom = actor({ state: { messages: [] }, createConnState: (c, { params }) => { @@ -442,7 +430,7 @@ Read more about [connection parameters](/concepts/connections). #### `opts.noCreate` -Connect only if a worker exists, without creating a new one: +Connect only if an actor exists, without creating a new one: ```typescript try { @@ -459,7 +447,7 @@ try { ```typescript TypeScript // Example with all client options const client = createClient( - "https://workers.example.com", + "https://actors.example.com", { // Data serialization format encoding: "cbor", // or "json" @@ -471,11 +459,11 @@ const client = createClient( ``` ```rust Rust -use rivetkit_client::{Client, TransportKind, EncodingKind}; +use actor_core_client::{Client, EncodingKind, GetOrCreateOptions, TransportKind}; // Create client with specific options let client = Client::new( - "https://workers.example.com".to_string(), + "https://actors.example.com", TransportKind::WebSocket, // or TransportKind::Sse EncodingKind::Cbor, // or EncodingKind::Json ); @@ -483,12 +471,12 @@ let client = Client::new( // Rust does not support accepting multiple transports ``` -```python Python (Callbacks) -from rivetkit_client import AsyncClient as WorkerClient +```python Python +from actor_core_client import AsyncClient as ActorClient # Example with all client options -client = WorkerClient( - "https://workers.example.com", +client = ActorClient( + "https://actors.example.com", "websocket" # or "sse" "cbor", # or "json" ) @@ -510,7 +498,7 @@ Specifies the data encoding format used for communication: `("websocket" | "sse")[]` (optional) -Configures which network transport mechanisms the client will use to communicate with workers, sorted by priority: +Configures which network transport mechanisms the client will use to communicate with actors, sorted by priority: - `"websocket"`: Real-time bidirectional communication, best for most applications - `"sse"` (Server-Sent Events): Works in more restricted environments where WebSockets may be blocked @@ -519,7 +507,7 @@ Default is `["websocket", "sse"]`, which automatically negotiates the best avail ## Error Handling -RivetKit provides specific error types to help you handle different failure scenarios: +ActorCore provides specific error types to help you handle different failure scenarios: ### Action Errors @@ -527,7 +515,7 @@ When an action fails, it throws an error with details about the failure: ```typescript try { - await worker.someAction(); + await actor.someAction(); } catch (error) { console.error(`Action failed: ${error.code} - ${error.message}`); // Handle specific error codes @@ -537,12 +525,12 @@ try { } ``` -These errors can be thrown from within the worker with `UserError`: +These errors can be thrown from within the actor with `UserError`: ```typescript -import { worker, UserError } from "rivetkit"; +import { actor, UserError } from "actor-core"; -const documentWorker = worker({ +const documentActor = actor({ state: { content: "" }, actions: { @@ -561,14 +549,14 @@ const documentWorker = worker({ }); ``` -RivetKit doesn't expose internal errors to clients for security, helping to prevent the exposure of sensitive information or internal implementation details. +ActorCore doesn't expose internal errors to clients for security, helping to prevent the exposure of sensitive information or internal implementation details. ### Other Errors Other common errors you might encounter: -- `InternalError`: Error from your worker that's not a subclass of `UserError` -- `ManagerError`: Issues when connecting to or communicating with the worker manager +- `InternalError`: Error from your actor that's not a subclass of `UserError` +- `ManagerError`: Issues when connecting to or communicating with the actor manager ## Disconnecting and Cleanup @@ -578,25 +566,25 @@ If you need to explicitly disconnect: ```typescript TypeScript -// Disconnect from the worker -await worker.dispose(); +// Disconnect from the actor +await actor.dispose(); // Disconnect the entire client await client.dispose(); ``` ```rust Rust -// Disconnect from the worker -worker.disconnect().await; +// Disconnect from the actor +actor.disconnect().await; // The client will be cleaned up automatically when it goes out of scope // Or explicitly drop it with: drop(client); ``` -```python Python (Callbacks) -# Disconnect from the worker -await worker.disconnect() +```python Python +# Disconnect from the actor +await actor.disconnect() ``` @@ -612,13 +600,13 @@ This makes your applications resilient to temporary network failures without any - Learn how state works in workers + Learn how state works in actors - Learn more about worker events + Learn more about actor events - Add security to your workers + Add security to your actors Manage client connections