Skip to content

Commit

Permalink
backing out generics
Browse files Browse the repository at this point in the history
  • Loading branch information
integritychain committed Feb 22, 2024
1 parent 9144179 commit 74e5ec5
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 173 deletions.
20 changes: 11 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ license = "MIT OR Apache-2.0"
description = "FIPS 204 (draft): Module-Lattice-Based Digital Signature"
authors = ["Eric Schorn <[email protected]>"]
repository = "https://github.com/integritychain/fips204"
rust-version = "1.73"
rust-version = "1.72"


[dependencies]
sha3 = { version = "0.10.8", default-features = false }
rand_core = { version = "0.6.4", default-features = false }
zeroize = { version = "1.6.0", features = ["zeroize_derive"] }
zeroize = { version = "1.7.0", features = ["zeroize_derive"] }


[features]
Expand All @@ -25,23 +25,25 @@ ml-dsa-87 = []

[dev-dependencies]
rand = "0.8.5"
regex = "1.10.2"
regex = "1.10.3"
hex = "0.4.3"
rand_chacha = "0.3.1"
criterion = "0.5.1"
criterion = "0.4.0" # for MSRV 1.72, otherwise newest is "0.5.1"


[[bench]]
name = "benchmark"
harness = false


[profile.release]
debug = true


[profile.bench]
debug = true
#debug-assertions = false
#incremental = false
#lto = true
#opt-level = 3
#overflow-checks = false
debug-assertions = false
incremental = false
lto = true
opt-level = 3
overflow-checks = false
16 changes: 8 additions & 8 deletions src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ pub(crate) fn coef_from_three_bytes(bbb: [u8; 3]) -> Result<u32, &'static str> {
///
/// # Errors
/// Returns an error `⊥` on misconfigured or illegal input.
pub(crate) fn coef_from_half_byte<const ETA: usize>(b: u8) -> Result<i32, &'static str> {
pub(crate) fn coef_from_half_byte(eta: u32, b: u8) -> Result<i32, &'static str> {
ensure!(b <= 15, "Algorithm 9: input b must be <= 15");

