diff --git a/engine/src/workers/stream/adapters/kv_store.rs b/engine/src/workers/stream/adapters/kv_store.rs index 6e58e15ea..9dcebbbed 100644 --- a/engine/src/workers/stream/adapters/kv_store.rs +++ b/engine/src/workers/stream/adapters/kv_store.rs @@ -20,7 +20,10 @@ use crate::{ engine::Engine, workers::stream::{ StreamMetadata, StreamWrapperMessage, - adapters::{StreamAdapter, StreamConnection}, + adapters::{ + StreamAdapter, StreamConnection, parse_stream_storage_key, stream_storage_key, + stream_storage_prefix, + }, registry::{StreamAdapterFuture, StreamAdapterRegistration}, }, }; @@ -47,7 +50,7 @@ impl BuiltinKvStoreAdapter { } fn gen_key(&self, stream_name: &str, group_id: &str) -> String { - format!("stream:{}:{}", stream_name, group_id) + stream_storage_key(stream_name, group_id) } } @@ -123,7 +126,7 @@ impl StreamAdapter for BuiltinKvStoreAdapter { } async fn list_groups(&self, stream_name: &str) -> anyhow::Result> { - let prefix = self.gen_key(stream_name, ""); + let prefix = stream_storage_prefix(stream_name); Ok(self .storage @@ -145,13 +148,10 @@ impl StreamAdapter for BuiltinKvStoreAdapter { // Parse keys to extract stream names and groups for key in all_keys { - let parts: Vec<&str> = key.split(':').collect(); - // Ensure key follows format "stream::" - if parts.len() >= 3 && parts[0] == "stream" { - let stream_name = parts[1].to_string(); - let group_id = parts[2].to_string(); - + if let Some((stream_name, group_id)) = parse_stream_storage_key(&key) { stream_map.entry(stream_name).or_default().insert(group_id); + } else if key.starts_with("stream:") { + tracing::warn!(key = %key, "Skipping unparseable stream storage key"); } } @@ -268,6 +268,34 @@ mod tests { assert_eq!(metadata[1].groups, vec!["default".to_string()]); } + #[tokio::test] + async fn list_all_stream_preserves_stream_names_with_colons() { + let adapter = BuiltinKvStoreAdapter::new(None); + adapter + .set("orders:v2", "region:us", "item-1", json!({ "value": 1 })) + .await + .unwrap(); + adapter + .set("orders:v2", "region:eu", "item-2", json!({ "value": 2 })) + .await + .unwrap(); + + let mut groups = adapter.list_groups("orders:v2").await.unwrap(); + groups.sort(); + assert_eq!( + groups, + vec!["region:eu".to_string(), "region:us".to_string()] + ); + + let metadata = adapter.list_all_stream().await.unwrap(); + assert_eq!(metadata.len(), 1); + assert_eq!(metadata[0].id, "orders:v2"); + assert_eq!( + metadata[0].groups, + vec!["region:eu".to_string(), "region:us".to_string()] + ); + } + #[tokio::test] async fn subscribe_emit_event_and_unsubscribe_round_trip() { let adapter = Arc::new(BuiltinKvStoreAdapter::new(None)); diff --git a/engine/src/workers/stream/adapters/mod.rs b/engine/src/workers/stream/adapters/mod.rs index 542b518eb..e63f56298 100644 --- a/engine/src/workers/stream/adapters/mod.rs +++ b/engine/src/workers/stream/adapters/mod.rs @@ -11,6 +11,7 @@ pub mod redis_adapter; use std::sync::Arc; use async_trait::async_trait; +use base64::Engine; use iii_sdk::{ UpdateOp, UpdateResult, types::{DeleteResult, SetResult}, @@ -22,6 +23,53 @@ use crate::{ workers::stream::{StreamMetadata, StreamWrapperMessage}, }; +const STREAM_KEY_PREFIX: &str = "stream:"; +const ENCODED_STREAM_NAME_PREFIX: &str = "b64~"; + +pub(super) fn stream_storage_key(stream_name: &str, group_id: &str) -> String { + format!( + "{}{}:{}", + STREAM_KEY_PREFIX, + encode_stream_name_segment(stream_name), + group_id + ) +} + +pub(super) fn stream_storage_prefix(stream_name: &str) -> String { + format!( + "{}{}:", + STREAM_KEY_PREFIX, + encode_stream_name_segment(stream_name) + ) +} + +pub(super) fn parse_stream_storage_key(key: &str) -> Option<(String, String)> { + let rest = key.strip_prefix(STREAM_KEY_PREFIX)?; + let (stream_name, group_id) = rest.split_once(':')?; + let stream_name = decode_stream_name_segment(stream_name)?; + Some((stream_name, group_id.to_string())) +} + +fn encode_stream_name_segment(stream_name: &str) -> String { + if stream_name.contains(':') || stream_name.starts_with(ENCODED_STREAM_NAME_PREFIX) { + let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(stream_name); + format!("{ENCODED_STREAM_NAME_PREFIX}{encoded}") + } else { + stream_name.to_string() + } +} + +fn decode_stream_name_segment(segment: &str) -> Option { + if let Some(encoded) = segment.strip_prefix(ENCODED_STREAM_NAME_PREFIX) { + let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(encoded) + .ok()?; + String::from_utf8(decoded).ok() + } else { + Some(segment.to_string()) + } +} + #[async_trait] pub trait StreamAdapter: Send + Sync { async fn set( @@ -84,3 +132,36 @@ pub trait StreamConnection: Subscriber + Send + Sync { /// This is the optimized path - deserialize once, call many times. async fn handle_stream_message(&self, msg: &StreamWrapperMessage) -> anyhow::Result<()>; } + +#[cfg(test)] +mod tests { + use super::{parse_stream_storage_key, stream_storage_key, stream_storage_prefix}; + + #[test] + fn stream_storage_key_preserves_group_ids_with_colons() { + let key = stream_storage_key("orders", "region:us"); + let parsed = parse_stream_storage_key(&key).unwrap(); + assert_eq!(parsed.0, "orders"); + assert_eq!(parsed.1, "region:us"); + } + + #[test] + fn stream_storage_key_encodes_stream_names_with_colons() { + let key = stream_storage_key("orders:v2", "region:us"); + let parsed = parse_stream_storage_key(&key).unwrap(); + assert_eq!(parsed.0, "orders:v2"); + assert_eq!(parsed.1, "region:us"); + assert_eq!( + stream_storage_prefix("orders:v2"), + "stream:b64~b3JkZXJzOnYy:" + ); + } + + #[test] + fn stream_storage_key_encodes_names_that_start_with_reserved_prefix() { + let key = stream_storage_key("b64~orders", "default"); + let parsed = parse_stream_storage_key(&key).unwrap(); + assert_eq!(parsed.0, "b64~orders"); + assert_eq!(parsed.1, "default"); + } +} diff --git a/engine/src/workers/stream/adapters/redis_adapter.rs b/engine/src/workers/stream/adapters/redis_adapter.rs index 05151302f..ec06d3976 100644 --- a/engine/src/workers/stream/adapters/redis_adapter.rs +++ b/engine/src/workers/stream/adapters/redis_adapter.rs @@ -21,7 +21,10 @@ use crate::{ redis::DEFAULT_REDIS_CONNECTION_TIMEOUT, stream::{ StreamMetadata, StreamWrapperMessage, - adapters::{StreamAdapter, StreamConnection}, + adapters::{ + StreamAdapter, StreamConnection, parse_stream_storage_key, stream_storage_key, + stream_storage_prefix, + }, registry::{StreamAdapterFuture, StreamAdapterRegistration}, }, }, @@ -78,7 +81,7 @@ impl StreamAdapter for RedisAdapter { ops: Vec, ) -> anyhow::Result { let mut conn = self.publisher.lock().await; - let key = format!("stream:{}:{}", stream_name, group_id); + let key = stream_storage_key(stream_name, group_id); // Serialize operations to JSON let ops_json = serde_json::to_string(&ops) @@ -270,7 +273,7 @@ impl StreamAdapter for RedisAdapter { item_id: &str, data: Value, ) -> anyhow::Result { - let key: String = format!("stream:{}:{}", stream_name, group_id); + let key = stream_storage_key(stream_name, group_id); let mut conn = self.publisher.lock().await; let value = serde_json::to_string(&data).unwrap_or_default(); @@ -321,7 +324,7 @@ impl StreamAdapter for RedisAdapter { group_id: &str, item_id: &str, ) -> anyhow::Result> { - let key = format!("stream:{}:{}", stream_name, group_id); + let key = stream_storage_key(stream_name, group_id); let mut conn = self.publisher.lock().await; match conn.hget::<_, _, Option>(&key, &item_id).await { @@ -343,7 +346,7 @@ impl StreamAdapter for RedisAdapter { let group_id = group_id.to_string(); let item_id = item_id.to_string(); - let key = format!("stream:{}:{}", stream_name, group_id); + let key = stream_storage_key(&stream_name, &group_id); // Use Lua script for atomic get-and-delete operation // This script atomically gets the old value and deletes the field @@ -380,7 +383,7 @@ impl StreamAdapter for RedisAdapter { } async fn get_group(&self, stream_name: &str, group_id: &str) -> anyhow::Result> { - let key = format!("stream:{}:{}", stream_name, group_id); + let key = stream_storage_key(stream_name, group_id); let mut conn = self.publisher.lock().await; match conn.hgetall::>(key).await { @@ -405,8 +408,8 @@ impl StreamAdapter for RedisAdapter { async fn list_groups(&self, stream_name: &str) -> anyhow::Result> { let mut conn = self.publisher.lock().await; - let pattern = format!("stream:{}:*", stream_name); - let prefix = format!("stream:{}:", stream_name); + let prefix = stream_storage_prefix(stream_name); + let pattern = format!("{prefix}*"); match conn.keys::<_, Vec>(pattern).await { Ok(keys) => Ok(keys @@ -437,14 +440,11 @@ impl StreamAdapter for RedisAdapter { let (next_cursor, keys) = result; - // Parse keys: stream:: for key in keys { - let parts: Vec<&str> = key.split(':').collect(); - if parts.len() >= 3 && parts[0] == "stream" { - let stream_name = parts[1].to_string(); - let group_id = parts[2].to_string(); - + if let Some((stream_name, group_id)) = parse_stream_storage_key(&key) { stream_map.entry(stream_name).or_default().insert(group_id); + } else if key.starts_with("stream:") { + tracing::warn!(key = %key, "Skipping unparseable stream storage key"); } } @@ -565,7 +565,7 @@ impl RedisAdapter { ops: Vec, ) -> anyhow::Result { let mut conn = self.publisher.lock().await; - let key = format!("stream:{}:{}", stream_name, group_id); + let key = stream_storage_key(stream_name, group_id); // Simple atomic get-and-set approach // Get old value @@ -676,3 +676,49 @@ fn make_adapter(_engine: Arc, config: Option) -> StreamAdapterFut } crate::register_adapter!( name: "redis", make_adapter); + +#[cfg(test)] +mod tests { + use serde_json::json; + use uuid::Uuid; + + use super::*; + + async fn setup_test_adapter() -> RedisAdapter { + let redis_url = "redis://localhost:6379".to_string(); + RedisAdapter::new(redis_url).await.unwrap() + } + + #[tokio::test] + #[ignore = "Requires Redis running"] + async fn list_all_stream_preserves_stream_names_with_colons_redis() { + let adapter = setup_test_adapter().await; + let stream_name = format!("orders:{}:v2", Uuid::new_v4()); + + let _ = adapter.delete(&stream_name, "region:us", "item-1").await; + let _ = adapter.delete(&stream_name, "region:eu", "item-2").await; + + adapter + .set(&stream_name, "region:us", "item-1", json!({ "value": 1 })) + .await + .unwrap(); + adapter + .set(&stream_name, "region:eu", "item-2", json!({ "value": 2 })) + .await + .unwrap(); + + let metadata = adapter.list_all_stream().await.unwrap(); + let entry = metadata + .into_iter() + .find(|stream| stream.id == stream_name) + .expect("stream with ':' in the name should be listed"); + + assert_eq!( + entry.groups, + vec!["region:eu".to_string(), "region:us".to_string()] + ); + + let _ = adapter.delete(&stream_name, "region:us", "item-1").await; + let _ = adapter.delete(&stream_name, "region:eu", "item-2").await; + } +}