[2.x] feat: Support fork in console task (#8604)

When enabled, the Scala REPL runs in a separate JVM.
This commit is contained in:
calm 2026-01-25 01:16:49 -06:00 committed by GitHub
parent ba8c340a2b
commit 9951a302c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 519 additions and 2 deletions

View File

@ -0,0 +1,142 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt
package internal
import java.io.File
import java.nio.file.Paths
import sbt.internal.inc.{
AnalyzingCompiler,
PlainVirtualFile,
MappedFileConverter,
ScalaInstance,
ZincUtil
}
import sbt.internal.inc.classpath.ClasspathUtil
import sbt.internal.worker.{ ConsoleConfig, ScalaInstanceConfig }
import sbt.io.IO
import sbt.util.{ Level, Logger }
import sjsonnew.support.scalajson.unsafe.{ Parser, Converter }
import xsbti.compile.ClasspathOptionsUtil
/**
* Entry point for the forked console. This class creates a Scala REPL
* in the forked JVM with proper terminal support.
*/
class ConsoleMain:
def run(config: ConsoleConfig): Unit =
val si = scalaInstance(config.scalaInstanceConfig)
val compiler = analyzingCompiler(config, si)
given log: Logger = ConsoleMain.consoleLogger
val externalCp = config.externalDependencyJars.map(Paths.get(_))
val cpFiles = externalCp.map(_.toFile)
IO.withTemporaryDirectory: tempDir =>
val fullCp = cpFiles ++ si.allJars
val loader = ClasspathUtil.makeLoader(fullCp.map(_.toPath), si, tempDir.toPath)
runConsole(
compiler = compiler,
classpath = cpFiles,
options = config.scalacOptions,
loader = loader,
initialCommands = config.initialCommands,
cleanupCommands = config.cleanupCommands,
)(using log)
private def runConsole(
compiler: AnalyzingCompiler,
classpath: Seq[File],
options: Seq[String],
loader: ClassLoader,
initialCommands: String,
cleanupCommands: String,
)(using log: Logger): Unit =
compiler.console(
classpath.map(x => PlainVirtualFile(x.toPath)),
MappedFileConverter.empty,
options,
initialCommands,
cleanupCommands,
log,
)(
Some(loader),
Nil
)
def analyzingCompiler(config: ConsoleConfig, si: ScalaInstance): AnalyzingCompiler =
val bridgeProvider = ZincUtil.constantBridgeProvider(si, File(config.bridgeJar))
val classpathOptions = ClasspathOptionsUtil.repl()
AnalyzingCompiler(
si,
bridgeProvider,
classpathOptions,
_ => (),
None
)
def scalaInstance(siConfig: ScalaInstanceConfig): ScalaInstance =
val libraryJars = siConfig.libraryJars.map(Paths.get(_)).sortBy(_.getFileName.toString)
val allCompilerJars = siConfig.allCompilerJars
.map(Paths.get(_))
.sortBy(_.getFileName.toString)
val jlineJars = allCompilerJars.filter(_.getFileName.toString.contains("jline"))
val compilerJars =
allCompilerJars.filterNot(x => libraryJars.contains(x) || jlineJars.contains(x)).distinct
val allDocJars = siConfig.allDocJars.map(Paths.get(_)).sortBy(_.getFileName.toString)
val docJars = allDocJars
.filterNot(jar => libraryJars.contains(jar) || compilerJars.contains(jar))
.distinct
val allJars = libraryJars ++ compilerJars ++ docJars
// Use parent class loader for JLine to avoid conflicts
val jlineLoader = classOf[org.jline.terminal.Terminal].getClassLoader
val libraryLoader = ClasspathUtil.toLoader(libraryJars, jlineLoader)
val compilerLoader = ClasspathUtil.toLoader(compilerJars, libraryLoader)
val fullLoader =
if docJars.isEmpty then compilerLoader
else ClasspathUtil.toLoader(docJars, compilerLoader)
new ScalaInstance(
version = siConfig.scalaVersion,
loader = fullLoader,
loaderCompilerOnly = compilerLoader,
loaderLibraryOnly = libraryLoader,
libraryJars = libraryJars.map(_.toFile).toArray,
compilerJars = compilerJars.map(_.toFile).toArray,
allJars = allJars.map(_.toFile).toArray,
explicitActual = Some(siConfig.scalaVersion)
)
end ConsoleMain
object ConsoleMain:
/** A simple console logger for the forked REPL process. */
private val consoleLogger: Logger = new Logger:
override def trace(t: => Throwable): Unit = t.printStackTrace()
override def success(message: => String): Unit = log(Level.Info, message)
override def log(level: Level.Value, message: => String): Unit =
level match
case Level.Debug => () // Suppress debug messages
case Level.Info => scala.Console.out.println(message)
case Level.Warn => scala.Console.err.println(s"[warn] $message")
case Level.Error => scala.Console.err.println(s"[error] $message")
def main(args: Array[String]): Unit =
args.toList match
case Nil =>
scala.Console.err.println("ConsoleMain requires a config file argument starting with @")
sys.exit(1)
case arg :: Nil if arg.startsWith("@") =>
import sbt.internal.worker.codec.JsonProtocol.given
val configFile = arg.drop(1)
val content = IO.read(File(configFile))
val json = Parser.parseFromString(content).get
val config = Converter.fromJson[ConsoleConfig](json).get
val main = ConsoleMain()
main.run(config)
case _ =>
scala.Console.err.println("ConsoleMain requires exactly one argument: @<config-file>")
sys.exit(1)
end ConsoleMain

View File

@ -0,0 +1,122 @@
/*
* sbt
* Copyright 2023, Scala center
* Copyright 2011 - 2022, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt
package internal
import java.io.File
import java.net.URLClassLoader
import java.nio.file.{ Path, Paths }
import java.lang.{ ProcessBuilder as JProcessBuilder }
import sbt.internal.worker.ConsoleConfig
import sbt.io.IO
import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter }
/**
* Utilities for running the Scala console in a forked JVM.
*/
private[sbt] object ForkConsole:
/**
* Run the Scala console in a forked JVM.
*
* @param config Configuration for the console
* @param forkOptions Fork options (javaHome, jvmOptions, etc.)
* @return Exit code of the forked process
*/
def apply(config: ConsoleConfig, forkOptions: ForkOptions): Int =
IO.withTemporaryDirectory: tempDir =>
import sbt.internal.worker.codec.JsonProtocol.given
val json = Converter.toJson[ConsoleConfig](config).get
val params = tempDir.toPath.resolve("console-params.json")
IO.write(params.toFile, CompactPrinter(json))
run(
mainClass = classOf[ConsoleMain].getCanonicalName,
classpath = currentClasspath,
args = List(s"@$params"),
forkOptions = forkOptions,
)
/**
* Run an arbitrary main class in a forked JVM with full terminal inheritance.
* This is critical for interactive console to work properly with JLine.
*/
def run(
mainClass: String,
classpath: List[Path],
args: List[String],
forkOptions: ForkOptions,
): Int =
val jlineJars = Seq(
IO.classLocationPath(classOf[jline.Terminal]),
IO.classLocationPath(classOf[org.jline.terminal.Terminal]),
IO.classLocationPath(classOf[org.jline.reader.LineReader]),
IO.classLocationPath(classOf[org.jline.utils.InfoCmp]),
IO.classLocationPath(classOf[org.jline.keymap.KeyMap[?]]),
).distinct
val fullCp = (classpath ++ jlineJars).distinct
// Build environment variables for proper terminal handling
val termEnv = sys.env.get("TERM").getOrElse("xterm-256color")
val baseEnv = forkOptions.envVars ++ Map(
"TERM" -> termEnv,
"COLORTERM" -> sys.env.getOrElse("COLORTERM", "truecolor"),
)
// Add JLine-related JVM options to help with terminal detection
val jlineJvmOpts = Seq(
s"-Dorg.jline.terminal.type=$termEnv",
"-Djline.terminal=auto",
)
val allJvmOpts = forkOptions.runJVMOptions ++ jlineJvmOpts
// Build the java command
val javaHome = forkOptions.javaHome.getOrElse(new File(System.getProperty("java.home")))
val javaCmd = new File(new File(javaHome, "bin"), "java").getAbsolutePath
// Build full command line
val cmdArgs = Seq(javaCmd) ++
allJvmOpts ++
Seq("-classpath", fullCp.mkString(File.pathSeparator), mainClass) ++
args
// Use ProcessBuilder directly with inheritIO() for proper terminal handling
// This is critical for JLine arrow keys to work - all streams must be inherited
val jpb = new JProcessBuilder(cmdArgs*)
jpb.inheritIO() // Inherit stdin, stdout, stderr from parent process
forkOptions.workingDirectory.foreach(jpb.directory(_))
// Set environment variables
val env = jpb.environment()
baseEnv.foreach { case (k, v) => env.put(k, v) }
// Start and wait for process
val process = jpb.start()
process.waitFor()
/**
* Get the classpath of the current class loader.
* This is used to pass the sbt classes to the forked JVM.
*/
def currentClasspath: List[Path] =
val cl = classOf[ForkConsole.type].getClassLoader match
case cl: URLClassLoader => cl
case other =>
throw RuntimeException(
s"Expected URLClassLoader but got ${other.getClass.getName}"
)
val urls = cl.getURLs.toList
val extraJars = Vector(
IO.classLocationPath(classOf[xsbti.compile.ScalaInstance]),
IO.classLocationPath(classOf[xsbti.Logger]),
IO.classLocationPath(classOf[sbt.internal.inc.AnalyzingCompiler]),
IO.classLocationPath(classOf[sbt.internal.inc.classpath.ClasspathUtil.type]),
IO.classLocationPath(classOf[sbt.util.Logger]),
IO.classLocationPath(classOf[sjsonnew.JsonFormat[?]]),
)
(urls.map(u => Paths.get(u.toURI)) ++ extraJars).distinct
end ForkConsole

View File

@ -1052,7 +1052,10 @@ object Defaults extends BuildCommon {
cache.get
},
compileIncSetup := Def.uncached(compileIncSetupTask.value),
console := consoleTask.value,
console := Def.taskDyn {
if (console / fork).value then forkedConsoleTask
else Def.task(consoleTask.value)
}.value,
collectAnalyses := Definition.collectAnalysesTask.map(_ => ()).value,
consoleQuick := consoleQuickTask.value,
discoveredMainClasses := compile
@ -2118,6 +2121,52 @@ object Defaults extends BuildCommon {
println()
}
private def forkedConsoleTask: Initialize[Task[Unit]] =
Def.task {
import sbt.internal.worker.{ ConsoleConfig, ScalaInstanceConfig }
val si = (console / scalaInstance).value
val conv = fileConverter.value
val depsJars = (console / externalDependencyClasspath).value.toVector
.map(_.data)
.map(conv.toPath)
val bridgeJars = scalaCompilerBridgeBin.value
val bridgeJar =
if bridgeJars.nonEmpty then conv.toPath(bridgeJars.head).toFile
else
// Fall back to fetching the bridge module
val dr = scalaCompilerBridgeDependencyResolution.value
val uc = (update / updateConfiguration).value
val uwc = (update / unresolvedWarningConfiguration).value
ZincLmUtil.fetchDefaultBridgeModule(
si.version,
dr,
uc,
uwc,
streams.value.log
)
val siConfig = ScalaInstanceConfig(
scalaVersion = si.version,
libraryJars = si.libraryJars.map(_.toString).toVector,
allCompilerJars = si.compilerJars.map(_.toString).toVector,
allDocJars = Vector.empty,
)
val config = ConsoleConfig(
scalaInstanceConfig = siConfig,
bridgeJar = bridgeJar.toString,
externalDependencyJars = depsJars.map(_.toString),
scalacOptions = (console / scalacOptions).value.toVector,
initialCommands = (console / initialCommands).value,
cleanupCommands = (console / cleanupCommands).value,
)
val fo = (console / forkOptions).value
val terminal = ITerminal.console
terminal.restore()
val exitCode = ForkConsole(config, fo)
terminal.restore()
if exitCode != 0 then throw MessageOnlyException(s"Forked console exited with code $exitCode")
println()
}
private def exported(w: PrintWriter, command: String): Seq[String] => Unit =
args => w.println((command +: args).mkString(" "))

View File

@ -0,0 +1,53 @@
/**
* This code is generated using [[https://www.scala-sbt.org/contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt.internal.worker
/** Configuration for forked console. */
final class ConsoleConfig private (
val scalaInstanceConfig: sbt.internal.worker.ScalaInstanceConfig,
val bridgeJar: String,
val externalDependencyJars: Vector[String],
val scalacOptions: Vector[String],
val initialCommands: String,
val cleanupCommands: String) extends Serializable {
override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match {
case x: ConsoleConfig => (this.scalaInstanceConfig == x.scalaInstanceConfig) && (this.bridgeJar == x.bridgeJar) && (this.externalDependencyJars == x.externalDependencyJars) && (this.scalacOptions == x.scalacOptions) && (this.initialCommands == x.initialCommands) && (this.cleanupCommands == x.cleanupCommands)
case _ => false
})
override def hashCode: Int = {
37 * (37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.internal.worker.ConsoleConfig".##) + scalaInstanceConfig.##) + bridgeJar.##) + externalDependencyJars.##) + scalacOptions.##) + initialCommands.##) + cleanupCommands.##)
}
override def toString: String = {
"ConsoleConfig(" + scalaInstanceConfig + ", " + bridgeJar + ", " + externalDependencyJars + ", " + scalacOptions + ", " + initialCommands + ", " + cleanupCommands + ")"
}
private def copy(scalaInstanceConfig: sbt.internal.worker.ScalaInstanceConfig = scalaInstanceConfig, bridgeJar: String = bridgeJar, externalDependencyJars: Vector[String] = externalDependencyJars, scalacOptions: Vector[String] = scalacOptions, initialCommands: String = initialCommands, cleanupCommands: String = cleanupCommands): ConsoleConfig = {
new ConsoleConfig(scalaInstanceConfig, bridgeJar, externalDependencyJars, scalacOptions, initialCommands, cleanupCommands)
}
def withScalaInstanceConfig(scalaInstanceConfig: sbt.internal.worker.ScalaInstanceConfig): ConsoleConfig = {
copy(scalaInstanceConfig = scalaInstanceConfig)
}
def withBridgeJar(bridgeJar: String): ConsoleConfig = {
copy(bridgeJar = bridgeJar)
}
def withExternalDependencyJars(externalDependencyJars: Vector[String]): ConsoleConfig = {
copy(externalDependencyJars = externalDependencyJars)
}
def withScalacOptions(scalacOptions: Vector[String]): ConsoleConfig = {
copy(scalacOptions = scalacOptions)
}
def withInitialCommands(initialCommands: String): ConsoleConfig = {
copy(initialCommands = initialCommands)
}
def withCleanupCommands(cleanupCommands: String): ConsoleConfig = {
copy(cleanupCommands = cleanupCommands)
}
}
object ConsoleConfig {
def apply(scalaInstanceConfig: sbt.internal.worker.ScalaInstanceConfig, bridgeJar: String, externalDependencyJars: Vector[String], scalacOptions: Vector[String], initialCommands: String, cleanupCommands: String): ConsoleConfig = new ConsoleConfig(scalaInstanceConfig, bridgeJar, externalDependencyJars, scalacOptions, initialCommands, cleanupCommands)
}

View File

@ -0,0 +1,45 @@
/**
* This code is generated using [[https://www.scala-sbt.org/contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt.internal.worker
/** Configuration for creating a ScalaInstance in forked process. */
final class ScalaInstanceConfig private (
val scalaVersion: String,
val libraryJars: Vector[String],
val allCompilerJars: Vector[String],
val allDocJars: Vector[String]) extends Serializable {
override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match {
case x: ScalaInstanceConfig => (this.scalaVersion == x.scalaVersion) && (this.libraryJars == x.libraryJars) && (this.allCompilerJars == x.allCompilerJars) && (this.allDocJars == x.allDocJars)
case _ => false
})
override def hashCode: Int = {
37 * (37 * (37 * (37 * (37 * (17 + "sbt.internal.worker.ScalaInstanceConfig".##) + scalaVersion.##) + libraryJars.##) + allCompilerJars.##) + allDocJars.##)
}
override def toString: String = {
"ScalaInstanceConfig(" + scalaVersion + ", " + libraryJars + ", " + allCompilerJars + ", " + allDocJars + ")"
}
private def copy(scalaVersion: String = scalaVersion, libraryJars: Vector[String] = libraryJars, allCompilerJars: Vector[String] = allCompilerJars, allDocJars: Vector[String] = allDocJars): ScalaInstanceConfig = {
new ScalaInstanceConfig(scalaVersion, libraryJars, allCompilerJars, allDocJars)
}
def withScalaVersion(scalaVersion: String): ScalaInstanceConfig = {
copy(scalaVersion = scalaVersion)
}
def withLibraryJars(libraryJars: Vector[String]): ScalaInstanceConfig = {
copy(libraryJars = libraryJars)
}
def withAllCompilerJars(allCompilerJars: Vector[String]): ScalaInstanceConfig = {
copy(allCompilerJars = allCompilerJars)
}
def withAllDocJars(allDocJars: Vector[String]): ScalaInstanceConfig = {
copy(allDocJars = allDocJars)
}
}
object ScalaInstanceConfig {
def apply(scalaVersion: String, libraryJars: Vector[String], allCompilerJars: Vector[String], allDocJars: Vector[String]): ScalaInstanceConfig = new ScalaInstanceConfig(scalaVersion, libraryJars, allCompilerJars, allDocJars)
}

View File

@ -0,0 +1,37 @@
/**
* This code is generated using [[https://www.scala-sbt.org/contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt.internal.worker.codec
import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError }
trait ConsoleConfigFormats { self: sbt.internal.worker.codec.ScalaInstanceConfigFormats & sjsonnew.BasicJsonProtocol =>
given ConsoleConfigFormat: JsonFormat[sbt.internal.worker.ConsoleConfig] = new JsonFormat[sbt.internal.worker.ConsoleConfig] {
override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.worker.ConsoleConfig = {
__jsOpt match {
case Some(__js) =>
unbuilder.beginObject(__js)
val scalaInstanceConfig = unbuilder.readField[sbt.internal.worker.ScalaInstanceConfig]("scalaInstanceConfig")
val bridgeJar = unbuilder.readField[String]("bridgeJar")
val externalDependencyJars = unbuilder.readField[Vector[String]]("externalDependencyJars")
val scalacOptions = unbuilder.readField[Vector[String]]("scalacOptions")
val initialCommands = unbuilder.readField[String]("initialCommands")
val cleanupCommands = unbuilder.readField[String]("cleanupCommands")
unbuilder.endObject()
sbt.internal.worker.ConsoleConfig(scalaInstanceConfig, bridgeJar, externalDependencyJars, scalacOptions, initialCommands, cleanupCommands)
case None =>
deserializationError("Expected JsObject but found None")
}
}
override def write[J](obj: sbt.internal.worker.ConsoleConfig, builder: Builder[J]): Unit = {
builder.beginObject()
builder.addField("scalaInstanceConfig", obj.scalaInstanceConfig)
builder.addField("bridgeJar", obj.bridgeJar)
builder.addField("externalDependencyJars", obj.externalDependencyJars)
builder.addField("scalacOptions", obj.scalacOptions)
builder.addField("initialCommands", obj.initialCommands)
builder.addField("cleanupCommands", obj.cleanupCommands)
builder.endObject()
}
}
}

View File

@ -10,4 +10,6 @@ trait JsonProtocol extends sjsonnew.BasicJsonProtocol
with sbt.internal.worker.codec.NativeRunInfoFormats
with sbt.internal.worker.codec.RunInfoFormats
with sbt.internal.worker.codec.ClientJobParamsFormats
with sbt.internal.worker.codec.ScalaInstanceConfigFormats
with sbt.internal.worker.codec.ConsoleConfigFormats
object JsonProtocol extends JsonProtocol

View File

@ -0,0 +1,33 @@
/**
* This code is generated using [[https://www.scala-sbt.org/contraband]].
*/
// DO NOT EDIT MANUALLY
package sbt.internal.worker.codec
import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError }
trait ScalaInstanceConfigFormats { self: sjsonnew.BasicJsonProtocol =>
given ScalaInstanceConfigFormat: JsonFormat[sbt.internal.worker.ScalaInstanceConfig] = new JsonFormat[sbt.internal.worker.ScalaInstanceConfig] {
override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.worker.ScalaInstanceConfig = {
__jsOpt match {
case Some(__js) =>
unbuilder.beginObject(__js)
val scalaVersion = unbuilder.readField[String]("scalaVersion")
val libraryJars = unbuilder.readField[Vector[String]]("libraryJars")
val allCompilerJars = unbuilder.readField[Vector[String]]("allCompilerJars")
val allDocJars = unbuilder.readField[Vector[String]]("allDocJars")
unbuilder.endObject()
sbt.internal.worker.ScalaInstanceConfig(scalaVersion, libraryJars, allCompilerJars, allDocJars)
case None =>
deserializationError("Expected JsObject but found None")
}
}
override def write[J](obj: sbt.internal.worker.ScalaInstanceConfig, builder: Builder[J]): Unit = {
builder.beginObject()
builder.addField("scalaVersion", obj.scalaVersion)
builder.addField("libraryJars", obj.libraryJars)
builder.addField("allCompilerJars", obj.allCompilerJars)
builder.addField("allDocJars", obj.allDocJars)
builder.endObject()
}
}
}

View File

@ -43,10 +43,28 @@ type RunInfo {
## Client-side job support.
##
## Notification: sbt/clientJob
##
##
## Parameter for the sbt/clientJob notification.
## A client-side job represents a unit of work that sbt server
## can outsourse back to the client, for example for run task.
type ClientJobParams {
runInfo: sbt.internal.worker.RunInfo
}
## Configuration for creating a ScalaInstance in forked process.
type ScalaInstanceConfig {
scalaVersion: String!
libraryJars: [String]
allCompilerJars: [String]
allDocJars: [String]
}
## Configuration for forked console.
type ConsoleConfig {
scalaInstanceConfig: sbt.internal.worker.ScalaInstanceConfig!
bridgeJar: String!
externalDependencyJars: [String]
scalacOptions: [String]
initialCommands: String!
cleanupCommands: String!
}

View File

@ -0,0 +1,13 @@
scalaVersion := "3.7.4"
// Enable forked console
Compile / console / fork := true
// Test that javaOptions are passed to the forked console
Compile / console / javaOptions += "-Xmx256m"
// Test that initialCommands work in forked console
Compile / console / initialCommands := """println("Forked console initialized!")"""
// Test that cleanupCommands work in forked console
Compile / console / cleanupCommands := """println("Forked console cleanup!")"""

View File

@ -0,0 +1,3 @@
# Test that fork in console setting is accepted
> set Compile / console / fork := true
> show Compile / console / fork