Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support signature polymorphic methods (MethodHandle and VarHandle) #16225

Merged
merged 10 commits into from
Nov 22, 2022
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/backend/jvm/CoreBTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ class CoreBTypes[BTFS <: BTypesFromSymbols[_ <: DottyBackendInterface]](val bTyp

private lazy val jliCallSiteRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.CallSite])
private lazy val jliLambdaMetafactoryRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.LambdaMetafactory])
private lazy val jliMethodHandleRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodHandle])
private lazy val jliMethodHandlesLookupRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodHandles.Lookup])
private lazy val jliMethodHandleRef : ClassBType = classBTypeFromSymbol(defn.MethodHandleClass)
private lazy val jliMethodHandlesLookupRef : ClassBType = classBTypeFromSymbol(defn.MethodHandlesLookupClass)
private lazy val jliMethodTypeRef : ClassBType = classBTypeFromSymbol(requiredClass[java.lang.invoke.MethodType])
private lazy val jliStringConcatFactoryRef : ClassBType = classBTypeFromSymbol(requiredClass("java.lang.invoke.StringConcatFactory")) // since JDK 9
private lazy val srLambdaDeserialize : ClassBType = classBTypeFromSymbol(requiredClass[scala.runtime.LambdaDeserialize])
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,10 @@ class Definitions {
}
def JavaEnumType = JavaEnumClass.typeRef

@tu lazy val MethodHandleClass: ClassSymbol = requiredClass("java.lang.invoke.MethodHandle")
@tu lazy val MethodHandlesLookupClass: ClassSymbol = requiredClass("java.lang.invoke.MethodHandles.Lookup")
@tu lazy val VarHandleClass: ClassSymbol = requiredClass("java.lang.invoke.VarHandle")

