From 6c6dc775e67f4473a313ed8c0cb95aec6016a42b Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Thu, 12 Jan 2023 15:39:47 +0100 Subject: [PATCH] Handle top-level class insertion using a MutableSymbolMap --- .../dotty/tools/dotc/transform/Inlining.scala | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/Inlining.scala b/compiler/src/dotty/tools/dotc/transform/Inlining.scala index 87924c7932eb..f0ed7026ee91 100644 --- a/compiler/src/dotty/tools/dotc/transform/Inlining.scala +++ b/compiler/src/dotty/tools/dotc/transform/Inlining.scala @@ -14,6 +14,7 @@ import dotty.tools.dotc.inlines.Inlines import dotty.tools.dotc.ast.TreeMapWithImplicits import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer +import scala.collection.mutable.ListBuffer /** Inlines all calls to inline methods that are not in an inline method or a quote */ class Inlining extends MacroTransform { @@ -66,7 +67,7 @@ class Inlining extends MacroTransform { /** List of top level classes added by macro annotation in a package object. * These are added to the PackageDef that owns this particular package object. */ - private val topClasses = new collection.mutable.ListBuffer[Tree] + private val newTopClasses = MutableSymbolMap[ListBuffer[Tree]]() override def transform(tree: Tree)(using Context): Tree = { tree match @@ -82,10 +83,11 @@ class Inlining extends MacroTransform { val trees1 = trees.map(super.transform) // Find classes added to the top level from a package object - val (topClasses0, trees2) = + val (topClasses, trees2) = if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner) else (Nil, trees1) - topClasses ++= topClasses0 + if topClasses.nonEmpty then + newTopClasses.getOrElseUpdate(ctx.owner.owner, new ListBuffer) ++= topClasses flatTree(trees2) else super.transform(tree) @@ -101,10 +103,13 @@ class Inlining extends MacroTransform { super.transform(tree)(using StagingContext.spliceContext) case _: PackageDef => super.transform(tree) match - case tree1: PackageDef if !topClasses.isEmpty => - val newStats = tree1.stats ::: topClasses.result() - topClasses.clear() - cpy.PackageDef(tree1)(tree1.pid, newStats) + case tree1: PackageDef => + newTopClasses.get(tree.symbol.moduleClass) match + case Some(topClasses) => + newTopClasses.remove(tree.symbol.moduleClass) + val newStats = tree1.stats ::: topClasses.result() + cpy.PackageDef(tree1)(tree1.pid, newStats) + case _ => tree1 case tree1 => tree1 case _ => super.transform(tree)