Skip to content

Commit

Permalink
musig-spec: Update reference implementation to match spec
Browse files Browse the repository at this point in the history
  • Loading branch information
robot-dreams committed Apr 4, 2022
1 parent 9265429 commit ff70d50
Showing 1 changed file with 164 additions and 62 deletions.
226 changes: 164 additions & 62 deletions doc/musig-reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,60 +102,90 @@ def pointc(x: bytes) -> Point:
return point_negate(P)
assert False

SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'msg'])

def get_session_values(session_ctx: SessionContext) -> tuple[bytes, List[bytes], bytes]:
(aggnonce, pubkeys, msg) = session_ctx
Q = key_agg_internal(pubkeys)
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
R_1 = pointc(aggnonce[0:33])
R_2 = pointc(aggnonce[33:66])
R = point_add(R_1, point_mul(R_2, b))
assert not is_infinite(R)
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
return (Q, b, R, e)

def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int:
(_, pubkeys, _) = session_ctx
return key_agg_coeff(pubkeys, bytes_from_point(P))

def key_agg(pubkeys: List[bytes]) -> bytes:
Q = key_agg_internal(pubkeys)
def key_agg(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> bytes:
Q, _, _ = key_agg_internal(pubkeys, tweaks, is_xonly)
return bytes_from_point(Q)

def key_agg_internal(pubkeys: List[bytes]) -> Point:
def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool]) -> Point:
pk2 = get_second_key(pubkeys)
u = len(pubkeys)
Q = infinity
for i in range(u):
a_i = key_agg_coeff(pubkeys, pubkeys[i])
P_i = lift_x(pubkeys[i])
a_i = key_agg_coeff_(pubkeys, pubkeys[i], pk2)
Q = point_add(Q, point_mul(P_i, a_i))
assert not is_infinite(Q)
return Q
gacc = 1
tacc = 0
v = len(tweaks)
for i in range(v):
Q, gacc, tacc = tweak(Q, gacc, tacc, tweaks[i], is_xonly[i])
return Q, gacc, tacc

def hash_keys(pubkeys: List[bytes]) -> bytes:
return tagged_hash('KeyAgg list', b''.join(pubkeys))

def is_second(pubkeys: List[bytes], pk: bytes) -> bool:
def get_second_key(pubkeys: List[bytes]) -> bytes:
u = len(pubkeys)
for j in range(u):
for j in range(1, u):
if pubkeys[j] != pubkeys[0]:
return pubkeys[j] == pk
return False
return pubkeys[j]
return bytes_from_int(0)

def key_agg_coeff(pubkeys: List[bytes], pk: bytes) -> int:
if is_second(pubkeys, pk):
def key_agg_coeff(pubkeys: List[bytes], pk_: bytes) -> int:
pk2 = get_second_key(pubkeys)
return key_agg_coeff_(pubkeys, pk_, pk2)

def key_agg_coeff_(pubkeys: List[bytes], pk_: bytes, pk2: bytes) -> int:
L = hash_keys(pubkeys)
if pk_ == pk2:
return 1
return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n

def tweak(Q: Point, gacc: int, tacc: int, tweak_i: bytes, is_xonly_i: bool) -> Tuple[Point, int, int]:
if is_xonly_i and not has_even_y(Q):
g = n - 1
else:
g = 1
t_i = int_from_bytes(tweak_i)
assert t_i < n
Q_i = point_add(point_mul(Q, g), point_mul(G, t_i))
assert not is_infinite(Q_i)
gacc_i = g * gacc % n
tacc_i = (t_i + g * tacc) % n
return Q_i, gacc_i, tacc_i

def bytes_xor(a: bytes, b: bytes) -> bytes:
return bytes(x ^ y for x, y in zip(a, b))

