Skip to content

Commit

Permalink
cut stack usage
Browse files Browse the repository at this point in the history
  • Loading branch information
eschorn1 committed Oct 22, 2024
1 parent abebea1 commit e9553c2
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 34 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## 0.4.4 (2024-10-XX)

- Significant shrink of required stack size
- Internal-only refactoring and polishing

## 0.4.3 (2024-10-16)

- Adapted ExpandedPrivateKey into PrivateKey and ExpandedPublicKey into PublicKey, removed the former(s)
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ workspace = { exclude = ["ct_cm4", "dudect", "fuzz", "wasm"] }

[package]
name = "fips204"
version = "0.4.3"
version = "0.4.4"
authors = ["Eric Schorn <[email protected]>"]
description = "FIPS 204: Module-Lattice-Based Digital Signature"
categories = ["cryptography", "no-std"]
Expand Down
21 changes: 13 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@


// TODO Roadmap
// 0. Code clean-up, more carefully shrink stack
// 1. Improve docs on first/last few algorithms
// 2. Several outstanding refactors (mostly down below in this file)
// 3. Always more testing...
// 2. Always more testing...


// Implements FIPS 204 Module-Lattice-Based Digital Signature Standard.
Expand Down Expand Up @@ -277,9 +277,11 @@ macro_rules! functionality {
use crate::high_low::power2round;
use crate::helpers::to_mont;
use crate::D;
use crate::hashing::expand_a;

// TODO: refactor
let PrivateKey {rho, cap_k: _, tr, s_hat_1_mont, s_hat_2_mont, t_hat_0_mont, cap_a_hat} = &self;
let PrivateKey {rho, cap_k: _, tr, s_hat_1_mont, s_hat_2_mont, t_hat_0_mont} = &self;
let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(&rho);
let s_1: [R; L] = inv_ntt(&core::array::from_fn(|l| T(core::array::from_fn(|n| full_reduce32(mont_reduce(s_hat_1_mont[l].0[n] as i64))))));
let s_1: [R; L] = core::array::from_fn(|l| R(core::array::from_fn(|n| if s_1[l].0[n] > (Q >> 2) {s_1[l].0[n] - Q} else {s_1[l].0[n]})));
let s_2: [R; K] = inv_ntt(&core::array::from_fn(|k| T(core::array::from_fn(|n| full_reduce32(mont_reduce(s_hat_2_mont[k].0[n] as i64))))));
Expand Down Expand Up @@ -307,7 +309,8 @@ macro_rules! functionality {
let t1_d2_hat_mont: [T; K] = to_mont(&core::array::from_fn(|k| {
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
}));
let pk = PublicKey { rho: *rho, cap_a_hat: cap_a_hat.clone(), tr: *tr, t1_d2_hat_mont};
//let pk = PublicKey { rho: *rho, cap_a_hat: cap_a_hat.clone(), tr: *tr, t1_d2_hat_mont};
let pk = PublicKey { rho: *rho, tr: *tr, t1_d2_hat_mont};

// 10: return pk
pk
Expand All @@ -320,7 +323,7 @@ macro_rules! functionality {

// Algorithm 3 in Verifier trait.
fn verify(&self, message: &[u8], sig: &Self::Signature, ctx: &[u8]) -> bool {
let Ok(res) = ml_dsa::verify::<K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
let Ok(res) = ml_dsa::verify::<false, K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
BETA, GAMMA1, GAMMA2, OMEGA, TAU, &self, &message, &sig, ctx, &[], &[], false
) else {
return false;
Expand All @@ -335,7 +338,7 @@ macro_rules! functionality {
};
let mut phm = [0u8; 64]; // hashers don't all play well with each other
let (oid, phm_len) = hash_message(message, ph, &mut phm);
let Ok(res) = ml_dsa::verify::<K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
let Ok(res) = ml_dsa::verify::<false, K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
BETA, GAMMA1, GAMMA2, OMEGA, TAU, &self, &message, &sig, ctx, &oid, &phm[0..phm_len], false
) else {
return false;
Expand Down Expand Up @@ -396,9 +399,11 @@ macro_rules! functionality {
use crate::helpers::full_reduce32;
use crate::ntt::inv_ntt;
use crate::D;
use crate::hashing::expand_a;

// TODO: refactor
let PublicKey {rho, cap_a_hat, tr, t1_d2_hat_mont} = &self;
let PublicKey {rho, tr, t1_d2_hat_mont} = &self;
let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(&rho);
let (_, _, _, _) = (rho, cap_a_hat, tr, t1_d2_hat_mont);
let t1_d2: [R; K] = inv_ntt(&core::array::from_fn(|k| T(core::array::from_fn(|n| full_reduce32(mont_reduce(t1_d2_hat_mont[k].0[n] as i64))))));
let t1: [R; K] = core::array::from_fn(|k| R(core::array::from_fn(|n| t1_d2[k].0[n] >> D)));
Expand Down Expand Up @@ -486,7 +491,7 @@ macro_rules! functionality {
if ctx.len() > 255 {
return false;
};
let Ok(res) = ml_dsa::verify::<K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
let Ok(res) = ml_dsa::verify::<false, K, L, LAMBDA_DIV4, PK_LEN, SIG_LEN, W1_LEN>(
BETA, GAMMA1, GAMMA2, OMEGA, TAU, pk, &message, &sig, ctx, &[], &[], true
) else {
return false;
Expand Down
55 changes: 32 additions & 23 deletions src/ml_dsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,17 @@ pub(crate) fn sign<

// Extract from expand_private()
let PrivateKey {
rho: _,
rho,
cap_k,
tr,
s_hat_1_mont,
s_hat_2_mont,
t_hat_0_mont,
cap_a_hat,
//cap_a_hat,
} = esk;

let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);

// 6: 𝜇 ← H(BytesToBits(𝑡𝑟)||𝑀 , 64) ▷ Compute message representative µ
// We may have arrived from 3 different paths
let mut h6 = if nist {
Expand Down Expand Up @@ -134,7 +136,7 @@ pub(crate) fn sign<
// 12: w ← NTT−1(cap_a_hat ◦ NTT(y))
let w: [R; K] = {
let y_hat: [T; L] = ntt(&y);
let ay_hat: [T; K] = mat_vec_mul(cap_a_hat, &y_hat);
let ay_hat: [T; K] = mat_vec_mul(&cap_a_hat, &y_hat);
inv_ntt(&ay_hat)
};

Expand Down Expand Up @@ -259,6 +261,7 @@ pub(crate) fn sign<
/// Continuation of `verify_start()`. The `lib.rs` wrapper around this will convert `Error()` to false.
#[allow(clippy::too_many_arguments, clippy::similar_names)]
pub(crate) fn verify<
const CTEST: bool,
const K: usize,
const L: usize,
const LAMBDA_DIV4: usize,
Expand All @@ -270,7 +273,8 @@ pub(crate) fn verify<
sig: &[u8; SIG_LEN], ctx: &[u8], oid: &[u8], phm: &[u8], nist: bool,
) -> Result<bool, &'static str> {
//
let PublicKey { rho: _, cap_a_hat, tr, t1_d2_hat_mont } = epk;
//let PublicKey { rho: _, cap_a_hat, tr, t1_d2_hat_mont } = epk;
let PublicKey { rho, tr, t1_d2_hat_mont } = epk;

// 1: (ρ, t_1) ← pkDecode(pk)
// --> calculated in expand_public()
Expand Down Expand Up @@ -314,8 +318,10 @@ pub(crate) fn verify<

// 9: w′_Approx ← invNTT(cap_A_hat ◦ NTT(z) - NTT(c) ◦ NTT(t_1 · 2^d) ▷ w′_Approx = Az − ct1·2^d
let wp_approx: [R; K] = {
// hardcode CTEST as false since everything is public here
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);
let z_hat: [T; L] = ntt(&z);
let az_hat: [T; K] = mat_vec_mul(cap_a_hat, &z_hat);
let az_hat: [T; K] = mat_vec_mul(&cap_a_hat, &z_hat);
// NTT(t_1 · 2^d) --> calculated in expand_public()
let c_hat: &T = &ntt(&[c])[0];
inv_ntt(&core::array::from_fn(|k| {
Expand Down Expand Up @@ -378,22 +384,22 @@ pub(crate) fn key_gen_internal<

// There is effectively no step 2 due to formatting error in spec

// 3: cap_a_hat ← ExpandA(ρ) ▷ A is generated and stored in NTT representation as Â
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(&rho);

// 4: (s_1, s_2) ← ExpandS(ρ′)
let (s_1, s_2): ([R; L], [R; K]) = expand_s::<CTEST, K, L>(eta, &rho_prime);

// 3: cap_a_hat ← ExpandA(ρ) ▷ A is generated and stored in NTT representation as Â
// 5: t ← NTT−1(cap_a_hat ◦ NTT(s_1)) + s_2 ▷ Compute t = As1 + s2
//let t: [R; K]
let s_1_hat: [T; L] = ntt(&s_1);
let as1_hat: [T; K] = mat_vec_mul(&cap_a_hat, &s_1_hat);
let t_not_reduced: [R; K] = add_vector_ntt(&inv_ntt(&as1_hat), &s_2);
let t: [R; K] =
core::array::from_fn(|k| R(core::array::from_fn(|n| full_reduce32(t_not_reduced[k].0[n]))));

// 6: (t_1, t_0) ← Power2Round(t, d) ▷ Compress t
let (t_1, t_0): ([R; K], [R; K]) = power2round(&t);

let (t_1, t_0): ([R; K], [R; K]) = {
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(&rho);
let s_1_hat: [T; L] = ntt(&s_1);
let as1_hat: [T; K] = mat_vec_mul(&cap_a_hat, &s_1_hat);
let t_not_reduced: [R; K] = add_vector_ntt(&inv_ntt(&as1_hat), &s_2);
let t: [R; K] =
core::array::from_fn(|k| R(core::array::from_fn(|n| full_reduce32(t_not_reduced[k].0[n]))));
power2round(&t)
};

// There is effectively no step 7 due to formatting error in spec

Expand All @@ -414,10 +420,12 @@ pub(crate) fn key_gen_internal<
let t1_d2_hat_mont: [T; K] = to_mont(&core::array::from_fn(|k| {
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
}));
let pk = PublicKey { rho, cap_a_hat: cap_a_hat.clone(), tr, t1_d2_hat_mont };
//let pk = PublicKey { rho, cap_a_hat: cap_a_hat.clone(), tr, t1_d2_hat_mont };
let pk = PublicKey { rho, tr, t1_d2_hat_mont };

// 2: s_hat_1 ← NTT(s_1)
let s_hat_1_mont: [T; L] = to_mont(&s_1_hat); //ntt(&s_1));
//let s_hat_1_mont: [T; L] = to_mont(&s_1_hat); //ntt(&s_1));
let s_hat_1_mont: [T; L] = to_mont(&ntt(&s_1));
// 3: s_hat_2 ← NTT(s_2)
let s_hat_2_mont: [T; K] = to_mont(&ntt(&s_2));
// 4: t_hat_0 ← NTT(t_0)
Expand All @@ -429,7 +437,7 @@ pub(crate) fn key_gen_internal<
s_hat_1_mont,
s_hat_2_mont,
t_hat_0_mont,
cap_a_hat,
// cap_a_hat,
};

// 11: return (pk, sk)
Expand Down Expand Up @@ -463,7 +471,7 @@ pub(crate) fn expand_private<
let t_hat_0_mont: [T; K] = to_mont(&ntt(&t_0));

// 5: cap_a_hat ← ExpandA(ρ) ▷ A is generated and stored in NTT representation as Â
let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);
//let cap_a_hat: [[T; L]; K] = expand_a::<CTEST, K, L>(rho);

Ok(PrivateKey {
rho: *rho,
Expand All @@ -472,7 +480,7 @@ pub(crate) fn expand_private<
s_hat_1_mont,
s_hat_2_mont,
t_hat_0_mont,
cap_a_hat,
//cap_a_hat,
})
}

Expand All @@ -489,7 +497,7 @@ pub(crate) fn expand_public<const K: usize, const L: usize, const PK_LEN: usize>
let (rho, t_1): (&[u8; 32], [R; K]) = pk_decode(pk)?;

// 5: cap_a_hat ← ExpandA(ρ) ▷ A is generated and stored in NTT representation as cap_A_hat
let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(rho);
//let cap_a_hat: [[T; L]; K] = expand_a::<false, K, L>(rho);

// 6: tr ← H(pk, 64)
let mut h6 = h256_xof(&[pk]);
Expand All @@ -503,5 +511,6 @@ pub(crate) fn expand_public<const K: usize, const L: usize, const PK_LEN: usize>
T(core::array::from_fn(|n| mont_reduce(i64::from(t1_hat_mont[k].0[n]) << D)))
}));

Ok(PublicKey { rho: *rho, cap_a_hat, tr, t1_d2_hat_mont })
//Ok(PublicKey { rho: *rho, cap_a_hat, tr, t1_d2_hat_mont })
Ok(PublicKey { rho: *rho, tr, t1_d2_hat_mont })
}
4 changes: 2 additions & 2 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub struct PrivateKey<const K: usize, const L: usize> {
pub(crate) s_hat_1_mont: [T; L],
pub(crate) s_hat_2_mont: [T; K],
pub(crate) t_hat_0_mont: [T; K],
pub(crate) cap_a_hat: [[T; L]; K],
// pub(crate) cap_a_hat: [[T; L]; K],
}


Expand All @@ -37,7 +37,7 @@ pub struct PrivateKey<const K: usize, const L: usize> {
#[repr(align(8))]
pub struct PublicKey<const K: usize, const L: usize> {
pub(crate) rho: [u8; 32],
pub(crate) cap_a_hat: [[T; L]; K],
// pub(crate) cap_a_hat: [[T; L]; K],
pub(crate) tr: [u8; 64],
pub(crate) t1_d2_hat_mont: [T; K],
}
Expand Down

0 comments on commit e9553c2

Please sign in to comment.