Skip to content

Commit

Permalink
Include nmod_mpoly and fmpz_mod_mpoly in factor test
Browse files Browse the repository at this point in the history
  • Loading branch information
oscarbenjamin committed Aug 18, 2024
1 parent 1abfb62 commit 01e6418
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 32 deletions.
66 changes: 37 additions & 29 deletions src/flint/test/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -2481,41 +2481,41 @@ def test_division_matrix():
def _all_polys():
return [
# (poly_type, scalar_type, is_field)
(flint.fmpz_poly, flint.fmpz, False),
(flint.fmpq_poly, flint.fmpq, True),
(lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True),
(flint.fmpz_poly, flint.fmpz, False, flint.fmpz(0)),
(flint.fmpq_poly, flint.fmpq, True, flint.fmpz(0)),
(lambda *a: flint.nmod_poly(*a, 17), lambda x: flint.nmod(x, 17), True, flint.fmpz(17)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(163)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(163)),
True),
True, flint.fmpz(163)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**127 - 1)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**127 - 1)),
True),
True, flint.fmpz(2**127 - 1)),
(lambda *a: flint.fmpz_mod_poly(*a, flint.fmpz_mod_poly_ctx(2**255 - 19)),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(2**255 - 19)),
True),
True, flint.fmpz(2**255 - 19)),
(lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(2**127 - 1)),
lambda x: flint.fq_default(x, flint.fq_default_ctx(2**127 - 1)),
True),
True, flint.fmpz(2**127 - 1)),
(lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(2**127 - 1, 2)),
lambda x: flint.fq_default(x, flint.fq_default_ctx(2**127 - 1, 2)),
True),
True, flint.fmpz(2**127 - 1)),
(lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(65537)),
lambda x: flint.fq_default(x, flint.fq_default_ctx(65537)),
True),
True, flint.fmpz(65537)),
(lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(65537, 5)),
lambda x: flint.fq_default(x, flint.fq_default_ctx(65537, 5)),
True),
True, flint.fmpz(65537)),
(lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(11)),
lambda x: flint.fq_default(x, flint.fq_default_ctx(11)),
True),
True, flint.fmpz(11)),
(lambda *a: flint.fq_default_poly(*a, flint.fq_default_poly_ctx(11, 5)),
lambda x: flint.fq_default(x, flint.fq_default_ctx(11, 5)),
True),
True, flint.fmpz(11)),
]


def test_polys():
for P, S, is_field in _all_polys():
for P, S, is_field, characteristic in _all_polys():
assert P([S(1)]) == P([1]) == P(P([1])) == P(1)

assert raises(lambda: P([None]), TypeError)
Expand Down Expand Up @@ -2750,37 +2750,41 @@ def setbad(obj, i, val):

def _all_mpolys():
return [
(flint.fmpz_mpoly, flint.fmpz_mpoly_ctx.get_context, flint.fmpz, False),
(flint.fmpq_mpoly, flint.fmpq_mpoly_ctx.get_context, flint.fmpq, True),
(flint.fmpz_mpoly, flint.fmpz_mpoly_ctx.get_context, flint.fmpz, False, flint.fmpz(0)),
(flint.fmpq_mpoly, flint.fmpq_mpoly_ctx.get_context, flint.fmpq, True, flint.fmpz(0)),
(
flint.fmpz_mod_mpoly,
lambda *args, **kwargs: flint.fmpz_mod_mpoly_ctx.get_context(*args, **kwargs, modulus=101),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(101)),
True,
flint.fmpz(101),
),
(
flint.fmpz_mod_mpoly,
lambda *args, **kwargs: flint.fmpz_mod_mpoly_ctx.get_context(*args, **kwargs, modulus=100),
lambda x: flint.fmpz_mod(x, flint.fmpz_mod_ctx(100)),
False,
flint.fmpz(100),
),
(
flint.nmod_mpoly,
lambda *args, **kwargs: flint.nmod_mpoly_ctx.get_context(*args, **kwargs, modulus=101),
lambda x: flint.nmod(x, 101),
True,
flint.fmpz(101),
),
(
flint.nmod_mpoly,
lambda *args, **kwargs: flint.nmod_mpoly_ctx.get_context(*args, **kwargs, modulus=100),
lambda x: flint.nmod(x, 100),
False,
flint.fmpz(100),
),
]


def test_mpolys():
for P, get_context, S, is_field in _all_mpolys():
for P, get_context, S, is_field, characteristic in _all_mpolys():

# Division under modulo will raise a flint exception if something is not invertible, crashing the program. We
# can't tell before what is invertible and what is not before hand so we always raise an exception, except for
Expand Down Expand Up @@ -3235,7 +3239,7 @@ def test_fmpz_mpoly_vec():

def _all_polys_mpolys():

for P, S, is_field in _all_polys():
for P, S, is_field, characteristic in _all_polys():
x = P([0, 1])
y = None
assert isinstance(x, (
Expand All @@ -3245,18 +3249,18 @@ def _all_polys_mpolys():
flint.fmpz_mod_poly,
flint.fq_default_poly,
))
characteristic_zero = isinstance(x, (flint.fmpz_poly, flint.fmpq_poly))
yield P, S, [x, y], is_field, characteristic_zero
yield P, S, [x, y], is_field, characteristic

