Move Task into raw

This commit is contained in:
Dario Nieuwenhuis 2021-03-18 00:20:02 +01:00
parent 0cc2c67194
commit 5c2bf3981e
7 changed files with 99 additions and 93 deletions

View File

@ -100,11 +100,12 @@ pub fn task(args: TokenStream, item: TokenStream) -> TokenStream {
let result = quote! {
#visibility fn #name(#args) -> ::embassy::executor::SpawnToken<#impl_ty> {
use ::embassy::executor::raw::Task;
#task_fn
type F = #impl_ty;
const NEW_TASK: ::embassy::executor::Task<F> = ::embassy::executor::Task::new();
static POOL: [::embassy::executor::Task<F>; #pool_size] = [NEW_TASK; #pool_size];
unsafe { ::embassy::executor::Task::spawn(&POOL, move || task(#arg_names)) }
const NEW_TASK: Task<F> = Task::new();
static POOL: [Task<F>; #pool_size] = [NEW_TASK; #pool_size];
unsafe { Task::spawn(&POOL, move || task(#arg_names)) }
}
};
result.into()

View File

@ -20,71 +20,9 @@ use crate::fmt::panic;
use crate::interrupt::{Interrupt, InterruptExt};
use crate::time::Alarm;
// repr(C) is needed to guarantee that the raw::Task is located at offset 0
// This makes it safe to cast between raw::Task and Task pointers.
#[repr(C)]
pub struct Task<F: Future + 'static> {
raw: raw::Task,
future: UninitCell<F>, // Valid if STATE_SPAWNED
}
impl<F: Future + 'static> Task<F> {
pub const fn new() -> Self {
Self {
raw: raw::Task::new(),
future: UninitCell::uninit(),
}
}
pub unsafe fn spawn(pool: &'static [Self], future: impl FnOnce() -> F) -> SpawnToken<F> {
for task in pool {
let state = raw::STATE_SPAWNED | raw::STATE_RUN_QUEUED;
if task
.raw
.state
.compare_exchange(0, state, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
// Initialize the task
task.raw.poll_fn.write(Self::poll);
task.future.write(future());
return SpawnToken {
raw_task: Some(NonNull::new_unchecked(&task.raw as *const raw::Task as _)),
phantom: PhantomData,
};
}
}
SpawnToken {
raw_task: None,
phantom: PhantomData,
}
}
unsafe fn poll(p: NonNull<raw::Task>) {
let this = &*(p.as_ptr() as *const Task<F>);
let future = Pin::new_unchecked(this.future.as_mut());
let waker = waker::from_task(p);
let mut cx = Context::from_waker(&waker);
match future.poll(&mut cx) {
Poll::Ready(_) => {
this.future.drop_in_place();
this.raw
.state
.fetch_and(!raw::STATE_SPAWNED, Ordering::AcqRel);
}
Poll::Pending => {}
}
}
}
unsafe impl<F: Future + 'static> Sync for Task<F> {}
#[must_use = "Calling a task function does nothing on its own. You must pass the returned SpawnToken to Executor::spawn()"]
pub struct SpawnToken<F> {
raw_task: Option<NonNull<raw::Task>>,
raw_task: Option<NonNull<raw::TaskHeader>>,
phantom: PhantomData<*mut F>,
}

View File

@ -1,15 +1,18 @@
use atomic_polyfill::{AtomicU32, Ordering};
use core::cell::Cell;
use core::cmp::min;
use core::future::Future;
use core::marker::PhantomData;
use core::pin::Pin;
use core::ptr;
use core::ptr::NonNull;
use core::task::Waker;
use core::task::{Context, Poll, Waker};
use super::run_queue::{RunQueue, RunQueueItem};
use super::timer_queue::{TimerQueue, TimerQueueItem};
use super::util::UninitCell;
use super::waker;
use super::SpawnToken;
use crate::time::{Alarm, Instant};
/// Task is spawned (has a future)
@ -19,16 +22,16 @@ pub(crate) const STATE_RUN_QUEUED: u32 = 1 << 1;
/// Task is in the executor timer queue
pub(crate) const STATE_TIMER_QUEUED: u32 = 1 << 2;
pub struct Task {
pub struct TaskHeader {
pub(crate) state: AtomicU32,
pub(crate) run_queue_item: RunQueueItem,
pub(crate) expires_at: Cell<Instant>,
pub(crate) timer_queue_item: TimerQueueItem,
pub(crate) executor: Cell<*const Executor>, // Valid if state != 0
pub(crate) poll_fn: UninitCell<unsafe fn(NonNull<Task>)>, // Valid if STATE_SPAWNED
pub(crate) poll_fn: UninitCell<unsafe fn(NonNull<TaskHeader>)>, // Valid if STATE_SPAWNED
}
impl Task {
impl TaskHeader {
pub(crate) const fn new() -> Self {
Self {
state: AtomicU32::new(0),
@ -64,10 +67,70 @@ impl Task {
// We have just marked the task as scheduled, so enqueue it.
let executor = &*self.executor.get();
executor.enqueue(self as *const Task as *mut Task);
executor.enqueue(self as *const TaskHeader as *mut TaskHeader);
}
}
// repr(C) is needed to guarantee that the Task is located at offset 0
// This makes it safe to cast between Task and Task pointers.
#[repr(C)]
pub struct Task<F: Future + 'static> {
raw: TaskHeader,
future: UninitCell<F>, // Valid if STATE_SPAWNED
}
impl<F: Future + 'static> Task<F> {
pub const fn new() -> Self {
Self {
raw: TaskHeader::new(),
future: UninitCell::uninit(),
}
}
pub unsafe fn spawn(pool: &'static [Self], future: impl FnOnce() -> F) -> SpawnToken<F> {
for task in pool {
let state = STATE_SPAWNED | STATE_RUN_QUEUED;
if task
.raw
.state
.compare_exchange(0, state, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
// Initialize the task
task.raw.poll_fn.write(Self::poll);
task.future.write(future());
return SpawnToken {
raw_task: Some(NonNull::new_unchecked(&task.raw as *const TaskHeader as _)),
phantom: PhantomData,
};
}
}
SpawnToken {
raw_task: None,
phantom: PhantomData,
}
}
unsafe fn poll(p: NonNull<TaskHeader>) {
let this = &*(p.as_ptr() as *const Task<F>);
let future = Pin::new_unchecked(this.future.as_mut());
let waker = waker::from_task(p);
let mut cx = Context::from_waker(&waker);
match future.poll(&mut cx) {
Poll::Ready(_) => {
this.future.drop_in_place();
this.raw.state.fetch_and(!STATE_SPAWNED, Ordering::AcqRel);
}
Poll::Pending => {}
}
}
}
unsafe impl<F: Future + 'static> Sync for Task<F> {}
pub struct Executor {
run_queue: RunQueue,
timer_queue: TimerQueue,
@ -95,13 +158,13 @@ impl Executor {
self.signal_ctx = signal_ctx;
}
unsafe fn enqueue(&self, item: *mut Task) {
unsafe fn enqueue(&self, item: *mut TaskHeader) {
if self.run_queue.enqueue(item) {
(self.signal_fn)(self.signal_ctx)
}
}
pub unsafe fn spawn(&'static self, task: NonNull<Task>) {
pub unsafe fn spawn(&'static self, task: NonNull<TaskHeader>) {
let task = task.as_ref();
task.executor.set(self);
self.enqueue(task as *const _ as _);
@ -154,7 +217,7 @@ impl Executor {
pub use super::waker::task_from_waker;
pub unsafe fn wake_task(task: NonNull<Task>) {
pub unsafe fn wake_task(task: NonNull<TaskHeader>) {
task.as_ref().enqueue();
}

View File

@ -2,10 +2,10 @@ use atomic_polyfill::{AtomicPtr, Ordering};
use core::ptr;
use core::ptr::NonNull;
use super::raw::Task;
use super::raw::TaskHeader;
pub(crate) struct RunQueueItem {
next: AtomicPtr<Task>,
next: AtomicPtr<TaskHeader>,
}
impl RunQueueItem {
@ -28,7 +28,7 @@ impl RunQueueItem {
/// current batch is completely processed, so even if a task enqueues itself instantly (for example
/// by waking its own waker) can't prevent other tasks from running.
pub(crate) struct RunQueue {
head: AtomicPtr<Task>,
head: AtomicPtr<TaskHeader>,
}
impl RunQueue {
@ -39,7 +39,7 @@ impl RunQueue {
}
/// Enqueues an item. Returns true if the queue was empty.
pub(crate) unsafe fn enqueue(&self, item: *mut Task) -> bool {
pub(crate) unsafe fn enqueue(&self, item: *mut TaskHeader) -> bool {
let mut prev = self.head.load(Ordering::Acquire);
loop {
(*item).run_queue_item.next.store(prev, Ordering::Relaxed);
@ -55,7 +55,7 @@ impl RunQueue {
prev.is_null()
}
pub(crate) unsafe fn dequeue_all(&self, on_task: impl Fn(NonNull<Task>)) {
pub(crate) unsafe fn dequeue_all(&self, on_task: impl Fn(NonNull<TaskHeader>)) {
let mut task = self.head.swap(ptr::null_mut(), Ordering::AcqRel);
while !task.is_null() {

View File

@ -4,11 +4,11 @@ use core::cmp::min;
use core::ptr;
use core::ptr::NonNull;
use super::raw::{Task, STATE_TIMER_QUEUED};
use super::raw::{TaskHeader, STATE_TIMER_QUEUED};
use crate::time::Instant;
pub(crate) struct TimerQueueItem {
next: Cell<*mut Task>,
next: Cell<*mut TaskHeader>,
}
impl TimerQueueItem {
@ -20,7 +20,7 @@ impl TimerQueueItem {
}
pub(crate) struct TimerQueue {
head: Cell<*mut Task>,
head: Cell<*mut TaskHeader>,
}
impl TimerQueue {
@ -30,7 +30,7 @@ impl TimerQueue {
}
}
pub(crate) unsafe fn update(&self, p: NonNull<Task>) {
pub(crate) unsafe fn update(&self, p: NonNull<TaskHeader>) {
let task = p.as_ref();
if task.expires_at.get() != Instant::MAX {
let old_state = task.state.fetch_or(STATE_TIMER_QUEUED, Ordering::AcqRel);
@ -54,7 +54,11 @@ impl TimerQueue {
res
}
pub(crate) unsafe fn dequeue_expired(&self, now: Instant, on_task: impl Fn(NonNull<Task>)) {
pub(crate) unsafe fn dequeue_expired(
&self,
now: Instant,
on_task: impl Fn(NonNull<TaskHeader>),
) {
self.retain(|p| {
let task = p.as_ref();
if task.expires_at.get() <= now {
@ -66,7 +70,7 @@ impl TimerQueue {
});
}
pub(crate) unsafe fn retain(&self, mut f: impl FnMut(NonNull<Task>) -> bool) {
pub(crate) unsafe fn retain(&self, mut f: impl FnMut(NonNull<TaskHeader>) -> bool) {
let mut prev = &self.head;
while !prev.get().is_null() {
let p = NonNull::new_unchecked(prev.get());

View File

@ -2,7 +2,7 @@ use core::mem;
use core::ptr::NonNull;
use core::task::{RawWaker, RawWakerVTable, Waker};
use super::raw::Task;
use super::raw::TaskHeader;
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake, drop);
@ -11,21 +11,21 @@ unsafe fn clone(p: *const ()) -> RawWaker {
}
unsafe fn wake(p: *const ()) {
(*(p as *mut Task)).enqueue()
(*(p as *mut TaskHeader)).enqueue()
}
unsafe fn drop(_: *const ()) {
// nop
}
pub(crate) unsafe fn from_task(p: NonNull<Task>) -> Waker {
pub(crate) unsafe fn from_task(p: NonNull<TaskHeader>) -> Waker {
Waker::from_raw(RawWaker::new(p.as_ptr() as _, &VTABLE))
}
pub unsafe fn task_from_waker(waker: &Waker) -> NonNull<Task> {
pub unsafe fn task_from_waker(waker: &Waker) -> NonNull<TaskHeader> {
let hack: &WakerHack = mem::transmute(waker);
assert_eq!(hack.vtable, &VTABLE);
NonNull::new_unchecked(hack.data as *mut Task)
NonNull::new_unchecked(hack.data as *mut TaskHeader)
}
struct WakerHack {

View File

@ -3,12 +3,12 @@ use core::task::Waker;
use atomic_polyfill::{AtomicPtr, Ordering};
use crate::executor::raw::{task_from_waker, wake_task, Task};
use crate::executor::raw::{task_from_waker, wake_task, TaskHeader};
/// Utility struct to register and wake a waker.
#[derive(Debug)]
pub struct WakerRegistration {
waker: Option<NonNull<Task>>,
waker: Option<NonNull<TaskHeader>>,
}
impl WakerRegistration {
@ -49,7 +49,7 @@ impl WakerRegistration {
}
pub struct AtomicWakerRegistration {
waker: AtomicPtr<Task>,
waker: AtomicPtr<TaskHeader>,
}
impl AtomicWakerRegistration {