Skip to content

Commit

Permalink
Move internal classes that are dependent on q inside a container.
Browse files Browse the repository at this point in the history
Scala 3.3.0 does not allow the `extends` clause of a class to
refer to constructor parameters. It was shown to be unsound.
See scala/scala3#16270

However, we can depend on paths that enclose the class definition.
So we introduce a container class `TreeMaps` that takes the `q: Q`
as parameter, and move the classes that extend `q.reflect.TreeMap`
inside that container.
  • Loading branch information
sjrd committed Aug 9, 2023
1 parent 027d078 commit e33ac5c
Showing 1 changed file with 80 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,105 +245,101 @@ object InlineHKDGeneric:
arr
}

private class ProductElementIdExactExpander[Q <: Quotes, Fields <: Tuple: Type](using val q: Q)
extends q.reflect.TreeMap {
private class TreeMaps[Q <: Quotes](using val q: Q) {
import q.reflect.*

override def transformTerm(tree: Term)(owner: Symbol): Term =
try {
tree.asExpr match {
case '{ InlineHKDGeneric.productElementIdExact[a2, elemTop]($a, $idx) } =>
transformExact(a, idx)
case _ =>
super.transformTerm(tree)(owner)
class ProductElementIdExactExpander[Fields <: Tuple: Type] extends TreeMap {
override def transformTerm(tree: Term)(owner: Symbol): Term =
try {
tree.asExpr match {
case '{ InlineHKDGeneric.productElementIdExact[a2, elemTop]($a, $idx) } =>
transformExact(a, idx)
case _ =>
super.transformTerm(tree)(owner)
}
} catch {
case e: Exception =>
// Tried to convert partially applied type to Expr. Ignoring it
tree
}
} catch {
case e: Exception =>
// Tried to convert partially applied type to Expr. Ignoring it
tree
}

def transformExact[A](a: Expr[A], idxExpr: Expr[Int]): Term = {
def findConstantIdx(tpe: TypeRepr): Option[Int] = tpe match {
case AndType(a, b) => findConstantIdx(a).orElse(findConstantIdx(b))
case ConstantType(IntConstant(i)) => Some(i)
case Refinement(a, _, _) => findConstantIdx(a)
case t => None
}
def transformExact[A](a: Expr[A], idxExpr: Expr[Int]): Term = {
def findConstantIdx(tpe: TypeRepr): Option[Int] = tpe match {
case AndType(a, b) => findConstantIdx(a).orElse(findConstantIdx(b))
case ConstantType(IntConstant(i)) => Some(i)
case Refinement(a, _, _) => findConstantIdx(a)
case t => None
}

val idx = findConstantIdx(idxExpr.asTerm.tpe.widenTermRefByName).getOrElse(idxExpr.valueOrAbort)
val idx = findConstantIdx(idxExpr.asTerm.tpe.widenTermRefByName).getOrElse(idxExpr.valueOrAbort)

val field = Helpers.valuesOfConstantTuple(TypeRepr.of[Fields], Nil) match {
case Some(seq) => seq(idx).asExprOf[String].valueOrAbort
case None => report.errorAndAbort("productElementIdExact called with non constant fields type")
}
val field = Helpers.valuesOfConstantTuple(TypeRepr.of[Fields], Nil) match {
case Some(seq) => seq(idx).asExprOf[String].valueOrAbort
case None => report.errorAndAbort("productElementIdExact called with non constant fields type")
}

Select.unique(a.asTerm, field)
Select.unique(a.asTerm, field)
}
}
}

private class RefReplacer[Q <: Quotes](using val q: Q)(oldRef: q.reflect.Symbol, newRef: q.reflect.Ref)
extends q.reflect.TreeMap {
import q.reflect.*
class RefReplacer(oldRef: q.reflect.Symbol, newRef: q.reflect.Ref) extends TreeMap {
override def transformTerm(tree: Term)(owner: Symbol): Term =
tree match {
case Ident(id) if id == oldRef.name => newRef
case _ => super.transformTerm(tree)(owner)
}
}

override def transformTerm(tree: Term)(owner: Symbol): Term =
tree match {
case Ident(id) if id == oldRef.name => newRef
case _ => super.transformTerm(tree)(owner)
class LateInlineMatchExpander extends TreeMap {
override def transformTerm(tree: Term)(owner: Symbol): Term = {
tree.asExpr match {
case '{ InlineHKDGeneric.lateInlineMatch[a]($a) } =>
transformMatch(a, owner)
case _ =>
try {
super.transformTerm(tree)(owner)
} catch {
case _: Exception =>
// FIXME: Have no idea why this happens. Just ignoring it for now.
tree
}
}
}
}

private class LateInlineMatchExpander[Q <: Quotes]()(using val q: Q) extends q.reflect.TreeMap {
import q.reflect.*
def transformMatch[A: Type](aExpr: Expr[A], owner: Symbol): Term = aExpr.asTerm match {
case m @ Match(scrutinee, cases) =>
val tpe = scrutinee.tpe.widenTermRefByName

override def transformTerm(tree: Term)(owner: Symbol): Term = {
tree.asExpr match {
case '{ InlineHKDGeneric.lateInlineMatch[a]($a) } =>
transformMatch(a, owner)
case _ =>
try {
super.transformTerm(tree)(owner)
} catch {
case _: Exception =>
// FIXME: Have no idea why this happens. Just ignoring it for now.
tree
cases.foreach {
case CaseDef(_, Some(_), _) => report.errorAndAbort("Cases in match can not have guards")
case CaseDef(Bind(_, Typed(Ident(_), _)), _, _) =>
case CaseDef(Bind(_, Ident(_)), _, _) =>
case caseDef => report.errorAndAbort("Invalid case in match inside lateInlineMatch", caseDef.pos)
}
}
}

def transformMatch[A: Type](aExpr: Expr[A], owner: Symbol): Term = aExpr.asTerm match {
case m @ Match(scrutinee, cases) =>
val tpe = scrutinee.tpe.widenTermRefByName

cases.foreach {
case CaseDef(_, Some(_), _) => report.errorAndAbort("Cases in match can not have guards")
case CaseDef(Bind(_, Typed(Ident(_), _)), _, _) =>
case CaseDef(Bind(_, Ident(_)), _, _) =>
case caseDef => report.errorAndAbort("Invalid case in match inside lateInlineMatch", caseDef.pos)
}
val hasDefaultCase = cases.exists {
case CaseDef(Bind(_, Ident(_)), _, _) => true
case _ => false
}
if !hasDefaultCase then report.errorAndAbort("Match must have a default case", m.pos)

val hasDefaultCase = cases.exists {
case CaseDef(Bind(_, Ident(_)), _, _) => true
case _ => false
}
if !hasDefaultCase then report.errorAndAbort("Match must have a default case", m.pos)
val (bind, rhs) = cases
.collectFirst {
case CaseDef(bind @ Bind(_, typed @ Typed(Ident(_), _)), _, rhs) if tpe <:< typed.symbol.typeRef =>
(bind, rhs)
}
.getOrElse {
cases.collectFirst { case CaseDef(bind @ Bind(_, Ident(_)), _, rhs) =>
(bind, rhs)
}.get
}

val (bind, rhs) = cases
.collectFirst {
case CaseDef(bind @ Bind(_, typed @ Typed(Ident(_), _)), _, rhs) if tpe <:< typed.symbol.typeRef =>
(bind, rhs)
}
.getOrElse {
cases.collectFirst { case CaseDef(bind @ Bind(_, Ident(_)), _, rhs) =>
(bind, rhs)
}.get
ValDef.let(owner, bind.name + "Replaced", scrutinee) { newRef =>
val replacer = new RefReplacer(bind.symbol, newRef)
replacer.transformTerm(rhs)(owner)
}

ValDef.let(owner, bind.name + "Replaced", scrutinee) { newRef =>
val replacer = new RefReplacer[q.type](bind.symbol, newRef)
replacer.transformTerm(rhs)(owner)
}
case _ => report.errorAndAbort("Body of lateInlineMatch must be a match", aExpr)
case _ => report.errorAndAbort("Body of lateInlineMatch must be a match", aExpr)
}
}
}

Expand All @@ -352,8 +348,9 @@ object InlineHKDGeneric:
): Expr[A] =
import q.reflect.*

val productElementIdExactExpander = new ProductElementIdExactExpander[q.type, Fields]()
val lateInlineMatchExpander = new LateInlineMatchExpander[q.type]()
val treeMaps = new TreeMaps[q.type]()
val productElementIdExactExpander = new treeMaps.ProductElementIdExactExpander[Fields]()
val lateInlineMatchExpander = new treeMaps.LateInlineMatchExpander()

val r1 = productElementIdExactExpander.transformTerm(e.asTerm)(Symbol.spliceOwner)
val r2 = lateInlineMatchExpander.transformTerm(r1)(Symbol.spliceOwner)
Expand Down

0 comments on commit e33ac5c

Please sign in to comment.