Skip to content

Commit

Permalink
[red-knot] Implement type narrowing for boolean conditionals (#14037)
Browse files Browse the repository at this point in the history
## Summary

This PR enables red-knot to support type narrowing based on `and` and
`or` conditionals, including nested combinations and their negation (for
`elif` / `else` blocks and for `not` operator). Part of #13694.

In order to address this properly (hopefully 😅), I had to run
`NarrowingConstraintsBuilder` functions recursively. In the first commit
I introduced a minor refactor - instead of mutating `self.constraints`,
the new constraints are now returned as function return values. I also
modified the constraints map to be optional, preventing unnecessary
hashmap allocations.
Thanks @carljm for your support on this :)

The second commit contains the logic and tests for handling boolean ops,
with improvements to intersections handling in `is_subtype_of` .

As I'm still new to Rust and the internals of type checkers, I’d be more
than happy to hear any insights or suggestions.
Thank you!

---------

Co-authored-by: Carl Meyer <[email protected]>
  • Loading branch information
TomerBin and carljm authored Nov 4, 2024
1 parent bb25bd9 commit 6c56a7a
Show file tree
Hide file tree
Showing 4 changed files with 591 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Narrowing for conditionals with boolean expressions

## Narrowing in `and` conditional

```py
class A: ...
class B: ...

def instance() -> A | B:
return A()

x = instance()

if isinstance(x, A) and isinstance(x, B):
reveal_type(x) # revealed: A & B
else:
reveal_type(x) # revealed: B & ~A | A & ~B
```

## Arms might not add narrowing constraints

```py
class A: ...
class B: ...

def bool_instance() -> bool:
return True

def instance() -> A | B:
return A()

x = instance()

if isinstance(x, A) and bool_instance():
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: A | B

if bool_instance() and isinstance(x, A):
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: A | B

reveal_type(x) # revealed: A | B
```

## Statically known arms

```py
class A: ...
class B: ...

def instance() -> A | B:
return A()

x = instance()

if isinstance(x, A) and True:
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: B & ~A

if True and isinstance(x, A):
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: B & ~A

if False and isinstance(x, A):
# TODO: should emit an `unreachable code` diagnostic
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: A | B

if False or isinstance(x, A):
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: B & ~A

if True or isinstance(x, A):
reveal_type(x) # revealed: A | B
else:
# TODO: should emit an `unreachable code` diagnostic
reveal_type(x) # revealed: B & ~A

reveal_type(x) # revealed: A | B
```

## The type of multiple symbols can be narrowed down

```py
class A: ...
class B: ...

def instance() -> A | B:
return A()

x = instance()
y = instance()

if isinstance(x, A) and isinstance(y, B):
reveal_type(x) # revealed: A
reveal_type(y) # revealed: B
else:
# No narrowing: Only-one or both checks might have failed
reveal_type(x) # revealed: A | B
reveal_type(y) # revealed: A | B

reveal_type(x) # revealed: A | B
reveal_type(y) # revealed: A | B
```

## Narrowing in `or` conditional

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, A) or isinstance(x, B):
reveal_type(x) # revealed: A | B
else:
reveal_type(x) # revealed: C & ~A & ~B
```

## In `or`, all arms should add constraint in order to narrow

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

def bool_instance() -> bool:
return True

x = instance()

if isinstance(x, A) or isinstance(x, B) or bool_instance():
reveal_type(x) # revealed: A | B | C
else:
reveal_type(x) # revealed: C & ~A & ~B
```

## in `or`, all arms should narrow the same set of symbols

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()
y = instance()

if isinstance(x, A) or isinstance(y, A):
# The predicate might be satisfied by the right side, so the type of `x` can’t be narrowed down here.
reveal_type(x) # revealed: A | B | C
# The same for `y`
reveal_type(y) # revealed: A | B | C
else:
reveal_type(x) # revealed: B & ~A | C & ~A
reveal_type(y) # revealed: B & ~A | C & ~A

if (isinstance(x, A) and isinstance(y, A)) or (isinstance(x, B) and isinstance(y, B)):
# Here, types of `x` and `y` can be narrowd since all `or` arms constraint them.
reveal_type(x) # revealed: A | B
reveal_type(y) # revealed: A | B
else:
reveal_type(x) # revealed: A | B | C
reveal_type(y) # revealed: A | B | C
```

## mixing `and` and `not`

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, B) and not isinstance(x, C):
reveal_type(x) # revealed: B & ~C
else:
# ~(B & ~C) -> ~B | C -> (A & ~B) | (C & ~B) | C -> (A & ~B) | C
reveal_type(x) # revealed: A & ~B | C
```

## mixing `or` and `not`

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, B) or not isinstance(x, C):
reveal_type(x) # revealed: B | A & ~C
else:
reveal_type(x) # revealed: C & ~B
```

## `or` with nested `and`

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, A) or (isinstance(x, B) and not isinstance(x, C)):
reveal_type(x) # revealed: A | B & ~C
else:
# ~(A | (B & ~C)) -> ~A & ~(B & ~C) -> ~A & (~B | C) -> (~A & C) | (~A ~ B)
reveal_type(x) # revealed: C & ~A
```

