Skip to content

Commit

Permalink
feat(bitfield): packing spec type safety (#174)
Browse files Browse the repository at this point in the history
This branch changes bitfield packing specs to ensure that a field of a
given bitfield type may not be used to pack into/from a bitfield of
another type.

Signed-off-by: Eliza Weisman <[email protected]>
  • Loading branch information
hawkw authored May 29, 2022
1 parent 5e5d0f4 commit 12ea600
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 41 deletions.
60 changes: 44 additions & 16 deletions bitfield/src/bitfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,34 @@
/// assert_eq!(formatted, expected);
/// ```
///
/// Packing specs from one bitfield type may *not* be used with a different
/// bitfield type's `get`, `set`, or `with` methods. For example, the following
/// is a type error:
///
/// ```compile_fail
/// use mycelium_bitfield::bitfield;
///
/// bitfield! {
/// struct Bitfield1<u8> {
/// pub const FOO: bool;
/// pub const BAR: bool;
/// pub const BAZ = 6;
/// }
/// }
///
/// bitfield! {
/// struct Bitfield2<u8> {
/// pub const ALICE = 2;
/// pub const BOB = 4;
/// pub const CHARLIE = 2;
/// }
/// }
///
///
/// // This is a *type error*, because `Bitfield2`'s field `ALICE` cannot be
/// // used with a `Bitfield2` value:
/// let bits = Bitfield1::new().with(Bitfield2::ALICE, 0b11);
/// ```
/// [`fmt::Debug`]: core::fmt::Debug
/// [`fmt::Display`]: core::fmt::Display
/// [`fmt::Binary`]: core::fmt::Binary
Expand Down Expand Up @@ -227,7 +255,7 @@ macro_rules! bitfield {
)+
}

const FIELDS: &'static [(&'static str, $crate::bitfield! { @t $T, $T })] = &[$(
const FIELDS: &'static [(&'static str, $crate::bitfield! { @t $T, $T, Self })] = &[$(
(stringify!($Field), Self::$Field.typed())
),+];

