diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index b9a8c11bac586..6606cb3640641 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -3429,18 +3429,18 @@ impl KnownFunction { } /// Return the [`ParameterExpectations`] for this function. - const fn takes_type_expression_arguments(self) -> ParameterExpectations { + const fn parameter_expectations(self) -> ParameterExpectations { match self { Self::IsFullyStatic | Self::IsSingleton | Self::IsSingleValued => { - ParameterExpectations::SINGLE_TYPE_EXPRESSION + ParameterExpectations::SingleTypeExpression } Self::IsEquivalentTo | Self::IsSubtypeOf | Self::IsAssignableTo - | Self::IsDisjointFrom => ParameterExpectations::TWO_TYPE_EXPRESSIONS, + | Self::IsDisjointFrom => ParameterExpectations::TwoTypeExpressions, - Self::AssertType => ParameterExpectations::VALUE_EXPRESSION_AND_TYPE_EXPRESSION, + Self::AssertType => ParameterExpectations::ValueExpressionAndTypeExpression, Self::ConstraintFunction(_) | Self::Len @@ -3461,28 +3461,41 @@ enum ParameterExpectations { /// All parameters in the function expect value expressions #[default] AllValueExpressions, - /// The function is special-cased by the type system: one or more parameters expect type expressions - /// rather than value expressions. - SpecialCased(&'static [ParameterExpectation]), + /// The first parameter in the function expects a type expression + SingleTypeExpression, + /// The first two parameters in the function expect type expressions + TwoTypeExpressions, + /// The first parameter in the function expects a value expression, + /// and the second expects a type expression + ValueExpressionAndTypeExpression, } impl ParameterExpectations { - const SINGLE_TYPE_EXPRESSION: Self = Self::SpecialCased(&[ParameterExpectation::Type]); - - const TWO_TYPE_EXPRESSIONS: Self = - Self::SpecialCased(&[ParameterExpectation::Type, ParameterExpectation::Type]); - - const VALUE_EXPRESSION_AND_TYPE_EXPRESSION: Self = - Self::SpecialCased(&[ParameterExpectation::Value, ParameterExpectation::Type]); - /// Query whether the parameter at `parameter_index` expects a value expression or a type expression fn expectation_at_index(self, parameter_index: usize) -> ParameterExpectation { match self { - Self::AllValueExpressions => ParameterExpectation::Value, - Self::SpecialCased(expectations) => expectations - .get(parameter_index) - .copied() - .unwrap_or_default(), + Self::AllValueExpressions => ParameterExpectation::ValueExpression, + Self::SingleTypeExpression => { + if parameter_index == 0 { + ParameterExpectation::TypeExpression + } else { + ParameterExpectation::ValueExpression + } + } + Self::TwoTypeExpressions => { + if parameter_index < 2 { + ParameterExpectation::TypeExpression + } else { + ParameterExpectation::ValueExpression + } + } + Self::ValueExpressionAndTypeExpression => { + if parameter_index == 1 { + ParameterExpectation::TypeExpression + } else { + ParameterExpectation::ValueExpression + } + } } } } @@ -3494,9 +3507,9 @@ impl ParameterExpectations { enum ParameterExpectation { /// The parameter expects a value expression #[default] - Value, + ValueExpression, /// The parameter expects a type expression - Type, + TypeExpression, } #[salsa::interned] diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 466954fbf9a24..7932777c66e59 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2609,8 +2609,8 @@ impl<'db> TypeInferenceBuilder<'db> { .enumerate() .map(|(index, arg_or_keyword)| { let infer_argument_type = match parameter_expectations.expectation_at_index(index) { - ParameterExpectation::Type => Self::infer_type_expression, - ParameterExpectation::Value => Self::infer_expression, + ParameterExpectation::TypeExpression => Self::infer_type_expression, + ParameterExpectation::ValueExpression => Self::infer_expression, }; match arg_or_keyword { @@ -3155,13 +3155,13 @@ impl<'db> TypeInferenceBuilder<'db> { let function_type = self.infer_expression(func); - let infer_arguments_as_type_expressions = function_type + let parameter_expectations = function_type .into_function_literal() .and_then(|f| f.known(self.db())) - .map(KnownFunction::takes_type_expression_arguments) + .map(KnownFunction::parameter_expectations) .unwrap_or_default(); - let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions); + let call_arguments = self.infer_arguments(arguments, parameter_expectations); function_type .call(self.db(), &call_arguments) .unwrap_with_diagnostic(&self.context, call_expression.into())