Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable returning classes from MacroAnnotations (part 3) #16534

Merged
merged 13 commits into from
Jan 12, 2023
Merged
26 changes: 26 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,32 @@ object Symbols {
owner.thisType, modcls, parents, decls, TermRef(owner.thisType, module)),
privateWithin, coord, assocFile)

/** Same as `newCompleteModuleSymbol` except that `parents` can be a list of arbitrary
* types which get normalized into type refs and parameter bindings.
*/
def newNormalizedModuleSymbol(
owner: Symbol,
name: TermName,
modFlags: FlagSet,
clsFlags: FlagSet,
parentTypes: List[Type],
decls: Scope,
privateWithin: Symbol = NoSymbol,
coord: Coord = NoCoord,
assocFile: AbstractFile | Null = null)(using Context): TermSymbol = {
def completer(module: Symbol) = new LazyType {
def complete(denot: SymDenotation)(using Context): Unit = {
val cls = denot.asClass.classSymbol
val decls = newScope
denot.info = ClassInfo(owner.thisType, cls, parentTypes.map(_.dealias), decls, TermRef(owner.thisType, module))
}
}
newModuleSymbol(
owner, name, modFlags, clsFlags,
(module, modcls) => completer(module),
privateWithin, coord, assocFile)
}

/** Create a package symbol with associated package class
* from its non-info fields and a lazy type for loading the package's members.
*/
Expand Down
28 changes: 24 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/Inlining.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer


