Skip to content

Commit

Permalink
Allow to add type parameters to newClass
Browse files Browse the repository at this point in the history
  • Loading branch information
jchyb committed Dec 31, 2024
1 parent bcc08aa commit a839263
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 25 deletions.
114 changes: 99 additions & 15 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.quoted.runtime.impl.printers.*
import scala.reflect.TypeTest
import dotty.tools.dotc.core.NameKinds.ExceptionBinderName
import dotty.tools.dotc.transform.TreeChecker
import dotty.tools.dotc.core.Names

object QuotesImpl {

Expand Down Expand Up @@ -243,15 +244,21 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def apply(cls: Symbol, parents: List[Tree], body: List[Statement]): ClassDef =
val paramsDefs: List[untpd.ParamClause] =
cls.primaryConstructor.paramSymss.map { paramSym =>
paramSym.map( symm =>
ValDef(symm, None)
)
if paramSym.headOption.map(_.isType).getOrElse(false) then
paramSym.map(sym => TypeDef(sym))
else
paramSym.map(ValDef(_, None))
}
val paramsAccessDefs: List[untpd.ParamClause] =
cls.primaryConstructor.paramSymss.map { paramSym =>
paramSym.map( symm =>
ValDef(cls.fieldMember(symm.name.toString()), None) // TODO I don't like the toString here
)
if paramSym.headOption.map(_.isType).getOrElse(false) then
paramSym.map { symm =>
TypeDef(cls.typeMember(symm.name.toString()))
}
else
paramSym.map { symm =>
ValDef(cls.fieldMember(symm.name.toString()), None)// TODO I don't like the toString here
}
}

val termSymbol: dotc.core.Symbols.TermSymbol = cls.primaryConstructor.asTerm
Expand Down Expand Up @@ -2620,10 +2627,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
def requiredMethod(path: String): Symbol = dotc.core.Symbols.requiredMethod(path)
def classSymbol(fullName: String): Symbol = dotc.core.Symbols.requiredClass(fullName)

def newClass(parent: Symbol, name: String, parents: List[TypeRepr], decls: Symbol => List[Symbol], selfType: Option[TypeRepr]): Symbol =
def newClass(owner: Symbol, name: String, parents: List[TypeRepr], decls: Symbol => List[Symbol], selfType: Option[TypeRepr]): Symbol =
assert(parents.nonEmpty && !parents.head.typeSymbol.is(dotc.core.Flags.Trait), "First parent must be a class")
val cls = dotc.core.Symbols.newNormalizedClassSymbol(
parent,
owner,
name.toTypeName,
dotc.core.Flags.EmptyFlags,
parents,
Expand All @@ -2633,19 +2640,96 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
for sym <- decls(cls) do cls.enter(sym)
cls

def newClass(parent: Symbol, name: String, parents: Symbol => List[TypeRepr], decls: Symbol => List[Symbol], selfType: Option[TypeRepr], paramNames: List[String], paramTypes: List[TypeRepr], flags: Flags, privateWithin: Symbol): Symbol =
checkValidFlags(flags.toTermFlags, Flags.validClassFlags)
assert(!privateWithin.exists || privateWithin.isType, "privateWithin must be a type symbol or `Symbol.noSymbol`")
def newClass(
owner: Symbol,
name: String,
parents: Symbol => List[TypeRepr],
decls: Symbol => List[Symbol],
selfType: Option[TypeRepr],
paramNames: List[String],
paramTypes: List[TypeRepr],
clsFlags: Flags,
clsPrivateWithin: Symbol
): Symbol =
checkValidFlags(clsFlags.toTermFlags, Flags.validClassFlags)
assert(paramNames.length == paramTypes.length, "paramNames and paramTypes must have the same length")
assert(!clsPrivateWithin.exists || clsPrivateWithin.isType, "clsPrivateWithin must be a type symbol or `Symbol.noSymbol`")
val cls = dotc.core.Symbols.newNormalizedClassSymbolUsingClassSymbolinParents(
parent,
owner,
name.toTypeName,
flags,
clsFlags,
parents,
selfType.getOrElse(Types.NoType),
privateWithin)
clsPrivateWithin)
cls.enter(dotc.core.Symbols.newConstructor(cls, dotc.core.Flags.Synthetic, paramNames.map(_.toTermName), paramTypes))
for (name, tpe) <- paramNames.zip(paramTypes) do
cls.enter(dotc.core.Symbols.newSymbol(cls, name.toTermName, Flags.ParamAccessor, tpe, Symbol.noSymbol)) // add other flags (local, private, privatelocal) and set privateWithin
cls.enter(dotc.core.Symbols.newSymbol(cls, name.toTermName, Flags.ParamAccessor, tpe, Symbol.noSymbol))
for sym <- decls(cls) do cls.enter(sym)
cls

def newClass(
owner: Symbol,
name: String,
parents: Symbol => List[TypeRepr],
decls: Symbol => List[Symbol],
selfType: Option[TypeRepr],
constructorMethodType: TypeRepr => MethodOrPoly,
clsFlags: Flags,
clsPrivateWithin: Symbol,
consFlags: Flags,
consPrivateWithin: Symbol,
consParamFlags: List[List[Flags]]
) =
assert(!clsPrivateWithin.exists || clsPrivateWithin.isType, "clsPrivateWithin must be a type symbol or `Symbol.noSymbol`")
assert(!consPrivateWithin.exists || consPrivateWithin.isType, "consPrivateWithin must be a type symbol or `Symbol.noSymbol`")
checkValidFlags(clsFlags.toTermFlags, Flags.validClassFlags)
val cls = dotc.core.Symbols.newNormalizedClassSymbolUsingClassSymbolinParents(
owner,
name.toTypeName,
clsFlags,
parents,
selfType.getOrElse(Types.NoType),
clsPrivateWithin)
val methodType: MethodOrPoly = constructorMethodType(cls.typeRef)
def throwShapeException() = throw new Exception("Shapes of constructorMethodType and consParamFlags differ.")
def checkMethodOrPolyShape(checkedMethodType: TypeRepr, clauseIdx: Int): Unit =
checkedMethodType match
case PolyType(params, _, res) if clauseIdx == 0 =>
if (consParamFlags.length < clauseIdx) throwShapeException()
if (consParamFlags(clauseIdx).length != params.length) throwShapeException()
checkMethodOrPolyShape(res, clauseIdx + 1)
case PolyType(_, _, _) => throw new Exception("Clause interleaving not supported for constructors")
case MethodType(params, _, res) =>
if (consParamFlags.length < clauseIdx) throwShapeException()
if (consParamFlags(clauseIdx).length != params.length) throwShapeException()
checkMethodOrPolyShape(res, clauseIdx + 1)
case _ =>
checkMethodOrPolyShape(methodType, clauseIdx = 0)
cls.enter(dotc.core.Symbols.newSymbol(cls, nme.CONSTRUCTOR, Flags.Synthetic | Flags.Method | consFlags, methodType, consPrivateWithin, dotty.tools.dotc.util.Spans.NoCoord)) // constructor flags
def getParamAccessors(methodType: TypeRepr, clauseIdx: Int): List[((String, TypeRepr, Boolean, Int), Int)] =
methodType match
case MethodType(paramInfosExp, resultTypeExp, res) =>
paramInfosExp.zip(resultTypeExp).map(_ :* false :* clauseIdx).zipWithIndex ++ getParamAccessors(res, clauseIdx + 1)
case pt @ PolyType(paramNames, paramBounds, res) =>
paramNames.zip(paramBounds).map(_ :* true :* clauseIdx).zipWithIndex ++ getParamAccessors(res, clauseIdx + 1)
case result =>
List()
// Maps PolyType indexes to type symbols
val paramRefMap = collection.mutable.HashMap[Int, Symbol]()
val paramRefRemapper = new Types.TypeMap {
def apply(tp: Types.Type) = tp match {
case pRef: ParamRef if pRef.binder == methodType => paramRefMap(pRef.paramNum).typeRef
case _ => mapOver(tp)
}
}
for ((name, tpe, isType, clauseIdx), elementIdx) <- getParamAccessors(methodType, 0) do
if isType then
val symbol = dotc.core.Symbols.newSymbol(cls, name.toTypeName, Flags.Param | Flags.Deferred | consParamFlags(clauseIdx)(elementIdx), tpe, Symbol.noSymbol)
paramRefMap.addOne(elementIdx, symbol)
cls.enter(symbol)
else
val fixedType = paramRefRemapper(tpe)
cls.enter(dotc.core.Symbols.newSymbol(cls, name.toTermName, Flags.ParamAccessor | consParamFlags(clauseIdx)(elementIdx), fixedType, Symbol.noSymbol)) // add other flags (local, private, privatelocal) and set privateWithin
for sym <- decls(cls) do cls.enter(sym)
cls

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1379,13 +1379,13 @@ object SourceCode {
printTypeTree(bounds.low)
else
bounds.low match {
case Inferred() =>
case Inferred() if bounds.low.tpe.typeSymbol == TypeRepr.of[Nothing].typeSymbol =>
case low =>
this += " >: "
printTypeTree(low)
}
bounds.hi match {
case Inferred() => this
case Inferred() if bounds.hi.tpe.typeSymbol == TypeRepr.of[Any].typeSymbol => this
case hi =>
this += " <: "
printTypeTree(hi)
Expand Down
60 changes: 52 additions & 8 deletions library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3796,7 +3796,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
/** The class Symbol of a global class definition */
def classSymbol(fullName: String): Symbol

/** Generates a new class symbol for a class with a parameterless constructor.
/** Generates a new class symbol for a class with a public parameterless constructor.
*
* Example usage:
* ```
Expand Down Expand Up @@ -3824,7 +3824,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
* }
* ```
*
* @param parent The owner of the class
* @param owner The owner of the class
* @param name The name of the class
* @param parents The parent classes of the class. The first parent must not be a trait.
* @param decls The member declarations of the class provided the symbol of this class
Expand All @@ -3840,17 +3840,61 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
// TODO: add flags and privateWithin
@experimental def newClass(owner: Symbol, name: String, parents: List[TypeRepr], decls: Symbol => List[Symbol], selfType: Option[TypeRepr]): Symbol

/**
* @param parent declerations of this class provided the symbol of this class.
* Calling `cls.typeRef.asType` as part of this function will lead to cyclic reference errors.
/** Generates a new class symbol for a class with a public constructor.
*
* @param owner The owner of the class
* @param name The name of the class
* @param parents Function returning the parent classes of the class. The first parent must not be a trait.
* Takes the constructed class symbol as an argument. Calling `cls.typeRef.asType` as part of this function will lead to cyclic reference errors.
* @param paramNames constructor parameter names.
* @param paramTypes constructor parameter types.
* @param flags extra flags with which the class symbol should be constructed.
* @param privateWithin the symbol within which this new method symbol should be private. May be noSymbol.
* @param clsFlags extra flags with which the class symbol should be constructed.
* @param clsPrivateWithin the symbol within which this new class symbol should be private. May be noSymbol.
*
* Parameters can be obtained via classSymbol.memberField
*/
@experimental def newClass(owner: Symbol, name: String, parents: Symbol => List[TypeRepr], decls: Symbol => List[Symbol], selfType: Option[TypeRepr], paramNames: List[String], paramTypes: List[TypeRepr], flags: Flags, privateWithin: Symbol): Symbol
@experimental def newClass(
owner: Symbol,
name: String,
parents: Symbol => List[TypeRepr],
decls: Symbol => List[Symbol], selfType: Option[TypeRepr],
paramNames: List[String],
paramTypes: List[TypeRepr],
clsFlags: Flags,
clsPrivateWithin: Symbol
): Symbol

/**
*
*
* @param owner The owner of the class
* @param name The name of the class
* @param parents Function returning the parent classes of the class. The first parent must not be a trait.
* Takes the constructed class symbol as an argument. Calling `cls.typeRef.asType` as part of this function will lead to cyclic reference errors.
* @param decls The member declarations of the class provided the symbol of this class
* @param selfType The self type of the class if it has one
* @param constructorMethodType The MethodOrPoly type representing the type of the constructor.
* PolyType may only represent only the first clause of the constructor.
* @param clsFlags extra flags with which the class symbol should be constructed.
* @param clsPrivateWithin the symbol within which this new class symbol should be private. May be noSymbol
* @param consFlags extra flags with which the constructor symbol should be constructed.
* @param consPrivateWithin the symbol within which the constructor for this new class symbol should be private. May be noSymbol
* @param conParamFlags extra flags with which the constructor parameter symbols should be constructed. Must match the shape of @param constructorMethodType
*
*/
@experimental def newClass(
owner: Symbol,
name: String,
parents: Symbol => List[TypeRepr],
decls: Symbol => List[Symbol],
selfType: Option[TypeRepr],
constructorMethodType: TypeRepr => MethodOrPoly,
clsFlags: Flags,
clsPrivateWithin: Symbol,
consFlags: Flags,
consPrivateWithin: Symbol,
conParamFlags: List[List[Flags]]
): Symbol

/** Generates a new module symbol with an associated module class symbol,
* this is equivalent to an `object` declaration in source code.
Expand Down
9 changes: 9 additions & 0 deletions tests/run-macros/newClassTypeParams.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Test_2$package$foo$1
{
class foo[A, B <: scala.Int](val param1: A, val param2: B) extends java.lang.Object {
type A
type B <: scala.Int
}

(new foo[java.lang.String, scala.Int]("test", 1): scala.Any)
}
45 changes: 45 additions & 0 deletions tests/run-macros/newClassTypeParams/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//> using options -experimental

import scala.quoted.*

transparent inline def makeClass(inline name: String): Any = ${ makeClassExpr('name) }
private def makeClassExpr(nameExpr: Expr[String])(using Quotes): Expr[Any] = {
import quotes.reflect.*

val name = nameExpr.valueOrAbort
def decls(cls: Symbol): List[Symbol] = Nil
val constrType =
(classType: TypeRepr) => PolyType(List("A", "B"))(
_ => List(TypeBounds.empty, TypeBounds.upper(TypeRepr.of[Int])),
polyType => MethodType(List("param1", "param2"))((_: MethodType) => List(polyType.param(0), polyType.param(1)), (_: MethodType) => classType)
)

val cls = Symbol.newClass(
Symbol.spliceOwner,
name,
parents = _ => List(TypeRepr.of[Object]),
decls,
selfType = None,
constrType,
Flags.EmptyFlags,
Symbol.noSymbol,
Flags.EmptyFlags,
Symbol.noSymbol,
List(List(Flags.EmptyFlags, Flags.EmptyFlags), List(Flags.EmptyFlags, Flags.EmptyFlags))
)

val clsDef = ClassDef(cls, List(TypeTree.of[Object]), body = Nil)
val newCls =
cls.typeRef.asType match
case '[t] =>
Typed(Apply(TypeApply(Select(New(TypeIdent(cls)), cls.primaryConstructor), List(TypeTree.of[String], TypeTree.of[Int])), List(Expr("test").asTerm, Expr(1).asTerm)), TypeTree.of[Any])

val res = Block(List(clsDef), newCls).asExpr

Expr.ofTuple(res, Expr(res.show))

// '{
// class `name`[A, B <: Int](param1: A, param2: B)
// new `name`[String, Int]("a", 1)
// }
}
7 changes: 7 additions & 0 deletions tests/run-macros/newClassTypeParams/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//> using options -experimental

@main def Test: Unit = {
val (cls, show) = makeClass("foo")
println(cls.getClass)
println(show)
}

0 comments on commit a839263

Please sign in to comment.