Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

struct CaseSet: Optimize by matching on len once per set_ctx #1283

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 105 additions & 25 deletions src/ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,55 +52,121 @@ use std::iter::zip;
/// This optimizes for the common cases where `buf.len()` is a small power of 2,
/// where the array write is optimized as few and large stores as possible.
#[inline]
pub fn small_memset<T: Clone + Copy, const UP_TO: usize, const WITH_DEFAULT: bool>(
pub fn small_memset<T: Clone + Copy, const N: usize, const WITH_DEFAULT: bool>(
buf: &mut [T],
val: T,
) {
fn as_array<T: Clone + Copy, const N: usize>(buf: &mut [T]) -> &mut [T; N] {
buf.try_into().unwrap()
}
match buf.len() {
01 if UP_TO >= 01 => *as_array(buf) = [val; 01],
02 if UP_TO >= 02 => *as_array(buf) = [val; 02],
04 if UP_TO >= 04 => *as_array(buf) = [val; 04],
08 if UP_TO >= 08 => *as_array(buf) = [val; 08],
16 if UP_TO >= 16 => *as_array(buf) = [val; 16],
32 if UP_TO >= 32 => *as_array(buf) = [val; 32],
64 if UP_TO >= 64 => *as_array(buf) = [val; 64],
_ => {
if WITH_DEFAULT {
buf.fill(val)
}
if N == 0 {
if WITH_DEFAULT {
buf.fill(val)
}
} else {
assert!(buf.len() == N); // Meant to be optimized out.
*as_array(buf) = [val; N];
}
}

pub struct CaseSetter<const UP_TO: usize, const WITH_DEFAULT: bool> {
pub trait CaseSetter {
fn set<T: Clone + Copy>(&self, buf: &mut [T], val: T);

/// # Safety
///
/// Caller must ensure that no elements of the written range are concurrently
/// borrowed (immutably or mutably) at all during the call to `set_disjoint`.
fn set_disjoint<T, V>(&self, buf: &DisjointMut<T>, val: V)
where
T: AsMutPtr<Target = V>,
V: Clone + Copy;
}

pub struct CaseSetterN<const N: usize, const WITH_DEFAULT: bool> {
offset: usize,
len: usize,
}

impl<const UP_TO: usize, const WITH_DEFAULT: bool> CaseSetter<UP_TO, WITH_DEFAULT> {
impl<const N: usize, const WITH_DEFAULT: bool> CaseSetterN<N, WITH_DEFAULT> {
const fn len(&self) -> usize {
if N == 0 {
self.len
} else {
N
}
}
}

impl<const N: usize, const WITH_DEFAULT: bool> CaseSetter for CaseSetterN<N, WITH_DEFAULT> {
#[inline]
pub fn set<T: Clone + Copy>(&self, buf: &mut [T], val: T) {
small_memset::<T, UP_TO, WITH_DEFAULT>(&mut buf[self.offset..][..self.len], val);
fn set<T: Clone + Copy>(&self, buf: &mut [T], val: T) {
small_memset::<_, N, WITH_DEFAULT>(&mut buf[self.offset..][..self.len()], val);
}

/// # Safety
///
/// Caller must ensure that no elements of the written range are concurrently
/// borrowed (immutably or mutably) at all during the call to `set_disjoint`.
#[inline]
pub fn set_disjoint<T, V>(&self, buf: &DisjointMut<T>, val: V)
fn set_disjoint<T, V>(&self, buf: &DisjointMut<T>, val: V)
where
T: AsMutPtr<Target = V>,
V: Clone + Copy,
{
let mut buf = buf.index_mut(self.offset..self.offset + self.len);
small_memset::<V, UP_TO, WITH_DEFAULT>(&mut *buf, val);
let mut buf = buf.index_mut((self.offset.., ..self.len()));
small_memset::<_, N, WITH_DEFAULT>(&mut *buf, val);
}
}

/// Rank-2 polymorphic closures aren't a thing in Rust yet,
/// so we need to emulate this through a generic trait with a generic method.
/// Unforunately, this means we have to write the closure sugar manually.
pub trait SetCtx<T> {
fn call<S: CaseSetter>(self, case: &S, ctx: T) -> Self;
}

/// Emulate a closure for a [`SetCtx`] `impl`.
macro_rules! set_ctx {
(
// `||` is used instead of just `|` due to this bug: <https://github.com/rust-lang/rustfmt/issues/6228>.
||
$($lifetime:lifetime,)?
$case:ident,
$ctx:ident: $T:ty,
// Note that the required trailing `,` is so `:expr` can precede `|`.
$($up_var:ident: $up_var_ty:ty$( = $up_var_val:expr)?,)*
|| $body:block
) => {{
use $crate::src::ctx::SetCtx;
use $crate::src::ctx::CaseSetter;

struct F$(<$lifetime>)? {
$($up_var: $up_var_ty,)*
}

impl$(<$lifetime>)? SetCtx<$T> for F$(<$lifetime>)? {
fn call<S: CaseSetter>(self, $case: &S, $ctx: $T) -> Self {
let Self {
$($up_var,)*
} = self;
$body
// We destructure and re-structure `Self` so that we
// can move out of refs without using `ref`/`ref mut`,
// which I don't know how to match on in a macro.
Self {
$($up_var,)*
}
}
}

F {
$($up_var$(: $up_var_val)?,)*
}
}};
}

pub(crate) use set_ctx;

/// The entrypoint to the [`CaseSet`] API.
///
/// `UP_TO` and `WITH_DEFAULT` are made const generic parameters rather than have multiple `case_set*` `fn`s,
Expand All @@ -117,11 +183,25 @@ impl<const UP_TO: usize, const WITH_DEFAULT: bool> CaseSet<UP_TO, WITH_DEFAULT>
/// The `len` and `offset` are supplied here and
/// applied to each `buf` passed to [`CaseSetter::set`] in `set_ctx`.
#[inline]
pub fn one<T, F>(ctx: T, len: usize, offset: usize, mut set_ctx: F)
pub fn one<T, F>(ctx: T, len: usize, offset: usize, set_ctx: F) -> F
where
F: FnMut(&CaseSetter<UP_TO, WITH_DEFAULT>, T),
F: SetCtx<T>,
{
set_ctx(&CaseSetter { offset, len }, ctx);
macro_rules! set_ctx {
($N:literal) => {
set_ctx.call(&CaseSetterN::<$N, WITH_DEFAULT> { offset, len }, ctx)
};
}
match len {
01 if UP_TO >= 01 => set_ctx!(01),
02 if UP_TO >= 02 => set_ctx!(02),
04 if UP_TO >= 04 => set_ctx!(04),
08 if UP_TO >= 08 => set_ctx!(08),
16 if UP_TO >= 16 => set_ctx!(16),
32 if UP_TO >= 32 => set_ctx!(32),
64 if UP_TO >= 64 => set_ctx!(64),
_ => set_ctx!(0),
}
}

/// Perform many case sets in one call.
Expand All @@ -138,10 +218,10 @@ impl<const UP_TO: usize, const WITH_DEFAULT: bool> CaseSet<UP_TO, WITH_DEFAULT>
offsets: [usize; N],
mut set_ctx: F,
) where
F: FnMut(&CaseSetter<UP_TO, WITH_DEFAULT>, T),
F: SetCtx<T>,
{
for (dir, (len, offset)) in zip(dirs, zip(lens, offsets)) {
Self::one(dir, len, offset, &mut set_ctx);
set_ctx = Self::one(dir, len, offset, set_ctx);
}
}
}
Loading
Loading