Skip to content

Commit

Permalink
fixup: Clean up asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
robot-dreams committed Apr 5, 2022
1 parent c88245f commit be9becb
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions doc/musig-reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import secrets
import time

# WARNING: Implementers should be aware that some inputs could
# trigger assertion errors, and proceed with caution. For example,
# an assertion error raised in one of the functions below should not
# cause a server process to crash.

#
# The following helper functions were copied from the BIP-340 reference implementation:
# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py
Expand Down Expand Up @@ -115,14 +120,16 @@ def point_negate(P: Optional[Point]) -> Optional[Point]:

def pointc(x: bytes) -> Point:
P = lift_x(x[1:33])
assert P is not None
if P is None:
raise ValueError('x is not a valid compressed point.')
if x[0] == 2:
return P
elif x[0] == 3:
P = point_negate(P)
assert P is not None
return P
assert False
else:
raise ValueError('x is not a valid compressed point.')

def key_agg(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> bytes:
Q, _, _ = key_agg_internal(pubkeys, tweaks, is_xonly)
Expand All @@ -136,7 +143,8 @@ def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[b
P_i = lift_x(pubkeys[i])
a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2)
Q = point_add(Q, point_mul(P_i, a_i))
assert Q is not None
if Q is None:
raise ValueError('The aggregate public key cannot be infinity.')
gacc = 1
tacc = 0
v = len(tweaks)
Expand Down Expand Up @@ -165,14 +173,18 @@ def key_agg_coeff_internal(pubkeys: List[bytes], pk_: bytes, pk2: bytes) -> int:
return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n

def apply_tweak(Q: Point, gacc: int, tacc: int, tweak_i: bytes, is_xonly_i: bool) -> Tuple[Point, int, int]:
if len(tweak_i) != 32:
raise ValueError('The tweak must be a 32-byte array.')
if is_xonly_i and not has_even_y(Q):
g = n - 1
else:
g = 1
t_i = int_from_bytes(tweak_i)
assert t_i < n
if t_i >= n:
raise ValueError('The tweak must be less than n.')
Q_i = point_add(point_mul(Q, g), point_mul(G, t_i))
assert Q_i is not None
if Q_i is None:
raise ValueError('The result of tweaking cannot be infinity.')
gacc_i = g * gacc % n
tacc_i = (t_i + g * tacc) % n
return Q_i, gacc_i, tacc_i
Expand All @@ -193,16 +205,20 @@ def nonce_hash(rand: bytes, aggpk: bytes, i: int, msg: bytes, extra_in: bytes) -
return int_from_bytes(tagged_hash('MuSig/nonce', buf))

def nonce_gen(sk: bytes, aggpk: bytes, msg: bytes, extra_in: bytes) -> Tuple[bytes, bytes]:
assert len(sk) in (0, 32)
assert len(aggpk) in (0, 32)
assert len(msg) in (0, 32)
if len(sk) not in (0, 32):
raise ValueError('The optional byte array sk must have length 0 or 32.')
if len(aggpk) not in (0, 32):
raise ValueError('The optional byte array aggpk must have length 0 or 32.')
if len(msg) not in (0, 32):
raise ValueError('The optional byte array msg must have length 0 or 32.')
rand_ = secrets.token_bytes(32)
if len(sk) > 0:
rand = bytes_xor(sk, tagged_hash('MuSig/aux', rand_))
else:
rand = rand_
k_1 = nonce_hash(rand, aggpk, 1, msg, extra_in)
k_2 = nonce_hash(rand, aggpk, 2, msg, extra_in)
# k_1 == 0 or k_2 == 0 cannot occur except with negligible probability.
assert k_1 != 0
assert k_2 != 0
R_1_ = point_mul(G, k_1)
Expand Down Expand Up @@ -234,6 +250,7 @@ def get_session_values(session_ctx: SessionContext) -> tuple[Point, int, int, in
R_1 = pointc(aggnonce[0:33])
R_2 = pointc(aggnonce[33:66])
R = point_add(R_1, point_mul(R_2, b))
# The aggregate public nonce cannot be infinity except with negligible probability.
assert R is not None
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
return (Q, gacc_v, tacc_v, b, R, e)
Expand All @@ -247,12 +264,15 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
(Q, gacc_v, _, b, R, e) = get_session_values(session_ctx)
k_1_ = int_from_bytes(secnonce[0:32])
k_2_ = int_from_bytes(secnonce[32:64])
assert 0 < k_1_ < n
assert 0 < k_2_ < n
if not 0 < k_1_ < n:
raise ValueError('first secnonce value is out of range.')
if not 0 < k_2_ < n:
raise ValueError('second secnonce value is out of range.')
k_1 = k_1_ if has_even_y(R) else n - k_1_
k_2 = k_2_ if has_even_y(R) else n - k_2_
d_ = int_from_bytes(sk)
assert 0 < d_ < n
if not 0 < d_ < n:
raise ValueError('secret key value is out of range.')
P = point_mul(G, d_)
assert P is not None
a = get_session_key_agg_coeff(session_ctx, P)
Expand All @@ -266,6 +286,7 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
assert R_1_ is not None
assert R_2_ is not None
pubnonce = cbytes(R_1_) + cbytes(R_2_)
# Optional correctness check. The result of signing should pass signature verification.
assert partial_sig_verify_internal(psig, pubnonce, bytes_from_point(P), session_ctx)
return psig

Expand Down

0 comments on commit be9becb

Please sign in to comment.