Skip to content

Commit

Permalink
Enable returning classes from MacroAnnotations (part 3) (#16534)
Browse files Browse the repository at this point in the history
Enable the addition of classes from a `MacroAnnotation`:
* Can add new `class`/`object` definitions next to the annotated
definition

Special cases:
* An annotated top-level `def`, `val`, `var`, `lazy val` can return a
`class`/`object`
   definition that is owned by the package or package object.

Related PRs:
 * Follows #16454
  • Loading branch information
smarter authored Jan 12, 2023
2 parents 80e8365 + 6c6dc77 commit be10bc6
Show file tree
Hide file tree
Showing 32 changed files with 653 additions and 31 deletions.
26 changes: 26 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,32 @@ object Symbols {
owner.thisType, modcls, parents, decls, TermRef(owner.thisType, module)),
privateWithin, coord, assocFile)

/** Same as `newCompleteModuleSymbol` except that `parents` can be a list of arbitrary
* types which get normalized into type refs and parameter bindings.
*/
def newNormalizedModuleSymbol(
owner: Symbol,
name: TermName,
modFlags: FlagSet,
clsFlags: FlagSet,
parentTypes: List[Type],
decls: Scope,
privateWithin: Symbol = NoSymbol,
coord: Coord = NoCoord,
assocFile: AbstractFile | Null = null)(using Context): TermSymbol = {
def completer(module: Symbol) = new LazyType {
def complete(denot: SymDenotation)(using Context): Unit = {
val cls = denot.asClass.classSymbol
val decls = newScope
denot.info = ClassInfo(owner.thisType, cls, parentTypes.map(_.dealias), decls, TermRef(owner.thisType, module))
}
}
newModuleSymbol(
owner, name, modFlags, clsFlags,
(module, modcls) => completer(module),
privateWithin, coord, assocFile)
}

/** Create a package symbol with associated package class
* from its non-info fields and a lazy type for loading the package's members.
*/
Expand Down
33 changes: 29 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/Inlining.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import dotty.tools.dotc.inlines.Inlines
import dotty.tools.dotc.ast.TreeMapWithImplicits
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer

import scala.collection.mutable.ListBuffer

