diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_boolean.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_boolean.md new file mode 100644 index 0000000000000..854e09ff0a50b --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/conditionals_boolean.md @@ -0,0 +1,282 @@ +# Narrowing for conditionals with boolean expressions + +## Narrowing in `and` conditional + +```py +class A: ... +class B: ... + +def instance() -> A | B: + return A() + +x = instance() + +if isinstance(x, A) and isinstance(x, B): + reveal_type(x) # revealed: A & B +else: + reveal_type(x) # revealed: B & ~A | A & ~B +``` + +## Arms might not add narrowing constraints + +```py +class A: ... +class B: ... + +def bool_instance() -> bool: + return True + +def instance() -> A | B: + return A() + +x = instance() + +if isinstance(x, A) and bool_instance(): + reveal_type(x) # revealed: A +else: + reveal_type(x) # revealed: A | B + +if bool_instance() and isinstance(x, A): + reveal_type(x) # revealed: A +else: + reveal_type(x) # revealed: A | B + +reveal_type(x) # revealed: A | B +``` + +## Statically known arms + +```py +class A: ... +class B: ... + +def instance() -> A | B: + return A() + +x = instance() + +if isinstance(x, A) and True: + reveal_type(x) # revealed: A +else: + reveal_type(x) # revealed: B & ~A + +if True and isinstance(x, A): + reveal_type(x) # revealed: A +else: + reveal_type(x) # revealed: B & ~A + +if False and isinstance(x, A): + # TODO: should emit an `unreachable code` diagnostic + reveal_type(x) # revealed: A +else: + reveal_type(x) # revealed: A | B + +if False or isinstance(x, A): + reveal_type(x) # revealed: A +else: + reveal_type(x) # revealed: B & ~A + +if True or isinstance(x, A): + reveal_type(x) # revealed: A | B +else: + # TODO: should emit an `unreachable code` diagnostic + reveal_type(x) # revealed: B & ~A + +reveal_type(x) # revealed: A | B +``` + +## The type of multiple symbols can be narrowed down + +```py +class A: ... +class B: ... + +def instance() -> A | B: + return A() + +x = instance() +y = instance() + +if isinstance(x, A) and isinstance(y, B): + reveal_type(x) # revealed: A + reveal_type(y) # revealed: B +else: + # No narrowing: Only-one or both checks might have failed + reveal_type(x) # revealed: A | B + reveal_type(y) # revealed: A | B + +reveal_type(x) # revealed: A | B +reveal_type(y) # revealed: A | B +``` + +## Narrowing in `or` conditional + +```py +class A: ... +class B: ... +class C: ... + +def instance() -> A | B | C: + return A() + +x = instance() + +if isinstance(x, A) or isinstance(x, B): + reveal_type(x) # revealed: A | B +else: + reveal_type(x) # revealed: C & ~A & ~B +``` + +## In `or`, all arms should add constraint in order to narrow + +```py +class A: ... +class B: ... +class C: ... + +def instance() -> A | B | C: + return A() + +def bool_instance() -> bool: + return True + +x = instance() + +if isinstance(x, A) or isinstance(x, B) or bool_instance(): + reveal_type(x) # revealed: A | B | C +else: + reveal_type(x) # revealed: C & ~A & ~B +``` + +## in `or`, all arms should narrow the same set of symbols + +```py +class A: ... +class B: ... +class C: ... + +def instance() -> A | B | C: + return A() + +x = instance() +y = instance() + +if isinstance(x, A) or isinstance(y, A): + # The predicate might be satisfied by the right side, so the type of `x` can’t be narrowed down here. + reveal_type(x) # revealed: A | B | C + # The same for `y` + reveal_type(y) # revealed: A | B | C +else: + reveal_type(x) # revealed: B & ~A | C & ~A + reveal_type(y) # revealed: B & ~A | C & ~A + +if (isinstance(x, A) and isinstance(y, A)) or (isinstance(x, B) and isinstance(y, B)): + # Here, types of `x` and `y` can be narrowd since all `or` arms constraint them. + reveal_type(x) # revealed: A | B + reveal_type(y) # revealed: A | B +else: + reveal_type(x) # revealed: A | B | C + reveal_type(y) # revealed: A | B | C +``` + +## mixing `and` and `not` + +```py +class A: ... +class B: ... +class C: ... + +def instance() -> A | B | C: + return A() + +x = instance() + +if isinstance(x, B) and not isinstance(x, C): + reveal_type(x) # revealed: B & ~C +else: + # ~(B & ~C) -> ~B | C -> (A & ~B) | (C & ~B) | C -> (A & ~B) | C + reveal_type(x) # revealed: A & ~B | C +``` + +## mixing `or` and `not` + +```py +class A: ... +class B: ... +class C: ... + +def instance() -> A | B | C: + return A() + +x = instance() + +if isinstance(x, B) or not isinstance(x, C): + reveal_type(x) # revealed: B | A & ~C +else: + reveal_type(x) # revealed: C & ~B +``` + +## `or` with nested `and` + +```py +class A: ... +class B: ... +class C: ... + +def instance() -> A | B | C: + return A() + +x = instance() + +if isinstance(x, A) or (isinstance(x, B) and not isinstance(x, C)): + reveal_type(x) # revealed: A | B & ~C +else: + # ~(A | (B & ~C)) -> ~A & ~(B & ~C) -> ~A & (~B | C) -> (~A & C) | (~A ~ B) + reveal_type(x) # revealed: C & ~A +``` + +## `and` with nested `or` + +```py +class A: ... +class B: ... +class C: ... + +def instance() -> A | B | C: + return A() + +x = instance() + +if isinstance(x, A) and (isinstance(x, B) or not isinstance(x, C)): + # A & (B | ~C) -> (A & B) | (A & ~C) + reveal_type(x) # revealed: A & B | A & ~C +else: + # ~((A & B) | (A & ~C)) -> + # ~(A & B) & ~(A & ~C) -> + # (~A | ~B) & (~A | C) -> + # [(~A | ~B) & ~A] | [(~A | ~B) & C] -> + # ~A | (~A & C) | (~B & C) -> + # ~A | (C & ~B) -> + # ~A | (C & ~B) The positive side of ~A is A | B | C -> + reveal_type(x) # revealed: B & ~A | C & ~A | C & ~B +``` + +## Boolean expression internal narrowing + +```py +def optional_string() -> str | None: + return None + +x = optional_string() +y = optional_string() + +if x is None and y is not x: + reveal_type(y) # revealed: str + +# Neither of the conditions alone is sufficient for narrowing y's type: +if x is None: + reveal_type(y) # revealed: str | None + +if y is not x: + reveal_type(y) # revealed: str | None +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 1703518d0ab0f..8fd3617ee0154 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -528,6 +528,46 @@ impl<'db> Type<'db> { .elements(db) .iter() .any(|&elem_ty| ty.is_subtype_of(db, elem_ty)), + (Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => { + // Check that all target positive values are covered in self positive values + target_intersection + .positive(db) + .iter() + .all(|&target_pos_elem| { + self_intersection + .positive(db) + .iter() + .any(|&self_pos_elem| self_pos_elem.is_subtype_of(db, target_pos_elem)) + }) + // Check that all target negative values are excluded in self, either by being + // subtypes of a self negative value or being disjoint from a self positive value. + && target_intersection + .negative(db) + .iter() + .all(|&target_neg_elem| { + // Is target negative value is subtype of a self negative value + self_intersection.negative(db).iter().any(|&self_neg_elem| { + target_neg_elem.is_subtype_of(db, self_neg_elem) + // Is target negative value is disjoint from a self positive value? + }) || self_intersection.positive(db).iter().any(|&self_pos_elem| { + target_neg_elem.is_disjoint_from(db, self_pos_elem) + }) + }) + } + (Type::Intersection(intersection), ty) => intersection + .positive(db) + .iter() + .any(|&elem_ty| elem_ty.is_subtype_of(db, ty)), + (ty, Type::Intersection(intersection)) => { + intersection + .positive(db) + .iter() + .all(|&pos_ty| ty.is_subtype_of(db, pos_ty)) + && intersection + .negative(db) + .iter() + .all(|&neg_ty| neg_ty.is_disjoint_from(db, ty)) + } (Type::Instance(self_class), Type::Instance(target_class)) => { self_class.is_subclass_of(db, target_class) } @@ -2190,6 +2230,11 @@ mod tests { Ty::BuiltinInstance("FloatingPointError"), Ty::BuiltinInstance("Exception") )] + #[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::BuiltinInstance("int"))] + #[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})] + #[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})] + #[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]})] + #[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("str")], neg: vec![Ty::StringLiteral("foo")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})] fn is_subtype_of(from: Ty, to: Ty) { let db = setup_db(); assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db))); @@ -2210,6 +2255,11 @@ mod tests { #[test_case(Ty::Tuple(vec![Ty::IntLiteral(42)]), Ty::Tuple(vec![Ty::BuiltinInstance("str")]))] #[test_case(Ty::Tuple(vec![Ty::Todo]), Ty::Tuple(vec![Ty::IntLiteral(2)]))] #[test_case(Ty::Tuple(vec![Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::Todo]))] + #[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(3)]})] + #[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})] + #[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]})] + #[test_case(Ty::BuiltinInstance("int"), Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})] + #[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(1)]})] fn is_not_subtype_of(from: Ty, to: Ty) { let db = setup_db(); assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db))); @@ -2241,6 +2291,34 @@ mod tests { assert!(type_u.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db))); } + #[test] + fn is_subtype_of_intersection_of_class_instances() { + let mut db = setup_db(); + db.write_dedented( + "/src/module.py", + " + class A: ... + a = A() + class B: ... + b = B() + ", + ) + .unwrap(); + let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap(); + + let a_ty = super::global_symbol(&db, module, "a").expect_type(); + let b_ty = super::global_symbol(&db, module, "b").expect_type(); + let intersection = IntersectionBuilder::new(&db) + .add_positive(a_ty) + .add_positive(b_ty) + .build(); + + assert_eq!(intersection.display(&db).to_string(), "A & B"); + assert!(!a_ty.is_subtype_of(&db, b_ty)); + assert!(intersection.is_subtype_of(&db, b_ty)); + assert!(intersection.is_subtype_of(&db, a_ty)); + } + #[test_case( Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]), Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]) diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 2fe9504e6fe6b..5853c44e1d7f0 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -25,12 +25,12 @@ //! * No type in an intersection can be a supertype of any other type in the intersection (just //! eliminate the supertype from the intersection). //! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. + +use super::KnownClass; use crate::types::{IntersectionType, Type, UnionType}; use crate::{Db, FxOrderSet}; use smallvec::SmallVec; -use super::KnownClass; - pub(crate) struct UnionBuilder<'db> { elements: Vec>, db: &'db dyn Db, @@ -80,7 +80,6 @@ impl<'db> UnionBuilder<'db> { to_remove.push(index); } } - match to_remove[..] { [] => self.elements.push(to_add), [index] => self.elements[index] = to_add, @@ -103,7 +102,6 @@ impl<'db> UnionBuilder<'db> { } } } - self } @@ -386,8 +384,9 @@ mod tests { use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; use crate::stdlib::typing_symbol; - use crate::types::{KnownClass, StringLiteralType, UnionBuilder}; + use crate::types::{global_symbol, KnownClass, StringLiteralType, UnionBuilder}; use crate::ProgramSettings; + use ruff_db::files::system_path_to_file; use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; use test_case::test_case; @@ -993,4 +992,66 @@ mod tests { .build(); assert_eq!(result, ty); } + + #[test] + fn build_intersection_of_two_unions_simplify() { + let mut db = setup_db(); + db.write_dedented( + "/src/module.py", + " + class A: ... + class B: ... + a = A() + b = B() + ", + ) + .unwrap(); + + let file = system_path_to_file(&db, "src/module.py").expect("file to exist"); + + let a = global_symbol(&db, file, "a").expect_type(); + let b = global_symbol(&db, file, "b").expect_type(); + let union = UnionBuilder::new(&db).add(a).add(b).build(); + assert_eq!(union.display(&db).to_string(), "A | B"); + let reversed_union = UnionBuilder::new(&db).add(b).add(a).build(); + assert_eq!(reversed_union.display(&db).to_string(), "B | A"); + let intersection = IntersectionBuilder::new(&db) + .add_positive(union) + .add_positive(reversed_union) + .build(); + assert_eq!(intersection.display(&db).to_string(), "B | A"); + } + + #[test] + fn build_union_of_two_intersections_simplify() { + let mut db = setup_db(); + db.write_dedented( + "/src/module.py", + " + class A: ... + class B: ... + a = A() + b = B() + ", + ) + .unwrap(); + + let file = system_path_to_file(&db, "src/module.py").expect("file to exist"); + + let a = global_symbol(&db, file, "a").expect_type(); + let b = global_symbol(&db, file, "b").expect_type(); + let intersection = IntersectionBuilder::new(&db) + .add_positive(a) + .add_positive(b) + .build(); + let reversed_intersection = IntersectionBuilder::new(&db) + .add_positive(b) + .add_positive(a) + .build(); + let union = UnionBuilder::new(&db) + .add(intersection) + .add(reversed_intersection) + .build(); + assert_eq!(union.display(&db).to_string(), "A & B"); + } } diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index f3d276f17236a..e088ab6d2848b 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -5,12 +5,15 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol_table; use crate::types::{ - infer_expression_types, IntersectionBuilder, KnownFunction, Type, UnionBuilder, + infer_expression_types, IntersectionBuilder, KnownClass, KnownFunction, Truthiness, Type, + UnionBuilder, }; use crate::Db; use itertools::Itertools; use ruff_python_ast as ast; +use ruff_python_ast::{BoolOp, ExprBoolOp}; use rustc_hash::FxHashMap; +use std::collections::hash_map::Entry; use std::sync::Arc; /// Return the type constraint that `test` (if true) would place on `definition`, if any. @@ -34,21 +37,20 @@ pub(crate) fn narrowing_constraint<'db>( constraint: Constraint<'db>, definition: Definition<'db>, ) -> Option> { - match constraint.node { + let constraints = match constraint.node { ConstraintNode::Expression(expression) => { if constraint.is_positive { all_narrowing_constraints_for_expression(db, expression) - .get(&definition.symbol(db)) - .copied() } else { all_negative_narrowing_constraints_for_expression(db, expression) - .get(&definition.symbol(db)) - .copied() } } - ConstraintNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern) - .get(&definition.symbol(db)) - .copied(), + ConstraintNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern), + }; + if let Some(constraints) = constraints { + constraints.get(&definition.symbol(db)).copied() + } else { + None } } @@ -56,7 +58,7 @@ pub(crate) fn narrowing_constraint<'db>( fn all_narrowing_constraints_for_pattern<'db>( db: &'db dyn Db, pattern: PatternConstraint<'db>, -) -> NarrowingConstraints<'db> { +) -> Option> { NarrowingConstraintsBuilder::new(db, ConstraintNode::Pattern(pattern), true).finish() } @@ -64,7 +66,7 @@ fn all_narrowing_constraints_for_pattern<'db>( fn all_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> NarrowingConstraints<'db> { +) -> Option> { NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), true).finish() } @@ -72,7 +74,7 @@ fn all_narrowing_constraints_for_expression<'db>( fn all_negative_narrowing_constraints_for_expression<'db>( db: &'db dyn Db, expression: Expression<'db>, -) -> NarrowingConstraints<'db> { +) -> Option> { NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), false).finish() } @@ -100,11 +102,52 @@ fn generate_isinstance_constraint<'db>( type NarrowingConstraints<'db> = FxHashMap>; +fn merge_constraints_and<'db>( + into: &mut NarrowingConstraints<'db>, + from: NarrowingConstraints<'db>, + db: &'db dyn Db, +) { + for (key, value) in from { + match into.entry(key) { + Entry::Occupied(mut entry) => { + *entry.get_mut() = IntersectionBuilder::new(db) + .add_positive(*entry.get()) + .add_positive(value) + .build(); + } + Entry::Vacant(entry) => { + entry.insert(value); + } + } + } +} + +fn merge_constraints_or<'db>( + into: &mut NarrowingConstraints<'db>, + from: &NarrowingConstraints<'db>, + db: &'db dyn Db, +) { + for (key, value) in from { + match into.entry(*key) { + Entry::Occupied(mut entry) => { + *entry.get_mut() = UnionBuilder::new(db).add(*entry.get()).add(*value).build(); + } + Entry::Vacant(entry) => { + entry.insert(KnownClass::Object.to_instance(db)); + } + } + } + for (key, value) in into.iter_mut() { + if !from.contains_key(key) { + *value = KnownClass::Object.to_instance(db); + } + } +} + struct NarrowingConstraintsBuilder<'db> { db: &'db dyn Db, constraint: ConstraintNode<'db>, is_positive: bool, - constraints: NarrowingConstraints<'db>, } impl<'db> NarrowingConstraintsBuilder<'db> { @@ -113,24 +156,31 @@ impl<'db> NarrowingConstraintsBuilder<'db> { db, constraint, is_positive, - constraints: NarrowingConstraints::default(), } } - fn finish(mut self) -> NarrowingConstraints<'db> { - match self.constraint { + fn finish(mut self) -> Option> { + let constraints: Option> = match self.constraint { ConstraintNode::Expression(expression) => { - self.evaluate_expression_constraint(expression, self.is_positive); + self.evaluate_expression_constraint(expression, self.is_positive) } ConstraintNode::Pattern(pattern) => self.evaluate_pattern_constraint(pattern), + }; + if let Some(mut constraints) = constraints { + constraints.shrink_to_fit(); + Some(constraints) + } else { + None } - self.constraints.shrink_to_fit(); - self.constraints } - fn evaluate_expression_constraint(&mut self, expression: Expression<'db>, is_positive: bool) { + fn evaluate_expression_constraint( + &mut self, + expression: Expression<'db>, + is_positive: bool, + ) -> Option> { let expression_node = expression.node_ref(self.db).node(); - self.evaluate_expression_node_constraint(expression_node, expression, is_positive); + self.evaluate_expression_node_constraint(expression_node, expression, is_positive) } fn evaluate_expression_node_constraint( @@ -138,52 +188,51 @@ impl<'db> NarrowingConstraintsBuilder<'db> { expression_node: &ruff_python_ast::Expr, expression: Expression<'db>, is_positive: bool, - ) { + ) -> Option> { match expression_node { ast::Expr::Compare(expr_compare) => { - self.add_expr_compare(expr_compare, expression, is_positive); + self.evaluate_expr_compare(expr_compare, expression, is_positive) } ast::Expr::Call(expr_call) => { - self.add_expr_call(expr_call, expression, is_positive); + self.evaluate_expr_call(expr_call, expression, is_positive) } - ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => { - self.evaluate_expression_node_constraint( - &unary_op.operand, - expression, - !is_positive, - ); - } - _ => {} // TODO other test expression kinds + ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => self + .evaluate_expression_node_constraint(&unary_op.operand, expression, !is_positive), + ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive), + _ => None, // TODO other test expression kinds } } - fn evaluate_pattern_constraint(&mut self, pattern: PatternConstraint<'db>) { + fn evaluate_pattern_constraint( + &mut self, + pattern: PatternConstraint<'db>, + ) -> Option> { let subject = pattern.subject(self.db); match pattern.pattern(self.db).node() { ast::Pattern::MatchValue(_) => { - // TODO + None // TODO } ast::Pattern::MatchSingleton(singleton_pattern) => { - self.add_match_pattern_singleton(subject, singleton_pattern); + self.evaluate_match_pattern_singleton(subject, singleton_pattern) } ast::Pattern::MatchSequence(_) => { - // TODO + None // TODO } ast::Pattern::MatchMapping(_) => { - // TODO + None // TODO } ast::Pattern::MatchClass(_) => { - // TODO + None // TODO } ast::Pattern::MatchStar(_) => { - // TODO + None // TODO } ast::Pattern::MatchAs(_) => { - // TODO + None // TODO } ast::Pattern::MatchOr(_) => { - // TODO + None // TODO } } } @@ -199,12 +248,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } - fn add_expr_compare( + fn evaluate_expr_compare( &mut self, expr_compare: &ast::ExprCompare, expression: Expression<'db>, is_positive: bool, - ) { + ) -> Option> { let ast::ExprCompare { range: _, left, @@ -214,14 +263,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> { if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) { // If none of the comparators are name expressions, // we have no symbol to narrow down the type of. - return; + return None; } if !is_positive && comparators.len() > 1 { // We can't negate a constraint made by a multi-comparator expression, since we can't // know which comparison part is the one being negated. // For example, the negation of `x is 1 is y is 2`, would be `(x is not 1) or (y is not 1) or (y is not 2)` // and that requires cross-symbol constraints, which we don't support yet. - return; + return None; } let scope = self.scope(); let inference = infer_expression_types(self.db, expression); @@ -229,6 +278,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> { let comparator_tuples = std::iter::once(&**left) .chain(comparators) .tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>(); + let mut constraints = NarrowingConstraints::default(); for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) { if let ast::Expr::Name(ast::ExprName { range: _, @@ -246,20 +296,20 @@ impl<'db> NarrowingConstraintsBuilder<'db> { let ty = IntersectionBuilder::new(self.db) .add_negative(rhs_ty) .build(); - self.constraints.insert(symbol, ty); + constraints.insert(symbol, ty); } else { // Non-singletons cannot be safely narrowed using `is not` } } ast::CmpOp::Is => { - self.constraints.insert(symbol, rhs_ty); + constraints.insert(symbol, rhs_ty); } ast::CmpOp::NotEq => { if rhs_ty.is_single_valued(self.db) { let ty = IntersectionBuilder::new(self.db) .add_negative(rhs_ty) .build(); - self.constraints.insert(symbol, ty); + constraints.insert(symbol, ty); } } _ => { @@ -268,14 +318,15 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } } + Some(constraints) } - fn add_expr_call( + fn evaluate_expr_call( &mut self, expr_call: &ast::ExprCall, expression: Expression<'db>, is_positive: bool, - ) { + ) -> Option> { let scope = self.scope(); let inference = infer_expression_types(self.db, expression); @@ -299,18 +350,21 @@ impl<'db> NarrowingConstraintsBuilder<'db> { if !is_positive { constraint = constraint.negate(self.db); } - self.constraints.insert(symbol, constraint); + let mut constraints = NarrowingConstraints::default(); + constraints.insert(symbol, constraint); + return Some(constraints); } } } } + None } - fn add_match_pattern_singleton( + fn evaluate_match_pattern_singleton( &mut self, subject: &ast::Expr, pattern: &ast::PatternMatchSingleton, - ) { + ) -> Option> { if let Some(ast::ExprName { id, .. }) = subject.as_name_expr() { // SAFETY: we should always have a symbol for every Name node. let symbol = self.symbols().symbol_id_by_name(id).unwrap(); @@ -320,7 +374,64 @@ impl<'db> NarrowingConstraintsBuilder<'db> { ast::Singleton::True => Type::BooleanLiteral(true), ast::Singleton::False => Type::BooleanLiteral(false), }; - self.constraints.insert(symbol, ty); + let mut constraints = NarrowingConstraints::default(); + constraints.insert(symbol, ty); + Some(constraints) + } else { + None + } + } + + fn evaluate_bool_op( + &mut self, + expr_bool_op: &ExprBoolOp, + expression: Expression<'db>, + is_positive: bool, + ) -> Option> { + let inference = infer_expression_types(self.db, expression); + let scope = self.scope(); + let mut sub_constraints = expr_bool_op + .values + .iter() + // filter our arms with statically known truthiness + .filter(|expr| { + inference + .expression_ty(expr.scoped_ast_id(self.db, scope)) + .bool(self.db) + != match expr_bool_op.op { + BoolOp::And => Truthiness::AlwaysTrue, + BoolOp::Or => Truthiness::AlwaysFalse, + } + }) + .map(|sub_expr| { + self.evaluate_expression_node_constraint(sub_expr, expression, is_positive) + }) + .collect::>(); + match (expr_bool_op.op, is_positive) { + (BoolOp::And, true) | (BoolOp::Or, false) => { + let mut aggregation: Option = None; + for sub_constraint in sub_constraints.into_iter().flatten() { + if let Some(ref mut some_aggregation) = aggregation { + merge_constraints_and(some_aggregation, sub_constraint, self.db); + } else { + aggregation = Some(sub_constraint); + } + } + aggregation + } + (BoolOp::Or, true) | (BoolOp::And, false) => { + let (first, rest) = sub_constraints.split_first_mut()?; + if let Some(ref mut first) = first { + for rest_constraint in rest { + if let Some(rest_constraint) = rest_constraint { + merge_constraints_or(first, rest_constraint, self.db); + } else { + return None; + } + } + } + first.clone() + } } } }