Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions engine/src/workers/stream/adapters/kv_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ use crate::{
engine::Engine,
workers::stream::{
StreamMetadata, StreamWrapperMessage,
adapters::{StreamAdapter, StreamConnection},
adapters::{
StreamAdapter, StreamConnection, parse_stream_storage_key, stream_storage_key,
stream_storage_prefix,
},
registry::{StreamAdapterFuture, StreamAdapterRegistration},
},
};
Expand All @@ -47,7 +50,7 @@ impl BuiltinKvStoreAdapter {
}

fn gen_key(&self, stream_name: &str, group_id: &str) -> String {
format!("stream:{}:{}", stream_name, group_id)
stream_storage_key(stream_name, group_id)
}
}

Expand Down Expand Up @@ -123,7 +126,7 @@ impl StreamAdapter for BuiltinKvStoreAdapter {
}

async fn list_groups(&self, stream_name: &str) -> anyhow::Result<Vec<String>> {
let prefix = self.gen_key(stream_name, "");
let prefix = stream_storage_prefix(stream_name);

Ok(self
.storage
Expand All @@ -145,12 +148,7 @@ impl StreamAdapter for BuiltinKvStoreAdapter {

// Parse keys to extract stream names and groups
for key in all_keys {
let parts: Vec<&str> = key.split(':').collect();
// Ensure key follows format "stream:<stream_name>:<group_id>"
if parts.len() >= 3 && parts[0] == "stream" {
let stream_name = parts[1].to_string();
let group_id = parts[2].to_string();

if let Some((stream_name, group_id)) = parse_stream_storage_key(&key) {
stream_map.entry(stream_name).or_default().insert(group_id);
}
}
Expand Down Expand Up @@ -268,6 +266,34 @@ mod tests {
assert_eq!(metadata[1].groups, vec!["default".to_string()]);
}

#[tokio::test]
async fn list_all_stream_preserves_stream_names_with_colons() {
let adapter = BuiltinKvStoreAdapter::new(None);
adapter
.set("orders:v2", "region:us", "item-1", json!({ "value": 1 }))
.await
.unwrap();
adapter
.set("orders:v2", "region:eu", "item-2", json!({ "value": 2 }))
.await
.unwrap();

let mut groups = adapter.list_groups("orders:v2").await.unwrap();
groups.sort();
assert_eq!(
groups,
vec!["region:eu".to_string(), "region:us".to_string()]
);

let metadata = adapter.list_all_stream().await.unwrap();
assert_eq!(metadata.len(), 1);
assert_eq!(metadata[0].id, "orders:v2");
assert_eq!(
metadata[0].groups,
vec!["region:eu".to_string(), "region:us".to_string()]
);
}

#[tokio::test]
async fn subscribe_emit_event_and_unsubscribe_round_trip() {
let adapter = Arc::new(BuiltinKvStoreAdapter::new(None));
Expand Down
81 changes: 81 additions & 0 deletions engine/src/workers/stream/adapters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod redis_adapter;
use std::sync::Arc;

use async_trait::async_trait;
use base64::Engine;
use iii_sdk::{
UpdateOp, UpdateResult,
types::{DeleteResult, SetResult},
Expand All @@ -22,6 +23,53 @@ use crate::{
workers::stream::{StreamMetadata, StreamWrapperMessage},
};

const STREAM_KEY_PREFIX: &str = "stream:";
const ENCODED_STREAM_NAME_PREFIX: &str = "b64~";

pub(super) fn stream_storage_key(stream_name: &str, group_id: &str) -> String {
format!(
"{}{}:{}",
STREAM_KEY_PREFIX,
encode_stream_name_segment(stream_name),
group_id
)
}

pub(super) fn stream_storage_prefix(stream_name: &str) -> String {
format!(
"{}{}:",
STREAM_KEY_PREFIX,
encode_stream_name_segment(stream_name)
)
}

pub(super) fn parse_stream_storage_key(key: &str) -> Option<(String, String)> {
let rest = key.strip_prefix(STREAM_KEY_PREFIX)?;
let (stream_name, group_id) = rest.split_once(':')?;
let stream_name = decode_stream_name_segment(stream_name)?;
Some((stream_name, group_id.to_string()))
}

fn encode_stream_name_segment(stream_name: &str) -> String {
if stream_name.contains(':') || stream_name.starts_with(ENCODED_STREAM_NAME_PREFIX) {
let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(stream_name);
format!("{ENCODED_STREAM_NAME_PREFIX}{encoded}")
} else {
stream_name.to_string()
}
}

fn decode_stream_name_segment(segment: &str) -> Option<String> {
if let Some(encoded) = segment.strip_prefix(ENCODED_STREAM_NAME_PREFIX) {
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(encoded)
.ok()?;
String::from_utf8(decoded).ok()
} else {
Some(segment.to_string())
}
}

#[async_trait]
pub trait StreamAdapter: Send + Sync {
async fn set(
Expand Down Expand Up @@ -84,3 +132,36 @@ pub trait StreamConnection: Subscriber + Send + Sync {
/// This is the optimized path - deserialize once, call many times.
async fn handle_stream_message(&self, msg: &StreamWrapperMessage) -> anyhow::Result<()>;
}

#[cfg(test)]
mod tests {
use super::{parse_stream_storage_key, stream_storage_key, stream_storage_prefix};

#[test]
fn stream_storage_key_preserves_group_ids_with_colons() {
let key = stream_storage_key("orders", "region:us");
let parsed = parse_stream_storage_key(&key).unwrap();
assert_eq!(parsed.0, "orders");
assert_eq!(parsed.1, "region:us");
}

#[test]
fn stream_storage_key_encodes_stream_names_with_colons() {
let key = stream_storage_key("orders:v2", "region:us");
let parsed = parse_stream_storage_key(&key).unwrap();
assert_eq!(parsed.0, "orders:v2");
assert_eq!(parsed.1, "region:us");
assert_eq!(
stream_storage_prefix("orders:v2"),
"stream:b64~b3JkZXJzOnYy:"
);
}

#[test]
fn stream_storage_key_encodes_names_that_start_with_reserved_prefix() {
let key = stream_storage_key("b64~orders", "default");
let parsed = parse_stream_storage_key(&key).unwrap();
assert_eq!(parsed.0, "b64~orders");
assert_eq!(parsed.1, "default");
}
}
28 changes: 13 additions & 15 deletions engine/src/workers/stream/adapters/redis_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use crate::{
redis::DEFAULT_REDIS_CONNECTION_TIMEOUT,
stream::{
StreamMetadata, StreamWrapperMessage,
adapters::{StreamAdapter, StreamConnection},
adapters::{
StreamAdapter, StreamConnection, parse_stream_storage_key, stream_storage_key,
stream_storage_prefix,
},
registry::{StreamAdapterFuture, StreamAdapterRegistration},
},
},
Expand Down Expand Up @@ -78,7 +81,7 @@ impl StreamAdapter for RedisAdapter {
ops: Vec<UpdateOp>,
) -> anyhow::Result<UpdateResult> {
let mut conn = self.publisher.lock().await;
let key = format!("stream:{}:{}", stream_name, group_id);
let key = stream_storage_key(stream_name, group_id);

// Serialize operations to JSON
let ops_json = serde_json::to_string(&ops)
Expand Down Expand Up @@ -270,7 +273,7 @@ impl StreamAdapter for RedisAdapter {
item_id: &str,
data: Value,
) -> anyhow::Result<SetResult> {
let key: String = format!("stream:{}:{}", stream_name, group_id);
let key = stream_storage_key(stream_name, group_id);
let mut conn = self.publisher.lock().await;
let value = serde_json::to_string(&data).unwrap_or_default();

Expand Down Expand Up @@ -321,7 +324,7 @@ impl StreamAdapter for RedisAdapter {
group_id: &str,
item_id: &str,
) -> anyhow::Result<Option<Value>> {
let key = format!("stream:{}:{}", stream_name, group_id);
let key = stream_storage_key(stream_name, group_id);
let mut conn = self.publisher.lock().await;

match conn.hget::<_, _, Option<String>>(&key, &item_id).await {
Expand All @@ -343,7 +346,7 @@ impl StreamAdapter for RedisAdapter {
let group_id = group_id.to_string();
let item_id = item_id.to_string();

let key = format!("stream:{}:{}", stream_name, group_id);
let key = stream_storage_key(&stream_name, &group_id);

// Use Lua script for atomic get-and-delete operation
// This script atomically gets the old value and deletes the field
Expand Down Expand Up @@ -380,7 +383,7 @@ impl StreamAdapter for RedisAdapter {
}

async fn get_group(&self, stream_name: &str, group_id: &str) -> anyhow::Result<Vec<Value>> {
let key = format!("stream:{}:{}", stream_name, group_id);
let key = stream_storage_key(stream_name, group_id);
let mut conn = self.publisher.lock().await;

match conn.hgetall::<String, HashMap<String, String>>(key).await {
Expand All @@ -405,8 +408,8 @@ impl StreamAdapter for RedisAdapter {

async fn list_groups(&self, stream_name: &str) -> anyhow::Result<Vec<String>> {
let mut conn = self.publisher.lock().await;
let pattern = format!("stream:{}:*", stream_name);
let prefix = format!("stream:{}:", stream_name);
let prefix = stream_storage_prefix(stream_name);
let pattern = format!("{prefix}*");

match conn.keys::<_, Vec<String>>(pattern).await {
Ok(keys) => Ok(keys
Expand Down Expand Up @@ -437,13 +440,8 @@ impl StreamAdapter for RedisAdapter {

let (next_cursor, keys) = result;

// Parse keys: stream:<stream_name>:<group_id>
for key in keys {
let parts: Vec<&str> = key.split(':').collect();
if parts.len() >= 3 && parts[0] == "stream" {
let stream_name = parts[1].to_string();
let group_id = parts[2].to_string();

if let Some((stream_name, group_id)) = parse_stream_storage_key(&key) {
stream_map.entry(stream_name).or_default().insert(group_id);
}
}
Expand Down Expand Up @@ -565,7 +563,7 @@ impl RedisAdapter {
ops: Vec<UpdateOp>,
) -> anyhow::Result<UpdateResult> {
let mut conn = self.publisher.lock().await;
let key = format!("stream:{}:{}", stream_name, group_id);
let key = stream_storage_key(stream_name, group_id);

// Simple atomic get-and-set approach
// Get old value
Expand Down
Loading