Skip to content

Commit

Permalink
[derive] Derive TryFromBytes on unions (#800)
Browse files Browse the repository at this point in the history
The implementation itself is very straightforward, but it requires a few
prerequisite changes:
- The `Ptr::project` method now requires that it only be called on
  struct and union types
- The `AsInitialized` invariant property is changed in a subtle but
  important way. Previously, a byte offset in a type was required to be
  initialized if, in every instance of that type, it was initialized.
  Now, this is required if the byte is initialized in /any/ instance of
  the type. Without this change, it would be possible to invoke
  `TryFromBytes::is_bit_valid` on a byte sequence where an uninitialized
  byte appears which is valid for one union field but not valid for
  another. Since we check bit validity for all union fields,
  uninitialized bytes may only appear which are compatible with /all/
  union fields.
  • Loading branch information
jswrenn committed Jan 21, 2024
1 parent a8a828c commit c5ab376
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 65 deletions.
68 changes: 39 additions & 29 deletions src/pointer/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,37 +285,37 @@ pub mod invariant {
/// The referent is not necessarily initialized.
AnyValidity,

/// The byte ranges initialized in `T` are also initialized in the
/// referent.
/// The byte ranges initialized in `T` are also initialized in
/// the referent.
///
/// Formally: uninitialized bytes may only be present in `Ptr<T>`'s
/// referent where it is possible for them to be present in `T`.
/// This is a dynamic property: if, at a particular byte offset, a
/// valid enum discriminant is set, the subsequent bytes may only
/// have uninitialized bytes as specificed by the corresponding
/// enum.
/// Formally: uninitialized bytes may only be present in
/// `Ptr<T>`'s referent where they are guaranteed to be present
/// in `T`. This is a dynamic property: if, at a particular byte
/// offset, a valid enum discriminant is set, the subsequent
/// bytes may only have uninitialized bytes as specificed by the
/// corresponding enum.
///
/// Formally, given `len = size_of_val_raw(ptr)`, at every byte
/// offset, `b`, in the range `[0, len)`:
/// - If, in all instances `t: T` of length `len`, the byte at
/// offset `b` in `t` is initialized, then the byte at offset `b`
/// within `*ptr` must be initialized.
/// - Let `c` be the contents of the byte range `[0, b)` in `*ptr`.
/// Let `S` be the subset of valid instances of `T` of length
/// `len` which contain `c` in the offset range `[0, b)`. If, for
/// all instances of `t: T` in `S`, the byte at offset `b` in `t`
/// is initialized, then the byte at offset `b` in `*ptr` must be
/// initialized.
/// - If, in any instance `t: T` of length `len`, the byte at
/// offset `b` in `t` is initialized, then the byte at offset
/// `b` within `*ptr` must be initialized.
/// - Let `c` be the contents of the byte range `[0, b)` in
/// `*ptr`. Let `S` be the subset of valid instances of `T` of
/// length `len` which contain `c` in the offset range `[0,
/// b)`. If, in any instance of `t: T` in `S`, the byte at
/// offset `b` in `t` is initialized, then the byte at offset
/// `b` in `*ptr` must be initialized.
///
/// Pragmatically, this means that if `*ptr` is guaranteed to
/// contain an enum type at a particular offset, and the enum
/// discriminant stored in `*ptr` corresponds to a valid variant
/// of that enum type, then it is guaranteed that the appropriate
/// bytes of `*ptr` are initialized as defined by that variant's
/// bit validity (although note that the variant may contain
/// another enum type, in which case the same rules apply
/// depending on the state of its discriminant, and so on
/// recursively).
/// discriminant stored in `*ptr` corresponds to a valid
/// variant of that enum type, then it is guaranteed that the
/// appropriate bytes of `*ptr` are initialized as defined by
/// that variant's bit validity (although note that the
/// variant may contain another enum type, in which case the
/// same rules apply depending on the state of its
/// discriminant, and so on recursively).
AsInitialized,

/// The referent is bit-valid for `T`.
Expand Down Expand Up @@ -785,6 +785,7 @@ mod _project {
where
T: 'a + ?Sized,
I: Invariants,
I::Validity: invariant::at_least::AsInitialized,
{
/// Projects a field from `self`.
///
Expand All @@ -796,6 +797,8 @@ mod _project {
/// argument. Its argument will be `self` casted to a raw pointer. The
/// pointer it returns must reference only a subset of `self`'s bytes.
///
/// The caller also promises that `T` is a struct or union type.
///
/// ## Postconditions
///
/// If the preconditions of this function are met, this function will
Expand All @@ -805,7 +808,7 @@ mod _project {
pub unsafe fn project<U: 'a + ?Sized>(
self,
projector: impl FnOnce(*mut T) -> *mut U,
) -> Ptr<'a, U, (I::Aliasing, invariant::AnyAlignment, I::Validity)> {
) -> Ptr<'a, U, (I::Aliasing, invariant::AnyAlignment, invariant::AsInitialized)> {
// SAFETY: `projector` is provided with `self` casted to a raw
// pointer.
let field = projector(self.as_non_null().as_ptr());
Expand Down Expand Up @@ -849,10 +852,17 @@ mod _project {
// `ALIASING_INVARIANT` because projection does not impact the
// aliasing invariant.
// 7. `field`, trivially, conforms to the alignment invariant of
// `AnyAlignment`.
// 8. `field`, conditionally, conforms to the validity invariant of
// `VALIDITY_INVARIANT`. If `field` is projected from data valid
// for `T`, `field` will be valid for `U`.
// `AnyAlignment`.
// 8. By type bound on `I::Validity`, `self` satisfies the
// "as-initialized" property relative to `T`. The returned `Ptr`
// has the validity `AsInitialized`. The caller promises that `T`
// is either a struct type or a union type. Returning a `Ptr`
// with the validity `AsInitialized` is valid in both cases. The
// struct case is self-explanatory, but the union case bears
// explanation. The "as-initialized" property says that a byte
// must be initialized if it is initialized in *any* instance of
// the type. Thus, if `self`'s referent is as-initialized as `T`,
// then it is at least as-initialized as each of its fields.
unsafe { Ptr::new(field) }
}
}
Expand Down
62 changes: 36 additions & 26 deletions zerocopy-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,7 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt
Data::Enum(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error()
}
Data::Union(_) => {
Error::new_spanned(&ast, "TryFromBytes not supported on union types").to_compile_error()
}
Data::Union(unn) => derive_try_from_bytes_union(&ast, unn),
}
.into()
}
Expand Down Expand Up @@ -346,36 +344,15 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m
// `is_bit_valid`.
fn is_bit_valid(candidate: zerocopy::Maybe<Self>) -> bool {
true #(&& {
// SAFETY: `project` is a field projection of `candidate`.
// The projected field will be well-aligned because this
// derive rejects packed types.
// SAFETY: `project` is a field projection of `candidate`,
// and `Self` is a struct type.
let field_candidate = unsafe {
let project = |slf: *mut Self|
::core::ptr::addr_of_mut!((*slf).#field_names);

candidate.project(project)
};

// SAFETY: The below invocation of `is_bit_valid` satisfies
// the safety preconditions of `is_bit_valid`:
// - The memory referenced by `field_candidate` is only
// accessed via reads for the duration of this method
// call. This is ensured by contract on the caller of the
// surrounding `is_bit_valid`.
// - `field_candidate` may not refer to a valid instance of
// its corresponding field type, but it will only have
// `UnsafeCell`s at the offsets at which they may occur in
// that field type. This is ensured both by contract on
// the caller of the surrounding `is_bit_valid`, and by
// the construction of `field_candidiate`, i.e., via
// projection through `candidate`.
//
// Note that it's possible that this call will panic -
// `is_bit_valid` does not promise that it doesn't panic,
// and in practice, we support user-defined validators,
// which could panic. This is sound because we haven't
// violated any safety invariants which we would need to fix
// before returning.
<#field_tys as zerocopy::TryFromBytes>::is_bit_valid(field_candidate)
})*
}
Expand All @@ -384,6 +361,39 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m
impl_block(ast, strct, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

// A union is `TryFromBytes` if:
// - any of its fields are `TryFromBytes`

fn derive_try_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream {
let extras = Some({
let fields = unn.fields();
let field_names = fields.iter().map(|(name, _ty)| name);
let field_tys = fields.iter().map(|(_name, ty)| ty);
quote!(
// SAFETY: We use `is_bit_valid` to validate that any field is
// bit-valid; we only return `true` if at least one of them is. The
// bit validity of a union is not yet well defined in Rust, but it
// is guaranteed to be no more strict than this definition. See #696
// for a more in-depth discussion.
fn is_bit_valid(candidate: zerocopy::Maybe<Self>) -> bool {
false #(|| {
// SAFETY: `project` is a field projection of `candidate`,
// and `Self` is a union type.
let field_candidate = unsafe {
let project = |slf: *mut Self|
::core::ptr::addr_of_mut!((*slf).#field_names);

candidate.project(project)
};

<#field_tys as zerocopy::TryFromBytes>::is_bit_valid(field_candidate)
})*
}
)
});
impl_block(ast, unn, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras)
}