/** Inlines all calls to inline methods that are not in an inline method or a quote */
class Inlining extends MacroTransform with IdentityDenotTransformer {
thisPhase =>
class Inlining extends MacroTransform {

import tpd._

Expand Down Expand Up @@ -63,6 +62,12 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
}

private class InliningTreeMap extends TreeMapWithImplicits {

/** List of top level classes added by macro annotation in a package object.
* These are added to the PackageDef that owns this particular package object.
*/
private val topClasses = new collection.mutable.ListBuffer[Tree]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I think this isn't sufficient because package objects can be nested:

package foo {
  val x = 1
  package bar {
    val y = 2
  }
}

Instead, maybe the MemberDef case of transform should return a Thicket with the top-level classes, and we should add an extra case to transform to handle the package object module class itself, where we should also return a Thicket with the top-level classes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This use case was considered and works. I added tests for it in tests/run-macros/annot-add-global-class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that after after post typer the tree is

package foo {
  package bar {
    val y = 2
  }
  val x = 1
}

This implies that nested classes are processed first and the buffer never overlaps and is emptied just after transforming the nested package.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implies that nested classes are processed first and the buffer never overlaps and is emptied just after transforming the nested package.

This is subtle, so this precondition should be documented in the code (and ideally checked somewhere, in case it breaks)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found cases where this precondition does not hold. I updated the implementation to handle such cases.


override def transform(tree: Tree)(using Context): Tree = {
tree match
case tree: MemberDef =>
Expand All @@ -73,8 +78,16 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
&& StagingContext.level == 0
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
then
val trees = new MacroAnnotations(thisPhase).expandAnnotations(tree)
flatTree(trees.map(super.transform))
val trees = (new MacroAnnotations).expandAnnotations(tree)
val trees1 = trees.map(super.transform)

// Find classes added to the top level from a package object
val (topClasses0, trees2) =
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
else (Nil, trees1)
topClasses ++= topClasses0

flatTree(trees2)
else super.transform(tree)
case _: Typed | _: Block =>
super.transform(tree)
Expand All @@ -86,6 +99,13 @@ class Inlining extends MacroTransform with IdentityDenotTransformer {
super.transform(tree)(using StagingContext.quoteContext)
case _: GenericApply if tree.symbol.isExprSplice =>
super.transform(tree)(using StagingContext.spliceContext)
case _: PackageDef =>
super.transform(tree) match
case tree1: PackageDef if !topClasses.isEmpty =>
val newStats = tree1.stats ::: topClasses.result()
topClasses.clear()
cpy.PackageDef(tree1)(tree1.pid, newStats)
case tree1 => tree1
case _ =>
super.transform(tree)
}
Expand Down
18 changes: 7 additions & 11 deletions compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.util.control.NonFatal

import java.lang.reflect.InvocationTargetException

class MacroAnnotations(thisPhase: DenotTransformer):
class MacroAnnotations:
import tpd.*
import MacroAnnotations.*

Expand Down Expand Up @@ -82,8 +82,8 @@ class MacroAnnotations(thisPhase: DenotTransformer):
case (prefixed, newTree :: suffixed) =>
allTrees ++= prefixed
insertedAfter = suffixed :: insertedAfter
prefixed.foreach(checkAndEnter(_, tree.symbol, annot))
suffixed.foreach(checkAndEnter(_, tree.symbol, annot))
prefixed.foreach(checkMacroDef(_, tree.symbol, annot))
suffixed.foreach(checkMacroDef(_, tree.symbol, annot))
newTree
case (Nil, Nil) =>
report.error(i"Unexpected `Nil` returned by `(${annot.tree}).transform(..)` during macro expansion", annot.tree.srcPos)
Expand Down Expand Up @@ -118,19 +118,15 @@ class MacroAnnotations(thisPhase: DenotTransformer):
val quotes = QuotesImpl()(using SpliceScope.contextWithNewSpliceScope(tree.symbol.sourcePos)(using MacroExpansion.context(tree)).withOwner(tree.symbol.owner))
annotInstance.transform(using quotes)(tree.asInstanceOf[quotes.reflect.Definition])

/** Check that this tree can be added by the macro annotation and enter it if needed */
private def checkAndEnter(newTree: Tree, annotated: Symbol, annot: Annotation)(using Context) =
/** Check that this tree can be added by the macro annotation */
private def checkMacroDef(newTree: DefTree, annotated: Symbol, annot: Annotation)(using Context) =
val sym = newTree.symbol
if sym.isClass then
report.error(i"macro annotation returning a `class` is not yet supported. $annot tried to add $sym", annot.tree)
else if sym.isType then
if sym.isType && !sym.isClass then
report.error(i"macro annotation cannot return a `type`. $annot tried to add $sym", annot.tree)
else if sym.owner != annotated.owner then
else if sym.owner != annotated.owner && !(annotated.owner.isPackageObject && (sym.isClass || sym.is(Module)) && sym.owner == annotated.owner.owner) then
report.error(i"macro annotation $annot added $sym with an inconsistent owner. Expected it to be owned by ${annotated.owner} but was owned by ${sym.owner}.", annot.tree)
else if annotated.isClass && annotated.owner.is(Package) /*&& !sym.isClass*/ then
report.error(i"macro annotation can not add top-level ${sym.showKind}. $annot tried to add $sym.", annot.tree)
else
sym.enteredAfter(thisPhase)

object MacroAnnotations:

Expand Down
34 changes: 29 additions & 5 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.untpd
import dotty.tools.dotc.core.Annotations
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds
import dotty.tools.dotc.core.NameOps._
import dotty.tools.dotc.core.StdNames._
import dotty.tools.dotc.quoted.reflect._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Types
import dotty.tools.dotc.NoCompilationUnit

import dotty.tools.dotc.quoted.{MacroExpansion, PickledQuotes}
import dotty.tools.dotc.quoted.MacroExpansion
import dotty.tools.dotc.quoted.PickledQuotes
import dotty.tools.dotc.quoted.reflect._

import scala.quoted.runtime.{QuoteUnpickler, QuoteMatching}
import scala.quoted.runtime.impl.printers._
Expand Down Expand Up @@ -242,6 +243,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def unapply(cdef: ClassDef): (String, DefDef, List[Tree /* Term | TypeTree */], Option[ValDef], List[Statement]) =
val rhs = cdef.rhs.asInstanceOf[tpd.Template]
(cdef.name.toString, cdef.constructor, cdef.parents, cdef.self, rhs.body)

def module(module: Symbol, parents: List[Tree /* Term | TypeTree */], body: List[Statement]): (ValDef, ClassDef) = {
val cls = module.moduleClass
val clsDef = ClassDef(cls, parents, body)
val newCls = Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil)
val modVal = ValDef(module, Some(newCls))
(modVal, clsDef)
}
end ClassDef

given ClassDefMethods: ClassDefMethods with
Expand Down Expand Up @@ -2481,6 +2490,21 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
for sym <- decls(cls) do cls.enter(sym)
cls

def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol =
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
val mod = dotc.core.Symbols.newNormalizedModuleSymbol(
owner,
name.toTermName,
modFlags | dotc.core.Flags.ModuleValCreationFlags,
clsFlags | dotc.core.Flags.ModuleClassCreationFlags,
parents,
dotc.core.Scopes.newScope,
privateWithin)
val cls = mod.moduleClass.asClass
cls.enter(dotc.core.Symbols.newConstructor(cls, dotc.core.Flags.Synthetic, Nil, Nil))
for sym <- decls(cls) do cls.enter(sym)
mod

def newMethod(owner: Symbol, name: String, tpe: TypeRepr): Symbol =
newMethod(owner, name, tpe, Flags.EmptyFlags, noSymbol)
def newMethod(owner: Symbol, name: String, tpe: TypeRepr, flags: Flags, privateWithin: Symbol): Symbol =
Expand Down
44 changes: 34 additions & 10 deletions library/src/scala/annotation/MacroAnnotation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,41 @@ package annotation

import scala.quoted._

/** Base trait for macro annotation that will transform a definition */
/** Base trait for macro annotation implementation.
* Macro annotations can transform definitions and add new definitions.
*
* See: `MacroAnnotation.transform`
*
* @syntax markdown
*/
@experimental
trait MacroAnnotation extends StaticAnnotation:

/** Transform the `tree` definition and add other definitions
/** Transform the `tree` definition and add new definitions
*
* This method takes as argument the annotated definition.
* It returns a non-empty list containing the modified version of the annotated definition.
* The new tree for the definition must use the original symbol.
* New definitions can be added to the list before or after the transformed definitions, this order
* will be retained.
* will be retained. New definitions will not be visible from outside the macro expansion.
*
* All definitions in the result must have the same owner. The owner can be recovered from `tree.symbol.owner`.
* #### Restrictions
* - All definitions in the result must have the same owner. The owner can be recovered from `Symbol.spliceOwner`.
* - Special case: an annotated top-level `def`, `val`, `var`, `lazy val` can return a `class`/`object`
definition that is owned by the package or package object.
* - Can not return a `type`.
* - Annotated top-level `class`/`object` can not return top-level `def`, `val`, `var`, `lazy val`.
* - Can not see new definition in user written code.
*
* The result cannot add new `class`, `object` or `type` definition. This limitation will be relaxed in the future.
* #### Good practices
* - Make your new definitions private if you can.
* - New definitions added as class members should use a fresh name (`Symbol.freshName`) to avoid collisions.
* - New top-level definitions should use a fresh name (`Symbol.freshName`) that includes the name of the annotated
* member as a prefix to avoid collisions of definitions added in other files.
*
* IMPORTANT: When developing and testing a macro annotation, you must enable `-Xcheck-macros` and `-Ycheck:all`.
* **IMPORTANT**: When developing and testing a macro annotation, you must enable `-Xcheck-macros` and `-Ycheck:all`.
*
* Example 1:
* #### Example 1
* This example shows how to modify a `def` and add a `val` next to it using a macro annotation.
* ```scala
* import scala.quoted.*
Expand Down Expand Up @@ -54,7 +70,10 @@ trait MacroAnnotation extends StaticAnnotation:
* List(tree)
* ```
* with this macro annotation a user can write
* ```scala sc:nocompile
* ```scala
* //{
* class memoize extends scala.annotation.StaticAnnotation
* //}
* @memoize
* def fib(n: Int): Int =
* println(s"compute fib of $n")
Expand All @@ -74,7 +93,7 @@ trait MacroAnnotation extends StaticAnnotation:
* )
* ```
*
* Example 2:
* #### Example 2
* This example shows how to modify a `class` using a macro annotation.
* It shows how to override inherited members and add new ones.
* ```scala
Expand Down Expand Up @@ -164,7 +183,10 @@ trait MacroAnnotation extends StaticAnnotation:
* }
* ```
* with this macro annotation a user can write
* ```scala sc:nocompile
* ```scala
* //{
* class equals extends scala.annotation.StaticAnnotation
* //}
* @equals class User(val name: String, val id: Int)
* ```
* and the macro will modify the class definition to generate the following code
Expand All @@ -184,5 +206,7 @@ trait MacroAnnotation extends StaticAnnotation:
*
* @param Quotes Implicit instance of Quotes used for tree reflection
* @param tree Tree that will be transformed
*
* @syntax markdown
*/
def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition]
84 changes: 83 additions & 1 deletion library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,33 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
* otherwise the can be `Term` containing the `New` applied to the parameters of the extended class.
* @param body List of members of the class. The members must align with the members of `cls`.
*/
// TODO add selfOpt: Option[ValDef]?
@experimental def apply(cls: Symbol, parents: List[Tree /* Term | TypeTree */], body: List[Statement]): ClassDef
def copy(original: Tree)(name: String, constr: DefDef, parents: List[Tree /* Term | TypeTree */], selfOpt: Option[ValDef], body: List[Statement]): ClassDef
def unapply(cdef: ClassDef): (String, DefDef, List[Tree /* Term | TypeTree */], Option[ValDef], List[Statement])


/** Create the ValDef and ClassDef of a module.
nicolasstucki marked this conversation as resolved.
Show resolved Hide resolved
*
* Equivalent to
* ```
* def module(module: Symbol, parents: List[Tree], body: List[Statement]): (ValDef, ClassDef) =
* val modCls = module.moduleClass
* val modClassDef = ClassDef(modCls, parents, body)
* val modValDef = ValDef(module, Some(Apply(Select(New(TypeIdent(modCls)), cls.primaryConstructor), Nil)))
* List(modValDef, modClassDef)
* ```
*
* @param module the module symbol (of the module lazy val)
nicolasstucki marked this conversation as resolved.
Show resolved Hide resolved
* @param parents parents of the module class
* @param body body of the module class
* @return The module lazy val definition and module class definition.
* These should be added one after the other (in that order) in the body of a class or statements of a block.
*
* @syntax markdown
*/
// TODO add selfOpt: Option[ValDef]?
@experimental def module(module: Symbol, parents: List[Tree /* Term | TypeTree */], body: List[Statement]): (ValDef, ClassDef)
}

