Skip to content

Commit

Permalink
synthesize mirrors for small generic tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
bishabosha committed May 20, 2022
1 parent ee9cc8f commit 8f55301
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ object SyntheticMembers {

/** Attachment recording that an anonymous class should extend Mirror.Sum */
val ExtendsSumMirror: Property.StickyKey[Unit] = new Property.StickyKey

/** Attachment recording that an anonymous class should extend Mirror.Sum */
val GenericTupleArity: Property.StickyKey[Int] = new Property.StickyKey
}

/** Synthetic method implementations for case classes, case objects,
Expand Down Expand Up @@ -601,7 +604,11 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
else if (impl.removeAttachment(ExtendsSingletonMirror).isDefined)
makeSingletonMirror()
else if (impl.removeAttachment(ExtendsProductMirror).isDefined)
makeProductMirror(monoType.typeRef.dealias.classSymbol)
val tupleArity = impl.removeAttachment(GenericTupleArity)
val cls = tupleArity match
case Some(n) => defn.TupleType(n).nn.classSymbol
case _ => monoType.typeRef.dealias.classSymbol
makeProductMirror(cls)
else if (impl.removeAttachment(ExtendsSumMirror).isDefined)
makeSumMirror(monoType.typeRef.dealias.classSymbol)

Expand Down
10 changes: 7 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/TypeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ object TypeUtils {
if (arity < 0) arity else arity + 1
case self: SingletonType =>
if self.termSymbol == defn.EmptyTupleModule then 0 else -1
case self: TypeRef if self.classSymbol == defn.EmptyTupleModule.moduleClass =>
0
case self if defn.isTupleClass(self.classSymbol) =>
self.dealias.argInfos.length
case _ =>
Expand All @@ -69,12 +71,14 @@ object TypeUtils {
case AppliedType(tycon, hd :: tl :: Nil) if tycon.isRef(defn.PairClass) =>
hd :: tl.tupleElementTypes
case self: SingletonType =>
assert(self.termSymbol == defn.EmptyTupleModule, "not a tuple")
assert(self.termSymbol == defn.EmptyTupleModule, i"not a tuple `$self`")
Nil
case self: TypeRef if self.classSymbol == defn.EmptyTupleModule.moduleClass =>
Nil
case self if defn.isTupleClass(self.classSymbol) =>
self.dealias.argInfos
case _ =>
throw new AssertionError("not a tuple")
case tp =>
throw new AssertionError(i"not a tuple `$tp`")
}

/** The `*:` equivalent of an instance of a Tuple class */
Expand Down
58 changes: 37 additions & 21 deletions compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
/** Handlers to synthesize implicits for special types */
type SpecialHandler = (Type, Span) => Context ?=> TreeWithErrors
private type SpecialHandlers = List[(ClassSymbol, SpecialHandler)]

