diff --git a/packages/pyright-internal/src/analyzer/enums.ts b/packages/pyright-internal/src/analyzer/enums.ts index d5b325e6b2a5..6b8c8c9f70c7 100644 --- a/packages/pyright-internal/src/analyzer/enums.ts +++ b/packages/pyright-internal/src/analyzer/enums.ts @@ -37,6 +37,15 @@ import { maxTypeRecursionCount, } from './types'; +interface EnumEvalStackEntry { + classType: ClassType; + memberName: string; +} + +// This stack is used to prevent infinite recursion when evaluating +// enum members that refer to other enum members. +const enumEvalStack: EnumEvalStackEntry[] = []; + // Determines whether the class is an Enum metaclass or a subclass thereof. export function isEnumMetaclass(classType: ClassType) { return classType.shared.mro.some( @@ -307,205 +316,223 @@ export function transformTypeForEnumMember( ignoreAnnotation = false, recursionCount = 0 ): Type | undefined { - if (recursionCount > maxTypeRecursionCount) { - return undefined; - } - recursionCount++; - if (!ClassType.isEnumClass(classType)) { return undefined; } - const memberInfo = lookUpClassMember(classType, memberName); - if (!memberInfo || !isClass(memberInfo.classType) || !ClassType.isEnumClass(memberInfo.classType)) { + if (recursionCount > maxTypeRecursionCount) { return undefined; } + recursionCount++; - const decls = memberInfo.symbol.getDeclarations(); - if (decls.length < 1) { + // Avoid infinite recursion. + if ( + enumEvalStack.find( + (entry) => ClassType.isSameGenericClass(entry.classType, classType) && entry.memberName === memberName + ) + ) { return undefined; } - const primaryDecl = decls[0]; + enumEvalStack.push({ classType, memberName }); - let isMemberOfEnumeration = false; - let isUnpackedTuple = false; - let valueTypeExprNode: ExpressionNode | undefined; - let declaredTypeNode: ExpressionNode | undefined; - let nameNode: NameNode | undefined; + try { + const memberInfo = lookUpClassMember(classType, memberName); + if (!memberInfo || !isClass(memberInfo.classType) || !ClassType.isEnumClass(memberInfo.classType)) { + return undefined; + } - if (primaryDecl.node.nodeType === ParseNodeType.Name) { - nameNode = primaryDecl.node; - } else if ( - primaryDecl.node.nodeType === ParseNodeType.Function || - primaryDecl.node.nodeType === ParseNodeType.Class - ) { - // Handle the case where a method or class is decorated with @enum.member. - nameNode = primaryDecl.node.d.name; - } else { - return undefined; - } + const decls = memberInfo.symbol.getDeclarations(); + if (decls.length < 1) { + return undefined; + } - if (nameNode.parent?.nodeType === ParseNodeType.Assignment && nameNode.parent.d.leftExpr === nameNode) { - isMemberOfEnumeration = true; - valueTypeExprNode = nameNode.parent.d.rightExpr; - } else if ( - nameNode.parent?.nodeType === ParseNodeType.Tuple && - nameNode.parent.parent?.nodeType === ParseNodeType.Assignment - ) { - isMemberOfEnumeration = true; - isUnpackedTuple = true; - valueTypeExprNode = nameNode.parent.parent.d.rightExpr; - } else if (nameNode.parent?.nodeType === ParseNodeType.TypeAnnotation && nameNode.parent.d.valueExpr === nameNode) { - if (ignoreAnnotation) { + const primaryDecl = decls[0]; + + let isMemberOfEnumeration = false; + let isUnpackedTuple = false; + let valueTypeExprNode: ExpressionNode | undefined; + let declaredTypeNode: ExpressionNode | undefined; + let nameNode: NameNode | undefined; + + if (primaryDecl.node.nodeType === ParseNodeType.Name) { + nameNode = primaryDecl.node; + } else if ( + primaryDecl.node.nodeType === ParseNodeType.Function || + primaryDecl.node.nodeType === ParseNodeType.Class + ) { + // Handle the case where a method or class is decorated with @enum.member. + nameNode = primaryDecl.node.d.name; + } else { + return undefined; + } + + if (nameNode.parent?.nodeType === ParseNodeType.Assignment && nameNode.parent.d.leftExpr === nameNode) { + isMemberOfEnumeration = true; + valueTypeExprNode = nameNode.parent.d.rightExpr; + } else if ( + nameNode.parent?.nodeType === ParseNodeType.Tuple && + nameNode.parent.parent?.nodeType === ParseNodeType.Assignment + ) { isMemberOfEnumeration = true; + isUnpackedTuple = true; + valueTypeExprNode = nameNode.parent.parent.d.rightExpr; + } else if ( + nameNode.parent?.nodeType === ParseNodeType.TypeAnnotation && + nameNode.parent.d.valueExpr === nameNode + ) { + if (ignoreAnnotation) { + isMemberOfEnumeration = true; + } + declaredTypeNode = nameNode.parent.d.annotation; } - declaredTypeNode = nameNode.parent.d.annotation; - } - // The spec specifically excludes names that start and end with a single underscore. - // This also includes dunder names. - if (isSingleDunderName(memberName)) { - return undefined; - } + // The spec specifically excludes names that start and end with a single underscore. + // This also includes dunder names. + if (isSingleDunderName(memberName)) { + return undefined; + } - // Specifically exclude "value" and "name". These are reserved by the enum metaclass. - if (memberName === 'name' || memberName === 'value') { - return undefined; - } + // Specifically exclude "value" and "name". These are reserved by the enum metaclass. + if (memberName === 'name' || memberName === 'value') { + return undefined; + } - const declaredType = declaredTypeNode ? evaluator.getTypeOfAnnotation(declaredTypeNode) : undefined; - let assignedType: Type | undefined; + const declaredType = declaredTypeNode ? evaluator.getTypeOfAnnotation(declaredTypeNode) : undefined; + let assignedType: Type | undefined; - if (valueTypeExprNode) { - const evalFlags = getFileInfo(valueTypeExprNode).isStubFile ? EvalFlags.ConvertEllipsisToAny : undefined; - assignedType = evaluator.getTypeOfExpression(valueTypeExprNode, evalFlags).type; - } + if (valueTypeExprNode) { + const evalFlags = getFileInfo(valueTypeExprNode).isStubFile ? EvalFlags.ConvertEllipsisToAny : undefined; + assignedType = evaluator.getTypeOfExpression(valueTypeExprNode, evalFlags).type; + } - // Handle aliases to other enum members within the same enum. - if (valueTypeExprNode?.nodeType === ParseNodeType.Name && valueTypeExprNode.d.value !== memberName) { - const aliasedEnumType = transformTypeForEnumMember( - evaluator, - classType, - valueTypeExprNode.d.value, - /* ignoreAnnotation */ false, - recursionCount - ); + // Handle aliases to other enum members within the same enum. + if (valueTypeExprNode?.nodeType === ParseNodeType.Name && valueTypeExprNode.d.value !== memberName) { + const aliasedEnumType = transformTypeForEnumMember( + evaluator, + classType, + valueTypeExprNode.d.value, + /* ignoreAnnotation */ false, + recursionCount + ); - if ( - aliasedEnumType && - isClassInstance(aliasedEnumType) && - ClassType.isSameGenericClass(aliasedEnumType, ClassType.cloneAsInstance(memberInfo.classType)) && - aliasedEnumType.priv.literalValue !== undefined - ) { - return aliasedEnumType; + if ( + aliasedEnumType && + isClassInstance(aliasedEnumType) && + ClassType.isSameGenericClass(aliasedEnumType, ClassType.cloneAsInstance(memberInfo.classType)) && + aliasedEnumType.priv.literalValue !== undefined + ) { + return aliasedEnumType; + } } - } - if (primaryDecl.node.nodeType === ParseNodeType.Function) { - const functionTypeInfo = evaluator.getTypeOfFunction(primaryDecl.node); - if (functionTypeInfo) { - assignedType = functionTypeInfo.decoratedType; - } - } else if (primaryDecl.node.nodeType === ParseNodeType.Class) { - const classTypeInfo = evaluator.getTypeOfClass(primaryDecl.node); - if (classTypeInfo) { - assignedType = classTypeInfo.decoratedType; - - // If the class is not marked as a member or a non-member, the behavior - // depends on the version of Python. In versions prior to 3.13, classes - // are treated as members. - if (isInstantiableClass(assignedType)) { - const fileInfo = getFileInfo(primaryDecl.node); - isMemberOfEnumeration = PythonVersion.isLessThan( - fileInfo.executionEnvironment.pythonVersion, - pythonVersion3_13 - ); + if (primaryDecl.node.nodeType === ParseNodeType.Function) { + const functionTypeInfo = evaluator.getTypeOfFunction(primaryDecl.node); + if (functionTypeInfo) { + assignedType = functionTypeInfo.decoratedType; + } + } else if (primaryDecl.node.nodeType === ParseNodeType.Class) { + const classTypeInfo = evaluator.getTypeOfClass(primaryDecl.node); + if (classTypeInfo) { + assignedType = classTypeInfo.decoratedType; + + // If the class is not marked as a member or a non-member, the behavior + // depends on the version of Python. In versions prior to 3.13, classes + // are treated as members. + if (isInstantiableClass(assignedType)) { + const fileInfo = getFileInfo(primaryDecl.node); + isMemberOfEnumeration = PythonVersion.isLessThan( + fileInfo.executionEnvironment.pythonVersion, + pythonVersion3_13 + ); + } } } - } - let valueType = declaredType ?? assignedType ?? UnknownType.create(); - - // If the LHS is an unpacked tuple, we need to handle this as - // a special case. - if (isUnpackedTuple) { - valueType = - evaluator.getTypeOfIterator( - { type: valueType }, - /* isAsync */ false, - nameNode, - /* emitNotIterableError */ false - )?.type ?? UnknownType.create(); - } + let valueType = declaredType ?? assignedType ?? UnknownType.create(); - // The spec excludes descriptors. - if (isClassInstance(valueType) && ClassType.getSymbolTable(valueType).get('__get__')) { - return undefined; - } + // If the LHS is an unpacked tuple, we need to handle this as + // a special case. + if (isUnpackedTuple) { + valueType = + evaluator.getTypeOfIterator( + { type: valueType }, + /* isAsync */ false, + nameNode, + /* emitNotIterableError */ false + )?.type ?? UnknownType.create(); + } - // The spec excludes private (mangled) names. - if (isPrivateName(memberName)) { - return undefined; - } + // The spec excludes descriptors. + if (isClassInstance(valueType) && ClassType.getSymbolTable(valueType).get('__get__')) { + return undefined; + } - // The enum spec doesn't explicitly specify this, but it - // appears that callables are excluded. - if (!findSubtype(valueType, (subtype) => !isFunction(subtype) && !isOverloaded(subtype))) { - return undefined; - } + // The spec excludes private (mangled) names. + if (isPrivateName(memberName)) { + return undefined; + } - if ( - !assignedType && - nameNode.parent?.nodeType === ParseNodeType.Assignment && - nameNode.parent.d.leftExpr === nameNode - ) { - assignedType = evaluator.getTypeOfExpression( - nameNode.parent.d.rightExpr, - /* flags */ undefined, - makeInferenceContext(declaredType) - ).type; - } + // The enum spec doesn't explicitly specify this, but it + // appears that callables are excluded. + if (!findSubtype(valueType, (subtype) => !isFunction(subtype) && !isOverloaded(subtype))) { + return undefined; + } + + if ( + !assignedType && + nameNode.parent?.nodeType === ParseNodeType.Assignment && + nameNode.parent.d.leftExpr === nameNode + ) { + assignedType = evaluator.getTypeOfExpression( + nameNode.parent.d.rightExpr, + /* flags */ undefined, + makeInferenceContext(declaredType) + ).type; + } + + // Handle the Python 3.11 "enum.member()" and "enum.nonmember()" features. + if (assignedType && isClassInstance(assignedType) && ClassType.isBuiltIn(assignedType)) { + if (assignedType.shared.fullName === 'enum.nonmember') { + const nonMemberType = + assignedType.priv.typeArgs && assignedType.priv.typeArgs.length > 0 + ? assignedType.priv.typeArgs[0] + : UnknownType.create(); + + // If the type of the nonmember is declared and the assigned value has + // a compatible type, use the declared type. + if (declaredType && evaluator.assignType(declaredType, nonMemberType)) { + return declaredType; + } - // Handle the Python 3.11 "enum.member()" and "enum.nonmember()" features. - if (assignedType && isClassInstance(assignedType) && ClassType.isBuiltIn(assignedType)) { - if (assignedType.shared.fullName === 'enum.nonmember') { - const nonMemberType = - assignedType.priv.typeArgs && assignedType.priv.typeArgs.length > 0 - ? assignedType.priv.typeArgs[0] - : UnknownType.create(); - - // If the type of the nonmember is declared and the assigned value has - // a compatible type, use the declared type. - if (declaredType && evaluator.assignType(declaredType, nonMemberType)) { - return declaredType; + return nonMemberType; } - return nonMemberType; + if (assignedType.shared.fullName === 'enum.member') { + valueType = + assignedType.priv.typeArgs && assignedType.priv.typeArgs.length > 0 + ? assignedType.priv.typeArgs[0] + : UnknownType.create(); + isMemberOfEnumeration = true; + } } - if (assignedType.shared.fullName === 'enum.member') { - valueType = - assignedType.priv.typeArgs && assignedType.priv.typeArgs.length > 0 - ? assignedType.priv.typeArgs[0] - : UnknownType.create(); - isMemberOfEnumeration = true; + if (!isMemberOfEnumeration) { + return undefined; } - } - if (!isMemberOfEnumeration) { - return undefined; - } - - const enumLiteral = new EnumLiteral( - memberInfo.classType.shared.fullName, - memberInfo.classType.shared.name, - memberName, - valueType, - isReprEnumClass(classType) - ); + const enumLiteral = new EnumLiteral( + memberInfo.classType.shared.fullName, + memberInfo.classType.shared.name, + memberName, + valueType, + isReprEnumClass(classType) + ); - return ClassType.cloneAsInstance(ClassType.cloneWithLiteral(memberInfo.classType, enumLiteral)); + return ClassType.cloneAsInstance(ClassType.cloneWithLiteral(memberInfo.classType, enumLiteral)); + } finally { + enumEvalStack.pop(); + } } export function isDeclInEnumClass(evaluator: TypeEvaluator, decl: VariableDeclaration): boolean { diff --git a/packages/pyright-internal/src/tests/samples/enum14.py b/packages/pyright-internal/src/tests/samples/enum14.py new file mode 100644 index 000000000000..1bba055f3467 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/enum14.py @@ -0,0 +1,16 @@ +# This sample tests certain error conditions that previously caused +# an infinite recursion condition in the type evaluator. + +from __future__ import annotations +from enum import Enum +from typing import Literal + + +class A(Enum): + # This should generate two errors. + x: Literal[A.x] + + +class B(Enum): + # This should generate an error. + x: B.x diff --git a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts index a166267712b5..fc89d94b3315 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator3.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator3.test.ts @@ -1082,6 +1082,12 @@ test('Enum13', () => { TestUtils.validateResults(analysisResults, 3); }); +test('Enum14', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['enum14.py']); + + TestUtils.validateResults(analysisResults, 3); +}); + test('EnumAuto1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['enumAuto1.py']);