Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: extend linalg and array operations to work on Numbers when it makes sense #1871

Merged
merged 1 commit into from
Jan 2, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ end

## find ##

function nnz(a::StridedArray)
function nnz(a)
n = 0
for i = 1:numel(a)
n += bool(a[i]) ? 1 : 0
Expand All @@ -1101,7 +1101,7 @@ function nnz(a::StridedArray)
end

# returns the index of the first non-zero element, or 0 if all zeros
function findfirst(A::StridedArray)
function findfirst(A)
for i = 1:length(A)
if A[i] != 0
return i
Expand All @@ -1111,7 +1111,7 @@ function findfirst(A::StridedArray)
end

# returns the index of the first matching element
function findfirst(A::StridedArray, v)
function findfirst(A, v)
for i = 1:length(A)
if A[i] == v
return i
Expand All @@ -1121,7 +1121,7 @@ function findfirst(A::StridedArray, v)
end

# returns the index of the first element for which the function returns true
function findfirst(testf::Function, A::StridedArray)
function findfirst(testf::Function, A)
for i = 1:length(A)
if testf(A[i])
return i
Expand Down Expand Up @@ -1157,6 +1157,10 @@ function find(A::StridedArray)
return I
end

find(x::Number) = x == 0 ? Array(Int,0) : [1]
find(x::Bool) = x ? [1] : Array(Int,0)
find(testf::Function, x) = find(testf(x))

findn(A::StridedVector) = find(A)

function findn(A::StridedMatrix)
Expand Down Expand Up @@ -1241,7 +1245,9 @@ function nonzeros{T}(A::StridedArray{T})
return V
end

function findmax(a::StridedArray)
nonzeros(x::Number) = x == 0 ? Array(typeof(x),0) : [x]

function findmax(a)
m = typemin(eltype(a))
mi = 0
for i=1:length(a)
Expand All @@ -1253,7 +1259,7 @@ function findmax(a::StridedArray)
return (m, mi)
end

function findmin(a::StridedArray)
function findmin(a)
m = typemax(eltype(a))
mi = 0
for i=1:length(a)
Expand All @@ -1264,8 +1270,9 @@ function findmin(a::StridedArray)
end
return (m, mi)
end
indmax(a::StridedArray) = findmax(a)[2]
indmin(a::StridedArray) = findmin(a)[2]

indmax(a) = findmax(a)[2]
indmin(a) = findmin(a)[2]

## Reductions ##

Expand Down
15 changes: 15 additions & 0 deletions base/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,21 @@ function norm(A::AbstractMatrix, p)
end

norm(A::AbstractMatrix) = norm(A, 2)

norm(x::Number) = abs(x)
norm(x::Number, p) = abs(x)

rank(A::AbstractMatrix, tol::Real) = sum(svdvals(A) .> tol)
function rank(A::AbstractMatrix)
m,n = size(A)
if m == 0 || n == 0; return 0; end
sv = svdvals(A)
sum(sv .> max(size(A,1),size(A,2))*eps(sv[1]))
end
rank(x::Number) = x == 0 ? 0 : 1

trace(A::AbstractMatrix) = sum(diag(A))
trace(x::Number) = x

#kron(a::AbstractVector, b::AbstractVector)
#kron{T,S}(a::AbstractMatrix{T}, b::AbstractMatrix{S})
Expand All @@ -109,6 +115,9 @@ function cond(a::AbstractMatrix, p)
end
end

cond(x::Number) = x == 0 ? Inf : 1
cond(x::Number, p) = cond(x)

function issym(A::AbstractMatrix)
m, n = size(A)
if m != n; error("matrix must be square, got $(m)x$(n)"); end
Expand All @@ -120,6 +129,8 @@ function issym(A::AbstractMatrix)
return true
end

issym(x::Number) = true

function ishermitian(A::AbstractMatrix)
m, n = size(A)
if m != n; error("matrix must be square, got $(m)x$(n)"); end
Expand All @@ -131,6 +142,8 @@ function ishermitian(A::AbstractMatrix)
return true
end

ishermitian(x::Number) = isreal(x)

function istriu(A::AbstractMatrix)
m, n = size(A)
for j = 1:min(n,m-1), i = j+1:m
Expand All @@ -151,6 +164,8 @@ function istril(A::AbstractMatrix)
return true
end

istriu(x::Number) = true
istril(x::Number) = true

function linreg{T<:Number}(X::StridedVecOrMat{T}, y::Vector{T})
[ones(T, size(X,1)) X] \ y
Expand Down
29 changes: 25 additions & 4 deletions base/linalg_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ isposdef{T<:BlasFloat}(A::Matrix{T}) = isposdef!(copy(A))
isposdef{T<:Number}(A::Matrix{T}, upper::Bool) = isposdef!(float64(A), upper)
isposdef{T<:Number}(A::Matrix{T}) = isposdef!(float64(A))

isposdef(x::Number) = isreal(x) && x > 0

norm{T<:BlasFloat}(x::Vector{T}) = BLAS.nrm2(length(x), x, 1)

function norm{T<:BlasFloat, TI<:Integer}(x::Vector{T}, rx::Union(Range1{TI},Range{TI}))
Expand Down Expand Up @@ -137,6 +139,8 @@ end

diagm(v) = diagm(v, 0)

diagm(x::Number) = (X = Array(typeof(x),1,1); X[1,1] = x; X)

