Skip to content

Commit

Permalink
Trial: New ElimByName phase
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Jan 20, 2022
1 parent 5a328dd commit 81cbe32
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 1 deletion.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class Compiler {
new InlineVals, // Check right hand-sides of an `inline val`s
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
new ElimRepeated) :: // Rewrite vararg parameters and arguments
List(new ElimByNameParams) ::
List(new init.Checker) :: // Check initialization of objects
List(new ProtectedAccessors, // Add accessors for protected members
new ExtensionMethods, // Expand methods of value classes with extension methods
Expand Down
20 changes: 20 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,24 @@ class Definitions {
}
}

object ByNameFunction:
def apply(tp: Type)(using Context): Type =
defn.ContextFunction0.typeRef.appliedTo(tp :: Nil)
def unapply(tp: Type)(using Context): Option[Type] = tp match
case tp @ AppliedType(tycon, arg :: Nil) if defn.isByNameFunctionClass(tycon.typeSymbol) =>
Some(arg)
case tp @ AnnotatedType(parent, _) =>
unapply(parent)
case _ =>
None

final def isByNameFunctionClass(sym: Symbol): Boolean =
sym eq ContextFunction0

def isByNameFunction(tp: Type)(using Context): Boolean = tp match
case ByNameFunction(_) => true
case _ => false

final def isCompiletime_S(sym: Symbol)(using Context): Boolean =
sym.name == tpnme.S && sym.owner == CompiletimeOpsIntModuleClass

Expand Down Expand Up @@ -1295,10 +1313,12 @@ class Definitions {
).symbol.asClass

@tu lazy val Function0_apply: Symbol = Function0.requiredMethod(nme.apply)
@tu lazy val ContextFunction0_apply: Symbol = ContextFunction0.requiredMethod(nme.apply)

@tu lazy val Function0: Symbol = FunctionClass(0)
@tu lazy val Function1: Symbol = FunctionClass(1)
@tu lazy val Function2: Symbol = FunctionClass(2)
@tu lazy val ContextFunction0: Symbol = FunctionClass(0, isContextual = true)

