Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle binding of beta reduced inlined lambdas #16377

Merged
merged 2 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/TreeInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -743,8 +743,6 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
Some(meth)
case Block(Nil, expr) =>
unapply(expr)
case Inlined(_, bindings, expr) if bindings.forall(isPureBinding) =>
unapply(expr)
case _ =>
None
}
Expand Down
65 changes: 38 additions & 27 deletions compiler/src/dotty/tools/dotc/inlines/InlineReducer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,35 +158,46 @@ class InlineReducer(inliner: Inliner)(using Context):
*
* where `def` is used for call-by-name parameters. However, we shortcut any NoPrefix
* refs among the ei's directly without creating an intermediate binding.
*
* This variant of beta-reduction preserves the integrity of `Inlined` tree nodes.
*/
def betaReduce(tree: Tree)(using Context): Tree = tree match {
case Apply(Select(cl @ closureDef(ddef), nme.apply), args) if defn.isFunctionType(cl.tpe) =>
// closureDef also returns a result for closures wrapped in Inlined nodes.
// These need to be preserved.
def recur(cl: Tree): Tree = cl match
case Inlined(call, bindings, expr) =>
cpy.Inlined(cl)(call, bindings, recur(expr))
case _ => ddef.tpe.widen match
case mt: MethodType if ddef.paramss.head.length == args.length =>
val bindingsBuf = new DefBuffer
val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) =>
arg.tpe.dealias match {
case ref @ TermRef(NoPrefix, _) => ref.symbol
case _ =>
paramBindingDef(name, paramtp, arg, bindingsBuf)(
using ctx.withSource(cl.source)
).symbol
case Apply(Select(cl, nme.apply), args) if defn.isFunctionType(cl.tpe) =>
val bindingsBuf = new DefBuffer
def recur(cl: Tree): Option[Tree] = cl match
case Block((ddef : DefDef) :: Nil, closure: Closure) if ddef.symbol == closure.meth.symbol =>
ddef.tpe.widen match
case mt: MethodType if ddef.paramss.head.length == args.length =>
val argSyms = mt.paramNames.lazyZip(mt.paramInfos).lazyZip(args).map { (name, paramtp, arg) =>
arg.tpe.dealias match {
case ref @ TermRef(NoPrefix, _) => ref.symbol
case _ =>
paramBindingDef(name, paramtp, arg, bindingsBuf)(
using ctx.withSource(cl.source)
).symbol
}
}
}
val expander = new TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = ddef.paramss.head.map(_.symbol),
substTo = argSyms)
Block(bindingsBuf.toList, expander.transform(ddef.rhs)).withSpan(tree.span)
case _ => tree
recur(cl)
case _ => tree
val expander = new TreeTypeMap(
oldOwners = ddef.symbol :: Nil,
newOwners = ctx.owner :: Nil,
substFrom = ddef.paramss.head.map(_.symbol),
substTo = argSyms)
Some(expander.transform(ddef.rhs))
case _ => None
case Block(stats, expr) if stats.forall(isPureBinding) =>
recur(expr).map(cpy.Block(cl)(stats, _))
case Inlined(call, bindings, expr) if bindings.forall(isPureBinding) =>
recur(expr).map(cpy.Inlined(cl)(call, bindings, _))
case Typed(expr, tpt) =>
recur(expr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we don't preserve the Typed node here?

Suggested change
recur(expr)
recur(expr).map(cpy.Typed(cl)(_, tpt))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is intentional. This ascription is on the lambda and hence has the lambda type, but the result of this transformation has the type of the result of this lambda. Dropping it is the simplest option. We could attempt to add the ascription again on the result. This would be more complicated and would also prevent optimizations on constants. We also do the same in https://github.com/lampepfl/dotty/blob/main/compiler/src/dotty/tools/dotc/transform/BetaReduce.scala#L57-L58.

case _ => None
recur(cl) match
case Some(reduced) =>
Block(bindingsBuf.toList, reduced).withSpan(tree.span)
case None =>
tree
case _ =>
tree
}

/** The result type of reducing a match. It consists optionally of a list of bindings
Expand Down Expand Up @@ -281,7 +292,7 @@ class InlineReducer(inliner: Inliner)(using Context):
// Test case is pos-macros/i15971
val tptBinds = getBinds(Set.empty[TypeSymbol], tpt)
val binds: Set[TypeSymbol] = pat match {
case UnApply(TypeApply(_, tpts), _, _) =>
case UnApply(TypeApply(_, tpts), _, _) =>
getBinds(Set.empty[TypeSymbol], tpts) ++ tptBinds
case _ => tptBinds
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,12 @@ class InlineBytecodeTests extends DottyBytecodeTest {
val instructions = instructionsFromMethod(fun)
val expected = // TODO room for constant folding
List(
Op(ICONST_1),
Op(ICONST_2),
VarOp(ISTORE, 1),
Op(ICONST_1),
VarOp(ISTORE, 2),
Op(ICONST_2),
VarOp(ILOAD, 1),
VarOp(ILOAD, 2),
Op(IADD),
Op(ICONST_3),
Op(IADD),
Expand Down
7 changes: 7 additions & 0 deletions tests/pos/i16374a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def method(using String): String = ???

inline def inlineMethod(inline op: String => Unit)(using String): Unit =
println(op(method))

def test(using String) =
inlineMethod(c => print(c))
9 changes: 9 additions & 0 deletions tests/pos/i16374b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def method(using String): String = ???

inline def identity[T](inline x: T): T = x

inline def inlineMethod(inline op: String => Unit)(using String): Unit =
println(identity(op)(method))

def test(using String) =
inlineMethod(c => print(c))
7 changes: 7 additions & 0 deletions tests/pos/i16374c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
def method(using String): String = ???

inline def inlineMethod(inline op: String => Unit)(using String): Unit =
println({ val a: Int = 1; op }.apply(method))

def test(using String) =
inlineMethod(c => print(c))
4 changes: 4 additions & 0 deletions tests/pos/i16374d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
inline def inline1(inline f: Int => Int): Int => Int = i => f(1)
inline def inline2(inline f: Int => Int): Int = f(2) + 3
def test: Int = inline2(inline1(2.+))