Skip to content

Commit

Permalink
[red-knot] Narrowing For Truthiness Checks (if x or if not x) (#1…
Browse files Browse the repository at this point in the history
…4687)

## Summary

Fixes #14550.

Add `AlwaysTruthy` and `AlwaysFalsy` types, representing the set of objects whose `__bool__` method can only ever return `True` or `False`, respectively, and narrow `if x` and `if not x` accordingly.


## Test Plan

- New Markdown test for truthiness narrowing `narrow/truthiness.md`
- unit tests in `types.rs` and `builders.rs` (`cargo test --package
red_knot_python_semantic --lib -- types`)
  • Loading branch information
cake-monotone authored Dec 17, 2024
1 parent c3b6139 commit f463fa7
Show file tree
Hide file tree
Showing 7 changed files with 343 additions and 25 deletions.
221 changes: 221 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Narrowing For Truthiness Checks (`if x` or `if not x`)

## Value Literals

```py
def foo() -> Literal[0, -1, True, False, "", "foo", b"", b"bar", None] | tuple[()]:
return 0

x = foo()

if x:
reveal_type(x) # revealed: Literal[-1] | Literal[True] | Literal["foo"] | Literal[b"bar"]
else:
reveal_type(x) # revealed: Literal[0] | Literal[False] | Literal[""] | Literal[b""] | None | tuple[()]

if not x:
reveal_type(x) # revealed: Literal[0] | Literal[False] | Literal[""] | Literal[b""] | None | tuple[()]
else:
reveal_type(x) # revealed: Literal[-1] | Literal[True] | Literal["foo"] | Literal[b"bar"]

if x and not x:
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: Literal[-1, 0] | bool | Literal["", "foo"] | Literal[b"", b"bar"] | None | tuple[()]

if not (x and not x):
reveal_type(x) # revealed: Literal[-1, 0] | bool | Literal["", "foo"] | Literal[b"", b"bar"] | None | tuple[()]
else:
reveal_type(x) # revealed: Never

if x or not x:
reveal_type(x) # revealed: Literal[-1, 0] | bool | Literal["foo", ""] | Literal[b"bar", b""] | None | tuple[()]
else:
reveal_type(x) # revealed: Never

if not (x or not x):
reveal_type(x) # revealed: Never
else:
reveal_type(x) # revealed: Literal[-1, 0] | bool | Literal["foo", ""] | Literal[b"bar", b""] | None | tuple[()]

if (isinstance(x, int) or isinstance(x, str)) and x:
reveal_type(x) # revealed: Literal[-1] | Literal[True] | Literal["foo"]
else:
reveal_type(x) # revealed: Literal[b"", b"bar"] | None | tuple[()] | Literal[0] | Literal[False] | Literal[""]
```

## Function Literals

Basically functions are always truthy.

```py
def flag() -> bool:
return True

def foo(hello: int) -> bytes:
return b""

def bar(world: str, *args, **kwargs) -> float:
return 0.0

x = foo if flag() else bar

if x:
reveal_type(x) # revealed: Literal[foo, bar]
else:
reveal_type(x) # revealed: Never
```

## Mutable Truthiness

### Truthiness of Instances

The boolean value of an instance is not always consistent. For example, `__bool__` can be customized
to return random values, or in the case of a `list()`, the result depends on the number of elements
in the list. Therefore, these types should not be narrowed by `if x` or `if not x`.

```py
class A: ...
class B: ...

def f(x: A | B):
if x:
reveal_type(x) # revealed: A & ~AlwaysFalsy | B & ~AlwaysFalsy
else:
reveal_type(x) # revealed: A & ~AlwaysTruthy | B & ~AlwaysTruthy

if x and not x:
reveal_type(x) # revealed: A & ~AlwaysFalsy & ~AlwaysTruthy | B & ~AlwaysFalsy & ~AlwaysTruthy
else:
reveal_type(x) # revealed: A & ~AlwaysTruthy | B & ~AlwaysTruthy | A & ~AlwaysFalsy | B & ~AlwaysFalsy

if x or not x:
reveal_type(x) # revealed: A & ~AlwaysFalsy | B & ~AlwaysFalsy | A & ~AlwaysTruthy | B & ~AlwaysTruthy
else:
reveal_type(x) # revealed: A & ~AlwaysTruthy & ~AlwaysFalsy | B & ~AlwaysTruthy & ~AlwaysFalsy
```

