From 2fd143761db64a21dcd1b1efa2d88b0a77c320a4 Mon Sep 17 00:00:00 2001
From: Yu Jiang Tham <yoojbruin@gmail.com>
Date: Fri, 3 Jan 2025 21:18:49 -0500
Subject: [PATCH] Add bls12-381 pairing check hint

---
 .../pairing/guest/src/bls12_381/pairing.rs    | 75 +++++++++++++++++--
 .../pairing/guest/src/bls12_381/tests.rs      | 39 +++++++++-
 2 files changed, 108 insertions(+), 6 deletions(-)

diff --git a/extensions/pairing/guest/src/bls12_381/pairing.rs b/extensions/pairing/guest/src/bls12_381/pairing.rs
index 0d8c4cf7e4..7d27049f1f 100644
--- a/extensions/pairing/guest/src/bls12_381/pairing.rs
+++ b/extensions/pairing/guest/src/bls12_381/pairing.rs
@@ -1,9 +1,10 @@
 use alloc::vec::Vec;
 
 use itertools::izip;
+use num_bigint::BigUint;
 use openvm_algebra_guest::{
     field::{ComplexConjugate, FieldExtension},
-    DivUnsafe, Field,
+    DivUnsafe, ExpBytes, Field,
 };
 use openvm_ecc_guest::AffinePoint;
 #[cfg(target_os = "zkvm")]
@@ -16,9 +17,12 @@ use {
 };
 
 use super::{Bls12_381, Fp, Fp12, Fp2};
