Skip to content

Commit

Permalink
[derive] Support TryFromBytes on field-less enums (#803)
Browse files Browse the repository at this point in the history
Makes progress on #5
  • Loading branch information
joshlf committed Jan 22, 2024
1 parent dd107ad commit 100ba98
Show file tree
Hide file tree
Showing 10 changed files with 654 additions and 562 deletions.
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,8 @@ pub unsafe trait TryFromBytes {
// SAFETY: `candidate` has no uninitialized sub-ranges because it
// derived from `bytes: &[u8]`, and is therefore at least as-initialized
// as `Self`.
let candidate = unsafe { candidate.assume_as_initialized() };
let candidate =
unsafe { candidate.assume_validity::<crate::pointer::invariant::AsInitialized>() };

// This call may panic. If that happens, it doesn't cause any soundness
// issues, as we have not generated any invalid state which we need to
Expand Down
4 changes: 2 additions & 2 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ macro_rules! unsafe_impl {

// SAFETY: The caller has promised that the referenced memory region
// will contain a valid `$repr`.
let $candidate = unsafe { candidate.assume_valid() };
let $candidate = unsafe { candidate.assume_validity::<crate::pointer::invariant::Valid>() };
$is_bit_valid
}
};
Expand All @@ -166,7 +166,7 @@ macro_rules! unsafe_impl {

// SAFETY: The caller has promised that `$repr` is as-initialized as
// `Self`.
let $candidate = unsafe { $candidate.assume_as_initialized() };
let $candidate = unsafe { $candidate.assume_validity::<crate::pointer::invariant::AsInitialized>() };

$is_bit_valid
}
Expand Down
4 changes: 2 additions & 2 deletions src/pointer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ where
if T::is_bit_valid(self.forget_aligned()) {
// SAFETY: If `T::is_bit_valid`, code may assume that `self`
// contains a bit-valid instance of `Self`.
Some(unsafe { self.assume_valid() })
Some(unsafe { self.assume_validity::<invariant::Valid>() })
} else {
None
}
Expand Down Expand Up @@ -82,7 +82,7 @@ where
{
// SAFETY: The alignment of `T` is 1 and thus is always aligned
// because `T: Unaligned`.
let ptr = unsafe { self.assume_aligned() };
let ptr = unsafe { self.assume_alignment::<invariant::Aligned>() };
ptr.as_ref()
}
}
56 changes: 37 additions & 19 deletions src/pointer/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,51 +458,69 @@ mod _transitions {
T: 'a + ?Sized,
I: Invariants,
{
/// Assumes that `Ptr`'s referent is validly-aligned for `T`.
/// Assumes that `Ptr`'s referent is validly-aligned for `T` if required
/// by `A`.
///
/// # Safety
///
/// The caller promises that `Ptr`'s referent conforms to the alignment
/// invariant of `T`.
/// invariant of `T` if required by `A`.
#[inline]
pub(crate) unsafe fn assume_aligned(
pub(crate) unsafe fn assume_alignment<A: invariant::Alignment>(
self,
) -> Ptr<'a, T, (I::Aliasing, invariant::Aligned, I::Validity)> {
) -> Ptr<'a, T, (I::Aliasing, A, I::Validity)> {
// SAFETY: The caller promises that `self`'s referent is
// well-aligned for `T`.
// well-aligned for `T` if required by `A` .
unsafe { Ptr::from_ptr(self) }
}

/// Assumes that `Ptr`'s referent is as-initialized as `T`.
/// Assumes that `Ptr`'s referent conforms to the validity requirement
/// of `V`.
///
/// # Safety
///
/// The caller promises that `Ptr`'s referent conforms to the
/// [`invariant::AsInitialized`] invariant (see documentation there).
/// The caller promises that `Ptr`'s referent conforms to the validity
/// requirement of `V`.
#[doc(hidden)]
#[inline]
pub unsafe fn assume_validity<V: invariant::Validity>(
self,
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, V)> {
// SAFETY: The caller promises that `self`'s referent conforms to
// the validity requirement of `V`.
unsafe { Ptr::from_ptr(self) }
}

/// A shorthand for `self.assume_validity<invariant::AsInitialized>()`.
///
/// # Safety
///
/// The caller promises to uphold the safety preconditions of
/// `self.assume_validity<invariant::AsInitialized>()`.
#[doc(hidden)]
#[inline]
pub unsafe fn assume_as_initialized(
self,
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, invariant::AsInitialized)> {
// SAFETY: The caller promises that `self`'s referent only contains
// uninitialized bytes in a subset of the uninitialized ranges in
// `T`. for `T`.
unsafe { Ptr::from_ptr(self) }
// SAFETY: The caller has promised to uphold the safety
// preconditions.
unsafe { self.assume_validity::<invariant::AsInitialized>() }
}

/// Assumes that `Ptr`'s referent is validly initialized for `T`.
/// A shorthand for `self.assume_validity<invariant::Valid>()`.
///
/// # Safety
///
/// The caller promises that `Ptr`'s referent conforms to the
/// bit validity invariants on `T`.
/// The caller promises to uphold the safety preconditions of
/// `self.assume_validity<invariant::Valid>()`.
#[doc(hidden)]
#[inline]
pub(crate) unsafe fn assume_valid(
pub unsafe fn assume_valid(
self,
) -> Ptr<'a, T, (I::Aliasing, I::Alignment, invariant::Valid)> {
// SAFETY: The caller promises that `self`'s referent is bit-valid
// for `T`.
unsafe { Ptr::from_ptr(self) }
// SAFETY: The caller has promised to uphold the safety
// preconditions.
unsafe { self.assume_validity::<invariant::Valid>() }
}

/// Forgets that `Ptr`'s referent is validly-aligned for `T`.
Expand Down
147 changes: 143 additions & 4 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use {
quote::quote,
syn::{
parse_quote, Data, DataEnum, DataStruct, DataUnion, DeriveInput, Error, Expr, ExprLit,
GenericParam, Ident, Lit,
GenericParam, Ident, Index, Lit,
},
};

Expand Down Expand Up @@ -268,9 +268,7 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt
let ast = syn::parse_macro_input!(ts as DeriveInput);
match &ast.data {
Data::Struct(strct) => derive_try_from_bytes_struct(&ast, strct),
Data::Enum(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error()
}
Data::Enum(enm) => derive_try_from_bytes_enum(&ast, enm),
Data::Union(unn) => derive_try_from_bytes_union(&ast, unn),
}
.into()
Expand Down Expand Up @@ -401,6 +399,147 @@ const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
&[StructRepr::C, StructRepr::Packed],
];

