From 5c2bf3981ed81a7e2ff6300056c671733c810566 Mon Sep 17 00:00:00 2001 From: Dario Nieuwenhuis Date: Thu, 18 Mar 2021 00:20:02 +0100 Subject: [PATCH] Move Task into raw --- embassy-macros/src/lib.rs | 7 +-- embassy/src/executor/mod.rs | 64 +---------------------- embassy/src/executor/raw.rs | 79 ++++++++++++++++++++++++++--- embassy/src/executor/run_queue.rs | 10 ++-- embassy/src/executor/timer_queue.rs | 16 +++--- embassy/src/executor/waker.rs | 10 ++-- embassy/src/util/waker.rs | 6 +-- 7 files changed, 99 insertions(+), 93 deletions(-) diff --git a/embassy-macros/src/lib.rs b/embassy-macros/src/lib.rs index 710c5a15..f207497d 100644 --- a/embassy-macros/src/lib.rs +++ b/embassy-macros/src/lib.rs @@ -100,11 +100,12 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream { let result = quote! { #visibility fn #name(#args) -> ::embassy::executor::SpawnToken<#impl_ty> { + use ::embassy::executor::raw::Task; #task_fn type F = #impl_ty; - const NEW_TASK: ::embassy::executor::Task = ::embassy::executor::Task::new(); - static POOL: [::embassy::executor::Task; #pool_size] = [NEW_TASK; #pool_size]; - unsafe { ::embassy::executor::Task::spawn(&POOL, move || task(#arg_names)) } + const NEW_TASK: Task = Task::new(); + static POOL: [Task; #pool_size] = [NEW_TASK; #pool_size]; + unsafe { Task::spawn(&POOL, move || task(#arg_names)) } } }; result.into() diff --git a/embassy/src/executor/mod.rs b/embassy/src/executor/mod.rs index 10e54330..8b23264d 100644 --- a/embassy/src/executor/mod.rs +++ b/embassy/src/executor/mod.rs @@ -20,71 +20,9 @@ use crate::fmt::panic; use crate::interrupt::{Interrupt, InterruptExt}; use crate::time::Alarm; -// repr(C) is needed to guarantee that the raw::Task is located at offset 0 -// This makes it safe to cast between raw::Task and Task pointers. -#[repr(C)] -pub struct Task { - raw: raw::Task, - future: UninitCell, // Valid if STATE_SPAWNED -} - -impl Task { - pub const fn new() -> Self { - Self { - raw: raw::Task::new(), - future: UninitCell::uninit(), - } - } - - pub unsafe fn spawn(pool: &'static [Self], future: impl FnOnce() -> F) -> SpawnToken { - for task in pool { - let state = raw::STATE_SPAWNED | raw::STATE_RUN_QUEUED; - if task - .raw - .state - .compare_exchange(0, state, Ordering::AcqRel, Ordering::Acquire) - .is_ok() - { - // Initialize the task - task.raw.poll_fn.write(Self::poll); - task.future.write(future()); - - return SpawnToken { - raw_task: Some(NonNull::new_unchecked(&task.raw as *const raw::Task as _)), - phantom: PhantomData, - }; - } - } - - SpawnToken { - raw_task: None, - phantom: PhantomData, - } - } - - unsafe fn poll(p: NonNull) { - let this = &*(p.as_ptr() as *const Task); - - let future = Pin::new_unchecked(this.future.as_mut()); - let waker = waker::from_task(p); - let mut cx = Context::from_waker(&waker); - match future.poll(&mut cx) { - Poll::Ready(_) => { - this.future.drop_in_place(); - this.raw - .state - .fetch_and(!raw::STATE_SPAWNED, Ordering::AcqRel); - } - Poll::Pending => {} - } - } -} - -unsafe impl Sync for Task {} - #[must_use = "Calling a task function does nothing on its own. You must pass the returned SpawnToken to Executor::spawn()"] pub struct SpawnToken { - raw_task: Option>, + raw_task: Option>, phantom: PhantomData<*mut F>, } diff --git a/embassy/src/executor/raw.rs b/embassy/src/executor/raw.rs index 84e171df..edc6d805 100644 --- a/embassy/src/executor/raw.rs +++ b/embassy/src/executor/raw.rs @@ -1,15 +1,18 @@ use atomic_polyfill::{AtomicU32, Ordering}; use core::cell::Cell; use core::cmp::min; +use core::future::Future; use core::marker::PhantomData; +use core::pin::Pin; use core::ptr; use core::ptr::NonNull; -use core::task::Waker; +use core::task::{Context, Poll, Waker}; use super::run_queue::{RunQueue, RunQueueItem}; use super::timer_queue::{TimerQueue, TimerQueueItem}; use super::util::UninitCell; use super::waker; +use super::SpawnToken; use crate::time::{Alarm, Instant}; /// Task is spawned (has a future) @@ -19,16 +22,16 @@ pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 1; /// Task is in the executor timer queue pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 2; -pub struct Task { +pub struct TaskHeader { pub(crate) state: AtomicU32, pub(crate) run_queue_item: RunQueueItem, pub(crate) expires_at: Cell, pub(crate) timer_queue_item: TimerQueueItem, pub(crate) executor: Cell<*const Executor>, // Valid if state != 0 - pub(crate) poll_fn: UninitCell)>, // Valid if STATE_SPAWNED + pub(crate) poll_fn: UninitCell)>, // Valid if STATE_SPAWNED } -impl Task { +impl TaskHeader { pub(crate) const fn new() -> Self { Self { state: AtomicU32::new(0), @@ -64,10 +67,70 @@ impl Task { // We have just marked the task as scheduled, so enqueue it. let executor = &*self.executor.get(); - executor.enqueue(self as *const Task as *mut Task); + executor.enqueue(self as *const TaskHeader as *mut TaskHeader); } } +// repr(C) is needed to guarantee that the Task is located at offset 0 +// This makes it safe to cast between Task and Task pointers. +#[repr(C)] +pub struct Task { + raw: TaskHeader, + future: UninitCell, // Valid if STATE_SPAWNED +} + +impl Task { + pub const fn new() -> Self { + Self { + raw: TaskHeader::new(), + future: UninitCell::uninit(), + } + } + + pub unsafe fn spawn(pool: &'static [Self], future: impl FnOnce() -> F) -> SpawnToken { + for task in pool { + let state = STATE_SPAWNED | STATE_RUN_QUEUED; + if task + .raw + .state + .compare_exchange(0, state, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + // Initialize the task + task.raw.poll_fn.write(Self::poll); + task.future.write(future()); + + return SpawnToken { + raw_task: Some(NonNull::new_unchecked(&task.raw as *const TaskHeader as _)), + phantom: PhantomData, + }; + } + } + + SpawnToken { + raw_task: None, + phantom: PhantomData, + } + } + + unsafe fn poll(p: NonNull) { + let this = &*(p.as_ptr() as *const Task); + + let future = Pin::new_unchecked(this.future.as_mut()); + let waker = waker::from_task(p); + let mut cx = Context::from_waker(&waker); + match future.poll(&mut cx) { + Poll::Ready(_) => { + this.future.drop_in_place(); + this.raw.state.fetch_and(!STATE_SPAWNED, Ordering::AcqRel); + } + Poll::Pending => {} + } + } +} + +unsafe impl Sync for Task {} + pub struct Executor { run_queue: RunQueue, timer_queue: TimerQueue, @@ -95,13 +158,13 @@ impl Executor { self.signal_ctx = signal_ctx; } - unsafe fn enqueue(&self, item: *mut Task) { + unsafe fn enqueue(&self, item: *mut TaskHeader) { if self.run_queue.enqueue(item) { (self.signal_fn)(self.signal_ctx) } } - pub unsafe fn spawn(&'static self, task: NonNull) { + pub unsafe fn spawn(&'static self, task: NonNull) { let task = task.as_ref(); task.executor.set(self); self.enqueue(task as *const _ as _); @@ -154,7 +217,7 @@ impl Executor { pub use super::waker::task_from_waker; -pub unsafe fn wake_task(task: NonNull) { +pub unsafe fn wake_task(task: NonNull) { task.as_ref().enqueue(); } diff --git a/embassy/src/executor/run_queue.rs b/embassy/src/executor/run_queue.rs index 1d1023e5..08391613 100644 --- a/embassy/src/executor/run_queue.rs +++ b/embassy/src/executor/run_queue.rs @@ -2,10 +2,10 @@ use atomic_polyfill::{AtomicPtr, Ordering}; use core::ptr; use core::ptr::NonNull; -use super::raw::Task; +use super::raw::TaskHeader; pub(crate) struct RunQueueItem { - next: AtomicPtr, + next: AtomicPtr, } impl RunQueueItem { @@ -28,7 +28,7 @@ impl RunQueueItem { /// current batch is completely processed, so even if a task enqueues itself instantly (for example /// by waking its own waker) can't prevent other tasks from running. pub(crate) struct RunQueue { - head: AtomicPtr, + head: AtomicPtr, } impl RunQueue { @@ -39,7 +39,7 @@ impl RunQueue { } /// Enqueues an item. Returns true if the queue was empty. - pub(crate) unsafe fn enqueue(&self, item: *mut Task) -> bool { + pub(crate) unsafe fn enqueue(&self, item: *mut TaskHeader) -> bool { let mut prev = self.head.load(Ordering::Acquire); loop { (*item).run_queue_item.next.store(prev, Ordering::Relaxed); @@ -55,7 +55,7 @@ impl RunQueue { prev.is_null() } - pub(crate) unsafe fn dequeue_all(&self, on_task: impl Fn(NonNull)) { + pub(crate) unsafe fn dequeue_all(&self, on_task: impl Fn(NonNull)) { let mut task = self.head.swap(ptr::null_mut(), Ordering::AcqRel); while !task.is_null() { diff --git a/embassy/src/executor/timer_queue.rs b/embassy/src/executor/timer_queue.rs index d72eb93b..a6939f11 100644 --- a/embassy/src/executor/timer_queue.rs +++ b/embassy/src/executor/timer_queue.rs @@ -4,11 +4,11 @@ use core::cmp::min; use core::ptr; use core::ptr::NonNull; -use super::raw::{Task, STATE_TIMER_QUEUED}; +use super::raw::{TaskHeader, STATE_TIMER_QUEUED}; use crate::time::Instant; pub(crate) struct TimerQueueItem { - next: Cell<*mut Task>, + next: Cell<*mut TaskHeader>, } impl TimerQueueItem { @@ -20,7 +20,7 @@ impl TimerQueueItem { } pub(crate) struct TimerQueue { - head: Cell<*mut Task>, + head: Cell<*mut TaskHeader>, } impl TimerQueue { @@ -30,7 +30,7 @@ impl TimerQueue { } } - pub(crate) unsafe fn update(&self, p: NonNull) { + pub(crate) unsafe fn update(&self, p: NonNull) { let task = p.as_ref(); if task.expires_at.get() != Instant::MAX { let old_state = task.state.fetch_or(STATE_TIMER_QUEUED, Ordering::AcqRel); @@ -54,7 +54,11 @@ impl TimerQueue { res } - pub(crate) unsafe fn dequeue_expired(&self, now: Instant, on_task: impl Fn(NonNull)) { + pub(crate) unsafe fn dequeue_expired( + &self, + now: Instant, + on_task: impl Fn(NonNull), + ) { self.retain(|p| { let task = p.as_ref(); if task.expires_at.get() <= now { @@ -66,7 +70,7 @@ impl TimerQueue { }); } - pub(crate) unsafe fn retain(&self, mut f: impl FnMut(NonNull) -> bool) { + pub(crate) unsafe fn retain(&self, mut f: impl FnMut(NonNull) -> bool) { let mut prev = &self.head; while !prev.get().is_null() { let p = NonNull::new_unchecked(prev.get()); diff --git a/embassy/src/executor/waker.rs b/embassy/src/executor/waker.rs index bc02c51d..050f6a1c 100644 --- a/embassy/src/executor/waker.rs +++ b/embassy/src/executor/waker.rs @@ -2,7 +2,7 @@ use core::mem; use core::ptr::NonNull; use core::task::{RawWaker, RawWakerVTable, Waker}; -use super::raw::Task; +use super::raw::TaskHeader; const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake, drop); @@ -11,21 +11,21 @@ unsafe fn clone(p: *const ()) -> RawWaker { } unsafe fn wake(p: *const ()) { - (*(p as *mut Task)).enqueue() + (*(p as *mut TaskHeader)).enqueue() } unsafe fn drop(_: *const ()) { // nop } -pub(crate) unsafe fn from_task(p: NonNull) -> Waker { +pub(crate) unsafe fn from_task(p: NonNull) -> Waker { Waker::from_raw(RawWaker::new(p.as_ptr() as _, &VTABLE)) } -pub unsafe fn task_from_waker(waker: &Waker) -> NonNull { +pub unsafe fn task_from_waker(waker: &Waker) -> NonNull { let hack: &WakerHack = mem::transmute(waker); assert_eq!(hack.vtable, &VTABLE); - NonNull::new_unchecked(hack.data as *mut Task) + NonNull::new_unchecked(hack.data as *mut TaskHeader) } struct WakerHack { diff --git a/embassy/src/util/waker.rs b/embassy/src/util/waker.rs index 1f2d3a77..2b72fd56 100644 --- a/embassy/src/util/waker.rs +++ b/embassy/src/util/waker.rs @@ -3,12 +3,12 @@ use core::task::Waker; use atomic_polyfill::{AtomicPtr, Ordering}; -use crate::executor::raw::{task_from_waker, wake_task, Task}; +use crate::executor::raw::{task_from_waker, wake_task, TaskHeader}; /// Utility struct to register and wake a waker. #[derive(Debug)] pub struct WakerRegistration { - waker: Option>, + waker: Option>, } impl WakerRegistration { @@ -49,7 +49,7 @@ impl WakerRegistration { } pub struct AtomicWakerRegistration { - waker: AtomicPtr, + waker: AtomicPtr, } impl AtomicWakerRegistration {