Skip to content

Commit

Permalink
Merge pull request #49032 from N5N3/subtyping-backports
Browse files Browse the repository at this point in the history
Subtyping backports for 1.9
  • Loading branch information
KristofferC authored Mar 24, 2023
2 parents a4cd8d2 + b860eca commit 38f0e29
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 96 deletions.
222 changes: 129 additions & 93 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -519,28 +519,43 @@ static jl_unionall_t *rename_unionall(jl_unionall_t *u)

static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->more == 0)
return 0;
// reset `used` and let `pick_union_decision` clean the stack.
state->used = state->more;
statestack_set(state, state->used - 1, 1);
return 1;
}

static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
return ui;
}

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
do {
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0) {
state->more = state->depth; // memorize that this was the deepest available choice
u = ((jl_uniontype_t*)u)->a;
}
else {
if (pick_union_decision(e, R))
u = ((jl_uniontype_t*)u)->b;
}
else
u = ((jl_uniontype_t*)u)->a;
} while (jl_is_uniontype(u));
return u;
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow);

// subtype for variable bounds consistency check. needs its own forall/exists environment.
static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
Expand All @@ -556,17 +571,7 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
return 0;
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int sub;
e->Lunions.used = e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 0);

pop_unionstate(&e->Runions, &oldRunions);
int sub = local_forall_exists_subtype(x, y, e, 0, 1);
pop_unionstate(&e->Lunions, &oldLunions);
return sub;
}
Expand Down Expand Up @@ -1195,15 +1200,9 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
// of unions and vars: if matching `typevar <: union`, first try to match the whole
// union against the variable before trying to take it apart to see if there are any
// variables lurking inside.
jl_unionstate_t *state = &e->Runions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
// note: for forall var, there's no need to split y if it has no free typevars.
jl_varbinding_t *xx = lookup(e, (jl_tvar_t *)x);
ui = ((xx && xx->right) || jl_has_free_typevars(y)) && pick_union_decision(e, 1);
}
if (ui == 1)
y = pick_union_element(y, e, 1);
Expand Down Expand Up @@ -1355,63 +1354,110 @@ static int is_definite_length_tuple_type(jl_value_t *x)
return k == JL_VARARG_NONE || k == JL_VARARG_INT;
}

static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
if (obviously_egal(x, y)) return 1;
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore);

if ((is_indefinite_length_tuple_type(x) && is_definite_length_tuple_type(y)) ||
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
static int may_contain_union_decision(jl_value_t *x, jl_stenv_t *e, jl_typeenv_t *log) JL_NOTSAFEPOINT
{
if (x == NULL || x == (jl_value_t*)jl_any_type || x == jl_bottom_type)
return 0;
if (jl_is_unionall(x))
return may_contain_union_decision(((jl_unionall_t *)x)->body, e, log);
if (jl_is_datatype(x)) {
jl_datatype_t *xd = (jl_datatype_t *)x;
for (int i = 0; i < jl_nparams(xd); i++) {
jl_value_t *param = jl_tparam(xd, i);
if (jl_is_vararg(param))
param = jl_unwrap_vararg(param);
if (may_contain_union_decision(param, e, log))
return 1;
}
return 0;
}
if (!jl_is_typevar(x))
return 1;
jl_typeenv_t *t = log;
while (t != NULL) {
if (x == (jl_value_t *)t->var)
return 1;
t = t->prev;
}
jl_typeenv_t newlog = { (jl_tvar_t*)x, NULL, log };
jl_varbinding_t *xb = lookup(e, (jl_tvar_t *)x);
return may_contain_union_decision(xb ? xb->lb : ((jl_tvar_t *)x)->lb, e, &newlog) ||
may_contain_union_decision(xb ? xb->ub : ((jl_tvar_t *)x)->ub, e, &newlog);
}

jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
e->Lunions.used = 0;
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow)
{
int16_t oldRmore = e->Runions.more;
int sub;

if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
if (may_contain_union_decision(y, e, NULL) && pick_union_decision(e, 1) == 0) {
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 2);

e->Lunions.used = e->Runions.used = 0;
e->Lunions.depth = e->Runions.depth = 0;
e->Lunions.more = e->Runions.more = 0;
int count = 0, noRmore = 0;
sub = _forall_exists_subtype(x, y, e, param, &count, &noRmore);
pop_unionstate(&e->Runions, &oldRunions);
// we should not try the slow path if `forall_exists_subtype` has tested all cases;
// Once limit_slow == 1, also skip it if
// 1) `forall_exists_subtype` return false
// 2) the left `Union` looks big
if (noRmore || (limit_slow && (count > 3 || !sub)))
e->Runions.more = oldRmore;
}
else {
int lastset = 0;
// slow path
e->Lunions.used = 0;
while (1) {
e->Lunions.more = 0;
e->Lunions.depth = 0;
sub = subtype(x, y, e, 2);
int set = e->Lunions.more;
if (!sub || !set)
sub = subtype(x, y, e, param);
if (!sub || !next_union_state(e, 0))
break;
for (int i = set; i <= lastset; i++)
statestack_set(&e->Lunions, i, 0);
lastset = set - 1;
statestack_set(&e->Lunions, lastset, 1);
}
}
return sub;
}