### Truthiness of Types

Also, types may not be Truthy. This is because `__bool__` can be customized via a metaclass.
Although this is a very rare case, we may consider metaclass checks in the future to handle this
more accurately.

```py
def flag() -> bool:
return True

x = int if flag() else str
reveal_type(x) # revealed: Literal[int, str]

if x:
reveal_type(x) # revealed: Literal[int] & ~AlwaysFalsy | Literal[str] & ~AlwaysFalsy
else:
reveal_type(x) # revealed: Literal[int] & ~AlwaysTruthy | Literal[str] & ~AlwaysTruthy
```

## Determined Truthiness

Some custom classes can have a boolean value that is consistently determined as either `True` or
`False`, regardless of the instance's state. This is achieved by defining a `__bool__` method that
always returns a fixed value.

These types can always be fully narrowed in boolean contexts, as shown below:

```py
class T:
def __bool__(self) -> Literal[True]:
return True

class F:
def __bool__(self) -> Literal[False]:
return False

t = T()

if t:
reveal_type(t) # revealed: T
else:
reveal_type(t) # revealed: Never

f = F()

if f:
reveal_type(f) # revealed: Never
else:
reveal_type(f) # revealed: F
```

## Narrowing Complex Intersection and Union

```py
class A: ...
class B: ...

def flag() -> bool:
return True

def instance() -> A | B:
return A()

def literals() -> Literal[0, 42, "", "hello"]:
return 42

x = instance()
y = literals()

if isinstance(x, str) and not isinstance(x, B):
reveal_type(x) # revealed: A & str & ~B
reveal_type(y) # revealed: Literal[0, 42] | Literal["", "hello"]

z = x if flag() else y

reveal_type(z) # revealed: A & str & ~B | Literal[0, 42] | Literal["", "hello"]

if z:
reveal_type(z) # revealed: A & str & ~B & ~AlwaysFalsy | Literal[42] | Literal["hello"]
else:
reveal_type(z) # revealed: A & str & ~B & ~AlwaysTruthy | Literal[0] | Literal[""]
```

## Narrowing Multiple Variables

```py
def f(x: Literal[0, 1], y: Literal["", "hello"]):
if x and y and not x and not y:
reveal_type(x) # revealed: Never
reveal_type(y) # revealed: Never
else:
# ~(x or not x) and ~(y or not y)
reveal_type(x) # revealed: Literal[0, 1]
reveal_type(y) # revealed: Literal["", "hello"]

if (x or not x) and (y and not y):
reveal_type(x) # revealed: Literal[0, 1]
reveal_type(y) # revealed: Never
else:
# ~(x or not x) or ~(y and not y)
reveal_type(x) # revealed: Literal[0, 1]
reveal_type(y) # revealed: Literal["", "hello"]
```

## ControlFlow Merging

After merging control flows, when we take the union of all constraints applied in each branch, we
should return to the original state.

