Skip to content

Commit

Permalink
Merge branch 'wait_for_new_keccak' of github.com:formosa-crypto/formo…
Browse files Browse the repository at this point in the history
…sa-mlkem into wait_for_new_keccak
mbbarbosa-lectures committed Dec 20, 2024
2 parents 51ae90f + a9671d8 commit bf85cc7
Showing 3 changed files with 75 additions and 91 deletions.
5 changes: 4 additions & 1 deletion proof/correctness/Montgomery.ec
Original file line number Diff line number Diff line change
@@ -230,7 +230,10 @@ lemma SREDCp_corr (a : int):
proof.
move => [#] H H0 [#] H1 H2.
have H3 : (R * R %/ 4 = R %/ 2 * R %/2 ) by smt(dvd2R div_mulr).
have albnd : (- R * R %/4 <= a)by rewrite -mulN1r H3 mulN1r /#.
have albnd : (- R * R %/4 <= a).
rewrite -mulN1r H3 mulN1r; apply ltrW.
apply (ltr_le_trans (-(R%/2 * SignedReductions.q))) => //.
by apply ltr_opp2; rewrite div_mulr /#.
have aubnd : (a < R* R %/4) by smt(ler_pmul dvd2R div_mulr).
rewrite /SREDC /= (smod_div (a * qinv)).
move : (smod_bnd (a * qinv) R _ _); first 2 by smt(gt0_R dvd2R).
22 changes: 21 additions & 1 deletion proof/correctness/avx2/NTT_AVX_j.ec
Original file line number Diff line number Diff line change
@@ -1069,7 +1069,27 @@ lemma wmuls16P n x y _x _y:
Iu16_sb n x _x =>
Iu16_sb n y _y =>
sint32_bnd (-n*n*q*q) (n*n*q*q) (wmuls16 x y).
proof. by move => [??] [??]; rewrite to_sint_wmuls16 /#. qed.
proof.
have ->: - n * n * q * q = -(n * q)*(n*q) by ring.
have ->: n * n * q * q = ( n * q)*(n*q) by ring.
rewrite /Iu16_sb to_sint_wmuls16.
pose k:= n*q.
pose a:= to_sint x.
pose b:= to_sint y.
move=> [?[??]][?[??]].
case: (0 <= a); case: (0 <= b) => C1 C2.
+ smt(ler_pmul).
+ pose bb := -b.
have ?: - (k*k) <= a*bb && a*bb <= k*k;
smt(ler_opp2 ler_pmul).
+ pose aa := -a.
have ?: - (k*k) <= aa*b && aa*b <= k*k;
smt(ler_opp2 ler_pmul).
+ pose aa := -a.
pose bb := -b.
have ?: - (k*k) <= aa*bb && aa*bb <= k*k;
smt(ler_opp2 ler_pmul).
qed.
phoare wmul_16u16_ph n _x _y:
[Jkem_avx2.M(Jkem_avx2.Syscall).__wmul_16u16:
139 changes: 50 additions & 89 deletions proof/eclib/bindings.ec
Original file line number Diff line number Diff line change
@@ -7,6 +7,22 @@ import BitEncoding BS2Int BitChunking.

require import JWord_extra.


(*[size_flatten] (for uniform inner lists) *)
lemma size_flatten' ['a] sz (ss: 'a list list):
(forall x, x\in ss => size x = sz) =>
size (flatten ss) = sz*size ss.
proof.
move=> H; rewrite size_flatten.
rewrite StdBigop.Bigint.sumzE.
rewrite StdBigop.Bigint.BIA.big_map.
rewrite -(StdBigop.Bigint.BIA.eq_big_seq (fun _ => sz)) /=.
by move=> x Hx; rewrite /(\o) /= H.
by rewrite StdBigop.Bigint.big_constz count_predT.
qed.
(* ----------- BEGIN BOOL BINDINGS ---------- *)
op bool2bits (b : bool) : bool list = [b].
op bits2bool (b: bool list) : bool = List.nth false b 0.
@@ -176,54 +192,31 @@ bind op W16.t W16_sub "sub".
realize bvsubP by rewrite /W16_sub => bv1 bv2; rewrite W16.to_uintD to_uintN /= /#.
op sll_16 (w1 w2 : W16.t) : W16.t =
if (16 <= to_uint w2) then W16.zero else w1 `<<` (truncateu8 w2).
w1 `<<<` to_uint w2.
bind op [W16.t] sll_16 "shl".
realize bvshlP.
rewrite /sll_16 => bv1 bv2.
case : (16 <= to_uint bv2); last first.
+ rewrite /(`<<`) W16.to_uint_shl; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
rewrite (pmod_small (to_uint bv2) _).
smt(W16.to_uint_cmp).
rewrite (pmod_small (to_uint bv2) _).
smt(W16.to_uint_cmp).
done.
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring.
by rewrite exprD_nneg 1,2:/# /= /#.
rewrite /shl_16 => bv1 bv2.
by rewrite W16.to_uint_shl; 1:smt(W16.to_uint_cmp).
qed.

op sra_16 (w1 w2 : W16.t) : W16.t =
W16.sar w1 (to_uint w2).
(*
if (16 <= to_uint w2) then W16.zero else w1 `|>>` (truncateu8 w2).
*)
bind op [W16.t] sra_16 "ashr".
realize bvashrP.
move=> bv1 bv2; rewrite W16_sar_div; smt(W16.to_uint_cmp).
qed.
op srl_16 (w1 w2 : W16.t) : W16.t =
if 16 <= (to_uint w2) then W16.zero else
w1 `>>` (truncateu8 w2).
w1 `>>>` W16.to_uint w2.
bind op [W16.t] srl_16 "shr".
realize bvshrP.
rewrite /srl_16 => bv1 bv2.
case : (16 <= to_uint bv2); last first.
+ rewrite /(`>>`) W16.to_uint_shr; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
rewrite (pmod_small (to_uint bv2) _); smt(W16.to_uint_cmp).
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring.
rewrite exprD_nneg 1,2:/# /=.
smt(StdOrder.IntOrder.expr_gt0 W16.to_uint_cmp pow2_16).
rewrite /shr_16 => bv1 bv2.
by rewrite W16.to_uint_shr; 1:smt(W16.to_uint_cmp).
qed.

bind op [W16.t & W8.t] W2u8.truncateu8 "truncate".
realize bvtruncateP.
move => mv; rewrite /truncateu8 /W16.w2bits take_mkseq 1:// /= /w2bits.
@@ -243,6 +236,11 @@ by have -> : (2 ^ (8 - i) * 2 ^ i) = 256;
1,2:/# /= -!addrA /= | done ].
qed.
bind op [W8.t & W16.t] W2u8.zeroextu16 "zextend".
realize bvzextendP.
move => bv; rewrite /zeroextu16 /= of_uintK /= modz_small 2://.
apply bound_abs; smt(W8.to_uint_cmp pow2_8).
qed.
(* ----------- BEGIN W32 BINDINGS ---------- *)
@@ -295,50 +293,34 @@ by rewrite !nth_zip /=;1:smt(W32.size_w2bits).
qed.
op sll_32 (w1 w2 : W32.t) : W32.t =
if (32 <= to_uint w2)
then W32.zero
else w1 `<<` (truncateu8 w2).
w1 `<<<` to_uint w2.
bind op [W32.t] sll_32 "shl".
realize bvshlP.
rewrite /sll_32 => bv1 bv2.
case : (32 <= to_uint bv2); last first.
+ rewrite /(`<<`) W32.to_uint_shl; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
rewrite (pmod_small (to_uint bv2) _); smt(W32.to_uint_cmp).
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 32) + 32 by ring.
by rewrite exprD_nneg 1,2:/# /= /#.
rewrite /shl_32 => bv1 bv2.
by rewrite W32.to_uint_shl; 1:smt(W32.to_uint_cmp).
qed.
op srl_32 (w1 w2 : W32.t) : W32.t =
if 32 <= (to_uint w2) then W32.zero else
w1 `>>` (truncateu8 w2).
w1 `>>>` W32.to_uint w2.
bind op [W32.t] srl_32 "shr".
realize bvshrP.
rewrite /srl_32 => bv1 bv2.
case : (32 <= to_uint bv2); last first.
+ rewrite /(`>>`) W32.to_uint_shr; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
rewrite (pmod_small (to_uint bv2) _); smt(W32.to_uint_cmp).
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 32) + 32 by ring.
rewrite exprD_nneg 1,2:/# /=.
smt(StdOrder.IntOrder.expr_gt0 W32.to_uint_cmp pow2_32).
rewrite /shr_32 => bv1 bv2.
by rewrite W32.to_uint_shr; 1:smt(W32.to_uint_cmp).
qed.
op sra_32 (w1 w2 : W32.t) : W32.t =
W32.sar w1 (to_uint w2).
(*
if (32 <= to_uint w2) then W32.zero else w1 `|>>` (truncateu8 w2).
*)
bind op [W32.t] sra_32 "ashr".
realize bvashrP.
move => bv1 bv2; rewrite W32_sar_div; smt(W32.to_uint_cmp).
qed.
bind op [W8.t & W32.t] W4u8.zeroextu32 "zextend".
realize bvzextendP by move => bv; rewrite /zeroextu32 /= of_uintK /=; smt(W8.to_uint_cmp pow2_8).
bind op [W16.t & W32.t] W2u16.zeroextu32 "zextend".
realize bvzextendP by move => bv; rewrite /zeroextu64 /= of_uintK /=; smt(W16.to_uint_cmp pow2_16).
@@ -406,19 +388,6 @@ by smt(bs2int_range mem_range W64.size_w2bits pow2_64).
qed.
realize touintP by smt().
(*[size_flatten] (for uniform inner lists) *)
lemma size_flatten' ['a] sz (ss: 'a list list):
(forall x, x\in ss => size x = sz) =>
size (flatten ss) = sz*size ss.
proof.
move=> H; rewrite size_flatten.
rewrite StdBigop.Bigint.sumzE.
rewrite StdBigop.Bigint.BIA.big_map.
rewrite -(StdBigop.Bigint.BIA.eq_big_seq (fun _ => sz)) /=.
by move=> x Hx; rewrite /(\o) /= H.
by rewrite StdBigop.Bigint.big_constz count_predT.
qed.
bind op [bool & W64.t] W64.init "init".
realize bvinitP.
move=> f; apply (eq_from_nth false).
@@ -479,36 +448,21 @@ qed.


op srl_64 (w1 w2 : W64.t) : W64.t =
if (64 <= to_uint w2) then W64.zero else w1 `>>` (truncateu8 w2).
w1 `>>>` to_uint w2.

bind op [W64.t] srl_64 "shr".
realize bvshrP.
rewrite /srl_64 => bv1 bv2.
case : (64 <= to_uint bv2); last first.
+ rewrite /(`>>`) W64.to_uint_shr; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
rewrite (pmod_small (to_uint bv2) _); smt(W64.to_uint_cmp).
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 64) + 64 by ring.
rewrite exprD_nneg 1,2:/# /=.
smt(StdOrder.IntOrder.expr_gt0 W64.to_uint_cmp pow2_64).
rewrite /shr_64 => bv1 bv2.
by rewrite W64.to_uint_shr; 1:smt(W64.to_uint_cmp).
qed.


