Skip to content

Commit

Permalink
fix: add __radd__ etc to fmpz
Browse files Browse the repository at this point in the history
This is needed for Cython 3.x compatibility:

https://cython.readthedocs.io/en/latest/src/userguide/special_methods.html#arithmetic-methods

Similar methods should be added to all python-flint types that have
arithmetic operations.
  • Loading branch information
oscarbenjamin committed Dec 14, 2022
1 parent 7f44df2 commit 3aa4a76
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions src/flint/fmpz.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ cdef class fmpz(flint_scalar):
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __radd__(s, t):
cdef fmpz_struct tval[1]
cdef int ttype = FMPZ_UNKNOWN
u = NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype != FMPZ_UNKNOWN:
u = fmpz.__new__(fmpz)
fmpz_add((<fmpz>u).val, tval, (<fmpz>s).val)
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __sub__(s, t):
cdef fmpz_struct sval[1]
cdef fmpz_struct tval[1]
Expand All @@ -205,6 +216,17 @@ cdef class fmpz(flint_scalar):
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __rsub__(s, t):
cdef fmpz_struct tval[1]
cdef int ttype = FMPZ_UNKNOWN
u = NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype != FMPZ_UNKNOWN:
u = fmpz.__new__(fmpz)
fmpz_sub((<fmpz>u).val, tval, (<fmpz>s).val)
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __mul__(s, t):
cdef fmpz_struct sval[1]
cdef fmpz_struct tval[1]
Expand All @@ -221,6 +243,17 @@ cdef class fmpz(flint_scalar):
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __rmul__(s, t):
cdef fmpz_struct tval[1]
cdef int ttype = FMPZ_UNKNOWN
u = NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype != FMPZ_UNKNOWN:
u = fmpz.__new__(fmpz)
fmpz_mul((<fmpz>u).val, tval, (<fmpz>s).val)
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __floordiv__(s, t):
cdef fmpz_struct sval[1]
cdef fmpz_struct tval[1]
Expand All @@ -241,6 +274,21 @@ cdef class fmpz(flint_scalar):
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __rfloordiv__(s, t):
cdef fmpz_struct tval[1]
cdef int ttype = FMPZ_UNKNOWN
u = NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype != FMPZ_UNKNOWN:
if fmpz_is_zero((<fmpz>s).val):
if ttype == FMPZ_TMP:
fmpz_clear(tval)
raise ZeroDivisionError("fmpz division by zero")
u = fmpz.__new__(fmpz)
fmpz_fdiv_q((<fmpz>u).val, tval, (<fmpz>s).val)
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __mod__(s, t):
cdef fmpz_struct sval[1]
cdef fmpz_struct tval[1]
Expand All @@ -261,6 +309,21 @@ cdef class fmpz(flint_scalar):
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __rmod__(s, t):
cdef fmpz_struct tval[1]
cdef int ttype = FMPZ_UNKNOWN
u = NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype != FMPZ_UNKNOWN:
if fmpz_is_zero((<fmpz>s).val):
if ttype == FMPZ_TMP:
fmpz_clear(tval)
raise ZeroDivisionError("fmpz division by zero")
u = fmpz.__new__(fmpz)
fmpz_fdiv_r((<fmpz>u).val, tval, (<fmpz>s).val)
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __divmod__(s, t):
cdef fmpz_struct sval[1]
cdef fmpz_struct tval[1]
Expand All @@ -283,6 +346,23 @@ cdef class fmpz(flint_scalar):
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __rdivmod__(s, t):
cdef fmpz_struct tval[1]
cdef int ttype = FMPZ_UNKNOWN
u = NotImplemented
ttype = fmpz_set_any_ref(tval, t)
if ttype != FMPZ_UNKNOWN:
if fmpz_is_zero((<fmpz>s).val):
if ttype == FMPZ_TMP:
fmpz_clear(tval)
raise ZeroDivisionError("fmpz division by zero")
u1 = fmpz.__new__(fmpz)
u2 = fmpz.__new__(fmpz)
fmpz_fdiv_qr((<fmpz>u1).val, (<fmpz>u2).val, tval, (<fmpz>s).val)
u = u1, u2
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def __pow__(s, t, m):
cdef fmpz_struct sval[1]
cdef int stype = FMPZ_UNKNOWN
Expand All @@ -298,6 +378,21 @@ cdef class fmpz(flint_scalar):
if stype == FMPZ_TMP: fmpz_clear(sval)
return u

def __rpow__(s, t, m):
cdef fmpz_struct tval[1]
cdef int stype = FMPZ_UNKNOWN
cdef ulong exp
u = NotImplemented
if m is not None:
raise NotImplementedError("modular exponentiation")
ttype = fmpz_set_any_ref(tval, t)
if ttype != FMPZ_UNKNOWN:
u = fmpz.__new__(fmpz)
s_ulong = fmpz_get_ui(s.val)
fmpz_pow_ui((<fmpz>u).val, tval, s_ulong)
if ttype == FMPZ_TMP: fmpz_clear(tval)
return u

def gcd(self, other):
"""
Returns the greatest common divisor of self and other.
Expand Down

0 comments on commit 3aa4a76

Please sign in to comment.