From 4b54d39ce755e48f415cc6778ba804f07622e761 Mon Sep 17 00:00:00 2001 From: Zachary P Christensen Date: Thu, 3 Feb 2022 21:20:54 -0500 Subject: [PATCH] Fix `find_first_eq` (#38) This fixes a problem where `find_first_eq` was failing to find values when there were a mix of dynamic and static values. --- Project.toml | 2 +- src/tuples.jl | 12 ++++++------ test/runtests.jl | 3 ++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 5b206d6..fa482c7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Static" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" authors = ["chriselrod", "ChrisRackauckas", "Tokazama"] -version = "0.5.1" +version = "0.5.2" [deps] IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" diff --git a/src/tuples.jl b/src/tuples.jl index aff56fb..a1af269 100644 --- a/src/tuples.jl +++ b/src/tuples.jl @@ -65,7 +65,6 @@ eachop_tuple(op, itr, arg, args...) = _eachop_tuple(op, itr, arg, args) Expr(:block, Expr(:meta, :inline), t) end - #= find_first_eq(x, collection::Tuple) @@ -74,12 +73,13 @@ If `x` and `collection` are static (`is_static`) and `x` is in `collection` then value is a `StaticInt`. =# @generated function find_first_eq(x::X, itr::I) where {X,N,I<:Tuple{Vararg{Any,N}}} - if (is_static(X) & is_static(I)) === True() - return Expr(:block, Expr(:meta, :inline), - :(Base.Cartesian.@nif $(N + 1) d->(x === getfield(itr, d)) d->(static(d)) d->(nothing))) + # we avoid incidental code gen when evaluated a tuple of known values by iterating + # through `I.parameters` instead of `known(I)`. + index = ifelse(known(X) === missing, nothing, findfirst(==(X), I.parameters)) + if index === nothing + :(Base.Cartesian.@nif $(N + 1) d->(x == getfield(itr, d)) d->(d) d->(nothing)) else - return Expr(:block, Expr(:meta, :inline), - :(Base.Cartesian.@nif $(N + 1) d->(x === getfield(itr, d)) d->(d) d->(nothing))) + :($(static(index))) end end diff --git a/test/runtests.jl b/test/runtests.jl index e5b78e5..6c1408e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -296,13 +296,14 @@ using Test @test @inferred(Static.permute(x, y)) === y @test @inferred(Static.eachop(getindex, x)) === x - @test Static.field_type(typeof((x = 1, y = 2)), :x) <: Int @test Static.field_type(typeof((x = 1, y = 2)), static(:x)) <: Int get_tuple_add(::Type{T}, ::Type{X}, dim::StaticInt) where {T,X} = Tuple{Static.field_type(T, dim),X} @test @inferred(Static.eachop_tuple(Static.field_type, y, T)) === Tuple{String,Float64,Int} @test @inferred(Static.eachop_tuple(get_tuple_add, y, T, String)) === Tuple{Tuple{String,String},Tuple{Float64,String},Tuple{Int,String}} @test @inferred(Static.find_first_eq(static(1), y)) === static(3) + @test @inferred(Static.find_first_eq(static(2), (1, static(2)))) === static(2) + @test @inferred(Static.find_first_eq(static(2), (1, static(2), 3))) === static(2) # inferred is Union{Int,Nothing} @test Static.find_first_eq(1, map(Int, y)) === 3