Skip to content

Commit a8f8a4b

Browse files
committed
Eliminating locking and improve safety with Pagers, Pollers
Resolves #3294
1 parent 1bef1b5 commit a8f8a4b

File tree

1 file changed

+190
-86
lines changed

1 file changed

+190
-86
lines changed

sdk/core/azure_core/src/http/pager.rs

Lines changed: 190 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
//! Types and methods for pageable responses.
55
6+
// TODO: Remove once tests re-enabled!
7+
#![allow(missing_docs, unexpected_cfgs)]
8+
69
use crate::{
710
error::ErrorKind,
811
http::{
@@ -16,6 +19,7 @@ use futures::{stream::unfold, FutureExt, Stream};
1619
use std::{
1720
fmt,
1821
future::Future,
22+
marker::PhantomData,
1923
ops::Deref,
2024
pin::Pin,
2125
str::FromStr,
@@ -457,45 +461,24 @@ impl<P: Page> ItemIterator<P> {
457461
}
458462
}
459463

460-
/// Creates a [`ItemIterator<P>`] from a raw stream of [`Result<P>`](crate::Result<P>) values.
461-
///
462-
/// This constructor is used when you are implementing a completely custom stream and want to use it as a pager.
463-
pub fn from_stream<
464-
// This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds.
465-
#[cfg(not(target_arch = "wasm32"))] S: Stream<Item = crate::Result<P>> + Send + 'static,
466-
#[cfg(target_arch = "wasm32")] S: Stream<Item = crate::Result<P>> + 'static,
467-
>(
468-
stream: S,
469-
) -> Self {
470-
Self {
471-
stream: Box::pin(stream),
472-
continuation_token: None,
473-
next_token: Default::default(),
474-
current: None,
475-
}
476-
}
477-
478-
/// Gets a [`PageIterator<P>`] to iterate over a collection of pages from a service.
479-
///
480-
/// You can use this to asynchronously iterate pages returned by a collection request to a service.
481-
/// This allows you to get the individual pages' [`Response<P>`], from which you can iterate items in each page
482-
/// or deserialize the raw response as appropriate.
483-
///
484-
/// The returned `PageIterator` resumes from the current page until _after_ all items are processed.
485-
/// It does not continue on the next page until you call `next()` after the last item in the current page
486-
/// because of how iterators are implemented. This may yield duplicates but will reduce the likelihood of skipping items instead.
487-
pub fn into_pages(self) -> PageIterator<P> {
488-
// Attempt to start paging from the current page so that we don't skip items,
489-
// assuming the service collection hasn't changed (most services don't create ephemeral snapshots).
490-
if let Ok(mut token) = self.next_token.lock() {
491-
*token = self.continuation_token;
492-
}
493-
494-
PageIterator {
495-
stream: self.stream,
496-
continuation_token: self.next_token,
497-
}
498-
}
464+
// /// Gets a [`PageIterator<P>`] to iterate over a collection of pages from a service.
465+
// ///
466+
// /// You can use this to asynchronously iterate pages returned by a collection request to a service.
467+
// /// This allows you to get the individual pages' [`Response<P>`], from which you can iterate items in each page
468+
// /// or deserialize the raw response as appropriate.
469+
// ///
470+
// /// The returned `PageIterator` resumes from the current page until _after_ all items are processed.
471+
// /// It does not continue on the next page until you call `next()` after the last item in the current page
472+
// /// because of how iterators are implemented. This may yield duplicates but will reduce the likelihood of skipping items instead.
473+
// pub fn into_pages(self) -> impl PagePager<P> {
474+
// // Attempt to start paging from the current page so that we don't skip items,
475+
// // assuming the service collection hasn't changed (most services don't create ephemeral snapshots).
476+
// if let Ok(mut token) = self.next_token.lock() {
477+
// *token = self.continuation_token;
478+
// }
479+
480+
// todo!()
481+
// }
499482

500483
/// Gets the continuation token for the current page.
501484
///
@@ -594,19 +577,48 @@ impl<P: Page> fmt::Debug for ItemIterator<P> {
594577
/// }
595578
/// # Ok(()) }
596579
/// ```
580+
#[must_use = "streams do nothing unless polled"]
597581
#[pin_project::pin_project]
598-
pub struct PageIterator<P> {
582+
pub struct PageIterator<'a, P, C, F, Fut>
583+
where
584+
C: AsRef<str> + FromStr + ConditionalSend,
585+
F: Fn(PagerState<C>, PagerOptions<'static>) -> Fut + ConditionalSend,
586+
Fut: Future<Output = crate::Result<PagerResult<P, C>>> + ConditionalSend,
587+
<C as FromStr>::Err: std::error::Error,
588+
{
599589
#[pin]
600-
stream: Pin<BoxedStream<P>>,
601-
continuation_token: Arc<Mutex<Option<String>>>,
590+
make_request: Pin<Box<F>>,
591+
continuation_token: Option<String>,
592+
options: PagerOptions<'a>,
593+
state: State<C>,
594+
added_span: bool,
595+
phantom: PhantomData<P>,
602596
}
603597

604-
impl<P> PageIterator<P> {
605-
/// Creates a [`PageIterator<P>`] from a callback that will be called repeatedly to request each page.
598+
#[cfg(not(target_arch = "wasm32"))]
599+
pub trait ConditionalSend: Send {}
600+
601+
#[cfg(not(target_arch = "wasm32"))]
602+
impl<T> ConditionalSend for T where T: Send {}
603+
604+
#[cfg(target_arch = "wasm32")]
605+
pub trait ConditionalSend {}
606+
607+
#[cfg(target_arch = "wasm32")]
608+
impl<T> ConditionalSend for T {}
609+
610+
impl<'a, P, C, F, Fut> PageIterator<'a, P, C, F, Fut>
611+
where
612+
C: AsRef<str> + FromStr + ConditionalSend,
613+
F: Fn(PagerState<C>, PagerOptions<'static>) -> Fut + ConditionalSend,
614+
Fut: Future<Output = crate::Result<PagerResult<P, C>>> + ConditionalSend,
615+
<C as FromStr>::Err: std::error::Error,
616+
{
617+
/// Creates a [`PageIterator`] from a callback that will be called repeatedly to request each page.
606618
///
607-
/// This method expect a callback that accepts a single [`PagerState<C>`] parameter, and returns a [`PagerResult<T, C>`] value asynchronously.
619+
/// This method expect a callback that accepts a single [`PagerState`] parameter, and returns a [`PagerResult`] value asynchronously.
608620
/// The `C` type parameter is the type of the next link/continuation token. It may be any [`Send`]able type.
609-
/// The result will be an asynchronous stream of [`Result<T>`](crate::Result<T>) values.
621+
/// The result will be an asynchronous stream of [`Result`](crate::Result) values.
610622
///
611623
/// The first time your callback is called, it will be called with [`PagerState::Initial`], indicating no next link/continuation token is present.
612624
///
@@ -691,76 +703,168 @@ impl<P> PageIterator<P> {
691703
/// }
692704
/// }, None);
693705
/// ```
694-
pub fn from_callback<
695-
// This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds.
696-
#[cfg(not(target_arch = "wasm32"))] C: AsRef<str> + FromStr + Send + 'static,
697-
#[cfg(not(target_arch = "wasm32"))] F: Fn(PagerState<C>, PagerOptions<'static>) -> Fut + Send + 'static,
698-
#[cfg(not(target_arch = "wasm32"))] Fut: Future<Output = crate::Result<PagerResult<P, C>>> + Send + 'static,
699-
#[cfg(target_arch = "wasm32")] C: AsRef<str> + FromStr + 'static,
700-
#[cfg(target_arch = "wasm32")] F: Fn(PagerState<C>, PagerOptions<'static>) -> Fut + 'static,
701-
#[cfg(target_arch = "wasm32")] Fut: Future<Output = crate::Result<PagerResult<P, C>>> + 'static,
702-
>(
703-
make_request: F,
704-
options: Option<PagerOptions<'static>>,
705-
) -> Self
706+
pub fn from_callback(make_request: F, options: Option<PagerOptions<'static>>) -> Self
706707
where
707708
<C as FromStr>::Err: std::error::Error,
708709
{
710+
// TODO: We'll want to delete this whole function and define a module function that returns an `ItemIterator<..>` since declaring the right type will be difficult.
709711
let options = options.unwrap_or_default();
710712

711713
// Start from the optional `PagerOptions::continuation_token`.
712-
let continuation_token = Arc::new(Mutex::new(options.continuation_token.clone()));
713-
let stream = iter_from_callback(make_request, options, continuation_token.clone());
714+
let continuation_token = options.continuation_token.clone();
714715

715716
Self {
716-
stream: Box::pin(stream),
717+
make_request: Box::pin(make_request),
717718
continuation_token,
719+
options,
720+
state: State::Init,
721+
added_span: false,
722+
phantom: PhantomData,
718723
}
719724
}
725+
}
720726

721-
/// Creates a [`PageIterator<P>`] from a raw stream of [`Result<P>`](crate::Result<P>) values.
722-
///
723-
/// This constructor is used when you are implementing a completely custom stream and want to use it as a pager.
724-
pub fn from_stream<
725-
// This is a bit gnarly, but the only thing that differs between the WASM/non-WASM configs is the presence of Send bounds.
726-
#[cfg(not(target_arch = "wasm32"))] S: Stream<Item = crate::Result<P>> + Send + 'static,
727-
#[cfg(target_arch = "wasm32")] S: Stream<Item = crate::Result<P>> + 'static,
728-
>(
729-
stream: S,
730-
) -> Self {
731-
Self {
732-
stream: Box::pin(stream),
733-
continuation_token: Default::default(),
734-
}
735-
}
736-
727+
// TODO: Rename this.
728+
pub trait PagePager<P>: futures::Stream<Item = crate::Result<P>> {
737729
/// Gets the continuation token for the current page.
738730
///
739731
/// Pass this to [`PagerOptions::continuation_token`] to create a `PageIterator` that, when first iterated,
740732
/// will return the next page. You can use this to page results across separate processes.
741-
pub fn continuation_token(&self) -> Option<String> {
742-
if let Ok(token) = self.continuation_token.lock() {
743-
return token.clone();
744-
}
733+
fn continuation_token(&self) -> Option<&str>;
734+
}
745735

746-
None
736+
impl<'a, P, C, F, Fut> PagePager<P> for PageIterator<'a, P, C, F, Fut>
737+
where
738+
C: AsRef<str> + FromStr + ConditionalSend,
739+
F: Fn(PagerState<C>, PagerOptions<'static>) -> Fut + ConditionalSend,
740+
Fut: Future<Output = crate::Result<PagerResult<P, C>>> + ConditionalSend,
741+
<C as FromStr>::Err: std::error::Error,
742+
{
743+
fn continuation_token(&self) -> Option<&str> {
744+
self.continuation_token.as_deref()
747745
}
748746
}
749747

750-
impl<P> futures::Stream for PageIterator<P> {
748+
impl<'a, P, C, F, Fut> futures::Stream for PageIterator<'a, P, C, F, Fut>
749+
where
750+
C: AsRef<str> + FromStr + ConditionalSend,
751+
F: Fn(PagerState<C>, PagerOptions<'static>) -> Fut + ConditionalSend,
752+
Fut: Future<Output = crate::Result<PagerResult<P, C>>> + ConditionalSend,
753+
<C as FromStr>::Err: std::error::Error,
754+
{
751755
type Item = crate::Result<P>;
752756

753757
fn poll_next(
754758
self: Pin<&mut Self>,
755759
cx: &mut std::task::Context<'_>,
756760
) -> std::task::Poll<Option<Self::Item>> {
757-
self.project().stream.poll_next(cx)
761+
let this = self.project();
762+
763+
// When in the "Init" state, we are either starting fresh or resuming from a continuation token.
764+
// In either case, attach a span to the context for the entire paging operation.
765+
if *this.state == State::Init {
766+
tracing::debug!("establish a public API span for new pager.");
767+
768+
// At the very start of polling, create a span for the entire request, and attach it to the context
769+
let span = create_public_api_span(&this.options.context, None, None);
770+
if let Some(s) = span {
771+
*this.added_span = true;
772+
let old_context = std::mem::replace(this.options, PagerOptions::default());
773+
*this.options = PagerOptions {
774+
context: old_context.context.with_value(s),
775+
continuation_token: old_context.continuation_token,
776+
};
777+
}
778+
}
779+
780+
// Get the `continuation_token` to pick up where we left off, or None for the initial page,
781+
// but don't override the terminal `State::Done`.
782+
if *this.state != State::Done {
783+
let next_state = match this.continuation_token.as_deref() {
784+
Some(n) => match n.parse() {
785+
Ok(s) => State::More(s),
786+
Err(err) => {
787+
let error =
788+
crate::Error::with_message_fn(ErrorKind::DataConversion, || {
789+
format!("invalid continuation token: {err}")
790+
});
791+
*this.state = State::Done;
792+
return std::task::Poll::Ready(Some(Err(error)));
793+
}
794+
},
795+
// Restart the pager if `continuation_token` is None indicating we resumed from before or within the first page.
796+
None => State::Init,
797+
};
798+
*this.state = next_state;
799+
}
800+
801+
// Poll based on current state
802+
match *this.state {
803+
State::Init => {
804+
tracing::debug!("initial page request");
805+
let mut fut = (this.make_request)(PagerState::Initial, this.options.clone());
806+
match fut.poll(cx) {
807+
std::task::Poll::Ready(result) => {
808+
let (item, next_state) =
809+
Self::handle_result(result, this.added_span, &this.options.context);
810+
*this.state = next_state.unwrap_or(State::Done);
811+
if let Ok(ref response) = item {
812+
if let State::More(ref token) = *this.state {
813+
*this.continuation_token = Some(token.as_ref().into());
814+
} else {
815+
*this.continuation_token = None;
816+
}
817+
}
818+
std::task::Poll::Ready(Some(item))
819+
}
820+
std::task::Poll::Pending => std::task::Poll::Pending,
821+
}
822+
}
823+
State::More(ref n) => {
824+
tracing::debug!("subsequent page request to {:?}", AsRef::<str>::as_ref(n));
825+
let mut fut =
826+
(this.make_request)(PagerState::More(n.clone()), this.options.clone());
827+
match Pin::new(&mut fut).poll(cx) {
828+
std::task::Poll::Ready(result) => {
829+
let (item, next_state) =
830+
Self::handle_result(result, this.added_span, &this.options.context);
831+
*this.state = next_state.unwrap_or(State::Done);
832+
if let Ok(ref response) = item {
833+
if let State::More(ref token) = *this.state {
834+
*this.continuation_token = Some(token.as_ref().into());
835+
} else {
836+
*this.continuation_token = None;
837+
}
838+
}
839+
std::task::Poll::Ready(Some(item))
840+
}
841+
std::task::Poll::Pending => std::task::Poll::Pending,
842+
}
843+
}
844+
State::Done => {
845+
tracing::debug!("done");
846+
// Set the `continuation_token` to None now that we are done.
847+
*this.continuation_token = None;
848+
std::task::Poll::Ready(None)
849+
}
850+
}
758851
}
759852
}
760853

761-
impl<P> fmt::Debug for PageIterator<P> {
854+
impl<'a, P, C, F, Fut> fmt::Debug for PageIterator<'a, P, C, F, Fut>
855+
where
856+
C: AsRef<str> + FromStr + ConditionalSend,
857+
F: Fn(PagerState<C>, PagerOptions<'static>) -> Fut + ConditionalSend,
858+
Fut: Future<Output = crate::Result<PagerResult<P, C>>> + ConditionalSend,
859+
<C as FromStr>::Err: std::error::Error,
860+
{
762861
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
763-
f.debug_struct("PageIterator").finish_non_exhaustive()
862+
f.debug_struct("PageIterator")
863+
.field("continuation_token", &self.continuation_token)
864+
.field("options", &self.options)
865+
.field("state", &self.state)
866+
.field("added_span", &self.added_span)
867+
.finish_non_exhaustive()
764868
}
765869
}
766870

0 commit comments

Comments
 (0)