Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subtyping backports for 1.9 #49032

Merged
merged 5 commits into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 129 additions & 93 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -519,28 +519,43 @@ static jl_unionall_t *rename_unionall(jl_unionall_t *u)

static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->more == 0)
return 0;
// reset `used` and let `pick_union_decision` clean the stack.
state->used = state->more;
statestack_set(state, state->used - 1, 1);
return 1;
}

static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
return ui;
}

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
do {
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0) {
state->more = state->depth; // memorize that this was the deepest available choice
u = ((jl_uniontype_t*)u)->a;
}
else {
if (pick_union_decision(e, R))
u = ((jl_uniontype_t*)u)->b;
}
else
u = ((jl_uniontype_t*)u)->a;
} while (jl_is_uniontype(u));
return u;
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow);

// subtype for variable bounds consistency check. needs its own forall/exists environment.
static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
Expand All @@ -556,17 +571,7 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
return 0;
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int sub;
e->Lunions.used = e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 0);

pop_unionstate(&e->Runions, &oldRunions);
int sub = local_forall_exists_subtype(x, y, e, 0, 1);
pop_unionstate(&e->Lunions, &oldLunions);
return sub;
}
Expand Down Expand Up @@ -1195,15 +1200,9 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
// of unions and vars: if matching `typevar <: union`, first try to match the whole
// union against the variable before trying to take it apart to see if there are any
// variables lurking inside.
jl_unionstate_t *state = &e->Runions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
// note: for forall var, there's no need to split y if it has no free typevars.
jl_varbinding_t *xx = lookup(e, (jl_tvar_t *)x);
ui = ((xx && xx->right) || jl_has_free_typevars(y)) && pick_union_decision(e, 1);
}
if (ui == 1)
y = pick_union_element(y, e, 1);
Expand Down Expand Up @@ -1355,63 +1354,110 @@ static int is_definite_length_tuple_type(jl_value_t *x)
return k == JL_VARARG_NONE || k == JL_VARARG_INT;
}

static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
if (obviously_egal(x, y)) return 1;
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore);

if ((is_indefinite_length_tuple_type(x) && is_definite_length_tuple_type(y)) ||
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
static int may_contain_union_decision(jl_value_t *x, jl_stenv_t *e, jl_typeenv_t *log) JL_NOTSAFEPOINT
{
if (x == NULL || x == (jl_value_t*)jl_any_type || x == jl_bottom_type)
return 0;
if (jl_is_unionall(x))
return may_contain_union_decision(((jl_unionall_t *)x)->body, e, log);
if (jl_is_datatype(x)) {
jl_datatype_t *xd = (jl_datatype_t *)x;
for (int i = 0; i < jl_nparams(xd); i++) {
jl_value_t *param = jl_tparam(xd, i);
if (jl_is_vararg(param))
param = jl_unwrap_vararg(param);
if (may_contain_union_decision(param, e, log))
return 1;
}
return 0;
}
if (!jl_is_typevar(x))
return 1;
jl_typeenv_t *t = log;
while (t != NULL) {
if (x == (jl_value_t *)t->var)
return 1;
t = t->prev;
}
jl_typeenv_t newlog = { (jl_tvar_t*)x, NULL, log };
jl_varbinding_t *xb = lookup(e, (jl_tvar_t *)x);
return may_contain_union_decision(xb ? xb->lb : ((jl_tvar_t *)x)->lb, e, &newlog) ||
may_contain_union_decision(xb ? xb->ub : ((jl_tvar_t *)x)->ub, e, &newlog);
}

jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
e->Lunions.used = 0;
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow)
{
int16_t oldRmore = e->Runions.more;
int sub;

if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
if (may_contain_union_decision(y, e, NULL) && pick_union_decision(e, 1) == 0) {
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 2);

e->Lunions.used = e->Runions.used = 0;
e->Lunions.depth = e->Runions.depth = 0;
e->Lunions.more = e->Runions.more = 0;
int count = 0, noRmore = 0;
sub = _forall_exists_subtype(x, y, e, param, &count, &noRmore);
pop_unionstate(&e->Runions, &oldRunions);
// we should not try the slow path if `forall_exists_subtype` has tested all cases;
// Once limit_slow == 1, also skip it if
// 1) `forall_exists_subtype` return false
// 2) the left `Union` looks big
if (noRmore || (limit_slow && (count > 3 || !sub)))
e->Runions.more = oldRmore;
}
else {
int lastset = 0;
// slow path
e->Lunions.used = 0;
while (1) {
e->Lunions.more = 0;
e->Lunions.depth = 0;
sub = subtype(x, y, e, 2);
int set = e->Lunions.more;
if (!sub || !set)
sub = subtype(x, y, e, param);
if (!sub || !next_union_state(e, 0))
break;
for (int i = set; i <= lastset; i++)
statestack_set(&e->Lunions, i, 0);
lastset = set - 1;
statestack_set(&e->Lunions, lastset, 1);
}
}
return sub;
}