@tu lazy val StringBuilderClass: ClassSymbol = requiredClass("scala.collection.mutable.StringBuilder")
@tu lazy val MatchErrorClass : ClassSymbol = requiredClass("scala.MatchError")
@tu lazy val ConversionClass : ClassSymbol = requiredClass("scala.Conversion").typeRef.symbol.asClass
Expand Down
20 changes: 20 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,26 @@ object SymDenotations {

def isSkolem: Boolean = name == nme.SKOLEM

// Java language spec: https://docs.oracle.com/javase/specs/jls/se11/html/jls-15.html#jls-15.12.3
// Scala 2 spec: https://scala-lang.org/files/archive/spec/2.13/06-expressions.html#signature-polymorphic-methods
def isSignaturePolymorphic(using Context): Boolean =
containsSignaturePolymorphic
&& is(JavaDefined)
&& hasAnnotation(defn.NativeAnnot)
&& atPhase(typerPhase)(symbol.denot).paramSymss.match
case List(List(p)) => p.info.isRepeatedParam
case _ => false

def containsSignaturePolymorphic(using Context): Boolean =
maybeOwner == defn.MethodHandleClass
|| maybeOwner == defn.VarHandleClass

def originalSignaturePolymorphic(using Context): Denotation =
if containsSignaturePolymorphic && !isSignaturePolymorphic then
val d = owner.info.member(name)
if d.symbol.isSignaturePolymorphic then d else NoDenotation
else NoDenotation

def isInlineMethod(using Context): Boolean =
isAllOf(InlineMethod, butNot = Accessor)

Expand Down
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,13 @@ class TreePickler(pickler: TastyPickler) {
writeByte(THROW)
pickleTree(args.head)
}
else if fun.symbol.originalSignaturePolymorphic.exists then
writeByte(APPLYsigpoly)
withLength {
pickleTree(fun)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't originalSignaturePolymorphic be involved here, somehow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like how? The function Select will be pickled as a SELECTin with the name "invokeExact" and the "in" as MethodHandle, which during unpickling will pick the original method and that's why we have to fix it up.

pickleType(fun.tpe.widenTermRefExpr, richTypes = true) // this widens to a MethodType, so need richTypes
args.foreach(pickleTree)
}
else {
writeByte(APPLY)
withLength {
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,12 @@ class TreeUnpickler(reader: TastyReader,
else tpd.Apply(fn, args)
case TYPEAPPLY =>
tpd.TypeApply(readTerm(), until(end)(readTpt()))
case APPLYsigpoly =>
val fn = readTerm()
val methType = readType()
val args = until(end)(readTerm())
val fun2 = typer.Applications.retypeSignaturePolymorphicFn(fn, methType)
tpd.Apply(fun2, args)
case TYPED =>
val expr = readTerm()
val tpt = readTpt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ class SemanticSymbolBuilder:
def addOwner(owner: Symbol): Unit =
if !owner.isRoot then addSymName(b, owner)

def addOverloadIdx(sym: Symbol): Unit =
def addOverloadIdx(initSym: Symbol): Unit =
// revert from the compiler-generated overload of the signature polymorphic method
val sym = initSym.originalSignaturePolymorphic.symbol.orElse(initSym)
val decls =
val decls0 = sym.owner.info.decls.lookupAll(sym.name)
if sym.owner.isAllOf(JavaModule) then
Expand Down
4 changes: 3 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/Recheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ abstract class Recheck extends Phase, SymTransformer:
mt.instantiate(argTypes)

def recheckApply(tree: Apply, pt: Type)(using Context): Type =
val funtpe = recheck(tree.fun)
val funTp = recheck(tree.fun)
// reuse the tree's type on signature polymorphic methods, instead of using the (wrong) rechecked one
val funtpe = if tree.fun.symbol.originalSignaturePolymorphic.exists then tree.fun.tpe else funTp
funtpe.widen match
case fntpe: MethodType =>
assert(fntpe.paramInfos.hasSameLengthAs(tree.args))
Expand Down
23 changes: 22 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Inferencing._
import reporting._
import transform.TypeUtils._
import transform.SymUtils._
import Nullables._
import Nullables._, NullOpsDecorator.*
import config.Feature

import collection.mutable
Expand Down Expand Up @@ -340,6 +340,12 @@ object Applications {
val getter = findDefaultGetter(fn, n, testOnly)
if getter.isEmpty then getter
else spliceMeth(getter.withSpan(fn.span), fn)

def retypeSignaturePolymorphicFn(fun: Tree, methType: Type)(using Context): Tree =
val sym1 = fun.symbol
val flags2 = sym1.flags | NonMember // ensures Select typing doesn't let TermRef#withPrefix revert the type
val sym2 = sym1.copy(info = methType, flags = flags2) // symbol not entered, to avoid overload resolution problems
fun.withType(sym2.termRef)
}

trait Applications extends Compatibility {
Expand Down Expand Up @@ -936,6 +942,21 @@ trait Applications extends Compatibility {
/** Type application where arguments come from prototype, and no implicits are inserted */
def simpleApply(fun1: Tree, proto: FunProto)(using Context): Tree =
methPart(fun1).tpe match {
case funRef: TermRef if funRef.symbol.isSignaturePolymorphic =>
// synthesize a method type based on the types at the call site.
// one can imagine the original signature-polymorphic method as
// being infinitely overloaded, with each individual overload only
// being brought into existence as needed
val originalResultType = funRef.symbol.info.resultType.stripNull
val resultType =
if !originalResultType.isRef(defn.ObjectClass) then originalResultType
else AvoidWildcardsMap()(proto.resultType.deepenProtoTrans) match
case SelectionProto(nme.asInstanceOf_, PolyProto(_, resTp), _, _) => resTp
case resTp if isFullyDefined(resTp, ForceDegree.all) => resTp
case _ => defn.ObjectType
val methType = MethodType(proto.typedArgs().map(_.tpe.widen), resultType)
val fun2 = Applications.retypeSignaturePolymorphicFn(fun1, methType)
simpleApply(fun2, proto)
case funRef: TermRef =>
val app = ApplyTo(tree, fun1, funRef, proto, pt)
convertNewGenericArray(
Expand Down
7 changes: 4 additions & 3 deletions project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1816,9 +1816,10 @@ object Build {
settings(disableDocSetting).
settings(
versionScheme := Some("semver-spec"),
if (mode == Bootstrapped) {
commonMiMaSettings
} else {
if (mode == Bootstrapped) Def.settings(
commonMiMaSettings,
mimaBinaryIssueFilters ++= MiMaFilters.TastyCore,
) else {
Nil
}
)
Expand Down
3 changes: 3 additions & 0 deletions project/MiMaFilters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ object MiMaFilters {
ProblemFilters.exclude[MissingClassProblem]("scala.caps$Pure"),
ProblemFilters.exclude[MissingClassProblem]("scala.caps$unsafe$"),
)
val TastyCore: Seq[ProblemFilter] = Seq(
ProblemFilters.exclude[MissingMethodProblem]("dotty.tools.tasty.TastyFormat.APPLYsigpoly"),
dwijnand marked this conversation as resolved.
Show resolved Hide resolved
)
}
3 changes: 3 additions & 0 deletions tasty/src/dotty/tools/tasty/TastyFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Standard-Section: "ASTs" TopLevelStat*
THROW throwableExpr_Term -- throw throwableExpr
NAMEDARG paramName_NameRef arg_Term -- paramName = arg
APPLY Length fn_Term arg_Term* -- fn(args)
APPLYsigpoly Length fn_Term meth_Type arg_Term* -- The application of a signature-polymorphic method
TYPEAPPLY Length fn_Term arg_Type* -- fn[args]
SUPER Length this_Term mixinTypeIdent_Tree? -- super[mixin]
TYPED Length expr_Term ascriptionType_Term -- expr: ascription
Expand Down Expand Up @@ -578,6 +579,7 @@ object TastyFormat {
// final val ??? = 178
// final val ??? = 179
final val METHODtype = 180
final val APPLYsigpoly = 181

final val MATCHtype = 190
final val MATCHtpt = 191
Expand Down Expand Up @@ -744,6 +746,7 @@ object TastyFormat {
case BOUNDED => "BOUNDED"
case APPLY => "APPLY"
case TYPEAPPLY => "TYPEAPPLY"
case APPLYsigpoly => "APPLYsigpoly"
case NEW => "NEW"
case THROW => "THROW"
case TYPED => "TYPED"
Expand Down
22 changes: 22 additions & 0 deletions tests/explicit-nulls/run/i11332.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// scalajs: --skip
import scala.language.unsafeNulls

import java.lang.invoke._, MethodType.methodType

// A copy of tests/run/i11332.scala
// to test the bootstrap minimisation which failed
// (because bootstrap runs under explicit nulls)
class Foo:
def neg(x: Int): Int = -x

object Test:
def main(args: Array[String]): Unit =
val l = MethodHandles.lookup()
val self = new Foo()

val res4 = {
l // explicit chain method call - previously derivedSelect broke the type
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(self, 4): Int
}
assert(-4 == res4)
72 changes: 72 additions & 0 deletions tests/run/i11332.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// scalajs: --skip
import scala.language.unsafeNulls

import java.lang.invoke._, MethodType.methodType

class Foo:
def neg(x: Int): Int = -x
def rev(s: String): String = s.reverse
def over(l: Long): String = "long"
def over(i: Int): String = "int"
def unit(s: String): Unit = ()
def obj(s: String): Object = s

object Test:
def main(args: Array[String]): Unit =
val l = MethodHandles.lookup()
val self = new Foo()
val mhNeg = l.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
val mhRev = l.findVirtual(classOf[Foo], "rev", methodType(classOf[String], classOf[String]))
val mhOverL = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Long]))
val mhOverI = l.findVirtual(classOf[Foo], "over", methodType(classOf[String], classOf[Int]))
val mhUnit = l.findVirtual(classOf[Foo], "unit", methodType(classOf[Unit], classOf[String]))
val mhObj = l.findVirtual(classOf[Foo], "obj", methodType(classOf[Any], classOf[String]))
val mhCL = l.findStatic(classOf[ClassLoader], "getPlatformClassLoader", methodType(classOf[ClassLoader]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This broke the CI because some of our jobs are running on Java 8 which does not define this method:
https://github.com/lampepfl/dotty/actions/runs/3524847656/jobs/5910798965

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


assert(-42 == (mhNeg.invokeExact(self, 42): Int))
assert(-33 == (mhNeg.invokeExact(self, 33): Int))

assert("oof" == (mhRev.invokeExact(self, "foo"): String))
assert("rab" == (mhRev.invokeExact(self, "bar"): String))

assert("long" == (mhOverL.invokeExact(self, 1L): String))
assert("int" == (mhOverI.invokeExact(self, 1): String))

assert(-3 == (id(mhNeg.invokeExact(self, 3)): Int))
expectWrongMethod(mhNeg.invokeExact(self, 4))

{ mhUnit.invokeExact(self, "hi"): Unit; () } // explicit block
val hi2: Unit = mhUnit.invokeExact(self, "hi2")
assert((()) == hi2)
def hi3: Unit = mhUnit.invokeExact(self, "hi3")
assert((()) == hi3)

{ mhObj.invokeExact(self, "any"); () } // explicit block
val any2 = mhObj.invokeExact(self, "any2")
assert("any2" == any2)
def any3 = mhObj.invokeExact(self, "any3")
assert("any3" == any3)

assert(null != (mhCL.invoke(): ClassLoader))
assert(null != (mhCL.invoke().asInstanceOf[ClassLoader]: ClassLoader))
assert(null != (mhCL.invokeExact(): ClassLoader))
assert(null != (mhCL.invokeExact().asInstanceOf[ClassLoader]: ClassLoader))

expectWrongMethod {
l // explicit chain method call
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(self, 3)
}
val res4 = {
l // explicit chain method call
.findVirtual(classOf[Foo], "neg", methodType(classOf[Int], classOf[Int]))
.invokeExact(self, 4): Int
}
assert(-4 == res4)

def id[T](x: T): T = x

def expectWrongMethod(op: => Any) = try {
op
throw new AssertionError("expected operation to fail but it didn't")
} catch case expected: WrongMethodTypeException => ()
22 changes: 22 additions & 0 deletions tests/run/t12348.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// test: -jvm 11+
// scalajs: --skip
import java.lang.invoke._
import scala.runtime.IntRef

object Test {
def main(args: Array[String]): Unit = {
val ref = new scala.runtime.IntRef(0)
val varHandle = MethodHandles.lookup()
.in(classOf[IntRef])
.findVarHandle(classOf[IntRef], "elem", classOf[Int])
assert(0 == (varHandle.getAndSet(ref, 1): Int))
assert(1 == (varHandle.getAndSet(ref, 2): Int))
assert(2 == ref.elem)

assert((()) == (varHandle.set(ref, 3): Any))
assert(3 == (varHandle.get(ref): Int))

assert(true == (varHandle.compareAndSet(ref, 3, 4): Any))
assert(4 == (varHandle.get(ref): Int))
}
}