From d33e05cef83d287d24438065c935c3f38aac33d6 Mon Sep 17 00:00:00 2001 From: Truls Flatberg <75753981+trulsf@users.noreply.github.com> Date: Thu, 27 Oct 2022 20:13:12 +0200 Subject: [PATCH] Add Tables.jl support to containers (#3104) --- docs/src/manual/containers.md | 100 ++++++++++++++++++++++++ docs/src/reference/containers.md | 1 + docs/src/tutorials/linear/diet.jl | 10 +++ src/Containers/Containers.jl | 1 + src/Containers/tables.jl | 68 ++++++++++++++++ test/Containers/tables.jl | 126 ++++++++++++++++++++++++++++++ 6 files changed, 306 insertions(+) create mode 100644 src/Containers/tables.jl create mode 100644 test/Containers/tables.jl diff --git a/docs/src/manual/containers.md b/docs/src/manual/containers.md index ce47643cc0c..40c144572d8 100644 --- a/docs/src/manual/containers.md +++ b/docs/src/manual/containers.md @@ -102,6 +102,42 @@ julia> swap.(x) (1, 2) (2, 2) (3, 2) ``` +### Tables + +Use [`Containers.rowtable`](@ref) to convert the `Array` into a +[Tables.jl](https://github.com/JuliaData/Tables.jl) compatible +`Vector{<:NamedTuple}`: + +```jldoctest containers_array +julia> table = Containers.rowtable(x; header = [:I, :J, :value]) +6-element Vector{NamedTuple{(:I, :J, :value), Tuple{Int64, Int64, Tuple{Int64, Int64}}}}: + (I = 1, J = 1, value = (1, 1)) + (I = 2, J = 1, value = (2, 1)) + (I = 1, J = 2, value = (1, 2)) + (I = 2, J = 2, value = (2, 2)) + (I = 1, J = 3, value = (1, 3)) + (I = 2, J = 3, value = (2, 3)) +``` + +Because it supports the [Tables.jl](https://github.com/JuliaData/Tables.jl) +interface, you can pass it to any function which accepts a table as input: + +```jldoctest containers_array +julia> import DataFrames; + +julia> DataFrames.DataFrame(table) +6×3 DataFrame + Row │ I J value + │ Int64 Int64 Tuple… +─────┼────────────────────── + 1 │ 1 1 (1, 1) + 2 │ 2 1 (2, 1) + 3 │ 1 2 (1, 2) + 4 │ 2 2 (2, 2) + 5 │ 1 3 (1, 3) + 6 │ 2 3 (2, 3) +``` + ## DenseAxisArray A [`Containers.DenseAxisArray`](@ref) is created when the index sets are @@ -191,6 +227,38 @@ julia> x.data (2, :A) (2, :B) ``` +### Tables + +Use [`Containers.rowtable`](@ref) to convert the `DenseAxisArray` into a +[Tables.jl](https://github.com/JuliaData/Tables.jl) compatible +`Vector{<:NamedTuple}`: + +```jldoctest containers_dense +julia> table = Containers.rowtable(x; header = [:I, :J, :value]) +4-element Vector{NamedTuple{(:I, :J, :value), Tuple{Int64, Symbol, Tuple{Int64, Symbol}}}}: + (I = 1, J = :A, value = (1, :A)) + (I = 2, J = :A, value = (2, :A)) + (I = 1, J = :B, value = (1, :B)) + (I = 2, J = :B, value = (2, :B)) +``` + +Because it supports the [Tables.jl](https://github.com/JuliaData/Tables.jl) +interface, you can pass it to any function which accepts a table as input: + +```jldoctest containers_dense +julia> import DataFrames; + +julia> DataFrames.DataFrame(table) +4×3 DataFrame + Row │ I J value + │ Int64 Symbol Tuple… +─────┼──────────────────────── + 1 │ 1 A (1, :A) + 2 │ 2 A (2, :A) + 3 │ 1 B (1, :B) + 4 │ 2 B (2, :B) +``` + ## SparseAxisArray A [`Containers.SparseAxisArray`](@ref) is created when the index sets are @@ -252,6 +320,38 @@ JuMP.Containers.SparseAxisArray{Tuple{Symbol, Int64}, 1, Tuple{Int64}} with 2 en [3] = (:B, 3) ``` +### Tables + +Use [`Containers.rowtable`](@ref) to convert the `SparseAxisArray` into a +[Tables.jl](https://github.com/JuliaData/Tables.jl) compatible +`Vector{<:NamedTuple}`: + +```jldoctest containers_sparse +julia> table = Containers.rowtable(x; header = [:I, :J, :value]) +4-element Vector{NamedTuple{(:I, :J, :value), Tuple{Int64, Symbol, Tuple{Int64, Symbol}}}}: + (I = 3, J = :B, value = (3, :B)) + (I = 2, J = :A, value = (2, :A)) + (I = 2, J = :B, value = (2, :B)) + (I = 3, J = :A, value = (3, :A)) +``` + +Because it supports the [Tables.jl](https://github.com/JuliaData/Tables.jl) +interface, you can pass it to any function which accepts a table as input: + +```jldoctest containers_sparse +julia> import DataFrames; + +julia> DataFrames.DataFrame(table) +4×3 DataFrame + Row │ I J value + │ Int64 Symbol Tuple… +─────┼──────────────────────── + 1 │ 3 B (3, :B) + 2 │ 2 A (2, :A) + 3 │ 2 B (2, :B) + 4 │ 3 A (3, :A) +``` + ## Forcing the container type Pass `container = T` to use `T` as the container. For example: diff --git a/docs/src/reference/containers.md b/docs/src/reference/containers.md index e2679c6f9d4..12c7bc16142 100644 --- a/docs/src/reference/containers.md +++ b/docs/src/reference/containers.md @@ -7,6 +7,7 @@ Containers Containers.DenseAxisArray Containers.SparseAxisArray Containers.container +Containers.rowtable Containers.default_container Containers.@container Containers.VectorizedProductIterator diff --git a/docs/src/tutorials/linear/diet.jl b/docs/src/tutorials/linear/diet.jl index 80f07a1ba01..8b876b6bb50 100644 --- a/docs/src/tutorials/linear/diet.jl +++ b/docs/src/tutorials/linear/diet.jl @@ -139,6 +139,16 @@ end # That's a lot of milk and ice cream! And sadly, we only get `0.6` of a # hamburger. +# We can also use the function [`Containers.rowtable`](@ref) to easily convert +# the result into a DataFrame: + +table = Containers.rowtable(value, x; header = [:food, :quantity]) +solution = DataFrames.DataFrame(table) + +# This makes it easy to perform analyses our solution: + +filter!(row -> row.quantity > 0.0, solution) + # ## Problem modification # JuMP makes it easy to take an existing model and modify it by adding extra diff --git a/src/Containers/Containers.jl b/src/Containers/Containers.jl index d42ceacb223..dbc1aa74f3e 100644 --- a/src/Containers/Containers.jl +++ b/src/Containers/Containers.jl @@ -54,5 +54,6 @@ include("nested_iterator.jl") include("no_duplicate_dict.jl") include("container.jl") include("macro.jl") +include("tables.jl") end diff --git a/src/Containers/tables.jl b/src/Containers/tables.jl new file mode 100644 index 00000000000..c8b0ad536e0 --- /dev/null +++ b/src/Containers/tables.jl @@ -0,0 +1,68 @@ +# Copyright 2017, Iain Dunning, Joey Huchette, Miles Lubin, and contributors +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. + +_rows(x::Array) = zip(eachindex(x), Iterators.product(axes(x)...)) + +_rows(x::DenseAxisArray) = zip(vec(eachindex(x)), Iterators.product(axes(x)...)) + +_rows(x::SparseAxisArray) = zip(eachindex(x.data), keys(x.data)) + +""" + rowtable([f::Function=identity,] x; [header::Vector{Symbol} = Symbol[]]) + +Applies the function `f` to all elements of the variable container `x`, +returning the result as a `Vector` of `NamedTuple`s, where `header` is a vector +containing the corresponding axis names. + +If `x` is an `N`-dimensional array, there must be `N+1` names, so that the last +name corresponds to the result of `f(x[i])`. + +If `header` is left empty, then the default header is `[:x1, :x2, ..., :xN, :y]`. + +!!! info + A `Vector` of `NamedTuple`s implements the [Tables.jl](https://github.com/JuliaData/Tables.jl) + interface, and so the result can be used as input for any function + that consumes a 'Tables.jl' compatible source. + +## Example + +```jldoctest; setup=:(using JuMP) +julia> model = Model(); + +julia> @variable(model, x[i=1:2, j=i:2] >= 0, start = i+j); + +julia> Containers.rowtable(start_value, x; header = [:i, :j, :start]) +3-element Vector{NamedTuple{(:i, :j, :start), Tuple{Int64, Int64, Float64}}}: + (i = 1, j = 2, start = 3.0) + (i = 1, j = 1, start = 2.0) + (i = 2, j = 2, start = 4.0) + +julia> Containers.rowtable(x) +3-element Vector{NamedTuple{(:x1, :x2, :y), Tuple{Int64, Int64, VariableRef}}}: + (x1 = 1, x2 = 2, y = x[1,2]) + (x1 = 1, x2 = 1, y = x[1,1]) + (x1 = 2, x2 = 2, y = x[2,2]) +``` +""" +function rowtable( + f::Function, + x::Union{Array,DenseAxisArray,SparseAxisArray}; + header::Vector{Symbol} = Symbol[], +) + if isempty(header) + header = Symbol[Symbol("x$i") for i in 1:ndims(x)] + push!(header, :y) + end + got, want = length(header), ndims(x) + 1 + if got != want + error( + "Invalid number of column names provided: Got $got, expected $want.", + ) + end + names = tuple(header...) + return [NamedTuple{names}((args..., f(x[i]))) for (i, args) in _rows(x)] +end + +rowtable(x; kwargs...) = rowtable(identity, x; kwargs...) diff --git a/test/Containers/tables.jl b/test/Containers/tables.jl new file mode 100644 index 00000000000..98ee4a3a57d --- /dev/null +++ b/test/Containers/tables.jl @@ -0,0 +1,126 @@ +# Copyright 2017, Iain Dunning, Joey Huchette, Miles Lubin, and contributors +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. + +module TestTableInterface + +using JuMP +using Test + +function runtests() + for name in names(@__MODULE__; all = true) + if startswith("$(name)", "test_") + @testset "$(name)" begin + getfield(@__MODULE__, name)() + end + end + end + return +end + +function test_denseaxisarray() + model = Model() + @variable(model, x[i = 4:10, j = 2002:2022] >= 0, start = 0.0) + @test typeof(x) <: Containers.DenseAxisArray + start_table = Containers.rowtable(start_value, x; header = [:i1, :i2, :i3]) + T = NamedTuple{(:i1, :i2, :i3),Tuple{Int,Int,Float64}} + @test start_table isa Vector{T} + @test length(start_table) == length(x) + row = first(start_table) + @test row == (i1 = 4, i2 = 2002, i3 = 0.0) + x_table = Containers.rowtable(x; header = [:i1, :i2, :i3]) + @test x_table[1] == (i1 = 4, i2 = 2002, i3 = x[4, 2002]) + return +end + +function test_array() + model = Model() + @variable(model, x[1:10, 1:5] >= 0, start = 0.0) + @test typeof(x) <: Array{VariableRef} + start_table = Containers.rowtable(start_value, x; header = [:i1, :i2, :i3]) + T = NamedTuple{(:i1, :i2, :i3),Tuple{Int,Int,Float64}} + @test start_table isa Vector{T} + @test length(start_table) == length(x) + row = first(start_table) + @test row == (i1 = 1, i2 = 1, i3 = 0.0) + x_table = Containers.rowtable(x; header = [:i1, :i2, :i3]) + @test x_table[1] == (i1 = 1, i2 = 1, i3 = x[1, 1]) + return +end + +function test_sparseaxisarray() + model = Model() + @variable(model, x[i = 1:10, j = 1:5; i + j <= 8] >= 0, start = 0) + @test typeof(x) <: Containers.SparseAxisArray + start_table = Containers.rowtable(start_value, x; header = [:i1, :i2, :i3]) + T = NamedTuple{(:i1, :i2, :i3),Tuple{Int,Int,Float64}} + @test start_table isa Vector{T} + @test length(start_table) == length(x) + @test (i1 = 1, i2 = 1, i3 = 0.0) in start_table + x_table = Containers.rowtable(x; header = [:i1, :i2, :i3]) + @test (i1 = 1, i2 = 1, i3 = x[1, 1]) in x_table + return +end + +function test_col_name_error() + model = Model() + @variable(model, x[1:2, 1:2]) + @test_throws ErrorException Containers.rowtable(x; header = [:y, :a]) + @test_throws( + ErrorException, + Containers.rowtable(x; header = [:y, :a, :b, :c]), + ) + @test Containers.rowtable(x; header = [:y, :a, :b]) isa Vector{<:NamedTuple} + return +end + +# Mockup of custom variable type +struct _MockVariable <: JuMP.AbstractVariable + var::JuMP.ScalarVariable +end + +struct _MockVariableRef <: JuMP.AbstractVariableRef + vref::VariableRef +end + +JuMP.name(v::_MockVariableRef) = JuMP.name(v.vref) + +JuMP.owner_model(v::_MockVariableRef) = JuMP.owner_model(v.vref) + +JuMP.start_value(v::_MockVariableRef) = JuMP.start_value(v.vref) + +struct _Mock end + +function JuMP.build_variable(::Function, info::JuMP.VariableInfo, _::_Mock) + return _MockVariable(JuMP.ScalarVariable(info)) +end + +function JuMP.add_variable(model::Model, x::_MockVariable, name::String) + variable = JuMP.add_variable(model, x.var, name) + return _MockVariableRef(variable) +end + +function test_custom_variable() + model = Model() + @variable( + model, + x[i = 1:3, j = 100:102] >= 0, + _Mock(), + container = Containers.DenseAxisArray, + start = 0.0, + ) + @test typeof(x) <: Containers.DenseAxisArray + start_table = Containers.rowtable(start_value, x) + T = NamedTuple{(:x1, :x2, :y),Tuple{Int,Int,Float64}} + @test start_table isa Vector{T} + @test length(start_table) == length(x) + @test (x1 = 1, x2 = 100, y = 0.0) in start_table + x_table = Containers.rowtable(x) + @test (x1 = 1, x2 = 100, y = x[1, 100]) in x_table + return +end + +end + +TestTableInterface.runtests()