From 18114f0c6b1076b20e19e6110ebc21f06ca0a94c Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Fri, 28 Jul 2023 08:47:27 +0200 Subject: [PATCH] Refactor refined function logic --- .../src/dotty/tools/dotc/ast/TreeInfo.scala | 2 +- compiler/src/dotty/tools/dotc/ast/tpd.scala | 9 +- .../src/dotty/tools/dotc/cc/CaptureSet.scala | 8 +- .../dotty/tools/dotc/cc/CheckCaptures.scala | 58 +++++----- compiler/src/dotty/tools/dotc/cc/Setup.scala | 16 +-- .../src/dotty/tools/dotc/cc/Synthetics.scala | 4 +- .../dotty/tools/dotc/core/Definitions.scala | 109 +++++++++--------- .../dotty/tools/dotc/core/TypeErasure.scala | 4 +- .../src/dotty/tools/dotc/core/Types.scala | 40 ++++--- .../tools/dotc/printing/PlainPrinter.scala | 6 +- .../dotty/tools/dotc/quoted/Interpreter.scala | 16 +-- .../dotty/tools/dotc/transform/Bridges.scala | 18 +-- .../transform/ContextFunctionResults.scala | 32 +++-- .../tools/dotc/transform/PickleQuotes.scala | 2 +- .../dotc/transform/SpecializeFunctions.scala | 2 +- .../dotty/tools/dotc/transform/Splicing.scala | 4 +- .../tools/dotc/transform/TreeChecker.scala | 4 +- .../dotty/tools/dotc/typer/Applications.scala | 16 +-- .../tools/dotc/typer/ErrorReporting.scala | 5 +- .../src/dotty/tools/dotc/typer/Namer.scala | 10 +- .../dotty/tools/dotc/typer/ProtoTypes.scala | 10 +- .../tools/dotc/typer/QuotesAndSplices.scala | 4 +- .../dotty/tools/dotc/typer/Synthesizer.scala | 2 +- .../src/dotty/tools/dotc/typer/Typer.scala | 39 +++---- .../quoted/runtime/impl/QuotesImpl.scala | 6 +- tests/neg-custom-args/captures/byname.check | 2 +- tests/pos-macros/erasedArgs/Macro_1.scala | 7 ++ tests/pos-macros/erasedArgs/Test_2.scala | 1 + 28 files changed, 226 insertions(+), 210 deletions(-) create mode 100644 tests/pos-macros/erasedArgs/Macro_1.scala create mode 100644 tests/pos-macros/erasedArgs/Test_2.scala diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index 6659818b333e..0f87c6bd7702 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -954,7 +954,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] => def isStructuralTermSelectOrApply(tree: Tree)(using Context): Boolean = { def isStructuralTermSelect(tree: Select) = def hasRefinement(qualtpe: Type): Boolean = qualtpe.dealias match - case defn.PolyFunctionOf(_) => + case defn.FunctionOf(_) => false case RefinedType(parent, rname, rinfo) => rname == tree.name || hasRefinement(parent) diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index ad2676624b0f..85546304d6de 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -1152,13 +1152,10 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def etaExpandCFT(using Context): Tree = def expand(target: Tree, tp: Type)(using Context): Tree = tp match - case defn.ContextFunctionType(argTypes, resType, _) => - val anonFun = newAnonFun( - ctx.owner, - MethodType.companion(isContextual = true)(argTypes, resType), - coord = ctx.owner.coord) + case defn.FunctionOf(mt: MethodType) if mt.isContextualMethod && !mt.isResultDependent => // TODO handle result-dependent functions? + val anonFun = newAnonFun(ctx.owner, mt, coord = ctx.owner.coord) def lambdaBody(refss: List[List[Tree]]) = - expand(target.select(nme.apply).appliedToArgss(refss), resType)( + expand(target.select(nme.apply).appliedToArgss(refss), mt.resType)( using ctx.withOwner(anonFun)) Closure(anonFun, lambdaBody) case _ => diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala index 3f2beaa3ff55..cdca83eda2a8 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureSet.scala @@ -876,13 +876,13 @@ object CaptureSet: empty case CapturingType(parent, refs) => recur(parent) ++ refs - case tpd @ RefinedType(parent, _, rinfo: MethodType) - if followResult && defn.isFunctionNType(tpd) => - ofType(parent, followResult = false) // pick up capture set from parent type + case tpd @ defn.RefinedFunctionOf(rinfo: MethodType) + if followResult => + ofType(tpd.parent, followResult = false) // pick up capture set from parent type ++ (recur(rinfo.resType) // add capture set of result -- CaptureSet(rinfo.paramRefs.filter(_.isTracked)*)) // but disregard bound parameters case tpd @ AppliedType(tycon, args) => - if followResult && defn.isNonRefinedFunction(tpd) then + if followResult && defn.isFunctionNType(tpd) then recur(args.last) // must be (pure) FunctionN type since ImpureFunctions have already // been eliminated in selector's dealias. Use capture set of result. diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index b6b5d569677c..baf823f28630 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -195,7 +195,7 @@ class CheckCaptures extends Recheck, SymTransformer: capt.println(i"solving $t") refs.solve() traverse(parent) - case t @ RefinedType(_, nme.apply, rinfo) if defn.isFunctionType(t) => + case defn.RefinedFunctionOf(rinfo) => traverse(rinfo) case tp: TypeVar => case tp: TypeRef => @@ -302,8 +302,8 @@ class CheckCaptures extends Recheck, SymTransformer: t case _ => val t1 = t match - case t @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(t) => - t.derivedRefinedType(parent, rname, this(rinfo)) + case t @ defn.RefinedFunctionOf(rinfo: MethodType) => + t.derivedRefinedType(t.parent, t.refinedName, this(rinfo)) case _ => mapOver(t) if variance > 0 then t1 @@ -408,10 +408,10 @@ class CheckCaptures extends Recheck, SymTransformer: else if meth == defn.Caps_unsafeUnbox then mapArgUsing(_.forceBoxStatus(false)) else if meth == defn.Caps_unsafeBoxFunArg then - mapArgUsing: - case defn.FunctionOf(paramtpe :: Nil, restpe, isContextual) => - defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContextual) - + mapArgUsing: tp => + val defn.FunctionOf(mt: MethodType) = tp.dealias: @unchecked + mt.derivedLambdaType(resType = mt.resType.forceBoxStatus(true)) + .toFunctionType() else super.recheckApply(tree, pt) match case appType @ CapturingType(appType1, refs) => @@ -502,8 +502,9 @@ class CheckCaptures extends Recheck, SymTransformer: block match case closureDef(mdef) => pt.dealias match - case defn.FunctionOf(ptformals, _, _) - if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) => + case defn.FunctionOf(mt0: MethodType) + if mt0.paramInfos.nonEmpty && mt0.paramInfos.forall(_.captureSet.isAlwaysEmpty) => + val ptformals = mt0.paramInfos // Redo setup of the anonymous function so that formal parameters don't // get capture sets. This is important to avoid false widenings to `cap` // when taking the base type of the actual closures's dependent function @@ -696,10 +697,6 @@ class CheckCaptures extends Recheck, SymTransformer: //println(i"check conforms $actual1 <<< $expected1") super.checkConformsExpr(actual1, expected1, tree) - private def toDepFun(args: List[Type], resultType: Type, isContextual: Boolean)(using Context): Type = - MethodType.companion(isContextual = isContextual)(args, resultType) - .toFunctionType(isJava = false, alwaysDependent = true) - /** Turn `expected` into a dependent function when `actual` is dependent. */ private def alignDependentFunction(expected: Type, actual: Type)(using Context): Type = def recur(expected: Type): Type = expected.dealias match @@ -707,10 +704,12 @@ class CheckCaptures extends Recheck, SymTransformer: val eparent1 = recur(eparent) if eparent1 eq eparent then expected else CapturingType(eparent1, refs, boxed = expected0.isBoxed) - case expected @ defn.FunctionOf(args, resultType, isContextual) - if defn.isNonRefinedFunction(expected) && defn.isFunctionNType(actual) && !defn.isNonRefinedFunction(actual) => - val expected1 = toDepFun(args, resultType, isContextual) - expected1 + case defn.FunctionOf(mt: MethodType) => + actual.dealias match + case defn.FunctionOf(mt2: MethodType) if mt2.isResultDependent => + mt.toFunctionType(alwaysDependent = true) + case _ => + expected case _ => expected recur(expected) @@ -781,9 +780,8 @@ class CheckCaptures extends Recheck, SymTransformer: try val (eargs, eres) = expected.dealias.stripCapturing match - case defn.FunctionOf(eargs, eres, _) => (eargs, eres) case expected: MethodType => (expected.paramInfos, expected.resType) - case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionNType(expected) => (rinfo.paramInfos, rinfo.resType) + case defn.FunctionOf(mt: MethodType) => (mt.paramInfos, mt.resType) case _ => (aargs.map(_ => WildcardType), WildcardType) val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) } val ares1 = adapt(ares, eres, covariant) @@ -808,7 +806,7 @@ class CheckCaptures extends Recheck, SymTransformer: try val eres = expected.dealias.stripCapturing match - case RefinedType(_, _, rinfo: PolyType) => rinfo.resType + case defn.PolyFunctionOf(rinfo: PolyType) => rinfo.resType case expected: PolyType => expected.resType case _ => WildcardType @@ -842,26 +840,26 @@ class CheckCaptures extends Recheck, SymTransformer: // Adapt the inner shape type: get the adapted shape type, and the capture set leaked during adaptation val (styp1, leaked) = styp match { - case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) => + case actual @ AppliedType(tycon, args) if defn.isFunctionNType(actual) => adaptFun(actual, args.init, args.last, expected, covariant, insertBox, (aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1)) - case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) => + case actual @ defn.RefinedFunctionOf(rinfo: MethodType) => // TODO Find a way to combine handling of generic and dependent function types (here and elsewhere) adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox, (aargs1, ares1) => rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1) - .toFunctionType(isJava = false, alwaysDependent = true)) - case actual: MethodType => - adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox, - (aargs1, ares1) => - actual.derivedLambdaType(paramInfos = aargs1, resType = ares1)) - case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionType(actual) => + .toFunctionType(alwaysDependent = true)) + case actual @ defn.RefinedFunctionOf(rinfo: PolyType) => adaptTypeFun(actual, rinfo.resType, expected, covariant, insertBox, ares1 => val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1) - val actual1 = actual.derivedRefinedType(p, nme, rinfo1) + val actual1 = actual.derivedRefinedType(actual.parent, actual.refinedName, rinfo1) actual1 ) + case actual: MethodType => + adaptFun(actual, actual.paramInfos, actual.resType, expected, covariant, insertBox, + (aargs1, ares1) => + actual.derivedLambdaType(paramInfos = aargs1, resType = ares1)) case _ => (styp, CaptureSet()) } @@ -1080,7 +1078,7 @@ class CheckCaptures extends Recheck, SymTransformer: case CapturingType(parent, refs) => healCaptureSet(refs) traverse(parent) - case tp @ RefinedType(parent, rname, rinfo: MethodType) if defn.isFunctionType(tp) => + case defn.RefinedFunctionOf(rinfo: MethodType) => traverse(rinfo) case tp: TermLambda => val saved = allowed diff --git a/compiler/src/dotty/tools/dotc/cc/Setup.scala b/compiler/src/dotty/tools/dotc/cc/Setup.scala index 4c32c2908635..e861ca95e28b 100644 --- a/compiler/src/dotty/tools/dotc/cc/Setup.scala +++ b/compiler/src/dotty/tools/dotc/cc/Setup.scala @@ -40,7 +40,7 @@ extends tpd.TreeTraverser: MethodType.companion( isContextual = defn.isContextFunctionClass(tycon.classSymbol), )(argTypes, resType) - .toFunctionType(isJava = false, alwaysDependent = true) + .toFunctionType(alwaysDependent = true) /** If `tp` is an unboxed capturing type or a function returning an unboxed capturing type, * convert it to be boxed. @@ -49,15 +49,15 @@ extends tpd.TreeTraverser: def recur(tp: Type): Type = tp.dealias match case tp @ CapturingType(parent, refs) if !tp.isBoxed => tp.boxed - case tp1 @ AppliedType(tycon, args) if defn.isNonRefinedFunction(tp1) => + case tp1 @ AppliedType(tycon, args) if defn.isFunctionNType(tp1) => val res = args.last val boxedRes = recur(res) if boxedRes eq res then tp else tp1.derivedAppliedType(tycon, args.init :+ boxedRes) - case tp1 @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(tp1) => + case defn.RefinedFunctionOf(rinfo: MethodType) => val boxedRinfo = recur(rinfo) if boxedRinfo eq rinfo then tp - else boxedRinfo.toFunctionType(isJava = false, alwaysDependent = true) + else boxedRinfo.toFunctionType(alwaysDependent = true) case tp1: MethodOrPoly => val res = tp1.resType val boxedRes = recur(res) @@ -129,7 +129,7 @@ extends tpd.TreeTraverser: apply(parent) case tp @ AppliedType(tycon, args) => val tycon1 = this(tycon) - if defn.isNonRefinedFunction(tp) then + if defn.isFunctionNType(tp) then // Convert toplevel generic function types to dependent functions if !defn.isFunctionSymbol(tp.typeSymbol) && (tp.dealias ne tp) then // This type is a function after dealiasing, so we dealias and recurse. @@ -149,9 +149,9 @@ extends tpd.TreeTraverser: tp.derivedAppliedType(tycon1, args1 :+ res1) else tp.derivedAppliedType(tycon1, args.mapConserve(arg => this(arg))) - case tp @ RefinedType(core, rname, rinfo: MethodType) if defn.isFunctionType(tp) => + case defn.RefinedFunctionOf(rinfo: MethodType) => val rinfo1 = apply(rinfo) - if rinfo1 ne rinfo then rinfo1.toFunctionType(isJava = false, alwaysDependent = true) + if rinfo1 ne rinfo then rinfo1.toFunctionType(alwaysDependent = true) else tp case tp: MethodType => tp.derivedLambdaType( @@ -197,7 +197,7 @@ extends tpd.TreeTraverser: val mt = ContextualMethodType(paramName :: Nil)( _ => paramType :: Nil, mt => if isLast then res else expandThrowsAlias(res, mt :: encl)) - val fntpe = defn.PolyFunctionOf(mt) + val fntpe = mt.toFunctionType() if !encl.isEmpty && isLast then val cs = CaptureSet(encl.map(_.paramRefs.head)*) CapturingType(fntpe, cs, boxed = false) diff --git a/compiler/src/dotty/tools/dotc/cc/Synthetics.scala b/compiler/src/dotty/tools/dotc/cc/Synthetics.scala index 1e7c8d641238..c4c52513fb49 100644 --- a/compiler/src/dotty/tools/dotc/cc/Synthetics.scala +++ b/compiler/src/dotty/tools/dotc/cc/Synthetics.scala @@ -174,9 +174,9 @@ object Synthetics: val (et: ExprType) = symd.info: @unchecked val (enclThis: ThisType) = symd.owner.thisType: @unchecked def mapFinalResult(tp: Type, f: Type => Type): Type = - val defn.FunctionOf(args, res, isContextual) = tp: @unchecked + val defn.FunctionNOf(args, res, isContextual) = tp: @unchecked if defn.isFunctionNType(res) then - defn.FunctionOf(args, mapFinalResult(res, f), isContextual) + defn.FunctionNOf(args, mapFinalResult(res, f), isContextual) else f(tp) val resType1 = diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index ea48dd2b56fa..e56f0867e012 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -1109,23 +1109,48 @@ class Definitions { sym.owner.linkedClass.typeRef object FunctionOf { + /** Matches a `FunctionN[...]`/`ContextFunctionN[...]` or refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`. + * Extracts the method type type and apply info. + */ + def unapply(ft: Type)(using Context): Option[MethodOrPoly] = { + ft match + case RefinedFunctionOf(mt) => Some(mt) + case FunctionNOf(argTypes, resultType, isContextual) => + val methodType = if isContextual then ContextualMethodType else MethodType + Some(methodType(argTypes, resultType)) + case _ => None + } + } + + object RefinedFunctionOf { + /** Matches a refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`. + * Extracts the method type type and apply info. + */ + def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = { + tpe.refinedInfo match + case mt: MethodOrPoly + if tpe.refinedName == nme.apply + && (tpe.parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(tpe.parent)) => + Some(mt) + case _ => None + } + } + + object FunctionNOf { + /** Create a `FunctionN` or `ContextFunctionN` type applied to the arguments and result type */ def apply(args: List[Type], resultType: Type, isContextual: Boolean = false)(using Context): Type = - val mt = MethodType.companion(isContextual, false)(args, resultType) - if mt.hasErasedParams then - RefinedType(PolyFunctionClass.typeRef, nme.apply, mt) - else - FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil) - def unapply(ft: Type)(using Context): Option[(List[Type], Type, Boolean)] = { - ft.dealias match - case PolyFunctionOf(mt: MethodType) => - Some(mt.paramInfos, mt.resType, mt.isContextualMethod) - case dft => - val tsym = dft.typeSymbol - if isFunctionSymbol(tsym) && ft.isRef(tsym) then - val targs = dft.argInfos - if (targs.isEmpty) None - else Some(targs.init, targs.last, tsym.name.isContextFunction) - else None + FunctionType(args.length, isContextual).appliedTo(args ::: resultType :: Nil) + + /** Matches a (possibly aliased) `FunctionN[...]` or `ContextFunctionN[...]`. + * Extracts the list of function argument types, the result type and whether function is contextual. + */ + def unapply(tpe: Type)(using Context): Option[(List[Type], Type, Boolean)] = { + val tsym = tpe.typeSymbol + if isFunctionSymbol(tsym) && tpe.isRef(tsym) then + val targs = tpe.argInfos + if (targs.isEmpty) None + else Some(targs.init, targs.last, tsym.name.isContextFunction) + else None } } @@ -1140,16 +1165,18 @@ class Definitions { * * Pattern: `PolyFunction { def apply: $mt }` */ - def unapply(ft: Type)(using Context): Option[MethodicType] = ft.dealias match - case RefinedType(parent, nme.apply, mt: MethodicType) - if parent.derivesFrom(defn.PolyFunctionClass) => - Some(mt) - case _ => None + def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = + tpe.refinedInfo match + case mt: MethodOrPoly + if tpe.refinedName == nme.apply && tpe.parent.derivesFrom(defn.PolyFunctionClass) => + Some(mt) + case _ => None private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean = def isValidMethodType(info: Type) = info match case info: MethodType => !info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list + && !info.isParamDependent case _ => false info match case info: PolyType => isValidMethodType(info.resType) @@ -1716,26 +1743,20 @@ class Definitions { def isProductSubType(tp: Type)(using Context): Boolean = tp.derivesFrom(ProductClass) - /** Is `tp` (an alias) of either a scala.FunctionN or a scala.ContextFunctionN - * instance? + /** Returns whether `tp` is an instance or a refined instance of: + * - scala.FunctionN + * - scala.ContextFunctionN */ - def isNonRefinedFunction(tp: Type)(using Context): Boolean = - val arity = functionArity(tp) - val sym = tp.dealias.typeSymbol + def isFunctionNType(tp: Type)(using Context): Boolean = + val tp1 = tp.dropDependentRefinement + val arity = functionArity(tp1) + val sym = tp1.dealias.typeSymbol arity >= 0 && isFunctionClass(sym) - && tp.isRef( + && tp1.isRef( FunctionType(arity, sym.name.isContextFunction).typeSymbol, skipRefined = false) - end isNonRefinedFunction - - /** Returns whether `tp` is an instance or a refined instance of: - * - scala.FunctionN - * - scala.ContextFunctionN - */ - def isFunctionNType(tp: Type)(using Context): Boolean = - isNonRefinedFunction(tp.dropDependentRefinement) /** Returns whether `tp` is an instance or a refined instance of: * - scala.FunctionN @@ -1858,24 +1879,6 @@ class Definitions { def isContextFunctionType(tp: Type)(using Context): Boolean = asContextFunctionType(tp).exists - /** An extractor for context function types `As ?=> B`, possibly with - * dependent refinements. Optionally returns a triple consisting of the argument - * types `As`, the result type `B` and a whether the type is an erased context function. - */ - object ContextFunctionType: - def unapply(tp: Type)(using Context): Option[(List[Type], Type, List[Boolean])] = - if ctx.erasedTypes then - atPhase(erasurePhase)(unapply(tp)) - else - asContextFunctionType(tp) match - case PolyFunctionOf(mt: MethodType) => - Some((mt.paramInfos, mt.resType, mt.erasedParams)) - case tp1 if tp1.exists => - val args = tp1.functionArgInfos - val erasedParams = List.fill(functionArity(tp1)) { false } - Some((args.init, args.last, erasedParams)) - case _ => None - /** A whitelist of Scala-2 classes that are known to be pure */ def isAssuredNoInits(sym: Symbol): Boolean = (sym `eq` SomeClass) || isTupleClass(sym) diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index 94c7b2993b97..7b03283e40b4 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -567,7 +567,7 @@ object TypeErasure { functionType(info.resultType) case info: MethodType => assert(!info.resultType.isInstanceOf[MethodicType]) - defn.FunctionType(n = info.erasedParams.count(_ == false)) + defn.FunctionType(n = info.nonErasedParamCount) } erasure(functionType(applyInfo)) } @@ -933,7 +933,7 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst case tp: TermRef => sigName(underlyingOfTermRef(tp)) case ExprType(rt) => - sigName(defn.FunctionOf(Nil, rt)) + sigName(defn.FunctionNOf(Nil, rt)) case tp: TypeVar if !tp.isInstantiated => tpnme.Uninstantiated case tp @ defn.PolyFunctionOf(_) => diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index fcf9d984bcf4..d60e722a4bfc 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -1514,7 +1514,7 @@ object Types { /** Dealias, and if result is a dependent function type, drop the `apply` refinement. */ final def dropDependentRefinement(using Context): Type = dealias match { - case RefinedType(parent, nme.apply, mt) if defn.isNonRefinedFunction(parent) => parent + case RefinedType(parent, nme.apply, mt) if defn.isFunctionNType(parent) => parent case tp => tp } @@ -1877,7 +1877,7 @@ object Types { * when forming the function type. * @param alwaysDependent if true, always create a dependent function type. */ - def toFunctionType(isJava: Boolean, dropLast: Int = 0, alwaysDependent: Boolean = false)(using Context): Type = this match { + def toFunctionType(isJava: Boolean = false, dropLast: Int = 0, alwaysDependent: Boolean = false)(using Context): Type = this match { case mt: MethodType => assert(!mt.isParamDependent) def nonDependentFunType = @@ -1887,19 +1887,30 @@ object Types { case res: MethodType => res.toFunctionType(isJava) case res => res } - defn.FunctionOf( + defn.FunctionNOf( formals1 mapConserve (_.translateFromRepeated(toArray = isJava)), result1, isContextual) if mt.hasErasedParams then - defn.PolyFunctionOf(mt) + assert(isValidPolyFunctionInfo(mt), s"Not a valid PolyFunction refinement: $mt") + RefinedType(defn.PolyFunctionType, nme.apply, mt) else if alwaysDependent || mt.isResultDependent then RefinedType(nonDependentFunType, nme.apply, mt) else nonDependentFunType - case poly @ PolyType(_, mt: MethodType) => - assert(!mt.isParamDependent) - defn.PolyFunctionOf(poly) + case poly: PolyType => + assert(isValidPolyFunctionInfo(poly), s"Not a valid PolyFunction refinement: $poly") + RefinedType(defn.PolyFunctionType, nme.apply, poly) } + private def isValidPolyFunctionInfo(info: Type)(using Context): Boolean = + def isValidMethodType(info: Type) = info match + case info: MethodType => + !info.resType.isInstanceOf[MethodOrPoly] // Has only one parameter list + && !info.isParamDependent + case _ => false + info match + case info: PolyType => isValidMethodType(info.resType) + case _ => isValidMethodType(info) + /** The signature of this type. This is by default NotAMethod, * but is overridden for PolyTypes, MethodTypes, and TermRef types. * (the reason why we deviate from the "final-method-with-pattern-match-in-base-class" @@ -3721,8 +3732,6 @@ object Types { def companion: LambdaTypeCompanion[ThisName, PInfo, This] - def erasedParams(using Context) = List.fill(paramInfos.size)(false) - /** The type `[tparams := paramRefs] tp`, where `tparams` can be * either a list of type parameter symbols or a list of lambda parameters * @@ -4014,13 +4023,18 @@ object Types { final override def isImplicitMethod: Boolean = companion.eq(ImplicitMethodType) || isContextualMethod final override def hasErasedParams(using Context): Boolean = - erasedParams.contains(true) + paramInfos.exists(p => p.hasAnnotation(defn.ErasedParamAnnot)) + final override def isContextualMethod: Boolean = companion.eq(ContextualMethodType) - override def erasedParams(using Context): List[Boolean] = + def erasedParams(using Context): List[Boolean] = paramInfos.map(p => p.hasAnnotation(defn.ErasedParamAnnot)) + def nonErasedParamCount(using Context): Int = + paramInfos.count(p => !p.hasAnnotation(defn.ErasedParamAnnot)) + + protected def prefixString: String = companion.prefixString } @@ -4097,8 +4111,8 @@ object Types { tp.derivedAppliedType(tycon, addInto(args.head) :: Nil) case tp @ AppliedType(tycon, args) if defn.isFunctionNType(tp) => wrapConvertible(tp.derivedAppliedType(tycon, args.init :+ addInto(args.last))) - case tp @ RefinedType(parent, rname, rinfo) if defn.isFunctionType(tp) => - wrapConvertible(tp.derivedRefinedType(parent, rname, addInto(rinfo))) + case tp @ defn.RefinedFunctionOf(rinfo) => + wrapConvertible(tp.derivedRefinedType(tp.parent, tp.refinedName, addInto(rinfo))) case tp: MethodOrPoly => tp.derivedLambdaType(resType = addInto(tp.resType)) case ExprType(resType) => diff --git a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala index 4d87d6406567..b739bcf1b74d 100644 --- a/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala @@ -297,10 +297,10 @@ class PlainPrinter(_ctx: Context) extends Printer { "(" ~ toTextRef(tp) ~ " : " ~ toTextGlobal(tp.underlying) ~ ")" protected def paramsText(lam: LambdaType): Text = { - val erasedParams = lam.erasedParams - def paramText(ref: ParamRef, erased: Boolean) = + def paramText(ref: ParamRef) = + val erased = ref.underlying.hasAnnotation(defn.ErasedParamAnnot) keywordText("erased ").provided(erased) ~ ParamRefNameString(ref) ~ lambdaHash(lam) ~ toTextRHS(ref.underlying, isParameter = true) - Text(lam.paramRefs.lazyZip(erasedParams).map(paramText), ", ") + Text(lam.paramRefs.map(paramText), ", ") } protected def ParamRefNameString(name: Name): String = nameString(name) diff --git a/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala b/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala index ccf6cd6f995b..dbf4fe91a970 100644 --- a/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala +++ b/compiler/src/dotty/tools/dotc/quoted/Interpreter.scala @@ -126,11 +126,13 @@ class Interpreter(pos: SrcPos, classLoader0: ClassLoader)(using Context): view.toList fnType.dealias match - case fnType: MethodType if fnType.hasErasedParams => interpretArgs(argss, fnType.resType) case fnType: MethodType => val argTypes = fnType.paramInfos assert(argss.head.size == argTypes.size) - interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType) + val nonErasedArgs = argss.head.lazyZip(fnType.erasedParams).collect { case (arg, false) => arg }.toList + val nonErasedArgTypes = fnType.paramInfos.lazyZip(fnType.erasedParams).collect { case (arg, false) => arg }.toList + assert(nonErasedArgs.size == nonErasedArgTypes.size) + interpretArgsGroup(nonErasedArgs, nonErasedArgTypes) ::: interpretArgs(argss.tail, fnType.resType) case fnType: AppliedType if defn.isContextFunctionType(fnType) => val argTypes :+ resType = fnType.args: @unchecked interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType) @@ -328,8 +330,8 @@ object Interpreter: object Call: import tpd._ /** Matches an expression that is either a field access or an application - * It retruns a TermRef containing field accessed or a method reference and the arguments passed to it. - */ + * It returns a TermRef containing field accessed or a method reference and the arguments passed to it. + */ def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = Call0.unapply(arg).map((fn, args) => (fn, args.reverse)) @@ -339,10 +341,8 @@ object Interpreter: Some((fn, args)) case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil)) case fn: Select => Some((fn, Nil)) - case Apply(f @ Call0(fn, args1), args2) => - if (f.tpe.widenDealias.hasErasedParams) Some((fn, args1)) - else Some((fn, args2 :: args1)) - case TypeApply(Call0(fn, args), _) => Some((fn, args)) + case Apply(f @ Call0(fn, argss), args) => Some((fn, args :: argss)) + case TypeApply(Call0(fn, argss), _) => Some((fn, argss)) case _ => None } } diff --git a/compiler/src/dotty/tools/dotc/transform/Bridges.scala b/compiler/src/dotty/tools/dotc/transform/Bridges.scala index 569b16681cde..ee64177a9ef1 100644 --- a/compiler/src/dotty/tools/dotc/transform/Bridges.scala +++ b/compiler/src/dotty/tools/dotc/transform/Bridges.scala @@ -129,25 +129,25 @@ class Bridges(root: ClassSymbol, thisPhase: DenotTransformer)(using Context) { assert(ctx.typer.isInstanceOf[Erasure.Typer]) ctx.typer.typed(untpd.cpy.Apply(ref)(ref, args), member.info.finalResultType) else - val defn.ContextFunctionType(argTypes, resType, erasedParams) = tp: @unchecked - val anonFun = newAnonFun(ctx.owner, - MethodType( - argTypes.zip(erasedParams.padTo(argTypes.length, false)) - .flatMap((t, e) => if e then None else Some(t)), - resType), - coord = ctx.owner.coord) + val mtWithoutErasedParams = atPhase(erasurePhase) { + tp.dealias match + case defn.FunctionOf(mt: MethodType) => + val paramInfos = mt.paramInfos.zip(mt.erasedParams).collect { case (param, false) => param } + mt.derivedLambdaType(paramInfos = paramInfos) + } + val anonFun = newAnonFun(ctx.owner, mtWithoutErasedParams, coord = ctx.owner.coord) anonFun.info = transformInfo(anonFun, anonFun.info) def lambdaBody(refss: List[List[Tree]]) = val refs :: Nil = refss: @unchecked val expandedRefs = refs.map(_.withSpan(ctx.owner.span.endPos)) match case (bunchedParam @ Ident(nme.ALLARGS)) :: Nil => - argTypes.indices.toList.map(n => + mtWithoutErasedParams.paramInfos.indices.toList.map(n => bunchedParam .select(nme.primitive.arrayApply) .appliedTo(Literal(Constant(n)))) case refs1 => refs1 - expand(args ::: expandedRefs, resType, n - 1)(using ctx.withOwner(anonFun)) + expand(args ::: expandedRefs, mtWithoutErasedParams.resType, n - 1)(using ctx.withOwner(anonFun)) val unadapted = Closure(anonFun, lambdaBody) cpy.Block(unadapted)(unadapted.stats, diff --git a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala index b4eb71c541d3..6152e84800c2 100644 --- a/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala +++ b/compiler/src/dotty/tools/dotc/transform/ContextFunctionResults.scala @@ -19,10 +19,10 @@ object ContextFunctionResults: * consists of a string of `n` nested context closures. */ def annotateContextResults(mdef: DefDef)(using Context): Unit = - def contextResultCount(rhs: Tree, tp: Type): Int = tp match - case defn.ContextFunctionType(_, resTpe, _) => + def contextResultCount(rhs: Tree, tp: Type): Int = tp.dealias match + case defn.FunctionOf(mt) if mt.isContextualMethod => rhs match - case closureDef(meth) => 1 + contextResultCount(meth.rhs, resTpe) + case closureDef(meth) => 1 + contextResultCount(meth.rhs, mt.resType) case _ => 0 case _ => 0 @@ -58,7 +58,8 @@ object ContextFunctionResults: */ def contextResultsAreErased(sym: Symbol)(using Context): Boolean = def allErased(tp: Type): Boolean = tp.dealias match - case defn.ContextFunctionType(_, resTpe, erasedParams) => !erasedParams.contains(false) && allErased(resTpe) + case ft @ defn.FunctionOf(mt: MethodType) if mt.isContextualMethod => + !mt.erasedParams.contains(false) && allErased(mt.resType) case _ => true contextResultCount(sym) > 0 && allErased(sym.info.finalResultType) @@ -67,13 +68,13 @@ object ContextFunctionResults: */ def integrateContextResults(tp: Type, crCount: Int)(using Context): Type = if crCount == 0 then tp - else tp match + else tp.dealias match case ExprType(rt) => integrateContextResults(rt, crCount) case tp: MethodOrPoly => tp.derivedLambdaType(resType = integrateContextResults(tp.resType, crCount)) - case defn.ContextFunctionType(argTypes, resType, erasedParams) => - MethodType(argTypes, integrateContextResults(resType, crCount - 1)) + case defn.FunctionOf(mt) if mt.isContextualMethod => + mt.derivedLambdaType(resType = integrateContextResults(mt.resultType, crCount - 1)) /** The total number of parameters of method `sym`, not counting * erased parameters, but including context result parameters. @@ -83,16 +84,11 @@ object ContextFunctionResults: def contextParamCount(tp: Type, crCount: Int): Int = if crCount == 0 then 0 else - val defn.ContextFunctionType(params, resTpe, erasedParams) = tp: @unchecked - val rest = contextParamCount(resTpe, crCount - 1) - if erasedParams.contains(true) then erasedParams.count(_ == false) + rest else params.length + rest + val defn.FunctionOf(mt: MethodType) = tp: @unchecked + mt.nonErasedParamCount + contextParamCount(mt.resType, crCount - 1) def normalParamCount(tp: Type): Int = tp.widenExpr.stripPoly match - case mt @ MethodType(pnames) => - val rest = normalParamCount(mt.resType) - if mt.hasErasedParams then - mt.erasedParams.count(_ == false) + rest - else pnames.length + rest + case mt @ MethodType(pnames) => mt.nonErasedParamCount + normalParamCount(mt.resType) case _ => contextParamCount(tp, contextResultCount(sym)) normalParamCount(sym.info) @@ -103,7 +99,7 @@ object ContextFunctionResults: def recur(tp: Type, n: Int): Type = if n == 0 then tp else tp match - case defn.ContextFunctionType(_, resTpe, _) => recur(resTpe, n - 1) + case defn.FunctionOf(mt) => recur(mt.resType, n - 1) recur(meth.info.finalResultType, depth) /** Should selection `tree` be eliminated since it refers to an `apply` @@ -117,8 +113,8 @@ object ContextFunctionResults: else tree match case Select(qual, name) => if name == nme.apply then - qual.tpe match - case defn.ContextFunctionType(_, _, _) => + qual.tpe.nn.dealias match + case defn.FunctionOf(mt) if mt.isContextualMethod => integrateSelect(qual, n + 1) case _ if defn.isContextFunctionClass(tree.symbol.maybeOwner) => // for TermRefs integrateSelect(qual, n + 1) diff --git a/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala b/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala index b368e47bf0b3..791d461add7a 100644 --- a/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala +++ b/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala @@ -326,7 +326,7 @@ object PickleQuotes { defn.QuotedExprClass.typeRef.appliedTo(defn.AnyType)), args => val cases = holeContents.zipWithIndex.map { case (splice, idx) => - val defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) = splice.tpe: @unchecked + val defn.FunctionNOf(argTypes, defn.FunctionNOf(quotesType :: _, _, _), _) = splice.tpe: @unchecked val rhs = { val spliceArgs = argTypes.zipWithIndex.map { (argType, i) => args(1).select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argType) diff --git a/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala b/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala index c50eaddd3213..9d757dc9713c 100644 --- a/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala +++ b/compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala @@ -88,7 +88,7 @@ class SpecializeFunctions extends MiniPhase { // Need to cast to regular function, since specialized apply methods // are not members of ContextFunction0. The cast will be eliminated in // erasure. - qual.cast(defn.FunctionOf(Nil, res)) + qual.cast(defn.FunctionNOf(Nil, res)) case _ => qual qual1.select(specializedApply) diff --git a/compiler/src/dotty/tools/dotc/transform/Splicing.scala b/compiler/src/dotty/tools/dotc/transform/Splicing.scala index 51cb716e47ca..dd95d5a9ca1e 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicing.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicing.scala @@ -197,7 +197,7 @@ class Splicing extends MacroTransform: if tree.isTerm then if isCaptured(tree.symbol) then val tpe = tree.tpe.widenTermRefExpr match { - case tpw: MethodicType => tpw.toFunctionType(isJava = false) + case tpw: MethodicType => tpw.toFunctionType() case tpw => tpw } spliced(tpe)(capturedTerm(tree)) @@ -291,7 +291,7 @@ class Splicing extends MacroTransform: private def capturedTerm(tree: Tree)(using Context): Tree = val tpe = tree.tpe.widenTermRefExpr match - case tpw: MethodicType => tpw.toFunctionType(isJava = false) + case tpw: MethodicType => tpw.toFunctionType() case tpw => tpw capturedTerm(tree, tpe) diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index f84f628fc981..dd32dde93f95 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -749,9 +749,9 @@ object TreeChecker { if isTerm then defn.QuotedExprClass.typeRef.appliedTo(tree1.typeOpt) else defn.QuotedTypeClass.typeRef.appliedTo(tree1.typeOpt) val contextualResult = - defn.FunctionOf(List(defn.QuotesClass.typeRef), expectedResultType, isContextual = true) + defn.FunctionNOf(List(defn.QuotesClass.typeRef), expectedResultType, isContextual = true) val expectedContentType = - defn.FunctionOf(argQuotedTypes, contextualResult) + defn.FunctionNOf(argQuotedTypes, contextualResult) assert(content.typeOpt =:= expectedContentType, i"unexpected content of hole\nexpected: ${expectedContentType}\nwas: ${content.typeOpt}") tree1 diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 921e3ca86fe4..dd0fefb5bd5a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1724,7 +1724,7 @@ trait Applications extends Compatibility { def apply(t: Type) = t match { case t @ AppliedType(tycon, args) => def mapArg(arg: Type, tparam: TypeParamInfo) = - if (variance > 0 && tparam.paramVarianceSign < 0) defn.FunctionOf(arg :: Nil, defn.UnitType) + if (variance > 0 && tparam.paramVarianceSign < 0) defn.FunctionNOf(arg :: Nil, defn.UnitType) else arg mapOver(t.derivedAppliedType(tycon, args.zipWithConserve(tycon.typeParams)(mapArg))) case _ => mapOver(t) @@ -1951,7 +1951,7 @@ trait Applications extends Compatibility { /** The shape of given tree as a type; cannot handle named arguments. */ def typeShape(tree: untpd.Tree): Type = tree match { case untpd.Function(args, body) => - defn.FunctionOf( + defn.FunctionNOf( args.map(Function.const(defn.AnyType)), typeShape(body), isContextual = untpd.isContextualClosure(tree)) case Match(EmptyTree, _) => @@ -1991,8 +1991,8 @@ trait Applications extends Compatibility { def paramCount(ref: TermRef) = val formals = ref.widen.firstParamTypes if formals.length > idx then - formals(idx) match - case defn.FunctionOf(args, _, _) => args.length + formals(idx).dealias match + case defn.FunctionNOf(args, _, _) => args.length case _ => -1 else -1 @@ -2077,8 +2077,8 @@ trait Applications extends Compatibility { else resolveMapped(alts1, _.widen.appliedTo(targs1.tpes), pt1) case pt => - val compat0 = pt match - case defn.FunctionOf(args, resType, _) => + val compat0 = pt.dealias match + case defn.FunctionNOf(args, resType, _) => narrowByTypes(alts, args, resType) case _ => Nil @@ -2243,7 +2243,7 @@ trait Applications extends Compatibility { val formalsForArg: List[Type] = altFormals.map(_.head) def argTypesOfFormal(formal: Type): List[Type] = formal.dealias match { - case defn.FunctionOf(args, result, isImplicit) => args + case defn.FunctionOf(mt: MethodType) if !mt.isResultDependent => mt.paramInfos // TODO handle result-dependent functions? case defn.PartialFunctionOf(arg, result) => arg :: Nil case _ => Nil } @@ -2266,7 +2266,7 @@ trait Applications extends Compatibility { false val commonFormal = if (isPartial) defn.PartialFunctionOf(commonParamTypes.head, WildcardType) - else defn.FunctionOf(commonParamTypes, WildcardType, isContextual = untpd.isContextualClosure(arg)) + else defn.FunctionNOf(commonParamTypes, WildcardType, isContextual = untpd.isContextualClosure(arg)) overload.println(i"pretype arg $arg with expected type $commonFormal") if (commonParamTypes.forall(isFullyDefined(_, ForceDegree.flipBottom))) withMode(Mode.ImplicitsEnabled) { diff --git a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala index 25cbfdfec600..044601c5471f 100644 --- a/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala +++ b/compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala @@ -166,8 +166,9 @@ object ErrorReporting { val normTp = normalize(tree.tpe, pt) val normPt = normalize(pt, pt) - def contextFunctionCount(tp: Type): Int = tp.stripped match - case defn.ContextFunctionType(_, restp, _) => 1 + contextFunctionCount(restp) + def contextFunctionCount(tp: Type): Int = tp.stripped.dealias match + // TODO handle result-dependent functions? + case defn.FunctionOf(mt) if mt.isContextualMethod && !mt.isResultDependent => 1 + contextFunctionCount(mt.resType) case _ => 0 def strippedTpCount = contextFunctionCount(tree.tpe) - contextFunctionCount(normTp) def strippedPtCount = contextFunctionCount(pt) - contextFunctionCount(normPt) diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index b43240a1fbb1..fc23620253a6 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1892,10 +1892,12 @@ class Namer { typer: Typer => def expectedDefaultArgType = val originalTp = defaultParamType val approxTp = wildApprox(originalTp) - approxTp.stripPoly match - case atp @ defn.ContextFunctionType(_, resType, _) - if !defn.isNonRefinedFunction(atp) // in this case `resType` is lying, gives us only the non-dependent upper bound - || resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) => + approxTp.dealias.stripPoly match + case defn.FunctionOf(mt) + if mt.isContextualMethod && ( + mt.isResultDependent || // in this case `resType` is lying, gives us only the non-dependent upper bound + mt.resType.existsPart(_.isInstanceOf[WildcardType], StopAt.Static, forceLazy = false) + ) => originalTp case _ => approxTp diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index 7303124b0cd4..2227f5f89411 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -383,9 +383,9 @@ object ProtoTypes { def allArgTypesAreCurrent()(using Context): Boolean = state.typedArg.size == args.length - private def isUndefined(tp: Type): Boolean = tp match { + private def isUndefined(tp: Type): Boolean = tp.dealias match { case _: WildcardType => true - case defn.FunctionOf(args, result, _) => args.exists(isUndefined) || isUndefined(result) + case defn.FunctionNOf(args, result, _) => args.exists(isUndefined) || isUndefined(result) case _ => false } @@ -424,7 +424,7 @@ object ProtoTypes { case ValDef(_, tpt, _) if !tpt.isEmpty => typer.typedType(tpt).typeOpt case _ => WildcardType } - targ = arg.withType(defn.FunctionOf(paramTypes, WildcardType)) + targ = arg.withType(defn.FunctionNOf(paramTypes, WildcardType)) case Some(_) if !force => targ = arg.withType(WildcardType) case _ => @@ -845,9 +845,9 @@ object ProtoTypes { tp case pt: ApplyingProto => if (rt eq mt.resultType) tp - else mt.derivedLambdaType(mt.paramNames, mt.paramInfos, rt) + else mt.derivedLambdaType(resType = rt) case _ => - val ft = defn.FunctionOf(mt.paramInfos, rt) + val ft = mt.derivedLambdaType(resType = rt).toFunctionType() if mt.paramInfos.nonEmpty || (ft frozen_<:< pt) then ft else rt } } diff --git a/compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala b/compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala index 28afccd1ca43..a172eb290f7a 100644 --- a/compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala +++ b/compiler/src/dotty/tools/dotc/typer/QuotesAndSplices.scala @@ -122,7 +122,7 @@ trait QuotesAndSplices { for arg <- typedArgs if arg.symbol.is(Mutable) do // TODO support these patterns. Possibly using scala.quoted.util.Var report.error("References to `var`s cannot be used in higher-order pattern", arg.srcPos) val argTypes = typedArgs.map(_.tpe.widenTermRefExpr) - val patType = if tree.args.isEmpty then pt else defn.FunctionOf(argTypes, pt) + val patType = if tree.args.isEmpty then pt else defn.FunctionNOf(argTypes, pt) val pat = typedPattern(tree.body, defn.QuotedExprClass.typeRef.appliedTo(patType))(using quotePatternSpliceContext) val baseType = pat.tpe.baseType(defn.QuotedExprClass) val argType = if baseType.exists then baseType.argTypesHi.head else defn.NothingType @@ -148,7 +148,7 @@ trait QuotesAndSplices { if isInBraces then // ${x}(...) match an application val typedArgs = args.map(arg => typedExpr(arg)) val argTypes = typedArgs.map(_.tpe.widenTermRefExpr) - val splice1 = typedSplicePattern(splice, defn.FunctionOf(argTypes, pt)) + val splice1 = typedSplicePattern(splice, defn.FunctionNOf(argTypes, pt)) untpd.cpy.Apply(tree)(splice1.select(nme.apply), typedArgs).withType(pt) else // $x(...) higher-order quasipattern if args.isEmpty then diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index cbb13a841946..c15a6da0b701 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -105,7 +105,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): case AppliedType(_, funArgs @ fun :: tupled :: Nil) => def functionTypeEqual(baseFun: Type, actualArgs: List[Type], actualRet: Type, expected: Type) = - expected =:= defn.FunctionOf(actualArgs, actualRet, + expected =:= defn.FunctionNOf(actualArgs, actualRet, defn.isContextFunctionType(baseFun)) val arity: Int = if defn.isFunctionNType(fun) then diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 09df6614d496..53fad96594aa 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1318,21 +1318,21 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer em"""Implementation restriction: Expected result type $pt1 |is a curried dependent context function type. Such types are not yet supported.""", pos) + def fallbackProto = (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) pt1 match { case tp: TypeParamRef => decomposeProtoFunction(ctx.typerState.constraint.entry(tp).bounds.hi, defaultArity, pos) case _ => pt1.findFunctionType match { - case pt1 if defn.isNonRefinedFunction(pt1) => - // if expected parameter type(s) are wildcards, approximate from below. - // if expected result type is a wildcard, approximate from above. - // this can type the greatest set of admissible closures. - - (pt1.argInfos.init, typeTree(interpolateWildcards(pt1.argInfos.last.hiBound))) - case RefinedType(parent, nme.apply, mt @ MethodTpe(_, formals, restpe)) - if defn.isNonRefinedFunction(parent) && formals.length == defaultArity => - (formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))) - case defn.PolyFunctionOf(mt @ MethodTpe(_, formals, restpe)) if formals.length == defaultArity => - (formals, untpd.InLambdaTypeTree(isResult = true, (_, syms) => restpe.substParams(mt, syms.map(_.termRef)))) + case ft @ defn.FunctionOf(mt: MethodType) => + if !mt.isResultDependent then + // if expected parameter type(s) are wildcards, approximate from below. + // if expected result type is a wildcard, approximate from above. + // this can type the greatest set of admissible closures. + (mt.paramInfos, typeTree(interpolateWildcards(mt.resType.hiBound))) + else if mt.paramInfos.length == defaultArity then + (mt.paramInfos, untpd.InLambdaTypeTree(isResult = true, (_, syms) => mt.resType.substParams(mt, syms.map(_.termRef)))) + else + fallbackProto case SAMType(mt @ MethodTpe(_, formals, _), samParent) => val restpe = mt.resultType match case mt: MethodType => mt.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined)) @@ -1343,7 +1343,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else typeTree(restpe)) case _ => - (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree()) + fallbackProto } } } @@ -1648,10 +1648,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer def typedPolyFunctionValue(tree: untpd.PolyFunction, pt: Type)(using Context): Tree = val untpd.PolyFunction(tparams: List[untpd.TypeDef] @unchecked, fun) = tree: @unchecked val untpd.Function(vparams: List[untpd.ValDef] @unchecked, body) = fun: @unchecked + val dpt = pt.dealias // If the expected type is a polymorphic function with the same number of // type and value parameters, then infer the types of value parameters from the expected type. - val inferredVParams = pt match + val inferredVParams = dpt match case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) if tparams.lengthCompare(poly.paramNames) == 0 && vparams.lengthCompare(mt.paramNames) == 0 => vparams.zipWithConserve(mt.paramInfos): (vparam, formal) => @@ -1667,7 +1668,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer case _ => vparams - val resultTpt = pt.dealias match + val resultTpt = dpt match case defn.PolyFunctionOf(poly @ PolyType(_, mt: MethodType)) => untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) => mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef))) @@ -3199,7 +3200,8 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer tree protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = { - val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked + val defn.FunctionOf(mt: MethodType) = pt.dropDependentRefinement: @unchecked + val formals = mt.paramInfos // The getter of default parameters may reach here. // Given the code below @@ -3227,12 +3229,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer else formals.map(untpd.TypeTree) } - val erasedParams = pt match { - case defn.PolyFunctionOf(mt: MethodType) => mt.erasedParams - case _ => paramTypes.map(_ => false) - } - - val ifun = desugar.makeContextualFunction(paramTypes, tree, erasedParams) + val ifun = desugar.makeContextualFunction(paramTypes, tree, mt.erasedParams) typr.println(i"make contextual function $tree / $pt ---> $ifun") typedFunctionValue(ifun, pt) } diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index bc4f84c147c8..e8c6e1125cad 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -1814,9 +1814,9 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler case PolyType(_, _, mt1) => mt1.hasErasedParams case _ => false def isDependentFunctionType: Boolean = - val tpNoRefinement = self.dropDependentRefinement - tpNoRefinement != self - && dotc.core.Symbols.defn.isNonRefinedFunction(tpNoRefinement) + self match + case dotc.core.Symbols.defn.FunctionOf(mt) => mt.isResultDependent + case _ => false def isTupleN: Boolean = dotc.core.Symbols.defn.isTupleNType(self) def select(sym: Symbol): TypeRepr = self.select(sym) diff --git a/tests/neg-custom-args/captures/byname.check b/tests/neg-custom-args/captures/byname.check index b1d8fb3b5404..297254a9b635 100644 --- a/tests/neg-custom-args/captures/byname.check +++ b/tests/neg-custom-args/captures/byname.check @@ -2,7 +2,7 @@ 10 | h(f2()) // error | ^^^^ | Found: (x$0: Int) ->{cap1} Int - | Required: (x$0: Int) ->{cap2} Int + | Required: Int ->{cap2} Int | | longer explanation available when compiling with `-explain` -- [E007] Type Mismatch Error: tests/neg-custom-args/captures/byname.scala:19:5 ---------------------------------------- diff --git a/tests/pos-macros/erasedArgs/Macro_1.scala b/tests/pos-macros/erasedArgs/Macro_1.scala new file mode 100644 index 000000000000..08706d6110b9 --- /dev/null +++ b/tests/pos-macros/erasedArgs/Macro_1.scala @@ -0,0 +1,7 @@ +import scala.quoted._ +import scala.language.experimental.erasedDefinitions + +transparent inline def mcr: Any = ${ mcrImpl(1, 2d, "abc") } + +def mcrImpl(x: Int, erased y: Double, z: String)(using Quotes): Expr[String] = + Expr(x.toString() + z) diff --git a/tests/pos-macros/erasedArgs/Test_2.scala b/tests/pos-macros/erasedArgs/Test_2.scala new file mode 100644 index 000000000000..19f0364d3f71 --- /dev/null +++ b/tests/pos-macros/erasedArgs/Test_2.scala @@ -0,0 +1 @@ +def test: "1abc" = mcr