Skip to content

Commit

Permalink
fix(ampd): ensure that txs get confirmed when broadcast with an ampd …
Browse files Browse the repository at this point in the history
…subcommand
  • Loading branch information
cgorenflo committed Aug 22, 2024
1 parent e2278d2 commit e73ed1e
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 161 deletions.
109 changes: 26 additions & 83 deletions ampd/src/asyncutil/future.rs
Original file line number Diff line number Diff line change
@@ -1,102 +1,45 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;

use futures::{Future, FutureExt};
use futures::Future;
use tokio::time;

pub fn with_retry<F, Fut, R, Err>(
future: F,
policy: RetryPolicy,
) -> impl Future<Output = Result<R, Err>>
pub async fn with_retry<F, Fut, R, Err>(mut future: F, policy: RetryPolicy) -> Result<R, Err>
where
F: Fn() -> Fut,
F: FnMut() -> Fut,
Fut: Future<Output = Result<R, Err>>,
{
RetriableFuture::new(future, policy)
}

pub enum RetryPolicy {
RepeatConstant { sleep: Duration, max_attempts: u64 },
}

struct RetriableFuture<F, Fut, R, Err>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<R, Err>>,
{
future: F,
inner: Pin<Box<Fut>>,
policy: RetryPolicy,
err_count: u64,
}

impl<F, Fut, R, Err> Unpin for RetriableFuture<F, Fut, R, Err>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<R, Err>>,
{
}

