use crate::RcLike;use crate::Resource;use futures::future::FusedFuture;use futures::future::Future;use futures::future::TryFuture;use futures::task::Context;use futures::task::Poll;use pin_project::pin_project;use std::any::type_name;use std::borrow::Cow;use std::error::Error;use std::fmt;use std::fmt::Display;use std::fmt::Formatter;use std::io;use std::pin::Pin;use std::rc::Rc;
use self::internal as i;
#[derive(Debug, Default)]pub struct CancelHandle { node: i::Node,}
impl CancelHandle { pub fn new() -> Self { Default::default() }
pub fn new_rc() -> Rc<Self> { Rc::new(Self::new()) }
pub fn cancel(&self) { self.node.cancel(); }
pub fn is_canceled(&self) -> bool { self.node.is_canceled() }}
#[pin_project(project = CancelableProjection)]#[derive(Debug)]pub enum Cancelable<F> { Pending { #[pin] future: F, #[pin] registration: i::Registration, }, Terminated,}
impl<F: Future> Future for Cancelable<F> { type Output = Result<F::Output, Canceled>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { let poll_result = match self.as_mut().project() { CancelableProjection::Pending { future, registration, } => Self::poll_pending(future, registration, cx), CancelableProjection::Terminated => { panic!("{}::poll() called after completion", type_name::<Self>()) } }; if matches!(poll_result, Poll::Ready(_)) { self.set(Cancelable::Terminated) } poll_result }}
impl<F: Future> FusedFuture for Cancelable<F> { fn is_terminated(&self) -> bool { matches!(self, Self::Terminated) }}
impl Resource for CancelHandle { fn name(&self) -> Cow<str> { "cancellation".into() }
fn close(self: Rc<Self>) { self.cancel(); }}
#[pin_project(project = TryCancelableProjection)]#[derive(Debug)]pub struct TryCancelable<F> { #[pin] inner: Cancelable<F>,}
impl<F, T, E> Future for TryCancelable<F>where F: Future<Output = Result<T, E>>, Canceled: Into<E>,{ type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { let TryCancelableProjection { inner } = self.project(); match inner.poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(result)) => Poll::Ready(result), Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())), } }}
impl<F, T, E> FusedFuture for TryCancelable<F>where F: Future<Output = Result<T, E>>, Canceled: Into<E>,{ fn is_terminated(&self) -> bool { self.inner.is_terminated() }}
pub trait CancelFuturewhere Self: Future + Sized,{ fn or_cancel<H: RcLike<CancelHandle>>( self, cancel_handle: H, ) -> Cancelable<Self> { Cancelable::new(self, cancel_handle.into()) }}
impl<F> CancelFuture for F where F: Future {}
pub trait CancelTryFuturewhere Self: TryFuture + Sized, Canceled: Into<Self::Error>,{ fn try_or_cancel<H: RcLike<CancelHandle>>( self, cancel_handle: H, ) -> TryCancelable<Self> { TryCancelable::new(self, cancel_handle.into()) }}
impl<F> CancelTryFuture for Fwhere F: TryFuture, Canceled: Into<F::Error>,{}
#[derive(Copy, Clone, Default, Debug, Eq, Hash, PartialEq)]pub struct Canceled;
impl Display for Canceled { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "operation canceled") }}
impl Error for Canceled {}
impl From<Canceled> for io::Error { fn from(_: Canceled) -> Self { io::Error::new(io::ErrorKind::Interrupted, Canceled) }}
mod internal { use super::CancelHandle; use super::Cancelable; use super::Canceled; use super::TryCancelable; use crate::RcRef; use futures::future::Future; use futures::task::Context; use futures::task::Poll; use futures::task::Waker; use pin_project::pin_project; use std::any::Any; use std::cell::UnsafeCell; use std::marker::PhantomPinned; use std::mem::replace; use std::pin::Pin; use std::ptr::NonNull; use std::rc::Rc; use std::rc::Weak;
impl<F: Future> Cancelable<F> { pub(super) fn new(future: F, cancel_handle: RcRef<CancelHandle>) -> Self { let head_node = RcRef::map(cancel_handle, |r| &r.node); let registration = Registration::WillRegister { head_node }; Self::Pending { future, registration, } }
pub(super) fn poll_pending( future: Pin<&mut F>, mut registration: Pin<&mut Registration>, cx: &mut Context, ) -> Poll<Result<F::Output, Canceled>> { let node = match &*registration { Registration::WillRegister { head_node } => &*head_node, Registration::Registered { node } => node, }; if node.is_canceled() { return Poll::Ready(Err(Canceled)); }
match future.poll(cx) { Poll::Ready(res) => return Poll::Ready(Ok(res)), Poll::Pending => {} }
let head_node = match &*registration { Registration::WillRegister { .. } => { match registration.as_mut().project_replace(Default::default()) { RegistrationProjectionOwned::WillRegister { head_node } => { Some(head_node) } _ => unreachable!(), } } _ => None, }; let node = match registration.project() { RegistrationProjection::Registered { node } => node, _ => unreachable!(), }; node.register(cx.waker(), head_node)?;
Poll::Pending } }
impl<F: Future> TryCancelable<F> { pub(super) fn new(future: F, cancel_handle: RcRef<CancelHandle>) -> Self { Self { inner: Cancelable::new(future, cancel_handle), } } }
#[pin_project(project = RegistrationProjection, project_replace = RegistrationProjectionOwned)] #[derive(Debug)] pub enum Registration { WillRegister { head_node: RcRef<Node>, }, Registered { #[pin] node: Node, }, }
impl Default for Registration { fn default() -> Self { Self::Registered { node: Default::default(), } } }
#[derive(Debug)] pub struct Node { inner: UnsafeCell<NodeInner>, _pin: PhantomPinned, }
impl Node { pub fn register( &self, waker: &Waker, head_rc: Option<RcRef<Node>>, ) -> Result<(), Canceled> { match head_rc.as_ref().map(RcRef::split) { Some((head, rc)) => { assert_ne!(self, head); #[allow(clippy::undocumented_unsafe_blocks)] let self_inner = unsafe { &mut *self.inner.get() }; #[allow(clippy::undocumented_unsafe_blocks)] let head_inner = unsafe { &mut *head.inner.get() }; self_inner.link(waker, head_inner, rc) } None => { #[allow(clippy::undocumented_unsafe_blocks)] let inner = unsafe { &mut *self.inner.get() }; inner.update_waker(waker) } } }
pub fn cancel(&self) { #[allow(clippy::undocumented_unsafe_blocks)] let inner = unsafe { &mut *self.inner.get() }; inner.cancel(); }
pub fn is_canceled(&self) -> bool { #[allow(clippy::undocumented_unsafe_blocks)] let inner = unsafe { &mut *self.inner.get() }; inner.is_canceled() } }
impl Default for Node { fn default() -> Self { Self { inner: UnsafeCell::new(NodeInner::Unlinked), _pin: PhantomPinned, } } }
impl Drop for Node { fn drop(&mut self) { #[allow(clippy::undocumented_unsafe_blocks)] let inner = unsafe { &mut *self.inner.get() }; inner.unlink(); } }
impl PartialEq for Node { fn eq(&self, other: &Self) -> bool { std::ptr::eq(self, other) } }
#[derive(Debug)] enum NodeInner { Unlinked, Linked { kind: NodeKind, prev: NonNull<NodeInner>, next: NonNull<NodeInner>, }, Canceled, }
impl NodeInner { fn as_non_null(&mut self) -> NonNull<Self> { NonNull::from(self) }
fn link( &mut self, waker: &Waker, head: &mut Self, rc_pin: &Rc<dyn Any>, ) -> Result<(), Canceled> { assert!(matches!(self, NodeInner::Unlinked));
match head { NodeInner::Unlinked => { *head = NodeInner::Linked { kind: NodeKind::head(rc_pin), prev: self.as_non_null(), next: self.as_non_null(), }; *self = NodeInner::Linked { kind: NodeKind::item(waker), prev: head.as_non_null(), next: head.as_non_null(), }; Ok(()) } NodeInner::Linked { kind: NodeKind::Head { .. }, prev: next_prev_nn, .. } => { #[allow(clippy::undocumented_unsafe_blocks)] let prev = unsafe { &mut *next_prev_nn.as_ptr() }; match prev { NodeInner::Linked { kind: NodeKind::Item { .. }, next: prev_next_nn, .. } => { *self = NodeInner::Linked { kind: NodeKind::item(waker), prev: replace(next_prev_nn, self.as_non_null()), next: replace(prev_next_nn, self.as_non_null()), }; Ok(()) } _ => unreachable!(), } } NodeInner::Canceled => Err(Canceled), _ => unreachable!(), } }
fn update_waker(&mut self, new_waker: &Waker) -> Result<(), Canceled> { match self { NodeInner::Unlinked => Ok(()), NodeInner::Linked { kind: NodeKind::Item { waker }, .. } => { if !waker.will_wake(new_waker) { *waker = new_waker.clone(); } Ok(()) } NodeInner::Canceled => Err(Canceled), _ => unreachable!(), } }
fn unlink(&mut self) { if let NodeInner::Linked { prev: mut prev_nn, next: mut next_nn, .. } = replace(self, NodeInner::Unlinked) { if prev_nn == next_nn { #[allow(clippy::undocumented_unsafe_blocks)] let other = unsafe { prev_nn.as_mut() }; *other = NodeInner::Unlinked; } else { #[allow(clippy::undocumented_unsafe_blocks)] match unsafe { prev_nn.as_mut() } { NodeInner::Linked { next: prev_next_nn, .. } => { *prev_next_nn = next_nn; } _ => unreachable!(), } #[allow(clippy::undocumented_unsafe_blocks)] match unsafe { next_nn.as_mut() } { NodeInner::Linked { prev: next_prev_nn, .. } => { *next_prev_nn = prev_nn; } _ => unreachable!(), } } } }
fn cancel(&mut self) { let mut head_nn = NonNull::from(self);
#[allow(clippy::undocumented_unsafe_blocks)] let mut item_nn = match replace(unsafe { head_nn.as_mut() }, NodeInner::Canceled) { NodeInner::Linked { kind: NodeKind::Head { .. }, next: next_nn, .. } => next_nn, NodeInner::Unlinked | NodeInner::Canceled => return, _ => unreachable!(), };
while item_nn != head_nn { #[allow(clippy::undocumented_unsafe_blocks)] match replace(unsafe { item_nn.as_mut() }, NodeInner::Canceled) { NodeInner::Linked { kind: NodeKind::Item { waker }, next: next_nn, .. } => { waker.wake(); item_nn = next_nn; } _ => unreachable!(), } } }
fn is_canceled(&self) -> bool { match self { NodeInner::Unlinked | NodeInner::Linked { .. } => false, NodeInner::Canceled => true, } } }
#[derive(Debug)] enum NodeKind { Head { _weak_pin: Weak<dyn Any>, }, Item { waker: Waker, }, }
impl NodeKind { fn head(rc_pin: &Rc<dyn Any>) -> Self { let _weak_pin = Rc::downgrade(rc_pin); Self::Head { _weak_pin } }
fn item(waker: &Waker) -> Self { let waker = waker.clone(); Self::Item { waker } } }}
#[cfg(test)]mod tests { use super::*; use anyhow::Error; use futures::future::pending; use futures::future::poll_fn; use futures::future::ready; use futures::future::FutureExt; use futures::future::TryFutureExt; use futures::pending; use futures::select; use futures::task::noop_waker_ref; use futures::task::Context; use futures::task::Poll; use std::convert::Infallible as Never; use std::io; use tokio::net::TcpStream; use tokio::spawn; use tokio::task::yield_now;
fn box_fused<'a, F: FusedFuture + 'a>( future: F, ) -> Pin<Box<dyn FusedFuture<Output = F::Output> + 'a>> { Box::pin(future) }
async fn ready_in_n(name: &str, count: usize) -> &str { let mut remaining = count as isize; poll_fn(|_| { assert!(remaining >= 0); if remaining == 0 { Poll::Ready(name) } else { remaining -= 1; Poll::Pending } }) .await }
#[test] fn cancel_future() { let cancel_now = CancelHandle::new_rc(); let cancel_at_0 = CancelHandle::new_rc(); let cancel_at_1 = CancelHandle::new_rc(); let cancel_at_4 = CancelHandle::new_rc(); let cancel_never = CancelHandle::new_rc();
cancel_now.cancel();
let mut futures = vec![ box_fused(ready("A").or_cancel(&cancel_now)), box_fused(ready("B").or_cancel(&cancel_at_0)), box_fused(ready("C").or_cancel(&cancel_at_1)), box_fused( ready_in_n("D", 0) .or_cancel(&cancel_never) .try_or_cancel(&cancel_now), ), box_fused( ready_in_n("E", 1) .or_cancel(&cancel_at_1) .try_or_cancel(&cancel_at_1), ), box_fused(ready_in_n("F", 2).or_cancel(&cancel_at_1)), box_fused(ready_in_n("G", 3).or_cancel(&cancel_at_4)), box_fused(ready_in_n("H", 4).or_cancel(&cancel_at_4)), box_fused(ready_in_n("I", 5).or_cancel(&cancel_at_4)), box_fused(ready_in_n("J", 5).map(Ok)), box_fused(ready_in_n("K", 5).or_cancel(cancel_never)), ];
let mut cx = Context::from_waker(noop_waker_ref());
for i in 0..=5 { match i { 0 => cancel_at_0.cancel(), 1 => cancel_at_1.cancel(), 4 => cancel_at_4.cancel(), 2 | 3 | 5 => {} _ => unreachable!(), }
let results = futures .iter_mut() .filter(|fut| !fut.is_terminated()) .filter_map(|fut| match fut.poll_unpin(&mut cx) { Poll::Pending => None, Poll::Ready(res) => Some(res), }) .collect::<Vec<_>>();
match i { 0 => assert_eq!( results, [Err(Canceled), Err(Canceled), Ok("C"), Err(Canceled)] ), 1 => assert_eq!(results, [Err(Canceled), Err(Canceled)]), 2 => assert_eq!(results, []), 3 => assert_eq!(results, [Ok("G")]), 4 => assert_eq!(results, [Err(Canceled), Err(Canceled)]), 5 => assert_eq!(results, [Ok("J"), Ok("K")]), _ => unreachable!(), } }
assert!(!futures.into_iter().any(|fut| !fut.is_terminated()));
let cancel_handles = [cancel_now, cancel_at_0, cancel_at_1, cancel_at_4]; assert!(!cancel_handles.iter().any(|c| !c.is_canceled())); }
#[tokio::test] async fn cancel_try_future() { { let cancel_handle = Rc::new(CancelHandle::new()); let future = spawn(async { panic!("the task should not be spawned") }) .map_err(Error::from) .try_or_cancel(&cancel_handle); cancel_handle.cancel(); let error = future.await.unwrap_err(); assert!(error.downcast_ref::<Canceled>().is_some()); assert_eq!(error.to_string().as_str(), "operation canceled"); }
{ let cancel_handle = Rc::new(CancelHandle::new()); let result = loop { select! { r = TcpStream::connect("1.2.3.4:12345") .try_or_cancel(&cancel_handle) => break r, default => cancel_handle.cancel(), }; }; let error = result.unwrap_err(); assert_eq!(error.kind(), io::ErrorKind::Interrupted); assert_eq!(error.to_string().as_str(), "operation canceled"); } }
#[tokio::test] async fn future_cancels_itself_before_completion() { let cancel_handle = CancelHandle::new_rc(); let result = async { cancel_handle.cancel(); yield_now().await; unreachable!(); } .or_cancel(&cancel_handle) .await; assert_eq!(result.unwrap_err(), Canceled); }
#[tokio::test] async fn future_cancels_itself_and_hangs() { let cancel_handle = CancelHandle::new_rc(); let result = async { yield_now().await; cancel_handle.cancel(); pending!(); unreachable!(); } .or_cancel(&cancel_handle) .await; assert_eq!(result.unwrap_err(), Canceled); }
#[tokio::test] async fn future_cancels_itself_and_completes() { let cancel_handle = CancelHandle::new_rc(); let result = async { yield_now().await; cancel_handle.cancel(); Ok::<_, io::Error>("done") } .try_or_cancel(&cancel_handle) .await; assert_eq!(result.unwrap(), "done"); }
#[test] fn cancel_handle_pinning() { let mut cancel_handle = CancelHandle::new_rc();
assert!(Rc::get_mut(&mut cancel_handle).is_some());
let mut future = pending::<Never>().or_cancel(&cancel_handle); let future = unsafe { Pin::new_unchecked(&mut future) };
assert!(Rc::get_mut(&mut cancel_handle).is_none());
let mut cx = Context::from_waker(noop_waker_ref()); assert!(future.poll(&mut cx).is_pending());
assert!(Rc::get_mut(&mut cancel_handle).is_none());
cancel_handle.cancel();
assert!(Rc::get_mut(&mut cancel_handle).is_some()); }}