// 1: if η = 2 and b < 15 then return 2 − (b mod 5)
if (ETA == 2) & (b < 15) {
if (eta == 2) & (b < 15) {
Ok(2 - (b % 5) as i32)
// 2: else
} else {
// 3: if η = 4 and b < 9 then return 4 − b
if (ETA == 4) & (b < 9) {
if (eta == 4) & (b < 9) {
Ok(4 - b as i32)
// 4: else return ⊥
} else {
Expand Down Expand Up @@ -332,35 +332,35 @@ mod tests {
#[test]
fn test_coef_from_half_byte1() {
let inp = 3;
let res = coef_from_half_byte::<2>(inp).unwrap();
let res = coef_from_half_byte(2, inp).unwrap();
assert_eq!(-1, res);
}

#[test]
fn test_coef_from_half_byte2() {
let inp = 8;
let res = coef_from_half_byte::<4>(inp).unwrap();
let res = coef_from_half_byte(4, inp).unwrap();
assert_eq!(-4, res);
}

#[test]
fn test_coef_from_half_byte_validation1() {
let inp = 22;
let res = coef_from_half_byte::<2>(inp);
let res = coef_from_half_byte(2, inp);
assert!(res.is_err());
}

#[test]
fn test_coef_from_half_byte_validation2() {
let inp = 15;
let res = coef_from_half_byte::<2>(inp);
let res = coef_from_half_byte(2, inp);
assert!(res.is_err());
}

#[test]
fn test_coef_from_half_byte_validation3() {
let inp = 10;
let res = coef_from_half_byte::<4>(inp);
let res = coef_from_half_byte(4, inp);
assert!(res.is_err());
}

Expand Down
84 changes: 27 additions & 57 deletions src/encodings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,57 +89,35 @@ pub(crate) fn pk_decode<const K: usize, const PK_LEN: usize>(
///
/// # Errors
/// Returns an error ........ TKTK
pub fn sk_encode<
const D: usize,
const ETA: usize,
const K: usize,
const L: usize,
const SK_LEN: usize,
>(
rho: &[u8; 32], k: &[u8; 32], tr: &[u8; 64], s1: &[R; L], s2: &[R; K], t0: &[R; K],
pub fn sk_encode<const D: usize, const K: usize, const L: usize, const SK_LEN: usize>(
eta: u32, rho: &[u8; 32], k: &[u8; 32], tr: &[u8; 64], s1: &[R; L], s2: &[R; K], t0: &[R; K],
) -> Result<[u8; SK_LEN], &'static str> {
ensure!(
s1.iter().all(|x| is_in_range(x, ETA as u32, ETA as u32)),
"Algorithm 18: s1 out of range"
);
ensure!(
s2.iter().all(|x| is_in_range(x, ETA as u32, ETA as u32)),
"Algorithm 18: s2 out of range"
);
ensure!(s1.iter().all(|x| is_in_range(x, eta, eta)), "Algorithm 18: s1 out of range");
ensure!(s2.iter().all(|x| is_in_range(x, eta, eta)), "Algorithm 18: s2 out of range");
ensure!(
t0.iter().all(|x| is_in_range(x, 2u32.pow(D as u32 - 1) + 1, 2u32.pow(D as u32 - 1))),
"Algorithm 18: t0 out of range"
);
let mut sk = [0u8; SK_LEN];
debug_assert_eq!(sk.len(), 32 + 32 + 64 + 32 * ((K + L) * bitlen(2 * ETA) + D * K)); // Useful nonsense? ;)
debug_assert_eq!(sk.len(), 32 + 32 + 64 + 32 * ((K + L) * bitlen(2 * eta as usize) + D * K)); // Useful nonsense? ;)

// 1: sk ← BitsToBytes(ρ) || BitsToBytes(K) || BitsToBytes(tr)
sk[0..32].copy_from_slice(rho);
sk[32..64].copy_from_slice(k);
sk[64..128].copy_from_slice(tr);
// 2: for i from 0 to ℓ − 1 do
let start = 128;
let step = 32 * bitlen(2 * ETA);
let step = 32 * bitlen(2 * eta as usize);
for i in 0..L {
// 3: sk ← sk || BitPack (s1[i], η, η)
bit_pack(
&s1[i],
ETA as u32,
ETA as u32,
&mut sk[start + i * step..start + (i + 1) * step],
)?;
bit_pack(&s1[i], eta, eta, &mut sk[start + i * step..start + (i + 1) * step])?;
// 4: end for
}
// 5: for i from 0 to k − 1 do
let start = start + L * step;
for i in 0..K {
// 6: sk ← sk || BitPack (s2[i], η, η)
bit_pack(
&s2[i],
ETA as u32,
ETA as u32,
&mut sk[start + i * step..start + (i + 1) * step],
)?;
bit_pack(&s2[i], eta, eta, &mut sk[start + i * step..start + (i + 1) * step])?;
// 7: end for
}
// 8: for i from 0 to k − 1 do
Expand Down Expand Up @@ -171,16 +149,10 @@ pub fn sk_encode<
/// # Errors
/// Returns an error ........ TKTK
#[allow(clippy::similar_names, clippy::type_complexity)]
pub(crate) fn sk_decode<
const D: usize,
const ETA: usize,
const K: usize,
const L: usize,
const SK_LEN: usize,
>(
sk: &[u8; SK_LEN],
pub(crate) fn sk_decode<const D: usize, const K: usize, const L: usize, const SK_LEN: usize>(
eta: u32, sk: &[u8; SK_LEN],
) -> Result<([u8; 32], [u8; 32], [u8; 64], [R; L], [R; K], [R; K]), &'static str> {
let bl = bitlen(2 * ETA);
let bl = bitlen(2 * eta as usize);
ensure!(sk.len() == 32 + 32 + 64 + 32 * ((L + K) * bl + D * K), "Algorithm 19: asdf asdf");
let (mut rho, mut k, mut tr) = ([0u8; 32], [0u8; 32], [0u8; 64]);
let (mut s1, mut s2, mut t0) = ([R::zero(); L], [R::zero(); K], [R::zero(); K]);
Expand All @@ -199,14 +171,14 @@ pub(crate) fn sk_decode<
let step = 32 * bl;
for i in 0..L {
// 6: s1[i] ← BitUnpack(yi, η, η) ▷ This may lie outside [−η, η], if input is malformed
s1[i] = bit_unpack(&sk[start + i * step..start + (i + 1) * step], ETA as u32, ETA as u32)?;
s1[i] = bit_unpack(&sk[start + i * step..start + (i + 1) * step], eta, eta)?;
// 7: end for
}
// 8: for i from 0 to k − 1 do
let start = start + L * step;
for i in 0..K {
// 9: s2[i] ← BitUnpack(zi, η, η) ▷ This may lie outside [−η, η], if input is malformed
s2[i] = bit_unpack(&sk[start + i * step..start + (i + 1) * step], ETA as u32, ETA as u32)?;
s2[i] = bit_unpack(&sk[start + i * step..start + (i + 1) * step], eta, eta)?;
// 10: end for
}
// 11: for i from 0 to k − 1 do
Expand All @@ -225,8 +197,8 @@ pub(crate) fn sk_decode<
ensure!(start + K * step == sk.len(), "Algorithm 19: asdf asdf ");

// Note spec is not consistent on the range constraints for s1 and s2; this is tighter
let s1_ok = s1.iter().all(|r| is_in_range(r, ETA as u32, ETA as u32));
let s2_ok = s2.iter().all(|r| is_in_range(r, ETA as u32, ETA as u32));
let s1_ok = s1.iter().all(|r| is_in_range(r, eta, eta));
let s2_ok = s2.iter().all(|r| is_in_range(r, eta, eta));
let t0_ok =
t0.iter().all(|r| is_in_range(r, 2u32.pow(D as u32 - 1) + 1, 2u32.pow(D as u32 - 1) + 1));
if s1_ok & s2_ok & t0_ok {
Expand All @@ -246,18 +218,17 @@ pub(crate) fn sk_decode<
/// # Errors
/// Returns an error ........ TKTK
pub(crate) fn sig_encode<
const GAMMA1: usize,
const K: usize,
const L: usize,
const LAMBDA: usize,
const OMEGA: usize,
const SIG_LEN: usize,
>(
c_tilde: &[u8], z: &[R; L], h: &[R; K],
gamma1: u32, c_tilde: &[u8], z: &[R; L], h: &[R; K],
) -> Result<[u8; SIG_LEN], &'static str> {
let mut sigma = [0u8; SIG_LEN];
ensure!(c_tilde.len() == 2 * LAMBDA / 8, "Algoirthm 20: asdf asdf");
let bl = bitlen(GAMMA1 - 1);
let bl = bitlen(gamma1 as usize - 1);
ensure!(sigma.len() == LAMBDA / 4 + L * 32 * (1 + bl) + OMEGA + K, "Algorithm 20: qwer qwer");

// 1: σ ← BitsToBytes(c_tilde)
Expand All @@ -269,8 +240,8 @@ pub(crate) fn sig_encode<
// 3: σ ← σ || BitPack (z[i], γ_1 − 1, γ_1)
bit_pack(
&z[i],
GAMMA1 as u32 - 1,
GAMMA1 as u32,
gamma1 - 1,
gamma1,
&mut sigma[start + i * step..start + (i + 1) * step],
)?;
// 4: end for
Expand All @@ -292,15 +263,14 @@ pub(crate) fn sig_encode<
/// Returns an error ........ TKTK
#[allow(clippy::type_complexity)]
pub(crate) fn sig_decode<
const GAMMA1: usize,
const K: usize,
const L: usize,
const LAMBDA: usize,
const OMEGA: usize,
>(
>(gamma1: u32,
sigma: &[u8],
) -> Result<([u8; 64], [R; L], Option<[R; K]>), &'static str> {
let bl = bitlen(GAMMA1 - 1);
let bl = bitlen(gamma1 as usize - 1);
// let mut c_tilde = vec![0u8; LAMBDA / 4];
let mut c_tilde = [0u8; 64]; // TODO: 'optimize'
let mut z: [R; L] = [R::zero(); L];
Expand All @@ -315,8 +285,8 @@ pub(crate) fn sig_decode<
// 4: z[i] ← BitUnpack(xi, γ1 − 1, γ1) ▷ This is always in the correct range, as γ1 is a power of 2
z[i] = bit_unpack(
&sigma[start + i * step..start + (i + 1) * step],
GAMMA1 as u32 - 1,
GAMMA1 as u32,
gamma1 - 1,
gamma1,
)?;
// 5: end for
}
Expand Down Expand Up @@ -424,8 +394,8 @@ mod tests {
get_vec(2u32.pow(11)),
];
//let mut sk = [0u8; 2560];
let sk = sk_encode::<13, 2, 4, 4, 2560>(&rho, &k, &tr, &s1, &s2, &t0).unwrap();
let res = sk_decode::<13, 2, 4, 4, 2560>(&sk);
let sk = sk_encode::<13, 4, 4, 2560>(2, &rho, &k, &tr, &s1, &s2, &t0).unwrap();
let res = sk_decode::<13, 4, 4, 2560>(2, &sk);
assert!(res.is_ok());
#[allow(clippy::similar_names)]
let (rho_test, k_test, tr_test, s1_test, s2_test, t0_test) = res.unwrap();
Expand All @@ -448,12 +418,12 @@ mod tests {
let h = [get_vec(1), get_vec(1), get_vec(1), get_vec(1)];
//let mut sigma = [0u8; 2420];
let sigma =
sig_encode::<{ 2usize.pow(17) }, 4, 4, 128, 80, 2420>(&c_tilde, &z, &h).unwrap();
sig_encode::<4, 4, 128, 80, 2420>(2u32.pow(17), &c_tilde, &z, &h).unwrap();
// let mut c_test = [0u8; 2 * 128 / 8];
// let mut z_test = [[0i32; 256]; 4];
// let mut h_test = [[0i32; 256]; 4];
let (c_test, z_test, h_test) =
sig_decode::<{ 2usize.pow(17) }, 4, 4, 128, 80>(&sigma).unwrap();
sig_decode::<4, 4, 128, 80>(2u32.pow(17), &sigma).unwrap();
// assert!(res.is_ok());
assert_eq!(c_tilde[0..8], c_test[0..8]);
assert_eq!(z, z_test);
Expand Down
Loading

0 comments on commit 74e5ec5

Please sign in to comment.