From bd0aa525d9773ed2f54d60c772bee8ad2528fd6e Mon Sep 17 00:00:00 2001 From: odersky Date: Sun, 15 Jan 2023 13:18:20 +0100 Subject: [PATCH 1/4] Streamline translation of for expressions - [] Avoid redundant map call if the yielded value is the same as the last result. This makes for expressions more efficient and provides more opportunities for tail recursion. --- .../src/dotty/tools/dotc/ast/Desugar.scala | 59 ++++++++++++------- compiler/src/dotty/tools/dotc/ast/untpd.scala | 3 +- tests/run/fors.check | 3 + tests/run/fors.scala | 14 +++++ 4 files changed, 56 insertions(+), 23 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index c360712999e2..df5b7c1501d8 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1807,38 +1807,44 @@ object desugar { * * 1. * - * for (P <- G) E ==> G.foreach (P => E) + * for (P <- G) E ==> G.foreach (P => E) * - * Here and in the following (P => E) is interpreted as the function (P => E) - * if P is a variable pattern and as the partial function { case P => E } otherwise. + * Here and in the following (P => E) is interpreted as the function (P => E) + * if P is a variable pattern and as the partial function { case P => E } otherwise. * * 2. * - * for (P <- G) yield E ==> G.map (P => E) + * for (P <- G) yield P ==> G + * + * if P is a variable or a tuple of variables and G is not a withFilter. + * + * for (P <- G) yield E ==> G.map (P => E) + * + * otherwise * * 3. * - * for (P_1 <- G_1; P_2 <- G_2; ...) ... - * ==> - * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) + * for (P_1 <- G_1; P_2 <- G_2; ...) ... + * ==> + * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) * * 4. * - * for (P <- G; E; ...) ... - * => - * for (P <- G.filter (P => E); ...) ... + * for (P <- G; E; ...) ... + * => + * for (P <- G.filter (P => E); ...) ... * * 5. For any N: * - * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) - * ==> - * for (TupleN(P_1, P_2, ... P_N) <- - * for (x_1 @ P_1 <- G) yield { - * val x_2 @ P_2 = E_2 - * ... - * val x_N & P_N = E_N - * TupleN(x_1, ..., x_N) - * } ...) + * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) + * ==> + * for (TupleN(P_1, P_2, ... P_N) <- + * for (x_1 @ P_1 <- G) yield { + * val x_2 @ P_2 = E_2 + * ... + * val x_N & P_N = E_N + * TupleN(x_1, ..., x_N) + * } ...) * * If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated * and the variable constituting P_i is used instead of x_i @@ -1951,7 +1957,7 @@ object desugar { case GenCheckMode.FilterAlways => false // pattern was prefixed by `case` case GenCheckMode.FilterNow | GenCheckMode.CheckAndFilter => isVarBinding(gen.pat) || isIrrefutable(gen.pat, gen.expr) case GenCheckMode.Check => true - case GenCheckMode.Ignore => true + case GenCheckMode.Ignore | GenCheckMode.Filtered => true /** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when * matched against `rhs`. @@ -1961,9 +1967,18 @@ object desugar { Select(rhs, name) } + def deepEquals(t1: Tree, t2: Tree): Boolean = + (unsplice(t1), unsplice(t2)) match + case (Ident(n1), Ident(n2)) => n1 == n2 + case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals) + case _ => false + enums match { case (gen: GenFrom) :: Nil => - Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + && deepEquals(gen.pat, body) + then gen.expr // avoid a redundant map with identity + else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => val cont = makeFor(mapName, flatMapName, rest, body) Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) @@ -1985,7 +2000,7 @@ object desugar { makeFor(mapName, flatMapName, vfrom1 :: rest1, body) case (gen: GenFrom) :: test :: rest => val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) - val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore) + val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered) makeFor(mapName, flatMapName, genFrom :: rest, body) case _ => EmptyTree //may happen for erroneous input diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 81228b1588d0..a3aee4dc17d2 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -183,7 +183,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { /** An enum to control checking or filtering of patterns in GenFrom trees */ enum GenCheckMode { - case Ignore // neither filter nor check since filtering was done before + case Ignore // neither filter since pattern is trivially irrefutable + case Filtered // neither filter nor check since filtering was done before case Check // check that pattern is irrefutable case CheckAndFilter // both check and filter (transitional period starting with 3.2) case FilterNow // filter out non-matching elements if we are not in 3.2 or later diff --git a/tests/run/fors.check b/tests/run/fors.check index 50f6385e5845..7b7e8d076108 100644 --- a/tests/run/fors.check +++ b/tests/run/fors.check @@ -45,6 +45,9 @@ hello world hello/1~2 hello/3~4 /1~2 /3~4 world/1~2 world/3~4 (2,1) (4,3) +testTailrec +List((4,Symbol(a)), (5,Symbol(b)), (6,Symbol(c))) + testGivens 123 456 diff --git a/tests/run/fors.scala b/tests/run/fors.scala index 682978b5b3d8..bd7de7d32263 100644 --- a/tests/run/fors.scala +++ b/tests/run/fors.scala @@ -4,6 +4,8 @@ //############################################################################ +import annotation.tailrec + object Test extends App { val xs = List(1, 2, 3) val ys = List(Symbol("a"), Symbol("b"), Symbol("c")) @@ -108,6 +110,17 @@ object Test extends App { for case (x, y) <- xs do print(s"${(y, x)} "); println() } + /////////////////// elimination of map /////////////////// + + @tailrec + def pair[B](xs: List[Int], ys: List[B], n: Int): List[(Int, B)] = + if n == 0 then xs.zip(ys) + else for (x, y) <- pair(xs.map(_ + 1), ys, n - 1) yield (x, y) + + def testTailrec() = + println("\ntestTailrec") + println(pair(xs, ys, 3)) + def testGivens(): Unit = { println("\ntestGivens") @@ -141,5 +154,6 @@ object Test extends App { testOld() testNew() testFiltering() + testTailrec() testGivens() } From 9c3e454f045a5bb886e344e3d4ba2910af5334f3 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 22 Jul 2024 20:08:49 +0200 Subject: [PATCH 2/4] Add improvements to for comprehensions - Allow `for`-comprehensions to start with aliases desugaring them into valdefs in a new block - Desugar aliases into simple valdefs, instead of patterns when they are not followed by a guard - Add an experimental language flag that enables the new desugaring method --- .../src/dotty/tools/dotc/ast/Desugar.scala | 161 +++++++++++++----- .../src/dotty/tools/dotc/config/Feature.scala | 3 + .../src/dotty/tools/dotc/core/StdNames.scala | 1 + .../dotty/tools/dotc/parsing/Parsers.scala | 18 +- .../runtime/stdLibPatches/language.scala | 6 + tests/run/better-fors.check | 12 ++ tests/run/better-fors.scala | 105 ++++++++++++ tests/run/fors.scala | 2 + 8 files changed, 263 insertions(+), 45 deletions(-) create mode 100644 tests/run/better-fors.check create mode 100644 tests/run/better-fors.scala diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index df5b7c1501d8..4231505dce62 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -11,6 +11,7 @@ import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, D import typer.{Namer, Checking} import util.{Property, SourceFile, SourcePosition, SrcPos, Chars} import config.{Feature, Config} +import config.Feature.{sourceVersion, migrateTo3, enabled, betterForsEnabled} import config.SourceVersion.* import collection.mutable import reporting.* @@ -1807,7 +1808,7 @@ object desugar { * * 1. * - * for (P <- G) E ==> G.foreach (P => E) + * for (P <- G) do E ==> G.foreach (P => E) * * Here and in the following (P => E) is interpreted as the function (P => E) * if P is a variable pattern and as the partial function { case P => E } otherwise. @@ -1816,11 +1817,11 @@ object desugar { * * for (P <- G) yield P ==> G * - * if P is a variable or a tuple of variables and G is not a withFilter. + * If P is a variable or a tuple of variables and G is not a withFilter. * * for (P <- G) yield E ==> G.map (P => E) * - * otherwise + * Otherwise * * 3. * @@ -1830,25 +1831,48 @@ object desugar { * * 4. * - * for (P <- G; E; ...) ... - * => - * for (P <- G.filter (P => E); ...) ... + * for (P <- G; if E; ...) ... + * ==> + * for (P <- G.withFilter (P => E); ...) ... * * 5. For any N: * - * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) + * for (P <- G; P_1 = E_1; ... P_N = E_N; rest) * ==> - * for (TupleN(P_1, P_2, ... P_N) <- - * for (x_1 @ P_1 <- G) yield { - * val x_2 @ P_2 = E_2 + * G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...)) if rest contains (<-) + * G.map (P => for (P_1 = E_1; ... P_N = E_N; ...)) otherwise + * + * 6. For any N: + * + * for (P <- G; P_1 = E_1; ... P_N = E_N; if E; ...) + * ==> + * for (TupleN(P, P_1, ... P_N) <- + * for (x @ P <- G) yield { + * val x_1 @ P_1 = E_2 * ... - * val x_N & P_N = E_N - * TupleN(x_1, ..., x_N) - * } ...) + * val x_N @ P_N = E_N + * TupleN(x, x_1, ..., x_N) + * }; if E; ...) * * If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated * and the variable constituting P_i is used instead of x_i * + * 7. For any N: + * + * for (P_1 = E_1; ... P_N = E_N; ...) + * ==> + * { + * val x_N @ P_N = E_N + * for (...) + * } + * + * 8. + * for () yield E ==> E + * + * (Where empty for-comprehensions are excluded by the parser) + * + * If the aliases are not followed by a guard, otherwise an error. + * * @param mapName The name to be used for maps (either map or foreach) * @param flatMapName The name to be used for flatMaps (either flatMap or foreach) * @param enums The enumerators in the for expression @@ -1973,37 +1997,86 @@ object desugar { case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals) case _ => false - enums match { - case (gen: GenFrom) :: Nil => - if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type - && deepEquals(gen.pat, body) - then gen.expr // avoid a redundant map with identity - else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) - case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => - val cont = makeFor(mapName, flatMapName, rest, body) - Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) - case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => - val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) - val pats = valeqs map { case GenAlias(pat, _) => pat } - val rhss = valeqs map { case GenAlias(_, rhs) => rhs } - val (defpat0, id0) = makeIdPat(gen.pat) - val (defpats, ids) = (pats map makeIdPat).unzip - val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => - val mods = defpat match - case defTree: DefTree => defTree.mods - case _ => Modifiers() - makePatDef(valeq, mods, defpat, rhs) - } - val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) - val allpats = gen.pat :: pats - val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) - makeFor(mapName, flatMapName, vfrom1 :: rest1, body) - case (gen: GenFrom) :: test :: rest => - val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) - val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered) - makeFor(mapName, flatMapName, genFrom :: rest, body) - case _ => - EmptyTree //may happen for erroneous input + if betterForsEnabled then + enums match { + case Nil => body + case (gen: GenFrom) :: Nil => + if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + && deepEquals(gen.pat, body) + then gen.expr // avoid a redundant map with identity + else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + case (gen: GenFrom) :: rest + if rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) => + val cont = makeFor(mapName, flatMapName, rest, body) + val selectName = + if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName + else mapName + Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) + case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => + val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) + val pats = valeqs map { case GenAlias(pat, _) => pat } + val rhss = valeqs map { case GenAlias(_, rhs) => rhs } + val (defpat0, id0) = makeIdPat(gen.pat) + val (defpats, ids) = (pats map makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) + val allpats = gen.pat :: pats + val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) + makeFor(mapName, flatMapName, vfrom1 :: rest1, body) + case (gen: GenFrom) :: test :: rest => + val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) + val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered) + makeFor(mapName, flatMapName, genFrom :: rest, body) + case GenAlias(_, _) :: _ => + val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias]) + val pats = valeqs.map { case GenAlias(pat, _) => pat } + val rhss = valeqs.map { case GenAlias(_, rhs) => rhs } + val (defpats, ids) = pats.map(makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + Block(pdefs, makeFor(mapName, flatMapName, rest, body)) + case _ => + EmptyTree //may happen for erroneous input + } + else { + enums match { + case (gen: GenFrom) :: Nil => + Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => + val cont = makeFor(mapName, flatMapName, rest, body) + Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) + case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => + val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) + val pats = valeqs map { case GenAlias(pat, _) => pat } + val rhss = valeqs map { case GenAlias(_, rhs) => rhs } + val (defpat0, id0) = makeIdPat(gen.pat) + val (defpats, ids) = (pats map makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) + val allpats = gen.pat :: pats + val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) + makeFor(mapName, flatMapName, vfrom1 :: rest1, body) + case (gen: GenFrom) :: test :: rest => + val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) + val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore) + makeFor(mapName, flatMapName, genFrom :: rest, body) + case _ => + EmptyTree //may happen for erroneous input + } } } diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 8c1021e91e38..cad9b4e76ca9 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -38,6 +38,7 @@ object Feature: val modularity = experimental("modularity") val betterMatchTypeExtractors = experimental("betterMatchTypeExtractors") val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions") + val betterFors = experimental("betterFors") def experimentalAutoEnableFeatures(using Context): List[TermName] = defn.languageExperimentalFeatures @@ -125,6 +126,8 @@ object Feature: def clauseInterleavingEnabled(using Context) = sourceVersion.isAtLeast(`3.6`) || enabled(clauseInterleaving) + def betterForsEnabled(using Context) = enabled(betterFors) + def genericNumberLiteralsEnabled(using Context) = enabled(genericNumberLiterals) def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index d3e198a7e7a7..bbe405b46bf1 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -435,6 +435,7 @@ object StdNames { val asInstanceOfPM: N = "$asInstanceOf$" val assert_ : N = "assert" val assume_ : N = "assume" + val betterFors: N = "betterFors" val box: N = "box" val break: N = "break" val build : N = "build" diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 37587868da58..f4a6b5b76aa0 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -2891,7 +2891,11 @@ object Parsers { /** Enumerators ::= Generator {semi Enumerator | Guard} */ - def enumerators(): List[Tree] = generator() :: enumeratorsRest() + def enumerators(): List[Tree] = + if in.featureEnabled(Feature.betterFors) then + aliasesUntilGenerator() ++ enumeratorsRest() + else + generator() :: enumeratorsRest() def enumeratorsRest(): List[Tree] = if (isStatSep) { @@ -2933,6 +2937,18 @@ object Parsers { GenFrom(pat, subExpr(), checkMode) } + def aliasesUntilGenerator(): List[Tree] = + if in.token == CASE then generator() :: Nil + else { + val pat = pattern1() + if in.token == EQUALS then + atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, subExpr()) } :: { + if (isStatSep) in.nextToken() + aliasesUntilGenerator() + } + else generatorRest(pat, casePat = false) :: Nil + } + /** ForExpr ::= ‘for’ ‘(’ Enumerators ‘)’ {nl} [‘do‘ | ‘yield’] Expr * | ‘for’ ‘{’ Enumerators ‘}’ {nl} [‘do‘ | ‘yield’] Expr * | ‘for’ Enumerators (‘do‘ | ‘yield’) Expr diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 7db326350fa1..3e8c2ab15cd2 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -133,6 +133,12 @@ object language: @compileTimeOnly("`quotedPatternsWithPolymorphicFunctions` can only be used at compile time in import statements") object quotedPatternsWithPolymorphicFunctions + /** Experimental support for improvements in `for` comprehensions + * + * @see [[https://dotty.epfl.ch/docs/reference/experimental/better-fors]] + */ + @compileTimeOnly("`betterFors` can only be used at compile time in import statements") + object betterFors end experimental /** The deprecated object contains features that are no longer officially suypported in Scala. diff --git a/tests/run/better-fors.check b/tests/run/better-fors.check new file mode 100644 index 000000000000..8b75db2f56ad --- /dev/null +++ b/tests/run/better-fors.check @@ -0,0 +1,12 @@ +List((1,3), (1,4), (2,3), (2,4)) +List((1,2,3), (1,2,4)) +List((1,3), (1,4), (2,3), (2,4)) +List((2,3), (2,4)) +List((2,3), (2,4)) +List((1,2), (2,4)) +List(1, 2, 3) +List((2,3,6)) +List(6) +List(3, 6) +List(6) +List(2) diff --git a/tests/run/better-fors.scala b/tests/run/better-fors.scala new file mode 100644 index 000000000000..8c0bff230632 --- /dev/null +++ b/tests/run/better-fors.scala @@ -0,0 +1,105 @@ +import scala.language.experimental.betterFors + +def for1 = + for { + a = 1 + b <- List(a, 2) + c <- List(3, 4) + } yield (b, c) + +def for2 = + for + a = 1 + b = 2 + c <- List(3, 4) + yield (a, b, c) + +def for3 = + for { + a = 1 + b <- List(a, 2) + c = 3 + d <- List(c, 4) + } yield (b, d) + +def for4 = + for { + a = 1 + b <- List(a, 2) + if b > 1 + c <- List(3, 4) + } yield (b, c) + +def for5 = + for { + a = 1 + b <- List(a, 2) + c = 3 + if b > 1 + d <- List(c, 4) + } yield (b, d) + +def for6 = + for { + a = 1 + b = 2 + c <- for { + x <- List(a, b) + y = x * 2 + } yield (x, y) + } yield c + +def for7 = + for { + a <- List(1, 2, 3) + } yield a + +def for8 = + for { + a <- List(1, 2) + b = a + 1 + if b > 2 + c = b * 2 + if c < 8 + } yield (a, b, c) + +def for9 = + for { + a <- List(1, 2) + b = a * 2 + if b > 2 + } yield a + b + +def for10 = + for { + a <- List(1, 2) + b = a * 2 + } yield a + b + +def for11 = + for { + a <- List(1, 2) + b = a * 2 + if b > 2 && b % 2 == 0 + } yield a + b + +def for12 = + for { + a <- List(1, 2) + if a > 1 + } yield a + +object Test extends App { + println(for1) + println(for2) + println(for3) + println(for4) + println(for5) + println(for6) + println(for7) + println(for8) + println(for9) + println(for10) + println(for11) + println(for12) +} diff --git a/tests/run/fors.scala b/tests/run/fors.scala index bd7de7d32263..af04beb311b1 100644 --- a/tests/run/fors.scala +++ b/tests/run/fors.scala @@ -112,6 +112,8 @@ object Test extends App { /////////////////// elimination of map /////////////////// + import scala.language.experimental.betterFors + @tailrec def pair[B](xs: List[Int], ys: List[B], n: Int): List[(Int, B)] = if n == 0 then xs.zip(ys) From 6ef7d8e856cc0acceacab27138613a362b2dc5d6 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Tue, 4 Jun 2024 15:49:09 +0200 Subject: [PATCH 3/4] Cleanup for experimental SIP-62 implementation --- .../src/dotty/tools/dotc/ast/Desugar.scala | 42 ++++++++++++++++++- .../src/dotty/tools/dotc/config/Feature.scala | 3 +- .../runtime/stdLibPatches/language.scala | 2 +- project/MiMaFilters.scala | 3 +- tests/run/fors.scala | 1 + 5 files changed, 47 insertions(+), 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 4231505dce62..30868fac4475 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1804,7 +1804,7 @@ object desugar { /** Create tree for for-comprehension `` or * `` where mapName and flatMapName are chosen * corresponding to whether this is a for-do or a for-yield. - * The creation performs the following rewrite rules: + * If betterFors are enabled, the creation performs the following rewrite rules: * * 1. * @@ -1872,6 +1872,46 @@ object desugar { * (Where empty for-comprehensions are excluded by the parser) * * If the aliases are not followed by a guard, otherwise an error. + * + * With betterFors disabled, the translation is as follows: + * + * 1. + * + * for (P <- G) E ==> G.foreach (P => E) + * + * Here and in the following (P => E) is interpreted as the function (P => E) + * if P is a variable pattern and as the partial function { case P => E } otherwise. + * + * 2. + * + * for (P <- G) yield E ==> G.map (P => E) + * + * 3. + * + * for (P_1 <- G_1; P_2 <- G_2; ...) ... + * ==> + * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) + * + * 4. + * + * for (P <- G; E; ...) ... + * => + * for (P <- G.filter (P => E); ...) ... + * + * 5. For any N: + * + * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) + * ==> + * for (TupleN(P_1, P_2, ... P_N) <- + * for (x_1 @ P_1 <- G) yield { + * val x_2 @ P_2 = E_2 + * ... + * val x_N & P_N = E_N + * TupleN(x_1, ..., x_N) + * } ...) + * + * If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated + * and the variable constituting P_i is used instead of x_i * * @param mapName The name to be used for maps (either map or foreach) * @param flatMapName The name to be used for flatMaps (either flatMap or foreach) diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index cad9b4e76ca9..fa82f14a81fe 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -68,7 +68,8 @@ object Feature: (into, "Allow into modifier on parameter types"), (namedTuples, "Allow named tuples"), (modularity, "Enable experimental modularity features"), - (betterMatchTypeExtractors, "Enable better match type extractors") + (betterMatchTypeExtractors, "Enable better match type extractors"), + (betterFors, "Enable improvements in `for` comprehensions") ) // legacy language features from Scala 2 that are no longer supported. diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 3e8c2ab15cd2..3d71c0da1481 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -135,7 +135,7 @@ object language: /** Experimental support for improvements in `for` comprehensions * - * @see [[https://dotty.epfl.ch/docs/reference/experimental/better-fors]] + * @see [[https://github.com/scala/improvement-proposals/pull/79]] */ @compileTimeOnly("`betterFors` can only be used at compile time in import statements") object betterFors diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index bf652cb0ee33..88e3f2b27a84 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -8,7 +8,8 @@ object MiMaFilters { val ForwardsBreakingChanges: Map[String, Seq[ProblemFilter]] = Map( // Additions that require a new minor version of the library Build.mimaPreviousDottyVersion -> Seq( - + ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language#experimental.betterFors"), + ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$experimental$betterFors$"), ), // Additions since last LTS diff --git a/tests/run/fors.scala b/tests/run/fors.scala index af04beb311b1..a12d0e977157 100644 --- a/tests/run/fors.scala +++ b/tests/run/fors.scala @@ -6,6 +6,7 @@ import annotation.tailrec +@scala.annotation.experimental object Test extends App { val xs = List(1, 2, 3) val ys = List(Symbol("a"), Symbol("b"), Symbol("c")) From 4bc0a4a51426d493624d170830bbec1dc9503387 Mon Sep 17 00:00:00 2001 From: Kacper Korban Date: Mon, 22 Jul 2024 16:48:08 +0200 Subject: [PATCH 4/4] Merge betterFors desugaring with the default implementation --- .../src/dotty/tools/dotc/ast/Desugar.scala | 255 +++++++----------- compiler/src/dotty/tools/dotc/ast/untpd.scala | 2 +- .../src/dotty/tools/dotc/core/StdNames.scala | 1 - 3 files changed, 98 insertions(+), 160 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 30868fac4475..b892e963ea51 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1806,113 +1806,79 @@ object desugar { * corresponding to whether this is a for-do or a for-yield. * If betterFors are enabled, the creation performs the following rewrite rules: * - * 1. + * 1. if betterFors is enabled: * - * for (P <- G) do E ==> G.foreach (P => E) + * for () do E ==> E + * or + * for () yield E ==> E * - * Here and in the following (P => E) is interpreted as the function (P => E) - * if P is a variable pattern and as the partial function { case P => E } otherwise. + * (Where empty for-comprehensions are excluded by the parser) * * 2. * - * for (P <- G) yield P ==> G - * - * If P is a variable or a tuple of variables and G is not a withFilter. + * for (P <- G) do E ==> G.foreach (P => E) * - * for (P <- G) yield E ==> G.map (P => E) - * - * Otherwise + * Here and in the following (P => E) is interpreted as the function (P => E) + * if P is a variable pattern and as the partial function { case P => E } otherwise. * * 3. * - * for (P_1 <- G_1; P_2 <- G_2; ...) ... - * ==> - * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) - * - * 4. - * - * for (P <- G; if E; ...) ... - * ==> - * for (P <- G.withFilter (P => E); ...) ... - * - * 5. For any N: - * - * for (P <- G; P_1 = E_1; ... P_N = E_N; rest) - * ==> - * G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...)) if rest contains (<-) - * G.map (P => for (P_1 = E_1; ... P_N = E_N; ...)) otherwise + * for (P <- G) yield P ==> G * - * 6. For any N: + * If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter. * - * for (P <- G; P_1 = E_1; ... P_N = E_N; if E; ...) - * ==> - * for (TupleN(P, P_1, ... P_N) <- - * for (x @ P <- G) yield { - * val x_1 @ P_1 = E_2 - * ... - * val x_N @ P_N = E_N - * TupleN(x, x_1, ..., x_N) - * }; if E; ...) - * - * If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated - * and the variable constituting P_i is used instead of x_i - * - * 7. For any N: - * - * for (P_1 = E_1; ... P_N = E_N; ...) - * ==> - * { - * val x_N @ P_N = E_N - * for (...) - * } - * - * 8. - * for () yield E ==> E - * - * (Where empty for-comprehensions are excluded by the parser) + * for (P <- G) yield E ==> G.map (P => E) * - * If the aliases are not followed by a guard, otherwise an error. - * - * With betterFors disabled, the translation is as follows: - * - * 1. + * Otherwise * - * for (P <- G) E ==> G.foreach (P => E) + * 4. * - * Here and in the following (P => E) is interpreted as the function (P => E) - * if P is a variable pattern and as the partial function { case P => E } otherwise. + * for (P_1 <- G_1; P_2 <- G_2; ...) ... + * ==> + * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) * - * 2. + * 5. * - * for (P <- G) yield E ==> G.map (P => E) + * for (P <- G; if E; ...) ... + * ==> + * for (P <- G.withFilter (P => E); ...) ... * - * 3. + * 6. For any N, if betterFors is enabled: * - * for (P_1 <- G_1; P_2 <- G_2; ...) ... + * for (P <- G; P_1 = E_1; ... P_N = E_N; P1 <- G1; ...) ... * ==> - * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) + * G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...)) * - * 4. + * 7. For any N, if betterFors is enabled: * - * for (P <- G; E; ...) ... - * => - * for (P <- G.filter (P => E); ...) ... + * for (P <- G; P_1 = E_1; ... P_N = E_N) ... + * ==> + * G.map (P => for (P_1 = E_1; ... P_N = E_N) ...) * - * 5. For any N: + * 8. For any N: * - * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) + * for (P <- G; P_1 = E_1; ... P_N = E_N; ...) * ==> - * for (TupleN(P_1, P_2, ... P_N) <- - * for (x_1 @ P_1 <- G) yield { - * val x_2 @ P_2 = E_2 + * for (TupleN(P, P_1, ... P_N) <- + * for (x @ P <- G) yield { + * val x_1 @ P_1 = E_2 * ... - * val x_N & P_N = E_N - * TupleN(x_1, ..., x_N) - * } ...) + * val x_N @ P_N = E_N + * TupleN(x, x_1, ..., x_N) + * }; if E; ...) * * If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated * and the variable constituting P_i is used instead of x_i * + * 9. For any N, if betterFors is enabled: + * + * for (P_1 = E_1; ... P_N = E_N; ...) + * ==> + * { + * val x_N @ P_N = E_N + * for (...) + * } + * * @param mapName The name to be used for maps (either map or foreach) * @param flatMapName The name to be used for flatMaps (either flatMap or foreach) * @param enums The enumerators in the for expression @@ -2037,86 +2003,59 @@ object desugar { case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals) case _ => false - if betterForsEnabled then - enums match { - case Nil => body - case (gen: GenFrom) :: Nil => - if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type - && deepEquals(gen.pat, body) - then gen.expr // avoid a redundant map with identity - else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) - case (gen: GenFrom) :: rest - if rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) => - val cont = makeFor(mapName, flatMapName, rest, body) - val selectName = - if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName - else mapName - Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) - case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => - val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) - val pats = valeqs map { case GenAlias(pat, _) => pat } - val rhss = valeqs map { case GenAlias(_, rhs) => rhs } - val (defpat0, id0) = makeIdPat(gen.pat) - val (defpats, ids) = (pats map makeIdPat).unzip - val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => - val mods = defpat match - case defTree: DefTree => defTree.mods - case _ => Modifiers() - makePatDef(valeq, mods, defpat, rhs) - } - val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) - val allpats = gen.pat :: pats - val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) - makeFor(mapName, flatMapName, vfrom1 :: rest1, body) - case (gen: GenFrom) :: test :: rest => - val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) - val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered) - makeFor(mapName, flatMapName, genFrom :: rest, body) - case GenAlias(_, _) :: _ => - val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias]) - val pats = valeqs.map { case GenAlias(pat, _) => pat } - val rhss = valeqs.map { case GenAlias(_, rhs) => rhs } - val (defpats, ids) = pats.map(makeIdPat).unzip - val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => - val mods = defpat match - case defTree: DefTree => defTree.mods - case _ => Modifiers() - makePatDef(valeq, mods, defpat, rhs) - } - Block(pdefs, makeFor(mapName, flatMapName, rest, body)) - case _ => - EmptyTree //may happen for erroneous input - } - else { - enums match { - case (gen: GenFrom) :: Nil => - Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) - case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => - val cont = makeFor(mapName, flatMapName, rest, body) - Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) - case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => - val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) - val pats = valeqs map { case GenAlias(pat, _) => pat } - val rhss = valeqs map { case GenAlias(_, rhs) => rhs } - val (defpat0, id0) = makeIdPat(gen.pat) - val (defpats, ids) = (pats map makeIdPat).unzip - val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => - val mods = defpat match - case defTree: DefTree => defTree.mods - case _ => Modifiers() - makePatDef(valeq, mods, defpat, rhs) - } - val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) - val allpats = gen.pat :: pats - val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) - makeFor(mapName, flatMapName, vfrom1 :: rest1, body) - case (gen: GenFrom) :: test :: rest => - val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) - val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore) - makeFor(mapName, flatMapName, genFrom :: rest, body) - case _ => - EmptyTree //may happen for erroneous input - } + enums match { + case Nil if betterForsEnabled => body + case (gen: GenFrom) :: Nil => + if betterForsEnabled + && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + && deepEquals(gen.pat, body) + then gen.expr // avoid a redundant map with identity + else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => + val cont = makeFor(mapName, flatMapName, rest, body) + Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) + case (gen: GenFrom) :: rest + if betterForsEnabled + && rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) => // possible aliases followed by a generator or end of for + val cont = makeFor(mapName, flatMapName, rest, body) + val selectName = + if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName + else mapName + Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) + case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => + val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) + val pats = valeqs map { case GenAlias(pat, _) => pat } + val rhss = valeqs map { case GenAlias(_, rhs) => rhs } + val (defpat0, id0) = makeIdPat(gen.pat) + val (defpats, ids) = (pats map makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids))) + val allpats = gen.pat :: pats + val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore) + makeFor(mapName, flatMapName, vfrom1 :: rest1, body) + case (gen: GenFrom) :: test :: rest => + val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) + val genFrom = GenFrom(gen.pat, filtered, if betterForsEnabled then GenCheckMode.Filtered else GenCheckMode.Ignore) + makeFor(mapName, flatMapName, genFrom :: rest, body) + case GenAlias(_, _) :: _ if betterForsEnabled => + val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias]) + val pats = valeqs.map { case GenAlias(pat, _) => pat } + val rhss = valeqs.map { case GenAlias(_, rhs) => rhs } + val (defpats, ids) = pats.map(makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + Block(pdefs, makeFor(mapName, flatMapName, rest, body)) + case _ => + EmptyTree //may happen for erroneous input } } diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index a3aee4dc17d2..60309d4d83bd 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -183,7 +183,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { /** An enum to control checking or filtering of patterns in GenFrom trees */ enum GenCheckMode { - case Ignore // neither filter since pattern is trivially irrefutable + case Ignore // neither filter nor check since pattern is trivially irrefutable case Filtered // neither filter nor check since filtering was done before case Check // check that pattern is irrefutable case CheckAndFilter // both check and filter (transitional period starting with 3.2) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index bbe405b46bf1..d3e198a7e7a7 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -435,7 +435,6 @@ object StdNames { val asInstanceOfPM: N = "$asInstanceOf$" val assert_ : N = "assert" val assume_ : N = "assume" - val betterFors: N = "betterFors" val box: N = "box" val break: N = "break" val build : N = "build"