Skip to content

Commit

Permalink
Fix scala#6622: Add code interpolation
Browse files Browse the repository at this point in the history
Allows to get string representations for code passed in the interpolated
values

```scala
inline def logged(p1: => Any) = {
  val c = code"code: $p1"
  val res = p1
  (c, p1)
}
logged(indentity("foo"))
```

is equivalent to:
```scala
("code: indentity("foo")", indentity("foo"))
```
  • Loading branch information
rtfpessoa committed Jun 11, 2019
1 parent 9ca016e commit 7899ae0
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 15 deletions.
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class Definitions {
def Compiletime_constValue(implicit ctx: Context): Symbol = Compiletime_constValueR.symbol
@threadUnsafe lazy val Compiletime_constValueOptR: TermRef = CompiletimePackageObjectRef.symbol.requiredMethodRef("constValueOpt")
def Compiletime_constValueOpt(implicit ctx: Context): Symbol = Compiletime_constValueOptR.symbol
@threadUnsafe lazy val Compiletime_codeR: TermRef = CompiletimePackageObjectRef.symbol.requiredMethodRef("code")
def Compiletime_code(implicit ctx: Context): Symbol = Compiletime_codeR.symbol

/** The `scalaShadowing` package is used to safely modify classes and
* objects in scala so that they can be used from dotty. They will
Expand Down
58 changes: 44 additions & 14 deletions compiler/src/dotty/tools/dotc/typer/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
val expansion = inliner.transform(rhsToInline)

def issueError() = callValueArgss match {
case (msgArg :: rest) :: Nil =>
case (msgArg :: Nil) :: Nil =>
msgArg.tpe match {
case ConstantType(Constant(msg: String)) =>
// Usually `error` is called from within a rewrite method. In this
Expand All @@ -482,23 +482,49 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
val callToReport = if (enclosingInlineds.nonEmpty) enclosingInlineds.last else call
val ctxToReport = ctx.outersIterator.dropWhile(enclosingInlineds(_).nonEmpty).next
def issueInCtx(implicit ctx: Context) = {
def decompose(arg: Tree): String = arg match {
case Typed(arg, _) => decompose(arg)
case SeqLiteral(elems, _) => elems.map(decompose).mkString(", ")
case arg =>
arg.tpe.widenTermRefExpr match {
case ConstantType(Constant(c)) => c.toString
case _ => arg.show
}
}
ctx.error(s"$msg${rest.map(decompose).mkString(", ")}", callToReport.sourcePos)
ctx.error(msg, callToReport.sourcePos)
}
issueInCtx(ctxToReport)
case _ =>
}
case _ =>
}

def issueCode()(implicit ctx: Context): Literal = {
def decompose(arg: Tree): String = arg match {
case Typed(arg, _) => decompose(arg)
case SeqLiteral(elems, _) => elems.map(decompose).mkString(", ")
case Block(Nil, expr) => decompose(expr)
case Inlined(_, Nil, expr) => decompose(expr)
case arg =>
arg.tpe.widenTermRefExpr match {
case ConstantType(Constant(c)) => c.toString
case _ => arg.show
}
}

def malformedString(): String = {
ctx.error("Malformed part `code` string interpolator", call.sourcePos)
""
}

callValueArgss match {
case List(List(Apply(_,List(Typed(SeqLiteral(Literal(headConst) :: parts,_),_)))), List(Typed(SeqLiteral(interpolatedParts,_),_)))
if parts.size == interpolatedParts.size =>
val constantParts = parts.map {
case Literal(const) => const.stringValue
case _ => malformedString()
}
val decomposedInterpolations = interpolatedParts.map(decompose)
val constantString = decomposedInterpolations.zip(constantParts)
.foldLeft(headConst.stringValue) { case (acc, (p1, p2)) => acc + p1 + p2 }

Literal(Constant(constantString)).withSpan(call.span)
case _ =>
Literal(Constant(malformedString()))
}
}

trace(i"inlining $call", inlining, show = true) {

// The normalized bindings collected in `bindingsBuf`
Expand All @@ -522,9 +548,13 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {

if (inlinedMethod == defn.Compiletime_error) issueError()

// Take care that only argument bindings go into `bindings`, since positions are
// different for bindings from arguments and bindings from body.
tpd.Inlined(call, finalBindings, finalExpansion)
if (inlinedMethod == defn.Compiletime_code) {
issueCode()(ctx.fresh.setSetting(ctx.settings.color, "never"))
} else {
// Take care that only argument bindings go into `bindings`, since positions are
// different for bindings from arguments and bindings from body.
tpd.Inlined(call, finalBindings, finalExpansion)
}
}
}

Expand Down
18 changes: 17 additions & 1 deletion library/src-3.x/scala/compiletime/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,23 @@ package object compiletime {

erased def erasedValue[T]: T = ???

inline def error(inline msg: String, objs: Any*): Nothing = ???
inline def error(inline msg: String): Nothing = ???

/** Returns the string representations for code passed in the interpolated values
* ```scala
* inline def logged(p1: => Any) = {
* val c = code"code: $p1"
* val res = p1
* (c, p1)
* }
* logged(indentity("foo"))
* ```
* is equivalent to:
* ```scala
* ("code: indentity("foo")", indentity("foo"))
* ```
*/
inline def (self: => StringContext) code (args: => Any*): String = ???

inline def constValueOpt[T]: Option[T] = ???

Expand Down
9 changes: 9 additions & 0 deletions tests/neg/i6622.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.compiletime._

object Test {

def main(args: Array[String]): Unit = {
println(StringContext("abc ", "", "").code(println(34))) // error
}

}
11 changes: 11 additions & 0 deletions tests/neg/i6622a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import scala.compiletime._

object Test {

def nonConstant: String = ""

def main(args: Array[String]): Unit = {
println(StringContext("abc ", nonConstant).code(println(34))) // error
}

}
9 changes: 9 additions & 0 deletions tests/neg/i6622b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.compiletime._

object Test {

def main(args: Array[String]): Unit = {
println(StringContext("abc ").code(println(34), 34)) // error
}

}
9 changes: 9 additions & 0 deletions tests/neg/i6622c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.compiletime._

object Test {

def main(args: Array[String]): Unit = {
println(StringContext(Seq.empty[String]:_*).code(println(34))) // error
}

}
9 changes: 9 additions & 0 deletions tests/neg/i6622d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.compiletime._

object Test {

def main(args: Array[String]): Unit = {
println(StringContext("abc").code(Seq.empty[Any]:_*)) // error
}

}
9 changes: 9 additions & 0 deletions tests/neg/i6622e.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.compiletime._

object Test {

def main(args: Array[String]): Unit = {
println(StringContext(Seq.empty[String]:_*).code(Seq.empty[Any]:_*)) // error
}

}
4 changes: 4 additions & 0 deletions tests/neg/i6622f.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/neg/i6622f.scala:6:8 -----------------------------------------------------------------------------------
6 | fail(println("foo")) // error
| ^^^^^^^^^^^^^^^^^^^^
| failed: println("foo") ...
11 changes: 11 additions & 0 deletions tests/neg/i6622f.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import scala.compiletime._

object Test {

def main(args: Array[String]): Unit = {
fail(println("foo")) // error
}

inline def fail(p1: => Any) = error(code"failed: $p1 ...")

}
15 changes: 15 additions & 0 deletions tests/run/i6622.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import scala.compiletime._

object Test {

def main(args: Array[String]): Unit = {
assert(code"abc ${println(34)} ..." == "abc println(34) ...")
assert(code"abc ${println(34)}" == "abc println(34)")
assert(code"${println(34)} ..." == "println(34) ...")
assert(code"${println(34)}" == "println(34)")
assert(code"..." == "...")
assert(testConstant(code"") == "")
}

inline def testConstant(inline msg: String): String = msg
}

0 comments on commit 7899ae0

Please sign in to comment.