Skip to content

Commit

Permalink
hermite normal form fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
KlausC committed Sep 10, 2024
1 parent 7f462d2 commit f723ec8
Showing 1 changed file with 121 additions and 9 deletions.
130 changes: 121 additions & 9 deletions src/determinant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ function det_MV(a::AbstractMatrix{D}) where D<:Union{Ring,Integer}
end


#TODO fix algorithm. normal forms to dedicated file
#TODO normal forms to dedicated file
"""
hermite_normal_form(A::AbstractMatrix; column_style=true, round=RoundUp) -> H, U
Expand All @@ -349,7 +349,11 @@ See [Wiki](https://en.wikipedia.org/wiki/Hermite_normal_form)
The algorithm was generalized to matrices of arbitrary rank and shape.
"""
function hermite_normal_form(a::AbstractMatrix{R}; column_style::Bool=true, round::RoundingMode=RoundUp) where R<:Union{Ring,Integer}
function hermite_normal_form(
a::AbstractMatrix{R};
column_style::Bool = true,
round::RoundingMode = RoundUp,
) where R<:Union{Ring,Integer}
m, n = size(a)

if column_style
Expand All @@ -362,7 +366,11 @@ function hermite_normal_form(a::AbstractMatrix{R}; column_style::Bool=true, roun
end
end

function hermite_normal_form!(a::AbstractMatrix{R}, u::AbstractMatrix{R}, round::RoundingMode) where R
function hermite_normal_form!(
a::AbstractMatrix{R},
u::AbstractMatrix{R},
round::RoundingMode,
) where R
m, n = size(a)
piv = something.(findfirst.(!iszero, eachcol(a)), m + 1)
n = something(findlast(x -> x <= m, piv), 0)
Expand All @@ -386,11 +394,7 @@ function hermite_normal_form!(a::AbstractMatrix{R}, u::AbstractMatrix{R}, round:
elseif pi == pj && pi <= m
ajj = a[pj, j]
aji = a[pj, i]
r, p, q = gcdx(ajj, aji)
@assert !iszero(r)
#@assert max(abs(p), abs(q)) <= max(abs(ajj), abs(aji))
pp = -div(aji, r)
qq = div(ajj, r)
r, q, p, qq, pp = gcdex(aji, ajj)
for k = 1:m
akj = a[k, j]
aki = a[k, i]
Expand All @@ -399,7 +403,7 @@ function hermite_normal_form!(a::AbstractMatrix{R}, u::AbstractMatrix{R}, round:
a[k, j] = bkj
a[k, i] = bki
end
for k = 1:n
for k = 1:size(u, 1)
akj = u[k, j]
aki = u[k, i]
bkj = akj * p + aki * q
Expand All @@ -418,6 +422,13 @@ function hermite_normal_form!(a::AbstractMatrix{R}, u::AbstractMatrix{R}, round:
a, u
end

function gcdex(a, b)
r, p, q = gcdx(a, b)
pp = -div(b, r)
qq = div(a, r)
r, p, q, pp, qq
end

function swap(a::AbstractMatrix, i, j)
m = size(a, 1)
for k = 1:m
Expand Down Expand Up @@ -448,3 +459,104 @@ end

Base.div(a::T, b::T, round::RoundingMode) where T<:ZZ = T(div(value(a), value(b), round))
Base.div(a::T, b::T, ::RoundingMode) where T<:Ring = div(a, b)

#
# u, v unimodular, d diagonal with mod(d[i+1,i+1], d[i,i]) == 0
# u * a * v = d
function smith_normal!(a::AbstractMatrix, u::AbstractMatrix, v::AbstractMatrix)
b = copy(a)
m, n = size(a)
m == size(u, 1) || throw(DimensionMismatch("left square matrix"))
n == size(v, 2) || throw(DimensionMismatch("right square matrix"))
i = 1
while i <= min(m, n)
A = view(a, i:m, i:n)
U = view(u, i:m, 1:size(u, 2))
V = view(v, 1:size(v, 1), i:n)

hermite_normal_form!(A, V, RoundUp)
swap_zero_rows!(A, U)
stop = false
while !stop
hermite_left!(A, U)
hermite_normal_form!(A, V, RoundUp)
stop = is_unit_column(A)
if stop
k = first_non_multiple_column(A)
stop = k == 0
if k > 1
A[:, 1] .+= A[:, k]
V[:, 1] .+= V[:, k]
end
end
end
i += 1
end
u, a, v
end

function hermite_left!(a, u)
m, n = size(a)
j = 1
for i = 2:m
aij = a[i, j]
iszero(aij) && continue
ajj = a[j, j]
g, q, p, qq, pp = gcdex(aij, ajj)
for k = 1:n
akj = a[j, k]
aki = a[i, k]
bkj = akj * p + aki * q
bki = akj * pp + aki * qq
a[j, k] = bkj
a[i, k] = bki
end
for k = 1:size(u, 2)
akj = u[j, k]
aki = u[i, k]
bkj = akj * p + aki * q
bki = akj * pp + aki * qq
u[j, k] = bkj
u[i, k] = bki
end
end
end

function swap_zero_rows!(A, U)
m, n = size(A)
k = 1
while k <= m && iszero(A[k, 1])
k += 1
end
if k > 1
for j = k:m
A[j-k+1, :] .= A[j, :]
A[j, :] .= 0
for i = 1:size(U, 2)
ukk = U[j-k+1, i]
U[j-k+1, i] = U[j, i]
U[j, i] = ukk
end
end
end
end

function is_unit_column(A)
iszero(A[2:size(A, 1), 1])
end

function isunit(a::Integer)
abs(a) == 1
end

function first_non_multiple_column(A)
m, n = size(A)
akk = A[1, 1]
for j = 1:n
for i = j:m
aij = A[i, j]
iszero(aij) || iszero(mod(aij, akk)) || return j
end
end
return 0
end

0 comments on commit f723ec8

Please sign in to comment.