function trace{T}(A::Matrix{T})
t = zero(T)
for i=1:min(size(A))
Expand Down Expand Up @@ -165,6 +169,12 @@ function kron{T,S}(a::Matrix{T}, b::Matrix{S})
R
end

kron(a::Number, b::Number) = a * b
kron(a::Vector, b::Number) = a * b
kron(a::Number, b::Vector) = a * b
kron(a::Matrix, b::Number) = a * b
kron(a::Number, b::Matrix) = a * b

function randsym(n)
a = randn(n,n)
for j=1:n-1, i=j+1:n
Expand Down Expand Up @@ -235,6 +245,8 @@ function rref{T}(A::Matrix{T})
return U
end

rref(x::Number) = one(x)

## Destructive matrix exponential using algorithm from Higham, 2008,
## "Functions of Matrices: Theory and Computation", SIAM
function expm!{T<:BlasFloat}(A::StridedMatrix{T})
Expand Down Expand Up @@ -377,6 +389,7 @@ end
# Matrix exponential
expm{T<:Union(Float32,Float64,Complex64,Complex128)}(A::StridedMatrix{T}) = expm!(copy(A))
expm{T<:Integer}(A::StridedMatrix{T}) = expm!(float(A))
expm(x::Number) = exp(x)

## Matrix factorizations and decompositions

Expand Down Expand Up @@ -451,6 +464,7 @@ chold{T<:Number}(A::Matrix{T}) = chold(A, true)

## Matlab (and R) compatible
chol{T<:Number}(A::Matrix{T}) = factors(chold(A))
chol(x::Number) = isreal(x) && x >= 0 ? sqrt(x) : error("argument not positive-definite")

type CholeskyDensePivoted{T<:BlasFloat} <: Factorization{T}
LR::Matrix{T}
Expand Down Expand Up @@ -537,6 +551,7 @@ lud{T<:Number}(A::Matrix{T}) = lud(float64(A))

## Matlab-compatible
lu{T<:Number}(A::Matrix{T}) = factors(lud(A))
lu(x::Number) = (one(x), x)

function det{T}(lu::LUDense{T})
m, n = size(lu)
Expand All @@ -551,6 +566,8 @@ function det(A::Matrix)
return det(lud(A))
end

det(x::Number) = x

function (\){T<:BlasFloat}(lu::LUDense{T}, B::StridedVecOrMat{T})
if lu.info > 0; throw(LAPACK.SingularException(info)); end
LAPACK.getrs!('N', lu.lu, lu.ipiv, copy(B))
Expand Down Expand Up @@ -585,6 +602,7 @@ function factors{T<:BlasFloat}(qrd::QRDense{T})
end

qr{T<:Number}(x::StridedMatrix{T}) = factors(qrd(x))
qr(x::Number) = (one(x), x)

## Multiplication by Q from the QR decomposition
(*){T<:BlasFloat}(A::QRDense{T}, B::StridedVecOrMat{T}) =
Expand Down Expand Up @@ -688,8 +706,9 @@ function eig{T<:BlasFloat}(A::StridedMatrix{T}, vecs::Bool)
end

eig{T<:Integer}(x::StridedMatrix{T}, vecs::Bool) = eig(float64(x), vecs)
eig(x::StridedMatrix) = eig(x, true)
eigvals(x::StridedMatrix) = eig(x, false)
eig(x::Number, vecs::Bool) = vecs ? (x, one(x)) : x
eig(x) = eig(x, true)
eigvals(x) = eig(x, false)

# This is the svd based on the LAPACK GESVD, which is slower, but takes
# lesser memory. It should be made available through a keyword argument
Expand Down Expand Up @@ -721,8 +740,9 @@ function svd{T<:BlasFloat}(A::StridedMatrix{T},vecs::Bool,thin::Bool)
end

svd{T<:Integer}(x::StridedMatrix{T},vecs,thin) = svd(float64(x),vecs,thin)
svd(A::StridedMatrix) = svd(A,true,false)
svd(A::StridedMatrix, thin::Bool) = svd(A,true,thin)
svd(x::Number,vecs::Bool,thin::Bool) = vecs ? (x==0?one(x):x/abs(x),abs(x),one(x)) : ([],abs(x),[])
svd(A) = svd(A,true,false)
svd(A, thin::Bool) = svd(A,true,thin)
svdvals(A) = svd(A,false,true)[2]

function (\){T<:BlasFloat}(A::StridedMatrix{T}, B::StridedVecOrMat{T})
Expand Down Expand Up @@ -776,6 +796,7 @@ function pinv{T<:BlasFloat}(A::StridedMatrix{T})
end
pinv{T<:Integer}(A::StridedMatrix{T}) = pinv(float(A))
pinv(a::StridedVector) = pinv(reshape(a, length(a), 1))
pinv(x::Number) = one(x)/x

## Basis for null space
function null{T<:BlasFloat}(A::StridedMatrix{T})
Expand Down
2 changes: 2 additions & 0 deletions base/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ function dot(x::Vector, y::Vector)
s
end

dot(x::Number, y::Number) = conj(x) * y

# Matrix-vector multiplication

function (*){T<:BlasFloat}(A::StridedMatrix{T},
Expand Down
1 change: 0 additions & 1 deletion base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ ctranspose(x::Number) = conj(x)
inv(x::Number) = one(x)/x
angle(z::Real) = atan2(zero(z), z)

# TODO: should we really treat numbers as iterable?
start(a::Real) = a
next(a::Real, i) = (a, a+1)
done(a::Real, i) = (i > a)
Expand Down