static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
if (obviously_egal(x, y)) return 1;

if ((is_indefinite_length_tuple_type(x) && is_definite_length_tuple_type(y)) ||
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
return 0;

if ((jl_is_uniontype(x) && jl_is_uniontype(y))) {
// For 2 unions, first try a more efficient greedy algorithm that compares the unions
// componentwise. If failed, `exists_subtype` would memorize that this branch should be skipped.
if (pick_union_decision(e, 1) == 0) {
return forall_exists_equal(((jl_uniontype_t *)x)->a, ((jl_uniontype_t *)y)->a, e) &&
forall_exists_equal(((jl_uniontype_t *)x)->b, ((jl_uniontype_t *)y)->b, e);
}
}

jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);

int limit_slow = !jl_has_free_typevars(x) || !jl_has_free_typevars(y);
int sub = local_forall_exists_subtype(x, y, e, 2, limit_slow) &&
local_forall_exists_subtype(y, x, e, 0, 0);

pop_unionstate(&e->Lunions, &oldLunions);
return sub && subtype(y, x, e, 0);
return sub;
}

static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
{
e->Runions.used = 0;
int lastset = 0;
while (1) {
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;
if (subtype(x, y, e, param))
return 1;
int set = e->Runions.more;
if (set) {
if (next_union_state(e, 1)) {
// We preserve `envout` here as `subtype_unionall` needs previous assigned env values.
int oldidx = e->envidx;
e->envidx = e->envsz;
Expand All @@ -1422,14 +1468,10 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_
restore_env(e, saved, se);
return 0;
}
for (int i = set; i <= lastset; i++)
statestack_set(&e->Runions, i, 0);
lastset = set - 1;
statestack_set(&e->Runions, lastset, 1);
}
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore)
{
// The depth recursion has the following shape, after simplification:
// ∀₁
Expand All @@ -1441,26 +1483,29 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
save_env(e, &saved, &se);

e->Lunions.used = 0;
int lastset = 0;
int sub;
if (count) *count = 0;
if (noRmore) *noRmore = 1;
while (1) {
sub = exists_subtype(x, y, e, saved, &se, param);
int set = e->Lunions.more;
if (!sub || !set)
if (count) *count = (*count < 4) ? *count + 1 : 4;
if (noRmore) *noRmore = *noRmore && e->Runions.more == 0;
if (!sub || !next_union_state(e, 0))
break;
free_env(&se);
save_env(e, &saved, &se);
for (int i = set; i <= lastset; i++)
statestack_set(&e->Lunions, i, 0);
lastset = set - 1;
statestack_set(&e->Lunions, lastset, 1);
}

free_env(&se);
JL_GC_POP();
return sub;
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
{
return _forall_exists_subtype(x, y, e, param, NULL, NULL);
}

static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
{
e->vars = NULL;
Expand Down Expand Up @@ -3292,39 +3337,30 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
jl_value_t **merged = &is[3];
jl_savedenv_t se, me;
save_env(e, saved, &se);
int lastset = 0, niter = 0, total_iter = 0;
jl_value_t *ii = intersect(x, y, e, 0);
is[0] = ii; // root
int niter = 0, total_iter = 0;
is[0] = intersect(x, y, e, 0); // root
if (is[0] != jl_bottom_type)
niter = merge_env(e, merged, &me, niter);
restore_env(e, *saved, &se);
while (e->Runions.more) {
if (e->emptiness_only && ii != jl_bottom_type)
while (next_union_state(e, 1)) {
if (e->emptiness_only && is[0] != jl_bottom_type)
break;
e->Runions.depth = 0;
int set = e->Runions.more - 1;
e->Runions.more = 0;
statestack_set(&e->Runions, set, 1);
for (int i = set + 1; i <= lastset; i++)
statestack_set(&e->Runions, i, 0);
lastset = set;

is[0] = ii;
is[1] = intersect(x, y, e, 0);
if (is[1] != jl_bottom_type)
niter = merge_env(e, merged, &me, niter);
restore_env(e, *saved, &se);
if (is[0] == jl_bottom_type)
ii = is[1];
else if (is[1] == jl_bottom_type)
ii = is[0];
else {
is[0] = is[1];
else if (is[1] != jl_bottom_type) {
// TODO: the repeated subtype checks in here can get expensive
ii = jl_type_union(is, 2);
is[0] = jl_type_union(is, 2);
}
total_iter++;
if (niter > 4 || total_iter > 400000) {
ii = y;
is[0] = y;
break;
}
}
Expand All @@ -3334,7 +3370,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
free_env(&se);
JL_GC_POP();
return ii;
return is[0];
}

// type intersection entry points
Expand Down
Loading

0 comments on commit 38f0e29

Please sign in to comment.