/** Inlines all calls to inline methods that are not in an inline method or a quote */
class Inlining extends MacroTransform with IdentityDenotTransformer {
thisPhase =>
class Inlining extends MacroTransform {

import tpd._

Expand Down Expand Up @@ -63,6 +63,12 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
}

private class InliningTreeMap extends TreeMapWithImplicits {

/** List of top level classes added by macro annotation in a package object.
* These are added to the PackageDef that owns this particular package object.
*/
private val newTopClasses = MutableSymbolMap[ListBuffer[Tree]]()

override def transform(tree: Tree)(using Context): Tree = {
tree match
case tree: MemberDef =>
Expand All @@ -73,8 +79,17 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
&& StagingContext.level == 0
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
then
val trees = new MacroAnnotations(thisPhase).expandAnnotations(tree)
flatTree(trees.map(super.transform))
val trees = (new MacroAnnotations).expandAnnotations(tree)
val trees1 = trees.map(super.transform)

// Find classes added to the top level from a package object
val (topClasses, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
if topClasses.nonEmpty then
newTopClasses.getOrElseUpdate(ctx.owner.owner, new ListBuffer) ++= topClasses

flatTree(trees2)
else super.transform(tree)
case _: Typed | _: Block =>
super.transform(tree)
Expand All @@ -86,6 +101,16 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
super.transform(tree)(using StagingContext.quoteContext)
case _: GenericApply if tree.symbol.isExprSplice =>
super.transform(tree)(using StagingContext.spliceContext)
case _: PackageDef =>
super.transform(tree) match
case tree1: PackageDef =>
newTopClasses.get(tree.symbol.moduleClass) match
case Some(topClasses) =>
newTopClasses.remove(tree.symbol.moduleClass)
val newStats = tree1.stats ::: topClasses.result()
cpy.PackageDef(tree1)(tree1.pid, newStats)
case _ => tree1
case tree1 => tree1
case _ =>
super.transform(tree)
}
Expand Down
18 changes: 7 additions & 11 deletions compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.util.control.NonFatal

import java.lang.reflect.InvocationTargetException

class MacroAnnotations(thisPhase: DenotTransformer):
class MacroAnnotations:
import tpd.*
import MacroAnnotations.*

Expand Down Expand Up @@ -82,8 +82,8 @@ class MacroAnnotations(thisPhase: DenotTransformer):
case (prefixed, newTree :: suffixed) =>
allTrees ++= prefixed
insertedAfter = suffixed :: insertedAfter
prefixed.foreach(checkAndEnter(_, tree.symbol, annot))
suffixed.foreach(checkAndEnter(_, tree.symbol, annot))
prefixed.foreach(checkMacroDef(_, tree.symbol, annot))
suffixed.foreach(checkMacroDef(_, tree.symbol, annot))
newTree
case (Nil, Nil) =>
report.error(i"Unexpected `Nil` returned by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
Expand Down Expand Up @@ -118,19 +118,15 @@ class MacroAnnotations(thisPhase: DenotTransformer):
val quotes = QuotesImpl()(using SpliceScope.contextWithNewSpliceScope(tree.symbol.sourcePos)(using MacroExpansion.context(tree)).withOwner(tree.symbol.owner))
annotInstance.transform(using quotes)(tree.asInstanceOf[quotes.reflect.Definition])

/** Check that this tree can be added by the macro annotation and enter it if needed */
private def checkAndEnter(newTree: Tree, annotated: Symbol, annot: Annotation)(using Context) =
/** Check that this tree can be added by the macro annotation */
private def checkMacroDef(newTree: DefTree, annotated: Symbol, annot: Annotation)(using Context) =
val sym = newTree.symbol
if sym.isClass then
report.error(i"macro annotation returning a `class` is not yet supported. $annot tried to add $sym", annot.tree)
else if sym.isType then
if sym.isType && !sym.isClass then
report.error(i"macro annotation cannot return a `type`. $annot tried to add $sym", annot.tree)
else if sym.owner != annotated.owner then
else if sym.owner != annotated.owner && !(annotated.owner.isPackageObject && (sym.isClass || sym.is(Module)) && sym.owner == annotated.owner.owner) then
report.error(i"macro annotation $annot added $sym with an inconsistent owner. Expected it to be owned by ${annotated.owner} but was owned by ${sym.owner}.", annot.tree)
else if annotated.isClass && annotated.owner.is(Package) /*&& !sym.isClass*/ then
report.error(i"macro annotation can not add top-level ${sym.showKind}. $annot tried to add $sym.", annot.tree)
else
sym.enteredAfter(thisPhase)

object MacroAnnotations:

Expand Down
34 changes: 29 additions & 5 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.core.Annotations
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds
import dotty.tools.dotc.core.NameOps._
import dotty.tools.dotc.core.StdNames._
import dotty.tools.dotc.quoted.reflect._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.NoCompilationUnit

import dotty.tools.dotc.quoted.{MacroExpansion, PickledQuotes}
import dotty.tools.dotc.quoted.MacroExpansion
import dotty.tools.dotc.quoted.PickledQuotes
import dotty.tools.dotc.quoted.reflect._

import scala.quoted.runtime.{QuoteUnpickler, QuoteMatching}
import scala.quoted.runtime.impl.printers._
Expand Down Expand Up @@ -242,6 +243,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def unapply(cdef: ClassDef): (String, DefDef, List[Tree /* Term | TypeTree */], Option[ValDef], List[Statement]) =
val rhs = cdef.rhs.asInstanceOf[tpd.Template]
(cdef.name.toString, cdef.constructor, cdef.parents, cdef.self, rhs.body)

def module(module: Symbol, parents: List[Tree /* Term | TypeTree */], body: List[Statement]): (ValDef, ClassDef) = {
val cls = module.moduleClass
val clsDef = ClassDef(cls, parents, body)
val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil)
val modVal = ValDef(module, Some(newCls))
(modVal, clsDef)
}
end ClassDef

given ClassDefMethods: ClassDefMethods with
Expand Down Expand Up @@ -2481,6 +2490,21 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
for sym <- decls(cls) do cls.enter(sym)
cls

def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol =
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
val mod = dotc.core.Symbols.newNormalizedModuleSymbol(
owner,
name.toTermName,
modFlags | dotc.core.Flags.ModuleValCreationFlags,
clsFlags | dotc.core.Flags.ModuleClassCreationFlags,
parents,
dotc.core.Scopes.newScope,
privateWithin)
val cls = mod.moduleClass.asClass
cls.enter(dotc.core.Symbols.newConstructor(cls, dotc.core.Flags.Synthetic, Nil, Nil))
for sym <- decls(cls) do cls.enter(sym)
mod

def newMethod(owner: Symbol, name: String, tpe: TypeRepr): Symbol =
newMethod(owner, name, tpe, Flags.EmptyFlags, noSymbol)
def newMethod(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
Expand Down
44 changes: 34 additions & 10 deletions library/src/scala/annotation/MacroAnnotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,41 @@ package annotation

import scala.quoted._

/** Base trait for macro annotation that will transform a definition */
/** Base trait for macro annotation implementation.
* Macro annotations can transform definitions and add new definitions.
*
* See: `MacroAnnotation.transform`
*
* @syntax markdown
*/
@experimental
trait MacroAnnotation extends StaticAnnotation:

/** Transform the `tree` definition and add other definitions
/** Transform the `tree` definition and add new definitions
*
* This method takes as argument the annotated definition.
* It returns a non-empty list containing the modified version of the annotated definition.
* The new tree for the definition must use the original symbol.
* New definitions can be added to the list before or after the transformed definitions, this order
* will be retained.
* will be retained. New definitions will not be visible from outside the macro expansion.
*
* All definitions in the result must have the same owner. The owner can be recovered from `tree.symbol.owner`.
* #### Restrictions
* - All definitions in the result must have the same owner. The owner can be recovered from `Symbol.spliceOwner`.
* - Special case: an annotated top-level `def`, `val`, `var`, `lazy val` can return a `class`/`object`
definition that is owned by the package or package object.
* - Can not return a `type`.
* - Annotated top-level `class`/`object` can not return top-level `def`, `val`, `var`, `lazy val`.
* - Can not see new definition in user written code.
*
* The result cannot add new `class`, `object` or `type` definition. This limitation will be relaxed in the future.
* #### Good practices
* - Make your new definitions private if you can.
* - New definitions added as class members should use a fresh name (`Symbol.freshName`) to avoid collisions.
* - New top-level definitions should use a fresh name (`Symbol.freshName`) that includes the name of the annotated
* member as a prefix to avoid collisions of definitions added in other files.
*
* IMPORTANT: When developing and testing a macro annotation, you must enable `-Xcheck-macros` and `-Ycheck:all`.
* **IMPORTANT**: When developing and testing a macro annotation, you must enable `-Xcheck-macros` and `-Ycheck:all`.
*
* Example 1:
* #### Example 1
* This example shows how to modify a `def` and add a `val` next to it using a macro annotation.
* ```scala
* import scala.quoted.*
Expand Down Expand Up @@ -54,7 +70,10 @@ trait MacroAnnotation extends StaticAnnotation:
* List(tree)
* ```
* with this macro annotation a user can write
* ```scala sc:nocompile
* ```scala
* //{
* class memoize extends scala.annotation.StaticAnnotation
* //}
* @memoize
* def fib(n: Int): Int =
* println(s"compute fib of $n")
Expand All @@ -74,7 +93,7 @@ trait MacroAnnotation extends StaticAnnotation:
* )
* ```
*
* Example 2:
* #### Example 2
* This example shows how to modify a `class` using a macro annotation.
* It shows how to override inherited members and add new ones.
* ```scala
Expand Down Expand Up @@ -164,7 +183,10 @@ trait MacroAnnotation extends StaticAnnotation:
* }
* ```
* with this macro annotation a user can write
* ```scala sc:nocompile
* ```scala
* //{
* class equals extends scala.annotation.StaticAnnotation
* //}
* @equals class User(val name: String, val id: Int)
* ```
* and the macro will modify the class definition to generate the following code
Expand All @@ -184,5 +206,7 @@ trait MacroAnnotation extends StaticAnnotation:
*
* @param Quotes Implicit instance of Quotes used for tree reflection
* @param tree Tree that will be transformed
*
* @syntax markdown
*/
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition]
85 changes: 84 additions & 1 deletion library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,33 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
* otherwise the can be `Term` containing the `New` applied to the parameters of the extended class.
* @param body List of members of the class. The members must align with the members of `cls`.
*/
// TODO add selfOpt: Option[ValDef]?
@experimental def apply(cls: Symbol, parents: List[Tree /* Term | TypeTree */], body: List[Statement]): ClassDef
def copy(original: Tree)(name: String, constr: DefDef, parents: List[Tree /* Term | TypeTree */], selfOpt: Option[ValDef], body: List[Statement]): ClassDef
def unapply(cdef: ClassDef): (String, DefDef, List[Tree /* Term | TypeTree */], Option[ValDef], List[Statement])


/** Create the ValDef and ClassDef of a module (equivalent to an `object` declaration in source code).
*
* Equivalent to
* ```
* def module(module: Symbol, parents: List[Tree], body: List[Statement]): (ValDef, ClassDef) =
* val modCls = module.moduleClass
* val modClassDef = ClassDef(modCls, parents, body)
* val modValDef = ValDef(module, Some(Apply(Select(New(TypeIdent(modCls)), cls.primaryConstructor), Nil)))
* List(modValDef, modClassDef)
* ```
*
* @param module the module symbol (created using `Symbol.newModule`)
* @param parents parents of the module class
* @param body body of the module class
* @return The module lazy val definition and module class definition.
* These should be added one after the other (in that order) in the body of a class or statements of a block.
*
* @syntax markdown
*/
// TODO add selfOpt: Option[ValDef]?
@experimental def module(module: Symbol, parents: List[Tree /* Term | TypeTree */], body: List[Statement]): (ValDef, ClassDef)
}

/** Makes extension methods on `ClassDef` available without any imports */
Expand Down Expand Up @@ -3638,8 +3662,67 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
* direct or indirect children of the reflection context's owner.
*/
// TODO: add flags and privateWithin
@experimental def newClass(parent: Symbol, name: String, parents: List[TypeRepr], decls: Symbol => List[Symbol], selfType: Option[TypeRepr]): Symbol

/** Generates a new module symbol with an associated module class symbol,
* this is equivalent to an `object` declaration in source code.
* This method returns the module symbol. The module class can be accessed calling `moduleClass` on this symbol.
*
* Example usage:
* ```scala
* //{
* given Quotes = ???
* import quotes.reflect._
* //}
* val moduleName: String = Symbol.freshName("MyModule")
* val parents = List(TypeTree.of[Object])
* def decls(cls: Symbol): List[Symbol] =
* List(Symbol.newMethod(cls, "run", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit]), Flags.EmptyFlags, Symbol.noSymbol))
*
* val mod = Symbol.newModule(Symbol.spliceOwner, moduleName, Flags.EmptyFlags, Flags.EmptyFlags, parents.map(_.tpe), decls, Symbol.noSymbol)
* val cls = mod.moduleClass
* val runSym = cls.declaredMethod("run").head
*
* val runDef = DefDef(runSym, _ => Some('{ println("run") }.asTerm))
* val modDef = ClassDef.module(mod, parents, body = List(runDef))
*
* val callRun = Apply(Select(Ref(mod), runSym), Nil)
*
* Block(modDef.toList, callRun)
* ```
* constructs the equivalent to
* ```scala
* //{
* given Quotes = ???
* import quotes.reflect._
* //}
* '{
* object MyModule$macro$1 extends Object:
* def run(): Unit = println("run")
* MyModule$macro$1.run()
* }
* ```
*
* @param parent The owner of the class
* @param name The name of the class
* @param modFlags extra flags with which the module symbol should be constructed
* @param clsFlags extra flags with which the module class symbol should be constructed
* @param parents The parent classes of the class. The first parent must not be a trait.
* @param decls A function that takes the symbol of the module class as input and return the symbols of its declared members
* @param privateWithin the symbol within which this new method symbol should be private. May be noSymbol.
*
* This symbol starts without an accompanying definition.
* It is the meta-programmer's responsibility to provide exactly one corresponding definition by passing
* this symbol to `ClassDef.module`.
*
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
* direct or indirect children of the reflection context's owner.
*
* @syntax markdown
*/
@experimental def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol

/** Generates a new method symbol with the given parent, name and type.
*
* To define a member method of a class, use the `newMethod` within the `decls` function of `newClass`.
Expand Down Expand Up @@ -4217,7 +4300,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
// FLAGS //
///////////////

/** FlagSet of a Symbol */
/** Flags of a Symbol */
type Flags

/** Module object of `type Flags` */
Expand Down
Loading

0 comments on commit be10bc6

Please sign in to comment.