From 0b2a08ec7995b3480b63ebb29681789c4b4e325c Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Sun, 24 Nov 2024 00:37:07 +0300 Subject: [PATCH 01/18] fix: remove changelog from cabal --- zkfold-prover.cabal | 1 - 1 file changed, 1 deletion(-) diff --git a/zkfold-prover.cabal b/zkfold-prover.cabal index 747e41e..e26123f 100644 --- a/zkfold-prover.cabal +++ b/zkfold-prover.cabal @@ -7,7 +7,6 @@ author: ZkFold maintainer: info@zkfold.io category: Network build-type: Custom -extra-doc-files: CHANGELOG.md custom-setup setup-depends: From e6f4011d97b4d12788ece1c04c058051f812d85d Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Sun, 24 Nov 2024 00:42:15 +0300 Subject: [PATCH 02/18] fix: rename to `zkfold-prover` --- zkfold-prover.cabal | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zkfold-prover.cabal b/zkfold-prover.cabal index e26123f..c5bd5e9 100644 --- a/zkfold-prover.cabal +++ b/zkfold-prover.cabal @@ -1,5 +1,5 @@ cabal-version: 3.0 -name: haskell-wrapper +name: zkfold-prover version: 0.1.0.0 license: BSD-3-Clause license-file: LICENSE From 6f35322e907648a9a4c314a54a9af358421668f1 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Sun, 24 Nov 2024 01:09:59 +0300 Subject: [PATCH 03/18] fix: cabal --- zkfold-prover.cabal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/zkfold-prover.cabal b/zkfold-prover.cabal index c5bd5e9..1c9873d 100644 --- a/zkfold-prover.cabal +++ b/zkfold-prover.cabal @@ -104,7 +104,7 @@ test-suite wrapper-test build-depends: , base >= 4.16 , zkfold-base - , haskell-wrapper + , zkfold-prover , hspec < 2.12 , QuickCheck , vector @@ -124,7 +124,7 @@ benchmark msm vector, zkfold-base, bytestring, - haskell-wrapper + zkfold-prover benchmark prove import: options @@ -141,7 +141,7 @@ benchmark prove vector, zkfold-base, bytestring, - haskell-wrapper + zkfold-prover benchmark poly-mul import: options @@ -158,6 +158,6 @@ benchmark poly-mul vector, zkfold-base, bytestring, - haskell-wrapper, + zkfold-prover, random, deepseq \ No newline at end of file From b3506a418458eb898f0e19eab565c3c3cf090bb5 Mon Sep 17 00:00:00 2001 From: Evgeniy Samodelov Date: Fri, 29 Nov 2024 12:52:29 +0300 Subject: [PATCH 04/18] fix: full path to Cargo.toml --- Setup.hs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Setup.hs b/Setup.hs index 542afbf..cd5d13a 100644 --- a/Setup.hs +++ b/Setup.hs @@ -16,7 +16,9 @@ main = defaultMainWithHooks simpleUserHooks buildRustLib :: Args -> a -> IO HookedBuildInfo buildRustLib _ flags = do - buildResult <- system "cargo +nightly cbuild --release --manifest-path rust-wrapper/Cargo.toml" + dir <- getCurrentDirectory + + buildResult <- system ("cargo +nightly cbuild --release --manifest-path " ++ dir ++ "/rust-wrapper/Cargo.toml") case buildResult of ExitSuccess -> return () ExitFailure exitCode -> throwIO $ userError $ "Build rust library failed with exit code " <> show exitCode From 46a31d1836b07197946660342b5e720949a44609 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Fri, 13 Dec 2024 18:51:48 +0300 Subject: [PATCH 05/18] fix: update to symbolic-base --- bench/BenchMSM.hs | 4 ++-- bench/BenchProve.hs | 13 +++++++++---- cabal.project | 3 ++- haskell-wrapper/src/RustFunctions.hs | 3 ++- zkfold-prover.cabal | 10 +++++----- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/bench/BenchMSM.hs b/bench/BenchMSM.hs index 42f024d..212601d 100644 --- a/bench/BenchMSM.hs +++ b/bench/BenchMSM.hs @@ -24,8 +24,8 @@ main = do (Vector s) <- generate arbitrary :: IO (Vector Length (ScalarField BLS12_381_G1)) let - points = V.fromList p - scalars = toPolyVec @(ScalarField BLS12_381_G1) @Length $ V.fromList s + points = p + scalars = toPolyVec @(ScalarField BLS12_381_G1) @Length s defaultMain [ diff --git a/bench/BenchProve.hs b/bench/BenchProve.hs index b2394f5..e22a04c 100644 --- a/bench/BenchProve.hs +++ b/bench/BenchProve.hs @@ -1,21 +1,26 @@ +{-# OPTIONS_GHC -Wno-orphans #-} module Main where import qualified Data.ByteString as BS +import GHC.Generics (U1 (U1)) import Prelude hiding (Num (..), length, sum, take, (-)) import RustFunctions (RustCore) import Test.QuickCheck (Arbitrary (arbitrary), generate) import Test.Tasty.Bench import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_G2) -import ZkFold.Base.Protocol.ARK.Plonk (Plonk) import ZkFold.Base.Protocol.NonInteractiveProof +import ZkFold.Base.Protocol.Plonk (Plonk) -type PlonkSizeBS = 128 -type PlonkBS n = Plonk PlonkSizeBS n BLS12_381_G1 BLS12_381_G2 BS.ByteString +type PlonkBS n = Plonk U1 1 32 n BLS12_381_G1 BLS12_381_G2 BS.ByteString + +instance Arbitrary (U1 a) where + arbitrary = return U1 main :: IO () main = do - (TestData a w) <- generate arbitrary :: IO (NonInteractiveProofTestData (PlonkBS 2) HaskellCore) + a <- generate arbitrary :: IO (PlonkBS 2) + w <- generate arbitrary :: IO (Witness (PlonkBS 2)) let spHaskell = setupProve @(PlonkBS 2) @HaskellCore a spRust = setupProve @(PlonkBS 2) @RustCore a diff --git a/cabal.project b/cabal.project index d5dbc56..ee30590 100644 --- a/cabal.project +++ b/cabal.project @@ -8,4 +8,5 @@ source-repository-package source-repository-package type: git location: https://github.com/zkFold/zkfold-base.git - tag: 204a983ce39dd683c1776a0c54a6fb02e53305f6 + tag: e34d8999d45761481e69deac17bbd550a36acf2b + subdir: symbolic-base \ No newline at end of file diff --git a/haskell-wrapper/src/RustFunctions.hs b/haskell-wrapper/src/RustFunctions.hs index 30dbe7c..1a0d7d4 100644 --- a/haskell-wrapper/src/RustFunctions.hs +++ b/haskell-wrapper/src/RustFunctions.hs @@ -28,7 +28,7 @@ import System.Posix.DynamicLinker import ZkFold.Base.Algebra.Basic.Field import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 import ZkFold.Base.Algebra.EllipticCurve.Class -import ZkFold.Base.Algebra.Polynomials.Univariate (fromPolyVec) +import ZkFold.Base.Algebra.Polynomials.Univariate (fromPoly, fromPolyVec, toPoly) import ZkFold.Base.Data.ByteString import ZkFold.Base.Protocol.NonInteractiveProof (CoreFunction (..), msm) @@ -181,6 +181,7 @@ instance CoreFunction BLS12_381_G1 RustCore where (a:rs1, b:rs2) zipAndUnzip _ _ = ([],[]) + polyMul x y = toPoly (rustMulFft @(ScalarField BLS12_381_G1) (fromPoly x) (fromPoly y)) rustMulFft :: forall f . Storable f => V.Vector f -> V.Vector f -> V.Vector f rustMulFft l r = if lByteLength * rByteLength == 0 then V.empty else unsafePerformIO runFFT diff --git a/zkfold-prover.cabal b/zkfold-prover.cabal index 1c9873d..3f590db 100644 --- a/zkfold-prover.cabal +++ b/zkfold-prover.cabal @@ -86,7 +86,7 @@ library build-depends: , base >= 4.16 , bytestring - , zkfold-base + , symbolic-base , vector , binary , unix @@ -103,7 +103,7 @@ test-suite wrapper-test FFT build-depends: , base >= 4.16 - , zkfold-base + , symbolic-base , zkfold-prover , hspec < 2.12 , QuickCheck @@ -122,7 +122,7 @@ benchmark msm tasty-bench, QuickCheck, vector, - zkfold-base, + symbolic-base, bytestring, zkfold-prover @@ -139,7 +139,7 @@ benchmark prove tasty-bench, QuickCheck, vector, - zkfold-base, + symbolic-base, bytestring, zkfold-prover @@ -156,7 +156,7 @@ benchmark poly-mul tasty-bench, QuickCheck, vector, - zkfold-base, + symbolic-base, bytestring, zkfold-prover, random, From a3c14178a9c955f0d274fb0d4cf6a1ef4b836b99 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 11:20:59 +0300 Subject: [PATCH 06/18] feat: add rust polynomial division --- Setup.hs | 42 +++++++----- bench/BenchPolyDiv.hs | 51 +++++++++++++++ bench/BenchProve.hs | 7 +- cabal.project | 7 +- haskell-wrapper/src/RustFunctions.hs | 95 +++++++++++++++++++++++----- haskell-wrapper/tests/FFT.hs | 30 +++++++-- rust-wrapper/Cargo.toml | 25 +++++--- rust-wrapper/benches/div_bench.rs | 48 ++++++++++++++ rust-wrapper/src/fft.rs | 56 ++++++++++++++++ zkfold-prover.cabal | 23 ++++++- 10 files changed, 328 insertions(+), 56 deletions(-) create mode 100644 bench/BenchPolyDiv.hs create mode 100644 rust-wrapper/benches/div_bench.rs diff --git a/Setup.hs b/Setup.hs index cd5d13a..d71f16b 100644 --- a/Setup.hs +++ b/Setup.hs @@ -1,38 +1,48 @@ +{-# LANGUAGE TemplateHaskell #-} import Control.Exception (throwIO) import Control.Monad import Data.Char (isSpace) -import Data.List (dropWhile, isPrefixOf) +import Data.Maybe (fromJust) +import Data.List (dropWhile, isPrefixOf, tails, findIndex, find) import Distribution.Simple import Distribution.Types.HookedBuildInfo +import PseudoMacros import System.Directory import System.Exit import System.Process (readProcess, system) main :: IO () main = defaultMainWithHooks simpleUserHooks - { preConf = buildRustLib + { + preConf = buildRustLib } buildRustLib :: Args -> a -> IO HookedBuildInfo buildRustLib _ flags = do - dir <- getCurrentDirectory + let file = $__FILE__ + let pathToDistNewstyle = take (fromJust $ findIndex (isPrefixOf "dist-newstyle") (tails file)) file - buildResult <- system ("cargo +nightly cbuild --release --manifest-path " ++ dir ++ "/rust-wrapper/Cargo.toml") - case buildResult of - ExitSuccess -> return () - ExitFailure exitCode -> throwIO $ userError $ "Build rust library failed with exit code " <> show exitCode + isNotDependency <- doesFileExist (pathToDistNewstyle ++ "rust-wrapper/Cargo.toml") + + pathToRustWrapper <- if isNotDependency + then return pathToDistNewstyle + else do + contents <- listDirectory (pathToDistNewstyle ++ "src/") + let depLib = fromJust $ find (isPrefixOf "zkfold-pr") contents + + return $ pathToDistNewstyle ++ "src/" ++ depLib - output <- readProcess "rustc" ["--version", "--verbose"] "" - case filter ("host: " `isPrefixOf`) (lines output) of - [line] -> do - let host = dropWhile isSpace $ drop 5 line - pathToLib = "rust-wrapper/target/" <> host <> "/release/librust_wrapper.so" + putStrLn $ pathToRustWrapper - libExist <- doesFileExist pathToLib - unless libExist $ throwIO $ userError "Can't find rust library" + buildResult <- system ("cargo +nightly build --release " ++ + "--manifest-path " ++ pathToRustWrapper ++ "rust-wrapper/Cargo.toml " ++ + "--artifact-dir=" ++ pathToRustWrapper ++ "libs/ -Z unstable-options" + ) - copyFile pathToLib "./lib.so" + case buildResult of + ExitSuccess -> return () + ExitFailure exitCode -> do + throwIO $ userError $ "Build rust library failed with exit code " <> show exitCode - _ -> throwIO $ userError "Can't find default rust target" return emptyHookedBuildInfo diff --git a/bench/BenchPolyDiv.hs b/bench/BenchPolyDiv.hs new file mode 100644 index 0000000..64fd122 --- /dev/null +++ b/bench/BenchPolyDiv.hs @@ -0,0 +1,51 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module Main where + +import Control.DeepSeq (force) +import Control.Exception (evaluate) +import Control.Monad (replicateM) +import Data.Tuple.Extra +import qualified Data.Vector as V +import Foreign +import Prelude hiding (sum, (*), (+), (-), (/), (^)) +import qualified Prelude as P +import RustFunctions (rustDivFft) +import System.Random (randomIO) +import Test.Tasty.Bench + +import ZkFold.Base.Algebra.Basic.Class +import ZkFold.Base.Algebra.Basic.Field +import ZkFold.Base.Algebra.Basic.Number (Prime) +import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 +import ZkFold.Base.Algebra.Polynomials.Univariate +-- | Generate random polynomials of given size +-- +polynomials :: forall a. Prime a => Int -> IO (Poly (Zp a), Poly (Zp a)) +polynomials size = do + coeffs1 <- replicateM size (toZp @a <$> randomIO) + coeffs2 <- replicateM size (toZp @a <$> randomIO) + evaluatedCoeffs1 <- evaluate . force . V.fromList $ coeffs1 + evaluatedCoeffs2 <- evaluate . force . V.fromList $ coeffs2 + pure (toPoly evaluatedCoeffs1, toPoly evaluatedCoeffs2) + +sizes :: [Int] +sizes = ((2 :: Int) P.^) <$> [10 .. 14 :: Int] + +ops :: forall a . (Eq a, Field a, Storable a) => [(String, Poly a -> Poly a -> (Poly a, Poly a))] +ops = [ ("Haskell division", qr) + , ("Rust division", \x y -> both toPoly (rustDivFft @a (fromPoly x) (fromPoly y))) + ] + +benchOps :: Prime a => Int -> [(String, Poly (Zp a) -> Poly (Zp a) -> (Poly (Zp a), Poly (Zp a)) )] -> Benchmark +benchOps size testOps = env (polynomials size) $ \ ~(p1, p2) -> + bgroup ("Multiplying polynomials of size " <> show size) $ + flip fmap testOps $ \(desc, op) -> bench desc $ nf (uncurry op) (p1, p2) + +main :: IO () +main = do + defaultMain + [ bgroup "Field without roots of unity" $ flip fmap sizes $ \s -> benchOps @BLS12_381_Base s ops + , bgroup "Field with roots of unity" $ flip fmap sizes $ \s -> benchOps @BLS12_381_Scalar s ops + ] diff --git a/bench/BenchProve.hs b/bench/BenchProve.hs index e22a04c..e6bb7a4 100644 --- a/bench/BenchProve.hs +++ b/bench/BenchProve.hs @@ -6,17 +6,22 @@ import GHC.Generics (U1 (U1)) import Prelude hiding (Num (..), length, sum, take, (-)) import RustFunctions (RustCore) import Test.QuickCheck (Arbitrary (arbitrary), generate) +import Test.QuickCheck.Arbitrary (Arbitrary1 (liftArbitrary)) import Test.Tasty.Bench import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_G2) +import ZkFold.Base.Data.Vector (Vector) import ZkFold.Base.Protocol.NonInteractiveProof import ZkFold.Base.Protocol.Plonk (Plonk) -type PlonkBS n = Plonk U1 1 32 n BLS12_381_G1 BLS12_381_G2 BS.ByteString +type PlonkBS n = Plonk U1 (Vector 1) 32 (Vector n) BLS12_381_G1 BLS12_381_G2 BS.ByteString instance Arbitrary (U1 a) where arbitrary = return U1 +instance Arbitrary1 U1 where + liftArbitrary _ = return U1 + main :: IO () main = do a <- generate arbitrary :: IO (PlonkBS 2) diff --git a/cabal.project b/cabal.project index ee30590..93f7aa1 100644 --- a/cabal.project +++ b/cabal.project @@ -1,12 +1,7 @@ packages: . -source-repository-package - type: git - location: https://github.com/BeFunctional/haskell-foreign-rust.git - tag: 90b1c210ae4e753c39481a5f3b141b74e6b6d96e - source-repository-package type: git location: https://github.com/zkFold/zkfold-base.git - tag: e34d8999d45761481e69deac17bbd550a36acf2b + tag: 544062ce7bb41da6f0e9fb58de53eb3377f9df22 subdir: symbolic-base \ No newline at end of file diff --git a/haskell-wrapper/src/RustFunctions.hs b/haskell-wrapper/src/RustFunctions.hs index 1a0d7d4..966023a 100644 --- a/haskell-wrapper/src/RustFunctions.hs +++ b/haskell-wrapper/src/RustFunctions.hs @@ -7,6 +7,7 @@ module RustFunctions ( rustMultiScalarMultiplicationWithoutSerialization , rustMulFft + , rustDivFft , RustCore ) where @@ -22,9 +23,10 @@ import GHC.Num.Integer (integerToInt#) import GHC.Num.Natural (naturalFromAddr, naturalToAddr) import GHC.Ptr (Ptr (..)) import GHC.TypeNats (KnownNat) -import Prelude hiding (sum) +import Prelude hiding (rem, sum) import System.Posix.DynamicLinker +import ZkFold.Base.Algebra.Basic.Class (AdditiveMonoid (zero)) import ZkFold.Base.Algebra.Basic.Field import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 import ZkFold.Base.Algebra.EllipticCurve.Class @@ -33,13 +35,19 @@ import ZkFold.Base.Data.ByteString import ZkFold.Base.Protocol.NonInteractiveProof (CoreFunction (..), msm) libPath :: FilePath -libPath = "./lib.so" +libPath = "libs/librust_wrapper.so" -type FunFFT = +type FunMulFFT = + CString -> Int -> CString -> Int -> Int -> CString -> IO () + +type FunDivFFT = CString -> Int -> CString -> Int -> Int -> CString -> IO () foreign import ccall "dynamic" - mkFunFFT :: FunPtr FunFFT -> FunFFT + mkFunMulFFT :: FunPtr FunMulFFT -> FunMulFFT + +foreign import ccall "dynamic" + mkFunDivFFT :: FunPtr FunDivFFT -> FunDivFFT type FunMSM = CString -> Int -> CString -> Int -> Int -> CString -> IO () @@ -143,27 +151,27 @@ instance Storable Fq where infByteStringRepr :: [Word8] infByteStringRepr = replicate 47 0 <> (bit 6 : replicate 48 0) -instance Storable (Point BLS12_381_G1) where +instance forall f . (ScalarField f ~ Zp BLS12_381_Scalar, BaseField f ~ Zp BLS12_381_Base, BooleanOf f ~ Bool) => Storable (Point f) where - sizeOf :: Point BLS12_381_G1 -> Int + sizeOf :: Point f -> Int sizeOf _ = 96 - alignment :: Point BLS12_381_G1 -> Int + alignment :: Point f -> Int alignment _ = alignment @Fq undefined - peek :: Ptr (Point BLS12_381_G1) -> IO (Point BLS12_381_G1) + peek :: Ptr (Point f) -> IO (Point f) peek ptr = do - a <- BS.packCStringLen (castPtr ptr, sizeOf @(Point BLS12_381_G1) undefined) + a <- BS.packCStringLen (castPtr ptr, sizeOf @(Point f) undefined) if BS.pack infByteStringRepr == a - then return Inf + then return $ Point zero zero True else do x <- peek @Fq (castPtr ptr) y <- peek @Fq (ptr `plusPtr` sizeOf @Fq undefined) - return $ Point x y + return $ Point x y False - poke :: Ptr (Point BLS12_381_G1) -> Point BLS12_381_G1 -> IO () - poke ptr Inf = pokeArray (castPtr ptr) infByteStringRepr - poke ptr (Point x y) = do + poke :: Ptr (Point f) -> Point f -> IO () + poke ptr (Point _ _ True) = pokeArray (castPtr ptr) infByteStringRepr + poke ptr (Point x y False) = do poke (castPtr ptr) x poke (castPtr ptr `plusPtr` sizeOf @Fq undefined) y @@ -182,6 +190,10 @@ instance CoreFunction BLS12_381_G1 RustCore where zipAndUnzip _ _ = ([],[]) polyMul x y = toPoly (rustMulFft @(ScalarField BLS12_381_G1) (fromPoly x) (fromPoly y)) + polyQr x y = both toPoly $ rustDivFft @(ScalarField BLS12_381_G1) (fromPoly x) (fromPoly y) + +both :: (t -> b) -> (t, t) -> (b, b) +both f (x, y) = (f x, f y) rustMulFft :: forall f . Storable f => V.Vector f -> V.Vector f -> V.Vector f rustMulFft l r = if lByteLength * rByteLength == 0 then V.empty else unsafePerformIO runFFT @@ -195,7 +207,7 @@ rustMulFft l r = if lByteLength * rByteLength == 0 then V.empty else unsafePerfo runFFT = do dl <- dlopen libPath [RTLD_NOW] fftPtr <- dlsym dl "rust_wrapper_mul_fft" - let !fft = mkFunFFT $ castFunPtr fftPtr + let !fft = mkFunMulFFT $ castFunPtr fftPtr ptrL <- callocBytes @f lByteLength ptrR <- callocBytes @f rByteLength @@ -210,6 +222,57 @@ rustMulFft l r = if lByteLength * rByteLength == 0 then V.empty else unsafePerfo (castPtr ptrL) lByteLength (castPtr ptrR) rByteLength outLen out + free ptrL + free ptrR dlclose dl - V.fromList <$> peekArray @f (V.length l + V.length r - 1) (castPtr out) + !res <- V.fromList <$> peekArray @f (V.length l + V.length r - 1) (castPtr out) + + free out + + return res + +-- Should be without leading zeroes +rustDivFft :: forall f . Storable f => V.Vector f -> V.Vector f -> (V.Vector f, V.Vector f) +rustDivFft l r | rByteLength == 0 = error "Polynomial division by zero" + | lByteLength < rByteLength = (V.empty, l) + | otherwise = unsafePerformIO runFFT + where + scalarSize = sizeOf (undefined :: f) + + lByteLength = scalarSize * V.length l + rByteLength = scalarSize * V.length r + + runFFT :: IO (V.Vector f, V.Vector f) + runFFT = do + dl <- dlopen libPath [RTLD_NOW] + fftPtr <- dlsym dl "rust_wrapper_div_fft" + let !fft = mkFunDivFFT $ castFunPtr fftPtr + + ptrL <- callocBytes @f lByteLength + ptrR <- callocBytes @f rByteLength + + pokeArray ptrL (V.toList l) + pokeArray ptrR (V.toList r) + + let outLen = (V.length l + 1) * scalarSize + out <- callocBytes outLen + + !_ <- fft + (castPtr ptrL) lByteLength + (castPtr ptrR) rByteLength + outLen out + + free ptrL + free ptrR + dlclose dl + + + + !quo <- V.fromList <$> peekArray @f (V.length l - V.length r + 1) (castPtr out) + !rem <- V.fromList <$> peekArray @f (V.length r) (castPtr out `plusPtr` ((V.length l - V.length r + 1 ) * scalarSize)) + + + free out + + return (quo, rem) diff --git a/haskell-wrapper/tests/FFT.hs b/haskell-wrapper/tests/FFT.hs index 2fed926..c495166 100644 --- a/haskell-wrapper/tests/FFT.hs +++ b/haskell-wrapper/tests/FFT.hs @@ -1,18 +1,19 @@ module FFT (testFFT) where +import Control.Monad (unless) import qualified Data.Vector as V import Prelude hiding (Num (..), sum, take) -import RustFunctions (rustMulFft) +import RustFunctions (rustMulFft, rustDivFft) import Test.Hspec (describe, hspec, it, shouldBe) import Test.QuickCheck (Testable (property)) -import ZkFold.Base.Algebra.Basic.Class (MultiplicativeSemigroup ((*))) -import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1) +import ZkFold.Base.Algebra.Basic.Class (MultiplicativeSemigroup ((*)), MultiplicativeMonoid (one), AdditiveMonoid (zero)) +import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_GT (BLS12_381_GT)) import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (ScalarField)) -import ZkFold.Base.Algebra.Polynomials.Univariate (toPoly) +import ZkFold.Base.Algebra.Polynomials.Univariate (toPoly, qr, deg) -specFFT :: IO () -specFFT = hspec $ do +specMulFFT :: IO () +specMulFFT = hspec $ do describe "Rust FFT multiplication specification" $ do it "should be equal to haskell" $ do property $ @@ -27,7 +28,22 @@ specFFT = hspec $ do `shouldBe` left * right +specDivFFT :: IO () +specDivFFT = hspec $ do + describe "Rust FFT division specification" $ do + it "should be equal to haskell" $ do + property $ + \ + (l :: [ScalarField BLS12_381_G1]) + (r :: [ScalarField BLS12_381_G1]) + -> let + left = toPoly $ V.fromList l + right = toPoly $ V.fromList r + (ll, rr) = (rustDivFft @(ScalarField BLS12_381_G1) (V.fromList l) (V.fromList r)) + in (unless (deg right == - 1) $ (toPoly ll, toPoly rr) `shouldBe` left `qr` right) + testFFT :: IO () testFFT = do - specFFT + specDivFFT + specMulFFT diff --git a/rust-wrapper/Cargo.toml b/rust-wrapper/Cargo.toml index 1d9b4aa..8556aec 100644 --- a/rust-wrapper/Cargo.toml +++ b/rust-wrapper/Cargo.toml @@ -7,26 +7,33 @@ edition = "2021" # Crate arkmsm uses arkworks 0.3.0, it is not compatible with version 0.4.0. # We are using an unofficial fork of the repository, which updates the version to 0.4.0. # We need to check the status of arkmsm to move to a stable version. -ark-msm = { git = "https://github.com/TalDerei/arkmsm.git", rev="bc95ea3784983d8ced03a642d765bbfdd91faa9b"} +ark-msm = { git = "https://github.com/TalDerei/arkmsm.git", rev = "bc95ea3784983d8ced03a642d765bbfdd91faa9b" } -ark-ff = { version= "0.4.0", default-features = false } -ark-ec = { version= "0.4.0" } +ark-ff = { version = "0.4.0", default-features = false } +ark-ec = { version = "0.4.0" } ark-std = { version = "0.4.0", default-features = false } ark-serialize = { version = "0.4.0", default-features = false } ark-bls12-381 = { version = "0.4.0" } -der = "0.7" -rand = "0.8" -num-bigint = "0.4.6" -num-traits = "0.2.19" -libc = "0.2.158" -ark-poly = "0.4.2" +der = "0.7" +rand = "0.8" +num-bigint = "0.4.6" +num-traits = "0.2.19" +libc = "0.2.158" +ark-poly = "0.4.2" ark-test-curves = "0.4.2" criterion = { version = "0.5", features = ["html_reports"] } +[lib] +crate-type = ["lib", "cdylib"] + [[bench]] name = "msm_bench" harness = false +[[bench]] +name = "div_bench" +harness = false + [[bench]] name = "fft_bench" harness = false diff --git a/rust-wrapper/benches/div_bench.rs b/rust-wrapper/benches/div_bench.rs new file mode 100644 index 0000000..9ca1f5b --- /dev/null +++ b/rust-wrapper/benches/div_bench.rs @@ -0,0 +1,48 @@ +use ark_bls12_381::Fr as ScalarField; +use ark_ff::{BigInt, BigInteger}; +use ark_poly::univariate::DensePolynomial; +use ark_poly::DenseUVPolynomial; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use rust_wrapper::fft::div_fft; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("fft"); + + for size in 14..=19 { + let degree = 1 << size; + + let mut rng = &mut ark_std::test_rng(); + + let l: DensePolynomial = + DensePolynomial::::rand(degree, &mut rng); + let r: DensePolynomial = + DensePolynomial::::rand(degree, &mut rng); + + let r_bytes_vec: Vec = r + .coeffs + .iter() + .flat_map(|i| (BigInt::from(*i)).to_bytes_le()) + .collect(); + + let l_bytes_vec: Vec = l + .coeffs + .iter() + .flat_map(|i| (BigInt::from(*i)).to_bytes_le()) + .collect(); + + group.bench_with_input( + BenchmarkId::new("FFT Multiplication", degree), + °ree, + |b, _size| { + b.iter(|| { + div_fft(&l_bytes_vec, &r_bytes_vec); + }) + }, + ); + } + + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/rust-wrapper/src/fft.rs b/rust-wrapper/src/fft.rs index 4ff82c9..1bfecd4 100644 --- a/rust-wrapper/src/fft.rs +++ b/rust-wrapper/src/fft.rs @@ -5,7 +5,9 @@ use ark_poly::DenseUVPolynomial; use core::slice; use libc; use num_bigint::BigUint; +use std::ops::Div; use std::ops::Mul; +use std::ops::Sub; use crate::utils::deserialize_vector_scalar_field; @@ -24,6 +26,60 @@ pub fn mul_fft(l: &[u8], r: &[u8]) -> Vec { .collect() } +pub fn div_fft(l: &[u8], r: &[u8]) -> Vec { + let l = DensePolynomial::from_coefficients_vec(deserialize_vector_scalar_field(l)); + let r = DensePolynomial::from_coefficients_vec(deserialize_vector_scalar_field(r)); + + let quo = l.div(&r); + let rem = l.sub(&quo.mul(&r)); + + let mut quo_vec: Vec = quo + .to_vec() + .iter() + .flat_map(|x| { + let mut v = BigUint::from(x.into_bigint()).to_bytes_le(); + v.resize(std::mem::size_of::(), 0); + v + }) + .collect(); + + let mut rem_vec: Vec = rem + .to_vec() + .iter() + .flat_map(|x| { + let mut v = BigUint::from(x.into_bigint()).to_bytes_le(); + v.resize(std::mem::size_of::(), 0); + v + }) + .collect(); + + rem_vec.resize(std::mem::size_of::() * r.len(), 0); + quo_vec.append(&mut rem_vec); + + quo_vec +} + +/// +/// # Safety +/// The caller must ensure that valid pointers and sizes are passed. +/// . +#[no_mangle] +pub unsafe extern "C" fn rust_wrapper_div_fft( + l_var: *const libc::c_char, + l_len: usize, + r_var: *const libc::c_char, + r_len: usize, + _out_len: usize, + out: *mut libc::c_char, +) { + let l = slice::from_raw_parts(l_var as *const u8, l_len); + let r = slice::from_raw_parts(r_var as *const u8, r_len); + + let res = div_fft(l, r); + + std::ptr::copy(res.as_ptr(), out as *mut u8, res.len()); +} + /// /// # Safety /// The caller must ensure that valid pointers and sizes are passed. diff --git a/zkfold-prover.cabal b/zkfold-prover.cabal index 3f590db..c39a7c6 100644 --- a/zkfold-prover.cabal +++ b/zkfold-prover.cabal @@ -14,6 +14,7 @@ custom-setup , Cabal , process , directory + , pseudomacros Flag Pedantic Description: Enable pedantic build with -Werror @@ -160,4 +161,24 @@ benchmark poly-mul bytestring, zkfold-prover, random, - deepseq \ No newline at end of file + deepseq + +benchmark poly-div + import: options + main-is: BenchPolyDiv.hs + hs-source-dirs: bench + type: exitcode-stdio-1.0 + ghc-options: + -rtsopts + -fprof-auto + build-depends: + base, + tasty-bench, + QuickCheck, + vector, + symbolic-base, + bytestring, + zkfold-prover, + random, + deepseq, + extra \ No newline at end of file From 314c0a4d22bdb917ded61c875442df60454e8611 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 11:34:17 +0300 Subject: [PATCH 07/18] break cabal build --- Setup.hs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Setup.hs b/Setup.hs index d71f16b..846fdbc 100644 --- a/Setup.hs +++ b/Setup.hs @@ -17,6 +17,9 @@ main = defaultMainWithHooks simpleUserHooks preConf = buildRustLib } +infi = do + infi + buildRustLib :: Args -> a -> IO HookedBuildInfo buildRustLib _ flags = do @@ -25,6 +28,7 @@ buildRustLib _ flags = do isNotDependency <- doesFileExist (pathToDistNewstyle ++ "rust-wrapper/Cargo.toml") + infi pathToRustWrapper <- if isNotDependency then return pathToDistNewstyle else do From 6af46e322e681ffa6a81d04df44d627ae81eb1e3 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 11:45:36 +0300 Subject: [PATCH 08/18] draft: debug info --- Setup.hs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Setup.hs b/Setup.hs index 846fdbc..cb49944 100644 --- a/Setup.hs +++ b/Setup.hs @@ -28,13 +28,17 @@ buildRustLib _ flags = do isNotDependency <- doesFileExist (pathToDistNewstyle ++ "rust-wrapper/Cargo.toml") - infi + -- infi + print $ file + print $ pathToDistNewstyle + pathToRustWrapper <- if isNotDependency then return pathToDistNewstyle else do contents <- listDirectory (pathToDistNewstyle ++ "src/") + print $ contents let depLib = fromJust $ find (isPrefixOf "zkfold-pr") contents - + print $ depLib return $ pathToDistNewstyle ++ "src/" ++ depLib putStrLn $ pathToRustWrapper From 93a012b23bdbd5cd6707f3e2766c9843f4c2d970 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 11:47:38 +0300 Subject: [PATCH 09/18] fix: path in dependency build --- Setup.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Setup.hs b/Setup.hs index cb49944..7529b84 100644 --- a/Setup.hs +++ b/Setup.hs @@ -35,7 +35,7 @@ buildRustLib _ flags = do pathToRustWrapper <- if isNotDependency then return pathToDistNewstyle else do - contents <- listDirectory (pathToDistNewstyle ++ "src/") + contents <- listDirectory (pathToDistNewstyle ++ "dist-newstyle/src/") print $ contents let depLib = fromJust $ find (isPrefixOf "zkfold-pr") contents print $ depLib From 88d39f89a129c29992f319488830631145e640c3 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 12:03:46 +0300 Subject: [PATCH 10/18] fix: filter directories --- Setup.hs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Setup.hs b/Setup.hs index 7529b84..123839c 100644 --- a/Setup.hs +++ b/Setup.hs @@ -37,9 +37,12 @@ buildRustLib _ flags = do else do contents <- listDirectory (pathToDistNewstyle ++ "dist-newstyle/src/") print $ contents - let depLib = fromJust $ find (isPrefixOf "zkfold-pr") contents - print $ depLib - return $ pathToDistNewstyle ++ "src/" ++ depLib + depLibs <- filterM (\p -> do + let prefixCond = isPrefixOf "zkfold-pr" p + dirCond <- doesDirectoryExist (pathToDistNewstyle ++ "dist-newstyle/src/" ++ p) + return $ dirCond && prefixCond) contents + print $ depLibs + return $ pathToDistNewstyle ++ "src/" ++ (head depLibs) putStrLn $ pathToRustWrapper From a19a7258983130c873fce87d9997c9ed43cfa318 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 12:06:46 +0300 Subject: [PATCH 11/18] fix: path to rust-wrapper --- Setup.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Setup.hs b/Setup.hs index 123839c..bf096c7 100644 --- a/Setup.hs +++ b/Setup.hs @@ -42,7 +42,7 @@ buildRustLib _ flags = do dirCond <- doesDirectoryExist (pathToDistNewstyle ++ "dist-newstyle/src/" ++ p) return $ dirCond && prefixCond) contents print $ depLibs - return $ pathToDistNewstyle ++ "src/" ++ (head depLibs) + return $ pathToDistNewstyle ++ "src/" ++ (head depLibs) ++ "/" putStrLn $ pathToRustWrapper From 5bd9446e7cff535cc254d705f689dd0a92ec002c Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 12:11:41 +0300 Subject: [PATCH 12/18] fix: path --- Setup.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Setup.hs b/Setup.hs index bf096c7..53c5f0f 100644 --- a/Setup.hs +++ b/Setup.hs @@ -42,7 +42,7 @@ buildRustLib _ flags = do dirCond <- doesDirectoryExist (pathToDistNewstyle ++ "dist-newstyle/src/" ++ p) return $ dirCond && prefixCond) contents print $ depLibs - return $ pathToDistNewstyle ++ "src/" ++ (head depLibs) ++ "/" + return $ pathToDistNewstyle ++ "dist-newstyle/src/" ++ (head depLibs) ++ "/" putStrLn $ pathToRustWrapper From d91d46792a9ae4a99ada65506e2fd15ae17bb70a Mon Sep 17 00:00:00 2001 From: diS3e Date: Mon, 20 Jan 2025 09:32:52 +0000 Subject: [PATCH 13/18] stylish-haskell auto-commit --- Setup.hs | 10 +++++----- haskell-wrapper/tests/FFT.hs | 7 ++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/Setup.hs b/Setup.hs index 53c5f0f..6762dd0 100644 --- a/Setup.hs +++ b/Setup.hs @@ -2,8 +2,8 @@ import Control.Exception (throwIO) import Control.Monad import Data.Char (isSpace) +import Data.List (dropWhile, find, findIndex, isPrefixOf, tails) import Data.Maybe (fromJust) -import Data.List (dropWhile, isPrefixOf, tails, findIndex, find) import Distribution.Simple import Distribution.Types.HookedBuildInfo import PseudoMacros @@ -13,7 +13,7 @@ import System.Process (readProcess, system) main :: IO () main = defaultMainWithHooks simpleUserHooks - { + { preConf = buildRustLib } @@ -27,7 +27,7 @@ buildRustLib _ flags = do let pathToDistNewstyle = take (fromJust $ findIndex (isPrefixOf "dist-newstyle") (tails file)) file isNotDependency <- doesFileExist (pathToDistNewstyle ++ "rust-wrapper/Cargo.toml") - + -- infi print $ file print $ pathToDistNewstyle @@ -46,8 +46,8 @@ buildRustLib _ flags = do putStrLn $ pathToRustWrapper - buildResult <- system ("cargo +nightly build --release " ++ - "--manifest-path " ++ pathToRustWrapper ++ "rust-wrapper/Cargo.toml " ++ + buildResult <- system ("cargo +nightly build --release " ++ + "--manifest-path " ++ pathToRustWrapper ++ "rust-wrapper/Cargo.toml " ++ "--artifact-dir=" ++ pathToRustWrapper ++ "libs/ -Z unstable-options" ) diff --git a/haskell-wrapper/tests/FFT.hs b/haskell-wrapper/tests/FFT.hs index c495166..39ff0e6 100644 --- a/haskell-wrapper/tests/FFT.hs +++ b/haskell-wrapper/tests/FFT.hs @@ -3,14 +3,15 @@ module FFT (testFFT) where import Control.Monad (unless) import qualified Data.Vector as V import Prelude hiding (Num (..), sum, take) -import RustFunctions (rustMulFft, rustDivFft) +import RustFunctions (rustDivFft, rustMulFft) import Test.Hspec (describe, hspec, it, shouldBe) import Test.QuickCheck (Testable (property)) -import ZkFold.Base.Algebra.Basic.Class (MultiplicativeSemigroup ((*)), MultiplicativeMonoid (one), AdditiveMonoid (zero)) +import ZkFold.Base.Algebra.Basic.Class (AdditiveMonoid (zero), MultiplicativeMonoid (one), + MultiplicativeSemigroup ((*))) import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_GT (BLS12_381_GT)) import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (ScalarField)) -import ZkFold.Base.Algebra.Polynomials.Univariate (toPoly, qr, deg) +import ZkFold.Base.Algebra.Polynomials.Univariate (deg, qr, toPoly) specMulFFT :: IO () specMulFFT = hspec $ do From f6df7c73fdb673c2f3b5fee8bd70e710ba316d11 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 14:59:29 +0300 Subject: [PATCH 14/18] fix: artifact-dir path --- Setup.hs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/Setup.hs b/Setup.hs index 6762dd0..b1983af 100644 --- a/Setup.hs +++ b/Setup.hs @@ -17,9 +17,6 @@ main = defaultMainWithHooks simpleUserHooks preConf = buildRustLib } -infi = do - infi - buildRustLib :: Args -> a -> IO HookedBuildInfo buildRustLib _ flags = do @@ -28,10 +25,6 @@ buildRustLib _ flags = do isNotDependency <- doesFileExist (pathToDistNewstyle ++ "rust-wrapper/Cargo.toml") - -- infi - print $ file - print $ pathToDistNewstyle - pathToRustWrapper <- if isNotDependency then return pathToDistNewstyle else do @@ -44,11 +37,9 @@ buildRustLib _ flags = do print $ depLibs return $ pathToDistNewstyle ++ "dist-newstyle/src/" ++ (head depLibs) ++ "/" - putStrLn $ pathToRustWrapper - buildResult <- system ("cargo +nightly build --release " ++ "--manifest-path " ++ pathToRustWrapper ++ "rust-wrapper/Cargo.toml " ++ - "--artifact-dir=" ++ pathToRustWrapper ++ "libs/ -Z unstable-options" + "--artifact-dir=" ++ pathToDistNewstyle ++ "libs/ -Z unstable-options" ) case buildResult of From d1fad6302fe359053c90563c40bb4783dc777293 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 16:49:38 +0300 Subject: [PATCH 15/18] fix: update gitignore --- .gitignore | 4 +++- cabal.project | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index d100958..b5234c0 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,8 @@ cabal.project.local~ .ghc.environment.* rust-wrapper/target/ rust-wrapper/Cargo.lock -lib.so +*.so *.data *.svg +*.nix +*.rlib diff --git a/cabal.project b/cabal.project index 93f7aa1..18cb355 100644 --- a/cabal.project +++ b/cabal.project @@ -3,5 +3,5 @@ packages: . source-repository-package type: git location: https://github.com/zkFold/zkfold-base.git - tag: 544062ce7bb41da6f0e9fb58de53eb3377f9df22 + tag: c8d4695e7f5e0140b76d819f82c4d41d1510b385 subdir: symbolic-base \ No newline at end of file From e79ab6cebec9d631d348dd8eebda90f716903a43 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 20 Jan 2025 16:52:42 +0300 Subject: [PATCH 16/18] fix: unused imports --- haskell-wrapper/tests/FFT.hs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/haskell-wrapper/tests/FFT.hs b/haskell-wrapper/tests/FFT.hs index 39ff0e6..4165efa 100644 --- a/haskell-wrapper/tests/FFT.hs +++ b/haskell-wrapper/tests/FFT.hs @@ -7,9 +7,8 @@ import RustFunctions (rustDivFft, rustMu import Test.Hspec (describe, hspec, it, shouldBe) import Test.QuickCheck (Testable (property)) -import ZkFold.Base.Algebra.Basic.Class (AdditiveMonoid (zero), MultiplicativeMonoid (one), - MultiplicativeSemigroup ((*))) -import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS12_381_GT (BLS12_381_GT)) +import ZkFold.Base.Algebra.Basic.Class (MultiplicativeSemigroup ((*))) +import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1) import ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (ScalarField)) import ZkFold.Base.Algebra.Polynomials.Univariate (deg, qr, toPoly) From fec99277980b62fec9276a488bc2f9e8e8790370 Mon Sep 17 00:00:00 2001 From: Eugene Samodelov Date: Mon, 3 Feb 2025 12:19:55 +0300 Subject: [PATCH 17/18] draft: add RustBLS --- bench/BenchProve.hs | 3 +- haskell-wrapper/src/RustBLS.hs | 279 +++++++++++++++++++++++++++ haskell-wrapper/src/RustFunctions.hs | 61 +++++- rust-wrapper/src/msm.rs | 51 ++++- zkfold-prover.cabal | 3 + 5 files changed, 390 insertions(+), 7 deletions(-) create mode 100644 haskell-wrapper/src/RustBLS.hs diff --git a/bench/BenchProve.hs b/bench/BenchProve.hs index e6bb7a4..b8cbc64 100644 --- a/bench/BenchProve.hs +++ b/bench/BenchProve.hs @@ -13,8 +13,9 @@ import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS1 import ZkFold.Base.Data.Vector (Vector) import ZkFold.Base.Protocol.NonInteractiveProof import ZkFold.Base.Protocol.Plonk (Plonk) +import RustBLS (RustBLS12_381_G1, RustBLS12_381_G2) -type PlonkBS n = Plonk U1 (Vector 1) 32 (Vector n) BLS12_381_G1 BLS12_381_G2 BS.ByteString +type PlonkBS n = Plonk U1 (Vector 1) 32 (Vector n) RustBLS12_381_G1 RustBLS12_381_G2 BS.ByteString instance Arbitrary (U1 a) where arbitrary = return U1 diff --git a/haskell-wrapper/src/RustBLS.hs b/haskell-wrapper/src/RustBLS.hs new file mode 100644 index 0000000..f1183a7 --- /dev/null +++ b/haskell-wrapper/src/RustBLS.hs @@ -0,0 +1,279 @@ +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE TypeSynonymInstances #-} + +{-# OPTIONS_GHC -Wno-orphans #-} + +module RustBLS where + +import Control.DeepSeq (NFData) +import Control.Monad +import Data.Bits +import Data.Foldable hiding (sum) +import Data.Word +import GHC.Generics (Generic) +import Prelude hiding (Num (..), (/), (^), Eq, sum) + +import ZkFold.Base.Algebra.Basic.Class hiding (sum) +import ZkFold.Base.Algebra.Basic.Field +import ZkFold.Base.Algebra.Basic.Number +import ZkFold.Base.Algebra.EllipticCurve.Class +import ZkFold.Base.Algebra.EllipticCurve.Pairing +import ZkFold.Base.Algebra.Polynomials.Univariate +import ZkFold.Base.Data.ByteString +import qualified Data.Vector as V + +import RustFunctions (rustMulFft, rustMulPoint, rustMultiScalarMultiplicationWithoutSerialization, RustCore, both, rustDivFft) +import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 hiding (Fq, Fr) +import ZkFold.Symbolic.Data.Bool (BoolType) +import ZkFold.Symbolic.Data.Conditional (Conditional) +import Data.Kind (Type) +import qualified Prelude as P +import ZkFold.Symbolic.Data.Eq (Eq) +import Test.QuickCheck (Arbitrary) +import ZkFold.Base.Protocol.NonInteractiveProof.Internal (HaskellCore, CoreFunction(..)) +import ZkFold.Base.Algebra.Basic.Class (sum) + +type Fr = Zp BLS12_381_Scalar +type Fq = Zp BLS12_381_Base + +data RustBLS12_381_G1 + deriving (Generic, NFData) + +instance {-# OVERLAPPING #-} EllipticCurve RustBLS12_381_G1 where + type ScalarField RustBLS12_381_G1 = Fr + + type BaseField RustBLS12_381_G1 = Fq + + pointGen = pointXY + 0x17f1d3a73197d7942695638c4fa9ac0fc3688c4f9774b905a14e3a3f171bac586c55e83ff97a1aeffb3af00adb22c6bb + 0x8b3f481e3aaa0f1a09e30ed741d8ae4fcf5e095d5d00af600db18cb2c04b3edd03cc744a2888ae40caa232946c5e7e1 + + add = addPoints + + mul scalar point = rustMulPoint point scalar + +instance WeierstrassCurve RustBLS12_381_G1 where + weierstrassA = zero + + weierstrassB = fromConstant (4 :: Natural) + +instance {-# OVERLAPPING #-} MultiplicativeSemigroup (Poly ((Zp BLS12_381_Scalar))) where + -- | If it is possible to calculate a primitive root of unity in the field, proceed with FFT multiplication. + -- Otherwise default to Karatsuba multiplication for polynomials of degree higher than 64 or use naive multiplication otherwise. + -- 64 is a threshold determined by benchmarking. + l * r = toPoly (rustMulFft (fromPoly l) (fromPoly r)) + +instance {-# OVERLAPPING #-} MultiplicativeMonoid (Poly ((Zp BLS12_381_Scalar))) where + one = toPoly $ V.singleton one + +instance {-# OVERLAPPING #-} (KnownNat size) => MultiplicativeSemigroup (PolyVec ( (Zp BLS12_381_Scalar)) size) where + l * r = poly2vec $ toPoly $ rustMulFft (fromPoly $ vec2poly l) (fromPoly $ vec2poly r) + +instance {-# OVERLAPPING #-} (KnownNat size) => MultiplicativeMonoid (PolyVec ( (Zp BLS12_381_Scalar)) size) where + one = toPolyVec $ V.singleton one V.++ V.replicate (fromIntegral (value @size -! 1)) zero + +instance {-# OVERLAPPING #-} (KnownNat size) => Semiring (PolyVec ( (Zp BLS12_381_Scalar)) size) + +instance {-# OVERLAPPING #-} (KnownNat size) => Ring (PolyVec ( (Zp BLS12_381_Scalar)) size) + +------------------------------------ BLS12-381 G2 ------------------------------------ + +data RustBLS12_381_G2 + deriving (Generic, NFData) + +instance EllipticCurve RustBLS12_381_G2 where + + type ScalarField RustBLS12_381_G2 = Fr + + type BaseField RustBLS12_381_G2 = Fq2 + + pointGen = pointXY + (Ext2 + 0x24aa2b2f08f0a91260805272dc51051c6e47ad4fa403b02b4510b647ae3d1770bac0326a805bbefd48056c8c121bdb8 + 0x13e02b6052719f607dacd3a088274f65596bd0d09920b61ab5da61bbdc7f5049334cf11213945d57e5ac7d055d042b7e) + (Ext2 + 0xce5d527727d6e118cc9cdc6da2e351aadfd9baa8cbdd3a76d429a695160d12c923ac9cc3baca289e193548608b82801 + 0x606c4a02ea734cc32acd2b02bc28b99cb3e287e85a763af267492ab572e99ab3f370d275cec1da1aaa9075ff05f79be) + + add = addPoints + + mul = pointMul + +instance WeierstrassCurve RustBLS12_381_G2 where + weierstrassA = zero + + weierstrassB = fromConstant (4 :: Natural) + + +instance Binary (Point RustBLS12_381_G1) where + put (Point x y isInf) = + if isInf then foldMap putWord8 (bitReverse8 (bit 1) : replicate 95 0) + else foldMap putWord8 (bytesOf 48 x <> bytesOf 48 y) + get = do + byte <- bitReverse8 <$> getWord8 + let compressed = testBit byte 0 + infinite = testBit byte 1 + if infinite then do + skip (if compressed then 47 else 95) + return pointInf + else do + let byteXhead = bitReverse8 $ clearBit (clearBit (clearBit byte 0) 1) 2 + bytesXtail <- replicateM 47 getWord8 + let x = ofBytes (byteXhead:bytesXtail) + bigY = testBit byte 2 + if compressed then return (decompress (pointCompressed x bigY)) + else do + bytesY <- replicateM 48 getWord8 + let y = ofBytes bytesY + return (pointXY x y) + +instance Binary (CompressedPoint RustBLS12_381_G1) where + put (CompressedPoint x bigY isInf) = + if isInf then foldMap putWord8 (bitReverse8 (bit 0 .|. bit 1) : replicate 47 0) else + let + flags = bitReverse8 $ if bigY then bit 0 .|. bit 2 else bit 0 + bytes = bytesOf 48 x + in foldMap putWord8 ((flags .|. head bytes) : tail bytes) + get = do + byte <- bitReverse8 <$> getWord8 + let compressed = testBit byte 0 + infinite = testBit byte 1 + if infinite then do + skip (if compressed then 47 else 95) + return pointInf + else do + let byteXhead = bitReverse8 $ clearBit (clearBit (clearBit byte 0) 1) 2 + bytesXtail <- replicateM 47 getWord8 + let x = ofBytes (byteXhead:bytesXtail) + bigY = testBit byte 2 + if compressed then return (pointCompressed x bigY) + else do + bytesY <- replicateM 48 getWord8 + let y :: Fq = ofBytes bytesY + bigY' = y > negate y + return (pointCompressed x bigY') + +instance Binary (Point RustBLS12_381_G2) where + put (Point (Ext2 x0 x1) (Ext2 y0 y1) isInf) = + if isInf then foldMap putWord8 (bitReverse8 (bit 1) : replicate 191 0) else + let + bytes = bytesOf 48 x1 + <> bytesOf 48 x0 + <> bytesOf 48 y1 + <> bytesOf 48 y0 + in + foldMap putWord8 bytes + get = do + byte <- bitReverse8 <$> getWord8 + let compressed = testBit byte 0 + infinite = testBit byte 1 + if infinite then do + skip (if compressed then 95 else 191) + return pointInf + else do + let byteX1head = bitReverse8 $ clearBit (clearBit (clearBit byte 0) 1) 2 + bytesX1tail <- replicateM 47 getWord8 + bytesX0 <- replicateM 48 getWord8 + let x1 = ofBytes (byteX1head:bytesX1tail) + x0 = ofBytes bytesX0 + bigY = testBit byte 2 + if compressed then return (decompress (pointCompressed (Ext2 x0 x1) bigY)) + else do + bytesY1 <- replicateM 48 getWord8 + bytesY0 <- replicateM 48 getWord8 + let y0 = ofBytes bytesY0 + y1 = ofBytes bytesY1 + return (pointXY (Ext2 x0 x1) (Ext2 y0 y1)) + +instance Binary (CompressedPoint RustBLS12_381_G2) where + put (CompressedPoint (Ext2 x0 x1) bigY isInf) = + if isInf then foldMap putWord8 (bitReverse8 (bit 0 .|. bit 1) : replicate 95 0) else + let + flags = bitReverse8 $ if bigY then bit 0 .|. bit 2 else bit 0 + bytes = bytesOf 48 x1 <> bytesOf 48 x0 + in + foldMap putWord8 ((flags .|. head bytes) : tail bytes) + get = do + byte <- bitReverse8 <$> getWord8 + let compressed = testBit byte 0 + infinite = testBit byte 1 + if infinite then do + skip (if compressed then 95 else 191) + return pointInf + else do + let byteX1head = bitReverse8 $ clearBit (clearBit (clearBit byte 0) 1) 2 + bytesX1tail <- replicateM 47 getWord8 + bytesX0 <- replicateM 48 getWord8 + let x1 = ofBytes (byteX1head:bytesX1tail) + x0 = ofBytes bytesX0 + x = Ext2 x0 x1 + bigY = testBit byte 2 + if compressed then return (pointCompressed (Ext2 x0 x1) bigY) + else do + bytesY1 <- replicateM 48 getWord8 + bytesY0 <- replicateM 48 getWord8 + let y0 = ofBytes bytesY0 + y1 = ofBytes bytesY1 + y :: Fq2 = Ext2 y0 y1 + bigY' = y > negate y + return (pointCompressed x bigY') + +-- --------------------------------------- Pairing --------------------------------------- + +-- -- | An image of a pairing is a cyclic multiplicative subgroup of @'Fq12'@ +-- -- of order @'BLS12_381_Scalar'@. +newtype RustBLS12_381_GT = RustBLS12_381_GT Fq12 + deriving newtype (P.Eq, Show, MultiplicativeSemigroup, MultiplicativeMonoid) + +instance Exponent RustBLS12_381_GT Natural where + RustBLS12_381_GT a ^ p = RustBLS12_381_GT (a ^ p) + +instance Exponent RustBLS12_381_GT Integer where + RustBLS12_381_GT a ^ p = RustBLS12_381_GT (a ^ p) + +deriving via (NonZero Fq12) instance MultiplicativeGroup RustBLS12_381_GT + +instance Finite RustBLS12_381_GT where + type Order RustBLS12_381_GT = BLS12_381_Scalar + +instance Pairing RustBLS12_381_G1 RustBLS12_381_G2 where + type TargetGroup RustBLS12_381_G1 RustBLS12_381_G2 = RustBLS12_381_GT + pairing a b + = RustBLS12_381_GT + $ finalExponentiation @BLS12_381_G2 + $ millerAlgorithmBLS12 param a b + where + param = [-1 + ,-1, 0,-1, 0, 0,-1, 0, 0, 0, 0, 0, 0, 0, 0,-1, 0 + , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,-1, 0 + , 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + ] + +instance {-# OVERLAPPING #-} CoreFunction RustBLS12_381_G1 HaskellCore where + msm gs f = sum $ V.zipWith scale (fromPolyVec f) gs + polyMul = (*) + polyQr = qr + +instance {-# OVERLAPPING #-} CoreFunction RustBLS12_381_G1 RustCore where + msm gs f = uncurry rustMultiScalarMultiplicationWithoutSerialization (zipAndUnzip points scalars) + where + points = V.toList gs + + scalars = V.toList $ fromPolyVec f + + zipAndUnzip :: [a] -> [b] -> ([a],[b]) + zipAndUnzip (a:as) (b:bs) + = let (rs1, rs2) = zipAndUnzip as bs + in + (a:rs1, b:rs2) + zipAndUnzip _ _ = ([],[]) + + polyMul x y = toPoly (rustMulFft @(ScalarField BLS12_381_G1) (fromPoly x) (fromPoly y)) + polyQr x y = both toPoly $ rustDivFft @(ScalarField BLS12_381_G1) (fromPoly x) (fromPoly y) \ No newline at end of file diff --git a/haskell-wrapper/src/RustFunctions.hs b/haskell-wrapper/src/RustFunctions.hs index 966023a..f234e55 100644 --- a/haskell-wrapper/src/RustFunctions.hs +++ b/haskell-wrapper/src/RustFunctions.hs @@ -8,7 +8,9 @@ module RustFunctions ( rustMultiScalarMultiplicationWithoutSerialization , rustMulFft , rustDivFft + , rustMulPoint , RustCore + , both ) where import qualified Data.ByteString as BS @@ -30,7 +32,7 @@ import ZkFold.Base.Algebra.Basic.Class (AdditiveMonoid (ze import ZkFold.Base.Algebra.Basic.Field import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 import ZkFold.Base.Algebra.EllipticCurve.Class -import ZkFold.Base.Algebra.Polynomials.Univariate (fromPoly, fromPolyVec, toPoly) +import ZkFold.Base.Algebra.Polynomials.Univariate (fromPoly, fromPolyVec, toPoly, qr) import ZkFold.Base.Data.ByteString import ZkFold.Base.Protocol.NonInteractiveProof (CoreFunction (..), msm) @@ -55,6 +57,61 @@ type FunMSM = foreign import ccall "dynamic" mkFunMSM :: FunPtr FunMSM -> FunMSM +type FunMul = + CString -> Int -> CString -> Int -> Int -> CString -> IO () + +foreign import ccall "dynamic" + mkFunMul :: FunPtr FunMul -> FunMul + + +rustMulPoint + :: forall + (a :: Type) + (b :: Natural) + . (ScalarField a ~ Zp b + , Binary (Point a) + , Storable (ScalarField a) + , Storable (Point a) + ) + => Point a + -> ScalarField a + -> Point a +rustMulPoint point scalar = unsafePerformIO runMSM + where + + pointSize = sizeOf (undefined :: Point a) + scalarSize = sizeOf (undefined :: ScalarField a) + + runMSM :: IO (Point a) + runMSM = do + dl <- dlopen libPath [RTLD_NOW] + mulPtr <- dlsym dl "rust_wrapper_mul" + let !mulf = mkFunMul $ castFunPtr mulPtr + + ptrScalar <- callocBytes @(ScalarField a) scalarSize + ptrPoint <- callocBytes @(Point a) pointSize + + poke ptrScalar scalar + poke ptrPoint point + + out <- mallocBytes pointSize + + !_ <- mulf + (castPtr ptrPoint) pointSize + (castPtr ptrScalar) scalarSize + pointSize out + + dlclose dl + + res <- BS.packCStringLen (out, pointSize) + + free ptrScalar + free ptrPoint + free out + + return $ fromJust $ fromByteString @(Point a) res + + rustMultiScalarMultiplicationWithoutSerialization :: forall (a :: Type) @@ -80,7 +137,7 @@ rustMultiScalarMultiplicationWithoutSerialization points scalars = unsafePerform runMSM :: IO (Point a) runMSM = do dl <- dlopen libPath [RTLD_NOW] - msmPtr <- dlsym dl "rust_wrapper_multi_scalar_multiplication_without_serialization" + msmPtr <- dlsym dl "rust_wrapper_msm" let !msmf = mkFunMSM $ castFunPtr msmPtr ptrScalars <- callocBytes @(ScalarField a) scalarsByteLength diff --git a/rust-wrapper/src/msm.rs b/rust-wrapper/src/msm.rs index 796399a..cf43e56 100644 --- a/rust-wrapper/src/msm.rs +++ b/rust-wrapper/src/msm.rs @@ -1,10 +1,12 @@ use std::slice; use ark_bls12_381::G1Affine as GAffine; +use ark_ec::CurveGroup; use ark_ff::PrimeField; use ark_msm::msm::VariableBaseMSM; -use ark_serialize::CanonicalSerialize; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::log2; +use ark_bls12_381::Fr as ScalarField; use crate::utils::{deserialize_vector_points, deserialize_vector_scalar_field}; @@ -21,7 +23,7 @@ const fn get_opt_window_size(k: u32) -> u32 { } } -pub fn multi_scalar_multiplication_without_serialization( +pub fn msm( scalar_buffer: &[u8], point_buffer: &[u8], ) -> Vec { @@ -48,12 +50,53 @@ pub fn multi_scalar_multiplication_without_serialization( res } +pub fn mul( + scalar_buffer: &[u8], + point_buffer: &[u8], +) -> Vec { + let scalar: ScalarField = PrimeField::from_le_bytes_mod_order(scalar_buffer); + + let mut bytes: Vec = point_buffer.to_vec(); + let points_size = bytes.len(); + bytes[0..(points_size >> 1)].reverse(); + bytes[(points_size >> 1)..points_size].reverse(); + let point = GAffine::deserialize_uncompressed_unchecked(&*bytes).unwrap(); + + let r: GAffine = (point * scalar).into_affine(); + print!("Rust mul"); + let mut res = Vec::new(); + r.serialize_uncompressed(&mut res).unwrap(); + res +} + +/// +/// # Safety +/// The caller must ensure that valid pointers and sizes are passed. +/// . +#[no_mangle] +pub unsafe extern "C" fn rust_wrapper_msm( + points_var: *const libc::c_char, + points_len: usize, + scalars_var: *const libc::c_char, + scalars_len: usize, + out_len: usize, + out: *mut libc::c_char, +) { + let scalar_buffer = slice::from_raw_parts(scalars_var as *const u8, scalars_len); + let point_buffer = slice::from_raw_parts(points_var as *const u8, points_len); + + let res = msm(scalar_buffer, point_buffer); + + std::ptr::copy(res.as_ptr(), out as *mut u8, out_len); +} + + /// /// # Safety /// The caller must ensure that valid pointers and sizes are passed. /// . #[no_mangle] -pub unsafe extern "C" fn rust_wrapper_multi_scalar_multiplication_without_serialization( +pub unsafe extern "C" fn rust_wrapper_mul( points_var: *const libc::c_char, points_len: usize, scalars_var: *const libc::c_char, @@ -64,7 +107,7 @@ pub unsafe extern "C" fn rust_wrapper_multi_scalar_multiplication_without_serial let scalar_buffer = slice::from_raw_parts(scalars_var as *const u8, scalars_len); let point_buffer = slice::from_raw_parts(points_var as *const u8, points_len); - let res = multi_scalar_multiplication_without_serialization(scalar_buffer, point_buffer); + let res = mul(scalar_buffer, point_buffer); std::ptr::copy(res.as_ptr(), out as *mut u8, out_len); } diff --git a/zkfold-prover.cabal b/zkfold-prover.cabal index c39a7c6..1d35f29 100644 --- a/zkfold-prover.cabal +++ b/zkfold-prover.cabal @@ -82,6 +82,7 @@ library import: options exposed-modules: RustFunctions + , RustBLS hs-source-dirs: haskell-wrapper/src build-depends: @@ -91,6 +92,8 @@ library , vector , binary , unix + , deepseq + , QuickCheck ghc-options: -O2 From 0d545efec7e2e7016f5a8d3c9d386f22c53a47ef Mon Sep 17 00:00:00 2001 From: diS3e Date: Mon, 3 Feb 2025 09:41:34 +0000 Subject: [PATCH 18/18] stylish-haskell auto-commit --- bench/BenchProve.hs | 2 +- haskell-wrapper/src/RustBLS.hs | 43 ++++++++++++++-------------- haskell-wrapper/src/RustFunctions.hs | 2 +- rust-wrapper/src/msm.rs | 13 ++------- 4 files changed, 27 insertions(+), 33 deletions(-) diff --git a/bench/BenchProve.hs b/bench/BenchProve.hs index b8cbc64..757159a 100644 --- a/bench/BenchProve.hs +++ b/bench/BenchProve.hs @@ -4,6 +4,7 @@ module Main where import qualified Data.ByteString as BS import GHC.Generics (U1 (U1)) import Prelude hiding (Num (..), length, sum, take, (-)) +import RustBLS (RustBLS12_381_G1, RustBLS12_381_G2) import RustFunctions (RustCore) import Test.QuickCheck (Arbitrary (arbitrary), generate) import Test.QuickCheck.Arbitrary (Arbitrary1 (liftArbitrary)) @@ -13,7 +14,6 @@ import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 (BLS12_381_G1, BLS1 import ZkFold.Base.Data.Vector (Vector) import ZkFold.Base.Protocol.NonInteractiveProof import ZkFold.Base.Protocol.Plonk (Plonk) -import RustBLS (RustBLS12_381_G1, RustBLS12_381_G2) type PlonkBS n = Plonk U1 (Vector 1) 32 (Vector n) RustBLS12_381_G1 RustBLS12_381_G2 BS.ByteString diff --git a/haskell-wrapper/src/RustBLS.hs b/haskell-wrapper/src/RustBLS.hs index f1183a7..de939a7 100644 --- a/haskell-wrapper/src/RustBLS.hs +++ b/haskell-wrapper/src/RustBLS.hs @@ -1,43 +1,44 @@ {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DerivingVia #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wno-orphans #-} module RustBLS where -import Control.DeepSeq (NFData) +import Control.DeepSeq (NFData) import Control.Monad import Data.Bits -import Data.Foldable hiding (sum) +import Data.Foldable hiding (sum) +import Data.Kind (Type) +import qualified Data.Vector as V import Data.Word -import GHC.Generics (Generic) -import Prelude hiding (Num (..), (/), (^), Eq, sum) - -import ZkFold.Base.Algebra.Basic.Class hiding (sum) +import GHC.Generics (Generic) +import Prelude hiding (Eq, Num (..), sum, (/), (^)) +import qualified Prelude as P +import RustFunctions (RustCore, both, rustDivFft, rustMulFft, + rustMulPoint, + rustMultiScalarMultiplicationWithoutSerialization) +import Test.QuickCheck (Arbitrary) + +import ZkFold.Base.Algebra.Basic.Class (sum) +import ZkFold.Base.Algebra.Basic.Class hiding (sum) import ZkFold.Base.Algebra.Basic.Field import ZkFold.Base.Algebra.Basic.Number +import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 hiding (Fq, Fr) import ZkFold.Base.Algebra.EllipticCurve.Class import ZkFold.Base.Algebra.EllipticCurve.Pairing import ZkFold.Base.Algebra.Polynomials.Univariate import ZkFold.Base.Data.ByteString -import qualified Data.Vector as V - -import RustFunctions (rustMulFft, rustMulPoint, rustMultiScalarMultiplicationWithoutSerialization, RustCore, both, rustDivFft) -import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 hiding (Fq, Fr) -import ZkFold.Symbolic.Data.Bool (BoolType) -import ZkFold.Symbolic.Data.Conditional (Conditional) -import Data.Kind (Type) -import qualified Prelude as P -import ZkFold.Symbolic.Data.Eq (Eq) -import Test.QuickCheck (Arbitrary) -import ZkFold.Base.Protocol.NonInteractiveProof.Internal (HaskellCore, CoreFunction(..)) -import ZkFold.Base.Algebra.Basic.Class (sum) +import ZkFold.Base.Protocol.NonInteractiveProof.Internal (CoreFunction (..), HaskellCore) +import ZkFold.Symbolic.Data.Bool (BoolType) +import ZkFold.Symbolic.Data.Conditional (Conditional) +import ZkFold.Symbolic.Data.Eq (Eq) type Fr = Zp BLS12_381_Scalar type Fq = Zp BLS12_381_Base @@ -276,4 +277,4 @@ instance {-# OVERLAPPING #-} CoreFunction RustBLS12_381_G1 RustCore where zipAndUnzip _ _ = ([],[]) polyMul x y = toPoly (rustMulFft @(ScalarField BLS12_381_G1) (fromPoly x) (fromPoly y)) - polyQr x y = both toPoly $ rustDivFft @(ScalarField BLS12_381_G1) (fromPoly x) (fromPoly y) \ No newline at end of file + polyQr x y = both toPoly $ rustDivFft @(ScalarField BLS12_381_G1) (fromPoly x) (fromPoly y) diff --git a/haskell-wrapper/src/RustFunctions.hs b/haskell-wrapper/src/RustFunctions.hs index f234e55..cec7b11 100644 --- a/haskell-wrapper/src/RustFunctions.hs +++ b/haskell-wrapper/src/RustFunctions.hs @@ -32,7 +32,7 @@ import ZkFold.Base.Algebra.Basic.Class (AdditiveMonoid (ze import ZkFold.Base.Algebra.Basic.Field import ZkFold.Base.Algebra.EllipticCurve.BLS12_381 import ZkFold.Base.Algebra.EllipticCurve.Class -import ZkFold.Base.Algebra.Polynomials.Univariate (fromPoly, fromPolyVec, toPoly, qr) +import ZkFold.Base.Algebra.Polynomials.Univariate (fromPoly, fromPolyVec, qr, toPoly) import ZkFold.Base.Data.ByteString import ZkFold.Base.Protocol.NonInteractiveProof (CoreFunction (..), msm) diff --git a/rust-wrapper/src/msm.rs b/rust-wrapper/src/msm.rs index cf43e56..99b2d44 100644 --- a/rust-wrapper/src/msm.rs +++ b/rust-wrapper/src/msm.rs @@ -1,12 +1,12 @@ use std::slice; +use ark_bls12_381::Fr as ScalarField; use ark_bls12_381::G1Affine as GAffine; use ark_ec::CurveGroup; use ark_ff::PrimeField; use ark_msm::msm::VariableBaseMSM; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::log2; -use ark_bls12_381::Fr as ScalarField; use crate::utils::{deserialize_vector_points, deserialize_vector_scalar_field}; @@ -23,10 +23,7 @@ const fn get_opt_window_size(k: u32) -> u32 { } } -pub fn msm( - scalar_buffer: &[u8], - point_buffer: &[u8], -) -> Vec { +pub fn msm(scalar_buffer: &[u8], point_buffer: &[u8]) -> Vec { let scalars: Vec<_> = deserialize_vector_scalar_field(scalar_buffer) .iter() .map(|i| i.into_bigint()) @@ -50,10 +47,7 @@ pub fn msm( res } -pub fn mul( - scalar_buffer: &[u8], - point_buffer: &[u8], -) -> Vec { +pub fn mul(scalar_buffer: &[u8], point_buffer: &[u8]) -> Vec { let scalar: ScalarField = PrimeField::from_le_bytes_mod_order(scalar_buffer); let mut bytes: Vec = point_buffer.to_vec(); @@ -90,7 +84,6 @@ pub unsafe extern "C" fn rust_wrapper_msm( std::ptr::copy(res.as_ptr(), out as *mut u8, out_len); } - /// /// # Safety /// The caller must ensure that valid pointers and sizes are passed.