Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefer cdef to def in mpoly operands #192

Merged
merged 2 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/flint/flint_base/flint_base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,37 @@ cdef class flint_mpoly_context(flint_elem):
cdef const char ** c_names

cdef class flint_mpoly(flint_elem):
pass
cdef _add_scalar_(self, other)
cdef _sub_scalar_(self, other)
cdef _mul_scalar_(self, other)

cdef _add_mpoly_(self, other)
cdef _sub_mpoly_(self, other)
cdef _mul_mpoly_(self, other)

cdef _divmod_mpoly_(self, other)
cdef _floordiv_mpoly_(self, other)
cdef _truediv_mpoly_(self, other)
cdef _mod_mpoly_(self, other)

cdef _rsub_scalar_(self, other)
cdef _rsub_mpoly_(self, other)

cdef _rdivmod_mpoly_(self, other)
cdef _rfloordiv_mpoly_(self, other)
cdef _rtruediv_mpoly_(self, other)
cdef _rmod_mpoly_(self, other)

cdef _pow_(self, other)

cdef _iadd_scalar_(self, other)
cdef _isub_scalar_(self, other)
cdef _imul_scalar_(self, other)

cdef _iadd_mpoly_(self, other)
cdef _isub_mpoly_(self, other)
cdef _imul_mpoly_(self, other)


cdef class flint_mat(flint_elem):
pass
Expand Down
218 changes: 131 additions & 87 deletions src/flint/flint_base/flint_base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ cdef class flint_scalar(flint_elem):
return self._invert_()



cdef class flint_poly(flint_elem):
"""
Base class for polynomials.
Expand Down Expand Up @@ -405,52 +404,73 @@ cdef class flint_mpoly(flint_elem):
if not other:
raise ZeroDivisionError("nmod_mpoly division by zero")

def _add_scalar_(self, other):
cdef _add_scalar_(self, other):
return NotImplemented

cdef _sub_scalar_(self, other):
return NotImplemented

cdef _mul_scalar_(self, other):
return NotImplemented

cdef _add_mpoly_(self, other):
return NotImplemented

cdef _sub_mpoly_(self, other):
return NotImplemented

cdef _mul_mpoly_(self, other):
return NotImplemented

def _add_mpoly_(self, other):
cdef _divmod_mpoly_(self, other):
return NotImplemented

def _iadd_scalar_(self, other):
cdef _floordiv_mpoly_(self, other):
return NotImplemented

def _iadd_mpoly_(self, other):
cdef _truediv_mpoly_(self, other):
return NotImplemented

def _sub_scalar_(self, other):
cdef _mod_mpoly_(self, other):
return NotImplemented

def _sub_mpoly_(self, other):
cdef _rsub_scalar_(self, other):
return NotImplemented

def _isub_scalar_(self, other):
cdef _rsub_mpoly_(self, other):
return NotImplemented

def _isub_mpoly_(self, other):
cdef _rdivmod_mpoly_(self, other):
return NotImplemented

def _mul_scalar_(self, other):
cdef _rfloordiv_mpoly_(self, other):
return NotImplemented

def _imul_mpoly_(self, other):
cdef _rtruediv_mpoly_(self, other):
return NotImplemented

def _imul_scalar_(self, other):
cdef _rmod_mpoly_(self, other):
return NotImplemented

def _mul_mpoly_(self, other):
cdef _pow_(self, other):
return NotImplemented

def _pow_(self, other):
cdef _iadd_scalar_(self, other):
return NotImplemented

def _divmod_mpoly_(self, other):
cdef _isub_scalar_(self, other):
return NotImplemented

def _floordiv_mpoly_(self, other):
cdef _imul_scalar_(self, other):
return NotImplemented

def _truediv_mpoly_(self, other):
cdef _iadd_mpoly_(self, other):
return NotImplemented

cdef _isub_mpoly_(self, other):
return NotImplemented

cdef _imul_mpoly_(self, other):
return NotImplemented

def __add__(self, other):
Expand All @@ -465,32 +485,15 @@ cdef class flint_mpoly(flint_elem):
return self._add_scalar_(other)

def __radd__(self, other):
return self.__add__(other)

def iadd(self, other):
"""
In-place addition, mutates self.

