Skip to content

Commit

Permalink
Fix duplicate error when using generator in Dict (#53151)
Browse files Browse the repository at this point in the history
Fixes: #33147
Replaces/Closes: #40445

The difference here, compared to past implementations, is that we use
the zero-cost `isiterable` check on every intermediate step, instead of
wrapping the call in a try/catch and then trying to re-approximate the
`isiterable` afterwards. Some samples:

```julia
julia> Dict(i for i in 1:3)                                                        
ERROR: ArgumentError: AbstractDict(kv): kv needs to be an iterator of 2-tuples or pairs                                                                               
Stacktrace:                                                                                                                                                           
 [1] _throw_dict_kv_error()                                                        
   @ Base ./dict.jl:118                                                                                                                                               
 [2] grow_to!                                                                      
   @ ./dict.jl:132 [inlined]                                                       
 [3] dict_with_eltype                                                                                                                                                 
   @ ./abstractdict.jl:592 [inlined]                                                                                                                                  
 [4] Dict(kv::Base.Generator{UnitRange{Int64}, typeof(identity)})                                                                                                     
   @ Base ./dict.jl:120                                                            
 [5] top-level scope                                                                                                                                                  
   @ REPL[1]:1                                                                     
                                                                                                                                                                      
julia> Dict(i => error("$i") for i in 1:3)                                         
ERROR: 1                                                                                                                                                              
Stacktrace:                                                                                                                                                           
 [1] error(s::String)                                                                                                                                                 
   @ Base ./error.jl:35                                                                                                                                               
 [2] (::var"#3#4")(i::Int64)                                                       
   @ Main ./none:0                                                                 
 [3] iterate                                                                                                                                                          
   @ ./generator.jl:48 [inlined]                                                   
 [4] grow_to!                                                                      
   @ ./dict.jl:124 [inlined]                                                       
 [5] dict_with_eltype                                                              
   @ ./abstractdict.jl:592 [inlined]                                               
 [6] Dict(kv::Base.Generator{UnitRange{Int64}, var"#3#4"})                                                                                                            
   @ Base ./dict.jl:120                                                                                                                                               
 [7] top-level scope                                                               
   @ REPL[2]:1                                                                                                                                                        
```

The other unrelated change here is that `dest = empty(dest, typeof(k),
typeof(v))` is made conditional, so we do not unconditionally construct
an empty Dict in order to discard it and allocate an exact duplicate of
it, but only do so if inference wasn't precise originally.

Co-authored-by: Curtis Vogt <[email protected]>
  • Loading branch information
vtjnash and omus authored Feb 10, 2024
1 parent b43edb7 commit 5547305
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 62 deletions.
49 changes: 49 additions & 0 deletions base/abstractdict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,55 @@ _tablesz(x::T) where T <: Integer = x < 16 ? T(16) : one(T)<<(top_set_bit(x-one(

TP{K,V} = Union{Type{Tuple{K,V}},Type{Pair{K,V}}}

# This error is thrown if `grow_to!` cannot validate the contents of the iterator argument to it, which it does by testing the iteration protocol (isiterable) on it each time it is about to start iteration on it
_throw_dict_kv_error() = throw(ArgumentError("AbstractDict(kv): kv needs to be an iterator of 2-tuples or pairs"))

function grow_to!(dest::AbstractDict, itr)
applicable(iterate, itr) || _throw_dict_kv_error()
y = iterate(itr)
y === nothing && return dest
kv, st = y
applicable(iterate, kv) || _throw_dict_kv_error()
k = iterate(kv)
k === nothing && _throw_dict_kv_error()
k, kvst = k
v = iterate(kv, kvst)
v === nothing && _throw_dict_kv_error()
v, kvst = v
iterate(kv, kvst) === nothing || _throw_dict_kv_error()
if !(dest isa AbstractDict{typeof(k), typeof(v)})
dest = empty(dest, typeof(k), typeof(v))
end
dest[k] = v
return grow_to!(dest, itr, st)
end

function grow_to!(dest::AbstractDict{K,V}, itr, st) where {K, V}
y = iterate(itr, st)
while y !== nothing
kv, st = y
applicable(iterate, kv) || _throw_dict_kv_error()
kst = iterate(kv)
kst === nothing && _throw_dict_kv_error()
k, kvst = kst
vst = iterate(kv, kvst)
vst === nothing && _throw_dict_kv_error()
v, kvst = vst
iterate(kv, kvst) === nothing || _throw_dict_kv_error()
if isa(k, K) && isa(v, V)
dest[k] = v
else
new = empty(dest, promote_typejoin(K, typeof(k)), promote_typejoin(V, typeof(v)))
merge!(new, dest)
new[k] = v
return grow_to!(new, itr, st)
end
y = iterate(itr, st)
end
return dest
end


dict_with_eltype(DT_apply, kv, ::TP{K,V}) where {K,V} = DT_apply(K, V)(kv)
dict_with_eltype(DT_apply, kv::Generator, ::TP{K,V}) where {K,V} = DT_apply(K, V)(kv)
dict_with_eltype(DT_apply, ::Type{Pair{K,V}}) where {K,V} = DT_apply(K, V)()
Expand Down
40 changes: 1 addition & 39 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,45 +114,7 @@ const AnyDict = Dict{Any,Any}
Dict(ps::Pair{K,V}...) where {K,V} = Dict{K,V}(ps)
Dict(ps::Pair...) = Dict(ps)

function Dict(kv)
try
dict_with_eltype((K, V) -> Dict{K, V}, kv, eltype(kv))
catch
if !isiterable(typeof(kv)) || !all(x->isa(x,Union{Tuple,Pair}),kv)
throw(ArgumentError("Dict(kv): kv needs to be an iterator of tuples or pairs"))
else
rethrow()
end
end
end

function grow_to!(dest::AbstractDict{K, V}, itr) where V where K
y = iterate(itr)
y === nothing && return dest
((k,v), st) = y
dest2 = empty(dest, typeof(k), typeof(v))
dest2[k] = v
grow_to!(dest2, itr, st)
end

# this is a special case due to (1) allowing both Pairs and Tuples as elements,
# and (2) Pair being invariant. a bit annoying.
function grow_to!(dest::AbstractDict{K,V}, itr, st) where V where K
y = iterate(itr, st)
while y !== nothing
(k,v), st = y
if isa(k,K) && isa(v,V)
dest[k] = v
else
new = empty(dest, promote_typejoin(K,typeof(k)), promote_typejoin(V,typeof(v)))
merge!(new, dest)
new[k] = v
return grow_to!(new, itr, st)
end
y = iterate(itr, st)
end
return dest
end
Dict(kv) = dict_with_eltype((K, V) -> Dict{K, V}, kv, eltype(kv))

empty(a::AbstractDict, ::Type{K}, ::Type{V}) where {K, V} = Dict{K, V}()

Expand Down
13 changes: 1 addition & 12 deletions base/iddict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,7 @@ IdDict(ps::Pair{K}...) where {K} = IdDict{K,Any}(ps)
IdDict(ps::(Pair{K,V} where K)...) where {V} = IdDict{Any,V}(ps)
IdDict(ps::Pair...) = IdDict{Any,Any}(ps)

function IdDict(kv)
try
dict_with_eltype((K, V) -> IdDict{K, V}, kv, eltype(kv))
catch
if !applicable(iterate, kv) || !all(x->isa(x,Union{Tuple,Pair}),kv)
throw(ArgumentError(
"IdDict(kv): kv needs to be an iterator of tuples or pairs"))
else
rethrow()
end
end
end
IdDict(kv) = dict_with_eltype((K, V) -> IdDict{K, V}, kv, eltype(kv))

empty(d::IdDict, ::Type{K}, ::Type{V}) where {K, V} = IdDict{K,V}()

Expand Down
12 changes: 1 addition & 11 deletions base/weakkeydict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,7 @@ WeakKeyDict(ps::Pair{K}...) where {K} = WeakKeyDict{K,Any}(ps)
WeakKeyDict(ps::(Pair{K,V} where K)...) where {V} = WeakKeyDict{Any,V}(ps)
WeakKeyDict(ps::Pair...) = WeakKeyDict{Any,Any}(ps)

function WeakKeyDict(kv)
try
Base.dict_with_eltype((K, V) -> WeakKeyDict{K, V}, kv, eltype(kv))
catch
if !isiterable(typeof(kv)) || !all(x->isa(x,Union{Tuple,Pair}),kv)
throw(ArgumentError("WeakKeyDict(kv): kv needs to be an iterator of tuples or pairs"))
else
rethrow()
end
end
end
WeakKeyDict(kv) = Base.dict_with_eltype((K, V) -> WeakKeyDict{K, V}, kv, eltype(kv))

function _cleanup_locked(h::WeakKeyDict)
if h.dirty
Expand Down
24 changes: 24 additions & 0 deletions test/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,30 @@ end

# issue #39117
@test Dict(t[1]=>t[2] for t in zip((1,"2"), (2,"2"))) == Dict{Any,Any}(1=>2, "2"=>"2")

@testset "issue #33147" begin
expected = try; Base._throw_dict_kv_error(); catch e; e; end
@test_throws expected Dict(i for i in 1:2)
@test_throws expected Dict(nothing for i in 1:2)
@test_throws expected Dict(() for i in 1:2)
@test_throws expected Dict((i, i, i) for i in 1:2)
@test_throws expected Dict(nothing)
@test_throws expected Dict((1,))
@test_throws expected Dict(1:2)
@test_throws expected Dict(((),))
@test_throws expected IdDict(((),))
@test_throws expected WeakKeyDict(((),))
@test_throws expected IdDict(nothing)
@test_throws expected WeakKeyDict(nothing)
@test Dict(1:0) isa Dict
@test Dict(()) isa Dict
try
Dict(i => error("$i") for i in 1:3)
catch ex
@test ex isa ErrorException
@test length(Base.current_exceptions()) == 1
end
end
end

@testset "empty tuple ctor" begin
Expand Down

0 comments on commit 5547305

Please sign in to comment.