Skip to content

Commit

Permalink
give wider/safer intersection result for vars used in both invariant …
Browse files Browse the repository at this point in the history
…and covariant position

fixes #41738
  • Loading branch information
JeffBezanson committed Aug 23, 2021
1 parent ddb7fff commit e3e550e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 33 deletions.
64 changes: 34 additions & 30 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ typedef struct jl_stenv_t {
int ignore_free; // treat free vars as black boxes; used during intersection
int intersection; // true iff subtype is being called from intersection
int emptiness_only; // true iff intersection only needs to test for emptiness
int triangular; // when intersecting Ref{X} with Ref{<:Y}
} jl_stenv_t;

// state manipulation utilities
Expand Down Expand Up @@ -1412,6 +1413,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
e->ignore_free = 0;
e->intersection = 0;
e->emptiness_only = 0;
e->triangular = 0;
e->Lunions.depth = 0; e->Runions.depth = 0;
e->Lunions.more = 0; e->Runions.more = 0;
}
Expand Down Expand Up @@ -2170,7 +2172,7 @@ static void set_bound(jl_value_t **bound, jl_value_t *val, jl_tvar_t *v, jl_sten
return;
jl_varbinding_t *btemp = e->vars;
while (btemp != NULL) {
if (btemp->lb == (jl_value_t*)v && btemp->ub == (jl_value_t*)v &&
if ((btemp->lb == (jl_value_t*)v || btemp->ub == (jl_value_t*)v) &&
in_union(val, (jl_value_t*)btemp->var))
return;
btemp = btemp->prev;
Expand Down Expand Up @@ -2222,6 +2224,21 @@ static int reachable_var(jl_value_t *x, jl_tvar_t *y, jl_stenv_t *e)
return reachable_var(xv->ub, y, e) || reachable_var(xv->lb, y, e);
}

// check whether setting v == t implies v == SomeType{v}, which is unsatisfiable.
static int check_unsat_bound(jl_value_t *t, jl_tvar_t *v, jl_stenv_t *e) JL_NOTSAFEPOINT
{
if (var_occurs_inside(t, v, 0, 0))
return 1;
jl_varbinding_t *btemp = e->vars;
while (btemp != NULL) {
if (btemp->lb == (jl_value_t*)v && btemp->ub == (jl_value_t*)v &&
var_occurs_inside(t, btemp->var, 0, 0))
return 1;
btemp = btemp->prev;
}
return 0;
}

static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int8_t R, int param)
{
jl_varbinding_t *bb = lookup(e, b);
Expand Down Expand Up @@ -2251,7 +2268,9 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
ub = a;
}
else {
e->triangular++;
ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
e->triangular--;
save_env(e, &root, &se);
int issub = subtype_in_env_existential(bb->lb, ub, e, 0, d);
restore_env(e, root, &se);
Expand All @@ -2263,20 +2282,10 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
}
if (ub != (jl_value_t*)b) {
if (jl_has_free_typevars(ub)) {
// constraint X == Ref{X} is unsatisfiable. also check variables set equal to X.
if (var_occurs_inside(ub, b, 0, 0)) {
if (check_unsat_bound(ub, b, e)) {
JL_GC_POP();
return jl_bottom_type;
}
jl_varbinding_t *btemp = e->vars;
while (btemp != NULL) {
if (btemp->lb == (jl_value_t*)b && btemp->ub == (jl_value_t*)b &&
var_occurs_inside(ub, btemp->var, 0, 0)) {
JL_GC_POP();
return jl_bottom_type;
}
btemp = btemp->prev;
}
}
bb->ub = ub;
bb->lb = ub;
Expand All @@ -2287,7 +2296,13 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
if (ub == jl_bottom_type)
return jl_bottom_type;
if (bb->constraintkind == 0) {
if (bb->constraintkind == 1 || e->triangular) {
if (e->triangular && check_unsat_bound(ub, b, e))
return jl_bottom_type;
set_bound(&bb->ub, ub, b, e);
return (jl_value_t*)b;
}
else if (bb->constraintkind == 0) {
JL_GC_PUSH1(&ub);
if (!jl_is_typevar(a) && try_subtype_in_env(bb->ub, a, e, 0, d)) {
JL_GC_POP();
Expand All @@ -2296,10 +2311,6 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
JL_GC_POP();
return ub;
}
else if (bb->constraintkind == 1) {
set_bound(&bb->ub, ub, b, e);
return (jl_value_t*)b;
}
assert(bb->constraintkind == 2);
if (!jl_is_typevar(a)) {
if (ub == a && bb->lb != jl_bottom_type)
Expand Down Expand Up @@ -2565,11 +2576,11 @@ static jl_value_t *intersect_unionall_(jl_value_t *t, jl_unionall_t *u, jl_stenv

static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
jl_savedenv_t se, se2;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0
jl_value_t *res=NULL, *save=NULL;
jl_savedenv_t se;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
JL_GC_PUSH5(&res, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
res = intersect_unionall_(t, u, e, R, param, &vb);
if (vb.limited) {
Expand All @@ -2584,18 +2595,11 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
vb.constraintkind = vb.concrete ? 1 : 2;
res = intersect_unionall_(t, u, e, R, param, &vb);
}
else if (vb.occurs_cov) {
save_env(e, &save2, &se2);
else if (vb.occurs_cov && !var_occurs_invariant(u->body, u->var, 0)) {
restore_env(e, save, &se);
vb.occurs_cov = vb.occurs_inv = 0;
vb.lb = u->var->lb; vb.ub = u->var->ub;
vb.constraintkind = 1;
res2 = intersect_unionall_(t, u, e, R, param, &vb);
if (res2 != jl_bottom_type)
res = res2;
else
restore_env(e, save2, &se2);
free_env(&se2);
res = intersect_unionall_(t, u, e, R, param, &vb);
}
}
free_env(&se);
Expand Down
18 changes: 15 additions & 3 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,7 @@ f31082(::Pair{B, C}, ::C, ::C) where {B, C} = 1
Tuple{Type{Val{T}},Int,T} where T)
@testintersect(Tuple{Type{Val{T}},Integer,T} where T,
Tuple{Type,Int,Integer},
Tuple{Type{Val{T}},Int,T} where T<:Integer)
Tuple{Type{Val{T}},Int,Integer} where T)
@testintersect(Tuple{Type{Val{T}},Integer,T} where T>:Integer,
Tuple{Type,Int,Integer},
Tuple{Type{Val{T}},Int,Integer} where T>:Integer)
Expand Down Expand Up @@ -1866,7 +1866,7 @@ let A = Tuple{Type{T} where T<:Ref, Ref, Union{T, Union{Ref{T}, T}} where T<:Ref
I = typeintersect(A,B)
# this was a case where <: disagreed with === (due to a badly-normalized type)
@test I == typeintersect(A,B)
@test I == Tuple{Type{T}, Ref{T}, Union{Ref{T}, T}} where T<:Ref
@test I == Tuple{Type{T}, Ref{T}, Ref} where T<:Ref
end

# issue #39218
Expand Down Expand Up @@ -1946,7 +1946,7 @@ let A = Tuple{UnionAll, Vector{Any}},
B = Tuple{Type{T}, T} where T<:AbstractArray,
I = typeintersect(A, B)
@test !isconcretetype(I)
@test_broken I == Tuple{Type{T}, Vector{Any}} where T<:AbstractArray
@test I == Tuple{Type{T}, Vector{Any}} where T<:AbstractArray
end

@testintersect(Tuple{Type{Vector{<:T}}, T} where {T<:Integer},
Expand All @@ -1959,3 +1959,15 @@ end
@testintersect(Tuple{Type{S40{_A, _B, _C, _D, _E, _F, _G, _H, _I, _J, _K, _L, _M, _N, _O, _P, _Q, _R, _S, _T, _U, _V, _W, _X, _Y, _Z, _Z1, _Z2, _Z3, _Z4, _Z5, _Z6, _Z7, _Z8, _Z9, _Z10, _Z11, _Z12, _Z13, _Z14}} where _Z14 where _Z13 where _Z12 where _Z11 where _Z10 where _Z9 where _Z8 where _Z7 where _Z6 where _Z5 where _Z4 where _Z3 where _Z2 where _Z1 where _Z where _Y where _X where _W where _V where _U where _T where _S where _R where _Q where _P where _O where _N where _M where _L where _K where _J where _I where _H where _G where _F where _E where _D where _C where _B where _A, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any},
Tuple{Type{S40{A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, A23, A24, A25, A26, A27, A28, A29, A30, A31, A32, A33, A34, A35, A36, A37, A38, A39, A40} where A40 where A39 where A38 where A37 where A36 where A35 where A34 where A33 where A32 where A31 where A30 where A29 where A28 where A27 where A26 where A25 where A24 where A23 where A22 where A21 where A20 where A19 where A18 where A17 where A16 where A15 where A14 where A13 where A12 where A11 where A10 where A9 where A8 where A7 where A6 where A5 where A4 where A3 where A2 where A1}, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, A23, A24, A25, A26, A27, A28, A29, A30, A31, A32, A33, A34, A35, A36, A37, A38, A39, A40} where A40 where A39 where A38 where A37 where A36 where A35 where A34 where A33 where A32 where A31 where A30 where A29 where A28 where A27 where A26 where A25 where A24 where A23 where A22 where A21 where A20 where A19 where A18 where A17 where A16 where A15 where A14 where A13 where A12 where A11 where A10 where A9 where A8 where A7 where A6 where A5 where A4 where A3 where A2 where A1,
Bottom)

let A = Tuple{Any, Type{Ref{_A}} where _A},
B = Tuple{Type{T}, Type{<:Union{Ref{T}, T}}} where T,
I = typeintersect(A, B)
@test I != Union{}
# TODO: this intersection result is still too narrow
@test_broken Tuple{Type{Ref{Integer}}, Type{Ref{Integer}}} <: I
end

@testintersect(Tuple{Type{T}, T} where T<:(Tuple{Vararg{_A, _B}} where _B where _A),
Tuple{Type{Tuple{Vararg{_A, N}} where _A<:F}, Pair{N, F}} where F where N,
Bottom)

0 comments on commit e3e550e

Please sign in to comment.