From a7b8cc08f09dfad0bd30ded79852ae45cc24a6c1 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 10 Sep 2024 18:41:45 -0400 Subject: [PATCH] [red-knot] Fix `.to_instance()` for union types (#13319) --- crates/red_knot_python_semantic/src/types.rs | 20 ++++++++++++++-- .../src/types/infer.rs | 24 ++++++++++--------- 2 files changed, 31 insertions(+), 13 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index e61a0f4843fee..093dd205ecb0d 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -445,12 +445,28 @@ impl<'db> Type<'db> { } #[must_use] - pub fn to_instance(&self) -> Type<'db> { + pub fn to_instance(&self, db: &'db dyn Db) -> Type<'db> { match self { Type::Any => Type::Any, Type::Unknown => Type::Unknown, + Type::Unbound => Type::Unknown, + Type::Never => Type::Never, Type::Class(class) => Type::Instance(*class), - _ => Type::Unknown, // TODO type errors + Type::Union(union) => union.map(db, |element| element.to_instance(db)), + // TODO: we can probably do better here: --Alex + Type::Intersection(_) => Type::Unknown, + // TODO: calling `.to_instance()` on any of these should result in a diagnostic, + // since they already indicate that the object is an instance of some kind: + Type::BooleanLiteral(_) + | Type::BytesLiteral(_) + | Type::Function(_) + | Type::Instance(_) + | Type::Module(_) + | Type::IntLiteral(_) + | Type::StringLiteral(_) + | Type::Tuple(_) + | Type::LiteralString + | Type::None => Type::Unknown, } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 3dd1a4d1ff757..30afbafc3c06b 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1457,9 +1457,11 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Number::Int(n) => n .as_i64() .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), - ast::Number::Float(_) => builtins_symbol_ty(self.db, "float").to_instance(), - ast::Number::Complex { .. } => builtins_symbol_ty(self.db, "complex").to_instance(), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)), + ast::Number::Float(_) => builtins_symbol_ty(self.db, "float").to_instance(self.db), + ast::Number::Complex { .. } => { + builtins_symbol_ty(self.db, "complex").to_instance(self.db) + } } } @@ -1573,7 +1575,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty(self.db, "list").to_instance() + builtins_symbol_ty(self.db, "list").to_instance(self.db) } fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> { @@ -1584,7 +1586,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty(self.db, "set").to_instance() + builtins_symbol_ty(self.db, "set").to_instance(self.db) } fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> { @@ -1596,7 +1598,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty(self.db, "dict").to_instance() + builtins_symbol_ty(self.db, "dict").to_instance(self.db) } /// Infer the type of the `iter` expression of the first comprehension. @@ -2067,22 +2069,22 @@ impl<'db> TypeInferenceBuilder<'db> { (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n .checked_add(m) .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n .checked_sub(m) .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n .checked_mul(m) .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Div) => n .checked_div(m) .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance(self.db)), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n .checked_rem(m) @@ -2311,7 +2313,7 @@ impl<'db> TypeInferenceBuilder<'db> { name.ctx ); - self.infer_name_expression(name).to_instance() + self.infer_name_expression(name).to_instance(self.db) } ast::Expr::NoneLiteral(_literal) => Type::None,