def nonce_hash(rand: bytes, aggpk: bytes, i: int, msg: bytes, extra_in: bytes) -> bytes:
buf = b''
buf += rand
buf += len(aggpk).to_bytes(1, 'big')
buf += aggpk
buf += i.to_bytes(1, 'big')
buf += len(msg).to_bytes(1, 'big')
buf += msg
buf += len(extra_in).to_bytes(4, 'big')
buf += extra_in
return int_from_bytes(tagged_hash('MuSig/nonce', buf))

def nonce_gen(sk: bytes, aggpk: bytes, msg: bytes, extra_in: bytes) -> Tuple[bytes, bytes]:
assert len(sk) in (0, 32)
assert len(aggpk) in (0, 32)
assert len(msg) in (0, 32)
rand_ = secrets.token_bytes(32)
if len(sk) > 0:
rand = bytes_xor(sk, tagged_hash('MuSig/aux', rand_))
else:
L = hash_keys(pubkeys)
return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk)) % n

def nonce_gen() -> Tuple[bytes, bytes]:
k_1 = 1 + secrets.randbelow(n - 2)
k_2 = 1 + secrets.randbelow(n - 2)
R_1 = point_mul(G, k_1)
R_2 = point_mul(G, k_2)
pubnonce = cbytes(R_1) + cbytes(R_2)
rand = rand_
k_1 = nonce_hash(rand, aggpk, 1, msg, extra_in)
k_2 = nonce_hash(rand, aggpk, 2, msg, extra_in)
assert k_1 != 0
assert k_2 != 0
R_1_ = point_mul(G, k_1)
R_2_ = point_mul(G, k_2)
pubnonce = cbytes(R_1_) + cbytes(R_2_)
secnonce = bytes_from_int(k_1) + bytes_from_int(k_2)
return secnonce, pubnonce

Expand All @@ -170,8 +200,25 @@ def nonce_agg(pubnonces: List[bytes]) -> bytes:
aggnonce += cbytes(R_i)
return aggnonce

SessionContext = namedtuple('SessionContext', ['aggnonce', 'pubkeys', 'tweaks', 'is_xonly', 'msg'])

def get_session_values(session_ctx: SessionContext) -> tuple[bytes, List[bytes], bytes]:
(aggnonce, pubkeys, tweaks, is_xonly, msg) = session_ctx
Q, gacc_v, tacc_v = key_agg_internal(pubkeys, tweaks, is_xonly)
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n
R_1 = pointc(aggnonce[0:33])
R_2 = pointc(aggnonce[33:66])
R = point_add(R_1, point_mul(R_2, b))
assert not is_infinite(R)
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n
return (Q, gacc_v, tacc_v, b, R, e)

def get_session_key_agg_coeff(session_ctx: SessionContext, P: Point) -> int:
(_, pubkeys, _, _, _) = session_ctx
return key_agg_coeff(pubkeys, bytes_from_point(P))

def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
(Q, b, R, e) = get_session_values(session_ctx)
(Q, gacc_v, _, b, R, e) = get_session_values(session_ctx)
k_1_ = int_from_bytes(secnonce[0:32])
k_2_ = int_from_bytes(secnonce[32:64])
assert 0 < k_1_ < n
Expand All @@ -181,31 +228,34 @@ def sign(secnonce: bytes, sk: bytes, session_ctx: SessionContext) -> bytes:
d_ = int_from_bytes(sk)
assert 0 < d_ < n
P = point_mul(G, d_)
mu = get_session_key_agg_coeff(session_ctx, P)
d = n - d_ if has_even_y(P) != has_even_y(Q) else d_
s = (k_1 + b * k_2 + e * mu * d) % n
a = get_session_key_agg_coeff(session_ctx, P)
gp = 1 if has_even_y(P) else n - 1
g_v = 1 if has_even_y(Q) else n - 1
d = g_v * gacc_v * gp * d_
s = (k_1 + b * k_2 + e * a * d) % n
psig = bytes_from_int(s)
pubnonce = cbytes(point_mul(G, k_1_)) + cbytes(point_mul(G, k_2_))
assert partial_sig_verify_internal(psig, pubnonce, bytes_from_point(P), session_ctx)
return psig

