Skip to content

Commit

Permalink
Add implementation of Pallas
Browse files Browse the repository at this point in the history
  • Loading branch information
str4d committed Feb 25, 2021
1 parent 0db0553 commit e5cf685
Showing 1 changed file with 216 additions and 0 deletions.
216 changes: 216 additions & 0 deletions orchard_pallas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
#!/usr/bin/env python3
import sys; assert sys.version_info[0] >= 3, "Python 3 required."

from sapling_jubjub import FieldElement

p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001

pm1d2 = 0x2000000000000000000000000000000011234c7e04a67c8dcc96987680000000
assert (p - 1) // 2 == pm1d2

S = 32
T = 0x40000000000000000000000000000000224698fc094cf91b992d30ed
assert (p - 1) == (1 << S) * T

tm1d2 = 0x2000000000000000000000000000000011234c7e04a67c8dcc969876
assert (T - 1) // 2 == tm1d2

ROOT_OF_UNITY = 0x2bce74deac30ebda362120830561f81aea322bf2b7bb7584bdad6fabd87ea32f


#
# Field arithmetic
#

class Fp(FieldElement):
@staticmethod
def from_bytes(buf):
return Fp(leos2ip(buf), strict=True)

def __init__(self, s, strict=False):
FieldElement.__init__(self, Fp, s, p, strict=strict)

def __str__(self):
return 'Fp(%s)' % self.s

def sqrt(self):
# Tonelli-Shank's algorithm for p mod 16 = 1
# https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
a = self.exp(pm1d2)
if a == self.ONE:
# z <- c^t
c = Fp(ROOT_OF_UNITY)
# x <- a \omega
x = self.exp(tm1d2 + 1)
# b <- x \omega = a \omega^2
b = self.exp(T)
y = S

# 7: while b != 1 do
while b != self.ONE:
# 8: Find least integer k >= 0 such that b^(2^k) == 1
k = 1
b2k = b * b
while b2k != self.ONE:
b2k = b2k * b2k
k += 1
assert k < y

# 9:
# w <- z^(2^(y-k-1))
for _ in range(0, y - k - 1):
c = c * c
# x <- xw
x = x * c
# z <- w^2
c = c * c
# b <- bz
b = b * c
# y <- k
y = k
assert x * x == self
return x
elif a == self.MINUS_ONE:
return None
return self.ZERO


class Scalar(FieldElement):
def __init__(self, s, strict=False):
FieldElement.__init__(self, Scalar, s, q, strict=strict)

def __str__(self):
return 'Scalar(%s)' % self.s

Fp.ZERO = Fp(0)
Fp.ONE = Fp(1)
Fp.MINUS_ONE = Fp(-1)

assert Fp.ZERO + Fp.ZERO == Fp.ZERO
assert Fp.ZERO + Fp.ONE == Fp.ONE
assert Fp.ONE + Fp.ZERO == Fp.ONE
assert Fp.ZERO - Fp.ONE == Fp.MINUS_ONE
assert Fp.ZERO * Fp.ONE == Fp.ZERO
assert Fp.ONE * Fp.ZERO == Fp.ZERO


#
# Point arithmetic
#

PALLAS_B = Fp(5)

class Point(object):
@staticmethod
def rand(rand):
while True:
data = rand.b(32)
p = Point.from_bytes(data)
if p is not None:
return p

@staticmethod
def from_bytes(buf):
assert len(buf) == 32
y_sign = buf[31] >> 7
buf = buf[:31] + bytes([buf[31] & 0b01111111])
try:
x = Fp.from_bytes(buf)
except ValueError:
return None

x3 = x * x * x
y2 = x3 + PALLAS_B

y = y2.sqrt()
if y is None:
return None

if y.s % 2 != y_sign:
y = Fp.ZERO - y

return Point(x, y)

def __init__(self, x, y):
self.x = x
self.y = y
self.is_identity = False

def identity():
p = Point(Fp.ZERO, Fp.ZERO)
p.is_identity = True
return p

def __add__(self, a):
if self.is_identity:
return a
elif a.is_identity:
return self
else:
(x1, y1) = (self.x, self.y)
(x2, y2) = (a.x, a.y)

if x1 == x2:
if y1 == y2 and y1 != Fp.ZERO:
return self.double()
return Point.identity()

s = (y2 - y1) / (x2 - x1)
x3 = s*s - x1 - x2
y3 = y1 + s*(x3 - x1)
return Point(x3, y3)

def __sub__(self, a):
if a.is_identity:
neg_a = a
else:
neg_a = Point(Fp(a.x.s), Fp.ZERO - Fp(a.y.s))
return self + neg_a

def double(self):
if self.is_identity:
return self

s = (Fp(3) * self.x * self.x) / (self.y + self.y)
x = s*s - self.x - self.x
y = self.y + s*(x - self.x)
return Point(x, y)

def __mul__(self, s):
s = format(s.s, '0256b')
ret = self.ZERO
for c in s:
ret = ret.double()
if int(c):
ret = ret + self
return ret

def __bytes__(self):
if self.is_identity:
return bytes([0] * 32)

buf = bytes(self.y)
if self.x.s % 2 == 1:
buf = buf[:31] + bytes([buf[31] | (1 << 7)])
return buf

def __eq__(self, a):
if not (self.is_identity or a.is_identity):
return self.x == a.x and self.y == a.y
else:
return self.is_identity == a.is_identity

def __str__(self):
if self.is_identity:
return 'Point(identity)'
else:
return 'Point(%s, %s)' % (self.x, self.y)


Point.ZERO = Point.identity()
Point.GENERATOR = Point(Fp.MINUS_ONE, Fp(2))

assert Point.ZERO + Point.ZERO == Point.ZERO
assert Point.GENERATOR - Point.GENERATOR == Point.ZERO
assert Point.GENERATOR + Point.GENERATOR + Point.GENERATOR == Point.GENERATOR * Scalar(3)

0 comments on commit e5cf685

Please sign in to comment.