>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.iadd(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 + 5

"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._iadd_mpoly_(other)
return
return self._add_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot add {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._iadd_scalar_(other_scalar)
return self._add_scalar_(other)

def __sub__(self, other):
if typecheck(other, type(self)):
Expand All @@ -504,32 +507,15 @@ cdef class flint_mpoly(flint_elem):
return self._sub_scalar_(other)

def __rsub__(self, other):
return -self.__sub__(other)

def isub(self, other):
"""
In-place subtraction, mutates self.

>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.isub(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 - 5

"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._isub_mpoly_(other)
return
return self._rsub_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot subtract {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._isub_scalar_(other_scalar)
return self._rsub_scalar_(other)

def __mul__(self, other):
if typecheck(other, type(self)):
Expand All @@ -543,32 +529,15 @@ cdef class flint_mpoly(flint_elem):
return self._mul_scalar_(other)

def __rmul__(self, other):
return self.__mul__(other)

def imul(self, other):
"""
In-place multiplication, mutates self.

>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.imul(2)
>>> f
8*x0*x1 + 4*x0 + 6*x1

"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._imul_mpoly_(other)
return
return self._mul_mpoly_(other)

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot multiply {type(self)} and {type(other)}")
other = self.context().any_as_scalar(other)
if other is NotImplemented:
return NotImplemented

self._imul_scalar_(other_scalar)
return self._mul_scalar_(other)

def __pow__(self, other, modulus):
if modulus is not None:
Expand Down Expand Up @@ -605,7 +574,7 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._divmod_mpoly_(self)
return self._rdivmod_mpoly_(other)

def __truediv__(self, other):
if typecheck(other, type(self)):
Expand All @@ -628,7 +597,7 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._truediv_mpoly_(self)
return self._rtruediv_mpoly_(other)

def __floordiv__(self, other):
if typecheck(other, type(self)):
Expand All @@ -651,7 +620,7 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._floordiv_mpoly_(self)
return self._rfloordiv_mpoly_(other)

def __mod__(self, other):
if typecheck(other, type(self)):
Expand All @@ -674,7 +643,82 @@ cdef class flint_mpoly(flint_elem):

other = self.context().scalar_as_mpoly(other)
other._division_check(self)
return other._mod_mpoly_(self)
return self._rmod_mpoly_(other)

def iadd(self, other):
"""
In-place addition, mutates self.

>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.iadd(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 + 5

"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._iadd_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot add {type(self)} and {type(other)}")

self._iadd_scalar_(other_scalar)

def isub(self, other):
"""
In-place subtraction, mutates self.

>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.isub(5)
>>> f
4*x0*x1 + 2*x0 + 3*x1 - 5

"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._isub_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot subtract {type(self)} and {type(other)}")

self._isub_scalar_(other_scalar)

def imul(self, other):
"""
In-place multiplication, mutates self.

>>> from flint import Ordering, fmpz_mpoly_ctx
>>> ctx = fmpz_mpoly_ctx.get_context(2, Ordering.lex, 'x')
>>> f = ctx.from_dict({(1, 0): 2, (0, 1): 3, (1, 1): 4})
>>> f
4*x0*x1 + 2*x0 + 3*x1
>>> f.imul(2)
>>> f
8*x0*x1 + 4*x0 + 6*x1

"""
if typecheck(other, type(self)):
self.context().compatible_context_check(other.context())
self._imul_mpoly_(other)
return

other_scalar = self.context().any_as_scalar(other)
if other_scalar is NotImplemented:
raise NotImplementedError(f"cannot multiply {type(self)} and {type(other)}")

self._imul_scalar_(other_scalar)

def __contains__(self, x):
"""
Expand Down
Loading
Loading