op sll_64 (w1 w2 : W64.t) : W64.t =
if (64 <= to_uint w2) then W64.zero else w1 `<<` (truncateu8 w2).
w1 `<<<` to_uint w2.

bind op [W64.t] sll_64 "shl".
realize bvshlP.
proof.
rewrite /sll_64 => bv1 bv2.
case : (64 <= to_uint bv2); last first.
+ rewrite /(`<<`) W64.to_uint_shl; 1: by smt(W8.to_uint_cmp).
rewrite /truncateu8 => bv2bnd />.
rewrite (pmod_small (to_uint bv2) _);smt(W64.to_uint_cmp).
move => *.
have -> : to_uint bv2 = (to_uint bv2 - 64) + 64 by ring.
by rewrite exprD_nneg 1,2:/# /= /#.
realize bvshlP.
rewrite /shl_64 => bv1 bv2.
by rewrite W64.to_uint_shl; 1:smt(W64.to_uint_cmp).
qed.

op rol_64 (w1 w2 : W64.t): W64.t =
@@ -520,11 +474,18 @@ rewrite /rol_64=> bv1 bv2 i Hi.
by rewrite !get_w2bits rolE initiE.
qed.

bind op [W8.t & W64.t] W8u8.zeroextu64 "zextend".
realize bvzextendP
by move => bv; rewrite /zeroextu64 /= of_uintK /=; smt(W8.to_uint_cmp pow2_8).

bind op [W16.t & W64.t] W4u16.zeroextu64 "zextend".
realize bvzextendP
by move => bv; rewrite /zeroextu64 /= of_uintK /=; smt(W16.to_uint_cmp pow2_16).

bind op [W32.t & W64.t] W2u32.zeroextu64 "zextend".
realize bvzextendP
by move => bv; rewrite /zeroextu64 /= of_uintK /=; smt(W32.to_uint_cmp pow2_32).

bind op [W64.t & W16.t] W4u16.truncateu16 "truncate".
realize bvtruncateP.
move => mv; rewrite /truncateu16 /W64.w2bits take_mkseq 1:// /= /w2bits.

0 comments on commit bf85cc7

Please sign in to comment.