mirror of https://github.com/sbt/sbt.git
[2.x] fix: Fixes script mode using Scala 3.x (#8900)
**Problem**
Scripts with scalaVersion 3.x in /*** */ and a shebang fail: -Xscript is
ignored by Scala 3, and the shebang line causes "Expected a toplevel
definition".
**Solution**
- Strip shebang when copying the script so the compiler never sees it.
- For Scala 3 only: do not add -Xscript; generate Main.scala wrapping the
script body in object Main { def main(...) = { ... } }; use it as the
single source and set run/mainClass to Main.
- For Scala 2: keep existing behavior (shebang stripped, -Xscript + script
base name).
- Use Def.uncached and ScalaArtifacts.isScala3(scalaVersion.value) so
embedded scalaVersion from /*** */ is respected.
This commit is contained in:
parent
3f9bafc153
commit
e12b3c9b9a
|
|
@ -184,6 +184,7 @@ final class ScriptMain extends xsbti.AppMain {
|
||||||
private[sbt] object ScriptMain {
|
private[sbt] object ScriptMain {
|
||||||
private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = {
|
private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = {
|
||||||
import BasicCommandStrings.runEarly
|
import BasicCommandStrings.runEarly
|
||||||
|
Plugins.defaultRequires = sbt.plugins.JvmPlugin
|
||||||
val state = StandardMain.initialState(
|
val state = StandardMain.initialState(
|
||||||
xMain.dealiasBaseDirectory(configuration),
|
xMain.dealiasBaseDirectory(configuration),
|
||||||
BuiltinCommands.ScriptCommands,
|
BuiltinCommands.ScriptCommands,
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@
|
||||||
package sbt
|
package sbt
|
||||||
package internal
|
package internal
|
||||||
|
|
||||||
import sbt.librarymanagement.Configurations
|
import sbt.librarymanagement.{ Configurations, ScalaArtifacts }
|
||||||
|
|
||||||
import sbt.util.Level
|
import sbt.util.Level
|
||||||
|
|
||||||
|
|
@ -26,6 +26,37 @@ import scala.annotation.tailrec
|
||||||
|
|
||||||
object Script {
|
object Script {
|
||||||
final val Name = "script"
|
final val Name = "script"
|
||||||
|
// When shebang is stripped, compiler error line numbers may be off by one for the original file;
|
||||||
|
// position mapping could be added in a future improvement (see sbt/sbt#6274).
|
||||||
|
/** If the first line is a shebang (#!), drop it so the compiler never sees it. */
|
||||||
|
private[internal] def stripShebang(lines: Seq[String]): Seq[String] =
|
||||||
|
if (lines.nonEmpty && lines.head.startsWith("#!")) lines.drop(1) else lines
|
||||||
|
|
||||||
|
/** Lines that are not inside any /*** ... */ block (i.e. the executable script body). */
|
||||||
|
private[internal] def scriptBodyLines(file: File): Seq[String] = {
|
||||||
|
val lines = IO.readLines(file).toIndexedSeq
|
||||||
|
// Block(offset, lines): offset = index of /*** line, lines = content between /*** and */ (excl. both).
|
||||||
|
// Exclude /*** (off), content (off+1..off+ls.size), and */ (off+1+ls.size).
|
||||||
|
val blockSet = blocks(file).flatMap {
|
||||||
|
case Block(off, ls) => (off until off + 2 + ls.size)
|
||||||
|
}.toSet
|
||||||
|
lines.indices.filterNot(blockSet).map(lines)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Write a Scala 3 compilable file that wraps the script body in object Main { def main(...) = { ... } }. */
|
||||||
|
private def writeWrappedScript(body: Seq[String], out: File): Unit = {
|
||||||
|
val indent = " "
|
||||||
|
val inner = body.map(line => indent + line).mkString("\n")
|
||||||
|
val content =
|
||||||
|
s"""object Main {
|
||||||
|
| def main(args: Array[String]): Unit = {
|
||||||
|
|$inner
|
||||||
|
| }
|
||||||
|
|}
|
||||||
|
|""".stripMargin
|
||||||
|
IO.write(out, content)
|
||||||
|
}
|
||||||
|
|
||||||
lazy val command =
|
lazy val command =
|
||||||
Command.command(Name) { state =>
|
Command.command(Name) { state =>
|
||||||
val scriptArg = state.remainingCommands.headOption map { _.commandLine } getOrElse sys.error(
|
val scriptArg = state.remainingCommands.headOption map { _.commandLine } getOrElse sys.error(
|
||||||
|
|
@ -44,7 +75,11 @@ object Script {
|
||||||
else scriptArg.substring(0, dotIndex) + ".scala"
|
else scriptArg.substring(0, dotIndex) + ".scala"
|
||||||
}
|
}
|
||||||
val script = new File(src, scalaFile)
|
val script = new File(src, scalaFile)
|
||||||
IO.copyFile(scriptFile, script)
|
val linesWithoutShebang = stripShebang(IO.readLines(scriptFile))
|
||||||
|
IO.write(script, linesWithoutShebang.mkString("", "\n", "\n"))
|
||||||
|
|
||||||
|
val scriptMain = new File(src, "Main.scala")
|
||||||
|
writeWrappedScript(scriptBodyLines(script), scriptMain)
|
||||||
|
|
||||||
val (eval, structure) = Load.defaultLoad(state, base, state.log)
|
val (eval, structure) = Load.defaultLoad(state, base, state.log)
|
||||||
val session = Load.initialSession(structure, eval)
|
val session = Load.initialSession(structure, eval)
|
||||||
|
|
@ -55,12 +90,23 @@ object Script {
|
||||||
val embeddedSettings = blocks(script).flatMap { block =>
|
val embeddedSettings = blocks(script).flatMap { block =>
|
||||||
evaluate(eval(), vf, block.lines, currentUnit.imports, block.offset + 1)(currentLoader)
|
evaluate(eval(), vf, block.lines, currentUnit.imports, block.offset + 1)(currentLoader)
|
||||||
}
|
}
|
||||||
val scriptAsSource = (Compile / sources) := Def.uncached(script :: Nil)
|
val scriptBaseName = script.getName.stripSuffix(".scala")
|
||||||
val asScript =
|
val scriptAsSource = (Compile / sources) := Def.uncached {
|
||||||
scalacOptions ++= Def.uncached(Seq("-Xscript", script.getName.stripSuffix(".scala")))
|
if (ScalaArtifacts.isScala3(scalaVersion.value)) scriptMain :: Nil else script :: Nil
|
||||||
|
}
|
||||||
|
val asScript = scalacOptions := Def.uncached {
|
||||||
|
val extra =
|
||||||
|
if (ScalaArtifacts.isScala3(scalaVersion.value)) Nil
|
||||||
|
else Seq("-Xscript", scriptBaseName)
|
||||||
|
scalacOptions.value ++ extra
|
||||||
|
}
|
||||||
|
val scriptMainClass = (run / mainClass) := Def.uncached {
|
||||||
|
if (ScalaArtifacts.isScala3(scalaVersion.value)) Some("Main") else Some(scriptBaseName)
|
||||||
|
}
|
||||||
val scriptSettings = Seq(
|
val scriptSettings = Seq(
|
||||||
asScript,
|
asScript,
|
||||||
scriptAsSource,
|
scriptAsSource,
|
||||||
|
scriptMainClass,
|
||||||
(Global / logLevel) := Level.Warn,
|
(Global / logLevel) := Level.Warn,
|
||||||
(Global / showSuccess) := false
|
(Global / showSuccess) := false
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
/*
|
||||||
|
* sbt
|
||||||
|
* Copyright 2023, Scala center
|
||||||
|
* Copyright 2011 - 2022, Lightbend, Inc.
|
||||||
|
* Copyright 2008 - 2010, Mark Harrah
|
||||||
|
* Licensed under Apache License 2.0 (see LICENSE)
|
||||||
|
*/
|
||||||
|
|
||||||
|
package sbt.internal
|
||||||
|
|
||||||
|
import java.io.File
|
||||||
|
import verify.BasicTestSuite
|
||||||
|
import sbt.io.IO
|
||||||
|
|
||||||
|
object ScriptTest extends BasicTestSuite {
|
||||||
|
|
||||||
|
test("stripShebang returns same lines when first line does not start with #!") {
|
||||||
|
val lines = Seq("println(1)", "val x = 2")
|
||||||
|
val result = Script.stripShebang(lines)
|
||||||
|
assert(result == lines)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("stripShebang drops first line when it is a shebang") {
|
||||||
|
val lines = Seq("#!/usr/bin/env sbt -Dsbt.main.class=sbt.ScriptMain", "println(1)")
|
||||||
|
val result = Script.stripShebang(lines)
|
||||||
|
assert(result == Seq("println(1)"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("stripShebang leaves empty list unchanged") {
|
||||||
|
val result = Script.stripShebang(Seq.empty)
|
||||||
|
assert(result.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("stripShebang leaves single non-shebang line unchanged") {
|
||||||
|
val lines = Seq("println(42)")
|
||||||
|
val result = Script.stripShebang(lines)
|
||||||
|
assert(result == lines)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("stripShebang drops only first line when it is #!") {
|
||||||
|
val lines = Seq("#!", "*/", "println(1)")
|
||||||
|
val result = Script.stripShebang(lines)
|
||||||
|
assert(result == Seq("*/", "println(1)"))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("scriptBodyLines returns all lines when file has no blocks") {
|
||||||
|
val f = File.createTempFile("script", ".scala")
|
||||||
|
try {
|
||||||
|
IO.write(f, "println(1)\nval x = 2\n")
|
||||||
|
val result = Script.scriptBodyLines(f)
|
||||||
|
assert(result.contains("println(1)"))
|
||||||
|
assert(result.contains("val x = 2"))
|
||||||
|
assert(!result.exists(_.startsWith("/***")))
|
||||||
|
} finally f.delete()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("scriptBodyLines excludes lines inside /*** */ block") {
|
||||||
|
val f = File.createTempFile("script", ".scala")
|
||||||
|
try {
|
||||||
|
IO.write(
|
||||||
|
f,
|
||||||
|
"""println("before")
|
||||||
|
|/***
|
||||||
|
|scalaVersion := "3.0.0"
|
||||||
|
|*/
|
||||||
|
|println("after")
|
||||||
|
|""".stripMargin
|
||||||
|
)
|
||||||
|
val result = Script.scriptBodyLines(f)
|
||||||
|
assert(result.contains("println(\"before\")"))
|
||||||
|
assert(result.contains("println(\"after\")"))
|
||||||
|
assert(!result.contains("scalaVersion := \"3.0.0\""))
|
||||||
|
assert(
|
||||||
|
!result.contains("*/"),
|
||||||
|
"closing */ must not appear in script body (would break wrapped Main.scala)"
|
||||||
|
)
|
||||||
|
} finally f.delete()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("scriptBodyLines excludes block content when file has only a block") {
|
||||||
|
val f = File.createTempFile("script", ".scala")
|
||||||
|
try {
|
||||||
|
IO.write(
|
||||||
|
f,
|
||||||
|
"""/***
|
||||||
|
|scalaVersion := "3.0.0"
|
||||||
|
|*/
|
||||||
|
|""".stripMargin
|
||||||
|
)
|
||||||
|
val result = Script.scriptBodyLines(f)
|
||||||
|
assert(
|
||||||
|
!result.contains("scalaVersion := \"3.0.0\""),
|
||||||
|
s"block content must be excluded, got $result"
|
||||||
|
)
|
||||||
|
} finally f.delete()
|
||||||
|
}
|
||||||
|
|
||||||
|
test("blocks parses block containing settings") {
|
||||||
|
val f = File.createTempFile("script", ".scala")
|
||||||
|
try {
|
||||||
|
IO.write(
|
||||||
|
f,
|
||||||
|
"""line0
|
||||||
|
|/***
|
||||||
|
|scalaVersion := "3.0.0"
|
||||||
|
|*/
|
||||||
|
|line3
|
||||||
|
|""".stripMargin
|
||||||
|
)
|
||||||
|
val result = Script.blocks(f)
|
||||||
|
val settingBlock = result.find(_.lines.contains("scalaVersion := \"3.0.0\""))
|
||||||
|
assert(settingBlock.isDefined, s"expected a block with scalaVersion, got $result")
|
||||||
|
} finally f.delete()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -53,5 +53,6 @@ object ArgParser:
|
||||||
.parse(parser, args, LauncherOptions())
|
.parse(parser, args, LauncherOptions())
|
||||||
.map: opts =>
|
.map: opts =>
|
||||||
val sbtNew = opts.residual.contains("new") || opts.residual.contains("init")
|
val sbtNew = opts.residual.contains("new") || opts.residual.contains("init")
|
||||||
opts.copy(sbtNew = sbtNew)
|
val isScript = opts.residual.exists(_.startsWith("-Dsbt.main.class=sbt.ScriptMain"))
|
||||||
|
opts.copy(sbtNew = sbtNew, allowEmpty = opts.allowEmpty || isScript)
|
||||||
end ArgParser
|
end ArgParser
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue