Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement an iterative FFT algorithm #21

Merged
merged 1 commit into from
Apr 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ thiserror = "1.0"

[dev-dependencies]
assert_matches = "1.5.0"
criterion = "0.3"
modinverse = "0.1.0"
num-bigint = "0.4.0"

[[bench]]
name = "fft"
harness = false

[[example]]
name = "sum"
33 changes: 33 additions & 0 deletions benches/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MPL-2.0

use criterion::{criterion_group, criterion_main, Criterion};

use prio::benchmarked::{benchmarked_iterative_fft, benchmarked_recursive_fft};
use prio::finite_field::{Field, FieldElement};

pub fn fft(c: &mut Criterion) {
let test_sizes = [16, 256, 1024, 4096];
for size in test_sizes.iter() {
let mut rng = rand::thread_rng();
let mut inp = vec![Field::zero(); *size];
let mut outp = vec![Field::zero(); *size];
for i in 0..*size {
inp[i] = Field::rand(&mut rng);
}

c.bench_function(&format!("iterative/{}", *size), |b| {
b.iter(|| {
benchmarked_iterative_fft(&mut outp, &inp);
})
});

c.bench_function(&format!("recursive/{}", *size), |b| {
b.iter(|| {
benchmarked_recursive_fft(&mut outp, &inp);
})
});
}
}

criterion_group!(benches, fft);
criterion_main!(benches);
26 changes: 26 additions & 0 deletions src/benchmarked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-License-Identifier: MPL-2.0

//! This package provides wrappers around internal components of this crate that we want to
//! benchmark, but which we don't want to expose in the public API.

use crate::fft::discrete_fourier_transform;
use crate::finite_field::{Field, FieldElement};
use crate::polynomial::{poly_fft, PolyAuxMemory};

/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm.
pub fn benchmarked_iterative_fft<F: FieldElement>(outp: &mut [F], inp: &[F]) {
discrete_fourier_transform(outp, inp).expect("encountered unexpected error");
}

/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm.
pub fn benchmarked_recursive_fft(outp: &mut [Field], inp: &[Field]) {
let mut mem = PolyAuxMemory::new(inp.len() / 2);
poly_fft(
outp,
inp,
&mem.roots_2n,
inp.len(),
false,
&mut mem.fft_memory,
)
}
2 changes: 1 addition & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl Client {
pub fn new(dimension: usize, public_key1: PublicKey, public_key2: PublicKey) -> Option<Self> {
let n = (dimension + 1).next_power_of_two();

if 2 * n > Field::num_roots() as usize {
if 2 * n > Field::generator_order() as usize {
// too many elements for this field, not enough roots of unity
return None;
}
Expand Down
172 changes: 172 additions & 0 deletions src/fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// SPDX-License-Identifier: MPL-2.0

//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier
//! Transform (DFT) over a slice of field elements.
use crate::finite_field::FieldElement;
use crate::fp::{log2, MAX_ROOTS};

use std::convert::TryFrom;

/// An error returned by DFT or DFT inverse computation.
#[derive(Debug, thiserror::Error)]
pub enum FftError {
/// The output is too small.
#[error("output slice is smaller than the input")]
OutputTooSmall,
/// The input is too large.
#[error("input slice is larger than than maximum permitted")]
InputTooLarge,
/// The input length is not a power of 2.
#[error("input size is not a power of 2")]
InputSizeInvalid,
}

/// Sets `outp` to the DFT of `inp`.
pub fn discrete_fourier_transform<F: FieldElement>(
outp: &mut [F],
inp: &[F],
) -> Result<(), FftError> {
let n = inp.len();
let d = usize::try_from(log2(n as u128)).unwrap();

if n > outp.len() {
return Err(FftError::OutputTooSmall);
}

if n > 1 << MAX_ROOTS {
return Err(FftError::InputTooLarge);
}

if n != 1 << d {
return Err(FftError::InputSizeInvalid);
}

for i in 0..n {
outp[i] = inp[bitrev(d, i)];
cjpatton marked this conversation as resolved.
Show resolved Hide resolved
}

let mut w: F;
for l in 1..d + 1 {
w = F::root(0).unwrap(); // one
let r = F::root(l).unwrap();
Comment on lines +51 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
w = F::root(0).unwrap(); // one
let r = F::root(l).unwrap();
let mut w = F::root(0)?; // one
let r = F::root(l)?;

Couple thoughts here:

  1. I think it's more clear to declare w here rather than outside the for loop.
  2. We should gracefully propagate errors rather than panicking. This function should return Result<(), FftError>, or Result<[F], FftError> if you end up deciding to return the result rather than mutating a passed-in buffer. Naturally this means you'll have to define an FftError enum in this module, akin to encrypt::EncryptError and finite_field::FiniteFieldError.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This answers the question above. I'll propagate an error here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I need to do .ok_or_else(< some FftError>)?.

However, I think panicking is more appropriate here. These errors don't depends on the input, they only happen if root() isn't implemented properly.

let y = 1 << (l - 1);
for i in 0..y {
for j in 0..(n / y) >> 1 {
let x = (1 << l) * j + i;
let u = outp[x];
let v = w * outp[x + y];
outp[x] = u + v;
outp[x + y] = u - v;
}
w *= r;
}
}

Ok(())
}

/// Sets `outp` to the inverse of the DFT of `inp`.
#[allow(dead_code)]
pub fn discrete_fourier_transform_inv<F: FieldElement>(
outp: &mut [F],
inp: &[F],
) -> Result<(), FftError> {
discrete_fourier_transform(outp, inp)?;
let n = inp.len();
let m = F::from(F::Integer::try_from(n).unwrap()).inv();
let mut tmp: F;

outp[0] *= m;
outp[n >> 1] *= m;
for i in 1..n >> 1 {
tmp = outp[i] * m;
outp[i] = outp[n - i] * m;
outp[n - i] = tmp;
}

Ok(())
}

// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109)
fn bitrev(d: usize, x: usize) -> usize {
let mut y = 0;
for i in 0..d {
y += ((x >> i) & 1) << (d - i);
}
y >> 1
}

#[cfg(test)]
mod tests {
use super::*;
use crate::finite_field::{Field, Field126, Field64, Field80};
use crate::polynomial::{poly_fft, PolyAuxMemory};

fn discrete_fourier_transform_then_inv_test<F: FieldElement>() -> Result<(), FftError> {
let mut rng = rand::thread_rng();
let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048];

for size in test_sizes.iter() {
let mut want = vec![F::zero(); *size];
let mut tmp = vec![F::zero(); *size];
let mut got = vec![F::zero(); *size];
for i in 0..*size {
want[i] = F::rand(&mut rng);
}

discrete_fourier_transform(&mut tmp, &want)?;
discrete_fourier_transform_inv(&mut got, &tmp)?;
assert_eq!(got, want);
}

Ok(())
}

#[test]
fn test_field32() {
discrete_fourier_transform_then_inv_test::<Field>().expect("unexpected error");
}

#[test]
fn test_field64() {
discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error");
}

#[test]
fn test_field80() {
discrete_fourier_transform_then_inv_test::<Field80>().expect("unexpected error");
}

#[test]
fn test_field126() {
discrete_fourier_transform_then_inv_test::<Field126>().expect("unexpected error");
}

#[test]
fn test_recursive_fft() {
let size = 128;
let mut rng = rand::thread_rng();
let mut mem = PolyAuxMemory::new(size / 2);

let mut inp = vec![Field::zero(); size];
let mut want = vec![Field::zero(); size];
let mut got = vec![Field::zero(); size];
for i in 0..size {
inp[i] = Field::rand(&mut rng);
}

discrete_fourier_transform::<Field>(&mut want, &inp).expect("unexpected error");

poly_fft(
&mut got,
&inp,
&mem.roots_2n,
size,
false,
&mut mem.fft_memory,
);

assert_eq!(got, want);
}
}
67 changes: 53 additions & 14 deletions src/finite_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
use crate::fp::{FP126, FP32, FP64, FP80};
use std::{
cmp::min,
convert::TryFrom,
fmt::{Display, Formatter},
fmt::{Debug, Display, Formatter},
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};

use rand::Rng;

/// Possible errors from finite field operations.
#[derive(Debug, thiserror::Error)]
pub enum FiniteFieldError {
Expand All @@ -21,36 +24,55 @@ pub enum FiniteFieldError {
/// Objects with this trait represent an element of `GF(p)` for some prime `p`.
pub trait FieldElement:
Sized
+ Debug
+ Copy
+ PartialEq
+ Eq
+ Add
+ Add<Output = Self>
+ AddAssign
+ Sub
+ Sub<Output = Self>
+ SubAssign
+ Mul
+ Mul<Output = Self>
+ MulAssign
+ Div
+ Div<Output = Self>
+ DivAssign
+ Neg
+ Neg<Output = Self>
+ Display
+ From<<Self as FieldElement>::Integer>
{
/// The error returned if converting `usize` to an `Int` fails.
type IntegerTryFromError: std::fmt::Debug;

/// The integer representation of the field element.
type Integer;
type Integer: Copy
+ Debug
+ Sub<Output = <Self as FieldElement>::Integer>
+ TryFrom<usize, Error = Self::IntegerTryFromError>;

/// Modular exponentation, i.e., `self^exp (mod p)`.
fn pow(&self, exp: Self) -> Self;
fn pow(&self, exp: Self) -> Self; // TODO(cjpatton) exp should have type Self::Integer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Because I might want to raise self to some exponent that's greater than the field prime?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if one would ever want to do that. The reason is more mathematically pedantic: Raising an element of a group (or field) to the power of an element of that group (or field) is not always well-defined. In general, a^x only makes sense if x is an integer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I plan to solve this TODO in the next PR.


/// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined.
fn inv(&self) -> Self;

/// Returns the prime modulus `p`.
fn modulus() -> Self::Integer;

/// Returns a generator of the multiplicative subgroup of size `FieldElement::num_roots()`.
/// Returns the size of the multiplicative subgroup generated by `generator()`.
fn generator_order() -> Self::Integer;

/// Returns the generator of the multiplicative subgroup of size `generator_order()`.
fn generator() -> Self;

/// Returns the size of the multiplicative subgroup generated by `FieldElement::generator()`.
fn num_roots() -> Self::Integer;
/// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th
/// prinicpal root of unity is 1 by definition.
fn root(l: usize) -> Option<Self>;

/// Returns a random field element distributed uniformly over all field elements.
fn rand<R: Rng + ?Sized>(rng: &mut R) -> Self;

/// Returns the additive identity.
fn zero() -> Self;
}

macro_rules! make_field {
Expand Down Expand Up @@ -190,6 +212,7 @@ macro_rules! make_field {

impl FieldElement for $elem {
type Integer = $int;
type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error;

fn pow(&self, exp: Self) -> Self {
Self($fp.pow(self.0, $fp.from_elem(exp.0)))
Expand All @@ -207,8 +230,24 @@ macro_rules! make_field {
Self($fp.g)
}

fn num_roots() -> Self::Integer {
$fp.num_roots as Self::Integer
fn generator_order() -> Self::Integer {
1 << (Self::Integer::try_from($fp.num_roots).unwrap())
}

fn root(l: usize) -> Option<Self> {
if l < min($fp.roots.len(), $fp.num_roots+1) {
Some(Self($fp.roots[l]))
} else {
None
}
}

fn rand<R: Rng + ?Sized>(rng: &mut R) -> Self {
Self($fp.rand_elem(rng))
}

fn zero() -> Self {
Self(0)
}
}
};
Expand Down Expand Up @@ -245,7 +284,7 @@ make_field!(

#[test]
fn test_arithmetic() {
// TODO(cjpatton) Add tests for Field64, Field80, and Field126.
// TODO(cjpatton) Add tests for the other fields.
use rand::prelude::*;

let modulus = Field::modulus();
Expand Down
Loading