Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to emit responses #17

Merged
merged 5 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ tokio-tungstenite = ["dep:tokio-tungstenite"]
serde_json = ["dep:serde_json"]

[dependencies]
async-stream = "0.3.6"
axum = { version = "0.8.1", features = ["ws"] }
futures = "0.3.31"
serde_json = { version = "1.0.133", optional = true }
Expand Down
138 changes: 120 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
//! API slightly based off wiremock in that you start a server
use crate::responder::{pending, ResponseStream};
use crate::responder::{pending, MapResponder, ResponseStream, StreamResponse};
use crate::utils::*;
use axum::{
extract::{
ws::{Message as AxumMessage, WebSocket, WebSocketUpgrade},
ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage, WebSocket, WebSocketUpgrade},
Path, Query,
},
http::header::HeaderMap,
response::Response,
routing::any,
Extension, Router,
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::collections::HashMap;
use std::future::IntoFuture;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use std::time::Instant;
use tokio::sync::{oneshot, RwLock};
use tracing::{debug, Instrument};
use tokio::sync::{broadcast, oneshot, RwLock};
use tracing::{debug, error, Instrument};
use tungstenite::{
protocol::{frame::Utf8Bytes, CloseFrame},
Message,
Expand All @@ -32,6 +33,7 @@ pub mod utils;
pub mod prelude {
pub use super::*;
pub use crate::matchers::*;
pub use crate::responder::*;
}

type MockList = Arc<RwLock<Vec<Mock>>>;
Expand Down Expand Up @@ -69,10 +71,34 @@ enum MatchStatus {
///
/// With that in mind how do we select which mock is active when we have matchers that act on the
/// initial request parameters and ones which act on the received messages? Our matching parameter
/// is already more complicated being a `Option<bool>` to handle the case where there's no
/// path/header matching and only body matching.
/// is already more complicated to handle the case where there's no path/header matching and only
/// body matching.
///
/// ## The Plan
///
/// The plan is to move matching from a simple "yes/no/undetermined" to a 4 state system:
///
/// 1. Mismatch: One or more matchers is false
/// 2. Potential: No matchers match but none have rejected the request
/// 3. Partial: Some matchers match the request
/// 4. Full: All matchers match the request - this is unambiguous
///
/// If we're in state 3 or 4 at the request point we'll pick the combination of most complete
/// status and highest priority. Otherwise, the list of potential matchers is used for the request
/// checking and each message we check and if we find a partial or higher we'll select the one with
/// the highest priority.
///
/// Once a mock has been selected as the active mock, we'll then start passing the messages into
/// the responder and the mock server will start sending messages back (if it's not a silent
/// responder).
///
/// ## Drawbacks
///
/// I'm not currently keeping track of which matchers are triggered and which aren't. This is maybe
/// a mistake and potentially I should be making sure that at some point during the request they've
/// all matched. I'm also not tracking which levels matchesr match at:
/// request/message/all-of-the-above. These two things may have to change in order to be actually
/// useful but I'm kicking the can on that decision down the road a bit.
#[derive(Clone)]
pub struct Mock {
matcher: Vec<Arc<dyn Match + Send + Sync + 'static>>,
Expand Down Expand Up @@ -117,6 +143,19 @@ impl Mock {
self
}

pub fn set_responder(mut self, responder: impl ResponseStream + Send + Sync + 'static) -> Self {
self.responder = Arc::new(responder);
self
}

pub fn one_to_one_response<F>(mut self, map_fn: F) -> Self
where
F: Fn(Message) -> Message + Send + Sync + 'static,
{
self.responder = Arc::new(MapResponder::new(map_fn));
self
}

/// You can use this to verify the mock separately to the one you put into the server (if
/// you've cloned it).
pub fn verify(&self) -> bool {
Expand Down Expand Up @@ -206,7 +245,7 @@ async fn ws_handler(
let mocks = mocks.read().await.clone();
for (index, mock) in mocks.iter().enumerate() {
let mock_status = mock.check_request(&path, &headers, &params);

debug!("Mock status: {:?}", mock_status);
if mock_status != MatchStatus::Mismatch {
active_mocks.push(index);
}
Expand Down Expand Up @@ -250,22 +289,64 @@ fn convert_message(msg: AxumMessage) -> Message {
}
}

fn unconvert_message(msg: Message) -> AxumMessage {
match msg {
Message::Text(t) => AxumMessage::Text(t.as_str().into()),
Message::Binary(b) => AxumMessage::Binary(b.into()),
Message::Ping(p) => AxumMessage::Ping(p.into()),
Message::Pong(p) => AxumMessage::Pong(p.into()),
Message::Close(cf) => AxumMessage::Close(cf.map(|cf| AxumCloseFrame {
code: cf.code.into(),
reason: cf.reason.as_str().into(),
})),
Message::Frame(_) => unreachable!(),
}
}

async fn handle_socket(mut socket: WebSocket, mocks: MockList, mut active_mocks: Vec<usize>) {
debug!("Active mock indexes are: {:?}", active_mocks);
// Clone the mocks present when the connection comes in
let mocks: Vec<Mock> = mocks.read().await.clone();
let mut active_mocks = active_mocks
.iter()
.filter_map(|m| mocks.get(*m))
.collect::<Vec<&Mock>>();
while let Some(msg) = socket.recv().await {

let (sender, mut receiver) = socket.split();
let (mut msg_tx, msg_rx) = broadcast::channel(128);

let mut receiver_holder = Some(msg_rx);
let mut sender_holder = Some(sender);

let mut sender_task = if active_mocks.len() == 1 {
let stream = active_mocks[0]
.responder
.handle(receiver_holder.take().unwrap());
let sender = sender_holder.take().unwrap();
let handle = tokio::task::spawn(async move {
stream
.map(|x| Ok(unconvert_message(x)))
.forward(sender)
.await
});
debug!("Spawned responder task");
Some(handle)
} else {
debug!("Ambiguous matching, responder launch pending");
None
};

while let Some(msg) = receiver.next().await {
if let Ok(msg) = msg {
let msg = convert_message(msg);
if let Err(e) = msg_tx.send(msg.clone()) {
error!("Dropping messages");
}
debug!("Checking: {:?}", msg);
if active_mocks.len() == 1 {
if matches!(
active_mocks[0].check_message(&msg),
MatchStatus::Full | MatchStatus::Partial
) {
let status = active_mocks[0].check_message(&msg);
debug!("Active mock status: {:?}", status);
if matches!(status, MatchStatus::Full | MatchStatus::Partial) {
active_mocks[0].register_hit();
}
} else {
Expand Down Expand Up @@ -305,6 +386,18 @@ async fn handle_socket(mut socket: WebSocket, mocks: MockList, mut active_mocks:
let active_mock = active_mocks.remove(index);
active_mocks = vec![active_mock];
mocks[index].register_hit();
let stream = active_mocks[0]
.responder
.handle(receiver_holder.take().unwrap());
let sender = sender_holder.take().unwrap();
let handle = tokio::task::spawn(async move {
stream
.map(|x| Ok(unconvert_message(x)))
.forward(sender)
.await
});
debug!("Spawned responder task");
sender_task = Some(handle);
}
None => {
continue;
Expand All @@ -315,6 +408,9 @@ async fn handle_socket(mut socket: WebSocket, mocks: MockList, mut active_mocks:
}
}
}
if let Some(hnd) = sender_task {
hnd.await.unwrap().unwrap();
}
}

impl MockServer {
Expand Down Expand Up @@ -373,18 +469,24 @@ impl MockServer {
}

pub async fn verify(&self) {
assert!(self.mocks_pass().await);
}

pub async fn mocks_pass(&self) -> bool {
let mut res = true;
for (index, mock) in self.mocks.read().await.iter().enumerate() {
println!("Checking {:?} {:?}", mock.expected_calls, mock.calls);
let mock_res = mock.verify();
match &mock.name {
None => debug!("Checking mock[{}]", index),
Some(name) => debug!("Checking mock: {}", name),
}
assert!(mock.verify())
debug!(
"Expected {:?} Actual {:?}: {}",
mock.expected_calls, mock.calls, mock_res
);
res &= mock_res;
}
}

pub async fn mocks_pass(&self) -> bool {
self.mocks.read().await.iter().all(|x| x.verify())
res
}
}

Expand Down
73 changes: 61 additions & 12 deletions src/responder.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
use async_stream::stream;
use futures::{
stream::{self, BoxStream},
Stream, StreamExt,
};
use std::future::ready;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;
use tokio::time::sleep;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::wrappers::BroadcastStream;
use tracing::warn;
use tungstenite::Message;

// Design thoughts I want:
Expand All @@ -25,20 +30,64 @@ use tungstenite::Message;
// responder take the last client message as an output.

pub trait ResponseStream {
fn handle(self, input: mpsc::Receiver<Message>) -> BoxStream<'static, Message>;
fn handle(&self, input: broadcast::Receiver<Message>) -> BoxStream<'static, Message>;
}

impl<S> ResponseStream for S
where
S: Stream<Item = Message> + Send + Sync + 'static,
{
fn handle(self, _: mpsc::Receiver<Message>) -> BoxStream<'static, Message> {
self.boxed()
pub struct StreamResponse {
stream_ctor: Arc<dyn Fn() -> BoxStream<'static, Message> + Send + Sync + 'static>,
}

impl StreamResponse {
pub fn new<F, S>(ctor: F) -> Self
where
F: Fn() -> S + Send + Sync + 'static,
S: Stream<Item = Message> + Send + Sync + 'static,
{
let stream_ctor = Arc::new(move || ctor().boxed());
Self { stream_ctor }
}
}

impl ResponseStream for StreamResponse {
fn handle(&self, _: broadcast::Receiver<Message>) -> BoxStream<'static, Message> {
(self.stream_ctor)()
}
}

pub fn pending() -> StreamResponse {
StreamResponse::new(stream::pending)
}

pub struct MapResponder {
map: Arc<dyn Fn(Message) -> Message + Send + Sync + 'static>,
}

impl MapResponder {
pub fn new<F: Fn(Message) -> Message + Send + Sync + 'static>(f: F) -> Self {
Self { map: Arc::new(f) }
}
}

pub fn pending() -> impl ResponseStream + Send + Sync + 'static {
stream::pending()
impl ResponseStream for MapResponder {
fn handle(&self, input: broadcast::Receiver<Message>) -> BoxStream<'static, Message> {
let map_fn = Arc::clone(&self.map);

let mut input = BroadcastStream::new(input);

let stream = stream! {
for await value in input {
match value {
Ok(v) => yield map_fn(v),
Err(e) => {
warn!("Broadcast error: {}", e);
}
}
}
};
stream.boxed()
}
}

// TODO we need rate throttling UX at some point
pub fn echo_response() -> MapResponder {
MapResponder::new(|msg| msg)
}
6 changes: 6 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ impl Default for TimesEnum {
}
}

impl From<RangeFull> for Times {
fn from(r: RangeFull) -> Self {
Times(TimesEnum::Unbounded(r))
}
}

impl From<u64> for Times {
fn from(x: u64) -> Self {
Times(TimesEnum::Exact(x))
Expand Down
34 changes: 33 additions & 1 deletion tests/simple.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use bytes::Bytes;
use futures::SinkExt;
use futures::{SinkExt, StreamExt};
use serde_json::json;
use std::time::Duration;
use tokio::time::sleep;
Expand Down Expand Up @@ -253,3 +253,35 @@ async fn combine_request_and_content_matchers() {

assert!(server.mocks_pass().await);
}

#[tokio::test]
#[traced_test]
async fn echo_response_test() {
let server = MockServer::start().await;

let responder = echo_response();

server
.register(
Mock::given(path("api/stream"))
.add_matcher(ValidJsonMatcher)
.set_responder(responder)
.expect(..),
)
.await;

let (mut stream, response) = connect_async(format!("{}/api/stream", server.uri()))
.await
.unwrap();

// Send a message just to show it doesn't change anything.
let val = json!({"hello": "world"});
let sent_message = Message::text(val.to_string());
stream.send(sent_message.clone()).await.unwrap();

let echoed = stream.next().await.unwrap().unwrap();

assert_eq!(sent_message, echoed);

assert!(server.mocks_pass().await);
}