Skip to content

Commit

Permalink
Use Attachment to mark the expression
Browse files Browse the repository at this point in the history
  • Loading branch information
tdudzik authored and adpi2 committed Oct 25, 2021
1 parent 149f007 commit 384bb1a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,25 @@ abstract class ExpressionEvaluatorSuite(scalaVersion: ScalaVersion)
)
}

"should evaluate expression with breakpoint on a field's assignment" - {
val source =
"""class Foo {
| val a = 1
| val b = 2
| def bar() = a + b
|}
|
|object EvaluateTest {
| def main(args: Array[String]): Unit = {
| new Foo()
| }
|}
|""".stripMargin
assertEvaluationInMainClass(source, "EvaluateTest", 3, "a + 2")(
_.exists(_.toInt == 3)
)
}

"should evaluate expression with breakpoint on method definition" - {
val source =
"""class Foo {
Expand Down Expand Up @@ -705,14 +724,14 @@ abstract class ExpressionEvaluatorSuite(scalaVersion: ScalaVersion)
ExpressionEvaluation(
5,
"1 + 1",
_.exists(_ == "\"values are not the same\""),
_.exists(_.toInt == 2),
stoppageNo = 0
),
// evaluating twice because the program stops twice at the same breakpoint...
ExpressionEvaluation(
5,
"1 + 1",
_.exists(_ == "\"values are not the same\""),
_.exists(_.toInt == 2),
stoppageNo = 1
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ private[nsc] class EvalGlobal(
)
}

case object ExpressionAttachment extends PlainAttachment

class InsertExpression extends Transform with TypingTransformers {
override val global: EvalGlobal.this.type = EvalGlobal.this
override val phaseName: String = "insertexpression"
Expand Down Expand Up @@ -176,8 +178,7 @@ private[nsc] class EvalGlobal(

override def transform(tree: Tree): Tree = tree match {
case tree: DefDef if tree.pos.line == line =>
expressionInserted = true
atPos(tree.pos)(
insertAt(tree.pos)(
filterOutTailRec(
treeCopy.DefDef(
tree,
Expand All @@ -193,8 +194,7 @@ private[nsc] class EvalGlobal(
case tree: DefDef =>
super.transform(filterOutTailRec(tree))
case vd: ValDef if vd.pos.line == line =>
expressionInserted = true
atPos(vd.pos)(
insertAt(vd.pos)(
treeCopy.ValDef(
vd,
vd.mods,
Expand All @@ -203,10 +203,8 @@ private[nsc] class EvalGlobal(
mkExprBlock(vd.rhs)
)
)

case tree if tree.pos.line == line =>
expressionInserted = true
atPos(tree.pos)(mkExprBlock(tree))
insertAt(tree.pos)(mkExprBlock(tree))
case tree: PackageDef =>
val transformed = super.transform(tree).asInstanceOf[PackageDef]
if (expressionInserted) {
Expand All @@ -225,11 +223,25 @@ private[nsc] class EvalGlobal(
super.transform(tree)
}

private def mkExprBlock(tree: Tree): Tree =
if (tree.isDef)
Block(List(parsedExpression, tree), Literal(Constant(())))
else
Block(List(parsedExpression), tree)
private def insertAt(pos: Position)(tree: Tree): Tree = {
expressionInserted = true
atPos(pos)(tree)
}

private def mkExprBlock(tree: Tree): Tree = {
val block =
if (tree.isDef)
Block(List(parsedExpression, tree), Literal(Constant(())))
else
Block(List(parsedExpression), tree)
addExpressionAttachment(block)
}

// `ExpressionAttachment` allows to find the inserted expression later on
private def addExpressionAttachment(tree: Tree): Tree = {
val attachments = tree.attachments.update(ExpressionAttachment)
tree.setAttachments(attachments)
}
}
}

Expand Down Expand Up @@ -282,21 +294,26 @@ private[nsc] class EvalGlobal(
// Don't extract expression from the Expression class
case tree: ClassDef if tree.name.decode == expressionClassName =>
// ignore
case tree: DefDef if !expressionExtracted && tree.pos.line == line =>
case tree: DefDef if shouldExtract(tree) =>
expressionOwners = ownerChain(tree)
extractedExpression = extractExpression(tree.rhs)
// default arguments will have an additional method generated, which we need to skip
case tree: ValDef if tree.rhs.isEmpty =>
case tree: ValDef if !expressionExtracted && tree.pos.line == line =>
case tree: ValDef if shouldExtract(tree) =>
expressionOwners = ownerChain(tree)
extractedExpression = extractExpression(tree.rhs)
case _ if !expressionExtracted && tree.pos.line == line =>
case _ if shouldExtract(tree) =>
expressionOwners = ownerChain(tree)
extractedExpression = extractExpression(tree)
case _ =>
super.traverse(tree)
}

private def shouldExtract(tree: Tree): Boolean =
!expressionExtracted && tree.attachments
.get[ExpressionAttachment.type]
.isDefined

private def ownerChain(tree: Tree): List[Symbol] =
if (tree.symbol == null)
currentOwner.ownerChain
Expand Down

0 comments on commit 384bb1a

Please sign in to comment.