def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], msg: bytes, i: int) -> bool:
def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[bool], msg: bytes, i: int) -> bool:
aggnonce = nonce_agg(pubnonces)
session_ctx = SessionContext(aggnonce, pubkeys, msg)
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
return partial_sig_verify_internal(psig, pubnonces[i], pubkeys[i], session_ctx)

def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk: bytes, session_ctx: SessionContext) -> bool:
(Q, b, R, e) = get_session_values(session_ctx)
def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, pk_: bytes, session_ctx: SessionContext) -> bool:
(Q, gacc_v, _, b, R, e) = get_session_values(session_ctx)
s = int_from_bytes(psig)
assert s < n
R_1_ = pointc(pubnonce[0:33])
R_2_ = pointc(pubnonce[33:66])
R__ = point_add(R_1_, point_mul(R_2_, b))
R_ = R__ if has_even_y(R) else point_negate(R__)
P_ = lift_x(pk)
P = P_ if has_even_y(Q) else point_negate(P_)
mu = get_session_key_agg_coeff(session_ctx, P)
return point_mul(G, s) == point_add(R_, point_mul(P, e * mu % n))
g_v = 1 if has_even_y(Q) else n - 1
g_ = g_v * gacc_v % n
P = point_mul(lift_x(pk_), g_)
a = get_session_key_agg_coeff(session_ctx, P)
return point_mul(G, s) == point_add(R_, point_mul(P, e * a % n))

#
# The following code is only used for testing.
Expand All @@ -230,10 +280,10 @@ def test_key_agg_vectors():
'2EB18851887E7BDC5E830E89B19DDBC28078F1FA88AAD0AD01CA06FE4F80210B',
])

assert key_agg([X[0], X[1], X[2]]) == expected[0]
assert key_agg([X[2], X[1], X[0]]) == expected[1]
assert key_agg([X[0], X[0], X[0]]) == expected[2]
assert key_agg([X[0], X[0], X[1], X[1]]) == expected[3]
assert key_agg([X[0], X[1], X[2]], [], []) == expected[0]
assert key_agg([X[2], X[1], X[0]], [], []) == expected[1]
assert key_agg([X[0], X[0], X[0]], [], []) == expected[2]
assert key_agg([X[0], X[0], X[1], X[1]], [], []) == expected[3]

def test_sign_vectors():
X = fromhex_all([
Expand All @@ -260,13 +310,13 @@ def test_sign_vectors():

pk = bytes_from_point(point_mul(G, int_from_bytes(sk)))

session_ctx = SessionContext(aggnonce, [pk, X[0], X[1]], msg)
session_ctx = SessionContext(aggnonce, [pk, X[0], X[1]], [], [], msg)
assert sign(secnonce, sk, session_ctx) == expected[0]

session_ctx = SessionContext(aggnonce, [X[0], pk, X[1]], msg)
session_ctx = SessionContext(aggnonce, [X[0], pk, X[1]], [], [], msg)
assert sign(secnonce, sk, session_ctx) == expected[1]

session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], msg)
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], [], [], msg)
assert sign(secnonce, sk, session_ctx) == expected[2]

def test_sign_and_verify_random(iters):
Expand All @@ -277,24 +327,76 @@ def test_sign_and_verify_random(iters):
pk_2 = bytes_from_point(point_mul(G, int_from_bytes(sk_2)))
pubkeys = [pk_1, pk_2]

secnonce_1, pubnonce_1 = nonce_gen()
secnonce_2, pubnonce_2 = nonce_gen()
secnonce_1, pubnonce_1 = nonce_gen(sk_1, b'', b'', b'')
secnonce_2, pubnonce_2 = nonce_gen(sk_2, b'', b'', b'')
pubnonces = [pubnonce_1, pubnonce_2]
aggnonce = nonce_agg(pubnonces)
tweaks = []
is_xonly = []

msg = secrets.token_bytes(32)