Expand All @@ -243,7 +271,7 @@ macro_rules! bitfield {

/// Packs the bit representation of `value` into `self` at the bit
/// range designated by `field`, returning a new bitfield.
$vis fn with<T>(self, field: $crate::bitfield! { @t $T, T }, value: T) -> Self
$vis fn with<T>(self, field: $crate::bitfield! { @t $T, T, Self }, value: T) -> Self
where
T: $crate::FromBits<$T>,
{
Expand All @@ -253,7 +281,7 @@ macro_rules! bitfield {

/// Packs the bit representation of `value` into `self` at the range
/// designated by `field`, mutating `self` in place.
$vis fn set<T>(&mut self, field: $crate::bitfield! { @t $T, T }, value: T) -> &mut Self
$vis fn set<T>(&mut self, field: $crate::bitfield! { @t $T, T, Self }, value: T) -> &mut Self
where
T: $crate::FromBits<$T>,
{
Expand All @@ -269,7 +297,7 @@ macro_rules! bitfield {
/// This method panics if `self` does not contain a valid bit
/// pattern for a `T`-typed value, as determined by `T`'s
/// [`mycelium_bitfield::FromBits::try_from_bits`] implementation.
$vis fn get<T>(self, field: $crate::bitfield! { @t $T, T }) -> T
$vis fn get<T>(self, field: $crate::bitfield! { @t $T, T, Self }) -> T
where
T: $crate::FromBits<$T>,
{
Expand All @@ -286,7 +314,7 @@ macro_rules! bitfield {
/// - `Err(T::Error)` if `src` does not contain a valid bit
/// pattern for a `T`-typed value, as determined by `T`'s
/// [`mycelium_bitfield::FromBits::try_from_bits`] implementation.
$vis fn try_get<T>(self, field: $crate::bitfield! { @t $T, T }) -> Result<T, T::Error>
$vis fn try_get<T>(self, field: $crate::bitfield! { @t $T, T, Self }) -> Result<T, T::Error>
where
T: $crate::FromBits<$T>,
{
Expand All @@ -297,7 +325,7 @@ macro_rules! bitfield {
///
/// This is intended to be used in unit tests.
$vis fn assert_valid() {
<$crate::bitfield! { @t $T, $T }>::assert_all_valid(&Self::FIELDS);
<$crate::bitfield! { @t $T, $T, Self }>::assert_all_valid(&Self::FIELDS);
}
}

Expand Down Expand Up @@ -420,7 +448,7 @@ macro_rules! bitfield {
$($rest:tt)*
) => {
$(#[$meta])*
$vis const $Field: $crate::bitfield!{ @t $T, $T } = Self::$Prev.next($value);
$vis const $Field: $crate::bitfield!{ @t $T, $T, Self } = Self::$Prev.next($value);
$crate::bitfield!{ @field<$T>, prev: $Field: $($rest)* }
};

Expand All @@ -430,7 +458,7 @@ macro_rules! bitfield {
$($rest:tt)*
) => {
$(#[$meta])*
$vis const $Field: $crate::bitfield!{ @t $T, $Val } = Self::$Prev.then::<$Val>();
$vis const $Field: $crate::bitfield!{ @t $T, $Val, Self } = Self::$Prev.then::<$Val>();
$crate::bitfield!{ @field<$T>, prev: $Field: $($rest)* }
};

Expand All @@ -441,7 +469,7 @@ macro_rules! bitfield {
$($rest:tt)*
) => {
$(#[$meta])*
$vis const $Field: $crate::bitfield!{ @t $T, $T } = <$crate::bitfield!{ @t $T, $T }>::least_significant($value);
$vis const $Field: $crate::bitfield!{ @t $T, $T, Self } = <$crate::bitfield!{ @t $T, $T, () }>::least_significant($value).typed();
$crate::bitfield!{ @field<$T>, prev: $Field: $($rest)* }
};

Expand All @@ -451,7 +479,7 @@ macro_rules! bitfield {
$($rest:tt)*
) => {
$(#[$meta])*
$vis const $Field: $crate::bitfield!{ @t $T, $Val } = <$crate::bitfield!{ @t $T, $Val } >::first();
$vis const $Field: $crate::bitfield!{ @t $T, $Val, Self } = <$crate::bitfield!{ @t $T, $Val, Self } >::first();
$crate::bitfield!{ @field<$T>, prev: $Field: $($rest)* }
};

Expand Down Expand Up @@ -528,12 +556,12 @@ macro_rules! bitfield {
// $crate::bitfield! { @process_derives $vis struct $Name<$T> { $Next, $($Before),* } { $($rest)* } }
// };

(@t usize, $V:ty) => { $crate::PackUsize<$V> };
(@t u64, $V:ty) => { $crate::Pack64<$V> };
(@t u32, $V:ty) => { $crate::Pack32<$V> };
(@t u16, $V:ty) => { $crate::Pack16<$V> };
(@t u8, $V:ty) => { $crate::Pack8<$V> };
(@t $T:ty, $V:ty) => { compile_error!(concat!("unsupported bitfield type `", stringify!($T), "`; expected one of `usize`, `u64`, `u32`, `u16`, or `u8`")) }
(@t usize, $V:ty, $F:ty) => { $crate::PackUsize<$V, $F> };
(@t u64, $V:ty, $F:ty) => { $crate::Pack64<$V, $F> };
(@t u32, $V:ty, $F:ty) => { $crate::Pack32<$V, $F> };
(@t u16, $V:ty, $F:ty) => { $crate::Pack16<$V, $F> };
(@t u8, $V:ty, $F:ty) => { $crate::Pack8<$V, $F> };
(@t $T:ty, $V:ty, $F:ty) => { compile_error!(concat!("unsupported bitfield type `", stringify!($T), "`; expected one of `usize`, `u64`, `u32`, `u16`, or `u8`")) }
}

#[cfg(test)]
Expand Down
50 changes: 25 additions & 25 deletions bitfield/src/pack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ macro_rules! make_packers {
stringify!($Bits),
"`] values."
)]
pub struct $Pack<T = $Bits> {
pub struct $Pack<T = $Bits, F = ()> {
mask: $Bits,
shift: u32,
_dst_ty: PhantomData<fn(&T)>,
_dst_ty: PhantomData<fn(&T, &F)>,
}

#[doc = concat!(
Expand All @@ -38,7 +38,7 @@ macro_rules! make_packers {
dst_shr: $Bits,
}

impl $Pack {
impl $Pack<$Bits> {
#[doc = concat!(
"Wrap a [`",
stringify!($Bits),
Expand Down Expand Up @@ -133,7 +133,7 @@ macro_rules! make_packers {
}
}

impl<T> $Pack<T> {
impl<T, F> $Pack<T, F> {
// XXX(eliza): why is this always `u32`? ask the stdlib i guess...
const SIZE_BITS: u32 = <$Bits>::MAX.leading_ones();

Expand All @@ -152,7 +152,7 @@ macro_rules! make_packers {
}

#[doc(hidden)]
pub const fn typed<T2>(self) -> $Pack<T2>
pub const fn typed<T2, F2>(self) -> $Pack<T2, F2>
where
T2: FromBits<$Bits>
{
Expand Down Expand Up @@ -215,7 +215,7 @@ macro_rules! make_packers {
base
}

pub const fn then<T2>(&self) -> $Pack<T2>
pub const fn then<T2>(&self) -> $Pack<T2, F>
where
T2: FromBits<$Bits>
{
Expand All @@ -224,15 +224,15 @@ macro_rules! make_packers {

/// Returns a packer for packing a value into the next more-significant
/// `n` from `self`.
pub const fn next(&self, n: u32) -> $Pack {
pub const fn next(&self, n: u32) -> $Pack<$Bits, F> {
let shift = self.shift_next();
let mask = Self::mk_mask(n) << shift;
$Pack { mask, shift, _dst_ty: core::marker::PhantomData, }
}

/// Returns a packer for packing a value into all the remaining
/// more-significant bits after `self`.
pub const fn remaining(&self) -> $Pack {
pub const fn remaining(&self) -> $Pack<$Bits, F> {
let shift = self.shift_next();
let n = Self::SIZE_BITS - shift;
let mask = Self::mk_mask(n) << shift;
Expand Down Expand Up @@ -405,14 +405,14 @@ macro_rules! make_packers {
}
}

impl<T> $Pack<T>
impl<T, F> $Pack<T, F>
where
T: FromBits<$Bits>,
{
/// Returns a packing spec for packing a `T`-typed value in the
/// first [`T::BITS`](FromBits::BITS) least-significant bits.
pub const fn first() -> Self {
$Pack::least_significant(T::BITS).typed()
$Pack::<$Bits, ()>::least_significant(T::BITS).typed()
}

/// Returns a pair type for packing bits from the range
Expand All @@ -422,7 +422,7 @@ macro_rules! make_packers {
/// The packing pair can be used to pack bits from one location
/// into another location, and vice versa.
pub const fn pair_at(&self, at: u32) -> $Pair<T> {
let dst = $Pack::starting_at(at, self.bits()).typed();
let dst = $Pack::<$Bits, ()>::starting_at(at, self.bits()).typed();
let at = at.saturating_sub(1);
// TODO(eliza): validate that `at + self.bits() < N_BITS` in
// const fn somehow lol
Expand All @@ -435,7 +435,7 @@ macro_rules! make_packers {
(0, (self.shift - at) as $Bits)
};
$Pair {
src: *self,
src: self.typed(),
dst,
dst_shl,
dst_shr,
Expand Down Expand Up @@ -521,7 +521,7 @@ macro_rules! make_packers {
}
}

impl<T> Clone for $Pack<T> {
impl<T, F> Clone for $Pack<T, F> {
fn clone(&self) -> Self {
Self {
mask: self.mask,
Expand All @@ -531,9 +531,9 @@ macro_rules! make_packers {
}
}

impl<T> Copy for $Pack<T> {}
impl<T, F> Copy for $Pack<T, F> {}

impl<T> fmt::Debug for $Pack<T> {
impl<T, F> fmt::Debug for $Pack<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(stringify!($Pack))
.field("mask", &format_args!("{:#b}", self.mask))
Expand All @@ -543,7 +543,7 @@ macro_rules! make_packers {
}
}

impl<T> fmt::UpperHex for $Pack<T> {
impl<T, F> fmt::UpperHex for $Pack<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(stringify!($Pack))
.field("mask", &format_args!("{:#X}", self.mask))
Expand All @@ -553,7 +553,7 @@ macro_rules! make_packers {
}
}

impl<T> fmt::LowerHex for $Pack<T> {
impl<T, F> fmt::LowerHex for $Pack<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(stringify!($Pack))
.field("mask", &format_args!("{:#x}", self.mask))
Expand All @@ -563,7 +563,7 @@ macro_rules! make_packers {
}
}

impl<T> fmt::Binary for $Pack<T> {
impl<T, F> fmt::Binary for $Pack<T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct(stringify!($Pack))
.field("mask", &format_args!("{:#b}", self.mask))
Expand All @@ -579,28 +579,28 @@ macro_rules! make_packers {
}
}

impl<A, B> PartialEq<$Pack<B>> for $Pack<A> {
impl<A, B, F> PartialEq<$Pack<B, F>> for $Pack<A, F> {
#[inline]
fn eq(&self, other: &$Pack<B>) -> bool {
fn eq(&self, other: &$Pack<B, F>) -> bool {
self.mask == other.mask && self.shift == other.shift
}
}

impl<A, B> PartialEq<&'_ $Pack<B>> for $Pack<A> {
impl<A, B, F> PartialEq<&'_ $Pack<B, F>> for $Pack<A, F> {
#[inline]
fn eq(&self, other: &&'_ $Pack<B>) -> bool {
fn eq(&self, other: &&'_ $Pack<B, F>) -> bool {
self.eq(*other)
}
}

impl<A, B> PartialEq<$Pack<B>> for &'_ $Pack<A> {
impl<A, B, F> PartialEq<$Pack<B, F>> for &'_ $Pack<A, F> {
#[inline]
fn eq(&self, other: &$Pack<B>) -> bool {
fn eq(&self, other: &$Pack<B, F>) -> bool {
(*self).eq(other)
}
}

impl<T> Eq for $Pack<T> {}
impl<T, F> Eq for $Pack<T, F> {}

// === packing type ===

Expand Down

0 comments on commit 12ea600

Please sign in to comment.