impl<F, Fut, R, Err> RetriableFuture<F, Fut, R, Err>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<R, Err>>,
{
fn new(get_future: F, policy: RetryPolicy) -> Self {
let future = get_future();

Self {
future: get_future,
inner: Box::pin(future),
policy,
err_count: 0,
let mut err_count = 0u64;
loop {
match future().await {
Ok(result) => return Ok(result),
Err(err) => {
err_count = err_count.saturating_add(1);
handle_err(err, err_count, policy).await?
}
}
}
}

fn handle_err(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
error: Err,
) -> Poll<Result<R, Err>> {
self.err_count = self.err_count.saturating_add(1);

match self.policy {
RetryPolicy::RepeatConstant {
sleep,
max_attempts,
} => {
if self.err_count >= max_attempts {
return Poll::Ready(Err(error));
}

self.inner = Box::pin((self.future)());
async fn handle_err<Err>(err: Err, err_count: u64, policy: RetryPolicy) -> Result<(), Err> {
match policy {
RetryPolicy::RepeatConstant {
sleep,
max_attempts,
} => {
if err_count >= max_attempts {
return Err(err);
}

let waker = cx.waker().clone();
tokio::spawn(time::sleep(sleep).then(|_| async {
waker.wake();
}));
time::sleep(sleep).await;

Poll::Pending
}
Ok(())
}
}
}

impl<F, Fut, R, Err> Future for RetriableFuture<F, Fut, R, Err>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<R, Err>>,
{
type Output = Result<R, Err>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.inner.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(result)) => Poll::Ready(Ok(result)),
Poll::Ready(Err(error)) => self.handle_err(cx, error),
}
}
#[derive(Copy, Clone)]
pub enum RetryPolicy {
RepeatConstant { sleep: Duration, max_attempts: u64 },
}

#[cfg(test)]
Expand Down
116 changes: 60 additions & 56 deletions ampd/src/broadcaster/confirm_tx.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
use std::sync::Arc;
use std::time::Duration;

use axelar_wasm_std::FnExt;
use cosmrs::proto::cosmos::tx::v1beta1::{GetTxRequest, GetTxResponse};
use error_stack::{report, Report, Result};
use error_stack::{bail, Report, Result};
use futures::{StreamExt, TryFutureExt};
use thiserror::Error;
use tokio::sync::{mpsc, Mutex};
use tokio::time;
use tokio_stream::wrappers::ReceiverStream;
use tonic::Status;
use tracing::error;

use super::cosmos;
use crate::asyncutil::future::{with_retry, RetryPolicy};

#[derive(Debug, PartialEq)]
pub enum TxStatus {
Expand Down Expand Up @@ -53,24 +54,17 @@ pub enum Error {
SendTxRes(#[from] Box<mpsc::error::SendError<TxResponse>>),
}

enum ConfirmationResult {
Confirmed(Box<TxResponse>),
NotFound,
GRPCError(Status),
}

pub struct TxConfirmer<T>
pub struct ConfirmationCtx<T>
where
T: cosmos::BroadcastClient,
{
client: T,
sleep: Duration,
max_attempts: u32,
retry_policy: RetryPolicy,
tx_hash_receiver: mpsc::Receiver<String>,
tx_res_sender: mpsc::Sender<TxResponse>,
}

impl<T> TxConfirmer<T>
impl<T> ConfirmationCtx<T>
where
T: cosmos::BroadcastClient,
{
Expand All @@ -83,8 +77,10 @@ where
) -> Self {
Self {
client,
sleep,
max_attempts,
retry_policy: RetryPolicy::RepeatConstant {
sleep,
max_attempts: max_attempts.into(),
},
tx_hash_receiver,
tx_res_sender,
}
Expand All @@ -93,23 +89,19 @@ where
pub async fn run(self) -> Result<(), Error> {
let Self {
client,
sleep,
max_attempts,
retry_policy,
tx_hash_receiver,
tx_res_sender,
} = self;
let limit = tx_hash_receiver.capacity();
let client = Mutex::new(client);
let client = Arc::new(Mutex::new(client));

let mut tx_hash_stream = ReceiverStream::new(tx_hash_receiver)
.map(|tx_hash| {
confirm_tx(&client, tx_hash, sleep, max_attempts).and_then(|tx| async {
tx_res_sender
.send(tx)
.await
.map_err(Box::new)
.map_err(Into::into)
.map_err(Report::new)
})
// multiple instances of confirm_tx can be spawned due to buffer_unordered,
// so we need to clone the client to avoid a deadlock
confirm_tx(client.clone(), tx_hash, retry_policy)
.and_then(|tx| async { send_response(&tx_res_sender, tx).await })
})
.buffer_unordered(limit);

Expand All @@ -121,50 +113,62 @@ where
}
}

async fn confirm_tx<T>(
client: &Mutex<T>,
async fn confirm_tx(
client: Arc<Mutex<impl cosmos::BroadcastClient>>,
tx_hash: String,
sleep: Duration,
attempts: u32,
) -> Result<TxResponse, Error>
where
T: cosmos::BroadcastClient,
{
for i in 0..attempts {
retry_policy: RetryPolicy,
) -> Result<TxResponse, Error> {
async fn confirm(
client: Arc<Mutex<impl cosmos::BroadcastClient>>,
tx_hash: String,
) -> Result<TxResponse, Error> {
let req = GetTxRequest {
hash: tx_hash.clone(),
};

match client.lock().await.tx(req).await.then(evaluate_tx_response) {
ConfirmationResult::Confirmed(tx) => return Ok(*tx),
ConfirmationResult::NotFound if i == attempts.saturating_sub(1) => {
return Err(report!(Error::Confirmation { tx_hash }))
}
ConfirmationResult::GRPCError(status) if i == attempts.saturating_sub(1) => {
return Err(report!(Error::Grpc { status, tx_hash }))
}
_ => time::sleep(sleep).await,
}
client
.lock()
.await
.tx(req)
.await
.then(evaluate_tx_response(tx_hash))
}

unreachable!("confirmation loop should have returned by now")
with_retry(|| confirm(client.clone(), tx_hash.clone()), retry_policy).await
}

fn evaluate_tx_response(
response: core::result::Result<GetTxResponse, Status>,
) -> ConfirmationResult {
match response {
Err(status) => ConfirmationResult::GRPCError(status),
tx_hash: String,
) -> impl Fn(core::result::Result<GetTxResponse, Status>) -> Result<TxResponse, Error> {
move |response| match response {
Err(status) => bail!(Error::Grpc {
status,
tx_hash: tx_hash.clone()
}),
Ok(GetTxResponse {
tx_response: None, ..
}) => ConfirmationResult::NotFound,
}) => bail!(Error::Confirmation {
tx_hash: tx_hash.clone()
}),
Ok(GetTxResponse {
tx_response: Some(response),
..
}) => ConfirmationResult::Confirmed(Box::new(response.into())),
}) => Ok(response.into()),
}
}

async fn send_response(
tx_res_sender: &mpsc::Sender<TxResponse>,
tx: TxResponse,
) -> Result<(), Error> {
tx_res_sender
.send(tx)
.await

Check warning on line 166 in ampd/src/broadcaster/confirm_tx.rs

View check run for this annotation

Codecov / codecov/patch

ampd/src/broadcaster/confirm_tx.rs#L166

Added line #L166 was not covered by tests
.map_err(Box::new)
.map_err(Into::into)
.map_err(Report::new)
}

#[cfg(test)]
mod test {
use std::time::Duration;
Expand All @@ -174,7 +178,7 @@ mod test {
use tokio::sync::mpsc;
use tokio::test;

use super::{Error, TxConfirmer, TxResponse, TxStatus};
use super::{ConfirmationCtx, Error, TxResponse, TxStatus};
use crate::broadcaster::cosmos::MockBroadcastClient;

#[test]
Expand Down Expand Up @@ -203,7 +207,7 @@ mod test {
let (tx_confirmer_sender, tx_confirmer_receiver) = mpsc::channel(100);
let (tx_res_sender, mut tx_res_receiver) = mpsc::channel(100);

let tx_confirmer = TxConfirmer::new(
let tx_confirmer = ConfirmationCtx::new(
client,
sleep,
max_attempts,
Expand Down Expand Up @@ -250,7 +254,7 @@ mod test {
let (tx_confirmer_sender, tx_confirmer_receiver) = mpsc::channel(100);
let (tx_res_sender, mut tx_res_receiver) = mpsc::channel(100);

let tx_confirmer = TxConfirmer::new(
let tx_confirmer = ConfirmationCtx::new(
client,
sleep,
max_attempts,
Expand Down Expand Up @@ -289,7 +293,7 @@ mod test {
let (tx_confirmer_sender, tx_confirmer_receiver) = mpsc::channel(100);
let (tx_res_sender, _tx_res_receiver) = mpsc::channel(100);

let tx_confirmer = TxConfirmer::new(
let tx_confirmer = ConfirmationCtx::new(
client,
sleep,
max_attempts,
Expand Down Expand Up @@ -328,7 +332,7 @@ mod test {
let (tx_confirmer_sender, tx_confirmer_receiver) = mpsc::channel(100);
let (tx_res_sender, _tx_res_receiver) = mpsc::channel(100);

let tx_confirmer = TxConfirmer::new(
let tx_confirmer = ConfirmationCtx::new(
client,
sleep,
max_attempts,
Expand Down
Loading

0 comments on commit e73ed1e

Please sign in to comment.