session_ctx = SessionContext(aggnonce, pubkeys, msg)
session_ctx = SessionContext(aggnonce, pubkeys, tweaks, is_xonly, msg)
psig = sign(secnonce_1, sk_1, session_ctx)
assert partial_sig_verify(psig, pubnonces, pubkeys, msg, 0)
assert partial_sig_verify(psig, pubnonces, pubkeys, tweaks, is_xonly, msg, 0)

# Wrong signer index
assert not partial_sig_verify(psig, pubnonces, pubkeys, msg, 1)
assert not partial_sig_verify(psig, pubnonces, pubkeys, tweaks, is_xonly, msg, 1)

# Wrong message
assert not partial_sig_verify(psig, pubnonces, pubkeys, secrets.token_bytes(32), 0)
assert not partial_sig_verify(psig, pubnonces, pubkeys, tweaks, is_xonly, secrets.token_bytes(32), 0)

def test_tweak_vectors():
X = fromhex_all([
'F9308A019258C31049344F85F89D5229B531C845836F99B08601F113BCE036F9',
'DFF1D77F2A671C5F36183726DB2341BE58FEAE1DA2DECED843240F7B502BA659',
])

secnonce = bytes.fromhex(
'508B81A611F100A6B2B6B29656590898AF488BCF2E1F55CF22E5CFB84421FE61' +
'FA27FD49B1D50085B481285E1CA205D55C82CC1B31FF5CD54A489829355901F7')

aggnonce = bytes.fromhex(
'028465FCF0BBDBCF443AABCCE533D42B4B5A10966AC09A49655E8C42DAAB8FCD61' +
'037496A3CC86926D452CAFCFD55D25972CA1675D549310DE296BFF42F72EEEA8C9')

sk = bytes.fromhex('7FB9E0E687ADA1EEBF7ECFE2F21E73EBDB51A7D450948DFE8D76D7F2D1007671')
msg = bytes.fromhex('F95466D086770E689964664219266FE5ED215C92AE20BAB5C9D79ADDDDF3C0CF')

tweaks = fromhex_all([
'E8F791FF9225A2AF0102AFFF4A9A723D9612A682A25EBE79802B263CDFCD83BB',
'AE2EA797CC0FE72AC5B97B97F3C6957D7E4199A167A58EB08BCAFFDA70AC0455',
'F52ECBC565B3D8BEA2DFD5B75A4F457E54369809322E4120831626F290FA87E0',
'1969AD73CC177FA0B4FCED6DF1F7BF9907E665FDE9BA196A74FED0A3CF5AEF9D',
])

expected = fromhex_all([
'5E24C7496B565DEBC3B9639E6F1304A21597F9603D3AB05B4913641775E1375B',
'78408DDCAB4813D1394C97D493EF1084195C1D4B52E63ECD7BC5991644E44DDD',
'C3A829A81480E36EC3AB052964509A94EBF34210403D16B226A6F16EC85B7357',
'8C4473C6A382BD3C4AD7BE59818DA5ED7CF8CEC4BC21996CFDA08BB4316B8BC7',
])

pk = bytes_from_point(point_mul(G, int_from_bytes(sk)))

# A single x-only tweak
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], [True], msg)
assert sign(secnonce, sk, session_ctx) == expected[0]

# A single ordinary tweak
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:1], [False], msg)
assert sign(secnonce, sk, session_ctx) == expected[1]

# An ordinary tweak followed by an x-only tweak
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:2], [False, True], msg)
assert sign(secnonce, sk, session_ctx) == expected[2]

# Four tweaks: x-only, ordinary, x-only, ordinary
session_ctx = SessionContext(aggnonce, [X[0], X[1], pk], tweaks[:4], [True, False, True, False], msg)
assert sign(secnonce, sk, session_ctx) == expected[3]

if __name__ == '__main__':
test_key_agg_vectors()
test_sign_vectors()
test_sign_and_verify_random(4)
test_tweak_vectors()

0 comments on commit ff70d50

Please sign in to comment.