Skip to content

Commit

Permalink
cleaner way to extract class or tuple proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed May 23, 2022
1 parent 0e09603 commit 7299938
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 53 deletions.
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeErasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,9 @@ class TypeErasure(sourceLanguage: SourceLanguage, semiEraseVCs: Boolean, isConst
}

private def erasePair(tp: Type)(using Context): Type = {
val arity = tp.tupleArity
// NOTE: `tupleArity` does not consider TypeRef(EmptyTuple$) equivalent to EmptyTuple.type,
// we fix this for printers, but type erasure should be preserved.
val arity = tp.tupleArity()
if (arity < 0) defn.ProductClass.typeRef
else if (arity <= Definitions.MaxTupleArity) defn.TupleType(arity).nn
else defn.TupleXXLClass.typeRef
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ import scala.annotation.internal.sharable
import scala.annotation.threadUnsafe

import dotty.tools.dotc.transform.SymUtils._
import dotty.tools.dotc.transform.TypeUtils.*

object Types {

@sharable private var nextId = 0

implicit def eqType: CanEqual[Type, Type] = CanEqual.derived

object GenericTupleType:
def unapply(tp: Type)(using Context): Option[List[Type]] = tp match
case tp @ AppliedType(r: TypeRef, _) if r.isRef(defn.PairClass) && tp.tupleArity(relaxEmptyTuple = true) > 0 =>
Some(tp.tupleElementTypes)
case AppliedType(_: TypeRef, args) if defn.isTupleNType(tp) =>
Some(args)
case _ =>
None

/** Main class representing types.
*
* The principal subclasses and sub-objects are as follows:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
val cls = tycon.typeSymbol
if tycon.isRepeatedParam then toTextLocal(args.head) ~ "*"
else if defn.isFunctionClass(cls) then toTextFunction(args, cls.name.isContextFunction, cls.name.isErasedFunction)
else if tp.tupleArity >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes)
else if tp.tupleArity(relaxEmptyTuple = true) >= 2 && !printDebug then toTextTuple(tp.tupleElementTypes)
else if isInfixType(tp) then
val l :: r :: Nil = args: @unchecked
val opName = tyconName(tycon)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ object GenericSignatures {
case _ => jsig(elemtp)

case RefOrAppliedType(sym, pre, args) =>
if (sym == defn.PairClass && tp.tupleArity > Definitions.MaxTupleArity)
if (sym == defn.PairClass && tp.tupleArity() > Definitions.MaxTupleArity)
jsig(defn.TupleXXLClass.typeRef)
else if (isTypeParameterInSig(sym, sym0)) {
assert(!sym.isAliasType, "Unexpected alias type: " + sym)
Expand Down
8 changes: 5 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ object TypeUtils {

/** The arity of this tuple type, which can be made up of EmptyTuple, TupleX and `*:` pairs,
* or -1 if this is not a tuple type.
*
* @param relaxEmptyTuple if true then TypeRef(EmptyTuple$) =:= EmptyTuple.type
*/
def tupleArity(using Context): Int = self match {
def tupleArity(relaxEmptyTuple: Boolean = false)(using Context): Int = self match {
case AppliedType(tycon, _ :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
val arity = tl.tupleArity
val arity = tl.tupleArity(relaxEmptyTuple)
if (arity < 0) arity else arity + 1
case self: SingletonType =>
if self.termSymbol == defn.EmptyTupleModule then 0 else -1
case self: TypeRef if self.classSymbol == defn.EmptyTupleModule.moduleClass =>
case self: TypeRef if relaxEmptyTuple && self.classSymbol == defn.EmptyTupleModule.moduleClass =>
0
case self if defn.isTupleClass(self.classSymbol) =>
self.dealias.argInfos.length
Expand Down
141 changes: 96 additions & 45 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,36 +279,81 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
case t => mapOver(t)
monoMap(mirroredType.resultType)

private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors =
private[Synthesizer] enum MirrorSource:
case ClassSymbol(cls: Symbol)
case GenericTuple(tuplArity: Int, tpArgs: List[Type])

def isGenericTuple: Boolean = this.isInstanceOf[GenericTuple]

/** tuple arity, works for TupleN classes and generic tuples */
final def arity(using Context): Int = this match
case GenericTuple(arity, _) => arity
case ClassSymbol(cls) if defn.isTupleClass(cls) => cls.typeParams.length
case _ => -1

def equiv(that: MirrorSource)(using Context): Boolean = (this.arity, that.arity) match
case (n, m) if n > 0 || m > 0 =>
// we shortcut when at least one was a tuple.
// This protects us from comparing classes for two TupleXXL with different arities.
n == m
case _ => this.asClass eq that.asClass // class equality otherwise

def isSub(that: MirrorSource)(using Context): Boolean = (this.arity, that.arity) match
case (n, m) if n > 0 || m > 0 =>
// we shortcut when at least one was a tuple.
// This protects us from comparing classes for two TupleXXL with different arities.
n == m
case _ => this.asClass isSubClass that.asClass

def asClass(using Context): Symbol = this match
case ClassSymbol(cls) => cls
case GenericTuple(arity, _) =>
if arity <= Definitions.MaxTupleArity then defn.TupleType(arity).nn.classSymbol
else defn.TupleXXLClass

object MirrorSource:
def tuple(tps: List[Type]): MirrorSource.GenericTuple = MirrorSource.GenericTuple(tps.size, tps)

end MirrorSource

var isSafeGenericTuple = Option.empty[(Symbol, List[Type])]
private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors =

/** do all parts match the class symbol? Or can we extract a generic tuple type out? */
def acceptable(tp: Type, cls: Symbol): Boolean =
var genericTupleParts = List.empty[(Symbol, List[Type])]
extension (msrc: MirrorSource) def isGenericProd(using Context) =
msrc.isGenericTuple || msrc.asClass.isGenericProduct && canAccessCtor(msrc.asClass)

def acceptableGenericTuple(tp: AppliedType): Boolean =
val tupleArgs = tp.tupleElementTypes
val arity = tupleArgs.size
val isOk = arity <= Definitions.MaxTupleArity
if isOk then
genericTupleParts ::= {
val cls = defn.TupleType(arity).nn.classSymbol
(cls, tupleArgs)
}
isOk
/** Follows `classSymbol`, but instead reduces to a proxy of a generic tuple (or a scala.TupleN class).
*
* Does not need to consider AndType, as that is already stripped.
*/
def tupleProxy(tp: Type)(using Context): Option[MirrorSource] = tp match
case tp: TypeRef => if tp.symbol.isClass then None else tupleProxy(tp.superType)
case GenericTupleType(args) => Some(MirrorSource.tuple(args))
case tp: TypeProxy =>
tupleProxy(tp.underlying)
case tp: OrType =>
if tp.tp1.hasClassSymbol(defn.NothingClass) then
tupleProxy(tp.tp2)
else if tp.tp2.hasClassSymbol(defn.NothingClass) then
tupleProxy(tp.tp1)
else tupleProxy(tp.join)
case _ =>
None

def inner(tp: Type, cls: Symbol): Boolean = tp match
case tp: HKTypeLambda if tp.resultType.isInstanceOf[HKTypeLambda] => false
case tp @ AppliedType(cons: TypeRef, _) if cons.isRef(defn.PairClass) => acceptableGenericTuple(tp)
case tp: TypeProxy => inner(tp.underlying, cls)
case OrType(tp1, tp2) => inner(tp1, cls) && inner(tp2, cls)
case _ => tp.classSymbol eq cls
def mirrorSource(tp: Type)(using Context): Option[MirrorSource] =
val fromClass = tp.classSymbol
if fromClass.exists then // test if it could be reduced to a generic tuple
if fromClass.isSubClass(defn.TupleClass) && !defn.isTupleClass(fromClass) then tupleProxy(tp)
else Some(MirrorSource.ClassSymbol(fromClass))
else None

val classPartsMatch = inner(tp, cls)
classPartsMatch && genericTupleParts.map((cls, _) => cls).distinct.sizeIs <= 1 &&
{ isSafeGenericTuple = genericTupleParts.headOption ; true }
end acceptable
/** do all parts match the class symbol? */
def acceptable(tp: Type, msrc: MirrorSource): Boolean = tp match
case tp: HKTypeLambda if tp.resultType.isInstanceOf[HKTypeLambda] => false
case OrType(tp1, tp2) => acceptable(tp1, msrc) && acceptable(tp2, msrc)
case GenericTupleType(args) if args.size <= Definitions.MaxTupleArity =>
MirrorSource.tuple(args).equiv(msrc)
case tp: TypeProxy => acceptable(tp.underlying, msrc)
case _ => mirrorSource(tp).exists(_.equiv(msrc))

/** for a case class, if it will have an anonymous mirror,
* check that its constructor can be accessed
Expand All @@ -326,13 +371,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
def genAnonyousMirror(cls: Symbol): Boolean =
cls.is(Scala2x) || cls.linkedClass.is(Case)

def makeProductMirror(cls: Symbol): TreeWithErrors =
val mirroredClass = isSafeGenericTuple.fold(cls)((cls, _) => cls)
def makeProductMirror(msrc: MirrorSource): TreeWithErrors =
val mirroredClass = msrc.asClass
val accessors = mirroredClass.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
val nestedPairs = isSafeGenericTuple.map((_, tps) => TypeOps.nestedPairs(tps)).getOrElse {
TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
}
val nestedPairs = msrc match
case MirrorSource.GenericTuple(_, args) => TypeOps.nestedPairs(args)
case _ => TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
Expand All @@ -342,25 +387,30 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
val mirrorType =
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name, formal)
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, mirroredClass.name, formal)
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
val mirrorRef =
if genAnonyousMirror(mirroredClass) then
anonymousMirror(monoType, ExtendsProductMirror, isSafeGenericTuple.map(_(1).size), span)
val arity = msrc match
case MirrorSource.GenericTuple(arity, _) => Some(arity)
case _ => None
anonymousMirror(monoType, ExtendsProductMirror, arity, span)
else companionPath(mirroredType, span)
withNoErrors(mirrorRef.cast(mirrorType))
end makeProductMirror

def getError(cls: Symbol): String =
def getError(msrc: MirrorSource): String =
val reason =
if !cls.isGenericProduct then
i"because ${cls.whyNotGenericProduct}"
else if !canAccessCtor(cls) then
i"because the constructor of $cls is innaccessible from the calling scope."
if !msrc.isGenericTuple then
if !msrc.asClass.isGenericProduct then
i"because ${msrc.asClass.whyNotGenericProduct}"
else if !canAccessCtor(msrc.asClass) then
i"because the constructor of ${msrc.asClass} is innaccessible from the calling scope."
else ""
else
""
i"$cls is not a generic product $reason"
i"${msrc.asClass} is not a generic product $reason"
end getError

mirroredType match
Expand All @@ -378,13 +428,14 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
val mirrorType = mirrorCore(defn.Mirror_SingletonClass, mirroredType, mirroredType, module.name, formal)
withNoErrors(modulePath.cast(mirrorType))
else
val cls = mirroredType.classSymbol
if acceptable(mirroredType, cls)
&& isSafeGenericTuple.isDefined || (cls.isGenericProduct && canAccessCtor(cls))
then
makeProductMirror(cls)
else
(EmptyTree, List(getError(cls)))
mirrorSource(mirroredType) match
case Some(msrc) =>
if acceptable(mirroredType, msrc) && msrc.isGenericProd then
makeProductMirror(msrc)
else
(EmptyTree, List(getError(msrc)))
case None =>
(EmptyTree, List(i"${mirroredType.show} does not reduce to a class or generic tuple type"))
end productMirror

private def sumMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2766,7 +2766,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
typed(desugar.smallTuple(tree).withSpan(tree.span), pt)
else {
val pts =
if (arity == pt.tupleArity) pt.tupleElementTypes
if (arity == pt.tupleArity()) pt.tupleElementTypes
else List.fill(arity)(defn.AnyType)
val elems = tree.trees.lazyZip(pts).map(
if ctx.mode.is(Mode.Type) then typedType(_, _, mapPatternBounds = true)
Expand Down
5 changes: 4 additions & 1 deletion tests/neg/i14127a.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ import scala.deriving.Mirror

// mixing arities is not supported

val mT23 = summon[Mirror.Of[(Int *: Int *: EmptyTuple) | (Int *: Int *: Int *: EmptyTuple)]] // error
val mT2Or2a = summon[Mirror.Of[(Int *: Int *: EmptyTuple) | (Int *: Int *: EmptyTuple)]] // ok, same arity
val mT2Or2b = summon[Mirror.Of[(Int *: Int *: EmptyTuple) | Tuple2[Int, Int]]] // ok, same arity
val mT2Or2c = summon[Mirror.Of[Tuple2[Int, Int] | (Int *: Int *: EmptyTuple)]] // ok, same arity
val mT2Or3 = summon[Mirror.Of[(Int *: Int *: EmptyTuple) | (Int *: Int *: Int *: EmptyTuple)]] // error

0 comments on commit 7299938

Please sign in to comment.