From 72999389cafdfe64fabb8f3306ce7e3c7d407bfb Mon Sep 17 00:00:00 2001 From: Jamie Thompson Date: Mon, 23 May 2022 11:57:46 +0200 Subject: [PATCH] cleaner way to extract class or tuple proxy --- .../dotty/tools/dotc/core/TypeErasure.scala | 4 +- .../src/dotty/tools/dotc/core/Types.scala | 10 ++ .../tools/dotc/printing/RefinedPrinter.scala | 2 +- .../dotc/transform/GenericSignatures.scala | 2 +- .../tools/dotc/transform/TypeUtils.scala | 8 +- .../dotty/tools/dotc/typer/Synthesizer.scala | 141 ++++++++++++------ .../src/dotty/tools/dotc/typer/Typer.scala | 2 +- tests/neg/i14127a.scala | 5 +- 8 files changed, 121 insertions(+), 53 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala index 2caa639592b3..87a829e51519 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeErasure.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeErasure.scala @@ -689,7 +689,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst } private def erasePair(tp: Type)(using Context): Type = { - val arity = tp.tupleArity + // NOTE: `tupleArity` does not consider TypeRef(EmptyTuple$) equivalent to EmptyTuple.type, + // we fix this for printers, but type erasure should be preserved. + val arity = tp.tupleArity() if (arity < 0) defn.ProductClass.typeRef else if (arity <= Definitions.MaxTupleArity) defn.TupleType(arity).nn else defn.TupleXXLClass.typeRef diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 3d2a77a05c82..5dfc7a6aa725 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -40,6 +40,7 @@ import scala.annotation.internal.sharable import scala.annotation.threadUnsafe import dotty.tools.dotc.transform.SymUtils._ +import dotty.tools.dotc.transform.TypeUtils.* object Types { @@ -47,6 +48,15 @@ object Types { implicit def eqType: CanEqual[Type, Type] = CanEqual.derived + object GenericTupleType: + def unapply(tp: Type)(using Context): Option[List[Type]] = tp match + case tp @ AppliedType(r: TypeRef, _) if r.isRef(defn.PairClass) && tp.tupleArity(relaxEmptyTuple = true) > 0 => + Some(tp.tupleElementTypes) + case AppliedType(_: TypeRef, args) if defn.isTupleNType(tp) => + Some(args) + case _ => + None + /** Main class representing types. * * The principal subclasses and sub-objects are as follows: diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index 8bdf1d4822ce..1c52b5acb38a 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -218,7 +218,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { val cls = tycon.typeSymbol if tycon.isRepeatedParam then toTextLocal(args.head) ~ "*" else if defn.isFunctionClass(cls) then toTextFunction(args, cls.name.isContextFunction, cls.name.isErasedFunction) - else if tp.tupleArity >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes) + else if tp.tupleArity(relaxEmptyTuple = true) >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes) else if isInfixType(tp) then val l :: r :: Nil = args: @unchecked val opName = tyconName(tycon) diff --git a/compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala b/compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala index 9a6ab233e239..6a01c9dc64f9 100644 --- a/compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala +++ b/compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala @@ -248,7 +248,7 @@ object GenericSignatures { case _ => jsig(elemtp) case RefOrAppliedType(sym, pre, args) => - if (sym == defn.PairClass && tp.tupleArity > Definitions.MaxTupleArity) + if (sym == defn.PairClass && tp.tupleArity() > Definitions.MaxTupleArity) jsig(defn.TupleXXLClass.typeRef) else if (isTypeParameterInSig(sym, sym0)) { assert(!sym.isAliasType, "Unexpected alias type: " + sym) diff --git a/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala b/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala index f416c2e9a13e..4458ee85f9ee 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeUtils.scala @@ -51,14 +51,16 @@ object TypeUtils { /** The arity of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs, * or -1 if this is not a tuple type. + * + * @param relaxEmptyTuple if true then TypeRef(EmptyTuple$) =:= EmptyTuple.type */ - def tupleArity(using Context): Int = self match { + def tupleArity(relaxEmptyTuple: Boolean = false)(using Context): Int = self match { case AppliedType(tycon, _ :: tl :: Nil) if tycon.isRef(defn.PairClass) => - val arity = tl.tupleArity + val arity = tl.tupleArity(relaxEmptyTuple) if (arity < 0) arity else arity + 1 case self: SingletonType => if self.termSymbol == defn.EmptyTupleModule then 0 else -1 - case self: TypeRef if self.classSymbol == defn.EmptyTupleModule.moduleClass => + case self: TypeRef if relaxEmptyTuple && self.classSymbol == defn.EmptyTupleModule.moduleClass => 0 case self if defn.isTupleClass(self.classSymbol) => self.dealias.argInfos.length diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 58b8c17b21e6..88c04582b86a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -279,36 +279,81 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): case t => mapOver(t) monoMap(mirroredType.resultType) - private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors = + private[Synthesizer] enum MirrorSource: + case ClassSymbol(cls: Symbol) + case GenericTuple(tuplArity: Int, tpArgs: List[Type]) + + def isGenericTuple: Boolean = this.isInstanceOf[GenericTuple] + + /** tuple arity, works for TupleN classes and generic tuples */ + final def arity(using Context): Int = this match + case GenericTuple(arity, _) => arity + case ClassSymbol(cls) if defn.isTupleClass(cls) => cls.typeParams.length + case _ => -1 + + def equiv(that: MirrorSource)(using Context): Boolean = (this.arity, that.arity) match + case (n, m) if n > 0 || m > 0 => + // we shortcut when at least one was a tuple. + // This protects us from comparing classes for two TupleXXL with different arities. + n == m + case _ => this.asClass eq that.asClass // class equality otherwise + + def isSub(that: MirrorSource)(using Context): Boolean = (this.arity, that.arity) match + case (n, m) if n > 0 || m > 0 => + // we shortcut when at least one was a tuple. + // This protects us from comparing classes for two TupleXXL with different arities. + n == m + case _ => this.asClass isSubClass that.asClass + + def asClass(using Context): Symbol = this match + case ClassSymbol(cls) => cls + case GenericTuple(arity, _) => + if arity <= Definitions.MaxTupleArity then defn.TupleType(arity).nn.classSymbol + else defn.TupleXXLClass + + object MirrorSource: + def tuple(tps: List[Type]): MirrorSource.GenericTuple = MirrorSource.GenericTuple(tps.size, tps) + + end MirrorSource - var isSafeGenericTuple = Option.empty[(Symbol, List[Type])] + private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors = - /** do all parts match the class symbol? Or can we extract a generic tuple type out? */ - def acceptable(tp: Type, cls: Symbol): Boolean = - var genericTupleParts = List.empty[(Symbol, List[Type])] + extension (msrc: MirrorSource) def isGenericProd(using Context) = + msrc.isGenericTuple || msrc.asClass.isGenericProduct && canAccessCtor(msrc.asClass) - def acceptableGenericTuple(tp: AppliedType): Boolean = - val tupleArgs = tp.tupleElementTypes - val arity = tupleArgs.size - val isOk = arity <= Definitions.MaxTupleArity - if isOk then - genericTupleParts ::= { - val cls = defn.TupleType(arity).nn.classSymbol - (cls, tupleArgs) - } - isOk + /** Follows `classSymbol`, but instead reduces to a proxy of a generic tuple (or a scala.TupleN class). + * + * Does not need to consider AndType, as that is already stripped. + */ + def tupleProxy(tp: Type)(using Context): Option[MirrorSource] = tp match + case tp: TypeRef => if tp.symbol.isClass then None else tupleProxy(tp.superType) + case GenericTupleType(args) => Some(MirrorSource.tuple(args)) + case tp: TypeProxy => + tupleProxy(tp.underlying) + case tp: OrType => + if tp.tp1.hasClassSymbol(defn.NothingClass) then + tupleProxy(tp.tp2) + else if tp.tp2.hasClassSymbol(defn.NothingClass) then + tupleProxy(tp.tp1) + else tupleProxy(tp.join) + case _ => + None - def inner(tp: Type, cls: Symbol): Boolean = tp match - case tp: HKTypeLambda if tp.resultType.isInstanceOf[HKTypeLambda] => false - case tp @ AppliedType(cons: TypeRef, _) if cons.isRef(defn.PairClass) => acceptableGenericTuple(tp) - case tp: TypeProxy => inner(tp.underlying, cls) - case OrType(tp1, tp2) => inner(tp1, cls) && inner(tp2, cls) - case _ => tp.classSymbol eq cls + def mirrorSource(tp: Type)(using Context): Option[MirrorSource] = + val fromClass = tp.classSymbol + if fromClass.exists then // test if it could be reduced to a generic tuple + if fromClass.isSubClass(defn.TupleClass) && !defn.isTupleClass(fromClass) then tupleProxy(tp) + else Some(MirrorSource.ClassSymbol(fromClass)) + else None - val classPartsMatch = inner(tp, cls) - classPartsMatch && genericTupleParts.map((cls, _) => cls).distinct.sizeIs <= 1 && - { isSafeGenericTuple = genericTupleParts.headOption ; true } - end acceptable + /** do all parts match the class symbol? */ + def acceptable(tp: Type, msrc: MirrorSource): Boolean = tp match + case tp: HKTypeLambda if tp.resultType.isInstanceOf[HKTypeLambda] => false + case OrType(tp1, tp2) => acceptable(tp1, msrc) && acceptable(tp2, msrc) + case GenericTupleType(args) if args.size <= Definitions.MaxTupleArity => + MirrorSource.tuple(args).equiv(msrc) + case tp: TypeProxy => acceptable(tp.underlying, msrc) + case _ => mirrorSource(tp).exists(_.equiv(msrc)) /** for a case class, if it will have an anonymous mirror, * check that its constructor can be accessed @@ -326,13 +371,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): def genAnonyousMirror(cls: Symbol): Boolean = cls.is(Scala2x) || cls.linkedClass.is(Case) - def makeProductMirror(cls: Symbol): TreeWithErrors = - val mirroredClass = isSafeGenericTuple.fold(cls)((cls, _) => cls) + def makeProductMirror(msrc: MirrorSource): TreeWithErrors = + val mirroredClass = msrc.asClass val accessors = mirroredClass.caseAccessors.filterNot(_.isAllOf(PrivateLocal)) val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString))) - val nestedPairs = isSafeGenericTuple.map((_, tps) => TypeOps.nestedPairs(tps)).getOrElse { - TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)) - } + val nestedPairs = msrc match + case MirrorSource.GenericTuple(_, args) => TypeOps.nestedPairs(args) + case _ => TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)) val (monoType, elemsType) = mirroredType match case mirroredType: HKTypeLambda => (mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs)) @@ -342,25 +387,30 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span) checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span) val mirrorType = - mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal) + mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, mirroredClass.name, formal) .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType)) .refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels)) val mirrorRef = if genAnonyousMirror(mirroredClass) then - anonymousMirror(monoType, ExtendsProductMirror, isSafeGenericTuple.map(_(1).size), span) + val arity = msrc match + case MirrorSource.GenericTuple(arity, _) => Some(arity) + case _ => None + anonymousMirror(monoType, ExtendsProductMirror, arity, span) else companionPath(mirroredType, span) withNoErrors(mirrorRef.cast(mirrorType)) end makeProductMirror - def getError(cls: Symbol): String = + def getError(msrc: MirrorSource): String = val reason = - if !cls.isGenericProduct then - i"because ${cls.whyNotGenericProduct}" - else if !canAccessCtor(cls) then - i"because the constructor of $cls is innaccessible from the calling scope." + if !msrc.isGenericTuple then + if !msrc.asClass.isGenericProduct then + i"because ${msrc.asClass.whyNotGenericProduct}" + else if !canAccessCtor(msrc.asClass) then + i"because the constructor of ${msrc.asClass} is innaccessible from the calling scope." + else "" else "" - i"$cls is not a generic product $reason" + i"${msrc.asClass} is not a generic product $reason" end getError mirroredType match @@ -378,13 +428,14 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, module.name, formal) withNoErrors(modulePath.cast(mirrorType)) else - val cls = mirroredType.classSymbol - if acceptable(mirroredType, cls) - && isSafeGenericTuple.isDefined || (cls.isGenericProduct && canAccessCtor(cls)) - then - makeProductMirror(cls) - else - (EmptyTree, List(getError(cls))) + mirrorSource(mirroredType) match + case Some(msrc) => + if acceptable(mirroredType, msrc) && msrc.isGenericProd then + makeProductMirror(msrc) + else + (EmptyTree, List(getError(msrc))) + case None => + (EmptyTree, List(i"${mirroredType.show} does not reduce to a class or generic tuple type")) end productMirror private def sumMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors = diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 34eb2b7df41f..812874651cb3 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -2766,7 +2766,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer typed(desugar.smallTuple(tree).withSpan(tree.span), pt) else { val pts = - if (arity == pt.tupleArity) pt.tupleElementTypes + if (arity == pt.tupleArity()) pt.tupleElementTypes else List.fill(arity)(defn.AnyType) val elems = tree.trees.lazyZip(pts).map( if ctx.mode.is(Mode.Type) then typedType(_, _, mapPatternBounds = true) diff --git a/tests/neg/i14127a.scala b/tests/neg/i14127a.scala index c5e1e88f3922..26611eb988a0 100644 --- a/tests/neg/i14127a.scala +++ b/tests/neg/i14127a.scala @@ -2,4 +2,7 @@ import scala.deriving.Mirror // mixing arities is not supported -val mT23 = summon[Mirror.Of[(Int *: Int *: EmptyTuple) | (Int *: Int *: Int *: EmptyTuple)]] // error +val mT2Or2a = summon[Mirror.Of[(Int *: Int *: EmptyTuple) | (Int *: Int *: EmptyTuple)]] // ok, same arity +val mT2Or2b = summon[Mirror.Of[(Int *: Int *: EmptyTuple) | Tuple2[Int, Int]]] // ok, same arity +val mT2Or2c = summon[Mirror.Of[Tuple2[Int, Int] | (Int *: Int *: EmptyTuple)]] // ok, same arity +val mT2Or3 = summon[Mirror.Of[(Int *: Int *: EmptyTuple) | (Int *: Int *: Int *: EmptyTuple)]] // error