## `and` with nested `or`

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, A) and (isinstance(x, B) or not isinstance(x, C)):
# A & (B | ~C) -> (A & B) | (A & ~C)
reveal_type(x) # revealed: A & B | A & ~C
else:
# ~((A & B) | (A & ~C)) ->
# ~(A & B) & ~(A & ~C) ->
# (~A | ~B) & (~A | C) ->
# [(~A | ~B) & ~A] | [(~A | ~B) & C] ->
# ~A | (~A & C) | (~B & C) ->
# ~A | (C & ~B) ->
# ~A | (C & ~B) The positive side of ~A is A | B | C ->
reveal_type(x) # revealed: B & ~A | C & ~A | C & ~B
```

## Boolean expression internal narrowing

```py
def optional_string() -> str | None:
return None

x = optional_string()
y = optional_string()

if x is None and y is not x:
reveal_type(y) # revealed: str

# Neither of the conditions alone is sufficient for narrowing y's type:
if x is None:
reveal_type(y) # revealed: str | None

if y is not x:
reveal_type(y) # revealed: str | None
```
78 changes: 78 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,46 @@ impl<'db> Type<'db> {
.elements(db)
.iter()
.any(|&elem_ty| ty.is_subtype_of(db, elem_ty)),
(Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => {
// Check that all target positive values are covered in self positive values
target_intersection
.positive(db)
.iter()
.all(|&target_pos_elem| {
self_intersection
.positive(db)
.iter()
.any(|&self_pos_elem| self_pos_elem.is_subtype_of(db, target_pos_elem))
})
// Check that all target negative values are excluded in self, either by being
// subtypes of a self negative value or being disjoint from a self positive value.
&& target_intersection
.negative(db)
.iter()
.all(|&target_neg_elem| {
// Is target negative value is subtype of a self negative value
self_intersection.negative(db).iter().any(|&self_neg_elem| {
target_neg_elem.is_subtype_of(db, self_neg_elem)
// Is target negative value is disjoint from a self positive value?
}) || self_intersection.positive(db).iter().any(|&self_pos_elem| {
target_neg_elem.is_disjoint_from(db, self_pos_elem)
})
})
}
(Type::Intersection(intersection), ty) => intersection
.positive(db)
.iter()
.any(|&elem_ty| elem_ty.is_subtype_of(db, ty)),
(ty, Type::Intersection(intersection)) => {
intersection
.positive(db)
.iter()
.all(|&pos_ty| ty.is_subtype_of(db, pos_ty))
&& intersection
.negative(db)
.iter()
.all(|&neg_ty| neg_ty.is_disjoint_from(db, ty))
}
(Type::Instance(self_class), Type::Instance(target_class)) => {
self_class.is_subclass_of(db, target_class)
}
Expand Down Expand Up @@ -2190,6 +2230,11 @@ mod tests {
Ty::BuiltinInstance("FloatingPointError"),
Ty::BuiltinInstance("Exception")
)]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::BuiltinInstance("int"))]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
#[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]})]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("str")], neg: vec![Ty::StringLiteral("foo")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
fn is_subtype_of(from: Ty, to: Ty) {
let db = setup_db();
assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
Expand All @@ -2210,6 +2255,11 @@ mod tests {
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(42)]), Ty::Tuple(vec![Ty::BuiltinInstance("str")]))]
#[test_case(Ty::Tuple(vec![Ty::Todo]), Ty::Tuple(vec![Ty::IntLiteral(2)]))]
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::Todo]))]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(3)]})]
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})]
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]})]
#[test_case(Ty::BuiltinInstance("int"), Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})]
#[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(1)]})]
fn is_not_subtype_of(from: Ty, to: Ty) {
let db = setup_db();
assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
Expand Down Expand Up @@ -2241,6 +2291,34 @@ mod tests {
assert!(type_u.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db)));
}

#[test]
fn is_subtype_of_intersection_of_class_instances() {
let mut db = setup_db();
db.write_dedented(
"/src/module.py",
"
class A: ...
a = A()
class B: ...
b = B()
",
)
.unwrap();
let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap();

let a_ty = super::global_symbol(&db, module, "a").expect_type();
let b_ty = super::global_symbol(&db, module, "b").expect_type();
let intersection = IntersectionBuilder::new(&db)
.add_positive(a_ty)
.add_positive(b_ty)
.build();

assert_eq!(intersection.display(&db).to_string(), "A & B");
assert!(!a_ty.is_subtype_of(&db, b_ty));
assert!(intersection.is_subtype_of(&db, b_ty));
assert!(intersection.is_subtype_of(&db, a_ty));
}

#[test_case(
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]),
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)])
Expand Down
Loading

0 comments on commit 6c56a7a

Please sign in to comment.