Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make object wrapper extend App #2556

Merged
merged 11 commits into from
Nov 20, 2023
23 changes: 9 additions & 14 deletions modules/build/src/main/scala/scala/build/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,11 @@ object Build {
logger: Logger
): Either[Seq[String], String] = {
val scriptInferredMainClasses =
sources.inMemory.map(im => im.originalPath.map(_._1))
.flatMap {
case Right(originalRelPath) if originalRelPath.toString.endsWith(".sc") =>
Some {
originalRelPath
.toString
.replace(".", "_")
.replace("/", ".")
}
case Left(VirtualScriptNameRegex(name)) => Some(s"${name}_sc")
case _ => None
}
sources.inMemory.collect {
case Sources.InMemory(_, _, _, Some(wrapperParams)) =>
wrapperParams.mainClass
}

val filteredMainClasses =
mainClasses.filter(!scriptInferredMainClasses.contains(_))
if (filteredMainClasses.length == 1) {
Expand Down Expand Up @@ -279,10 +272,12 @@ object Build {

val scopedSources = value(crossSources.scopedSources(baseOptions))

val mainSources = value(scopedSources.sources(Scope.Main, baseOptions, allInputs.workspace))
val mainSources =
value(scopedSources.sources(Scope.Main, baseOptions, allInputs.workspace, logger))
val mainOptions = mainSources.buildOptions

val testSources = value(scopedSources.sources(Scope.Test, baseOptions, allInputs.workspace))
val testSources =
value(scopedSources.sources(Scope.Test, baseOptions, allInputs.workspace, logger))
val testOptions = testSources.buildOptions

val inputs0 = updateInputs(
Expand Down
15 changes: 6 additions & 9 deletions modules/build/src/main/scala/scala/build/CrossSources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import scala.util.chaining.*
final case class CrossSources(
paths: Seq[WithBuildRequirements[(os.Path, os.RelPath)]],
inMemory: Seq[WithBuildRequirements[Sources.InMemory]],
defaultMainClass: Option[String],
defaultMainElemPath: Option[os.Path],
resourceDirs: Seq[WithBuildRequirements[os.Path]],
buildOptions: Seq[WithBuildRequirements[BuildOptions]],
unwrappedScripts: Seq[WithBuildRequirements[Sources.UnwrappedScript]]
Expand Down Expand Up @@ -130,7 +130,7 @@ final case class CrossSources(
ScopedSources(
crossSources0.paths.map(_.scopedValue(defaultScope)),
crossSources0.inMemory.map(_.scopedValue(defaultScope)),
defaultMainClass,
defaultMainElemPath,
crossSources0.resourceDirs.map(_.scopedValue(defaultScope)),
crossSources0.buildOptions.map(_.scopedValue(defaultScope)),
crossSources0.unwrappedScripts.map(_.scopedValue(defaultScope))
Expand Down Expand Up @@ -273,12 +273,9 @@ object CrossSources {
)
}).flatten

val defaultMainClassOpt: Option[String] = for {
mainClassPath <- allInputs.defaultMainClassElement
.map(s => ScopePath.fromPath(s.path).subPath)
processedMainClass <- preprocessedSources.find(_.scopePath.subPath == mainClassPath)
mainClass <- processedMainClass.mainClassOpt
} yield mainClass
val defaultMainElemPath = for {
defaultMainElem <- allInputs.defaultMainClassElement
} yield defaultMainElem.path

val pathsWithDirectivePositions
: Seq[(WithBuildRequirements[(os.Path, os.RelPath)], Option[Position.File])] =
Expand Down Expand Up @@ -369,7 +366,7 @@ object CrossSources {
CrossSources(
paths,
inMemory,
defaultMainClassOpt,
defaultMainElemPath,
resourceDirs,
buildOptions,
unwrappedScripts
Expand Down
22 changes: 20 additions & 2 deletions modules/build/src/main/scala/scala/build/ScopedSources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import java.nio.charset.StandardCharsets
import scala.build.EitherCps.{either, value}
import scala.build.errors.BuildException
import scala.build.info.{BuildInfo, ScopedBuildInfo}
import scala.build.internal.AppCodeWrapper
import scala.build.internal.util.WarningMessages
import scala.build.options.{BuildOptions, HasScope, Scope}
import scala.build.preprocessing.ScriptPreprocessor

Expand All @@ -28,7 +30,7 @@ import scala.build.preprocessing.ScriptPreprocessor
final case class ScopedSources(
paths: Seq[HasScope[(os.Path, os.RelPath)]],
inMemory: Seq[HasScope[Sources.InMemory]],
defaultMainClass: Option[String],
defaultMainElemPath: Option[os.Path],
resourceDirs: Seq[HasScope[os.Path]],
buildOptions: Seq[HasScope[BuildOptions]],
unwrappedScripts: Seq[HasScope[Sources.UnwrappedScript]]
Expand All @@ -55,7 +57,8 @@ final case class ScopedSources(
def sources(
scope: Scope,
baseOptions: BuildOptions,
workspace: os.Path
workspace: os.Path,
logger: Logger
): Either[BuildException, Sources] = either {
val combinedOptions = combinedBuildOptions(scope, baseOptions)

Expand All @@ -65,6 +68,21 @@ final case class ScopedSources(
.flatMap(_.valueFor(scope).toSeq)
.map(_.wrap(codeWrapper))

codeWrapper match {
case _: AppCodeWrapper.type if wrappedScripts.size > 1 =>
wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc"))
.foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper))
case _ => ()
}

val defaultMainClass = defaultMainElemPath.flatMap { mainElemPath =>
wrappedScripts.collectFirst {
case Sources.InMemory(Right((_, path)), _, _, Some(wrapperParams))
if mainElemPath == path =>
wrapperParams.mainClass
}
}

val needsBuildInfo = combinedOptions.sourceGeneratorOptions.useBuildInfo.getOrElse(false)

val maybeBuildInfoSource = if (needsBuildInfo && scope == Scope.Main)
Expand Down
4 changes: 2 additions & 2 deletions modules/build/src/main/scala/scala/build/bsp/BspImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ final class BspImpl(
pprint.err.log(scopedSources)

val sourcesMain = value {
scopedSources.sources(Scope.Main, sharedOptions, allInputs.workspace)
scopedSources.sources(Scope.Main, sharedOptions, allInputs.workspace, persistentLogger)
.left.map((_, Scope.Main))
}

val sourcesTest = value {
scopedSources.sources(Scope.Test, sharedOptions, allInputs.workspace)
scopedSources.sources(Scope.Test, sharedOptions, allInputs.workspace, persistentLogger)
.left.map((_, Scope.Test))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package scala.build.internal

case object AppCodeWrapper extends CodeWrapper {
override def mainClassObject(className: Name) = className

def apply(
code: String,
pkgName: Seq[Name],
indexedWrapperName: Name,
extraCode: String,
scriptPath: String
) = {
val wrapperObjectName = indexedWrapperName.backticked

val packageDirective =
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
val top = AmmUtil.normalizeNewlines(
s"""$packageDirective
|
|object $wrapperObjectName extends App {
|val scriptPath = \"\"\"$scriptPath\"\"\"
|""".stripMargin
)
val bottom = AmmUtil.normalizeNewlines(
s"""
|$extraCode
|}
|""".stripMargin
)

(top, bottom)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ package scala.build.internal
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
*/
case object ClassCodeWrapper extends CodeWrapper {

override def mainClassObject(className: Name): Name =
Name(className.raw ++ "_sc")
def apply(
code: String,
pkgName: Seq[Name],
indexedWrapperName: Name,
extraCode: String,
scriptPath: String
) = {
val name = CodeWrapper.mainClassObject(indexedWrapperName).backticked
val name = mainClassObject(indexedWrapperName).backticked
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
Expand All @@ -34,29 +37,27 @@ case object ClassCodeWrapper extends CodeWrapper {
| }
|}
|
|export $name.script as ${indexedWrapperName.backticked}
|export $name.script as `${indexedWrapperName.raw}`
|""".stripMargin)

val packageDirective =
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"

// indentation is important in the generated code, so we don't want scalafmt to touch that
// format: off
val top = AmmUtil.normalizeNewlines(s"""
$packageDirective


final class $wrapperClassName {
def args = $name.args$$
def scriptPath = \"\"\"$scriptPath\"\"\"
""")
val bottom = AmmUtil.normalizeNewlines(s"""
$extraCode
}

$mainObjectCode
""")
// format: on
val top = AmmUtil.normalizeNewlines(
s"""$packageDirective
|
|final class $wrapperClassName {
|def args = $name.args$$
|def scriptPath = \"\"\"$scriptPath\"\"\"
|""".stripMargin
)
val bottom = AmmUtil.normalizeNewlines(
s"""$extraCode
|}
|
|$mainObjectCode
|""".stripMargin
)

(top, bottom)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ package scala.build.internal
* threads from script
*/
case object ObjectCodeWrapper extends CodeWrapper {

override def mainClassObject(className: Name): Name =
Name(className.raw ++ "_sc")
def apply(
code: String,
pkgName: Seq[Name],
indexedWrapperName: Name,
extraCode: String,
scriptPath: String
) = {
val name = CodeWrapper.mainClassObject(indexedWrapperName).backticked
val name = mainClassObject(indexedWrapperName).backticked
val aliasedWrapperName = name + "$$alias"
val funHashCodeMethod =
if (name == "main_sc")
Expand Down Expand Up @@ -46,23 +49,22 @@ case object ObjectCodeWrapper extends CodeWrapper {
|}""".stripMargin
else ""

// indentation is important in the generated code, so we don't want scalafmt to touch that
// format: off
val top = AmmUtil.normalizeNewlines(s"""
$packageDirective

val top = AmmUtil.normalizeNewlines(
s"""$packageDirective
|
|object ${indexedWrapperName.backticked} {
|def args = $name.args$$
|def scriptPath = \"\"\"$scriptPath\"\"\"
|""".stripMargin
)

object ${indexedWrapperName.backticked} {
def args = $name.args$$
def scriptPath = \"\"\"$scriptPath\"\"\"
""")
val bottom = AmmUtil.normalizeNewlines(s"""
$extraCode
}
$aliasObject
$mainObjectCode
""")
// format: on
val bottom = AmmUtil.normalizeNewlines(
s"""$extraCode
|}
|$aliasObject
|$mainObjectCode
|""".stripMargin
)

(top, bottom)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,7 @@ object WarningMessages {
|$recommendedMsg
|""".stripMargin
}

val mainScriptNameClashesWithAppWrapper =
"Script file named 'main.sc' detected, keep in mind that accessing it from other scripts is impossible due to a clash of `main` symbols"
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ sealed abstract class PreprocessedSource extends Product with Serializable {
def scopePath: ScopePath
def directivesPositions: Option[Position.File]
def distinctPathOrSource: String = this match {
case PreprocessedSource.OnDisk(p, _, _, _, _, _, _) => p.toString
case PreprocessedSource.InMemory(op, rp, _, _, _, _, _, _, _, _, _) => s"$op; $rp"
case PreprocessedSource.UnwrappedScript(p, _, _, _, _, _, _, _, _, _) => p.toString
case PreprocessedSource.NoSourceCode(_, _, _, _, p) => p.toString
case p: PreprocessedSource.OnDisk => p.path.toString
case p: PreprocessedSource.InMemory => s"${p.originalPath}; ${p.relPath}"
case p: PreprocessedSource.UnwrappedScript => p.originalPath.toString
case p: PreprocessedSource.NoSourceCode => p.path.toString
}
}

Expand Down Expand Up @@ -58,11 +58,12 @@ object PreprocessedSource {
optionsWithTargetRequirements: List[WithBuildRequirements[BuildOptions]],
requirements: Option[BuildRequirements],
scopedRequirements: Seq[Scoped[BuildRequirements]],
mainClassOpt: Option[String],
scopePath: ScopePath,
directivesPositions: Option[Position.File],
wrapScriptFun: CodeWrapper => (String, WrapperParams)
) extends PreprocessedSource
) extends PreprocessedSource {
override def mainClassOpt: Option[String] = None
}

final case class NoSourceCode(
options: Option[BuildOptions],
Expand Down
Loading