def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false)(using Context): TypeRef =
FunctionClass(n, isContextual && !ctx.erasedTypes, isErased).typeRef
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ object Phases {
private var myRefChecksPhase: Phase = _
private var myPatmatPhase: Phase = _
private var myElimRepeatedPhase: Phase = _
private var myElimByNamePhase: Phase = _
private var myExtensionMethodsPhase: Phase = _
private var myExplicitOuterPhase: Phase = _
private var myGettersPhase: Phase = _
Expand All @@ -229,6 +230,7 @@ object Phases {
final def refchecksPhase: Phase = myRefChecksPhase
final def patmatPhase: Phase = myPatmatPhase
final def elimRepeatedPhase: Phase = myElimRepeatedPhase
final def elimByNamePhase: Phase = myElimByNamePhase
final def extensionMethodsPhase: Phase = myExtensionMethodsPhase
final def explicitOuterPhase: Phase = myExplicitOuterPhase
final def gettersPhase: Phase = myGettersPhase
Expand All @@ -253,6 +255,7 @@ object Phases {
myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields])
myRefChecksPhase = phaseOfClass(classOf[RefChecks])
myElimRepeatedPhase = phaseOfClass(classOf[ElimRepeated])
myElimByNamePhase = phaseOfClass(classOf[ElimByNameParams])
myExtensionMethodsPhase = phaseOfClass(classOf[ExtensionMethods])
myErasurePhase = phaseOfClass(classOf[Erasure])
myElimErasedValueTypePhase = phaseOfClass(classOf[ElimErasedValueType])
Expand Down Expand Up @@ -427,6 +430,7 @@ object Phases {
def firstTransformPhase(using Context): Phase = ctx.base.firstTransformPhase
def refchecksPhase(using Context): Phase = ctx.base.refchecksPhase
def elimRepeatedPhase(using Context): Phase = ctx.base.elimRepeatedPhase
def elimByNamePhase(using Context): Phase = ctx.base.elimByNamePhase
def extensionMethodsPhase(using Context): Phase = ctx.base.extensionMethodsPhase
def explicitOuterPhase(using Context): Phase = ctx.base.explicitOuterPhase
def gettersPhase(using Context): Phase = ctx.base.gettersPhase
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ object StdNames {
// ----- Type names -----------------------------------------

final val BYNAME_PARAM_CLASS: N = "<byname>"
final val BYNAME_PARAM_FUN: N = "<function0-byname>"
final val EQUALS_PATTERN: N = "<equals>"
final val LOCAL_CHILD: N = "<local child>"
final val REPEATED_PARAM_CLASS: N = "<repeated>"
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case _ => tp2.isAnyRef
}
compareJavaArray
case tp1: ExprType if ctx.phase.id > gettersPhase.id =>
case tp1: ExprType if ctx.phaseId > gettersPhase.id =>
// getters might have converted T to => T, need to compensate.
recur(tp1.widenExpr, tp2)
case _ =>
Expand Down Expand Up @@ -1510,6 +1510,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case _ => arg1
}
arg2.contains(arg1norm)
case ExprType(arg2res)
if ctx.phaseId > ctx.base.elimByNamePhase.id && !ctx.erasedTypes
&& defn.isByNameFunction(arg1) =>
// ElimByName maps `=> T` to `()? => T`, but only in method parameters. It leaves
// embedded `=> T` alone. This clause needs to compensate for that.
isSubArg(arg1.argInfos.head, arg2res)
case _ =>
arg1 match {
case arg1: TypeBounds =>
Expand Down
135 changes: 135 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/ElimByNameParams.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package dotty.tools
package dotc
package transform

import core._
import Contexts._
import Symbols._
import Types._
import Flags._
import SymDenotations.*
import DenotTransformers.InfoTransformer
import NameKinds.SuperArgName
import core.StdNames.nme
import MegaPhase.*
import Decorators.*
import reporting.trace

/** This phase translates arguments to call-by-name parameters, using the rules
*
* x ==> x if x is a => parameter
* e.apply() ==> <cbn-arg>(e) if e is pure
* e ==> <cbn-arg>(() => e) for all other arguments
*
* where
*
* <cbn-arg>: [T](() => T): T
*
* is a synthetic method defined in Definitions. Erasure will later strip the <cbn-arg> wrappers.
*/
class ElimByNameParams extends MiniPhase, InfoTransformer:
thisPhase =>

import ast.tpd._

override def phaseName: String = ElimByNameParams.name

override def runsAfterGroupsOf: Set[String] = Set(ExpandSAMs.name, ElimRepeated.name)
// - ExpanSAMs applied to partial functions creates methods that need
// to be fully defined before converting. Test case is pos/i9391.scala.
// - ByNameLambda needs to run in a group after ElimRepeated since ElimRepeated
// works on simple arguments but not converted closures, and it sees the arguments
// after transformations by subsequent miniphases in the same group.

override def changesParents: Boolean = true
// Expr types in parent type arguments are changed to function types.

/** If denotation had an ExprType before, it now gets a function type */
private def exprBecomesFunction(symd: SymDenotation)(using Context): Boolean =
symd.is(Param) || symd.is(ParamAccessor, butNot = Method)

def transformInfo(tp: Type, sym: Symbol)(using Context): Type = tp match {
case ExprType(rt) if exprBecomesFunction(sym) =>
defn.ByNameFunction(rt)
case tp: MethodType =>
def exprToFun(tp: Type) = tp match
case ExprType(rt) => defn.ByNameFunction(rt)
case tp => tp
tp.derivedLambdaType(
paramInfos = tp.paramInfos.mapConserve(exprToFun),
resType = transformInfo(tp.resType, sym))
case tp: PolyType =>
tp.derivedLambdaType(resType = transformInfo(tp.resType, sym))
case _ => tp
}

override def infoMayChange(sym: Symbol)(using Context): Boolean =
sym.is(Method) || exprBecomesFunction(sym)

def byNameClosure(arg: Tree, argType: Type)(using Context): Tree =
val meth = newAnonFun(ctx.owner, MethodType(Nil, argType), coord = arg.span)
Closure(meth,
_ => arg.changeOwnerAfter(ctx.owner, meth, thisPhase),
targetType = defn.ByNameFunction(argType)
).withSpan(arg.span)

private def isByNameRef(tree: Tree)(using Context): Boolean =
defn.isByNameFunction(tree.tpe.widen)

/** Map `tree` to `tree.apply()` is `tree` is of type `() ?=> T` */
private def applyIfFunction(tree: Tree)(using Context) =
if isByNameRef(tree) then
val tree0 = transformFollowing(tree)
atPhase(next) { tree0.select(defn.ContextFunction0_apply).appliedToNone }
else tree

override def transformIdent(tree: Ident)(using Context): Tree =
applyIfFunction(tree)

override def transformSelect(tree: Select)(using Context): Tree =
applyIfFunction(tree)

override def transformTypeApply(tree: TypeApply)(using Context): Tree = tree match {
case TypeApply(Select(_, nme.asInstanceOf_), arg :: Nil) =>
// tree might be of form e.asInstanceOf[x.type] where x becomes a function.
// See pos/t296.scala
applyIfFunction(tree)
case _ => tree
}

override def transformApply(tree: Apply)(using Context): Tree =
trace(s"transforming ${tree.show} at phase ${ctx.phase}", show = true) {

def transformArg(arg: Tree, formal: Type): Tree = formal match
case defn.ByNameFunction(formalResult) =>
def stripTyped(t: Tree): Tree = t match
case Typed(expr, _) => stripTyped(expr)
case _ => t
stripTyped(arg) match
case Apply(Select(qual, nme.apply), Nil)
if isByNameRef(qual) && (isPureExpr(qual) || qual.symbol.isAllOf(InlineParam)) =>
qual
case _ =>
if isByNameRef(arg) || arg.symbol.name.is(SuperArgName)
then arg
else
var argType = arg.tpe.widenIfUnstable
if argType.isBottomType then argType = formalResult
byNameClosure(arg, argType)
case _ =>
arg

val mt @ MethodType(_) = tree.fun.tpe.widen
val args1 = tree.args.zipWithConserve(mt.paramInfos)(transformArg)
cpy.Apply(tree)(tree.fun, args1)
}

override def transformValDef(tree: ValDef)(using Context): Tree =
atPhase(next) {
if exprBecomesFunction(tree.symbol) then
cpy.ValDef(tree)(tpt = tree.tpt.withType(tree.symbol.info))
else tree
}

object ElimByNameParams:
val name: String = "elimByNameParams"

0 comments on commit 81cbe32

Please sign in to comment.