From c5ab3769f517b68fafb766b23c119f7155d4896f Mon Sep 17 00:00:00 2001 From: Jack Wrenn Date: Sat, 20 Jan 2024 17:05:53 -0800 Subject: [PATCH] [derive] Derive TryFromBytes on unions (#800) The implementation itself is very straightforward, but it requires a few prerequisite changes: - The `Ptr::project` method now requires that it only be called on struct and union types - The `AsInitialized` invariant property is changed in a subtle but important way. Previously, a byte offset in a type was required to be initialized if, in every instance of that type, it was initialized. Now, this is required if the byte is initialized in /any/ instance of the type. Without this change, it would be possible to invoke `TryFromBytes::is_bit_valid` on a byte sequence where an uninitialized byte appears which is valid for one union field but not valid for another. Since we check bit validity for all union fields, uninitialized bytes may only appear which are compatible with /all/ union fields. --- src/pointer/ptr.rs | 68 +++++---- zerocopy-derive/src/lib.rs | 62 ++++---- .../tests/struct_try_from_bytes.rs | 18 +-- zerocopy-derive/tests/union_try_from_bytes.rs | 143 ++++++++++++++++++ 4 files changed, 226 insertions(+), 65 deletions(-) create mode 100644 zerocopy-derive/tests/union_try_from_bytes.rs diff --git a/src/pointer/ptr.rs b/src/pointer/ptr.rs index 45f9d6dfd4..7f75e30409 100644 --- a/src/pointer/ptr.rs +++ b/src/pointer/ptr.rs @@ -285,37 +285,37 @@ pub mod invariant { /// The referent is not necessarily initialized. AnyValidity, - /// The byte ranges initialized in `T` are also initialized in the - /// referent. + /// The byte ranges initialized in `T` are also initialized in + /// the referent. /// - /// Formally: uninitialized bytes may only be present in `Ptr`'s - /// referent where it is possible for them to be present in `T`. - /// This is a dynamic property: if, at a particular byte offset, a - /// valid enum discriminant is set, the subsequent bytes may only - /// have uninitialized bytes as specificed by the corresponding - /// enum. + /// Formally: uninitialized bytes may only be present in + /// `Ptr`'s referent where they are guaranteed to be present + /// in `T`. This is a dynamic property: if, at a particular byte + /// offset, a valid enum discriminant is set, the subsequent + /// bytes may only have uninitialized bytes as specificed by the + /// corresponding enum. /// /// Formally, given `len = size_of_val_raw(ptr)`, at every byte /// offset, `b`, in the range `[0, len)`: - /// - If, in all instances `t: T` of length `len`, the byte at - /// offset `b` in `t` is initialized, then the byte at offset `b` - /// within `*ptr` must be initialized. - /// - Let `c` be the contents of the byte range `[0, b)` in `*ptr`. - /// Let `S` be the subset of valid instances of `T` of length - /// `len` which contain `c` in the offset range `[0, b)`. If, for - /// all instances of `t: T` in `S`, the byte at offset `b` in `t` - /// is initialized, then the byte at offset `b` in `*ptr` must be - /// initialized. + /// - If, in any instance `t: T` of length `len`, the byte at + /// offset `b` in `t` is initialized, then the byte at offset + /// `b` within `*ptr` must be initialized. + /// - Let `c` be the contents of the byte range `[0, b)` in + /// `*ptr`. Let `S` be the subset of valid instances of `T` of + /// length `len` which contain `c` in the offset range `[0, + /// b)`. If, in any instance of `t: T` in `S`, the byte at + /// offset `b` in `t` is initialized, then the byte at offset + /// `b` in `*ptr` must be initialized. /// /// Pragmatically, this means that if `*ptr` is guaranteed to /// contain an enum type at a particular offset, and the enum - /// discriminant stored in `*ptr` corresponds to a valid variant - /// of that enum type, then it is guaranteed that the appropriate - /// bytes of `*ptr` are initialized as defined by that variant's - /// bit validity (although note that the variant may contain - /// another enum type, in which case the same rules apply - /// depending on the state of its discriminant, and so on - /// recursively). + /// discriminant stored in `*ptr` corresponds to a valid + /// variant of that enum type, then it is guaranteed that the + /// appropriate bytes of `*ptr` are initialized as defined by + /// that variant's bit validity (although note that the + /// variant may contain another enum type, in which case the + /// same rules apply depending on the state of its + /// discriminant, and so on recursively). AsInitialized, /// The referent is bit-valid for `T`. @@ -785,6 +785,7 @@ mod _project { where T: 'a + ?Sized, I: Invariants, + I::Validity: invariant::at_least::AsInitialized, { /// Projects a field from `self`. /// @@ -796,6 +797,8 @@ mod _project { /// argument. Its argument will be `self` casted to a raw pointer. The /// pointer it returns must reference only a subset of `self`'s bytes. /// + /// The caller also promises that `T` is a struct or union type. + /// /// ## Postconditions /// /// If the preconditions of this function are met, this function will @@ -805,7 +808,7 @@ mod _project { pub unsafe fn project( self, projector: impl FnOnce(*mut T) -> *mut U, - ) -> Ptr<'a, U, (I::Aliasing, invariant::AnyAlignment, I::Validity)> { + ) -> Ptr<'a, U, (I::Aliasing, invariant::AnyAlignment, invariant::AsInitialized)> { // SAFETY: `projector` is provided with `self` casted to a raw // pointer. let field = projector(self.as_non_null().as_ptr()); @@ -849,10 +852,17 @@ mod _project { // `ALIASING_INVARIANT` because projection does not impact the // aliasing invariant. // 7. `field`, trivially, conforms to the alignment invariant of - // `AnyAlignment`. - // 8. `field`, conditionally, conforms to the validity invariant of - // `VALIDITY_INVARIANT`. If `field` is projected from data valid - // for `T`, `field` will be valid for `U`. + // `AnyAlignment`. + // 8. By type bound on `I::Validity`, `self` satisfies the + // "as-initialized" property relative to `T`. The returned `Ptr` + // has the validity `AsInitialized`. The caller promises that `T` + // is either a struct type or a union type. Returning a `Ptr` + // with the validity `AsInitialized` is valid in both cases. The + // struct case is self-explanatory, but the union case bears + // explanation. The "as-initialized" property says that a byte + // must be initialized if it is initialized in *any* instance of + // the type. Thus, if `self`'s referent is as-initialized as `T`, + // then it is at least as-initialized as each of its fields. unsafe { Ptr::new(field) } } } diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 21daca4be5..133867dde7 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -271,9 +271,7 @@ pub fn derive_try_from_bytes(ts: proc_macro::TokenStream) -> proc_macro::TokenSt Data::Enum(_) => { Error::new_spanned(&ast, "TryFromBytes not supported on enum types").to_compile_error() } - Data::Union(_) => { - Error::new_spanned(&ast, "TryFromBytes not supported on union types").to_compile_error() - } + Data::Union(unn) => derive_try_from_bytes_union(&ast, unn), } .into() } @@ -346,9 +344,8 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m // `is_bit_valid`. fn is_bit_valid(candidate: zerocopy::Maybe) -> bool { true #(&& { - // SAFETY: `project` is a field projection of `candidate`. - // The projected field will be well-aligned because this - // derive rejects packed types. + // SAFETY: `project` is a field projection of `candidate`, + // and `Self` is a struct type. let field_candidate = unsafe { let project = |slf: *mut Self| ::core::ptr::addr_of_mut!((*slf).#field_names); @@ -356,26 +353,6 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m candidate.project(project) }; - // SAFETY: The below invocation of `is_bit_valid` satisfies - // the safety preconditions of `is_bit_valid`: - // - The memory referenced by `field_candidate` is only - // accessed via reads for the duration of this method - // call. This is ensured by contract on the caller of the - // surrounding `is_bit_valid`. - // - `field_candidate` may not refer to a valid instance of - // its corresponding field type, but it will only have - // `UnsafeCell`s at the offsets at which they may occur in - // that field type. This is ensured both by contract on - // the caller of the surrounding `is_bit_valid`, and by - // the construction of `field_candidiate`, i.e., via - // projection through `candidate`. - // - // Note that it's possible that this call will panic - - // `is_bit_valid` does not promise that it doesn't panic, - // and in practice, we support user-defined validators, - // which could panic. This is sound because we haven't - // violated any safety invariants which we would need to fix - // before returning. <#field_tys as zerocopy::TryFromBytes>::is_bit_valid(field_candidate) })* } @@ -384,6 +361,39 @@ fn derive_try_from_bytes_struct(ast: &DeriveInput, strct: &DataStruct) -> proc_m impl_block(ast, strct, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras) } +// A union is `TryFromBytes` if: +// - any of its fields are `TryFromBytes` + +fn derive_try_from_bytes_union(ast: &DeriveInput, unn: &DataUnion) -> proc_macro2::TokenStream { + let extras = Some({ + let fields = unn.fields(); + let field_names = fields.iter().map(|(name, _ty)| name); + let field_tys = fields.iter().map(|(_name, ty)| ty); + quote!( + // SAFETY: We use `is_bit_valid` to validate that any field is + // bit-valid; we only return `true` if at least one of them is. The + // bit validity of a union is not yet well defined in Rust, but it + // is guaranteed to be no more strict than this definition. See #696 + // for a more in-depth discussion. + fn is_bit_valid(candidate: zerocopy::Maybe) -> bool { + false #(|| { + // SAFETY: `project` is a field projection of `candidate`, + // and `Self` is a union type. + let field_candidate = unsafe { + let project = |slf: *mut Self| + ::core::ptr::addr_of_mut!((*slf).#field_names); + + candidate.project(project) + }; + + <#field_tys as zerocopy::TryFromBytes>::is_bit_valid(field_candidate) + })* + } + ) + }); + impl_block(ast, unn, Trait::TryFromBytes, RequireBoundedFields::Yes, false, None, extras) +} + const STRUCT_UNION_ALLOWED_REPR_COMBINATIONS: &[&[StructRepr]] = &[ &[StructRepr::C], &[StructRepr::Transparent], diff --git a/zerocopy-derive/tests/struct_try_from_bytes.rs b/zerocopy-derive/tests/struct_try_from_bytes.rs index 18ae218d96..28664f7c77 100644 --- a/zerocopy-derive/tests/struct_try_from_bytes.rs +++ b/zerocopy-derive/tests/struct_try_from_bytes.rs @@ -22,21 +22,17 @@ use crate::util::AU16; // A struct is `TryFromBytes` if: // - all fields are `TryFromBytes` -#[derive(TryFromBytes, FromZeros, FromBytes)] -struct Zst; - -assert_impl_all!(Zst: TryFromBytes); - #[test] fn zst() { // TODO(#5): Use `try_transmute` in this test once it's available. - let candidate = zerocopy::Ptr::from_ref(&Zst); + let candidate = zerocopy::Ptr::from_ref(&()); let candidate = candidate.forget_aligned().forget_valid(); - let is_bit_valid = Zst::is_bit_valid(candidate); + let is_bit_valid = <()>::is_bit_valid(candidate); assert!(is_bit_valid); } #[derive(TryFromBytes, FromZeros, FromBytes)] +#[repr(C)] struct One { a: u8, } @@ -53,9 +49,10 @@ fn one() { } #[derive(TryFromBytes, FromZeros)] +#[repr(C)] struct Two { a: bool, - b: Zst, + b: (), } assert_impl_all!(Two: TryFromBytes); @@ -63,7 +60,7 @@ assert_impl_all!(Two: TryFromBytes); #[test] fn two() { // TODO(#5): Use `try_transmute` in this test once it's available. - let candidate = zerocopy::Ptr::from_ref(&Two { a: false, b: Zst }); + let candidate = zerocopy::Ptr::from_ref(&Two { a: false, b: () }); let candidate = candidate.forget_aligned().forget_valid(); let is_bit_valid = Two::is_bit_valid(candidate); assert!(is_bit_valid); @@ -80,7 +77,6 @@ fn two_bad() { // *mut U`. // - The size of the object referenced by the resulting pointer is equal to // the size of the object referenced by `self`. - // - The alignment of `Unsized` is equal to the alignment of `[u8]`. let candidate = unsafe { candidate.cast_unsized(|p| p as *mut Two) }; // SAFETY: `candidate`'s referent is as-initialized as `Two`. @@ -91,6 +87,7 @@ fn two_bad() { } #[derive(TryFromBytes, FromZeros, FromBytes)] +#[repr(C)] struct Unsized { a: [u8], } @@ -118,6 +115,7 @@ fn un_sized() { } #[derive(TryFromBytes, FromZeros, FromBytes)] +#[repr(C)] struct TypeParams<'a, T: ?Sized, I: Iterator> { a: I::Item, b: u8, diff --git a/zerocopy-derive/tests/union_try_from_bytes.rs b/zerocopy-derive/tests/union_try_from_bytes.rs new file mode 100644 index 0000000000..0fae644526 --- /dev/null +++ b/zerocopy-derive/tests/union_try_from_bytes.rs @@ -0,0 +1,143 @@ +// Copyright 2023 The Fuchsia Authors +// +// Licensed under a BSD-style license , Apache License, Version 2.0 +// , or the MIT +// license , at your option. +// This file may not be copied, modified, or distributed except according to +// those terms. + +#![allow(warnings)] + +mod util; + +use std::{marker::PhantomData, option::IntoIter}; + +use { + static_assertions::assert_impl_all, + zerocopy::{FromBytes, FromZeros, KnownLayout, TryFromBytes}, +}; + +use crate::util::AU16; + +// A struct is `TryFromBytes` if: +// - any of its fields are `TryFromBytes` + +#[derive(TryFromBytes, FromZeros, FromBytes)] +union One { + a: u8, +} + +assert_impl_all!(One: TryFromBytes); + +#[test] +fn one() { + // TODO(#5): Use `try_transmute` in this test once it's available. + let candidate = zerocopy::Ptr::from_ref(&One { a: 42 }); + let candidate = candidate.forget_aligned().forget_valid(); + let is_bit_valid = One::is_bit_valid(candidate); + assert!(is_bit_valid); +} + +#[derive(TryFromBytes, FromZeros)] +#[repr(C)] +union Two { + a: bool, + b: bool, +} + +assert_impl_all!(Two: TryFromBytes); + +#[test] +fn two() { + // TODO(#5): Use `try_transmute` in this test once it's available. + let candidate_a = zerocopy::Ptr::from_ref(&Two { a: false }); + let candidate_a = candidate_a.forget_aligned().forget_valid(); + let is_bit_valid = Two::is_bit_valid(candidate_a); + assert!(is_bit_valid); + + let candidate_b = zerocopy::Ptr::from_ref(&Two { b: true }); + let candidate_b = candidate_b.forget_aligned().forget_valid(); + let is_bit_valid = Two::is_bit_valid(candidate_b); + assert!(is_bit_valid); +} + +#[test] +fn two_bad() { + // TODO(#5): Use `try_transmute` in this test once it's available. + let candidate = zerocopy::Ptr::from_ref(&[2u8][..]); + let candidate = candidate.forget_aligned().forget_valid(); + + // SAFETY: + // - The cast `cast(p)` is implemented exactly as follows: `|p: *mut T| p as + // *mut U`. + // - The size of the object referenced by the resulting pointer is equal to + // the size of the object referenced by `self`. + let candidate = unsafe { candidate.cast_unsized(|p| p as *mut Two) }; + + // SAFETY: `candidate`'s referent is as-initialized as `Two`. + let candidate = unsafe { candidate.assume_as_initialized() }; + + let is_bit_valid = Two::is_bit_valid(candidate); + assert!(!is_bit_valid); +} + +#[derive(TryFromBytes, FromZeros)] +#[repr(C)] +union BoolAndZst { + a: bool, + b: (), +} + +#[test] +fn bool_and_zst() { + // TODO(#5): Use `try_transmute` in this test once it's available. + let candidate = zerocopy::Ptr::from_ref(&[2u8][..]); + let candidate = candidate.forget_aligned().forget_valid(); + + // SAFETY: + // - The cast `cast(p)` is implemented exactly as follows: `|p: *mut T| p as + // *mut U`. + // - The size of the object referenced by the resulting pointer is equal to + // the size of the object referenced by `self`. + let candidate = unsafe { candidate.cast_unsized(|p| p as *mut BoolAndZst) }; + + // SAFETY: `candidate`'s referent is as-initialized as `BoolAndZst`. + let candidate = unsafe { candidate.assume_as_initialized() }; + + let is_bit_valid = BoolAndZst::is_bit_valid(candidate); + assert!(is_bit_valid); +} + +#[derive(TryFromBytes, FromZeros, FromBytes)] +#[repr(C)] +union TypeParams<'a, T: Copy, I: Iterator> +where + I::Item: Copy, +{ + a: I::Item, + b: u8, + c: PhantomData<&'a [u8]>, + d: PhantomData<&'static str>, + e: PhantomData, + f: T, +} + +assert_impl_all!(TypeParams<'static, (), IntoIter<()>>: TryFromBytes); +assert_impl_all!(TypeParams<'static, AU16, IntoIter<()>>: TryFromBytes); +assert_impl_all!(TypeParams<'static, [AU16; 2], IntoIter<()>>: TryFromBytes); + +// Deriving `TryFromBytes` should work if the union has bounded parameters. + +#[derive(TryFromBytes, FromZeros, FromBytes)] +#[repr(C)] +union WithParams<'a: 'b, 'b: 'a, const N: usize, T: 'a + 'b + TryFromBytes> +where + 'a: 'b, + 'b: 'a, + T: 'a + 'b + TryFromBytes + Copy, +{ + a: PhantomData<&'a &'b ()>, + b: T, +} + +assert_impl_all!(WithParams<'static, 'static, 42, u8>: TryFromBytes);