Skip to content

Commit

Permalink
Lift type restrictions (JuliaStats#146)
Browse files Browse the repository at this point in the history
* Relax transform and reconstruct types
  • Loading branch information
Kolaru authored Mar 13, 2021
1 parent 617b5bb commit d24eba5
Show file tree
Hide file tree
Showing 20 changed files with 187 additions and 82 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
docs/build/
Manifest.toml
.vscode
4 changes: 2 additions & 2 deletions src/cca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ correlations(M::CCA) = M.corrs

## use

xtransform(M::CCA, X::AbstractVecOrMat{T}) where T<:Real = transpose(M.xproj) * centralize(X, M.xmean)
ytransform(M::CCA, Y::AbstractVecOrMat{T}) where T<:Real = transpose(M.yproj) * centralize(Y, M.ymean)
xtransform(M::CCA, X::AbstractVecOrMat{<:Real}) = transpose(M.xproj) * centralize(X, M.xmean)
ytransform(M::CCA, Y::AbstractVecOrMat{<:Real}) = transpose(M.yproj) * centralize(Y, M.ymean)

## show & dump

Expand Down
2 changes: 1 addition & 1 deletion src/cmds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ eigvals(M::MDS) = M.λ
## use

"""Calculate out-of-sample multidimensional scaling transformation"""
function transform(M::MDS{T}, x::AbstractVector{T}; distances=false) where {T<:Real}
function transform(M::MDS{T}, x::AbstractVector{<:Real}; distances=false) where {T<:Real}
d = if isnan(M.d) # model has only distance matrix
@assert distances "Cannot transform points if model was fitted with a distance matrix. Use point distances."
size(x, 1) != size(M.X, 1) && throw(
Expand Down
11 changes: 5 additions & 6 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ function calcscattermat(Z::DenseMatrix)
end

# calculate pairwise kernel
function pairwise!(K::AbstractVecOrMat{T}, kernel::Function,
X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:AbstractFloat
function pairwise!(K::AbstractVecOrMat{<:Real}, kernel::Function,
X::AbstractVecOrMat{<:Real}, Y::AbstractVecOrMat{<:Real})
n = size(X, 2)
m = size(Y, 2)
for j = 1:m
Expand All @@ -137,15 +137,14 @@ function pairwise!(K::AbstractVecOrMat{T}, kernel::Function,
K
end

pairwise!(K::AbstractVecOrMat{T}, kernel::Function, X::AbstractVecOrMat{T}) where T<:AbstractFloat =
pairwise!(K::AbstractVecOrMat{<:Real}, kernel::Function, X::AbstractVecOrMat{<:Real}) =
pairwise!(K, kernel, X, X)

function pairwise(kernel::Function, X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:AbstractFloat
function pairwise(kernel::Function, X::AbstractVecOrMat{<:Real}, Y::AbstractVecOrMat{<:Real})
n = size(X, 2)
m = size(Y, 2)
K = similar(X, n, m)
pairwise!(K, kernel, X, Y)
end

pairwise(kernel::Function, X::AbstractVecOrMat{T}) where T<:AbstractFloat =
pairwise(kernel, X, X)
pairwise(kernel::Function, X::AbstractVecOrMat{<:Real}) = pairwise(kernel, X, X)
4 changes: 2 additions & 2 deletions src/fa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ loadings(M::FactorAnalysis) = M.W

## use

function transform(m::FactorAnalysis{T}, x::AbstractVecOrMat{T}) where T<:Real
function transform(m::FactorAnalysis, x::AbstractVecOrMat{<:Real})
xn = centralize(x, mean(m))
W = m.W
WᵀΨ⁻¹ = W'*diagm(0 => 1 ./ m.Ψ) # (q x d) * (d x d) = (q x d)
return inv(I+WᵀΨ⁻¹*W)*(WᵀΨ⁻¹*xn) # (q x q) * (q x d) * (d x 1) = (q x 1)
end

function reconstruct(m::FactorAnalysis{T}, z::AbstractVecOrMat{T}) where T<:Real
function reconstruct(m::FactorAnalysis, z::AbstractVecOrMat{<:Real})
W = m.W
# ΣW(W'W)⁻¹z+μ = ΣW(W'W)⁻¹W'Σ⁻¹(x-μ)+μ = Σ(WW⁻¹)((W')⁻¹W')Σ⁻¹(x-μ)+μ = ΣΣ⁻¹(x-μ)+μ = (x-μ)+μ = x
return cov(m)*W*inv(W'W)*z .+ mean(m)
Expand Down
2 changes: 1 addition & 1 deletion src/ica.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ indim(M::ICA) = size(M.W, 1)
outdim(M::ICA) = size(M.W, 2)
mean(M::ICA) = fullmean(indim(M), M.mean)

transform(M::ICA, x::AbstractVecOrMat) = transpose(M.W) * centralize(x, M.mean)
transform(M::ICA, x::AbstractVecOrMat{<:Real}) = transpose(M.W) * centralize(x, M.mean)


#### core algorithm
Expand Down
8 changes: 4 additions & 4 deletions src/kpca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function fit(::Type{KernelCenter}, K::AbstractMatrix{T}) where {T<:Real}
end

"""Center kernel matrix."""
function transform!(C::KernelCenter{T}, K::AbstractMatrix{T}) where {T<:Real}
function transform!(C::KernelCenter, K::AbstractMatrix{<:Real})
r, c = size(K)
tot = C.total
means = mean(K, dims=1)
Expand Down Expand Up @@ -49,7 +49,7 @@ principalvars(M::KernelPCA) = M.λ
## use

"""Calculate transformation to kernel space"""
function transform(M::KernelPCA{T}, x::AbstractVecOrMat{T}) where {T<:Real}
function transform(M::KernelPCA, x::AbstractVecOrMat{<:Real})
k = pairwise(M.ker, M.X, x)
transform!(M.center, k)
return projection(M)'*k
Expand All @@ -58,7 +58,7 @@ end
transform(M::KernelPCA) = sqrt.(M.λ) .* M.α'

"""Calculate inverse transformation to original space"""
function reconstruct(M::KernelPCA{T}, y::AbstractVecOrMat{T}) where {T<:Real}
function reconstruct(M::KernelPCA, y::AbstractVecOrMat{<:Real})
if size(M.inv, 1) == 0
throw(ArgumentError("Inverse transformation coefficients are not available, set `inverse` parameter when fitting data"))
end
Expand All @@ -80,7 +80,7 @@ function fit(::Type{KernelPCA}, X::AbstractMatrix{T};
maxoutdim::Int = min(size(X)...),
remove_zero_eig::Bool = false, atol::Real = 1e-10,
solver::Symbol = :eig,
inverse::Bool = false, β::Real = 1.0,
inverse::Bool = false, β::Real = convert(T, 1.0),
tol::Real = 0.0, maxiter::Real = 300) where {T<:Real}

d, n = size(X)
Expand Down
2 changes: 1 addition & 1 deletion src/lda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ classweights(M::MulticlassLDA) = classweights(M.stats)
withclass_scatter(M::MulticlassLDA) = withclass_scatter(M.stats)
betweenclass_scatter(M::MulticlassLDA) = betweenclass_scatter(M.stats)

transform(M::MulticlassLDA, x::AbstractVecOrMat{T}) where T<:Real = M.proj'x
transform(M::MulticlassLDA, x::AbstractVecOrMat{<:Real}) = M.proj'x

function fit(::Type{MulticlassLDA}, nc::Int, X::DenseMatrix{T}, y::AbstractVector{Int};
method::Symbol=:gevd,
Expand Down
4 changes: 2 additions & 2 deletions src/pca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ principalratio(M::PCA) = M.tprinvar / M.tvar

## use

transform(M::PCA{T}, x::AbstractVecOrMat{T}) where {T<:Real} = transpose(M.proj) * centralize(x, M.mean)
reconstruct(M::PCA{T}, y::AbstractVecOrMat{T}) where {T<:Real} = decentralize(M.proj * y, M.mean)
transform(M::PCA, x::AbstractVecOrMat{<:Real}) = transpose(M.proj) * centralize(x, M.mean)
reconstruct(M::PCA, y::AbstractVecOrMat{<:Real}) = decentralize(M.proj * y, M.mean)

## show & dump

Expand Down
8 changes: 4 additions & 4 deletions src/ppca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ loadings(M::PPCA) = M.W

## use

function transform(m::PPCA{T}, x::AbstractVecOrMat{T}) where {T<:Real}
function transform(m::PPCA, x::AbstractVecOrMat{<:Real})
xn = centralize(x, m.mean)
W = m.W
n = outdim(m)
M = W'W .+ m.σ² * Matrix{T}(I, n, n)
M = W'W + m.σ² * I
return inv(M)*m.W'*xn
end

function reconstruct(m::PPCA{T}, z::AbstractVecOrMat{T}) where {T<:Real}
function reconstruct(m::PPCA, z::AbstractVecOrMat{<:Real})
W = m.W
WTW = W'W
n = outdim(m)
M = WTW .+ var(m) * Matrix{T}(I, n, n)
M = WTW + var(m) * I
return W*inv(WTW)*M*z .+ mean(m)
end

Expand Down
2 changes: 1 addition & 1 deletion src/whiten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ indim(f::Whitening) = size(f.W, 1)
outdim(f::Whitening) = size(f.W, 2)
mean(f::Whitening) = fullmean(indim(f), f.mean)

transform(f::Whitening, x::AbstractVecOrMat) = transpose(f.W) * centralize(x, f.mean)
transform(f::Whitening, x::AbstractVecOrMat{<:Real}) = transpose(f.W) * centralize(x, f.mean)

## Fit whitening to data

Expand Down
27 changes: 17 additions & 10 deletions test/cca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,21 @@ import Random
@test Px MultivariateStats.qnormalize!(Cxx \ (Cxy * Py), Cxx)
@test Py MultivariateStats.qnormalize!(Cyy \ (Cyx * Px), Cyy)

# different input type
XX = convert(Matrix{Float32}, X)
YY = convert(Matrix{Float32}, Y)
M = fit(CCA, view(XX, :, 1:400), view(YY, :, 1:400); method=:svd, outdim=p)
@test eltype(xmean(M)) == Float32
@test eltype(ymean(M)) == Float32
@test eltype(xprojection(M)) == Float32
@test eltype(yprojection(M)) == Float32
@test eltype(correlations(M)) == Float32

# different input types
XX = convert.(Float32, X)
YY = convert.(Float32, Y)

MM = fit(CCA, view(XX, :, 1:400), view(YY, :, 1:400); method=:svd, outdim=p)

# test that mixing types doesn't error
xtransform(M, XX)
ytransform(M, YY)
xtransform(MM, XX)
ytransform(MM, YY)

# type stability
for func in (xmean, ymean, xprojection, yprojection, correlations)
@test eltype(func(M)) == Float64
@test eltype(func(MM)) == Float32
end
end
20 changes: 20 additions & 0 deletions test/cmds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,24 @@ using Test
@test A MultivariateStats.pairwise((x,y)->sum(abs2, x-y), Y)
@test eltype(Y) == Float32

# different input types
d = 3
X = randn(Float64, d, 10)
XX = convert.(Float32, X)

y = randn(Float64, d)
yy = convert.(Float32, y)

M = fit(MDS, X, maxoutdim=3, distances=false)
MM = fit(MDS, XX, maxoutdim=3, distances=false)

# test that mixing types doesn't error
transform(M, yy)
transform(MM, y)

# type stability
for func in (projection, eigvals, stress)
@test eltype(func(M)) == Float64
@test eltype(func(MM)) == Float32
end
end
29 changes: 21 additions & 8 deletions test/fa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,27 @@ import Random
LL(m, x) = (-size(x,2)/2)*(size(x,1)*log(2π) + log(det(cov(m))) + tr(inv(cov(m))*cov(x, dims=2)))
@test LL(M1, X) LL(M2, X) # log likelihood

# test that fit works with Float32 values
X2 = convert(Array{Float32,2}, X)
# Float32 input
M = fit(FactorAnalysis, X2; method=:cm, maxoutdim=3)
M = fit(FactorAnalysis, X2; method=:em, maxoutdim=3)
# test with different types
XX = convert.(Float32, X)
YY = convert.(Float32, Y)

for method in (:cm, :em)
MM = fit(FactorAnalysis, XX; method=method, maxoutdim=3)

# mixing types
transform(M, XX)
transform(MM, X)
reconstruct(M, YY)
reconstruct(MM, Y)

# type stability
for func in (mean, projection, cov, var, loadings)
@test eltype(func(M)) == Float64
@test eltype(func(MM)) == Float32
end
end

# views
M = fit(FactorAnalysis, view(X2, :, 1:100), method=:cm, maxoutdim=3)
M = fit(FactorAnalysis, view(X2, :, 1:100), method=:em, maxoutdim=3)

M = fit(FactorAnalysis, view(XX, :, 1:100), method=:cm, maxoutdim=3)
M = fit(FactorAnalysis, view(XX, :, 1:100), method=:em, maxoutdim=3)
end
26 changes: 13 additions & 13 deletions test/ica.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,25 @@ import StatsBase
@test_throws StatsBase.ConvergenceException fit(ICA, X, k; do_whiten=true, tol=1e-8, maxiter=2)

# Use data of different type

XX = convert(Matrix{Float32}, X)

M = fit(ICA, XX, k; do_whiten=true, tol=Inf)
@test eltype(mean(M)) == Float32
@test eltype(M.W) == Float32

M = fit(ICA, XX, k; do_whiten=false, tol=Inf)
@test isa(M, ICA)
@test eltype(mean(M)) == Float32
@test eltype(M.W) == Float32
W = M.W
@test transform(M, X) W' * convert(Matrix{Float32}, (X .- μ))
MM = fit(ICA, XX, k; do_whiten=true, tol=Inf)
@test eltype(mean(MM)) == Float32
@test eltype(MM.W) == Float32

MM = fit(ICA, XX, k; do_whiten=false, tol=Inf)
@test isa(MM, ICA)
@test eltype(mean(MM)) == Float32
@test eltype(MM.W) == Float32
W = MM.W
@test transform(MM, X) W' * convert(Matrix{Float32}, (X .- μ))
@test transform(M, XX) M.W' * (X .- μ) atol=1e-4
@test W'W Matrix{Float32}(I, k, k)

# input as view
M = fit(ICA, view(XX, :, 1:400), k; do_whiten=true, tol=Inf)
@test eltype(mean(M)) == Float32
@test eltype(M.W) == Float32
@test eltype(mean(MM)) == Float32
@test eltype(MM.W) == Float32
end

end
30 changes: 24 additions & 6 deletions test/kpca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,30 @@ import Random

@test_throws TypeError fit(KernelPCA, rand(1,10), kernel=1)

# fit a Float32 matrix
X = randn(Float32, d, n)
M = fit(KernelPCA, X)
@test indim(M) == d
@test outdim(M) == d
@test eltype(transform(M, X[:,1])) == Float32
# different types
X = randn(Float64, d, n)
XX = convert.(Float32, X)

M = fit(KernelPCA, X ; inverse=true)
MM = fit(KernelPCA, XX ; inverse=true)

Y = randn(Float64, outdim(M))
YY = convert.(Float32, Y)

@test indim(MM) == d
@test outdim(MM) == d
@test eltype(transform(MM, X[:,1])) == Float32

for func in (projection, principalvars)
@test eltype(func(M)) == Float64
@test eltype(func(MM)) == Float32
end

# mixing types should not error
transform(M, XX)
transform(MM, X)
reconstruct(M, YY)
reconstruct(MM, Y)

## fit a sparse matrix
X = SparseArrays.sprandn(100d, n, 0.6)
Expand Down
32 changes: 22 additions & 10 deletions test/pca.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,28 @@ import SparseArrays

# Different data types
# --------------------
# test that fit works with Float32 values
X2 = convert(Array{Float32,2}, X)
# Float32 input, default pratio
M = fit(PCA, X2; maxoutdim=3)
# Float32 input, specified Float64 pratio
M = fit(PCA, X2, pratio=0.85)
# Float32 input, specified Float32 pratio
M = fit(PCA, X2, pratio=0.85f0)
# Float64 input, specified Float32 pratio
M = fit(PCA, X, pratio=0.85f0)

XX = convert.(Float32, X)
YY = convert.(Float32, Y)
p = 0.085
pp = convert(Float32, p)

MM = fit(PCA, XX; maxoutdim=3)

# mix types
fit(PCA, X ; pratio=pp)
fit(PCA, XX ; pratio=p)
fit(PCA, XX ; pratio=pp)
transform(M, XX)
transform(MM, X)
reconstruct(M, YY)
reconstruct(MM, Y)

# type consistency
for func in (mean, projection, principalvars, tprincipalvar, tresidualvar, tvar, principalratio)
@test eltype(func(M)) == Float64
@test eltype(func(MM)) == Float32
end

# views
M = fit(PCA, view(X, :, 1:500), pratio=0.85)
Expand Down
Loading

0 comments on commit d24eba5

Please sign in to comment.