diff --git a/embassy/src/util/select.rs b/embassy/src/util/select.rs index b3b03d81..ccc50f11 100644 --- a/embassy/src/util/select.rs +++ b/embassy/src/util/select.rs @@ -2,6 +2,175 @@ use core::future::Future; use core::pin::Pin; use core::task::{Context, Poll}; +#[derive(Debug, Clone)] +pub enum Either { + First(A), + Second(B), +} + +/// Wait for one of two futures to complete. +/// +/// This function returns a new future which polls all the futures. +/// When one of them completes, it will complete with its result value. +/// +/// The other future is dropped. +pub fn select(a: A, b: B) -> Select +where + A: Future, + B: Future, +{ + Select { a, b } +} + +/// Future for the [`select`] function. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Select { + a: A, + b: B, +} + +impl Unpin for Select {} + +impl Future for Select +where + A: Future, + B: Future, +{ + type Output = Either; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + let a = unsafe { Pin::new_unchecked(&mut this.a) }; + let b = unsafe { Pin::new_unchecked(&mut this.b) }; + if let Poll::Ready(x) = a.poll(cx) { + return Poll::Ready(Either::First(x)); + } + if let Poll::Ready(x) = b.poll(cx) { + return Poll::Ready(Either::Second(x)); + } + Poll::Pending + } +} + +// ==================================================================== + +#[derive(Debug, Clone)] +pub enum Either3 { + First(A), + Second(B), + Third(C), +} + +/// Same as [`select`], but with more futures. +pub fn select3(a: A, b: B, c: C) -> Select3 +where + A: Future, + B: Future, + C: Future, +{ + Select3 { a, b, c } +} + +/// Future for the [`select3`] function. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Select3 { + a: A, + b: B, + c: C, +} + +impl Future for Select3 +where + A: Future, + B: Future, + C: Future, +{ + type Output = Either3; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + let a = unsafe { Pin::new_unchecked(&mut this.a) }; + let b = unsafe { Pin::new_unchecked(&mut this.b) }; + let c = unsafe { Pin::new_unchecked(&mut this.c) }; + if let Poll::Ready(x) = a.poll(cx) { + return Poll::Ready(Either3::First(x)); + } + if let Poll::Ready(x) = b.poll(cx) { + return Poll::Ready(Either3::Second(x)); + } + if let Poll::Ready(x) = c.poll(cx) { + return Poll::Ready(Either3::Third(x)); + } + Poll::Pending + } +} + +// ==================================================================== + +#[derive(Debug, Clone)] +pub enum Either4 { + First(A), + Second(B), + Third(C), + Fourth(D), +} + +/// Same as [`select`], but with more futures. +pub fn select4(a: A, b: B, c: C, d: D) -> Select4 +where + A: Future, + B: Future, + C: Future, + D: Future, +{ + Select4 { a, b, c, d } +} + +/// Future for the [`select4`] function. +#[derive(Debug)] +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct Select4 { + a: A, + b: B, + c: C, + d: D, +} + +impl Future for Select4 +where + A: Future, + B: Future, + C: Future, + D: Future, +{ + type Output = Either4; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + let a = unsafe { Pin::new_unchecked(&mut this.a) }; + let b = unsafe { Pin::new_unchecked(&mut this.b) }; + let c = unsafe { Pin::new_unchecked(&mut this.c) }; + let d = unsafe { Pin::new_unchecked(&mut this.d) }; + if let Poll::Ready(x) = a.poll(cx) { + return Poll::Ready(Either4::First(x)); + } + if let Poll::Ready(x) = b.poll(cx) { + return Poll::Ready(Either4::Second(x)); + } + if let Poll::Ready(x) = c.poll(cx) { + return Poll::Ready(Either4::Third(x)); + } + if let Poll::Ready(x) = d.poll(cx) { + return Poll::Ready(Either4::Fourth(x)); + } + Poll::Pending + } +} + +// ==================================================================== + /// Future for the [`select_all`] function. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"]