Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tables.jl support #3104

Merged
merged 18 commits into from
Oct 27, 2022
Merged
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
MathOptInterface = "1.3.0"
Expand Down
3 changes: 3 additions & 0 deletions src/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,9 @@ include("callbacks.jl")
include("file_formats.jl")
include("feasibility_checker.jl")

using Tables
odow marked this conversation as resolved.
Show resolved Hide resolved
include("tables.jl")

# MOI contains a number of Enums that are often accessed by users such as
# `MOI.OPTIMAL`. This piece of code re-exports them from JuMP so that users can
# use: `MOI.OPTIMAL`, `JuMP.OPTIMAL`, or `using JuMP; OPTIMAL`.
Expand Down
164 changes: 164 additions & 0 deletions src/tables.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
abstract type _SolutionTable end

Tables.istable(::Type{<:_SolutionTable}) = true
Tables.rowaccess(::Type{<:_SolutionTable}) = true

_column_names(t::_SolutionTable) = getfield(t, :column_names)
_lookup(t::_SolutionTable) = getfield(t, :lookup)

Base.eltype(::_SolutionTable) = _SolutionRow
Base.length(t::_SolutionTable) = length(t.var)

struct _SolutionRow <: Tables.AbstractRow
index_vals::Any
sol_val::Number
source::_SolutionTable
end

function Tables.getcolumn(s::_SolutionRow, i::Int)
if i > length(getfield(s, :index_vals))
return getfield(s, :sol_val)
end
return getfield(s, :index_vals)[i]
end

function Tables.getcolumn(s::_SolutionRow, nm::Symbol)
i = _lookup(getfield(s, :source))[nm]
if i > length(getfield(s, :index_vals))
return getfield(s, :sol_val)
end
return getfield(s, :index_vals)[i]
end

Tables.columnnames(s::_SolutionRow) = _column_names(getfield(s, :source))

struct _SolutionTableDense{C} <: _SolutionTable
column_names::Vector{Symbol}
lookup::Dict{Symbol,Int}
index_lookup::Dict
var::C
end

function Base.iterate(t::_SolutionTableDense, state = nothing)
next =
isnothing(state) ? iterate(CartesianIndices(t.var)) :
iterate(CartesianIndices(t.var), state)
next === nothing && return nothing
index = next[1]
index_vals = [t.index_lookup[i][index[i]] for i in 1:length(index)]
return _SolutionRow(index_vals, JuMP.value(t.var[next[1]]), t), next[2]
end

function _SolutionTableDense(
v::Containers.DenseAxisArray{T,N,Ax,L},
name,
colnames...,
) where {T<:AbstractVariableRef,N,Ax,L}
if length(colnames) < N
error("Not enough column names provided")
end
if length(v) > 0 && !has_values(owner_model(first(v)))
error("No solution values available for variable")
end
all_names = vcat(colnames..., name)
lookup = Dict(nm => i for (i, nm) in enumerate(all_names))
index_lookup = Dict()
for (i, ax) in enumerate(axes(v))
index_lookup[i] = collect(ax)
end
return _SolutionTableDense(all_names, lookup, index_lookup, v)
end

"""
solution_table(var::DenseAxisArray, name, colnames...)

Returns the solution values of the variable container `var` as a table
that implements the `Tables.jl` interface.

The table will have one column for each index and a column with the
corresponding solution value. The name of the column with the solution
value is provided by `name`, while `colnames` provides the name of the
index columns.

## Example
```julia
model = Model()
@variable(model, x[1:10, 2000:2020] >= 0)
[...]
optimize!(model)
tbl = solution_table(x, :value, :car, :year)
```
"""
function solution_table(
var::Containers.DenseAxisArray{T,N,Ax,L},
name,
colnames...,
) where {T<:AbstractVariableRef,N,Ax,L}
return _SolutionTableDense(var, name, colnames...)
end

function _SolutionTableDense(
v::Array{T},
name,
colnames...,
) where {T<:AbstractVariableRef}
if length(colnames) < length(axes(v))
error("Not enough column names provided")
end
if length(v) > 0 && !has_values(owner_model(first(v)))
error("No solution values available for variable")
end
all_names = vcat(colnames..., name)
lookup = Dict(nm => i for (i, nm) in enumerate(all_names))
index_lookup = Dict()
for (i, ax) in enumerate(axes(v))
index_lookup[i] = collect(ax)
end
return _SolutionTableDense(all_names, lookup, index_lookup, v)
end

function solution_table(
var::Array{T},
name,
colnames...,
) where {T<:AbstractVariableRef}
return _SolutionTableDense(var, name, colnames...)
end

struct _SolutionTableSparse <: _SolutionTable
column_names::Vector{Symbol}
lookup::Dict{Symbol,Int}
var::Containers.SparseAxisArray
end

function _SolutionTableSparse(
v::Containers.SparseAxisArray{T,N,K},
name,
colnames...,
) where {T<:AbstractVariableRef,N,K}
if length(colnames) < N
error("Not enough column names provided")
end
if length(v) > 0 && !has_values(first(v).model)
error("No solution values available for variable")
end
all_names = vcat(colnames..., name)
lookup = Dict(nm => i for (i, nm) in enumerate(all_names))
return _SolutionTableSparse(all_names, lookup, v)
end

function Base.iterate(t::_SolutionTableSparse, state = nothing)
next =
isnothing(state) ? iterate(eachindex(t.var)) :
iterate(eachindex(t.var), state)
next === nothing && return nothing
return _SolutionRow(next[1], JuMP.value(t.var[next[1]]), t), next[2]
end

function solution_table(
var::Containers.SparseAxisArray{T,N,K},
name,
colnames...,
) where {T<:AbstractVariableRef,N,K}
return _SolutionTableSparse(var, name, colnames...)
end
Loading