const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[
&[StructRepr::C],
&[StructRepr::Transparent],
Expand Down
18 changes: 8 additions & 10 deletions zerocopy-derive/tests/struct_try_from_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,17 @@ use crate::util::AU16;
// A struct is `TryFromBytes` if:
// - all fields are `TryFromBytes`

#[derive(TryFromBytes, FromZeros, FromBytes)]
struct Zst;

assert_impl_all!(Zst: TryFromBytes);

#[test]
fn zst() {
// TODO(#5): Use `try_transmute` in this test once it's available.
let candidate = zerocopy::Ptr::from_ref(&Zst);
let candidate = zerocopy::Ptr::from_ref(&());
let candidate = candidate.forget_aligned().forget_valid();
let is_bit_valid = Zst::is_bit_valid(candidate);
let is_bit_valid = <()>::is_bit_valid(candidate);
assert!(is_bit_valid);
}

#[derive(TryFromBytes, FromZeros, FromBytes)]
#[repr(C)]
struct One {
a: u8,
}
Expand All @@ -53,17 +49,18 @@ fn one() {
}

#[derive(TryFromBytes, FromZeros)]
#[repr(C)]
struct Two {
a: bool,
b: Zst,
b: (),
}

assert_impl_all!(Two: TryFromBytes);

#[test]
fn two() {
// TODO(#5): Use `try_transmute` in this test once it's available.
let candidate = zerocopy::Ptr::from_ref(&Two { a: false, b: Zst });
let candidate = zerocopy::Ptr::from_ref(&Two { a: false, b: () });
let candidate = candidate.forget_aligned().forget_valid();
let is_bit_valid = Two::is_bit_valid(candidate);
assert!(is_bit_valid);
Expand All @@ -80,7 +77,6 @@ fn two_bad() {
// *mut U`.
// - The size of the object referenced by the resulting pointer is equal to
// the size of the object referenced by `self`.
// - The alignment of `Unsized` is equal to the alignment of `[u8]`.
let candidate = unsafe { candidate.cast_unsized(|p| p as *mut Two) };

// SAFETY: `candidate`'s referent is as-initialized as `Two`.
Expand All @@ -91,6 +87,7 @@ fn two_bad() {
}

#[derive(TryFromBytes, FromZeros, FromBytes)]
#[repr(C)]
struct Unsized {
a: [u8],
}
Expand Down Expand Up @@ -118,6 +115,7 @@ fn un_sized() {
}

#[derive(TryFromBytes, FromZeros, FromBytes)]
#[repr(C)]
struct TypeParams<'a, T: ?Sized, I: Iterator> {
a: I::Item,
b: u8,
Expand Down
Loading

0 comments on commit c5ab376

Please sign in to comment.