Skip to content

Commit

Permalink
sync: wrap state in helper struct (#3922)
Browse files Browse the repository at this point in the history
  • Loading branch information
Darksonn committed Jul 7, 2021
1 parent 4818c2e commit be26ca7
Showing 1 changed file with 80 additions and 16 deletions.
96 changes: 80 additions & 16 deletions tokio/src/sync/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
use crate::sync::notify::Notify;

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::{Relaxed, SeqCst};
use crate::loom::sync::atomic::Ordering::Relaxed;
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard};
use std::ops;

Expand All @@ -74,7 +74,7 @@ pub struct Receiver<T> {
shared: Arc<Shared<T>>,

/// Last observed version
version: usize,
version: Version,
}

/// Sends values to the associated [`Receiver`](struct@Receiver).
Expand Down Expand Up @@ -104,7 +104,7 @@ struct Shared<T> {
///
/// The lowest bit represents a "closed" state. The rest of the bits
/// represent the current version.
version: AtomicUsize,
state: AtomicState,

/// Tracks the number of `Receiver` instances
ref_count_rx: AtomicUsize,
Expand Down Expand Up @@ -152,7 +152,69 @@ pub mod error {
impl std::error::Error for RecvError {}
}

const CLOSED: usize = 1;
use self::state::{AtomicState, Version};
mod state {
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::SeqCst;

const CLOSED: usize = 1;

/// The version part of the state. The lowest bit is always zero.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(super) struct Version(usize);

/// Snapshot of the state. The first bit is used as the CLOSED bit.
/// The remaining bits are used as the version.
#[derive(Copy, Clone, Debug)]
pub(super) struct StateSnapshot(usize);

/// The state stored in an atomic integer.
#[derive(Debug)]
pub(super) struct AtomicState(AtomicUsize);

impl Version {
/// Get the initial version when creating the channel.
pub(super) fn initial() -> Self {
Version(0)
}
}

impl StateSnapshot {
/// Extract the version from the state.
pub(super) fn version(self) -> Version {
Version(self.0 & !CLOSED)
}

/// Is the closed bit set?
pub(super) fn is_closed(self) -> bool {
(self.0 & CLOSED) == CLOSED
}
}

impl AtomicState {
/// Create a new `AtomicState` that is not closed and which has the
/// version set to `Version::initial()`.
pub(super) fn new() -> Self {
AtomicState(AtomicUsize::new(0))
}

/// Load the current value of the state.
pub(super) fn load(&self) -> StateSnapshot {
StateSnapshot(self.0.load(SeqCst))
}

/// Increment the version counter.
pub(super) fn increment_version(&self) {
// Increment by two to avoid touching the CLOSED bit.
self.0.fetch_add(2, SeqCst);
}

/// Set the closed bit in the state.
pub(super) fn set_closed(&self) {
self.0.fetch_or(CLOSED, SeqCst);
}
}
}

/// Creates a new watch channel, returning the "send" and "receive" handles.
///
Expand Down Expand Up @@ -184,7 +246,7 @@ const CLOSED: usize = 1;
pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared {
value: RwLock::new(init),
version: AtomicUsize::new(0),
state: AtomicState::new(),
ref_count_rx: AtomicUsize::new(1),
notify_rx: Notify::new(),
notify_tx: Notify::new(),
Expand All @@ -194,13 +256,16 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
shared: shared.clone(),
};

let rx = Receiver { shared, version: 0 };
let rx = Receiver {
shared,
version: Version::initial(),
};

(tx, rx)
}

impl<T> Receiver<T> {
fn from_shared(version: usize, shared: Arc<Shared<T>>) -> Self {
fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self {
// No synchronization necessary as this is only used as a counter and
// not memory access.
shared.ref_count_rx.fetch_add(1, Relaxed);
Expand Down Expand Up @@ -247,7 +312,7 @@ impl<T> Receiver<T> {
/// [`changed`]: Receiver::changed
pub fn borrow_and_update(&mut self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
self.version = self.shared.version.load(SeqCst) & !CLOSED;
self.version = self.shared.state.load().version();
Ref { inner }
}

Expand Down Expand Up @@ -315,19 +380,19 @@ impl<T> Receiver<T> {

fn maybe_changed<T>(
shared: &Shared<T>,
version: &mut usize,
version: &mut Version,
) -> Option<Result<(), error::RecvError>> {
// Load the version from the state
let state = shared.version.load(SeqCst);
let new_version = state & !CLOSED;
let state = shared.state.load();
let new_version = state.version();

if *version != new_version {
// Observe the new version and return
*version = new_version;
return Some(Ok(()));
}

if CLOSED == state & CLOSED {
if state.is_closed() {
// All receivers have dropped.
return Some(Err(error::RecvError(())));
}
Expand Down Expand Up @@ -368,8 +433,7 @@ impl<T> Sender<T> {
let mut lock = self.shared.value.write().unwrap();
*lock = value;

// Update the version. 2 is used so that the CLOSED bit is not set.
self.shared.version.fetch_add(2, SeqCst);
self.shared.state.increment_version();

// Release the write lock.
//
Expand Down Expand Up @@ -463,7 +527,7 @@ impl<T> Sender<T> {
cfg_signal_internal! {
pub(crate) fn subscribe(&self) -> Receiver<T> {
let shared = self.shared.clone();
let version = shared.version.load(SeqCst);
let version = shared.state.load().version();

Receiver::from_shared(version, shared)
}
Expand Down Expand Up @@ -494,7 +558,7 @@ impl<T> Sender<T> {

impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.shared.version.fetch_or(CLOSED, SeqCst);
self.shared.state.set_closed();
self.shared.notify_rx.notify_waiters();
}
}
Expand Down

0 comments on commit be26ca7

Please sign in to comment.