val synthesizedClassTag: SpecialHandler = (formal, span) =>
formal.argInfos match
case arg :: Nil =>
Expand Down Expand Up @@ -223,16 +223,19 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
/** Create an anonymous class `new Object { type MirroredMonoType = ... }`
* and mark it with given attachment so that it is made into a mirror at PostTyper.
*/
private def anonymousMirror(monoType: Type, attachment: Property.StickyKey[Unit], span: Span)(using Context) =
private def anonymousMirror(monoType: Type, attachment: Property.StickyKey[Unit], tupleArity: Option[Int], span: Span)(using Context) =
if ctx.isAfterTyper then ctx.compilationUnit.needsMirrorSupport = true
val monoTypeDef = untpd.TypeDef(tpnme.MirroredMonoType, untpd.TypeTree(monoType))
val newImpl = untpd.Template(
var newImpl = untpd.Template(
constr = untpd.emptyConstructor,
parents = untpd.TypeTree(defn.ObjectType) :: Nil,
derived = Nil,
self = EmptyValDef,
body = monoTypeDef :: Nil
).withAttachment(attachment, ())
tupleArity.foreach { n =>
newImpl = newImpl.withAttachment(GenericTupleArity, n)
}
typer.typed(untpd.New(newImpl).withSpan(span))

/** The mirror type
Expand Down Expand Up @@ -278,9 +281,18 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):

private def productMirror(mirroredType: Type, formal: Type, span: Span)(using Context): TreeWithErrors =

var isSafeGenericTuple = Option.empty[List[Type]]

def illegalGenericTuple(tp: AppliedType): Boolean =
val tupleArgs = tp.tupleElementTypes
val isTooLarge = tupleArgs.length > Definitions.MaxTupleArity
isSafeGenericTuple = Option.when(!isTooLarge)(tupleArgs)
isTooLarge

/** do all parts match the class symbol? */
def acceptable(tp: Type, cls: Symbol): Boolean = tp match
case tp: HKTypeLambda if tp.resultType.isInstanceOf[HKTypeLambda] => false
case tp @ AppliedType(cons: TypeRef, _) if cons.isRef(defn.PairClass) && illegalGenericTuple(tp) => false
case tp: TypeProxy => acceptable(tp.underlying, cls)
case OrType(tp1, tp2) => acceptable(tp1, cls) && acceptable(tp2, cls)
case _ => tp.classSymbol eq cls
Expand All @@ -302,9 +314,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
cls.is(Scala2x) || cls.linkedClass.is(Case)

def makeProductMirror(cls: Symbol): TreeWithErrors =
val accessors = cls.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
val mirroredClass = isSafeGenericTuple.fold(cls)(tps => defn.TupleType(tps.size).nn.classSymbol)
val accessors = mirroredClass.caseAccessors.filterNot(_.isAllOf(PrivateLocal))
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
val nestedPairs = TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
val nestedPairs = isSafeGenericTuple.map(TypeOps.nestedPairs).getOrElse {
TypeOps.nestedPairs(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
}
val (monoType, elemsType) = mirroredType match
case mirroredType: HKTypeLambda =>
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
Expand All @@ -318,18 +333,20 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
val mirrorRef =
if (genAnonyousMirror(cls)) anonymousMirror(monoType, ExtendsProductMirror, span)
if genAnonyousMirror(mirroredClass) then
anonymousMirror(monoType, ExtendsProductMirror, isSafeGenericTuple.map(_.size), span)
else companionPath(mirroredType, span)
withNoErrors(mirrorRef.cast(mirrorType))
end makeProductMirror

def getError(cls: Symbol): String =
val reason = if !cls.isGenericProduct then
i"because ${cls.whyNotGenericProduct}"
else if !canAccessCtor(cls) then
i"because the constructor of $cls is innaccessible from the calling scope."
else
""
def getError(cls: Symbol): String =
val reason =
if !cls.isGenericProduct then
i"because ${cls.whyNotGenericProduct}"
else if !canAccessCtor(cls) then
i"because the constructor of $cls is innaccessible from the calling scope."
else
""
i"$cls is not a generic product $reason"
end getError

Expand All @@ -350,8 +367,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
else
val cls = mirroredType.classSymbol
if acceptable(mirroredType, cls)
&& cls.isGenericProduct
&& canAccessCtor(cls)
&& isSafeGenericTuple.isDefined || (cls.isGenericProduct && canAccessCtor(cls))
then
makeProductMirror(cls)
else
Expand Down Expand Up @@ -424,11 +440,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
val mirrorRef =
if useCompanion then companionPath(mirroredType, span)
else anonymousMirror(monoType, ExtendsSumMirror, span)
else anonymousMirror(monoType, ExtendsSumMirror, tupleArity = None, span)
withNoErrors(mirrorRef.cast(mirrorType))
else if !clsIsGenericSum then
(EmptyTree, List(i"$cls is not a generic sum because ${cls.whyNotGenericSum(declScope)}"))
else
else
EmptyTreeNoError
end sumMirror

Expand Down Expand Up @@ -595,7 +611,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
tp.baseType(cls)
val base = baseWithRefinements(formal)
val result =
if (base <:< formal.widenExpr)
if (base <:< formal.widenExpr)
// With the subtype test we enforce that the searched type `formal` is of the right form
handler(base, span)
else EmptyTreeNoError
Expand All @@ -609,19 +625,19 @@ end Synthesizer

object Synthesizer:

/** Tuple used to store the synthesis result with a list of errors. */
/** Tuple used to store the synthesis result with a list of errors. */
type TreeWithErrors = (Tree, List[String])
private def withNoErrors(tree: Tree): TreeWithErrors = (tree, List.empty)

private val EmptyTreeNoError: TreeWithErrors = withNoErrors(EmptyTree)

private def orElse(treeWithErrors1: TreeWithErrors, treeWithErrors2: => TreeWithErrors): TreeWithErrors = treeWithErrors1 match
case (tree, errors) if tree eq genericEmptyTree =>
case (tree, errors) if tree eq genericEmptyTree =>
val (tree2, errors2) = treeWithErrors2
(tree2, errors ::: errors2)
case _ => treeWithErrors1

private def clearErrorsIfNotEmpty(treeWithErrors: TreeWithErrors) = treeWithErrors match
private def clearErrorsIfNotEmpty(treeWithErrors: TreeWithErrors) = treeWithErrors match
case (tree, _) if tree eq genericEmptyTree => treeWithErrors
case (tree, _) => withNoErrors(tree)

Expand Down
6 changes: 6 additions & 0 deletions tests/neg/i14127.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import scala.deriving.Mirror

val mT23 = summon[Mirror.Of[(
Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int
*: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int
*: Int *: Int *: Int *: Int *: Int *: EmptyTuple)]] // error
14 changes: 14 additions & 0 deletions tests/run/i14127.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import scala.deriving.Mirror

@main def Test =
val mISB = summon[Mirror.Of[Int *: String *: Boolean *: EmptyTuple]]
assert(mISB.fromProduct((1, "foo", true)) == (1, "foo", true))

val mT22 = summon[Mirror.Of[(
Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int
*: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int *: Int
*: Int *: Int *: Int *: Int *: EmptyTuple)]]

// tuple of 22 elements
val t22 = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)
assert(mT22.fromProduct(t22) == t22)
18 changes: 18 additions & 0 deletions tests/run/i7079.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import scala.deriving._

case class Foo(x: Int, y: String)

def toTuple[T <: Product](x: T)(using m: Mirror.ProductOf[T], mt: Mirror.ProductOf[m.MirroredElemTypes]) =
mt.fromProduct(x)

@main def Test = {
val m = summon[Mirror.ProductOf[Foo]]
val mt1 = summon[Mirror.ProductOf[(Int, String)]]
type R = (Int, String)
val mt2 = summon[Mirror.ProductOf[R]]
val mt3 = summon[Mirror.ProductOf[m.MirroredElemTypes]]

val f = Foo(1, "foo")
val g: (Int, String) = toTuple(f)// (using m, mt1)
assert(g == (1, "foo"))
}

0 comments on commit 8f55301

Please sign in to comment.