diff --git a/modules/build/src/main/scala/scala/build/Build.scala b/modules/build/src/main/scala/scala/build/Build.scala
index 66424f0998..093252d3c5 100644
--- a/modules/build/src/main/scala/scala/build/Build.scala
+++ b/modules/build/src/main/scala/scala/build/Build.scala
@@ -110,18 +110,11 @@ object Build {
logger: Logger
): Either[Seq[String], String] = {
val scriptInferredMainClasses =
- sources.inMemory.map(im => im.originalPath.map(_._1))
- .flatMap {
- case Right(originalRelPath) if originalRelPath.toString.endsWith(".sc") =>
- Some {
- originalRelPath
- .toString
- .replace(".", "_")
- .replace("/", ".")
- }
- case Left(VirtualScriptNameRegex(name)) => Some(s"${name}_sc")
- case _ => None
- }
+ sources.inMemory.collect {
+ case Sources.InMemory(_, _, _, Some(wrapperParams)) =>
+ wrapperParams.mainClass
+ }
+
val filteredMainClasses =
mainClasses.filter(!scriptInferredMainClasses.contains(_))
if (filteredMainClasses.length == 1) {
@@ -279,10 +272,12 @@ object Build {
val scopedSources = value(crossSources.scopedSources(baseOptions))
- val mainSources = value(scopedSources.sources(Scope.Main, baseOptions, allInputs.workspace))
+ val mainSources =
+ value(scopedSources.sources(Scope.Main, baseOptions, allInputs.workspace, logger))
val mainOptions = mainSources.buildOptions
- val testSources = value(scopedSources.sources(Scope.Test, baseOptions, allInputs.workspace))
+ val testSources =
+ value(scopedSources.sources(Scope.Test, baseOptions, allInputs.workspace, logger))
val testOptions = testSources.buildOptions
val inputs0 = updateInputs(
diff --git a/modules/build/src/main/scala/scala/build/CrossSources.scala b/modules/build/src/main/scala/scala/build/CrossSources.scala
index 62bffee541..f9c917d49b 100644
--- a/modules/build/src/main/scala/scala/build/CrossSources.scala
+++ b/modules/build/src/main/scala/scala/build/CrossSources.scala
@@ -47,7 +47,7 @@ import scala.util.chaining.*
final case class CrossSources(
paths: Seq[WithBuildRequirements[(os.Path, os.RelPath)]],
inMemory: Seq[WithBuildRequirements[Sources.InMemory]],
- defaultMainClass: Option[String],
+ defaultMainElemPath: Option[os.Path],
resourceDirs: Seq[WithBuildRequirements[os.Path]],
buildOptions: Seq[WithBuildRequirements[BuildOptions]],
unwrappedScripts: Seq[WithBuildRequirements[Sources.UnwrappedScript]]
@@ -130,7 +130,7 @@ final case class CrossSources(
ScopedSources(
crossSources0.paths.map(_.scopedValue(defaultScope)),
crossSources0.inMemory.map(_.scopedValue(defaultScope)),
- defaultMainClass,
+ defaultMainElemPath,
crossSources0.resourceDirs.map(_.scopedValue(defaultScope)),
crossSources0.buildOptions.map(_.scopedValue(defaultScope)),
crossSources0.unwrappedScripts.map(_.scopedValue(defaultScope))
@@ -273,12 +273,9 @@ object CrossSources {
)
}).flatten
- val defaultMainClassOpt: Option[String] = for {
- mainClassPath <- allInputs.defaultMainClassElement
- .map(s => ScopePath.fromPath(s.path).subPath)
- processedMainClass <- preprocessedSources.find(_.scopePath.subPath == mainClassPath)
- mainClass <- processedMainClass.mainClassOpt
- } yield mainClass
+ val defaultMainElemPath = for {
+ defaultMainElem <- allInputs.defaultMainClassElement
+ } yield defaultMainElem.path
val pathsWithDirectivePositions
: Seq[(WithBuildRequirements[(os.Path, os.RelPath)], Option[Position.File])] =
@@ -369,7 +366,7 @@ object CrossSources {
CrossSources(
paths,
inMemory,
- defaultMainClassOpt,
+ defaultMainElemPath,
resourceDirs,
buildOptions,
unwrappedScripts
diff --git a/modules/build/src/main/scala/scala/build/ScopedSources.scala b/modules/build/src/main/scala/scala/build/ScopedSources.scala
index 4749912826..75bb29e710 100644
--- a/modules/build/src/main/scala/scala/build/ScopedSources.scala
+++ b/modules/build/src/main/scala/scala/build/ScopedSources.scala
@@ -5,6 +5,8 @@ import java.nio.charset.StandardCharsets
import scala.build.EitherCps.{either, value}
import scala.build.errors.BuildException
import scala.build.info.{BuildInfo, ScopedBuildInfo}
+import scala.build.internal.AppCodeWrapper
+import scala.build.internal.util.WarningMessages
import scala.build.options.{BuildOptions, HasScope, Scope}
import scala.build.preprocessing.ScriptPreprocessor
@@ -28,7 +30,7 @@ import scala.build.preprocessing.ScriptPreprocessor
final case class ScopedSources(
paths: Seq[HasScope[(os.Path, os.RelPath)]],
inMemory: Seq[HasScope[Sources.InMemory]],
- defaultMainClass: Option[String],
+ defaultMainElemPath: Option[os.Path],
resourceDirs: Seq[HasScope[os.Path]],
buildOptions: Seq[HasScope[BuildOptions]],
unwrappedScripts: Seq[HasScope[Sources.UnwrappedScript]]
@@ -55,7 +57,8 @@ final case class ScopedSources(
def sources(
scope: Scope,
baseOptions: BuildOptions,
- workspace: os.Path
+ workspace: os.Path,
+ logger: Logger
): Either[BuildException, Sources] = either {
val combinedOptions = combinedBuildOptions(scope, baseOptions)
@@ -65,6 +68,21 @@ final case class ScopedSources(
.flatMap(_.valueFor(scope).toSeq)
.map(_.wrap(codeWrapper))
+ codeWrapper match {
+ case _: AppCodeWrapper.type if wrappedScripts.size > 1 =>
+ wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc"))
+ .foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper))
+ case _ => ()
+ }
+
+ val defaultMainClass = defaultMainElemPath.flatMap { mainElemPath =>
+ wrappedScripts.collectFirst {
+ case Sources.InMemory(Right((_, path)), _, _, Some(wrapperParams))
+ if mainElemPath == path =>
+ wrapperParams.mainClass
+ }
+ }
+
val needsBuildInfo = combinedOptions.sourceGeneratorOptions.useBuildInfo.getOrElse(false)
val maybeBuildInfoSource = if (needsBuildInfo && scope == Scope.Main)
diff --git a/modules/build/src/main/scala/scala/build/bsp/BspImpl.scala b/modules/build/src/main/scala/scala/build/bsp/BspImpl.scala
index 2ae8977b3c..f208b3e79d 100644
--- a/modules/build/src/main/scala/scala/build/bsp/BspImpl.scala
+++ b/modules/build/src/main/scala/scala/build/bsp/BspImpl.scala
@@ -123,12 +123,12 @@ final class BspImpl(
pprint.err.log(scopedSources)
val sourcesMain = value {
- scopedSources.sources(Scope.Main, sharedOptions, allInputs.workspace)
+ scopedSources.sources(Scope.Main, sharedOptions, allInputs.workspace, persistentLogger)
.left.map((_, Scope.Main))
}
val sourcesTest = value {
- scopedSources.sources(Scope.Test, sharedOptions, allInputs.workspace)
+ scopedSources.sources(Scope.Test, sharedOptions, allInputs.workspace, persistentLogger)
.left.map((_, Scope.Test))
}
diff --git a/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala
new file mode 100644
index 0000000000..c20fd26983
--- /dev/null
+++ b/modules/build/src/main/scala/scala/build/internal/AppCodeWrapper.scala
@@ -0,0 +1,33 @@
+package scala.build.internal
+
+case object AppCodeWrapper extends CodeWrapper {
+ override def mainClassObject(className: Name) = className
+
+ def apply(
+ code: String,
+ pkgName: Seq[Name],
+ indexedWrapperName: Name,
+ extraCode: String,
+ scriptPath: String
+ ) = {
+ val wrapperObjectName = indexedWrapperName.backticked
+
+ val packageDirective =
+ if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
+ val top = AmmUtil.normalizeNewlines(
+ s"""$packageDirective
+ |
+ |object $wrapperObjectName extends App {
+ |val scriptPath = \"\"\"$scriptPath\"\"\"
+ |""".stripMargin
+ )
+ val bottom = AmmUtil.normalizeNewlines(
+ s"""
+ |$extraCode
+ |}
+ |""".stripMargin
+ )
+
+ (top, bottom)
+ }
+}
diff --git a/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala
index d64d1c5880..8adcd8ad8b 100644
--- a/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala
+++ b/modules/build/src/main/scala/scala/build/internal/ClassCodeWrapper.scala
@@ -6,6 +6,9 @@ package scala.build.internal
* Scala 3 feature 'export'
Incompatible with native JS members - the wrapper is a class
*/
case object ClassCodeWrapper extends CodeWrapper {
+
+ override def mainClassObject(className: Name): Name =
+ Name(className.raw ++ "_sc")
def apply(
code: String,
pkgName: Seq[Name],
@@ -13,7 +16,7 @@ case object ClassCodeWrapper extends CodeWrapper {
extraCode: String,
scriptPath: String
) = {
- val name = CodeWrapper.mainClassObject(indexedWrapperName).backticked
+ val name = mainClassObject(indexedWrapperName).backticked
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
@@ -34,29 +37,27 @@ case object ClassCodeWrapper extends CodeWrapper {
| }
|}
|
- |export $name.script as ${indexedWrapperName.backticked}
+ |export $name.script as `${indexedWrapperName.raw}`
|""".stripMargin)
val packageDirective =
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
- // indentation is important in the generated code, so we don't want scalafmt to touch that
- // format: off
- val top = AmmUtil.normalizeNewlines(s"""
-$packageDirective
-
-
-final class $wrapperClassName {
-def args = $name.args$$
-def scriptPath = \"\"\"$scriptPath\"\"\"
-""")
- val bottom = AmmUtil.normalizeNewlines(s"""
-$extraCode
-}
-
-$mainObjectCode
-""")
- // format: on
+ val top = AmmUtil.normalizeNewlines(
+ s"""$packageDirective
+ |
+ |final class $wrapperClassName {
+ |def args = $name.args$$
+ |def scriptPath = \"\"\"$scriptPath\"\"\"
+ |""".stripMargin
+ )
+ val bottom = AmmUtil.normalizeNewlines(
+ s"""$extraCode
+ |}
+ |
+ |$mainObjectCode
+ |""".stripMargin
+ )
(top, bottom)
}
diff --git a/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala b/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala
index 191d16dace..bbd1f9e9b9 100644
--- a/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala
+++ b/modules/build/src/main/scala/scala/build/internal/ObjectCodeWrapper.scala
@@ -5,6 +5,9 @@ package scala.build.internal
* threads from script
*/
case object ObjectCodeWrapper extends CodeWrapper {
+
+ override def mainClassObject(className: Name): Name =
+ Name(className.raw ++ "_sc")
def apply(
code: String,
pkgName: Seq[Name],
@@ -12,7 +15,7 @@ case object ObjectCodeWrapper extends CodeWrapper {
extraCode: String,
scriptPath: String
) = {
- val name = CodeWrapper.mainClassObject(indexedWrapperName).backticked
+ val name = mainClassObject(indexedWrapperName).backticked
val aliasedWrapperName = name + "$$alias"
val funHashCodeMethod =
if (name == "main_sc")
@@ -46,23 +49,22 @@ case object ObjectCodeWrapper extends CodeWrapper {
|}""".stripMargin
else ""
- // indentation is important in the generated code, so we don't want scalafmt to touch that
- // format: off
- val top = AmmUtil.normalizeNewlines(s"""
-$packageDirective
-
+ val top = AmmUtil.normalizeNewlines(
+ s"""$packageDirective
+ |
+ |object ${indexedWrapperName.backticked} {
+ |def args = $name.args$$
+ |def scriptPath = \"\"\"$scriptPath\"\"\"
+ |""".stripMargin
+ )
-object ${indexedWrapperName.backticked} {
-def args = $name.args$$
-def scriptPath = \"\"\"$scriptPath\"\"\"
-""")
- val bottom = AmmUtil.normalizeNewlines(s"""
-$extraCode
-}
-$aliasObject
-$mainObjectCode
-""")
- // format: on
+ val bottom = AmmUtil.normalizeNewlines(
+ s"""$extraCode
+ |}
+ |$aliasObject
+ |$mainObjectCode
+ |""".stripMargin
+ )
(top, bottom)
}
diff --git a/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala b/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala
index 92e0d33e53..c14624b5bd 100644
--- a/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala
+++ b/modules/build/src/main/scala/scala/build/internal/util/WarningMessages.scala
@@ -120,4 +120,7 @@ object WarningMessages {
|$recommendedMsg
|""".stripMargin
}
+
+ val mainScriptNameClashesWithAppWrapper =
+ "Script file named 'main.sc' detected, keep in mind that accessing it from other scripts is impossible due to a clash of `main` symbols"
}
diff --git a/modules/build/src/main/scala/scala/build/preprocessing/PreprocessedSource.scala b/modules/build/src/main/scala/scala/build/preprocessing/PreprocessedSource.scala
index 0b1bf9966f..63d644d288 100644
--- a/modules/build/src/main/scala/scala/build/preprocessing/PreprocessedSource.scala
+++ b/modules/build/src/main/scala/scala/build/preprocessing/PreprocessedSource.scala
@@ -13,10 +13,10 @@ sealed abstract class PreprocessedSource extends Product with Serializable {
def scopePath: ScopePath
def directivesPositions: Option[Position.File]
def distinctPathOrSource: String = this match {
- case PreprocessedSource.OnDisk(p, _, _, _, _, _, _) => p.toString
- case PreprocessedSource.InMemory(op, rp, _, _, _, _, _, _, _, _, _) => s"$op; $rp"
- case PreprocessedSource.UnwrappedScript(p, _, _, _, _, _, _, _, _, _) => p.toString
- case PreprocessedSource.NoSourceCode(_, _, _, _, p) => p.toString
+ case p: PreprocessedSource.OnDisk => p.path.toString
+ case p: PreprocessedSource.InMemory => s"${p.originalPath}; ${p.relPath}"
+ case p: PreprocessedSource.UnwrappedScript => p.originalPath.toString
+ case p: PreprocessedSource.NoSourceCode => p.path.toString
}
}
@@ -58,11 +58,12 @@ object PreprocessedSource {
optionsWithTargetRequirements: List[WithBuildRequirements[BuildOptions]],
requirements: Option[BuildRequirements],
scopedRequirements: Seq[Scoped[BuildRequirements]],
- mainClassOpt: Option[String],
scopePath: ScopePath,
directivesPositions: Option[Position.File],
wrapScriptFun: CodeWrapper => (String, WrapperParams)
- ) extends PreprocessedSource
+ ) extends PreprocessedSource {
+ override def mainClassOpt: Option[String] = None
+ }
final case class NoSourceCode(
options: Option[BuildOptions],
diff --git a/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala b/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala
index 89bddc4e72..2603602c76 100644
--- a/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala
+++ b/modules/build/src/main/scala/scala/build/preprocessing/ScriptPreprocessor.scala
@@ -6,15 +6,8 @@ import scala.build.EitherCps.{either, value}
import scala.build.Logger
import scala.build.errors.BuildException
import scala.build.input.{Inputs, ScalaCliInvokeData, Script, SingleElement, VirtualScript}
+import scala.build.internal.*
import scala.build.internal.util.WarningMessages
-import scala.build.internal.{
- AmmUtil,
- ClassCodeWrapper,
- CodeWrapper,
- Name,
- ObjectCodeWrapper,
- WrapperParams
-}
import scala.build.options.{BuildOptions, BuildRequirements, Platform, SuppressWarningOptions}
import scala.build.preprocessing.PreprocessedSource
import scala.build.preprocessing.ScalaPreprocessor.ProcessingOutput
@@ -123,7 +116,6 @@ case object ScriptPreprocessor extends Preprocessor {
optionsWithTargetRequirements = processingOutput.optsWithReqs,
requirements = Some(processingOutput.globalReqs),
scopedRequirements = processingOutput.scopedReqs,
- mainClassOpt = Some(CodeWrapper.mainClassObject(Name(className)).backticked),
scopePath = scopePath,
directivesPositions = processingOutput.directivesPositions,
wrapScriptFun = wrapScriptFun
@@ -142,7 +134,7 @@ case object ScriptPreprocessor extends Preprocessor {
(codeWrapper: CodeWrapper) =>
if (containsMainAnnot) logger.diagnostic(
codeWrapper match {
- case _: ObjectCodeWrapper.type =>
+ case _: AppCodeWrapper.type =>
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ true)
case _ => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
}
@@ -165,18 +157,27 @@ case object ScriptPreprocessor extends Preprocessor {
* @return
* code wrapper compatible with provided BuildOptions
*/
- def getScriptWrapper(buildOptions: BuildOptions): CodeWrapper =
+ def getScriptWrapper(buildOptions: BuildOptions): CodeWrapper = {
+ val scalaVersionOpt = for {
+ maybeScalaVersion <- buildOptions.scalaOptions.scalaVersion
+ scalaVersion <- maybeScalaVersion.versionOpt
+ } yield scalaVersion
+
+ def objectCodeWrapperForScalaVersion =
+ // AppObjectWrapper only introduces the 'main.sc' restriction when used in Scala 3, there's no gain in using it with Scala 3
+ if (scalaVersionOpt.exists(_.startsWith("2")))
+ AppCodeWrapper
+ else
+ ObjectCodeWrapper
+
buildOptions.scriptOptions.forceObjectWrapper match {
- case Some(true) => ObjectCodeWrapper
+ case Some(true) => objectCodeWrapperForScalaVersion
case _ =>
- val scalaVersionOpt = for {
- maybeScalaVersion <- buildOptions.scalaOptions.scalaVersion
- scalaVersion <- maybeScalaVersion.versionOpt
- } yield scalaVersion
buildOptions.scalaOptions.platform.map(_.value) match {
- case Some(_: Platform.JS.type) => ObjectCodeWrapper
- case _ if scalaVersionOpt.exists(_.startsWith("2")) => ObjectCodeWrapper
+ case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
+ case _ if scalaVersionOpt.exists(_.startsWith("2")) => AppCodeWrapper
case _ => ClassCodeWrapper
}
}
+ }
}
diff --git a/modules/build/src/test/scala/scala/build/tests/BuildTests.scala b/modules/build/src/test/scala/scala/build/tests/BuildTests.scala
index 37cfb49975..77a3d8c3bc 100644
--- a/modules/build/src/test/scala/scala/build/tests/BuildTests.scala
+++ b/modules/build/src/test/scala/scala/build/tests/BuildTests.scala
@@ -13,7 +13,6 @@ import scala.build.errors.{
InvalidBinaryScalaVersionError,
ScalaNativeCompatibilityError
}
-import scala.build.internal.{ClassCodeWrapper, ObjectCodeWrapper}
import scala.build.options.{
BuildOptions,
InternalOptions,
@@ -84,9 +83,8 @@ abstract class BuildTests(server: Boolean) extends munit.FunSuite {
if (checkResults)
maybeBuild.orThrow.assertGeneratedEquals(
"simple.class",
- "simple_sc.class",
"simple$.class",
- "simple_sc$.class"
+ "simple$delayedInit$body.class"
)
}
}
@@ -187,10 +185,9 @@ abstract class BuildTests(server: Boolean) extends munit.FunSuite {
testInputs.withBuild(buildOptions, buildThreads, bloopConfigOpt) { (_, _, maybeBuild) =>
val build = maybeBuild.orThrow
build.assertGeneratedEquals(
- "simple.class",
- "simple_sc.class",
+ "simple$delayedInit$body.class",
"simple$.class",
- "simple_sc$.class",
+ "simple.class",
"META-INF/semanticdb/simple.sc.semanticdb"
)
maybeBuild.orThrow.assertNoDiagnostics
@@ -263,14 +260,12 @@ abstract class BuildTests(server: Boolean) extends munit.FunSuite {
testInputs.withBuild(defaultOptions.enableJs, buildThreads, bloopConfigOpt) {
(_, _, maybeBuild) =>
maybeBuild.orThrow.assertGeneratedEquals(
- "simple.class",
- "simple_sc.class",
- "simple.sjsir",
- "simple$.sjsir",
- "simple_sc.sjsir",
"simple$.class",
- "simple_sc$.class",
- "simple_sc$.sjsir"
+ "simple$.sjsir",
+ "simple$delayedInit$body.class",
+ "simple$delayedInit$body.sjsir",
+ "simple.class",
+ "simple.sjsir"
)
maybeBuild.orThrow.assertNoDiagnostics
}
@@ -288,13 +283,10 @@ abstract class BuildTests(server: Boolean) extends munit.FunSuite {
maybeBuild.orThrow.assertGeneratedEquals(
"simple$.class",
"simple$.nir",
+ "simple$delayedInit$body.class",
+ "simple$delayedInit$body.nir",
"simple.class",
- "simple.nir",
- "simple_sc$$$Lambda$1.nir",
- "simple_sc$.class",
- "simple_sc$.nir",
- "simple_sc.class",
- "simple_sc.nir"
+ "simple.nir"
)
maybeBuild.orThrow.assertNoDiagnostics
}
@@ -316,9 +308,8 @@ abstract class BuildTests(server: Boolean) extends munit.FunSuite {
testInputs.withBuild(defaultOptions, buildThreads, bloopConfigOpt) { (_, _, maybeBuild) =>
maybeBuild.orThrow.assertGeneratedEquals(
"simple.class",
- "simple_sc.class",
"simple$.class",
- "simple_sc$.class"
+ "simple$delayedInit$body.class"
)
maybeBuild.orThrow.assertNoDiagnostics
}
@@ -342,14 +333,12 @@ abstract class BuildTests(server: Boolean) extends munit.FunSuite {
)
testInputs.withBuild(defaultOptions, buildThreads, bloopConfigOpt) { (_, _, maybeBuild) =>
maybeBuild.orThrow.assertGeneratedEquals(
- "simple.class",
- "simple_sc.class",
"simple$.class",
- "simple_sc$.class",
- "simple2.class",
- "simple2_sc.class",
+ "simple$delayedInit$body.class",
+ "simple.class",
"simple2$.class",
- "simple2_sc$.class"
+ "simple2$delayedInit$body.class",
+ "simple2.class"
)
maybeBuild.orThrow.assertNoDiagnostics
}
diff --git a/modules/build/src/test/scala/scala/build/tests/ExcludeTests.scala b/modules/build/src/test/scala/scala/build/tests/ExcludeTests.scala
index 984914152c..139bae2643 100644
--- a/modules/build/src/test/scala/scala/build/tests/ExcludeTests.scala
+++ b/modules/build/src/test/scala/scala/build/tests/ExcludeTests.scala
@@ -97,7 +97,12 @@ class ExcludeTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions())
.orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(sources.paths.nonEmpty)
@@ -126,7 +131,12 @@ class ExcludeTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions())
.orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(sources.paths.nonEmpty)
@@ -155,7 +165,12 @@ class ExcludeTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions())
.orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(sources.paths.nonEmpty)
@@ -184,7 +199,12 @@ class ExcludeTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions())
.orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(sources.paths.nonEmpty)
diff --git a/modules/build/src/test/scala/scala/build/tests/ScriptWrapperTests.scala b/modules/build/src/test/scala/scala/build/tests/ScriptWrapperTests.scala
index 62a2d0714c..640e9352f3 100644
--- a/modules/build/src/test/scala/scala/build/tests/ScriptWrapperTests.scala
+++ b/modules/build/src/test/scala/scala/build/tests/ScriptWrapperTests.scala
@@ -20,6 +20,19 @@ import scala.build.{Build, BuildThreads, Directories, LocalRepo, Position, Posit
class ScriptWrapperTests extends munit.FunSuite {
+ def expectAppWrapper(wrapperName: String, path: os.Path) = {
+ val generatedFileContent = os.read(path)
+ assert(
+ generatedFileContent.contains(s"object $wrapperName extends App {"),
+ clue(s"Generated file content: $generatedFileContent")
+ )
+ assert(
+ !generatedFileContent.contains(s"final class $wrapperName$$_") &&
+ !generatedFileContent.contains(s"object $wrapperName {"),
+ clue(s"Generated file content: $generatedFileContent")
+ )
+ }
+
def expectObjectWrapper(wrapperName: String, path: os.Path) = {
val generatedFileContent = os.read(path)
assert(
@@ -27,7 +40,8 @@ class ScriptWrapperTests extends munit.FunSuite {
clue(s"Generated file content: $generatedFileContent")
)
assert(
- !generatedFileContent.contains(s"final class $wrapperName$$_"),
+ !generatedFileContent.contains(s"final class $wrapperName$$_") &&
+ !generatedFileContent.contains(s"object $wrapperName extends App {"),
clue(s"Generated file content: $generatedFileContent")
)
}
@@ -39,6 +53,7 @@ class ScriptWrapperTests extends munit.FunSuite {
clue(s"Generated file content: $generatedFileContent")
)
assert(
+ !generatedFileContent.contains(s"object $wrapperName extends App {") &&
!generatedFileContent.contains(s"object $wrapperName {"),
clue(s"Generated file content: $generatedFileContent")
)
@@ -116,7 +131,6 @@ class ScriptWrapperTests extends munit.FunSuite {
useDirectives <- Seq(true, false)
(directive, options, optionName) <- Seq(
("//> using object.wrapper", objectWrapperOptions, "--object-wrapper"),
- ("//> using scala 2.13", scala213Options, "--scala 2.13"),
("//> using platform js", platfromJsOptions, "--js")
)
} {
@@ -157,6 +171,49 @@ class ScriptWrapperTests extends munit.FunSuite {
}
}
+ for {
+ useDirectives <- Seq(true, false)
+ (directive, options, optionName) <- Seq(
+ ("//> using scala 2.13", scala213Options, "--scala 2.13")
+ )
+ } {
+ val inputs = TestInputs(
+ os.rel / "script1.sc" ->
+ s"""//> using dep "com.lihaoyi::os-lib:0.9.1"
+ |${if (useDirectives) directive else ""}
+ |
+ |def main(args: String*): Unit = println("Hello")
+ |main()
+ |""".stripMargin,
+ os.rel / "script2.sc" ->
+ """//> using dep "com.lihaoyi::os-lib:0.9.1"
+ |
+ |println("Hello")
+ |""".stripMargin
+ )
+
+ test(
+ s"App object wrapper forced with ${if (useDirectives) directive else optionName}"
+ ) {
+ inputs.withBuild(options orElse baseOptions, buildThreads, bloopConfigOpt) {
+ (root, _, maybeBuild) =>
+ expect(maybeBuild.orThrow.success)
+ val projectDir = os.list(root / ".scala-build").filter(
+ _.baseName.startsWith(root.baseName + "_")
+ )
+ expect(projectDir.size == 1)
+ expectAppWrapper(
+ "script1",
+ projectDir.head / "src_generated" / "main" / "script1.scala"
+ )
+ expectAppWrapper(
+ "script2",
+ projectDir.head / "src_generated" / "main" / "script2.scala"
+ )
+ }
+ }
+ }
+
for {
(targetDirective, enablingDirective) <- Seq(
("target.scala 3.2.2", "scala 3.2.2"),
diff --git a/modules/build/src/test/scala/scala/build/tests/SourcesTests.scala b/modules/build/src/test/scala/scala/build/tests/SourcesTests.scala
index 6394f48271..85ed2271c7 100644
--- a/modules/build/src/test/scala/scala/build/tests/SourcesTests.scala
+++ b/modules/build/src/test/scala/scala/build/tests/SourcesTests.scala
@@ -8,7 +8,6 @@ import dependency.*
import java.nio.charset.StandardCharsets
import scala.build.Ops.*
import scala.build.{CrossSources, Position, Sources}
-import scala.build.internal.ObjectCodeWrapper
import scala.build.errors.{UsingDirectiveValueNumError, UsingDirectiveWrongValueTypeError}
import scala.build.input.ScalaCliInvokeData
import scala.build.options.{BuildOptions, Scope, SuppressWarningOptions}
@@ -66,7 +65,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
val obtainedDeps = sources.buildOptions.classPathOptions.extraDependencies.toSeq.toSeq.map(
@@ -104,7 +108,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(
@@ -139,7 +148,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(
@@ -174,7 +188,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(sources.buildOptions.classPathOptions.extraDependencies.toSeq.map(_.value).isEmpty)
@@ -212,7 +231,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(
@@ -252,7 +276,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(
@@ -321,7 +350,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
val parsedCodes: Seq[String] =
@@ -358,7 +392,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(
@@ -397,7 +436,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
expect(
@@ -427,7 +471,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
val javaOpts = sources.buildOptions.javaOptions.javaOpts.toSeq.sortBy(_.toString)
@@ -468,7 +517,12 @@ class SourcesTests extends munit.FunSuite {
val scopedSources = crossSources.scopedSources(BuildOptions()).orThrow
val sources =
- scopedSources.sources(Scope.Main, crossSources.sharedOptions(BuildOptions()), root)
+ scopedSources.sources(
+ Scope.Main,
+ crossSources.sharedOptions(BuildOptions()),
+ root,
+ TestLogger()
+ )
.orThrow
val jsOptions = sources.buildOptions.scalaJsOptions
diff --git a/modules/cli/src/main/scala/scala/cli/commands/dependencyupdate/DependencyUpdate.scala b/modules/cli/src/main/scala/scala/cli/commands/dependencyupdate/DependencyUpdate.scala
index b9ec91c686..a1bd19167d 100644
--- a/modules/cli/src/main/scala/scala/cli/commands/dependencyupdate/DependencyUpdate.scala
+++ b/modules/cli/src/main/scala/scala/cli/commands/dependencyupdate/DependencyUpdate.scala
@@ -47,7 +47,8 @@ object DependencyUpdate extends ScalaCommand[DependencyUpdateOptions] {
def generateActionableUpdateDiagnostic(scope: Scope)
: Seq[ActionableDependencyUpdateDiagnostic] = {
- val sources = scopedSources.sources(scope, sharedOptions, inputs.workspace).orExit(logger)
+ val sources =
+ scopedSources.sources(scope, sharedOptions, inputs.workspace, logger).orExit(logger)
if (verbosity >= 3)
pprint.err.log(sources)
diff --git a/modules/cli/src/main/scala/scala/cli/commands/export0/Export.scala b/modules/cli/src/main/scala/scala/cli/commands/export0/Export.scala
index 23a94e1ba3..b8177ffc4a 100644
--- a/modules/cli/src/main/scala/scala/cli/commands/export0/Export.scala
+++ b/modules/cli/src/main/scala/scala/cli/commands/export0/Export.scala
@@ -56,7 +56,12 @@ object Export extends ScalaCommand[ExportOptions] {
val scopedSources: ScopedSources = value(crossSources.scopedSources(buildOptions))
val sources: Sources =
- scopedSources.sources(scope, crossSources.sharedOptions(buildOptions), inputs.workspace)
+ scopedSources.sources(
+ scope,
+ crossSources.sharedOptions(buildOptions),
+ inputs.workspace,
+ logger
+ )
.orExit(logger)
if (verbosity >= 3)
diff --git a/modules/cli/src/main/scala/scala/cli/commands/fix/Fix.scala b/modules/cli/src/main/scala/scala/cli/commands/fix/Fix.scala
index 3f42d6bb00..bd229e3999 100644
--- a/modules/cli/src/main/scala/scala/cli/commands/fix/Fix.scala
+++ b/modules/cli/src/main/scala/scala/cli/commands/fix/Fix.scala
@@ -136,8 +136,8 @@ object Fix extends ScalaCommand[FixOptions] {
val sharedOptions = crossSources.sharedOptions(buildOptions)
val scopedSources = crossSources.scopedSources(sharedOptions).orExit(logger)
- val mainSources = scopedSources.sources(Scope.Main, sharedOptions, inputs.workspace)
- val testSources = scopedSources.sources(Scope.Test, sharedOptions, inputs.workspace)
+ val mainSources = scopedSources.sources(Scope.Main, sharedOptions, inputs.workspace, logger)
+ val testSources = scopedSources.sources(Scope.Test, sharedOptions, inputs.workspace, logger)
(mainSources, testSources).traverseN
}
diff --git a/modules/cli/src/main/scala/scala/cli/commands/publish/PublishSetup.scala b/modules/cli/src/main/scala/scala/cli/commands/publish/PublishSetup.scala
index f02eef6517..a01a47196b 100644
--- a/modules/cli/src/main/scala/scala/cli/commands/publish/PublishSetup.scala
+++ b/modules/cli/src/main/scala/scala/cli/commands/publish/PublishSetup.scala
@@ -94,8 +94,9 @@ object PublishSetup extends ScalaCommand[PublishSetupOptions] {
val crossSourcesSharedOptions = crossSources.sharedOptions(cliBuildOptions)
val scopedSources = crossSources.scopedSources(crossSourcesSharedOptions).orExit(logger)
- val sources = scopedSources.sources(Scope.Main, crossSourcesSharedOptions, inputs.workspace)
- .orExit(logger)
+ val sources =
+ scopedSources.sources(Scope.Main, crossSourcesSharedOptions, inputs.workspace, logger)
+ .orExit(logger)
val pureJava = sources.hasJava && !sources.hasScala
diff --git a/modules/cli/src/main/scala/scala/cli/commands/setupide/SetupIde.scala b/modules/cli/src/main/scala/scala/cli/commands/setupide/SetupIde.scala
index 833dd01141..e1552be290 100644
--- a/modules/cli/src/main/scala/scala/cli/commands/setupide/SetupIde.scala
+++ b/modules/cli/src/main/scala/scala/cli/commands/setupide/SetupIde.scala
@@ -50,7 +50,8 @@ object SetupIde extends ScalaCommand[SetupIdeOptions] {
val mainSources = value(scopedSources.sources(
Scope.Main,
crossSources.sharedOptions(options),
- allInputs.workspace
+ allInputs.workspace,
+ logger
))
mainSources.buildOptions
diff --git a/modules/cli/src/main/scala/scala/cli/commands/shared/SharedOptions.scala b/modules/cli/src/main/scala/scala/cli/commands/shared/SharedOptions.scala
index b9d480dbbf..7de4834df0 100644
--- a/modules/cli/src/main/scala/scala/cli/commands/shared/SharedOptions.scala
+++ b/modules/cli/src/main/scala/scala/cli/commands/shared/SharedOptions.scala
@@ -24,7 +24,7 @@ import scala.build.interactive.Interactive
import scala.build.interactive.Interactive.{InteractiveAsk, InteractiveNop}
import scala.build.internal.util.ConsoleUtils.ScalaCliConsole
import scala.build.internal.util.WarningMessages
-import scala.build.internal.{Constants, FetchExternalBinary, ObjectCodeWrapper, OsLibc, Util}
+import scala.build.internal.{Constants, FetchExternalBinary, OsLibc, Util}
import scala.build.options.ScalaVersionUtil.fileWithTtl0
import scala.build.options.{BuildOptions, ComputeVersion, Platform, ScalacOpt, ShadowingSeq}
import scala.build.preprocessing.directives.ClasspathUtils.*
diff --git a/modules/core/src/main/scala/scala/build/internals/CodeWrapper.scala b/modules/core/src/main/scala/scala/build/internals/CodeWrapper.scala
index 0c6b21b961..fea1201a2d 100644
--- a/modules/core/src/main/scala/scala/build/internals/CodeWrapper.scala
+++ b/modules/core/src/main/scala/scala/build/internals/CodeWrapper.scala
@@ -1,7 +1,6 @@
package scala.build.internal
abstract class CodeWrapper {
- def wrapperPath: Seq[Name] = Nil
def apply(
code: String,
pkgName: Seq[Name],
@@ -10,6 +9,8 @@ abstract class CodeWrapper {
scriptPath: String
): (String, String)
+ def mainClassObject(className: Name): Name
+
def wrapCode(
pkgName: Seq[Name],
indexedWrapperName: Name,
@@ -30,16 +31,14 @@ abstract class CodeWrapper {
nl + "/**/ /**/" + bottomWrapper
)
- val wrapperParams = WrapperParams(topWrapper0.linesIterator.size, code.linesIterator.size)
+ val mainClassName =
+ (pkgName :+ mainClassObject(indexedWrapperName)).map(_.encoded).mkString(".")
+
+ val wrapperParams =
+ WrapperParams(topWrapper0.linesIterator.size, code.linesIterator.size, mainClassName)
(topWrapper0 + code + bottomWrapper0, wrapperParams)
}
-
-}
-
-object CodeWrapper {
- def mainClassObject(className: Name): Name =
- Name(className.raw ++ "_sc")
}
-case class WrapperParams(topWrapperLineCount: Int, userCodeLineCount: Int)
+case class WrapperParams(topWrapperLineCount: Int, userCodeLineCount: Int, mainClass: String)
diff --git a/modules/integration/src/test/scala/scala/cli/integration/BspTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/BspTestDefinitions.scala
index 5414d95ebf..ba012a46b4 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/BspTestDefinitions.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/BspTestDefinitions.scala
@@ -1899,8 +1899,8 @@ abstract class BspTestDefinitions(val scalaVersionOpt: Option[String])
val change = edit.getChanges.asScala.head
val expectedRange = new b.Range(
- new b.Position(11, 19),
- new b.Position(11, 19)
+ new b.Position(9, 19),
+ new b.Position(9, 19)
)
expect(change.getRange == expectedRange)
expect(change.getNewText == "()")
diff --git a/modules/integration/src/test/scala/scala/cli/integration/PackageTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/PackageTestDefinitions.scala
index 934ad354f9..06d2ad060a 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/PackageTestDefinitions.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/PackageTestDefinitions.scala
@@ -931,7 +931,13 @@ abstract class PackageTestDefinitions(val scalaVersionOpt: Option[String])
.call(cwd = root)
val output = res.out.trim()
val mainClasses = output.split(" ").toSet
- expect(mainClasses == Set(scalaFile1, scalaFile2, s"$scriptsDir.${scriptName}_sc"))
+
+ val scriptMainClassName = if (actualScalaVersion.startsWith("3"))
+ s"$scriptsDir.${scriptName}_sc"
+ else
+ s"$scriptsDir.$scriptName"
+
+ expect(mainClasses == Set(scalaFile1, scalaFile2, scriptMainClassName))
}
}
diff --git a/modules/integration/src/test/scala/scala/cli/integration/PublishTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/PublishTestDefinitions.scala
index 2c056e0b8f..b00e4f4e32 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/PublishTestDefinitions.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/PublishTestDefinitions.scala
@@ -376,7 +376,13 @@ abstract class PublishTestDefinitions(val scalaVersionOpt: Option[String])
val outputLocal = resLocal.out.trim()
expect(output == outputLocal)
val mainClasses = output.linesIterator.toSeq.last.split(" ").toSet
- expect(mainClasses == Set(scalaFile1, scalaFile2, s"$scriptsDir.${scriptName}_sc"))
+
+ val scriptMainClassName = if (actualScalaVersion.startsWith("3"))
+ s"$scriptsDir.${scriptName}_sc"
+ else
+ s"$scriptsDir.$scriptName"
+
+ expect(mainClasses == Set(scalaFile1, scalaFile2, scriptMainClassName))
}
}
diff --git a/modules/integration/src/test/scala/scala/cli/integration/RunScalaJsTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/RunScalaJsTestDefinitions.scala
index 9fa038866c..b952358e10 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/RunScalaJsTestDefinitions.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/RunScalaJsTestDefinitions.scala
@@ -244,7 +244,9 @@ trait RunScalaJsTestDefinitions { _: RunTestDefinitions =>
|""".stripMargin
)
inputs.fromRoot { root =>
- val output = os.proc(TestUtil.cli, extraOptions, "dir", "--js", "--main-class", "print_sc")
+ val mainClassName = if (actualScalaVersion.startsWith("3")) "print_sc" else "print"
+
+ val output = os.proc(TestUtil.cli, extraOptions, "dir", "--js", "--main-class", mainClassName)
.call(cwd = root)
.out.trim()
expect(output == message)
diff --git a/modules/integration/src/test/scala/scala/cli/integration/RunScalaNativeTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/RunScalaNativeTestDefinitions.scala
index cc03b23860..da5ce6849d 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/RunScalaNativeTestDefinitions.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/RunScalaNativeTestDefinitions.scala
@@ -240,7 +240,7 @@ trait RunScalaNativeTestDefinitions { _: RunTestDefinitions =>
)
inputs.fromRoot { root =>
val output =
- os.proc(TestUtil.cli, extraOptions, "dir", "--native", "--main-class", "print_sc", "-q")
+ os.proc(TestUtil.cli, extraOptions, "dir", "--native", "--main-class", "print", "-q")
.call(cwd = root)
.out.trim()
expect(output == message)
diff --git a/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala
index 87b3018f63..27ee5a0f1a 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/RunScriptTestDefinitions.scala
@@ -46,6 +46,9 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions =>
root
).out.trim()
expect(output == message)
+ expect(
+ !output.contains("Script file named 'main.sc' detected, keep in mind that accessing it")
+ )
}
}
@@ -64,23 +67,49 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions =>
}
}
- test("use method from main.sc file") {
- val message = "Hello"
- val inputs = TestInputs(
- os.rel / "message.sc" ->
- s"""println(main.msg)
- |""".stripMargin,
- os.rel / "main.sc" ->
- s"""def msg = "$message"
- |""".stripMargin
- )
- inputs.fromRoot { root =>
- val output = os.proc(TestUtil.cli, extraOptions, "message.sc", "main.sc").call(cwd =
- root
- ).out.trim()
- expect(output == message)
+ if (actualScalaVersion.startsWith("3"))
+ test("use method from main.sc file") {
+ val message = "Hello"
+ val inputs = TestInputs(
+ os.rel / "message.sc" ->
+ s"""println(main.msg)
+ |""".stripMargin,
+ os.rel / "main.sc" ->
+ s"""def msg = "$message"
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val output = os.proc(TestUtil.cli, extraOptions, "message.sc", "main.sc").call(cwd =
+ root
+ ).out.trim()
+ expect(output == message)
+ expect(
+ !output.contains("Script file named 'main.sc' detected, keep in mind that accessing it")
+ )
+ }
+ }
+ else
+ test("warn when main.sc file is used together with other scripts") {
+ val message = "Hello"
+ val inputs = TestInputs(
+ os.rel / "message.sc" ->
+ s"""println(main.msg)
+ |""".stripMargin,
+ os.rel / "main.sc" ->
+ s"""def msg = "$message"
+ |""".stripMargin
+ )
+ inputs.fromRoot { root =>
+ val res = os.proc(TestUtil.cli, extraOptions, "message.sc", "main.sc")
+ .call(cwd = root, check = false, mergeErrIntoOut = true)
+
+ expect(res.exitCode == 1)
+ val output = res.out.trim()
+ expect(
+ output.contains("Script file named 'main.sc' detected, keep in mind that accessing it")
+ )
+ }
}
- }
test("Directory") {
val message = "Hello"
@@ -92,10 +121,13 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions =>
s"""println(messages.msg)
|""".stripMargin
)
+
+ val mainClassName = if (actualScalaVersion.startsWith("3")) "print_sc" else "print"
inputs.fromRoot { root =>
- val output = os.proc(TestUtil.cli, extraOptions, "dir", "--main-class", "print_sc").call(cwd =
- root
- ).out.trim()
+ val output =
+ os.proc(TestUtil.cli, extraOptions, "dir", "--main-class", mainClassName).call(cwd =
+ root
+ ).out.trim()
expect(output == message)
}
}
@@ -180,30 +212,44 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions =>
val tab = "\t"
val expectedLines =
if (actualScalaVersion.startsWith("2.12."))
- s"""Exception in thread "main" java.lang.ExceptionInInitializerError
- |${tab}at throws_sc$$.main(throws.sc:24)
- |${tab}at throws_sc.main(throws.sc)
- |Caused by: java.lang.Exception: Caught exception during processing
- |${tab}at throws$$.(throws.sc:6)
- |${tab}at throws$$.(throws.sc)
- |$tab... 2 more
+ s"""Exception in thread "main" java.lang.Exception: Caught exception during processing
+ |${tab}at throws$$.delayedEndpoint$$throws$$1(throws.sc:6)
+ |${tab}at throws$$delayedInit$$body.apply(throws.sc:65534)
+ |${tab}at scala.Function0.apply$$mcV$$sp(Function0.scala:39)
+ |${tab}at scala.Function0.apply$$mcV$$sp$$(Function0.scala:39)
+ |${tab}at scala.runtime.AbstractFunction0.apply$$mcV$$sp(AbstractFunction0.scala:17)
+ |${tab}at scala.App.$$anonfun$$main$$1$$adapted(App.scala:80)
+ |${tab}at scala.collection.immutable.List.foreach(List.scala:431)
+ |${tab}at scala.App.main(App.scala:80)
+ |${tab}at scala.App.main$$(App.scala:78)
+ |${tab}at throws$$.main(throws.sc:65534)
+ |${tab}at throws.main(throws.sc)
|Caused by: java.lang.RuntimeException: nope
|${tab}at scala.sys.package$$.error(package.scala:30)
|${tab}at throws$$.something(throws.sc:2)
- |${tab}at throws$$.(throws.sc:3)
- |$tab... 3 more""".stripMargin.linesIterator.toVector
+ |${tab}at throws$$.delayedEndpoint$$throws$$1(throws.sc:3)
+ |$tab... 10 more""".stripMargin.linesIterator.toVector
else
- s"""Exception in thread "main" java.lang.ExceptionInInitializerError
- |${tab}at throws_sc$$.main(throws.sc:24)
- |${tab}at throws_sc.main(throws.sc)
- |Caused by: java.lang.Exception: Caught exception during processing
- |${tab}at throws$$.(throws.sc:6)
- |$tab... 2 more
+ s"""Exception in thread "main" java.lang.Exception: Caught exception during processing
+ |${tab}at throws$$.delayedEndpoint$$throws$$1(throws.sc:6)
+ |${tab}at throws$$delayedInit$$body.apply(throws.sc:65534)
+ |${tab}at scala.Function0.apply$$mcV$$sp(Function0.scala:42)
+ |${tab}at scala.Function0.apply$$mcV$$sp$$(Function0.scala:42)
+ |${tab}at scala.runtime.AbstractFunction0.apply$$mcV$$sp(AbstractFunction0.scala:17)
+ |${tab}at scala.App.$$anonfun$$main$$1(App.scala:98)
+ |${tab}at scala.App.$$anonfun$$main$$1$$adapted(App.scala:98)
+ |${tab}at scala.collection.IterableOnceOps.foreach(IterableOnce.scala:576)
+ |${tab}at scala.collection.IterableOnceOps.foreach$$(IterableOnce.scala:574)
+ |${tab}at scala.collection.AbstractIterable.foreach(Iterable.scala:933)
+ |${tab}at scala.App.main(App.scala:98)
+ |${tab}at scala.App.main$$(App.scala:96)
+ |${tab}at throws$$.main(throws.sc:65534)
+ |${tab}at throws.main(throws.sc)
|Caused by: java.lang.RuntimeException: nope
|${tab}at scala.sys.package$$.error(package.scala:27)
|${tab}at throws$$.something(throws.sc:2)
- |${tab}at throws$$.(throws.sc:3)
- |$tab... 2 more
+ |${tab}at throws$$.delayedEndpoint$$throws$$1(throws.sc:3)
+ |$tab... 13 more
|""".stripMargin.linesIterator.toVector
if (exceptionLines != expectedLines) {
println(exceptionLines.mkString("\n"))
@@ -276,10 +322,10 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions =>
val inputs = TestInputs(
os.rel / "Hello.scala" ->
"""object Hello extends App {
- | println(s"Hello ${scripts.Script.world}")
+ | println(s"Hello ${scripts.`Script-1`.world}")
|}
|""".stripMargin,
- os.rel / "scripts" / "Script.sc" -> """def world: String = "world"""".stripMargin
+ os.rel / "scripts" / "Script-1.sc" -> """def world: String = "world"""".stripMargin
)
inputs.fromRoot { root =>
val res = os.proc(
diff --git a/modules/integration/src/test/scala/scala/cli/integration/RunTestDefinitions.scala b/modules/integration/src/test/scala/scala/cli/integration/RunTestDefinitions.scala
index 8beb2c8232..6610e08958 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/RunTestDefinitions.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/RunTestDefinitions.scala
@@ -1056,8 +1056,12 @@ abstract class RunTestDefinitions(val scalaVersionOpt: Option[String])
val errorMessage =
output.linesWithSeparators.toSeq.takeRight(6).mkString // dropping compilation logs
val extraOptionsString = extraOptions.mkString(" ")
- val expectedMainClassNames =
- Seq(scalaFile1, scalaFile2, s"$scriptsDir.${scriptName}_sc").sorted
+ val scriptMainClassName = if (actualScalaVersion.startsWith("3"))
+ s"$scriptsDir.${scriptName}_sc"
+ else
+ s"$scriptsDir.$scriptName"
+
+ val expectedMainClassNames = Seq(scalaFile1, scalaFile2, scriptMainClassName).sorted
val expectedErrorMessage =
s"""[${Console.RED}error${Console.RESET}] Found several main classes: ${expectedMainClassNames.mkString(
", "
@@ -1108,7 +1112,13 @@ abstract class RunTestDefinitions(val scalaVersionOpt: Option[String])
.call(cwd = root)
val output = res.out.trim()
val mainClasses = output.split(" ").toSet
- expect(mainClasses == Set(scalaFile1, scalaFile2, s"$scriptsDir.${scriptName}_sc"))
+
+ val scriptMainClassName = if (actualScalaVersion.startsWith("3"))
+ s"$scriptsDir.${scriptName}_sc"
+ else
+ s"$scriptsDir.$scriptName"
+
+ expect(mainClasses == Set(scalaFile1, scalaFile2, scriptMainClassName))
}
}
diff --git a/modules/integration/src/test/scala/scala/cli/integration/ScriptWrapperTests.scala b/modules/integration/src/test/scala/scala/cli/integration/ScriptWrapperTests.scala
index b4f3c061d9..6e1509761d 100644
--- a/modules/integration/src/test/scala/scala/cli/integration/ScriptWrapperTests.scala
+++ b/modules/integration/src/test/scala/scala/cli/integration/ScriptWrapperTests.scala
@@ -6,6 +6,20 @@ import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration
class ScriptWrapperTests extends ScalaCliSuite {
+
+ def expectAppWrapper(wrapperName: String, path: os.Path) = {
+ val generatedFileContent = os.read(path)
+ assert(
+ generatedFileContent.contains(s"object $wrapperName extends App {"),
+ clue(s"Generated file content: $generatedFileContent")
+ )
+ assert(
+ !generatedFileContent.contains(s"final class $wrapperName$$_") &&
+ !generatedFileContent.contains(s"object $wrapperName {"),
+ clue(s"Generated file content: $generatedFileContent")
+ )
+ }
+
def expectObjectWrapper(wrapperName: String, path: os.Path) = {
val generatedFileContent = os.read(path)
assert(
@@ -13,7 +27,8 @@ class ScriptWrapperTests extends ScalaCliSuite {
clue(s"Generated file content: $generatedFileContent")
)
assert(
- !generatedFileContent.contains(s"final class $wrapperName$$_"),
+ !generatedFileContent.contains(s"final class $wrapperName$$_") &&
+ !generatedFileContent.contains(s"object $wrapperName wraps App {"),
clue(s"Generated file content: $generatedFileContent")
)
}
@@ -25,6 +40,7 @@ class ScriptWrapperTests extends ScalaCliSuite {
clue(s"Generated file content: $generatedFileContent")
)
assert(
+ !generatedFileContent.contains(s"object $wrapperName extends App {") &&
!generatedFileContent.contains(s"object $wrapperName {"),
clue(s"Generated file content: $generatedFileContent")
)
@@ -85,7 +101,6 @@ class ScriptWrapperTests extends ScalaCliSuite {
useDirectives <- Seq(true, false)
(directive, options) <- Seq(
("//> using object.wrapper", Seq("--object-wrapper")),
- ("//> using scala 2.13", Seq("--scala", "2.13")),
("//> using platform js", Seq("--js"))
)
} {
@@ -147,4 +162,70 @@ class ScriptWrapperTests extends ScalaCliSuite {
}
}
}
+
+ for {
+ useDirectives <- Seq(true, false)
+ (directive, options) <- Seq(
+ ("//> using scala 2.13", Seq("--scala", "2.13"))
+ )
+ } {
+ val inputs = TestInputs(
+ os.rel / "script1.sc" ->
+ s"""//> using platform js
+ |//> using dep "com.lihaoyi::os-lib:0.9.1"
+ |${if (useDirectives) directive else ""}
+ |
+ |def main(args: String*): Unit = println("Hello")
+ |main()
+ |""".stripMargin,
+ os.rel / "script2.sc" ->
+ """//> using dep "com.lihaoyi::os-lib:0.9.1"
+ |
+ |println("Hello")
+ |""".stripMargin
+ )
+
+ test(
+ s"BSP App object wrapper forced with ${if (useDirectives) directive else options.mkString(" ")}"
+ ) {
+ inputs.fromRoot { root =>
+ TestUtil.withThreadPool("script-wrapper-bsp-test", 2) { pool =>
+ val timeout = Duration("60 seconds")
+ implicit val ec = ExecutionContext.fromExecutorService(pool)
+
+ val bspProc = os.proc(
+ TestUtil.cli,
+ "--power",
+ "bsp",
+ "script1.sc",
+ "script2.sc",
+ if (useDirectives) Nil else options
+ )
+ .spawn(cwd = root, mergeErrIntoOut = true, stdout = os.Pipe)
+
+ def lineReaderIter =
+ Iterator.continually(TestUtil.readLine(bspProc.stdout, ec, timeout))
+
+ lineReaderIter.find(_.contains("\"build/taskFinish\""))
+
+ bspProc.destroy()
+ if (bspProc.isAlive())
+ bspProc.destroyForcibly()
+
+ val projectDir = os.list(root / Constants.workspaceDirName).filter(
+ _.baseName.startsWith(root.baseName + "_")
+ )
+ expect(projectDir.size == 1)
+ expectAppWrapper(
+ "script1",
+ projectDir.head / "src_generated" / "main" / "script1.scala"
+ )
+ expectAppWrapper(
+ "script2",
+ projectDir.head / "src_generated" / "main" / "script2.scala"
+ )
+ }
+ }
+ }
+ }
}