Skip to content

Commit

Permalink
Code refactoring of initialization checker (#16066)
Browse files Browse the repository at this point in the history
Code refactoring of initialization checker
  • Loading branch information
liufengyun authored Oct 13, 2022
2 parents 2844c2b + 9050560 commit 2746428
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 89 deletions.
44 changes: 17 additions & 27 deletions compiler/src/dotty/tools/dotc/transform/init/Checker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import StdNames._
import dotty.tools.dotc.transform._
import Phases._

import scala.collection.mutable

import Semantic._

class Checker extends Phase {
class Checker extends Phase:

override def phaseName: String = Checker.name

Expand All @@ -31,17 +32,23 @@ class Checker extends Phase {

override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
val checkCtx = ctx.fresh.setPhase(this.start)
Semantic.checkTasks(using checkCtx) {
val traverser = new InitTreeTraverser()
units.foreach { unit => traverser.traverse(unit.tpdTree) }
}
val traverser = new InitTreeTraverser()
units.foreach { unit => traverser.traverse(unit.tpdTree) }
val classes = traverser.getClasses()

Semantic.checkClasses(classes)(using checkCtx)

units

def run(using Context): Unit = {
def run(using Context): Unit =
// ignore, we already called `Semantic.check()` in `runOn`
}
()

class InitTreeTraverser extends TreeTraverser:
private val classes: mutable.ArrayBuffer[ClassSymbol] = new mutable.ArrayBuffer

def getClasses(): List[ClassSymbol] = classes.toList

class InitTreeTraverser(using WorkList) extends TreeTraverser {
override def traverse(tree: Tree)(using Context): Unit =
traverseChildren(tree)
tree match {
Expand All @@ -53,29 +60,12 @@ class Checker extends Phase {
mdef match
case tdef: TypeDef if tdef.isClassDef =>
val cls = tdef.symbol.asClass
val thisRef = ThisRef(cls)
if shouldCheckClass(cls) then Semantic.addTask(thisRef)
classes.append(cls)
case _ =>

case _ =>
}
}

private def shouldCheckClass(cls: ClassSymbol)(using Context) = {
val instantiable: Boolean =
cls.is(Flags.Module) ||
!cls.isOneOf(Flags.AbstractOrTrait) && {
// see `Checking.checkInstantiable` in typer
val tp = cls.appliedRef
val stp = SkolemType(tp)
val selfType = cls.givenSelfType.asSeenFrom(stp, cls)
!selfType.exists || stp <:< selfType
}

// A concrete class may not be instantiated if the self type is not satisfied
instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass
}
}
end InitTreeTraverser

object Checker:
val name: String = "initChecker"
Expand Down
145 changes: 83 additions & 62 deletions compiler/src/dotty/tools/dotc/transform/init/Semantic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1206,72 +1206,49 @@ object Semantic:
cls == defn.AnyValClass ||
cls == defn.ObjectClass

// ----- Work list ---------------------------------------------------
case class Task(value: ThisRef)

class WorkList private[Semantic]():
private val pendingTasks: mutable.ArrayBuffer[Task] = new mutable.ArrayBuffer

def addTask(task: Task): Unit =
if !pendingTasks.contains(task) then pendingTasks.append(task)

/** Process the worklist until done */
final def work()(using Cache, Context): Unit =
for task <- pendingTasks
do doTask(task)

/** Check an individual class
*
* This method should only be called from the work list scheduler.
*/
private def doTask(task: Task)(using Cache, Context): Unit =
val thisRef = task.value
val tpl = thisRef.klass.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]

@tailrec
def iterate(): Unit = {
given Promoted = Promoted.empty(thisRef.klass)
given Trace = Trace.empty.add(thisRef.klass.defTree)
given reporter: Reporter.BufferedReporter = new Reporter.BufferedReporter
// ----- API --------------------------------

thisRef.ensureFresh()
/** Check an individual class
*
* The class to be checked must be an instantiable concrete class.
*/
private def checkClass(classSym: ClassSymbol)(using Cache, Context): Unit =
val thisRef = ThisRef(classSym)
val tpl = classSym.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]

// set up constructor parameters
for param <- tpl.constr.termParamss.flatten do
thisRef.updateField(param.symbol, Hot)
@tailrec
def iterate(): Unit = {
given Promoted = Promoted.empty(classSym)
given Trace = Trace.empty.add(classSym.defTree)
given reporter: Reporter.BufferedReporter = new Reporter.BufferedReporter

log("checking " + task) { eval(tpl, thisRef, thisRef.klass) }
reporter.errors.foreach(_.issue)
thisRef.ensureFresh()

if cache.hasChanged && reporter.errors.isEmpty then
// code to prepare cache and heap for next iteration
cache.prepareForNextIteration()
iterate()
else
cache.prepareForNextClass()
}
// set up constructor parameters
for param <- tpl.constr.termParamss.flatten do
thisRef.updateField(param.symbol, Hot)

iterate()
end doTask
end WorkList
inline def workList(using wl: WorkList): WorkList = wl
log("checking " + classSym) { eval(tpl, thisRef, classSym) }
reporter.errors.foreach(_.issue)

// ----- API --------------------------------
if cache.hasChanged && reporter.errors.isEmpty then
// code to prepare cache and heap for next iteration
cache.prepareForNextIteration()
iterate()
else
cache.prepareForNextClass()
}

/** Add a checking task to the work list */
def addTask(thisRef: ThisRef)(using WorkList) = workList.addTask(Task(thisRef))
iterate()
end checkClass

/** Check the specified tasks
*
* Semantic.checkTasks {
* Semantic.addTask(...)
* }
/**
* Check the specified concrete classes
*/
def checkTasks(using Context)(taskBuilder: WorkList ?=> Unit): Unit =
val workList = new WorkList
val cache = new Cache
taskBuilder(using workList)
workList.work()(using cache, ctx)
def checkClasses(classes: List[ClassSymbol])(using Context): Unit =
given Cache()
for classSym <- classes if isConcreteClass(classSym) do
checkClass(classSym)

// ----- Semantic definition --------------------------------

Expand All @@ -1296,7 +1273,10 @@ object Semantic:
*
* This method only handles cache logic and delegates the work to `cases`.
*
* The parameter `cacheResult` is used to reduce the size of the cache.
* @param expr The expression to be evaluated.
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the expression is located.
* @param cacheResult It is used to reduce the size of the cache.
*/
def eval(expr: Tree, thisV: Ref, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) {
cache.get(thisV, expr) match
Expand Down Expand Up @@ -1326,6 +1306,10 @@ object Semantic:
/** Handles the evaluation of different expressions
*
* Note: Recursive call should go to `eval` instead of `cases`.
*
* @param expr The expression to be evaluated.
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the expression `expr` is located.
*/
def cases(expr: Tree, thisV: Ref, klass: ClassSymbol): Contextual[Value] =
val trace2 = trace.add(expr)
Expand Down Expand Up @@ -1503,7 +1487,14 @@ object Semantic:
report.error("[Internal error] unexpected tree" + Trace.show, expr)
Hot

/** Handle semantics of leaf nodes */
/** Handle semantics of leaf nodes
*
* For leaf nodes, their semantics is determined by their types.
*
* @param tp The type to be evaluated.
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the type `tp` is located.
*/
def cases(tp: Type, thisV: Ref, klass: ClassSymbol): Contextual[Value] = log("evaluating " + tp.show, printer, (_: Value).show) {
tp match
case _: ConstantType =>
Expand Down Expand Up @@ -1541,7 +1532,12 @@ object Semantic:
Hot
}

/** Resolve C.this that appear in `klass` */
/** Resolve C.this that appear in `klass`
*
* @param target The class symbol for `C` for which `C.this` is to be resolved.
* @param thisV The value for `D.this` where `D` is represented by the parameter `klass`.
* @param klass The enclosing class where the type `C.this` is located.
*/
def resolveThis(target: ClassSymbol, thisV: Value, klass: ClassSymbol): Contextual[Value] = log("resolving " + target.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) {
if target == klass then thisV
else if target.is(Flags.Package) then Hot
Expand All @@ -1566,7 +1562,12 @@ object Semantic:

}

/** Compute the outer value that correspond to `tref.prefix` */
/** Compute the outer value that correspond to `tref.prefix`
*
* @param tref The type whose prefix is to be evaluated.
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
* @param klass The enclosing class where the type `tref` is located.
*/
def outerValue(tref: TypeRef, thisV: Ref, klass: ClassSymbol): Contextual[Value] =
val cls = tref.classSymbol.asClass
if tref.prefix == NoPrefix then
Expand All @@ -1577,7 +1578,12 @@ object Semantic:
if cls.isAllOf(Flags.JavaInterface) then Hot
else cases(tref.prefix, thisV, klass)

/** Initialize part of an abstract object in `klass` of the inheritance chain */
/** Initialize part of an abstract object in `klass` of the inheritance chain
*
* @param tpl The class body to be evaluated.
* @param thisV The value of the current object to be initialized.
* @param klass The class to which the template belongs.
*/
def init(tpl: Template, thisV: Ref, klass: ClassSymbol): Contextual[Value] = log("init " + klass.show, printer, (_: Value).show) {
val paramsMap = tpl.constr.termParamss.flatten.map { vdef =>
vdef.name -> thisV.objekt.field(vdef.symbol)
Expand Down Expand Up @@ -1782,3 +1788,18 @@ object Semantic:
if (sym.isEffectivelyFinal || sym.isConstructor) sym
else sym.matchingMember(cls.appliedRef)
}

private def isConcreteClass(cls: ClassSymbol)(using Context) = {
val instantiable: Boolean =
cls.is(Flags.Module) ||
!cls.isOneOf(Flags.AbstractOrTrait) && {
// see `Checking.checkInstantiable` in typer
val tp = cls.appliedRef
val stp = SkolemType(tp)
val selfType = cls.givenSelfType.asSeenFrom(stp, cls)
!selfType.exists || stp <:< selfType
}

// A concrete class may not be instantiated if the self type is not satisfied
instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass
}

0 comments on commit 2746428

Please sign in to comment.