Skip to content

Commit

Permalink
Make portable implementation const
Browse files Browse the repository at this point in the history
  • Loading branch information
nazar-pc committed Jan 5, 2025
1 parent e51bcac commit b3eb262
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 86 deletions.
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@ const DERIVE_KEY_CONTEXT: u8 = 1 << 5;
const DERIVE_KEY_MATERIAL: u8 = 1 << 6;

#[inline]
fn counter_low(counter: u64) -> u32 {
const fn counter_low(counter: u64) -> u32 {
counter as u32
}

#[inline]
fn counter_high(counter: u64) -> u32 {
const fn counter_high(counter: u64) -> u32 {
(counter >> 32) as u32
}

Expand Down Expand Up @@ -623,7 +623,7 @@ pub enum IncrementCounter {

impl IncrementCounter {
#[inline]
fn yes(&self) -> bool {
const fn yes(&self) -> bool {
match self {
IncrementCounter::Yes => true,
IncrementCounter::No => false,
Expand Down
125 changes: 73 additions & 52 deletions src/platform.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::{portable, CVWords, IncrementCounter, BLOCK_LEN};
use arrayref::{array_mut_ref, array_ref};

cfg_if::cfg_if! {
if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
Expand Down Expand Up @@ -473,74 +472,96 @@ pub fn sse2_detected() -> bool {
false
}

macro_rules! extract_u32_from_byte_chunks {
($src:ident, $chunk_index:literal) => {
u32::from_le_bytes([
$src[$chunk_index * 4 + 0],
$src[$chunk_index * 4 + 1],
$src[$chunk_index * 4 + 2],
$src[$chunk_index * 4 + 3],
])
};
}

macro_rules! store_u32_to_by_chunks {
($src:ident, $dst:ident, $chunk_index:literal) => {
[
$dst[$chunk_index * 4 + 0],
$dst[$chunk_index * 4 + 1],
$dst[$chunk_index * 4 + 2],
$dst[$chunk_index * 4 + 3],
] = $src[$chunk_index].to_le_bytes();
};
}

#[inline(always)]
pub fn words_from_le_bytes_32(bytes: &[u8; 32]) -> [u32; 8] {
let mut out = [0; 8];
out[0] = u32::from_le_bytes(*array_ref!(bytes, 0 * 4, 4));
out[1] = u32::from_le_bytes(*array_ref!(bytes, 1 * 4, 4));
out[2] = u32::from_le_bytes(*array_ref!(bytes, 2 * 4, 4));
out[3] = u32::from_le_bytes(*array_ref!(bytes, 3 * 4, 4));
out[4] = u32::from_le_bytes(*array_ref!(bytes, 4 * 4, 4));
out[5] = u32::from_le_bytes(*array_ref!(bytes, 5 * 4, 4));
out[6] = u32::from_le_bytes(*array_ref!(bytes, 6 * 4, 4));
out[7] = u32::from_le_bytes(*array_ref!(bytes, 7 * 4, 4));
out[0] = extract_u32_from_byte_chunks!(bytes, 0);
out[1] = extract_u32_from_byte_chunks!(bytes, 1);
out[2] = extract_u32_from_byte_chunks!(bytes, 2);
out[3] = extract_u32_from_byte_chunks!(bytes, 3);
out[4] = extract_u32_from_byte_chunks!(bytes, 4);
out[5] = extract_u32_from_byte_chunks!(bytes, 5);
out[6] = extract_u32_from_byte_chunks!(bytes, 6);
out[7] = extract_u32_from_byte_chunks!(bytes, 7);
out
}

#[inline(always)]
pub fn words_from_le_bytes_64(bytes: &[u8; 64]) -> [u32; 16] {
pub const fn words_from_le_bytes_64(bytes: &[u8; 64]) -> [u32; 16] {
let mut out = [0; 16];
out[0] = u32::from_le_bytes(*array_ref!(bytes, 0 * 4, 4));
out[1] = u32::from_le_bytes(*array_ref!(bytes, 1 * 4, 4));
out[2] = u32::from_le_bytes(*array_ref!(bytes, 2 * 4, 4));
out[3] = u32::from_le_bytes(*array_ref!(bytes, 3 * 4, 4));
out[4] = u32::from_le_bytes(*array_ref!(bytes, 4 * 4, 4));
out[5] = u32::from_le_bytes(*array_ref!(bytes, 5 * 4, 4));
out[6] = u32::from_le_bytes(*array_ref!(bytes, 6 * 4, 4));
out[7] = u32::from_le_bytes(*array_ref!(bytes, 7 * 4, 4));
out[8] = u32::from_le_bytes(*array_ref!(bytes, 8 * 4, 4));
out[9] = u32::from_le_bytes(*array_ref!(bytes, 9 * 4, 4));
out[10] = u32::from_le_bytes(*array_ref!(bytes, 10 * 4, 4));
out[11] = u32::from_le_bytes(*array_ref!(bytes, 11 * 4, 4));
out[12] = u32::from_le_bytes(*array_ref!(bytes, 12 * 4, 4));
out[13] = u32::from_le_bytes(*array_ref!(bytes, 13 * 4, 4));
out[14] = u32::from_le_bytes(*array_ref!(bytes, 14 * 4, 4));
out[15] = u32::from_le_bytes(*array_ref!(bytes, 15 * 4, 4));
out[0] = extract_u32_from_byte_chunks!(bytes, 0);
out[1] = extract_u32_from_byte_chunks!(bytes, 1);
out[2] = extract_u32_from_byte_chunks!(bytes, 2);
out[3] = extract_u32_from_byte_chunks!(bytes, 3);
out[4] = extract_u32_from_byte_chunks!(bytes, 4);
out[5] = extract_u32_from_byte_chunks!(bytes, 5);
out[6] = extract_u32_from_byte_chunks!(bytes, 6);
out[7] = extract_u32_from_byte_chunks!(bytes, 7);
out[8] = extract_u32_from_byte_chunks!(bytes, 8);
out[9] = extract_u32_from_byte_chunks!(bytes, 9);
out[10] = extract_u32_from_byte_chunks!(bytes, 10);
out[11] = extract_u32_from_byte_chunks!(bytes, 11);
out[12] = extract_u32_from_byte_chunks!(bytes, 12);
out[13] = extract_u32_from_byte_chunks!(bytes, 13);
out[14] = extract_u32_from_byte_chunks!(bytes, 14);
out[15] = extract_u32_from_byte_chunks!(bytes, 15);
out
}

#[inline(always)]
pub fn le_bytes_from_words_32(words: &[u32; 8]) -> [u8; 32] {
pub const fn le_bytes_from_words_32(words: &[u32; 8]) -> [u8; 32] {
let mut out = [0; 32];
*array_mut_ref!(out, 0 * 4, 4) = words[0].to_le_bytes();
*array_mut_ref!(out, 1 * 4, 4) = words[1].to_le_bytes();
*array_mut_ref!(out, 2 * 4, 4) = words[2].to_le_bytes();
*array_mut_ref!(out, 3 * 4, 4) = words[3].to_le_bytes();
*array_mut_ref!(out, 4 * 4, 4) = words[4].to_le_bytes();
*array_mut_ref!(out, 5 * 4, 4) = words[5].to_le_bytes();
*array_mut_ref!(out, 6 * 4, 4) = words[6].to_le_bytes();
*array_mut_ref!(out, 7 * 4, 4) = words[7].to_le_bytes();
store_u32_to_by_chunks!(words, out, 0);
store_u32_to_by_chunks!(words, out, 1);
store_u32_to_by_chunks!(words, out, 2);
store_u32_to_by_chunks!(words, out, 3);
store_u32_to_by_chunks!(words, out, 4);
store_u32_to_by_chunks!(words, out, 5);
store_u32_to_by_chunks!(words, out, 6);
store_u32_to_by_chunks!(words, out, 7);
out
}

#[inline(always)]
pub fn le_bytes_from_words_64(words: &[u32; 16]) -> [u8; 64] {
pub const fn le_bytes_from_words_64(words: &[u32; 16]) -> [u8; 64] {
let mut out = [0; 64];
*array_mut_ref!(out, 0 * 4, 4) = words[0].to_le_bytes();
*array_mut_ref!(out, 1 * 4, 4) = words[1].to_le_bytes();
*array_mut_ref!(out, 2 * 4, 4) = words[2].to_le_bytes();
*array_mut_ref!(out, 3 * 4, 4) = words[3].to_le_bytes();
*array_mut_ref!(out, 4 * 4, 4) = words[4].to_le_bytes();
*array_mut_ref!(out, 5 * 4, 4) = words[5].to_le_bytes();
*array_mut_ref!(out, 6 * 4, 4) = words[6].to_le_bytes();
*array_mut_ref!(out, 7 * 4, 4) = words[7].to_le_bytes();
*array_mut_ref!(out, 8 * 4, 4) = words[8].to_le_bytes();
*array_mut_ref!(out, 9 * 4, 4) = words[9].to_le_bytes();
*array_mut_ref!(out, 10 * 4, 4) = words[10].to_le_bytes();
*array_mut_ref!(out, 11 * 4, 4) = words[11].to_le_bytes();
*array_mut_ref!(out, 12 * 4, 4) = words[12].to_le_bytes();
*array_mut_ref!(out, 13 * 4, 4) = words[13].to_le_bytes();
*array_mut_ref!(out, 14 * 4, 4) = words[14].to_le_bytes();
*array_mut_ref!(out, 15 * 4, 4) = words[15].to_le_bytes();
store_u32_to_by_chunks!(words, out, 0);
store_u32_to_by_chunks!(words, out, 1);
store_u32_to_by_chunks!(words, out, 2);
store_u32_to_by_chunks!(words, out, 3);
store_u32_to_by_chunks!(words, out, 4);
store_u32_to_by_chunks!(words, out, 5);
store_u32_to_by_chunks!(words, out, 6);
store_u32_to_by_chunks!(words, out, 7);
store_u32_to_by_chunks!(words, out, 8);
store_u32_to_by_chunks!(words, out, 9);
store_u32_to_by_chunks!(words, out, 10);
store_u32_to_by_chunks!(words, out, 11);
store_u32_to_by_chunks!(words, out, 12);
store_u32_to_by_chunks!(words, out, 13);
store_u32_to_by_chunks!(words, out, 14);
store_u32_to_by_chunks!(words, out, 15);
out
}
64 changes: 33 additions & 31 deletions src/portable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ use crate::{
counter_high, counter_low, CVBytes, CVWords, IncrementCounter, BLOCK_LEN, IV, MSG_SCHEDULE,
OUT_LEN,
};
use arrayref::{array_mut_ref, array_ref};

#[inline(always)]
fn g(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) {
const fn g(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) {
state[a] = state[a].wrapping_add(state[b]).wrapping_add(x);
state[d] = (state[d] ^ state[a]).rotate_right(16);
state[c] = state[c].wrapping_add(state[d]);
Expand All @@ -17,7 +16,7 @@ fn g(state: &mut [u32; 16], a: usize, b: usize, c: usize, d: usize, x: u32, y: u
}

#[inline(always)]
fn round(state: &mut [u32; 16], msg: &[u32; 16], round: usize) {
const fn round(state: &mut [u32; 16], msg: &[u32; 16], round: usize) {
// Select the message schedule based on the round.
let schedule = MSG_SCHEDULE[round];

Expand All @@ -35,7 +34,7 @@ fn round(state: &mut [u32; 16], msg: &[u32; 16], round: usize) {
}

#[inline(always)]
fn compress_pre(
const fn compress_pre(
cv: &CVWords,
block: &[u8; BLOCK_LEN],
block_len: u8,
Expand Down Expand Up @@ -74,7 +73,7 @@ fn compress_pre(
state
}

pub fn compress_in_place(
pub const fn compress_in_place(
cv: &mut CVWords,
block: &[u8; BLOCK_LEN],
block_len: u8,
Expand All @@ -93,7 +92,7 @@ pub fn compress_in_place(
cv[7] = state[7] ^ state[15];
}

pub fn compress_xof(
pub const fn compress_xof(
cv: &CVWords,
block: &[u8; BLOCK_LEN],
block_len: u8,
Expand All @@ -120,7 +119,7 @@ pub fn compress_xof(
crate::platform::le_bytes_from_words_64(&state)
}

fn hash1<const N: usize>(
const fn hash1<const N: usize>(
input: &[u8; N],
key: &CVWords,
counter: u64,
Expand All @@ -129,48 +128,51 @@ fn hash1<const N: usize>(
flags_end: u8,
out: &mut CVBytes,
) {
debug_assert_eq!(N % BLOCK_LEN, 0, "uneven blocks");
debug_assert!(N % BLOCK_LEN == 0, "uneven blocks");
let mut cv = *key;
let mut block_flags = flags | flags_start;
let mut slice = &input[..];
let mut slice = input.as_slice();
while slice.len() >= BLOCK_LEN {
if slice.len() == BLOCK_LEN {
let block;
(block, slice) = slice.split_at(BLOCK_LEN);
if slice.is_empty() {
block_flags |= flags_end;
}
compress_in_place(
&mut cv,
array_ref!(slice, 0, BLOCK_LEN),
BLOCK_LEN as u8,
counter,
block_flags,
);
let block = {
let ptr = block.as_ptr() as *const [u8; BLOCK_LEN];
// SAFETY: Sliced off correct length above
unsafe { &*ptr }
};

compress_in_place(&mut cv, block, BLOCK_LEN as u8, counter, block_flags);
block_flags = flags;
slice = &slice[BLOCK_LEN..];
}
*out = crate::platform::le_bytes_from_words_32(&cv);
}

pub fn hash_many<const N: usize>(
inputs: &[&[u8; N]],
pub const fn hash_many<const N: usize>(
mut inputs: &[&[u8; N]],
key: &CVWords,
mut counter: u64,
increment_counter: IncrementCounter,
flags: u8,
flags_start: u8,
flags_end: u8,
out: &mut [u8],
mut out: &mut [u8],
) {
debug_assert!(out.len() >= inputs.len() * OUT_LEN, "out too short");
for (&input, output) in inputs.iter().zip(out.chunks_exact_mut(OUT_LEN)) {
hash1(
input,
key,
counter,
flags,
flags_start,
flags_end,
array_mut_ref!(output, 0, OUT_LEN),
);
while !inputs.is_empty() {
let input;
(input, inputs) = inputs.split_first().expect("Not empty; qed");
let o;
(o, out) = out.split_at_mut(OUT_LEN);
let o = {
let ptr = o.as_mut_ptr() as *mut [u8; OUT_LEN];
// SAFETY: Sliced off correct length above
unsafe { &mut *ptr }
};

hash1(input, key, counter, flags, flags_start, flags_end, o);
if increment_counter.yes() {
counter += 1;
}
Expand Down

0 comments on commit b3eb262

Please sign in to comment.