-use crate::pairing::{
-    Evaluatable, EvaluatedLine, FromLineMType, LineMulMType, MillerStep, MultiMillerLoop,
-    PairingCheck, PairingCheckError, PairingIntrinsics, UnevaluatedLine,
+use crate::{
+    curve_const::bls12_381::{FINAL_EXP_FACTOR, LAMBDA, POLY_FACTOR},
+    pairing::{
+        Evaluatable, EvaluatedLine, FromLineMType, LineMulMType, MillerStep, MultiMillerLoop,
+        PairingCheck, PairingCheckError, PairingIntrinsics, UnevaluatedLine,
+    },
 };
 
 // TODO[jpw]: make macro
@@ -275,7 +279,68 @@ impl PairingCheck for Bls12_381 {
     ) -> (Self::Fp12, Self::Fp12) {
         #[cfg(not(target_os = "zkvm"))]
         {
-            todo!()
+            let f = Self::multi_miller_loop(P, Q);
+
+            // 1. get p-th root inverse
+            let mut exp = FINAL_EXP_FACTOR.clone() * BigUint::from(27u32);
+            let mut root = f.exp_bytes(true, &exp.to_bytes_be());
+            let root_pth_inv: Fp12;
+            if root == Fp12::ONE {
+                root_pth_inv = Fp12::ONE;
+            } else {
+                let exp_inv = exp.modinv(&POLY_FACTOR.clone()).unwrap();
+                exp = exp_inv % POLY_FACTOR.clone();
+                root_pth_inv = root.exp_bytes(false, &exp.to_bytes_be());
+            }
+
+            // 2.1. get order of 3rd primitive root
+            let three = BigUint::from(3u32);
+            let mut order_3rd_power: u32 = 0;
+            exp = POLY_FACTOR.clone() * FINAL_EXP_FACTOR.clone();
+
+            root = f.exp_bytes(true, &exp.to_bytes_be());
+            let three_be = three.to_bytes_be();
+            // NOTE[yj]: we can probably remove this first check as an optimization since we initizlize order_3rd_power to 0
+            if root == Fp12::ONE {
+                order_3rd_power = 0;
+            }
+            root = root.exp_bytes(true, &three_be);
+            if root == Fp12::ONE {
+                order_3rd_power = 1;
+            }
+            root = root.exp_bytes(true, &three_be);
+            if root == Fp12::ONE {
+                order_3rd_power = 2;
+            }
+            root = root.exp_bytes(true, &three_be);
+            if root == Fp12::ONE {
+                order_3rd_power = 3;
+            }
+
+            // 2.2. get 27th root inverse
+            let root_27th_inv: Fp12;
+            if order_3rd_power == 0 {
+                root_27th_inv = Fp12::ONE;
+            } else {
+                let order_3rd = three.pow(order_3rd_power);
+                exp = POLY_FACTOR.clone() * FINAL_EXP_FACTOR.clone();
+                root = f.exp_bytes(true, &exp.to_bytes_be());
+                let exp_inv = exp.modinv(&order_3rd).unwrap();
+                exp = exp_inv % order_3rd;
+                root_27th_inv = root.exp_bytes(false, &exp.to_bytes_be());
+            }
+
+            // 2.3. shift the Miller loop result so that millerLoop * scalingFactor
+            // is of order finalExpFactor
+            let s = root_pth_inv * root_27th_inv;
+            let f = f * s.clone();
+
+            // 3. get the witness residue
+            // lambda = q - u, the optimal exponent
+            exp = LAMBDA.clone().modinv(&FINAL_EXP_FACTOR.clone()).unwrap();
+            let c = f.exp_bytes(true, &exp.to_bytes_be());
+
+            (c, s)
         }
         #[cfg(target_os = "zkvm")]
         {
diff --git a/extensions/pairing/guest/src/bls12_381/tests.rs b/extensions/pairing/guest/src/bls12_381/tests.rs
index 60963e19b1..1294bb0d7b 100644
--- a/extensions/pairing/guest/src/bls12_381/tests.rs
+++ b/extensions/pairing/guest/src/bls12_381/tests.rs
@@ -12,7 +12,8 @@ use super::{Fp, Fp12, Fp2};
 use crate::{
     bls12_381::{Bls12_381, G2Affine as OpenVmG2Affine},
     pairing::{
-        fp2_invert_assign, fp6_invert_assign, fp6_square_assign, MultiMillerLoop, PairingIntrinsics,
+        fp2_invert_assign, fp6_invert_assign, fp6_square_assign, FinalExp, MultiMillerLoop,
+        PairingCheck, PairingIntrinsics,
     },
 };
 
@@ -300,3 +301,39 @@ fn test_bls12381_g2_affine() {
         }
     }
 }
+
+#[test]
+fn test_bls12381_pairing_check_hint_host() {
+    let mut rng = StdRng::seed_from_u64(83);
+    let h2c_p = G1Affine::random(&mut rng);
+    let h2c_q = G2Affine::random(&mut rng);
+
+    let p = AffinePoint {
+        x: convert_bls12381_halo2_fq_to_fp(h2c_p.x),
+        y: convert_bls12381_halo2_fq_to_fp(h2c_p.y),
+    };
+    let q = AffinePoint {
+        x: convert_bls12381_halo2_fq2_to_fp2(h2c_q.x),
+        y: convert_bls12381_halo2_fq2_to_fp2(h2c_q.y),
+    };
+
+    let (c, s) = Bls12_381::pairing_check_hint(&[p], &[q]);
+
+    let p_cmp = AffinePoint {
+        x: h2c_p.x,
+        y: h2c_p.y,
+    };
+    let q_cmp = AffinePoint {
+        x: h2c_q.x,
+        y: h2c_q.y,
+    };
+
+    let f_cmp =
+        crate::halo2curves_shims::bls12_381::Bls12_381::multi_miller_loop(&[p_cmp], &[q_cmp]);
+    let (c_cmp, s_cmp) = crate::halo2curves_shims::bls12_381::Bls12_381::final_exp_hint(&f_cmp);
+    let c_cmp = convert_bls12381_halo2_fq12_to_fp12(c_cmp);
+    let s_cmp = convert_bls12381_halo2_fq12_to_fp12(s_cmp);
+
+    assert_eq!(c, c_cmp);
+    assert_eq!(s, s_cmp);
+}