diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 1bb3351e0..0bbc13f88 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -184,6 +184,7 @@ final class ScriptMain extends xsbti.AppMain { private[sbt] object ScriptMain { private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = { import BasicCommandStrings.runEarly + Plugins.defaultRequires = sbt.plugins.JvmPlugin val state = StandardMain.initialState( xMain.dealiasBaseDirectory(configuration), BuiltinCommands.ScriptCommands, diff --git a/main/src/main/scala/sbt/internal/Script.scala b/main/src/main/scala/sbt/internal/Script.scala index a849d4ca0..cc7923808 100644 --- a/main/src/main/scala/sbt/internal/Script.scala +++ b/main/src/main/scala/sbt/internal/Script.scala @@ -9,7 +9,7 @@ package sbt package internal -import sbt.librarymanagement.Configurations +import sbt.librarymanagement.{ Configurations, ScalaArtifacts } import sbt.util.Level @@ -26,6 +26,37 @@ import scala.annotation.tailrec object Script { final val Name = "script" + // When shebang is stripped, compiler error line numbers may be off by one for the original file; + // position mapping could be added in a future improvement (see sbt/sbt#6274). + /** If the first line is a shebang (#!), drop it so the compiler never sees it. */ + private[internal] def stripShebang(lines: Seq[String]): Seq[String] = + if (lines.nonEmpty && lines.head.startsWith("#!")) lines.drop(1) else lines + + /** Lines that are not inside any /*** ... */ block (i.e. the executable script body). */ + private[internal] def scriptBodyLines(file: File): Seq[String] = { + val lines = IO.readLines(file).toIndexedSeq + // Block(offset, lines): offset = index of /*** line, lines = content between /*** and */ (excl. both). + // Exclude /*** (off), content (off+1..off+ls.size), and */ (off+1+ls.size). + val blockSet = blocks(file).flatMap { + case Block(off, ls) => (off until off + 2 + ls.size) + }.toSet + lines.indices.filterNot(blockSet).map(lines) + } + + /** Write a Scala 3 compilable file that wraps the script body in object Main { def main(...) = { ... } }. */ + private def writeWrappedScript(body: Seq[String], out: File): Unit = { + val indent = " " + val inner = body.map(line => indent + line).mkString("\n") + val content = + s"""object Main { + | def main(args: Array[String]): Unit = { + |$inner + | } + |} + |""".stripMargin + IO.write(out, content) + } + lazy val command = Command.command(Name) { state => val scriptArg = state.remainingCommands.headOption map { _.commandLine } getOrElse sys.error( @@ -44,7 +75,11 @@ object Script { else scriptArg.substring(0, dotIndex) + ".scala" } val script = new File(src, scalaFile) - IO.copyFile(scriptFile, script) + val linesWithoutShebang = stripShebang(IO.readLines(scriptFile)) + IO.write(script, linesWithoutShebang.mkString("", "\n", "\n")) + + val scriptMain = new File(src, "Main.scala") + writeWrappedScript(scriptBodyLines(script), scriptMain) val (eval, structure) = Load.defaultLoad(state, base, state.log) val session = Load.initialSession(structure, eval) @@ -55,12 +90,23 @@ object Script { val embeddedSettings = blocks(script).flatMap { block => evaluate(eval(), vf, block.lines, currentUnit.imports, block.offset + 1)(currentLoader) } - val scriptAsSource = (Compile / sources) := Def.uncached(script :: Nil) - val asScript = - scalacOptions ++= Def.uncached(Seq("-Xscript", script.getName.stripSuffix(".scala"))) + val scriptBaseName = script.getName.stripSuffix(".scala") + val scriptAsSource = (Compile / sources) := Def.uncached { + if (ScalaArtifacts.isScala3(scalaVersion.value)) scriptMain :: Nil else script :: Nil + } + val asScript = scalacOptions := Def.uncached { + val extra = + if (ScalaArtifacts.isScala3(scalaVersion.value)) Nil + else Seq("-Xscript", scriptBaseName) + scalacOptions.value ++ extra + } + val scriptMainClass = (run / mainClass) := Def.uncached { + if (ScalaArtifacts.isScala3(scalaVersion.value)) Some("Main") else Some(scriptBaseName) + } val scriptSettings = Seq( asScript, scriptAsSource, + scriptMainClass, (Global / logLevel) := Level.Warn, (Global / showSuccess) := false ) diff --git a/main/src/test/scala/sbt/internal/ScriptTest.scala b/main/src/test/scala/sbt/internal/ScriptTest.scala new file mode 100644 index 000000000..c6dcc972e --- /dev/null +++ b/main/src/test/scala/sbt/internal/ScriptTest.scala @@ -0,0 +1,115 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal + +import java.io.File +import verify.BasicTestSuite +import sbt.io.IO + +object ScriptTest extends BasicTestSuite { + + test("stripShebang returns same lines when first line does not start with #!") { + val lines = Seq("println(1)", "val x = 2") + val result = Script.stripShebang(lines) + assert(result == lines) + } + + test("stripShebang drops first line when it is a shebang") { + val lines = Seq("#!/usr/bin/env sbt -Dsbt.main.class=sbt.ScriptMain", "println(1)") + val result = Script.stripShebang(lines) + assert(result == Seq("println(1)")) + } + + test("stripShebang leaves empty list unchanged") { + val result = Script.stripShebang(Seq.empty) + assert(result.isEmpty) + } + + test("stripShebang leaves single non-shebang line unchanged") { + val lines = Seq("println(42)") + val result = Script.stripShebang(lines) + assert(result == lines) + } + + test("stripShebang drops only first line when it is #!") { + val lines = Seq("#!", "*/", "println(1)") + val result = Script.stripShebang(lines) + assert(result == Seq("*/", "println(1)")) + } + + test("scriptBodyLines returns all lines when file has no blocks") { + val f = File.createTempFile("script", ".scala") + try { + IO.write(f, "println(1)\nval x = 2\n") + val result = Script.scriptBodyLines(f) + assert(result.contains("println(1)")) + assert(result.contains("val x = 2")) + assert(!result.exists(_.startsWith("/***"))) + } finally f.delete() + } + + test("scriptBodyLines excludes lines inside /*** */ block") { + val f = File.createTempFile("script", ".scala") + try { + IO.write( + f, + """println("before") + |/*** + |scalaVersion := "3.0.0" + |*/ + |println("after") + |""".stripMargin + ) + val result = Script.scriptBodyLines(f) + assert(result.contains("println(\"before\")")) + assert(result.contains("println(\"after\")")) + assert(!result.contains("scalaVersion := \"3.0.0\"")) + assert( + !result.contains("*/"), + "closing */ must not appear in script body (would break wrapped Main.scala)" + ) + } finally f.delete() + } + + test("scriptBodyLines excludes block content when file has only a block") { + val f = File.createTempFile("script", ".scala") + try { + IO.write( + f, + """/*** + |scalaVersion := "3.0.0" + |*/ + |""".stripMargin + ) + val result = Script.scriptBodyLines(f) + assert( + !result.contains("scalaVersion := \"3.0.0\""), + s"block content must be excluded, got $result" + ) + } finally f.delete() + } + + test("blocks parses block containing settings") { + val f = File.createTempFile("script", ".scala") + try { + IO.write( + f, + """line0 + |/*** + |scalaVersion := "3.0.0" + |*/ + |line3 + |""".stripMargin + ) + val result = Script.blocks(f) + val settingBlock = result.find(_.lines.contains("scalaVersion := \"3.0.0\"")) + assert(settingBlock.isDefined, s"expected a block with scalaVersion, got $result") + } finally f.delete() + } +} diff --git a/sbtw/src/main/scala/sbtw/ArgParser.scala b/sbtw/src/main/scala/sbtw/ArgParser.scala index 66c3eb233..2f7d7548e 100644 --- a/sbtw/src/main/scala/sbtw/ArgParser.scala +++ b/sbtw/src/main/scala/sbtw/ArgParser.scala @@ -53,5 +53,6 @@ object ArgParser: .parse(parser, args, LauncherOptions()) .map: opts => val sbtNew = opts.residual.contains("new") || opts.residual.contains("init") - opts.copy(sbtNew = sbtNew) + val isScript = opts.residual.exists(_.startsWith("-Dsbt.main.class=sbt.ScriptMain")) + opts.copy(sbtNew = sbtNew, allowEmpty = opts.allowEmpty || isScript) end ArgParser