for P, ctx_type, S, is_field in _all_mpolys():
ctx = ctx_type(2, flint.Ordering.lex, ["x", "y"])
for P, get_context, S, is_field, characteristic in _all_mpolys():
ctx = get_context(2, flint.Ordering.lex, nametup=("x", "y"))
x, y = ctx.gens()
assert isinstance(x, (
flint.fmpz_mpoly,
flint.fmpq_mpoly,
flint.nmod_mpoly,
flint.fmpz_mod_mpoly,
))
characteristic_zero = isinstance(x, (flint.fmpz_mpoly, flint.fmpq_mpoly))
yield P, S, [x, y], is_field, characteristic_zero
yield P, S, [x, y], is_field, characteristic


def test_factor_poly_mpoly():
Expand All @@ -3273,7 +3277,11 @@ def factor(p):
assert type(m) is int
return coeff, sorted(factors, key=lambda p: (p[1], str(p[0])))

for P, S, [x, y], is_field, characteristic_zero in _all_polys_mpolys():
for P, S, [x, y], is_field, characteristic in _all_polys_mpolys():

if characteristic != 0 and not characteristic.is_prime():
assert raises(lambda: x.factor(), DomainError)
continue

assert factor(0*x) == (S(0), [])
assert factor(0*x + 1) == (S(1), [])
Expand All @@ -3284,15 +3292,15 @@ def factor(p):
assert factor(x*(x+1)) == (S(1), [(x, 1), (x+1, 1)])
assert factor(2*(x+1)) == (S(2), [(x+1, 1)])

if characteristic_zero:
if characteristic == 0:
# primitive factors over Z for Z and Q.
assert factor(2*x+1) == (S(1), [(2*x+1, 1)])
else:
# monic factors over Z/pZ and GF(p^d)
assert factor(2*x+1) == (S(2), [(x+S(1)/2, 1)])

if is_field:
if characteristic_zero:
if characteristic == 0:
assert factor((2*x+1)/7) == (S(1)/7, [(2*x+1, 1)])
else:
assert factor((2*x+1)/7) == (S(2)/7, [(x+S(1)/2, 1)])
Expand All @@ -3302,13 +3310,13 @@ def factor(p):
assert factor(x*y+1) == (S(1), [(x*y+1, 1)])
assert factor(x*y) == (S(1), [(x, 1), (y, 1)])

if characteristic_zero:
if characteristic == 0:
assert factor(2*x + y) == (S(1), [(2*x + y, 1)])
else:
assert factor(2*x + y) == (S(1)/2, [(x + y/2, 1)])
assert factor(2*x + y) == (S(2), [(x + y/2, 1)])

if is_field:
if characteristic_zero:
if characteristic == 0:
assert factor((2*x+y)/7) == (S(1)/7, [(2*x+y, 1)])
else:
assert factor((2*x+y)/7) == (S(2)/7, [(x+y/2, 1)])
Expand Down
8 changes: 7 additions & 1 deletion src/flint/types/fmpz_mod_mpoly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,12 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
fmpz_mod_mpoly_total_degree_fmpz((<fmpz> res).val, self.val, self.ctx.val)
return res

def leading_coefficient(self):
if fmpz_mod_mpoly_is_zero(self.val, self.ctx.val):
return fmpz(0)
else:
return self.coefficient(0)

def repr(self):
return f"{self.ctx}.from_dict({self.to_dict()})"

Expand Down Expand Up @@ -846,7 +852,7 @@ cdef class fmpz_mod_mpoly(flint_mpoly):
c = fmpz.__new__(fmpz)
fmpz_set((<fmpz>c).val, &fac.exp[i])

res[i] = (u, c)
res[i] = (u, int(c))

c = fmpz.__new__(fmpz)
fmpz_set((<fmpz>c).val, fac.constant)
Expand Down
10 changes: 8 additions & 2 deletions src/flint/types/nmod_mpoly.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,12 @@ cdef class nmod_mpoly(flint_mpoly):
nmod_mpoly_total_degree_fmpz((<fmpz> res).val, self.val, self.ctx.val)
return res

def leading_coefficient(self):
if nmod_mpoly_is_zero(self.val, self.ctx.val):
return nmod(0, self.ctx.modulus())
else:
return nmod(self.coefficient(0), self.ctx.modulus())

def repr(self):
return f"{self.ctx}.from_dict({self.to_dict()})"

Expand Down Expand Up @@ -822,9 +828,9 @@ cdef class nmod_mpoly(flint_mpoly):
c = fmpz.__new__(fmpz)
fmpz_set((<fmpz>c).val, &fac.exp[i])

res[i] = (u, c)
res[i] = (u, int(c))

constant = fac.constant
constant = nmod(fac.constant, self.ctx.modulus())
nmod_mpoly_factor_clear(fac, self.ctx.val)
return constant, res

Expand Down

0 comments on commit 01e6418

Please sign in to comment.