-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MIT-licensed sparse() parent method and expert driver #14798
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -335,7 +335,225 @@ function sparse_IJ_sorted!{Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector | |
return SparseMatrixCSC(m, n, colptr, I, V) | ||
end | ||
|
||
## sparse() can take its inputs in unsorted order (the parent method is now in csparse.jl) | ||
""" | ||
sparse(I, J, V,[ m, n, combine]) | ||
|
||
Create a sparse matrix `S` of dimensions `m x n` such that `S[I[k], J[k]] = V[k]`. The | ||
`combine` function is used to combine duplicates. If `m` and `n` are not specified, they | ||
are set to `maximum(I)` and `maximum(J)` respectively. If the `combine` function is not | ||
supplied, `combine` defaults to `+` unless the elements of `V` are Booleans in which case | ||
`combine` defaults to `|`. All elements of `I` must satisfy `1 <= I[k] <= m`, and all | ||
elements of `J` must satisfy `1 <= J[k] <= n`. Numerical zeros in (`I`, `J`, `V`) are | ||
retained as structural nonzeros. | ||
|
||
For additional documentation and an expert driver, see `Base.SparseArrays.sparse!`. | ||
""" | ||
function sparse{Tv,Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector{Ti}, V::AbstractVector{Tv}, m::Integer, n::Integer, combine) | ||
coolen = length(I) | ||
if length(J) != coolen || length(V) != coolen | ||
throw(ArgumentError(string("the first three arguments' lengths must match, ", | ||
"length(I) (=$(length(I))) == length(J) (= $(length(J))) == length(V) (= ", | ||
"$(length(V)))"))) | ||
end | ||
|
||
if m == 0 || n == 0 || coolen == 0 | ||
if coolen != 0 | ||
if n == 0 | ||
throw(ArgumentError("column indices J[k] must satisfy 1 <= J[k] <= n")) | ||
elseif m == 0 | ||
throw(ArgumentError("row indices I[k] must satisfy 1 <= I[k] <= m")) | ||
end | ||
end | ||
SparseMatrixCSC(m, n, ones(Ti, n+1), Vector{Ti}(), Vector{Tv}()) | ||
else | ||
# Allocate storage for CSR form | ||
csrrowptr = Vector{Ti}(m+1) | ||
csrcolval = Vector{Ti}(coolen) | ||
csrnzval = Vector{Tv}(coolen) | ||
|
||
# Allocate storage for the CSC form's column pointers and a necessary workspace | ||
csccolptr = Vector{Ti}(n+1) | ||
klasttouch = Vector{Ti}(n) | ||
|
||
# Allocate empty arrays for the CSC form's row and nonzero value arrays | ||
# The parent method called below automagically resizes these arrays | ||
cscrowval = Vector{Ti}() | ||
cscnzval = Vector{Tv}() | ||
|
||
sparse!(I, J, V, m, n, combine, klasttouch, | ||
csrrowptr, csrcolval, csrnzval, | ||
csccolptr, cscrowval, cscnzval ) | ||
end | ||
end | ||
|
||
""" | ||
sparse!{Tv,Ti<:Integer}( | ||
I::AbstractVector{Ti}, J::AbstractVector{Ti}, V::AbstractVector{Tv}, | ||
m::Integer, n::Integer, combine, klasttouch::Vector{Ti}, | ||
csrrowptr::Vector{Ti}, csrcolval::Vector{Ti}, csrnzval::Vector{Tv}, | ||
[csccolptr::Vector{Ti}], [cscrowval::Vector{Ti}, cscnzval::Vector{Tv}] ) | ||
|
||
Parent of and expert driver for `sparse`; see `sparse` for basic usage. This method | ||
allows the user to provide preallocated storage for `sparse`'s intermediate objects and | ||
result as described below. This capability enables more efficient successive construction | ||
of `SparseMatrixCSC`s from coordinate representations, and also enables extraction of an | ||
unsorted-column representation of the result's transpose at no additional cost. | ||
|
||
This method consists of three major steps: (1) Counting-sort the provided coordinate | ||
representation into an unsorted-row CSR form including repeated entries. (2) Sweep through | ||
the CSR form, simultaneously calculating the desired CSC form's column-pointer array, | ||
detecting repeated entries, and repacking the CSR form with repeated entries combined; | ||
this stage yields an unsorted-row CSR form with no repeated entries. (3) Counting-sort the | ||
preceding CSR form into a fully-sorted CSC form with no repeated entries. | ||
|
||
Input arrays `csrrowptr`, `csrcolval`, and `csrnzval` constitute storage for the | ||
intermediate CSR forms and require `length(csrrowptr) >= m + 1`, | ||
`length(csrcolval) >= length(I)`, and `length(csrnzval >= length(I))`. Input | ||
array `klasttouch`, workspace for the second stage, requires `length(klasttouch) >= n`. | ||
Optional input arrays `csccolptr`, `cscrowval`, and `cscnzval` constitute storage for the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's another optional here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The CSC arrays are currently optional; see the method definitions immediately following the main There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But there is a method argument style issue with the immediately following definitions. Fixing. Thanks! |
||
returned CSC form `S`. `csccolptr` requires `length(csccolptr) >= n + 1`. If necessary, | ||
`cscrowval` and `cscnzval` are automatically resized to satisfy | ||
`length(cscrowval) >= nnz(S)` and `length(cscnzval) >= nnz(S)`; hence, if `nnz(S)` is | ||
unknown at the outset, passing in empty vectors of the appropriate type (`Vector{Ti}()` | ||
and `Vector{Tv}()` respectively) suffices, or calling the `sparse!` method | ||
neglecting `cscrowval` and `cscnzval`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These arrays are currently not optional. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed, thanks! |
||
|
||
On return, `csrrowptr`, `csrcolval`, and `csrnzval` contain an unsorted-column | ||
representation of the result's transpose. | ||
|
||
You may reuse the input arrays' storage (`I`, `J`, `V`) for the output arrays | ||
(`csccolptr`, `cscrowval`, `cscnzval`). For example, you may call | ||
`sparse!(I, J, V, csrrowptr, csrcolval, csrnzval, I, J, V)`. | ||
|
||
For the sake of efficiency, this method performs no argument checking beyond | ||
`1 <= I[k] <= m` and `1 <= J[k] <= n`. Use with care. Testing with `--check-bounds=yes` | ||
is wise. | ||
|
||
This method runs in `O(m, n, length(I))` time. The HALFPERM algorithm described in | ||
F. Gustavson, "Two fast algorithms for sparse matrices: multiplication and permuted | ||
transposition," ACM TOMS 4(3), 250-269 (1978) inspired this method's use of a pair of | ||
counting sorts. | ||
|
||
Performance note: As of January 2016, `combine` should be a functor for this method to | ||
perform well. This caveat may disappear when the work in `jb/functions` lands. | ||
""" | ||
function sparse!{Tv,Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector{Ti}, | ||
V::AbstractVector{Tv}, m::Integer, n::Integer, combine, klasttouch::Vector{Ti}, | ||
csrrowptr::Vector{Ti}, csrcolval::Vector{Ti}, csrnzval::Vector{Tv}, | ||
csccolptr::Vector{Ti}, cscrowval::Vector{Ti}, cscnzval::Vector{Tv} ) | ||
|
||
# Compute the CSR form's row counts and store them shifted forward by one in csrrowptr | ||
fill!(csrrowptr, 0) | ||
coolen = length(I) | ||
@inbounds for k in 1:coolen | ||
Ik = I[k] | ||
if 1 > Ik || m < Ik | ||
throw(ArgumentError("row indices I[k] must satisfy 1 <= I[k] <= m")) | ||
end | ||
csrrowptr[Ik+1] += 1 | ||
end | ||
|
||
# Compute the CSR form's rowptrs and store them shifted forward by one in csrrowptr | ||
countsum = 1 | ||
csrrowptr[1] = 1 | ||
@inbounds for i in 2:(m+1) | ||
overwritten = csrrowptr[i] | ||
csrrowptr[i] = countsum | ||
countsum += overwritten | ||
end | ||
|
||
# Counting-sort the column and nonzero values from J and V into csrcolval and csrnzval | ||
# Tracking write positions in csrrowptr corrects the row pointers | ||
@inbounds for k in 1:coolen | ||
Ik, Jk = I[k], J[k] | ||
if 1 > Jk || n < Jk | ||
throw(ArgumentError("column indices J[k] must satisfy 1 <= J[k] <= n")) | ||
end | ||
csrk = csrrowptr[Ik+1] | ||
csrrowptr[Ik+1] = csrk+1 | ||
csrcolval[csrk] = Jk | ||
csrnzval[csrk] = V[k] | ||
end | ||
# This completes the unsorted-row, has-repeats CSR form's construction | ||
|
||
# Sweep through the CSR form, simultaneously (1) caculating the CSC form's column | ||
# counts and storing them shifted forward by one in csccolptr; (2) detecting repeated | ||
# entries; and (3) repacking the CSR form with the repeated entries combined. | ||
# | ||
# Minimizing extraneous communication and nonlocality of reference, primarily by using | ||
# only a single auxiliary array in this step, is the key to this method's performance. | ||
fill!(csccolptr, 0) | ||
fill!(klasttouch, 0) | ||
writek = 1 | ||
newcsrrowptri = 1 | ||
origcsrrowptri = 1 | ||
origcsrrowptrip1 = csrrowptr[2] | ||
@inbounds for i in 1:m | ||
for readk in origcsrrowptri:(origcsrrowptrip1-1) | ||
j = csrcolval[readk] | ||
if klasttouch[j] < newcsrrowptri | ||
klasttouch[j] = writek | ||
if writek != readk | ||
csrcolval[writek] = j | ||
csrnzval[writek] = csrnzval[readk] | ||
end | ||
writek += 1 | ||
csccolptr[j+1] += 1 | ||
else | ||
klt = klasttouch[j] | ||
csrnzval[klt] = combine(csrnzval[klt], csrnzval[readk]) | ||
end | ||
end | ||
newcsrrowptri = writek | ||
origcsrrowptri = origcsrrowptrip1 | ||
origcsrrowptrip1 != writek && (csrrowptr[i+1] = writek) | ||
i < m && (origcsrrowptrip1 = csrrowptr[i+2]) | ||
end | ||
|
||
# Compute the CSC form's colptrs and store them shifted forward by one in csccolptr | ||
countsum = 1 | ||
csccolptr[1] = 1 | ||
@inbounds for j in 2:(n+1) | ||
overwritten = csccolptr[j] | ||
csccolptr[j] = countsum | ||
countsum += overwritten | ||
end | ||
|
||
# Now knowing the CSC form's entry count, resize cscrowval and cscnzval if necessary | ||
cscnnz = countsum - 1 | ||
length(cscrowval) < cscnnz && resize!(cscrowval, cscnnz) | ||
length(cscnzval) < cscnnz && resize!(cscnzval, cscnnz) | ||
|
||
# Finally counting-sort the row and nonzero values from the CSR form into cscrowval and | ||
# cscnzval. Tracking write positions in csccolptr corrects the column pointers. | ||
@inbounds for i in 1:m | ||
for csrk in csrrowptr[i]:(csrrowptr[i+1]-1) | ||
j = csrcolval[csrk] | ||
x = csrnzval[csrk] | ||
csck = csccolptr[j+1] | ||
csccolptr[j+1] = csck+1 | ||
cscrowval[csck] = i | ||
cscnzval[csck] = x | ||
end | ||
end | ||
|
||
SparseMatrixCSC(m, n, csccolptr, cscrowval, cscnzval) | ||
end | ||
function sparse!{Tv,Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector{Ti}, | ||
V::AbstractVector{Tv}, m::Integer, n::Integer, combine, klasttouch::Vector{Ti}, | ||
csrrowptr::Vector{Ti}, csrcolval::Vector{Ti}, csrnzval::Vector{Tv}, | ||
csccolptr::Vector{Ti} ) | ||
sparse!(I, J, V, m, n, combine, klasttouch, | ||
csrrowptr, csrcolval, csrnzval, | ||
csccolptr, Vector{Ti}(), Vector{Tv}() ) | ||
end | ||
function sparse!{Tv,Ti<:Integer}(I::AbstractVector{Ti}, J::AbstractVector{Ti}, | ||
V::AbstractVector{Tv}, m::Integer, n::Integer, combine, klasttouch::Vector{Ti}, | ||
csrrowptr::Vector{Ti}, csrcolval::Vector{Ti}, csrnzval::Vector{Tv} ) | ||
sparse!(I, J, V, m, n, combine, klasttouch, | ||
csrrowptr, csrcolval, csrnzval, | ||
Vector{Ti}(n+1), Vector{Ti}(), Vector{Tv}() ) | ||
end | ||
|
||
dimlub(I) = isempty(I) ? 0 : Int(maximum(I)) #least upper bound on required sparse matrix dimension | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should throw a
BoundsError
ArgumentError
if elements ofI
orJ
are outside ofm
-by-n
edit: my bad, sorry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in lines 421 and 440 (
ArgumentError
->BoundsError
), thanks!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, if
m == 0 || n == 0
this should not return successfully if the input indices are out of bounds (aka if there are any of them)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see, thanks! Fixing now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed (I think), thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose the first check (
!isempty(I)
) suffices given the earlier checklength(I) == length(J) == length(V)
. But perhaps this is more clear. Thoughts?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that works. I think the error message can be the same as the non-empty case though,
row values I[k] must satisfy 1 <= I[k] <= m
row indices I[k] must satisfy 1 <= I[k] <= m
etc. Also always good to add more tests for this kind of corner case.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After a little thought I might advocate for the distinct, explicit error message in now: The
I[k]
-specific error message may be confusing where, for example,m == 2
,n == 0
, and(I,J,V)
=(2, 1, 1.0)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current version that starts "where ..., any entry is necessarily out-of-bounds," reads to me as too meandering for an error message. I'd do something like this
edit: except it should be
ArgumentError
, whoopsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful. Copied verbatim. Thanks!