```py
class A: ...

x = A()

if x and not x:
y = x
reveal_type(y) # revealed: A & ~AlwaysFalsy & ~AlwaysTruthy
else:
y = x
reveal_type(y) # revealed: A & ~AlwaysTruthy | A & ~AlwaysFalsy

# TODO: It should be A. We should improve UnionBuilder or IntersectionBuilder. (issue #15023)
reveal_type(y) # revealed: A & ~AlwaysTruthy | A & ~AlwaysFalsy
```
61 changes: 58 additions & 3 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,11 @@ pub enum Type<'db> {
Union(UnionType<'db>),
/// The set of objects in all of the types in the intersection
Intersection(IntersectionType<'db>),
/// Represents objects whose `__bool__` method is deterministic:
/// - `AlwaysTruthy`: `__bool__` always returns `True`
/// - `AlwaysFalsy`: `__bool__` always returns `False`
AlwaysTruthy,
AlwaysFalsy,
/// An integer literal
IntLiteral(i64),
/// A boolean literal, either `True` or `False`.
Expand Down Expand Up @@ -717,6 +722,15 @@ impl<'db> Type<'db> {
.all(|&neg_ty| self.is_disjoint_from(db, neg_ty))
}

// Note that the definition of `Type::AlwaysFalsy` depends on the return value of `__bool__`.
// If `__bool__` always returns True or False, it can be treated as a subtype of `AlwaysTruthy` or `AlwaysFalsy`, respectively.
(left, Type::AlwaysFalsy) => matches!(left.bool(db), Truthiness::AlwaysFalse),
(left, Type::AlwaysTruthy) => matches!(left.bool(db), Truthiness::AlwaysTrue),
// Currently, the only supertype of `AlwaysFalsy` and `AlwaysTruthy` is the universal set (object instance).
(Type::AlwaysFalsy | Type::AlwaysTruthy, _) => {
target.is_equivalent_to(db, KnownClass::Object.to_instance(db))
}

// All `StringLiteral` types are a subtype of `LiteralString`.
(Type::StringLiteral(_), Type::LiteralString) => true,

Expand Down Expand Up @@ -1105,6 +1119,16 @@ impl<'db> Type<'db> {
false
}

(Type::AlwaysTruthy, ty) | (ty, Type::AlwaysTruthy) => {
// `Truthiness::Ambiguous` may include `AlwaysTrue` as a subset, so it's not guaranteed to be disjoint.
// Thus, they are only disjoint if `ty.bool() == AlwaysFalse`.
matches!(ty.bool(db), Truthiness::AlwaysFalse)
}
(Type::AlwaysFalsy, ty) | (ty, Type::AlwaysFalsy) => {
// Similarly, they are only disjoint if `ty.bool() == AlwaysTrue`.
matches!(ty.bool(db), Truthiness::AlwaysTrue)
}

(Type::KnownInstance(left), right) => {
left.instance_fallback(db).is_disjoint_from(db, right)
}
Expand Down Expand Up @@ -1238,7 +1262,9 @@ impl<'db> Type<'db> {
| Type::LiteralString
| Type::BytesLiteral(_)
| Type::SliceLiteral(_)
| Type::KnownInstance(_) => true,
| Type::KnownInstance(_)
| Type::AlwaysFalsy
| Type::AlwaysTruthy => true,
Type::SubclassOf(SubclassOfType { base }) => matches!(base, ClassBase::Class(_)),
Type::ClassLiteral(_) | Type::Instance(_) => {
// TODO: Ideally, we would iterate over the MRO of the class, check if all
Expand Down Expand Up @@ -1340,6 +1366,7 @@ impl<'db> Type<'db> {
//
false
}
Type::AlwaysTruthy | Type::AlwaysFalsy => false,
}
}

Expand Down Expand Up @@ -1410,7 +1437,9 @@ impl<'db> Type<'db> {
| Type::Todo(_)
| Type::Union(..)
| Type::Intersection(..)
| Type::LiteralString => false,
| Type::LiteralString
| Type::AlwaysTruthy
| Type::AlwaysFalsy => false,
}
}

Expand Down Expand Up @@ -1578,6 +1607,10 @@ impl<'db> Type<'db> {
// TODO: implement tuple methods
todo_type!().into()
}
Type::AlwaysTruthy | Type::AlwaysFalsy => {
// TODO return `Callable[[], Literal[True/False]]` for `__bool__` access
KnownClass::Object.to_instance(db).member(db, name)
}
&todo @ Type::Todo(_) => todo.into(),
}
}
Expand All @@ -1600,6 +1633,8 @@ impl<'db> Type<'db> {
// TODO: see above
Truthiness::Ambiguous
}
Type::AlwaysTruthy => Truthiness::AlwaysTrue,
Type::AlwaysFalsy => Truthiness::AlwaysFalse,
instance_ty @ Type::Instance(InstanceType { class }) => {
if class.is_known(db, KnownClass::NoneType) {
Truthiness::AlwaysFalse
Expand Down Expand Up @@ -1912,7 +1947,9 @@ impl<'db> Type<'db> {
| Type::StringLiteral(_)
| Type::SliceLiteral(_)
| Type::Tuple(_)
| Type::LiteralString => Type::Unknown,
| Type::LiteralString
| Type::AlwaysTruthy
| Type::AlwaysFalsy => Type::Unknown,
}
}

Expand Down Expand Up @@ -2074,6 +2111,7 @@ impl<'db> Type<'db> {
ClassBase::try_from_ty(db, todo_type!("Intersection meta-type"))
.expect("Type::Todo should be a valid ClassBase"),
),
Type::AlwaysTruthy | Type::AlwaysFalsy => KnownClass::Type.to_instance(db),
Type::Todo(todo) => Type::subclass_of_base(ClassBase::Todo(*todo)),
}
}
Expand Down Expand Up @@ -3558,6 +3596,8 @@ pub(crate) mod tests {
SubclassOfAbcClass(&'static str),
StdlibModule(CoreStdlibModule),
SliceLiteral(i32, i32, i32),
AlwaysTruthy,
AlwaysFalsy,
}

impl Ty {
Expand Down Expand Up @@ -3625,6 +3665,8 @@ pub(crate) mod tests {
Some(stop),
Some(step),
)),
Ty::AlwaysTruthy => Type::AlwaysTruthy,
Ty::AlwaysFalsy => Type::AlwaysFalsy,
}
}
}
Expand Down Expand Up @@ -3763,6 +3805,12 @@ pub(crate) mod tests {
)]
#[test_case(Ty::SliceLiteral(1, 2, 3), Ty::BuiltinInstance("slice"))]
#[test_case(Ty::SubclassOfBuiltinClass("str"), Ty::Intersection{pos: vec![], neg: vec![Ty::None]})]
#[test_case(Ty::IntLiteral(1), Ty::AlwaysTruthy)]
#[test_case(Ty::IntLiteral(0), Ty::AlwaysFalsy)]
#[test_case(Ty::AlwaysTruthy, Ty::BuiltinInstance("object"))]
#[test_case(Ty::AlwaysFalsy, Ty::BuiltinInstance("object"))]
#[test_case(Ty::Never, Ty::AlwaysTruthy)]
#[test_case(Ty::Never, Ty::AlwaysFalsy)]
fn is_subtype_of(from: Ty, to: Ty) {
let db = setup_db();
assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
Expand Down Expand Up @@ -3797,6 +3845,10 @@ pub(crate) mod tests {
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::SubclassOfAny)]
#[test_case(Ty::AbcInstance("ABCMeta"), Ty::SubclassOfBuiltinClass("type"))]
#[test_case(Ty::SubclassOfBuiltinClass("str"), Ty::BuiltinClassLiteral("str"))]
#[test_case(Ty::IntLiteral(1), Ty::AlwaysFalsy)]
#[test_case(Ty::IntLiteral(0), Ty::AlwaysTruthy)]
#[test_case(Ty::BuiltinInstance("str"), Ty::AlwaysTruthy)]
#[test_case(Ty::BuiltinInstance("str"), Ty::AlwaysFalsy)]
fn is_not_subtype_of(from: Ty, to: Ty) {
let db = setup_db();
assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
Expand Down Expand Up @@ -3931,6 +3983,7 @@ pub(crate) mod tests {
#[test_case(Ty::Tuple(vec![]), Ty::BuiltinClassLiteral("object"))]
#[test_case(Ty::SubclassOfBuiltinClass("object"), Ty::None)]
#[test_case(Ty::SubclassOfBuiltinClass("str"), Ty::LiteralString)]
#[test_case(Ty::AlwaysFalsy, Ty::AlwaysTruthy)]
fn is_disjoint_from(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
Expand Down Expand Up @@ -3961,6 +4014,8 @@ pub(crate) mod tests {
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::BuiltinInstance("type"))]
#[test_case(Ty::BuiltinClassLiteral("str"), Ty::SubclassOfAny)]
#[test_case(Ty::AbcClassLiteral("ABC"), Ty::AbcInstance("ABCMeta"))]
#[test_case(Ty::BuiltinInstance("str"), Ty::AlwaysTruthy)]
#[test_case(Ty::BuiltinInstance("str"), Ty::AlwaysFalsy)]
fn is_not_disjoint_from(a: Ty, b: Ty) {
let db = setup_db();
let a = a.into_type(&db);
Expand Down
Loading

0 comments on commit f463fa7

Please sign in to comment.