From 7849edce1b49b8f0d8c68c2e9009920d0f59cccf Mon Sep 17 00:00:00 2001 From: Devdutt Shenoi Date: Wed, 19 Jun 2024 21:29:17 +0530 Subject: [PATCH] fix: action cancellation state loss (#351) * refactor: don't reallocate * fix: cancellation state loss on done response * test: action cancellation * refactor: inline function * doc: comment for posterity sake --- uplink/src/base/bridge/actions_lane.rs | 330 +++++++++++++++++++++---- 1 file changed, 278 insertions(+), 52 deletions(-) diff --git a/uplink/src/base/bridge/actions_lane.rs b/uplink/src/base/bridge/actions_lane.rs index 2adbcf3b..f26e6b9b 100644 --- a/uplink/src/base/bridge/actions_lane.rs +++ b/uplink/src/base/bridge/actions_lane.rs @@ -148,10 +148,6 @@ impl ActionsBridge { CtrlTx { inner: self.ctrl_tx.clone() } } - fn clear_current_action(&mut self) { - self.current_action.take(); - } - pub async fn start(&mut self) -> Result<(), Error> { let mut metrics_timeout = interval(self.config.stream_metrics.timeout); let mut end: Pin> = Box::pin(time::sleep(Duration::from_secs(u64::MAX))); @@ -168,7 +164,6 @@ impl ActionsBridge { } self.handle_action(action).await; - } response = self.status_rx.recv_async() => { @@ -192,7 +187,7 @@ impl ActionsBridge { self.forward_action_error(&action_id, Error::ActionTimeout).await; // Remove action because it timedout - self.clear_current_action(); + self.current_action.take(); continue; } @@ -206,7 +201,7 @@ impl ActionsBridge { if route.try_send(cancel_action).is_err() { error!("Couldn't cancel action ({}) on timeout", cancellation.action_id); // Remove action anyways - self.clear_current_action(); + self.current_action.take(); continue; } @@ -276,7 +271,7 @@ impl ActionsBridge { }; // Remove action because it couldn't be routed - self.clear_current_action(); + self.current_action.take(); // Ignore sending failure status to backend. This makes // backend retry action. @@ -367,10 +362,9 @@ impl ActionsBridge { /// Handle received actions fn try_route_action(&mut self, action: Action) -> Result<(), Error> { - let route = self - .action_routes - .get(&action.name) - .ok_or_else(|| Error::NoRoute(action.name.clone()))?; + let Some(route) = self.action_routes.get(&action.name) else { + return Err(Error::NoRoute(action.name)); + }; let deadline = route.try_send(action.clone()).map_err(|_| Error::UnresponsiveReceiver)?; // current action left unchanged in case of new tunshell action @@ -423,6 +417,7 @@ impl ActionsBridge { self.current_action.take() { if response.is_failed() { + // NOTE: action need not actually have been cancelled let response = ActionResponse::success(&cancel_action); self.streams.forward(response).await; } else { @@ -430,56 +425,46 @@ impl ActionsBridge { self.forward_action_error(&cancel_action, Error::FailedCancellation).await } } + return; } // Forward actions included in the config to the appropriate forward route, when // they have reached 100% progress but haven't been marked as "Completed"/"Finished". - if response.is_done() { - let mut action = self.current_action.take().unwrap().action; - + if response.is_done() && self.current_action.is_some() { + let CurrentAction { action, .. } = self.current_action.as_mut().unwrap(); if let Some(a) = response.done_response.take() { - action = a; + *action = a; } - match self.redirect_action(&mut action).await { - Ok(_) => (), - Err(Error::NoRoute(_)) => { - // NOTE: send success reponse for actions that don't have redirections configured - warn!("Action redirection is not configured for: {:?}", action); - let response = ActionResponse::success(&action.action_id); - self.streams.forward(response).await; - - if let Some(CurrentAction { cancelled_by: Some(cancel_action), .. }) = - self.current_action.take() - { - // Marks the cancellation as a failure as action has reached completion without being cancelled - self.forward_action_error(&cancel_action, Error::FailedCancellation).await - } - } - Err(Error::Cancelled(cancel_action)) => { - let response = ActionResponse::success(&cancel_action); - self.streams.forward(response).await; - - self.forward_action_error(&action.action_id, Error::Cancelled(cancel_action)) - .await; - } - Err(e) => self.forward_action_error(&action.action_id, e).await, - } + self.redirect_current_action().await; } } - async fn redirect_action(&mut self, action: &mut Action) -> Result<(), Error> { - let fwd_name = self - .action_redirections - .get(&action.name) - .ok_or_else(|| Error::NoRoute(action.name.clone()))?; + async fn redirect_current_action(&mut self) { + let CurrentAction { mut action, cancelled_by, .. } = self.current_action.take().unwrap(); + + let Some(fwd_name) = self.action_redirections.get(&action.name) else { + // NOTE: send success reponse for actions that don't have redirections configured + warn!("Action redirection is not configured for: {:?}", action); + let response = ActionResponse::success(&action.action_id); + self.streams.forward(response).await; + + if let Some(cancel_action) = cancelled_by { + // Marks the cancellation as a failure as action has reached completion without being cancelled + self.forward_action_error(&cancel_action, Error::FailedCancellation).await + } + return; + }; // Cancelled action should not be redirected - if let Some(CurrentAction { cancelled_by: Some(cancel_action), .. }) = - self.current_action.take() - { - return Err(Error::Cancelled(cancel_action)); + if let Some(cancel_action) = cancelled_by { + let response = ActionResponse::success(&cancel_action); + self.streams.forward(response).await; + + self.forward_action_error(&action.action_id, Error::Cancelled(cancel_action)).await; + + return; } debug!( @@ -488,9 +473,9 @@ impl ActionsBridge { ); fwd_name.clone_into(&mut action.name); - self.try_route_action(action.clone())?; - - Ok(()) + if let Err(e) = self.try_route_action(action.clone()) { + self.forward_action_error(&action.action_id, e).await + } } async fn forward_action_error(&mut self, action_id: &str, error: Error) { @@ -1104,4 +1089,245 @@ mod tests { assert_eq!(status.action_id, "1"); assert!(status.is_completed()); } + + #[tokio::test] + async fn cancel_action() { + let tmpdir = tempdir::TempDir::new("bridge").unwrap(); + std::env::set_current_dir(&tmpdir).unwrap(); + let config = default_config(); + let (mut bridge, actions_tx, data_rx) = create_bridge(Arc::new(config)); + + let bridge_tx_1 = bridge.status_tx(); + let (route_tx, action_rx_1) = bounded(1); + let test_route = ActionRoute { + name: "test".to_string(), + timeout: Duration::from_secs(30), + cancellable: true, + }; + bridge.register_action_route(test_route, route_tx).unwrap(); + + spawn_bridge(bridge); + + std::thread::spawn(move || { + let rt = Runtime::new().unwrap(); + let action = action_rx_1.recv().unwrap(); + assert_eq!(action.action_id, "1"); + let response = ActionResponse::progress(&action.action_id, "Running", 0); + rt.block_on(bridge_tx_1.send_action_response(response)); + let cancel_action = action_rx_1.recv().unwrap(); + assert_eq!(cancel_action.action_id, "2"); + assert_eq!(cancel_action.name, "cancel_action"); + let response = ActionResponse::failure(&action.action_id, "Cancelled"); + rt.block_on(bridge_tx_1.send_action_response(response)); + }); + + std::thread::sleep(Duration::from_secs(1)); + + let action = Action { + action_id: "1".to_string(), + name: "test".to_string(), + payload: "test".to_string(), + }; + actions_tx.send(action).unwrap(); + + std::thread::sleep(Duration::from_secs(1)); + + let action = Action { + action_id: "2".to_string(), + name: "cancel_action".to_string(), + payload: r#"{"action_id": "1", "name": "test"}"#.to_string(), + }; + actions_tx.send(action).unwrap(); + + let mut responses = Responses { rx: data_rx, responses: vec![] }; + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "1"); + assert_eq!(state, "Received"); + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "1"); + assert_eq!(state, "Running"); + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "2"); + assert_eq!(state, "Received"); + + let status = responses.next(); + assert_eq!(status.action_id, "1"); + assert!(status.is_failed()); + assert_eq!(status.errors, ["Cancelled"]); + + let status = responses.next(); + assert_eq!(status.action_id, "2"); + assert!(status.is_completed()); + } + + #[tokio::test] + async fn cancel_action_failure() { + let tmpdir = tempdir::TempDir::new("bridge").unwrap(); + std::env::set_current_dir(&tmpdir).unwrap(); + let config = default_config(); + let (mut bridge, actions_tx, data_rx) = create_bridge(Arc::new(config)); + + let bridge_tx_1 = bridge.status_tx(); + let (route_tx, action_rx_1) = bounded(1); + let test_route = ActionRoute { + name: "test".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; + bridge.register_action_route(test_route, route_tx).unwrap(); + + spawn_bridge(bridge); + + std::thread::spawn(move || { + let rt = Runtime::new().unwrap(); + let action = action_rx_1.recv().unwrap(); + assert_eq!(action.action_id, "1"); + let response = ActionResponse::progress(&action.action_id, "Running", 0); + rt.block_on(bridge_tx_1.send_action_response(response)); + + std::thread::sleep(Duration::from_secs(1)); + let response = ActionResponse::success(&action.action_id); + rt.block_on(bridge_tx_1.send_action_response(response)); + }); + + std::thread::sleep(Duration::from_secs(1)); + + let action = Action { + action_id: "1".to_string(), + name: "test".to_string(), + payload: "test".to_string(), + }; + actions_tx.send(action).unwrap(); + + std::thread::sleep(Duration::from_secs(1)); + + let action = Action { + action_id: "2".to_string(), + name: "cancel_action".to_string(), + payload: r#"{"action_id": "1", "name": "test"}"#.to_string(), + }; + actions_tx.send(action).unwrap(); + + let mut responses = Responses { rx: data_rx, responses: vec![] }; + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "1"); + assert_eq!(state, "Received"); + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "1"); + assert_eq!(state, "Running"); + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "2"); + assert_eq!(state, "Received"); + + let status = responses.next(); + assert_eq!(status.action_id, "1"); + assert!(status.is_completed()); + + let status = responses.next(); + assert_eq!(status.action_id, "2"); + assert!(status.is_failed()); + assert_eq!(status.errors, ["Cancellation request failed as action completed execution!"]); + } + + #[tokio::test] + async fn cancel_action_between_redirect() { + let tmpdir = tempdir::TempDir::new("bridge").unwrap(); + std::env::set_current_dir(&tmpdir).unwrap(); + let mut config = default_config(); + config.action_redirections.insert("test".to_string(), "redirect".to_string()); + let (mut bridge, actions_tx, data_rx) = create_bridge(Arc::new(config)); + + let bridge_tx_1 = bridge.status_tx(); + let (route_tx, action_rx_1) = bounded(1); + let test_route = ActionRoute { + name: "test".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; + bridge.register_action_route(test_route, route_tx).unwrap(); + + let bridge_tx_2 = bridge.status_tx(); + let (route_tx, action_rx_2) = bounded(1); + let test_route = ActionRoute { + name: "redirect".to_string(), + timeout: Duration::from_secs(30), + cancellable: false, + }; + bridge.register_action_route(test_route, route_tx).unwrap(); + + spawn_bridge(bridge); + + std::thread::spawn(move || { + let rt = Runtime::new().unwrap(); + let action = action_rx_2.recv().unwrap(); + assert_eq!(action.action_id, "1"); + let response = ActionResponse::progress(&action.action_id, "Running", 0); + rt.block_on(bridge_tx_1.send_action_response(response)); + std::thread::sleep(Duration::from_secs(3)); + let response = ActionResponse::progress(&action.action_id, "Finished", 100); + rt.block_on(bridge_tx_1.send_action_response(response)); + }); + + std::thread::spawn(move || { + let rt = Runtime::new().unwrap(); + let action = action_rx_1.recv().unwrap(); + assert_eq!(action.action_id, "1"); + let response = ActionResponse::progress(&action.action_id, "Running", 0); + rt.block_on(bridge_tx_2.send_action_response(response)); + std::thread::sleep(Duration::from_secs(3)); + let response = ActionResponse::progress(&action.action_id, "Finished", 100); + rt.block_on(bridge_tx_2.send_action_response(response)); + }); + + std::thread::sleep(Duration::from_secs(1)); + + let action = Action { + action_id: "1".to_string(), + name: "test".to_string(), + payload: "test".to_string(), + }; + actions_tx.send(action).unwrap(); + + std::thread::sleep(Duration::from_secs(1)); + + let action = Action { + action_id: "2".to_string(), + name: "cancel_action".to_string(), + payload: r#"{"action_id": "1", "name": "test"}"#.to_string(), + }; + actions_tx.send(action).unwrap(); + + let mut responses = Responses { rx: data_rx, responses: vec![] }; + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "1"); + assert_eq!(state, "Received"); + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "1"); + assert_eq!(state, "Running"); + + let ActionResponse { action_id, state, .. } = responses.next(); + assert_eq!(action_id, "2"); + assert_eq!(state, "Received"); + + let status = responses.next(); + assert_eq!(status.action_id, "1"); + assert!(status.is_done()); + + let status = responses.next(); + assert_eq!(status.action_id, "2"); + assert!(status.is_completed()); + + let status = responses.next(); + assert_eq!(status.action_id, "1"); + assert!(status.is_failed()); + assert_eq!(status.errors, ["Action cancelled by action_id: 2"]); + } }