Skip to content

Commit

Permalink
Refactoring atomic waker
Browse files Browse the repository at this point in the history
  • Loading branch information
DoumanAsh committed Feb 10, 2024
1 parent 40f2ec1 commit 86ecb39
Showing 1 changed file with 60 additions and 54 deletions.
114 changes: 60 additions & 54 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//!State module

use core::{ptr, task};
use core::{ptr, task, hint, mem};
use core::cell::UnsafeCell;
use core::sync::atomic::{AtomicBool, AtomicU8, Ordering};

Expand Down Expand Up @@ -58,52 +58,26 @@ pub struct AtomicWaker {
waker: UnsafeCell<task::Waker>,
}

impl AtomicWaker {
fn new() -> Self {
Self {
state: AtomicU8::new(WAITING),
waker: UnsafeCell::new(noop::waker()),
}
struct StateRestore<F: Fn()>(F);
impl<F: Fn()> Drop for StateRestore<F> {
fn drop(&mut self) {
(self.0)()
}
}

///This is the same function as `register` but working with owned version.
fn register_owned(&self, waker: task::Waker) {
match self.state.compare_exchange(WAITING, REGISTERING, Ordering::Acquire, Ordering::Acquire).unwrap_or_else(|err| err) {
macro_rules! impl_register {
($this:ident($waker:ident) { $($impl:tt)+ }) => {
match $this.state.compare_exchange(WAITING, REGISTERING, Ordering::Acquire, Ordering::Acquire).unwrap_or_else(|err| err) {
WAITING => {
unsafe {
//unconditionally store since we already have ownership
*self.waker.get() = waker;

match self.state.compare_exchange(REGISTERING, WAITING, Ordering::AcqRel, Ordering::Acquire) {
Ok(_) => {}
Err(actual) => {
debug_assert_eq!(actual, REGISTERING | WAKING);
//Make sure we do not stuck in REGISTERING state
let state_guard = StateRestore(|| {
$this.state.store(WAITING, Ordering::Release);
});

let mut waker = noop::waker();
ptr::swap(self.waker.get(), &mut waker);

self.state.swap(WAITING, Ordering::AcqRel);
waker.wake();
}
}
}
}
WAKING => {
waker.wake();
}
state => debug_assert!(state == REGISTERING || state == REGISTERING | WAKING),
}
}

fn register(&self, waker: &task::Waker) {
match self.state.compare_exchange(WAITING, REGISTERING, Ordering::Acquire, Ordering::Acquire).unwrap_or_else(|err| err) {
WAITING => {
unsafe {
// Lock acquired, update the waker cell
if !(*self.waker.get()).will_wake(waker) {
//Clone new waker if it is definitely not the same as old one
*self.waker.get() = waker.clone();
}
$(
$impl
)+

// Release the lock. If the state transitioned to include
// the `WAKING` bit, this means that a wake has been
Expand All @@ -112,8 +86,10 @@ impl AtomicWaker {
//
// Start by assuming that the state is `REGISTERING` as this
// is what we jut set it to.
match self.state.compare_exchange(REGISTERING, WAITING, Ordering::AcqRel, Ordering::Acquire) {
Ok(_) => {}
match $this.state.compare_exchange(REGISTERING, WAITING, Ordering::AcqRel, Ordering::Acquire) {
Ok(_) => {
mem::forget(state_guard);
}
Err(actual) => {
// This branch can only be reached if a
// concurrent thread called `wake`. In this
Expand All @@ -122,10 +98,11 @@ impl AtomicWaker {
debug_assert_eq!(actual, REGISTERING | WAKING);

let mut waker = noop::waker();
ptr::swap(self.waker.get(), &mut waker);
ptr::swap($this.waker.get(), &mut waker);

// Just swap, because no one could change state while state == `REGISTERING` | `WAKING`.
self.state.swap(WAITING, Ordering::AcqRel);
// Just restore state,
// because no one could change state while state == `REGISTERING` | `WAKING`.
drop(state_guard);
waker.wake();
}
}
Expand All @@ -135,7 +112,8 @@ impl AtomicWaker {
// Currently in the process of waking the task, i.e.,
// `wake` is currently being called on the old task handle.
// So, we call wake on the new waker
waker.wake_by_ref();
$waker.wake_by_ref();
hint::spin_loop();
}
state => {
// In this case, a concurrent thread is holding the
Expand All @@ -147,9 +125,37 @@ impl AtomicWaker {
// call to `register`.
debug_assert!(
state == REGISTERING ||
state == REGISTERING | WAKING);
state == REGISTERING | WAKING
);
}
}
};
}

impl AtomicWaker {
fn new() -> Self {
Self {
state: AtomicU8::new(WAITING),
waker: UnsafeCell::new(noop::waker()),
}
}

///This is the same function as `register` but working with owned version.
fn register(&self, waker: task::Waker) {
impl_register!(self(waker) {
//unconditionally store since we already have ownership
*self.waker.get() = waker;
});
}

fn register_ref(&self, waker: &task::Waker) {
impl_register!(self(waker) {
// Lock acquired, update the waker cell
if !(*self.waker.get()).will_wake(waker) {
//Clone new waker if it is definitely not the same as old one
*self.waker.get() = waker.clone();
}
});
}

fn wake(&self) {
Expand All @@ -175,11 +181,11 @@ impl AtomicWaker {
// Nothing more to do as the `WAKING` bit has been set. It
// doesn't matter if there are concurrent registering threads or
// not.
//
debug_assert!(
state == REGISTERING ||
state == REGISTERING | WAKING ||
state == WAKING);
state == WAKING
);
}
}
}
Expand Down Expand Up @@ -256,19 +262,19 @@ pub trait Callback {
impl<'a> Callback for &'a task::Waker {
#[inline(always)]
fn register(self, waker: &AtomicWaker) {
waker.register(self)
waker.register_ref(self)
}
}

impl Callback for task::Waker {
#[inline(always)]
fn register(self, waker: &AtomicWaker) {
waker.register_owned(self)
waker.register(self)
}
}

impl Callback for fn() {
fn register(self, waker: &AtomicWaker) {
waker.register_owned(plain_fn::waker(self));
waker.register(plain_fn::waker(self));
}
}

0 comments on commit 86ecb39

Please sign in to comment.