Skip to content

Commit

Permalink
[red-knot] Refactor KnownFunction::takes_expression_arguments() (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood authored Jan 10, 2025
1 parent 12f86f3 commit c82932e
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 37 deletions.
103 changes: 78 additions & 25 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3428,37 +3428,90 @@ impl KnownFunction {
}
}

/// Returns a `u32` bitmask specifying whether or not
/// arguments given to a particular function
/// should be interpreted as type expressions or value expressions.
///
/// The argument is treated as a type expression
/// when the corresponding bit is `1`.
/// The least-significant (right-most) bit corresponds to
/// the argument at the index 0 and so on.
///
/// For example, `assert_type()` has the bitmask value of `0b10`.
/// This means the second argument is a type expression and the first a value expression.
const fn takes_type_expression_arguments(self) -> u32 {
const ALL_VALUES: u32 = 0b0;
const SINGLE_TYPE: u32 = 0b1;
const TYPE_TYPE: u32 = 0b11;
const VALUE_TYPE: u32 = 0b10;
/// Return the [`ParameterExpectations`] for this function.
const fn parameter_expectations(self) -> ParameterExpectations {
match self {
Self::IsFullyStatic | Self::IsSingleton | Self::IsSingleValued => {
ParameterExpectations::SingleTypeExpression
}

Self::IsEquivalentTo
| Self::IsSubtypeOf
| Self::IsAssignableTo
| Self::IsDisjointFrom => ParameterExpectations::TwoTypeExpressions,

Self::AssertType => ParameterExpectations::ValueExpressionAndTypeExpression,

Self::ConstraintFunction(_)
| Self::Len
| Self::Final
| Self::NoTypeCheck
| Self::RevealType
| Self::StaticAssert => ParameterExpectations::AllValueExpressions,
}
}
}

/// Describes whether the parameters in a function expect value expressions or type expressions.
///
/// Whether a specific parameter in the function expects a type expression can be queried
/// using [`ParameterExpectations::expectation_at_index`].
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
enum ParameterExpectations {
/// All parameters in the function expect value expressions
#[default]
AllValueExpressions,
/// The first parameter in the function expects a type expression
SingleTypeExpression,
/// The first two parameters in the function expect type expressions
TwoTypeExpressions,
/// The first parameter in the function expects a value expression,
/// and the second expects a type expression
ValueExpressionAndTypeExpression,
}

impl ParameterExpectations {
/// Query whether the parameter at `parameter_index` expects a value expression or a type expression
fn expectation_at_index(self, parameter_index: usize) -> ParameterExpectation {
match self {
KnownFunction::IsEquivalentTo => TYPE_TYPE,
KnownFunction::IsSubtypeOf => TYPE_TYPE,
KnownFunction::IsAssignableTo => TYPE_TYPE,
KnownFunction::IsDisjointFrom => TYPE_TYPE,
KnownFunction::IsFullyStatic => SINGLE_TYPE,
KnownFunction::IsSingleton => SINGLE_TYPE,
KnownFunction::IsSingleValued => SINGLE_TYPE,
KnownFunction::AssertType => VALUE_TYPE,
_ => ALL_VALUES,
Self::AllValueExpressions => ParameterExpectation::ValueExpression,
Self::SingleTypeExpression => {
if parameter_index == 0 {
ParameterExpectation::TypeExpression
} else {
ParameterExpectation::ValueExpression
}
}
Self::TwoTypeExpressions => {
if parameter_index < 2 {
ParameterExpectation::TypeExpression
} else {
ParameterExpectation::ValueExpression
}
}
Self::ValueExpressionAndTypeExpression => {
if parameter_index == 1 {
ParameterExpectation::TypeExpression
} else {
ParameterExpectation::ValueExpression
}
}
}
}
}

/// Whether a single parameter in a given function expects a value expression or a [type expression]
///
/// [type expression]: https://typing.readthedocs.io/en/latest/spec/annotations.html#type-and-annotation-expressions
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
enum ParameterExpectation {
/// The parameter expects a value expression
#[default]
ValueExpression,
/// The parameter expects a type expression
TypeExpression,
}

#[salsa::interned]
pub struct ModuleLiteralType<'db> {
/// The file in which this module was imported.
Expand Down
22 changes: 10 additions & 12 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ use super::slots::check_class_slots;
use super::string_annotation::{
parse_string_annotation, BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION,
};
use super::{ParameterExpectation, ParameterExpectations};

/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope.
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the
Expand Down Expand Up @@ -956,7 +957,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_type_parameters(type_params);

if let Some(arguments) = class.arguments.as_deref() {
self.infer_arguments(arguments, 0b0);
self.infer_arguments(arguments, ParameterExpectations::default());
}
}

Expand Down Expand Up @@ -2601,18 +2602,15 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_arguments<'a>(
&mut self,
arguments: &'a ast::Arguments,
infer_as_type_expressions: u32,
parameter_expectations: ParameterExpectations,
) -> CallArguments<'a, 'db> {
arguments
.arguments_source_order()
.enumerate()
.map(|(index, arg_or_keyword)| {
let infer_argument_type = if index < u32::BITS as usize
&& infer_as_type_expressions & (1 << index) != 0
{
Self::infer_type_expression
} else {
Self::infer_expression
let infer_argument_type = match parameter_expectations.expectation_at_index(index) {
ParameterExpectation::TypeExpression => Self::infer_type_expression,
ParameterExpectation::ValueExpression => Self::infer_expression,
};

match arg_or_keyword {
Expand Down Expand Up @@ -3157,13 +3155,13 @@ impl<'db> TypeInferenceBuilder<'db> {

let function_type = self.infer_expression(func);

let infer_arguments_as_type_expressions = function_type
let parameter_expectations = function_type
.into_function_literal()
.and_then(|f| f.known(self.db()))
.map(KnownFunction::takes_type_expression_arguments)
.unwrap_or(0b0);
.map(KnownFunction::parameter_expectations)
.unwrap_or_default();

let call_arguments = self.infer_arguments(arguments, infer_arguments_as_type_expressions);
let call_arguments = self.infer_arguments(arguments, parameter_expectations);
function_type
.call(self.db(), &call_arguments)
.unwrap_with_diagnostic(&self.context, call_expression.into())
Expand Down

0 comments on commit c82932e

Please sign in to comment.