fn derive_try_from_bytes_enum(ast: &DeriveInput, enm: &DataEnum) -> proc_macro2::TokenStream {
if !enm.is_fieldless() {
return Error::new_spanned(ast, "only field-less enums can implement TryFromBytes")
.to_compile_error();
}

let reprs = try_or_print!(ENUM_TRY_FROM_BYTES_CFG.validate_reprs(ast));
let discriminant_type = match reprs.as_slice() {
[EnumRepr::U8] => quote!(u8),
[EnumRepr::U16] => quote!(u16),
[EnumRepr::U32] => quote!(u32),
[EnumRepr::U64] => quote!(u64),
[EnumRepr::Usize] => quote!(usize),
[EnumRepr::I8] => quote!(i8),
[EnumRepr::I16] => quote!(i16),
[EnumRepr::I32] => quote!(i32),
[EnumRepr::I64] => quote!(i64),
[EnumRepr::Isize] => quote!(isize),
// `validate_reprs` has already validated that it's one of the preceding
// patterns.
_ => unreachable!(),
};

let discriminant_exprs = enm.variants.iter().scan(Discriminant::default(), |disc, var| {
Some(disc.update_and_generate_expr(&var.discriminant))
});
let extras = Some(quote!(
// SAFETY: We use `is_bit_valid` to validate that the bit pattern
// corresponds to one of the C-like enum's variant discriminants.
// Thus, this is a sound implementation of `is_bit_valid`.
fn is_bit_valid(
candidate: ::zerocopy::Ptr<
'_,
Self,
(
::zerocopy::pointer::invariant::Shared,
::zerocopy::pointer::invariant::AnyAlignment,
::zerocopy::pointer::invariant::AsInitialized,
),
>,
) -> bool {
// SAFETY:
// - `cast` is implemented as required.
// - Since we cast to the type specified by `Self`'s repr, `p`'s
// referent and the referent of the returned pointer have the
// same size.
let discriminant = unsafe { candidate.cast_unsized(|p: *mut Self| p as *mut ::zerocopy::macro_util::core_reexport::primitive::#discriminant_type) };
// SAFETY: Since `candidate` has the invariant `AsInitialized`,
// we know that `candidate`'s referent (and thus
// `discriminant`'s referent) is as-initialized as `Self`. Since
// `Self`'s repr is the same type as `discriminant`, we know
// that `discriminant`'s referent satisfies the as-initialized
// property.
let discriminant = unsafe { discriminant.assume_valid() };
let discriminant = discriminant.read_unaligned();

false #(|| (discriminant == (#discriminant_exprs)))*
}
));
impl_block(ast, enm, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

// Enum variant discriminants can be manually set not only as literal values,
// but as arbitrary const expressions. In order to handle this, we keep track of
// the most-recently-seen expression and a count of how many variants have been
// encountered since then.
//
// #[repr(u8)]
// enum Foo {
// A, // 0
// B = 5, // 5
// C, // 6
// D = 1 + 1, // 2
// E, // 3
// }
//
// Note: Default::default does the right thing (initializes to { None, 0 }).
#[derive(Default, Copy, Clone)]
struct Discriminant<'a> {
// The most-recently-set explicit discriminant.
previous: Option<&'a Expr>,
// When the next variant is encountered, what offset should be used compared
// to `previous` to determine the variant's discriminant?
next_offset: usize,
}

impl<'a> Discriminant<'a> {
/// Called when encountering a variant with discriminant set to `ast`.
/// Updates `self` in preparation for the next variant and generates an
/// expression which will evaluate to the numeric value this variant's
/// discriminant.
fn update_and_generate_expr(
&mut self,
ast: &'a Option<(syn::token::Eq, Expr)>,
) -> proc_macro2::TokenStream {
match ast.as_ref().map(|(_eq, expr)| expr) {
Some(expr) => {
self.previous = Some(expr);
self.next_offset = 1;
quote!(#expr)
}
None => {
let previous = self.previous.iter();
// Use `Index` instead of `usize` so that the number is
// formatted just as `0` rather than as `0usize`; the latter
// syntax is only valid if the repr is `usize`; otherwise,
// comparison will result in a type mismatch.
let offset = Index::from(self.next_offset);
let tokens = quote!(#(#previous +)* #offset);

self.next_offset += 1;
tokens
}
}
}
}

#[rustfmt::skip]
const ENUM_TRY_FROM_BYTES_CFG: Config<EnumRepr> = {
use EnumRepr::*;
Config {
allowed_combinations_message: r#"TryFromBytes requires repr of "u8", "u16", "u32", "u64", "usize", "i8", or "i16", "i32", "i64", or "isize""#,
derive_unaligned: false,
allowed_combinations: &[
&[U8],
&[U16],
&[U32],
&[U64],
&[Usize],
&[I8],
&[I16],
&[I32],
&[I64],
&[Isize],
],
disallowed_but_legal_combinations: &[
&[C],
],
}
};

// A struct is `FromZeros` if:
// - all fields are `FromZeros`

Expand Down
Loading

0 comments on commit 100ba98

Please sign in to comment.