static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
if (obviously_egal(x, y)) return 1;

if ((is_indefinite_length_tuple_type(x) && is_definite_length_tuple_type(y)) ||
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
return 0;

if ((jl_is_uniontype(x) && jl_is_uniontype(y))) {
// For 2 unions, first try a more efficient greedy algorithm that compares the unions
// componentwise. If failed, `exists_subtype` would memorize that this branch should be skipped.
if (pick_union_decision(e, 1) == 0) {
return forall_exists_equal(((jl_uniontype_t *)x)->a, ((jl_uniontype_t *)y)->a, e) &&
forall_exists_equal(((jl_uniontype_t *)x)->b, ((jl_uniontype_t *)y)->b, e);
}
}

jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);

int limit_slow = !jl_has_free_typevars(x) || !jl_has_free_typevars(y);
int sub = local_forall_exists_subtype(x, y, e, 2, limit_slow) &&
local_forall_exists_subtype(y, x, e, 0, 0);

pop_unionstate(&e->Lunions, &oldLunions);
return sub && subtype(y, x, e, 0);
return sub;
}

static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
{
e->Runions.used = 0;
int lastset = 0;
while (1) {
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;
if (subtype(x, y, e, param))
return 1;
int set = e->Runions.more;
if (set) {
if (next_union_state(e, 1)) {
// We preserve `envout` here as `subtype_unionall` needs previous assigned env values.
int oldidx = e->envidx;
e->envidx = e->envsz;
Expand All @@ -1422,14 +1468,10 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_
restore_env(e, saved, se);
return 0;
}
for (int i = set; i <= lastset; i++)
statestack_set(&e->Runions, i, 0);
lastset = set - 1;
statestack_set(&e->Runions, lastset, 1);
}
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore)
{
// The depth recursion has the following shape, after simplification:
// ∀₁
Expand All @@ -1441,26 +1483,29 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
save_env(e, &saved, &se);

e->Lunions.used = 0;
int lastset = 0;
int sub;
if (count) *count = 0;
if (noRmore) *noRmore = 1;
while (1) {
sub = exists_subtype(x, y, e, saved, &se, param);
int set = e->Lunions.more;
if (!sub || !set)
if (count) *count = (*count < 4) ? *count + 1 : 4;
if (noRmore) *noRmore = *noRmore && e->Runions.more == 0;
if (!sub || !next_union_state(e, 0))
break;
free_env(&se);
save_env(e, &saved, &se);
for (int i = set; i <= lastset; i++)
statestack_set(&e->Lunions, i, 0);
lastset = set - 1;
statestack_set(&e->Lunions, lastset, 1);
}

free_env(&se);
JL_GC_POP();
return sub;
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
{
return _forall_exists_subtype(x, y, e, param, NULL, NULL);
}

static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
{
e->vars = NULL;
Expand Down Expand Up @@ -3292,39 +3337,30 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
jl_value_t **merged = &is[3];
jl_savedenv_t se, me;
save_env(e, saved, &se);
int lastset = 0, niter = 0, total_iter = 0;
jl_value_t *ii = intersect(x, y, e, 0);
is[0] = ii; // root
int niter = 0, total_iter = 0;
is[0] = intersect(x, y, e, 0); // root
if (is[0] != jl_bottom_type)
niter = merge_env(e, merged, &me, niter);
restore_env(e, *saved, &se);
while (e->Runions.more) {
if (e->emptiness_only && ii != jl_bottom_type)
while (next_union_state(e, 1)) {
if (e->emptiness_only && is[0] != jl_bottom_type)
break;
e->Runions.depth = 0;
int set = e->Runions.more - 1;
e->Runions.more = 0;
statestack_set(&e->Runions, set, 1);
for (int i = set + 1; i <= lastset; i++)
statestack_set(&e->Runions, i, 0);
lastset = set;

is[0] = ii;
is[1] = intersect(x, y, e, 0);
if (is[1] != jl_bottom_type)
niter = merge_env(e, merged, &me, niter);
restore_env(e, *saved, &se);
if (is[0] == jl_bottom_type)
ii = is[1];
else if (is[1] == jl_bottom_type)
ii = is[0];
else {
is[0] = is[1];
else if (is[1] != jl_bottom_type) {
// TODO: the repeated subtype checks in here can get expensive
ii = jl_type_union(is, 2);
is[0] = jl_type_union(is, 2);
}
total_iter++;
if (niter > 4 || total_iter > 400000) {
ii = y;
is[0] = y;
break;
}
}
Expand All @@ -3334,7 +3370,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
free_env(&se);
JL_GC_POP();
return ii;
return is[0];
}

// type intersection entry points
Expand Down
Loading