diff --git a/embassy/Cargo.toml b/embassy/Cargo.toml index d2649024..c03fc0df 100644 --- a/embassy/Cargo.toml +++ b/embassy/Cargo.toml @@ -39,6 +39,7 @@ embedded-hal = "0.2.5" cast = { version = "=0.2.3", default-features = false } [dev-dependencies] +embassy = { path = ".", features = ["executor-agnostic"] } futures-executor = { version = "0.3", features = [ "thread-pool" ] } futures-test = "0.3" futures-timer = "0.3" diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 8f1bba76..580c6794 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs @@ -51,6 +51,7 @@ use super::CriticalSectionMutex; use super::Mutex; use super::NoopMutex; use super::ThreadModeMutex; +use super::WakerRegistration; /// Send values to the associated `Receiver`. /// @@ -149,7 +150,7 @@ where Ok(v) => Poll::Ready(Some(v)), Err(TryRecvError::Closed) => Poll::Ready(None), Err(TryRecvError::Empty) => { - self.channel.get().set_receiver_waker(cx.waker().clone()); + self.channel.get().set_receiver_waker(&cx.waker()); Poll::Pending } } @@ -282,10 +283,7 @@ where Ok(..) => Poll::Ready(Ok(())), Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), Err(TrySendError::Full(..)) => { - self.sender - .channel - .get() - .set_senders_waker(cx.waker().clone()); + self.sender.channel.get().set_senders_waker(&cx.waker()); Poll::Pending // Note we leave the existing UnsafeCell contents - they still // contain the original message. We could create another UnsafeCell @@ -312,10 +310,7 @@ where if self.sender.is_closed() { Poll::Ready(()) } else { - self.sender - .channel - .get() - .set_senders_waker(cx.waker().clone()); + self.sender.channel.get().set_senders_waker(&cx.waker()); Poll::Pending } } @@ -400,8 +395,8 @@ struct ChannelState { closed: bool, receiver_registered: bool, senders_registered: u32, - receiver_waker: Option, - senders_waker: Option, + receiver_waker: WakerRegistration, + senders_waker: WakerRegistration, } impl ChannelState { @@ -416,8 +411,8 @@ impl ChannelState { let closed = false; let receiver_registered = false; let senders_registered = 0; - let receiver_waker = None; - let senders_waker = None; + let receiver_waker = WakerRegistration::new(); + let senders_waker = WakerRegistration::new(); ChannelState { buf, read_pos, @@ -534,9 +529,7 @@ where if state.read_pos != state.write_pos || state.full { if state.full { state.full = false; - if let Some(w) = state.senders_waker.take() { - w.wake(); - } + state.senders_waker.wake(); } let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() }; @@ -546,9 +539,7 @@ where Err(TryRecvError::Empty) } else { state.closed = true; - if let Some(w) = state.senders_waker.take() { - w.wake(); - } + state.senders_waker.wake(); Err(TryRecvError::Closed) } } else { @@ -567,9 +558,7 @@ where if state.write_pos == state.read_pos { state.full = true; } - if let Some(w) = state.receiver_waker.take() { - w.wake(); - } + state.receiver_waker.wake(); Ok(()) } else { Err(TrySendError::Full(message)) @@ -583,9 +572,7 @@ where fn close(&mut self) { let state = &mut self.state; self.mutex.lock(|_| { - if let Some(w) = state.receiver_waker.take() { - w.wake(); - } + state.receiver_waker.wake(); state.closing = true; }); } @@ -608,9 +595,7 @@ where self.mutex.lock(|_| { if state.receiver_registered { state.closed = true; - if let Some(w) = state.senders_waker.take() { - w.wake(); - } + state.senders_waker.wake(); } state.receiver_registered = false; }) @@ -629,38 +614,30 @@ where assert!(state.senders_registered > 0); state.senders_registered -= 1; if state.senders_registered == 0 { - if let Some(w) = state.receiver_waker.take() { - w.wake(); - } + state.receiver_waker.wake(); state.closing = true; } }) } - fn set_receiver_waker(&mut self, receiver_waker: Waker) { + fn set_receiver_waker(&mut self, receiver_waker: &Waker) { let state = &mut self.state; self.mutex.lock(|_| { - state.receiver_waker = Some(receiver_waker); + state.receiver_waker.register(receiver_waker); }) } - fn set_senders_waker(&mut self, senders_waker: Waker) { + fn set_senders_waker(&mut self, senders_waker: &Waker) { let state = &mut self.state; self.mutex.lock(|_| { - // Dispose of any existing sender causing them to be polled again. // This could cause a spin given multiple concurrent senders, however given that // most sends only block waiting for the receiver to become active, this should // be a short-lived activity. The upside is a greatly simplified implementation // that avoids the need for intrusive linked-lists and unsafe operations on pinned // pointers. - if let Some(waker) = state.senders_waker.clone() { - if !senders_waker.will_wake(&waker) { - trace!("Waking an an active send waker due to being superseded with a new one. While benign, please report this."); - waker.wake(); - } - } - state.senders_waker = Some(senders_waker); + state.senders_waker.wake(); + state.senders_waker.register(senders_waker); }) } }