/** Makes extension methods on `ClassDef` available without any imports */
Expand Down Expand Up @@ -3638,8 +3662,66 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
* direct or indirect children of the reflection context's owner.
*/
// TODO: add flags and privateWithin
@experimental def newClass(parent: Symbol, name: String, parents: List[TypeRepr], decls: Symbol => List[Symbol], selfType: Option[TypeRepr]): Symbol

/** Generates a new module symbol with an associated module class symbol,
* This returns the module symbol. The module class can be accessed calling `moduleClass` on this symbol.
*
* Example usage:
* ```scala
* //{
* given Quotes = ???
* import quotes.reflect._
* //}
* val moduleName: String = Symbol.freshName("MyModule")
* val parents = List(TypeTree.of[Object])
* def decls(cls: Symbol): List[Symbol] =
* List(Symbol.newMethod(cls, "run", MethodType(Nil)(_ => Nil, _ => TypeRepr.of[Unit]), Flags.EmptyFlags, Symbol.noSymbol))
*
* val mod = Symbol.newModule(Symbol.spliceOwner, moduleName, Flags.EmptyFlags, Flags.EmptyFlags, parents.map(_.tpe), decls, Symbol.noSymbol)
* val cls = mod.moduleClass
* val runSym = cls.declaredMethod("run").head
*
* val runDef = DefDef(runSym, _ => Some('{ println("run") }.asTerm))
* val modDef = ClassDef.module(mod, parents, body = List(runDef))
*
* val callRun = Apply(Select(Ref(mod), runSym), Nil)
*
* Block(modDef.toList, callRun)
* ```
* constructs the equivalent to
* ```scala
* //{
* given Quotes = ???
* import quotes.reflect._
* //}
* '{
* object MyModule$macro$1 extends Object:
* def run(): Unit = println("run")
* MyModule$macro$1.run()
* }
* ```
*
* @param parent The owner of the class
* @param name The name of the class
* @param modFlags extra flags with which the module symbol should be constructed
* @param clsFlags extra flags with which the module class symbol should be constructed
* @param parents The parent classes of the class. The first parent must not be a trait.
* @param decls A function that takes the symbol of the module class as input and return the symbols of its declared members
* @param privateWithin the symbol within which this new method symbol should be private. May be noSymbol.
*
* This symbol starts without an accompanying definition.
* It is the meta-programmer's responsibility to provide exactly one corresponding definition by passing
* this symbol to the ClassDef and ValDef constructor.
nicolasstucki marked this conversation as resolved.
Show resolved Hide resolved
*
* @note As a macro can only splice code into the point at which it is expanded, all generated symbols must be
* direct or indirect children of the reflection context's owner.
*
* @syntax markdown
smarter marked this conversation as resolved.
Show resolved Hide resolved
*/
@experimental def newModule(owner: Symbol, name: String, modFlags: Flags, clsFlags: Flags, parents: List[TypeRepr], decls: Symbol => List[Symbol], privateWithin: Symbol): Symbol

/** Generates a new method symbol with the given parent, name and type.
*
* To define a member method of a class, use the `newMethod` within the `decls` function of `newClass`.
Expand Down Expand Up @@ -4217,7 +4299,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
// FLAGS //
///////////////

/** FlagSet of a Symbol */
/** Flags of a Symbol */
type Flags

/** Module object of `type Flags` */
Expand Down
Loading