From ba345dd7977ca1e6db995a8fe295fd232c0c9be7 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Wed, 18 Dec 2019 10:24:32 -0800 Subject: [PATCH] Add multi-client ui to server This commit makes it possible for the sbt server to render the same ui to multiple clients. The network client ui should look nearly identical to the console ui except for the log messages about the experimental client. The way that it works is that it associates a ui thread with each terminal. Whenever a command starts or completes, callbacks are invoked on the various channels to update their ui state. For example, if there are two clients and one of them runs compile, then the prompt is changed from AskUser to Running for the terminal that initiated the command while the other client remains in the AskUser state. Whenever the client changes uses ui states, the existing thread is terminated if it is running and a new thread is begun. The UITask formalizes this process. It is based on the AskUser class from older versions of sbt. In fact, there is an AskUserTask which is very similar. It uses jline to read input from the terminal (which could be a network terminal). When it gets a line, it submits it to the CommandExchange and exits. Once the next command is run (which may or may not be the command it submitted), the ui state will be reset. The debug, info, warn and error commands should work with the multi client ui. When run, they set the log level globally, not just for the client that set the level. --- build.sbt | 4 + .../sbt/internal/util/ManagedLogger.scala | 15 +- .../src/main/scala/sbt/util/LogExchange.scala | 2 +- .../main/scala/sbt/BasicCommandStrings.scala | 3 + .../src/main/scala/sbt/BasicCommands.scala | 8 +- .../src/main/scala/sbt/BasicKeys.scala | 7 +- .../scala/sbt/internal/CommandChannel.scala | 28 ++- .../scala/sbt/internal/ConsoleChannel.scala | 76 ++------ .../internal/client/ServerConnection.scala | 39 ++-- .../scala/sbt/internal/server/Server.scala | 2 +- .../main/scala/sbt/internal/ui/UITask.scala | 99 +++++++++++ .../scala/sbt/internal/ui/UserThread.scala | 91 ++++++++++ main-settings/src/main/scala/sbt/Def.scala | 11 +- main/src/main/scala/sbt/Defaults.scala | 17 +- main/src/main/scala/sbt/Keys.scala | 1 + main/src/main/scala/sbt/Main.scala | 36 +++- main/src/main/scala/sbt/MainLoop.scala | 26 ++- main/src/main/scala/sbt/Project.scala | 3 + .../scala/sbt/internal/CommandExchange.scala | 166 ++++++++++-------- .../scala/sbt/internal/ConsoleProject.scala | 8 +- .../main/scala/sbt/internal/Continuous.scala | 2 +- .../main/scala/sbt/internal/LogManager.scala | 5 +- .../scala/sbt/internal/TaskProgress.scala | 2 +- .../sbt/internal/XMainConfiguration.scala | 19 +- .../sbt/internal/server/NetworkChannel.scala | 86 +++++++-- main/src/test/scala/PluginCommandTest.scala | 23 +-- project/build.properties | 2 +- server-test/src/server-test/client/build.sbt | 7 + .../server-test/client/src/main/scala/A.scala | 1 + .../client/src/test/scala/FooSpec.scala | 3 + server-test/src/server-test/events/Main.scala | 6 +- .../src/server-test/handshake/build.sbt | 15 +- .../src/test/scala/testpkg/EventsTest.scala | 57 ++++-- .../src/test/scala/testpkg/ResponseTest.scala | 20 +-- .../src/test/scala/testpkg/TestServer.scala | 4 +- 35 files changed, 633 insertions(+), 261 deletions(-) create mode 100644 main-command/src/main/scala/sbt/internal/ui/UITask.scala create mode 100644 main-command/src/main/scala/sbt/internal/ui/UserThread.scala create mode 100644 server-test/src/server-test/client/build.sbt create mode 100644 server-test/src/server-test/client/src/main/scala/A.scala create mode 100644 server-test/src/server-test/client/src/test/scala/FooSpec.scala diff --git a/build.sbt b/build.sbt index a70ce498f..dfbbcea83 100644 --- a/build.sbt +++ b/build.sbt @@ -955,6 +955,10 @@ lazy val mainProj = (project in file("main")) exclude[DirectMissingMethodProblem]("sbt.Classpaths.warnInsecureProtocol"), exclude[DirectMissingMethodProblem]("sbt.Classpaths.warnInsecureProtocolInModules"), exclude[MissingClassProblem]("sbt.internal.ExternalHooks*"), + // This seems to be a mima problem. The older constructor still exists but + // mima seems to incorrectly miss the secondary constructor that provides + // the binary compatible version. + exclude[IncompatibleMethTypeProblem]("sbt.internal.server.NetworkChannel.this"), ) ) .configure( diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ManagedLogger.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ManagedLogger.scala index 95b670635..9bb00524f 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ManagedLogger.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ManagedLogger.scala @@ -21,8 +21,11 @@ class ManagedLogger( val name: String, val channelName: Option[String], val execId: Option[String], - xlogger: XLogger + xlogger: XLogger, + terminal: Option[Terminal] ) extends Logger { + def this(name: String, channelName: Option[String], execId: Option[String], xlogger: XLogger) = + this(name, channelName, execId, xlogger, None) override def trace(t: => Throwable): Unit = logEvent(Level.Error, TraceEvent("Error", t, channelName, execId)) override def log(level: Level.Value, message: => String): Unit = { @@ -35,10 +38,12 @@ class ManagedLogger( private lazy val SuccessEventTag = scala.reflect.runtime.universe.typeTag[SuccessEvent] // send special event for success since it's not a real log level override def success(message: => String): Unit = { - infoEvent[SuccessEvent](SuccessEvent(message))( - implicitly[JsonFormat[SuccessEvent]], - SuccessEventTag - ) + if (terminal.fold(true)(_.isSuccessEnabled)) { + infoEvent[SuccessEvent](SuccessEvent(message))( + implicitly[JsonFormat[SuccessEvent]], + SuccessEventTag + ) + } } def registerStringCodec[A: ShowLines: TypeTag]: Unit = { diff --git a/internal/util-logging/src/main/scala/sbt/util/LogExchange.scala b/internal/util-logging/src/main/scala/sbt/util/LogExchange.scala index 869328ce2..a1073b36f 100644 --- a/internal/util-logging/src/main/scala/sbt/util/LogExchange.scala +++ b/internal/util-logging/src/main/scala/sbt/util/LogExchange.scala @@ -49,7 +49,7 @@ sealed abstract class LogExchange { config.addLogger(name, loggerConfig) ctx.updateLoggers val logger = ctx.getLogger(name) - new ManagedLogger(name, channelName, execId, logger) + new ManagedLogger(name, channelName, execId, logger, Some(Terminal.get)) } def unbindLoggerAppenders(loggerName: String): Unit = { val lc = loggerConfig(loggerName) diff --git a/main-command/src/main/scala/sbt/BasicCommandStrings.scala b/main-command/src/main/scala/sbt/BasicCommandStrings.scala index 331852d12..939b95e9e 100644 --- a/main-command/src/main/scala/sbt/BasicCommandStrings.scala +++ b/main-command/src/main/scala/sbt/BasicCommandStrings.scala @@ -235,4 +235,7 @@ $AliasCommand name= (ContinuousExecutePrefix + " ", continuousDetail) def ClearCaches: String = "clearCaches" def ClearCachesDetailed: String = "Clears all of sbt's internal caches." + + private[sbt] val networkExecPrefix = "__" + private[sbt] val DisconnectNetworkChannel = s"${networkExecPrefix}disconnectNetworkChannel" } diff --git a/main-command/src/main/scala/sbt/BasicCommands.scala b/main-command/src/main/scala/sbt/BasicCommands.scala index d74b3568c..cba9c73c1 100644 --- a/main-command/src/main/scala/sbt/BasicCommands.scala +++ b/main-command/src/main/scala/sbt/BasicCommands.scala @@ -346,7 +346,13 @@ object BasicCommands { private[this] def classpathStrings: Parser[Seq[String]] = token(StringBasic.map(s => IO.pathSplit(s).toSeq), "") - def exit: Command = Command.command(TerminateAction, exitBrief, exitBrief)(_ exit true) + def exit: Command = Command.command(TerminateAction, exitBrief, exitBrief) { s => + s.source match { + case Some(c) if c.channelName.startsWith("network") => + s"${DisconnectNetworkChannel} ${c.channelName}" :: s + case _ => s exit true + } + } @deprecated("Replaced by BuiltInCommands.continuous", "1.3.0") def continuous: Command = diff --git a/main-command/src/main/scala/sbt/BasicKeys.scala b/main-command/src/main/scala/sbt/BasicKeys.scala index 107b2df88..985ed5879 100644 --- a/main-command/src/main/scala/sbt/BasicKeys.scala +++ b/main-command/src/main/scala/sbt/BasicKeys.scala @@ -13,7 +13,7 @@ import com.github.ghik.silencer.silent import sbt.internal.inc.classpath.{ ClassLoaderCache => IncClassLoaderCache } import sbt.internal.classpath.ClassLoaderCache import sbt.internal.server.ServerHandler -import sbt.internal.util.AttributeKey +import sbt.internal.util.{ AttributeKey, Terminal } import sbt.librarymanagement.ModuleID import sbt.util.Level @@ -35,6 +35,11 @@ object BasicKeys { "The function that constructs the command prompt from the current build state.", 10000 ) + val terminalShellPrompt = AttributeKey[(Terminal, State) => String]( + "new-shell-prompt", + "The function that constructs the command prompt from the current build state for a given terminal.", + 10000 + ) @silent val watch = AttributeKey[Watched]("watched", "Continuous execution configuration.", 1000) val serverPort = diff --git a/main-command/src/main/scala/sbt/internal/CommandChannel.scala b/main-command/src/main/scala/sbt/internal/CommandChannel.scala index 9096c8a20..154e4cffd 100644 --- a/main-command/src/main/scala/sbt/internal/CommandChannel.scala +++ b/main-command/src/main/scala/sbt/internal/CommandChannel.scala @@ -9,9 +9,13 @@ package sbt package internal import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicReference +import sbt.internal.ui.{ UITask, UserThread } import sbt.internal.util.Terminal import sbt.protocol.EventMessage +import sbt.util.Level + import scala.collection.JavaConverters._ /** @@ -48,6 +52,8 @@ abstract class CommandChannel { private[sbt] final def initiateMaintenance(task: String): Unit = { maintenance.forEach(q => q.synchronized { q.add(new MaintenanceTask(this, task)); () }) } + private[sbt] def mkUIThread: (State, CommandChannel) => UITask + private[sbt] def makeUIThread(state: State): UITask = mkUIThread(state, this) final def append(exec: Exec): Boolean = { registered.synchronized { exec.commandLine.nonEmpty && { @@ -58,10 +64,29 @@ abstract class CommandChannel { } def poll: Option[Exec] = Option(commandQueue.poll) + def prompt(e: ConsolePromptEvent): Unit = userThread.onConsolePromptEvent(e) + def unprompt(e: ConsoleUnpromptEvent): Unit = userThread.onConsoleUnpromptEvent(e) def publishBytes(bytes: Array[Byte]): Unit - def shutdown(): Unit + private[sbt] def userThread: UserThread + def shutdown(logShutdown: Boolean): Unit = { + userThread.stopThread() + userThread.close() + } + @deprecated("Use the variant that takes the logShutdown parameter", "1.4.0") + def shutdown(): Unit = shutdown(true) def name: String + private[this] val level = new AtomicReference[Level.Value](Level.Info) + private[sbt] final def setLevel(l: Level.Value): Unit = level.set(l) + private[sbt] final def logLevel: Level.Value = level.get + private[this] def setLevel(value: Level.Value, cmd: String): Boolean = { + level.set(value) + append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name)))) + } private[sbt] def onCommand: String => Boolean = { + case "error" => setLevel(Level.Error, "error") + case "debug" => setLevel(Level.Debug, "debug") + case "info" => setLevel(Level.Info, "info") + case "warn" => setLevel(Level.Warn, "warn") case cmd => if (cmd.nonEmpty) append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name)))) else false @@ -89,7 +114,6 @@ case class ConsolePromptEvent(state: State) extends EventMessage /* * This is a data passed specifically for unprompting local console. */ -@deprecated("No longer used", "1.4.0") case class ConsoleUnpromptEvent(lastSource: Option[CommandSource]) extends EventMessage private[internal] class MaintenanceTask(val channel: CommandChannel, val task: String) diff --git a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala index 1482b0094..dac0c9565 100644 --- a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala +++ b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala @@ -8,76 +8,24 @@ package sbt package internal -import java.io.File -import java.nio.channels.ClosedChannelException -import java.util.concurrent.atomic.AtomicReference - -import sbt.BasicKeys._ +import sbt.internal.ui.{ UITask, UserThread } import sbt.internal.util._ +import sjsonnew.JsonFormat -private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel { - private[this] val askUserThread = new AtomicReference[AskUserThread] - private[this] def getPrompt(s: State): String = s.get(shellPrompt) match { - case Some(pf) => pf(s) - case None => - def ansi(s: String): String = if (ConsoleAppender.formatEnabledInEnv) s"$s" else "" - s"${ansi(ConsoleAppender.DeleteLine)}> ${ansi(ConsoleAppender.ClearScreenAfterCursor)}" - } - private[this] class AskUserThread(s: State) extends Thread("ask-user-thread") { - private val history = s.get(historyPath).getOrElse(Some(new File(s.baseDir, ".history"))) - private val prompt = getPrompt(s) - private val reader = - new FullReader( - history, - s.combinedParser, - LineReader.HandleCONT, - Terminal.console, - ) - setDaemon(true) - start() - override def run(): Unit = - try { - reader.readLine(prompt) match { - case Some(cmd) => append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name)))) - case None => - println("") // Prevents server shutdown log lines from appearing on the prompt line - append(Exec("exit", Some(Exec.newExecId), Some(CommandSource(name)))) - } - () - } catch { - case _: ClosedChannelException => - } finally askUserThread.synchronized(askUserThread.set(null)) - def redraw(): Unit = { - System.out.print(ConsoleAppender.clearLine(0)) - System.out.print(ConsoleAppender.ClearScreenAfterCursor) - System.out.flush() - } - } - private[this] def makeAskUserThread(s: State): AskUserThread = new AskUserThread(s) +private[sbt] final class ConsoleChannel( + val name: String, + override private[sbt] val mkUIThread: (State, CommandChannel) => UITask +) extends CommandChannel { def run(s: State): State = s def publishBytes(bytes: Array[Byte]): Unit = () - def prompt(event: ConsolePromptEvent): Unit = { - if (Terminal.systemInIsAttached) { - askUserThread.synchronized { - askUserThread.get match { - case null => askUserThread.set(makeAskUserThread(event.state)) - case t => t.redraw() - } - } - } - } + def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = () - def shutdown(): Unit = askUserThread.synchronized { - askUserThread.get match { - case null => - case t if t.isAlive => - t.interrupt() - askUserThread.set(null) - case _ => () - } - } - override private[sbt] def terminal = Terminal.console + override val userThread: UserThread = new UserThread(this) + private[sbt] def terminal = Terminal.console +} +private[sbt] object ConsoleChannel { + private[sbt] def defaultName = "console0" } diff --git a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala index 85c3b9867..93c9afcc0 100644 --- a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala +++ b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala @@ -20,12 +20,14 @@ import sbt.internal.util.ReadJsonFromInputStream abstract class ServerConnection(connection: Socket) { private val running = new AtomicBoolean(true) + private val closed = new AtomicBoolean(false) private val retByte: Byte = '\r'.toByte private val delimiter: Byte = '\n'.toByte private val out = connection.getOutputStream val thread = new Thread(s"sbt-serverconnection-${connection.getPort}") { + setDaemon(true) override def run(): Unit = { try { val in = connection.getInputStream @@ -67,17 +69,22 @@ abstract class ServerConnection(connection: Socket) { writeLine(a) } - def writeLine(a: Array[Byte]): Unit = { - def writeEndLine(): Unit = { - out.write(retByte.toInt) - out.write(delimiter.toInt) - out.flush + def writeLine(a: Array[Byte]): Unit = + try { + def writeEndLine(): Unit = { + out.write(retByte.toInt) + out.write(delimiter.toInt) + out.flush + } + if (a.nonEmpty) { + out.write(a) + } + writeEndLine + } catch { + case e: IOException => + shutdown() + throw e } - if (a.nonEmpty) { - out.write(a) - } - writeEndLine - } def onRequest(msg: JsonRpcRequestMessage): Unit def onResponse(msg: JsonRpcResponseMessage): Unit @@ -85,10 +92,14 @@ abstract class ServerConnection(connection: Socket) { def onShutdown(): Unit - def shutdown(): Unit = { - println("Shutting down client connection") - running.set(false) - out.close() + def shutdown(): Unit = if (closed.compareAndSet(false, true)) { + if (!running.compareAndSet(true, false)) { + System.err.println("\nsbt server connection closed.") + } + try { + out.close() + connection.close() + } catch { case e: IOException => e.printStackTrace() } onShutdown } diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 3d859d403..775683130 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -227,7 +227,7 @@ private[sbt] case class ServerConnection( socketfile: File, pipeName: String, bspConnectionFile: File, - appConfiguration: AppConfiguration + appConfiguration: AppConfiguration, ) { def shortName: String = { connectionType match { diff --git a/main-command/src/main/scala/sbt/internal/ui/UITask.scala b/main-command/src/main/scala/sbt/internal/ui/UITask.scala new file mode 100644 index 000000000..f5d365efe --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/ui/UITask.scala @@ -0,0 +1,99 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.ui + +import java.io.File +import java.nio.channels.ClosedChannelException +import java.util.concurrent.atomic.AtomicBoolean + +import jline.console.history.PersistentHistory +import sbt.BasicKeys.{ historyPath, terminalShellPrompt } +import sbt.State +import sbt.internal.CommandChannel +import sbt.internal.util.ConsoleAppender.{ ClearPromptLine, ClearScreenAfterCursor, DeleteLine } +import sbt.internal.util._ +import sbt.internal.util.complete.{ JLineCompletion, Parser } + +import scala.annotation.tailrec + +private[sbt] trait UITask extends Runnable with AutoCloseable { + private[sbt] def channel: CommandChannel + private[sbt] def reader: UITask.Reader + private[this] final def handleInput(s: Either[String, String]): Boolean = s match { + case Left(m) => channel.onMaintenance(m) + case Right(cmd) => channel.onCommand(cmd) + } + private[this] val isStopped = new AtomicBoolean(false) + override def run(): Unit = { + @tailrec def impl(): Unit = { + val res = reader.readLine() + if (!handleInput(res) && !isStopped.get) impl() + } + try impl() + catch { case _: InterruptedException | _: ClosedChannelException => isStopped.set(true) } + } + override def close(): Unit = isStopped.set(true) +} + +private[sbt] object UITask { + trait Reader { def readLine(): Either[String, String] } + object Reader { + def terminalReader(parser: Parser[_])( + terminal: Terminal, + state: State + ): Reader = { + val lineReader = LineReader.createReader(history(state), terminal, terminal.prompt) + JLineCompletion.installCustomCompletor(lineReader, parser) + () => { + val clear = terminal.ansi(ClearPromptLine, "") + try { + @tailrec def impl(): Either[String, String] = { + lineReader.readLine(clear + terminal.prompt.mkPrompt()) match { + case null => Left("exit") + case s: String => + lineReader.getHistory match { + case p: PersistentHistory => + p.add(s) + p.flush() + case _ => + } + s match { + case "" => impl() + case cmd @ ("shutdown" | "exit" | "cancel") => Left(cmd) + case cmd => + if (terminal.prompt != Prompt.Batch) terminal.setPrompt(Prompt.Running) + terminal.printStream.write(Int.MinValue) + Right(cmd) + } + } + } + impl() + } catch { + case _: InterruptedException => Right("") + } finally lineReader.close() + } + } + } + private[this] def history(s: State): Option[File] = + s.get(historyPath).getOrElse(Some(new File(s.baseDir, ".history"))) + private[sbt] def shellPrompt(terminal: Terminal, s: State): String = + s.get(terminalShellPrompt) match { + case Some(pf) => pf(terminal, s) + case None => + def ansi(s: String): String = if (terminal.isAnsiSupported) s"$s" else "" + s"${ansi(DeleteLine)}> ${ansi(ClearScreenAfterCursor)}" + } + private[sbt] class AskUserTask( + state: State, + override val channel: CommandChannel, + ) extends UITask { + override private[sbt] def reader: UITask.Reader = { + UITask.Reader.terminalReader(state.combinedParser)(channel.terminal, state) + } + } +} diff --git a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala new file mode 100644 index 000000000..e868c3682 --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala @@ -0,0 +1,91 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal + +package ui + +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } +import java.util.concurrent.Executors + +import sbt.State +import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Util } +import sbt.internal.util.Prompt.{ AskUser, Running } + +private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable { + private[this] val uiThread = new AtomicReference[(UITask, Thread)] + private[sbt] final def onProgressEvent(pe: ProgressEvent): Unit = { + lastProgressEvent.set(pe) + ProgressState.updateProgressState(pe, channel.terminal) + } + private[this] val executor = + Executors.newSingleThreadExecutor(r => new Thread(r, s"sbt-$name-ui-thread")) + private[this] val lastProgressEvent = new AtomicReference[ProgressEvent] + private[this] val isClosed = new AtomicBoolean(false) + + private[sbt] def reset(state: State): Unit = if (!isClosed.get) { + uiThread.synchronized { + val task = channel.makeUIThread(state) + def submit(): Thread = { + val thread = new Thread(() => { + task.run() + uiThread.set(null) + }, s"sbt-$name-ui-thread") + thread.setDaemon(true) + thread.start() + uiThread.getAndSet((task, thread)) match { + case null => + case (_, t) => t.interrupt() + } + thread + } + uiThread.get match { + case null => uiThread.set((task, submit())) + case (t, _) if t.getClass == task.getClass => + case (t, thread) => + thread.interrupt() + uiThread.set((task, submit())) + } + } + Option(lastProgressEvent.get).foreach(onProgressEvent) + } + + private[sbt] def stopThread(): Unit = uiThread.synchronized { + uiThread.getAndSet(null) match { + case null => + case (t, thread) => + t.close() + Util.ignoreResult(thread.interrupt()) + } + } + + private[sbt] def onConsolePromptEvent(consolePromptEvent: ConsolePromptEvent): Unit = { + channel.terminal.withPrintStream { ps => + ps.print(ConsoleAppender.ClearScreenAfterCursor) + ps.flush() + } + val state = consolePromptEvent.state + terminal.prompt match { + case Running => terminal.setPrompt(AskUser(() => UITask.shellPrompt(terminal, state))) + case _ => + } + onProgressEvent(ProgressEvent("Info", Vector(), None, None, None)) + reset(state) + } + + private[sbt] def onConsoleUnpromptEvent( + consoleUnpromptEvent: ConsoleUnpromptEvent + ): Unit = { + if (consoleUnpromptEvent.lastSource.fold(true)(_.channelName != name)) { + terminal.progressState.reset() + } else stopThread() + } + + override def close(): Unit = if (isClosed.compareAndSet(false, true)) executor.shutdown() + private def terminal = channel.terminal + private def name: String = channel.name +} diff --git a/main-settings/src/main/scala/sbt/Def.scala b/main-settings/src/main/scala/sbt/Def.scala index 6f169c5c6..2f6833fb8 100644 --- a/main-settings/src/main/scala/sbt/Def.scala +++ b/main-settings/src/main/scala/sbt/Def.scala @@ -170,12 +170,11 @@ object Def extends Init[Scope] with TaskMacroExtra with InitializeImplicits { def displayMasked(scoped: ScopedKey[_], mask: ScopeMask, showZeroConfig: Boolean): String = Scope.displayMasked(scoped.scope, scoped.key.label, mask, showZeroConfig) - def withColor(s: String, color: Option[String]): String = { - val useColor = ConsoleAppender.formatEnabledInEnv - color match { - case Some(c) if useColor => c + s + scala.Console.RESET - case _ => s - } + def withColor(s: String, color: Option[String]): String = + withColor(s, color, useColor = ConsoleAppender.formatEnabledInEnv) + def withColor(s: String, color: Option[String], useColor: Boolean): String = color match { + case Some(c) if useColor => c + s + scala.Console.RESET + case _ => s } override def deriveAllowed[T](s: Setting[T], allowDynamic: Boolean): Option[String] = diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index b3d6f3e92..90b6bc815 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -407,7 +407,7 @@ object Defaults extends BuildCommon { // TODO: This should be on the new default settings for a project. def projectCore: Seq[Setting[_]] = Seq( name := thisProject.value.id, - logManager := LogManager.defaults(extraLoggers.value, StandardMain.console), + logManager := LogManager.defaults(extraLoggers.value, ConsoleOut.terminalOut), onLoadMessage := (onLoadMessage or Def.setting { s"set current project to ${name.value} (in build ${thisProjectRef.value.build})" @@ -1495,13 +1495,13 @@ object Defaults extends BuildCommon { def askForMainClass(classes: Seq[String]): Option[String] = sbt.SelectMainClass( - if (classes.length >= 10) Some(SimpleReader.readLine(_)) + if (classes.length >= 10) Some(SimpleReader(Terminal.get).readLine(_)) else Some(s => { def print(st: String) = { scala.Console.out.print(st); scala.Console.out.flush() } print(s) Terminal.get.withRawSystemIn { - Terminal.read match { + Terminal.get.inputStream.read match { case -1 => None case b => val res = b.toChar.toString @@ -2343,6 +2343,9 @@ object Classpaths { CrossVersion(scalaVersion, binVersion)(base).withCrossVersion(Disabled()) }, shellPrompt := shellPromptFromState, + terminalShellPrompt := { (t, s) => + shellPromptFromState(t)(s) + }, dynamicDependency := { (): Unit }, transitiveClasspathDependency := { (): Unit }, transitiveDynamicInputs :== Nil, @@ -3826,11 +3829,13 @@ object Classpaths { } } - def shellPromptFromState: State => String = { s: State => + def shellPromptFromState: State => String = shellPromptFromState(Terminal.console) + def shellPromptFromState(terminal: Terminal): State => String = { s: State => val extracted = Project.extract(s) (name in extracted.currentRef).get(extracted.structure.data) match { - case Some(name) => s"sbt:$name" + Def.withColor("> ", Option(scala.Console.CYAN)) - case _ => "> " + case Some(name) => + s"sbt:$name" + Def.withColor(s"> ", Option(scala.Console.CYAN), terminal.isColorEnabled) + case _ => "> " } } } diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 2a75c9901..54a3c5ee1 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -89,6 +89,7 @@ object Keys { // Command keys val historyPath = SettingKey(BasicKeys.historyPath) val shellPrompt = SettingKey(BasicKeys.shellPrompt) + val terminalShellPrompt = SettingKey(BasicKeys.terminalShellPrompt) val autoStartServer = SettingKey(BasicKeys.autoStartServer) val serverPort = SettingKey(BasicKeys.serverPort) val serverHost = SettingKey(BasicKeys.serverHost) diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 03a110754..fe155b856 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -13,8 +13,9 @@ import java.nio.channels.ClosedChannelException import java.nio.file.{ FileAlreadyExistsException, FileSystems, Files } import java.util.Properties import java.util.concurrent.ForkJoinPool +import java.util.concurrent.atomic.AtomicBoolean -import sbt.BasicCommandStrings.{ Shell, TemplateCommand } +import sbt.BasicCommandStrings.{ Shell, TemplateCommand, networkExecPrefix } import sbt.Project.LoadAction import sbt.compiler.EvalImports import sbt.internal.Aggregation.AnyKeys @@ -24,6 +25,7 @@ import sbt.internal.client.BspClient import sbt.internal.inc.ScalaInstance import sbt.internal.io.Retry import sbt.internal.nio.CheckBuildSources +import sbt.internal.server.NetworkChannel import sbt.internal.util.Types.{ const, idFun } import sbt.internal.util._ import sbt.internal.util.complete.{ Parser, SizeParser } @@ -34,7 +36,9 @@ import xsbti.compile.CompilerCache import scala.annotation.tailrec import scala.concurrent.ExecutionContext +import scala.concurrent.duration.Duration import scala.util.control.NonFatal +import sbt.internal.io.Retry /** This class is the entry point for sbt. */ final class xMain extends xsbti.AppMain { @@ -130,18 +134,24 @@ object StandardMain { pool.foreach(_.shutdownNow()) } + private[this] val isShutdown = new AtomicBoolean(false) def runManaged(s: State): xsbti.MainResult = { val previous = TrapExit.installManager() try { val hook = ShutdownHooks.add(closeRunnable) try { MainLoop.runLogged(s) + } catch { + case _: InterruptedException if isShutdown.get => + new xsbti.Exit { override def code(): Int = 0 } } finally { try DefaultBackgroundJobService.shutdown() finally hook.close() () } - } finally TrapExit.uninstallManager(previous) + } finally { + TrapExit.uninstallManager(previous) + } } /** The common interface to standard output, used for all built-in ConsoleLoggers. */ @@ -254,6 +264,7 @@ object BuiltinCommands { act, continuous, clearCaches, + NetworkChannel.disconnect, ) ++ allBasicCommands def DefaultBootCommands: Seq[String] = @@ -904,6 +915,16 @@ object BuiltinCommands { Command.command(ClearCaches, help)(f) } + private def getExec(state: State, interval: Duration): Exec = { + val exec: Exec = + StandardMain.exchange.blockUntilNextExec(interval, Some(state), state.globalLogging.full) + if (exec.source.fold(true)(_.channelName != ConsoleChannel.defaultName) && + !exec.commandLine.startsWith(networkExecPrefix)) { + Terminal.consoleLog(s"received remote command: ${exec.commandLine}") + } + exec + } + def shell: Command = Command.command(Shell, Help.more(Shell, ShellDetailed)) { s0 => import sbt.internal.ConsolePromptEvent val exchange = StandardMain.exchange @@ -914,18 +935,17 @@ object BuiltinCommands { .extract(s1) .getOpt(Keys.minForcegcInterval) .getOrElse(GCUtil.defaultMinForcegcInterval) - val exec: Exec = exchange.blockUntilNextExec(minGCInterval, s1.globalLogging.full) - if (exec.source.fold(true)(_.channelName != "console0")) { - s1.log.info(s"received remote command: ${exec.commandLine}") - } + val exec: Exec = getExec(s1, minGCInterval) val newState = s1 .copy( onFailure = Some(Exec(Shell, None)), remainingCommands = exec +: Exec(Shell, None) +: s1.remainingCommands ) .setInteractive(true) - if (exec.commandLine.trim.isEmpty) newState - else newState.clearGlobalLog + val res = + if (exec.commandLine.trim.isEmpty) newState + else newState.clearGlobalLog + res } def startServer: Command = diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index 41f312f0c..fd3005138 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -10,11 +10,13 @@ package sbt import java.io.PrintWriter import java.util.Properties +import sbt.BasicCommandStrings.{ StashOnFailure, networkExecPrefix } import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.nio.CheckBuildSources.CheckBuildSourcesKey import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Terminal } +import sbt.internal.{ ConsoleUnpromptEvent, ShutdownHooks } import sbt.io.{ IO, Using } import sbt.protocol._ import sbt.util.Logger @@ -195,16 +197,21 @@ object MainLoop { state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress)) } else state } + StandardMain.exchange.setState(progressState) StandardMain.exchange.setExec(Some(exec)) + StandardMain.exchange.unprompt(ConsoleUnpromptEvent(exec.source)) val newState = Command.process(exec.commandLine, progressState) - val doneEvent = ExecStatusEvent( - "Done", - channelName, - exec.execId, - newState.remainingCommands.toVector map (_.commandLine), - exitCode(newState, state), - ) - StandardMain.exchange.respondStatus(doneEvent) + if (exec.execId.fold(true)(!_.startsWith(networkExecPrefix)) && + !exec.commandLine.startsWith(networkExecPrefix)) { + val doneEvent = ExecStatusEvent( + "Done", + channelName, + exec.execId, + newState.remainingCommands.toVector map (_.commandLine), + exitCode(newState, state), + ) + StandardMain.exchange.respondStatus(doneEvent) + } StandardMain.exchange.setExec(None) newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) newState.remove(sbt.Keys.currentTaskProgress) @@ -273,7 +280,8 @@ object MainLoop { // it's handled by executing the shell again, instead of the state failing // so we also use that to indicate that the execution failed private[this] def exitCodeFromStateOnFailure(state: State, prevState: State): ExitCode = - if (prevState.onFailure.isDefined && state.onFailure.isEmpty) ExitCode(ErrorCodes.UnknownError) + if (prevState.onFailure.isDefined && state.onFailure.isEmpty && + state.currentCommand.fold(true)(_ != StashOnFailure)) ExitCode(ErrorCodes.UnknownError) else ExitCode.Success } diff --git a/main/src/main/scala/sbt/Project.scala b/main/src/main/scala/sbt/Project.scala index 555737dd5..18cf13ddc 100755 --- a/main/src/main/scala/sbt/Project.scala +++ b/main/src/main/scala/sbt/Project.scala @@ -19,6 +19,7 @@ import Keys.{ historyPath, projectCommand, sessionSettings, + terminalShellPrompt, shellPrompt, templateResolverInfos, autoStartServer, @@ -508,6 +509,7 @@ object Project extends ProjectExtra { val allCommands = commandsIn(ref) ++ commandsIn(BuildRef(ref.build)) ++ (commands in Global get structure.data toList) val history = get(historyPath) flatMap idFun val prompt = get(shellPrompt) + val newPrompt = get(terminalShellPrompt) val trs = (templateResolverInfos in Global get structure.data).toList.flatten val startSvr: Option[Boolean] = get(autoStartServer) val host: Option[String] = get(serverHost) @@ -532,6 +534,7 @@ object Project extends ProjectExtra { .put(historyPath.key, history) .put(templateResolverInfos.key, trs) .setCond(shellPrompt.key, prompt) + .setCond(terminalShellPrompt.key, newPrompt) .setCond(serverLogLevel, srvLogLevel) .setCond(fullServerHandlers.key, hs) s.copy( diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 6b3178062..cbabbb6a5 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -10,19 +10,21 @@ package sbt package internal import java.io.IOException import java.net.Socket -import java.util.concurrent.{ ConcurrentLinkedQueue, LinkedBlockingQueue, TimeUnit } import java.util.concurrent.atomic._ +import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit } +import sbt.BasicCommandStrings.networkExecPrefix import sbt.BasicKeys._ -import sbt.nio.Watch.NullLogger import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.server._ -import sbt.internal.util.{ ConsoleOut, MainAppender, ProgressEvent, ProgressState, Terminal } +import sbt.internal.ui.UITask +import sbt.internal.util._ import sbt.io.syntax._ import sbt.io.{ Hash, IO } +import sbt.nio.Watch.NullLogger import sbt.protocol.{ ExecStatusEvent, LogEvent } -import sbt.util.{ Level, LogExchange, Logger } -import sjsonnew.JsonFormat +import sbt.util.Logger +import sbt.protocol.Serialization.attach import scala.annotation.tailrec import scala.collection.mutable.ListBuffer @@ -30,6 +32,8 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.util.{ Failure, Success, Try } +import sjsonnew.JsonFormat + /** * The command exchange merges multiple command channels (e.g. network and console), * and acts as the central multiplexing point. @@ -42,23 +46,15 @@ private[sbt] final class CommandExchange { private var server: Option[ServerInstance] = None private val firstInstance: AtomicBoolean = new AtomicBoolean(true) private var consoleChannel: Option[ConsoleChannel] = None - private val commandQueue: ConcurrentLinkedQueue[Exec] = new ConcurrentLinkedQueue() + private val commandQueue: LinkedBlockingQueue[Exec] = new LinkedBlockingQueue[Exec] private val channelBuffer: ListBuffer[CommandChannel] = new ListBuffer() private val channelBufferLock = new AnyRef {} - private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel] private val maintenanceChannelQueue = new LinkedBlockingQueue[MaintenanceTask] private val nextChannelId: AtomicInteger = new AtomicInteger(0) - private[this] val activePrompt = new AtomicBoolean(false) private[this] val lastState = new AtomicReference[State] private[this] val currentExecRef = new AtomicReference[Exec] def channels: List[CommandChannel] = channelBuffer.toList - private[this] def removeChannel(channel: CommandChannel): Unit = { - channelBufferLock.synchronized { - channelBuffer -= channel - () - } - } def subscribe(c: CommandChannel): Unit = channelBufferLock.synchronized { channelBuffer.append(c) @@ -68,23 +64,39 @@ private[sbt] final class CommandExchange { private[sbt] def withState[T](f: State => T): T = f(lastState.get) def blockUntilNextExec: Exec = blockUntilNextExec(Duration.Inf, NullLogger) // periodically move all messages from all the channels - private[sbt] def blockUntilNextExec(interval: Duration, logger: Logger): Exec = { + private[sbt] def blockUntilNextExec(interval: Duration, logger: Logger): Exec = + blockUntilNextExec(interval, None, logger) + private[sbt] def blockUntilNextExec( + interval: Duration, + state: Option[State], + logger: Logger + ): Exec = { @tailrec def impl(deadline: Option[Deadline]): Exec = { - @tailrec def slurpMessages(): Unit = - channels.foldLeft(Option.empty[Exec]) { _ orElse _.poll } match { - case None => () - case Some(x) => - commandQueue.add(x) - slurpMessages() - } - commandChannelQueue.poll(1, TimeUnit.SECONDS) - slurpMessages() - Option(commandQueue.poll) match { - case Some(exec) => - val needFinish = needToFinishPromptLine() - if (exec.source.fold(needFinish)(s => needFinish && s.channelName != "console0")) - ConsoleOut.systemOut.println("") - exec + state.foreach(s => prompt(ConsolePromptEvent(s))) + def poll: Option[Exec] = + Option(deadline match { + case Some(d: Deadline) => + commandQueue.poll(d.timeLeft.toMillis + 1, TimeUnit.MILLISECONDS) + case _ => commandQueue.take + }) + poll match { + case Some(exec) if exec.source.fold(true)(s => channels.exists(_.name == s.channelName)) => + exec.commandLine match { + case "shutdown" => + exec + .withCommandLine("exit") + .withSource(Some(CommandSource(ConsoleChannel.defaultName))) + case "exit" if exec.source.fold(false)(_.channelName.startsWith("network")) => + channels.collectFirst { + case c: NetworkChannel if exec.source.fold(false)(_.channelName == c.name) => c + } match { + case Some(c) if c.isAttached => + c.shutdown(false) + impl(deadline) + case _ => exec + } + case _ => exec + } case None => val newDeadline = if (deadline.fold(false)(_.isOverdue())) { GCUtil.forceGcWithInterval(interval, logger) @@ -100,23 +112,38 @@ private[sbt] final class CommandExchange { }) } - def run(s: State): State = { + private def addConsoleChannel(): Unit = if (consoleChannel.isEmpty) { - val console0 = new ConsoleChannel("console0") + val name = ConsoleChannel.defaultName + val console0 = new ConsoleChannel(name, mkAskUser(name)) consoleChannel = Some(console0) subscribe(console0) } - val autoStartServerAttr = s get autoStartServer match { - case Some(bool) => bool - case None => true - } - if (autoStartServerSysProp && autoStartServerAttr) runServer(s) + def run(s: State): State = run(s, s.get(autoStartServer).getOrElse(true)) + def run(s: State, autoStart: Boolean): State = { + if (autoStartServerSysProp && autoStart) runServer(s) else s } private[sbt] def setState(s: State): Unit = lastState.set(s) private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}" + private[sbt] def removeChannel(c: CommandChannel): Unit = { + channelBufferLock.synchronized { + Util.ignoreResult(channelBuffer -= c) + } + commandQueue.removeIf(_.source.map(_.channelName) == Some(c.name)) + currentExec.filter(_.source.map(_.channelName) == Some(c.name)).foreach { e => + Util.ignoreResult(NetworkChannel.cancel(e.execId, e.execId.getOrElse("0"))) + } + } + + private[this] def mkAskUser( + name: String, + ): (State, CommandChannel) => UITask = { (state, channel) => + new UITask.AskUserTask(state, channel) + } + private[sbt] def currentExec = Option(currentExecRef.get) /** @@ -128,22 +155,22 @@ private[sbt] final class CommandExchange { lazy val auth: Set[ServerAuthentication] = s.get(serverAuthentication).getOrElse(Set(ServerAuthentication.Token)) lazy val connectionType = s.get(serverConnectionType).getOrElse(ConnectionType.Tcp) - lazy val level = s.get(serverLogLevel).orElse(s.get(logLevel)).getOrElse(Level.Warn) lazy val handlers = s.get(fullServerHandlers).getOrElse(Nil) def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = { val name = newNetworkName - if (needToFinishPromptLine()) ConsoleOut.systemOut.println("") - s.log.info(s"new client connected: $name") - val logger: Logger = { - val log = LogExchange.logger(name, None, None) - LogExchange.unbindLoggerAppenders(name) - val appender = MainAppender.defaultScreen(s.globalLogging.console) - LogExchange.bindLoggerAppenders(name, List(appender -> level)) - log - } + Terminal.consoleLog(s"new client connected: $name") val channel = - new NetworkChannel(name, socket, Project structure s, auth, instance, handlers, logger) + new NetworkChannel( + name, + socket, + Project structure s, + auth, + instance, + handlers, + s.log, + mkAskUser(name) + ) subscribe(channel) } if (server.isEmpty && firstInstance.get) { @@ -165,7 +192,7 @@ private[sbt] final class CommandExchange { socketfile, pipeName, bspConnectionFile, - s.configuration + s.configuration, ) val serverInstance = Server.start(connection, onIncomingSocket, s.log) // don't throw exception when it times out @@ -196,7 +223,7 @@ private[sbt] final class CommandExchange { def shutdown(): Unit = { maintenanceThread.close() - channels foreach (_.shutdown()) + channels foreach (_.shutdown(true)) // interrupt and kill the thread server.foreach(_.shutdown()) server = None @@ -243,11 +270,10 @@ private[sbt] final class CommandExchange { // This is an interface to directly notify events. private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.notifyEvent(method, params)) - } + channels.foreach { + case c: NetworkChannel => tryTo(_.notifyEvent(method, params))(c) + case _ => + } } private def tryTo(f: NetworkChannel => Unit)( @@ -274,26 +300,22 @@ private[sbt] final class CommandExchange { tryTo(_.respondError(code, event.message.getOrElse(""), event.execId))(channel) } } - - tryTo(_.respond(event, event.execId))(channel) } } private[sbt] def setExec(exec: Option[Exec]): Unit = currentExecRef.set(exec.orNull) def prompt(event: ConsolePromptEvent): Unit = { - activePrompt.set(Terminal.systemInIsAttached) - channels - .collect { case c: ConsoleChannel => c } - .foreach { _.prompt(event) } + currentExecRef.set(null) + channels.foreach(_.prompt(event)) } + def unprompt(event: ConsoleUnpromptEvent): Unit = channels.foreach(_.unprompt(event)) def logMessage(event: LogEvent): Unit = { - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.notifyEvent(event)) - } + channels.foreach { + case c: NetworkChannel => tryTo(_.notifyEvent(event))(c) + case _ => + } } def notifyStatus(event: ExecStatusEvent): Unit = { @@ -305,19 +327,23 @@ private[sbt] final class CommandExchange { } tryTo(_.notifyEvent(event))(channel) } - private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) + private[sbt] def killChannel(channel: String): Unit = { + channels.find(_.name == channel).foreach(_.shutdown(false)) + } private[sbt] def updateProgress(pe: ProgressEvent): Unit = { val newPE = currentExec match { - case Some(e) => + case Some(e) if !e.commandLine.startsWith(networkExecPrefix) => pe.withCommand(currentExec.map(_.commandLine)) .withExecId(currentExec.flatMap(_.execId)) .withChannelName(currentExec.flatMap(_.source.map(_.channelName))) case _ => pe } + if (channels.isEmpty) addConsoleChannel() channels.foreach(c => ProgressState.updateProgressState(newPE, c.terminal)) } private[sbt] def shutdown(name: String): Unit = { + Option(currentExecRef.get).foreach(cancel) commandQueue.clear() val exit = Exec("shutdown", Some(Exec.newExecId), Some(CommandSource(name))) @@ -333,7 +359,7 @@ private[sbt] final class CommandExchange { private[this] val isStopped = new AtomicBoolean(false) override def run(): Unit = { def exit(mt: MaintenanceTask): Unit = { - mt.channel.shutdown() + mt.channel.shutdown(false) if (mt.channel.name.contains("console")) shutdown(mt.channel.name) } @tailrec def impl(): Unit = { @@ -341,6 +367,8 @@ private[sbt] final class CommandExchange { case null => case mt: MaintenanceTask => mt.task match { + case `attach` => mt.channel.prompt(ConsolePromptEvent(lastState.get)) + case "cancel" => Option(currentExecRef.get).foreach(cancel) case "exit" => exit(mt) case "shutdown" => shutdown(mt.channel.name) case _ => diff --git a/main/src/main/scala/sbt/internal/ConsoleProject.scala b/main/src/main/scala/sbt/internal/ConsoleProject.scala index 82fe28e8e..52beebc9a 100644 --- a/main/src/main/scala/sbt/internal/ConsoleProject.scala +++ b/main/src/main/scala/sbt/internal/ConsoleProject.scala @@ -28,7 +28,7 @@ object ConsoleProject { extracted.runTask(Keys.scalaCompilerBridgeBinaryJar.in(Keys.consoleProject), state1) val scalaInstance = { val scalaProvider = state.configuration.provider.scalaProvider - ScalaInstance(scalaProvider.version, scalaProvider.launcher) + ScalaInstance(scalaProvider.version, scalaProvider) } val g = BuildPaths.getGlobalBase(state) val zincDir = BuildPaths.getZincDirectory(state, g) @@ -61,13 +61,15 @@ object ConsoleProject { val importString = imports.mkString("", ";\n", ";\n\n") val initCommands = importString + extra - Terminal.get.withCanonicalIn { + val terminal = Terminal.get + terminal.withCanonicalIn { // TODO - Hook up dsl classpath correctly... (new Console(compiler))( unit.classpath, options, initCommands, - cleanupCommands + cleanupCommands, + terminal )(Some(unit.loader), bindings).get } () diff --git a/main/src/main/scala/sbt/internal/Continuous.scala b/main/src/main/scala/sbt/internal/Continuous.scala index eac9e76ca..9267964e2 100644 --- a/main/src/main/scala/sbt/internal/Continuous.scala +++ b/main/src/main/scala/sbt/internal/Continuous.scala @@ -274,7 +274,7 @@ private[sbt] object Continuous extends DeprecatedContinuous { private[this] def withCharBufferedStdIn[R](f: InputStream => R): R = Terminal.get.withRawSystemIn { - val wrapped = Terminal.get.inputStream + val wrapped = Terminal.wrappedSystemIn if (Util.isNonCygwinWindows) { val inputStream: InputStream with AutoCloseable = new InputStream with AutoCloseable { private[this] val buffer = new java.util.LinkedList[Int] diff --git a/main/src/main/scala/sbt/internal/LogManager.scala b/main/src/main/scala/sbt/internal/LogManager.scala index f44bd0371..a9db8d34b 100644 --- a/main/src/main/scala/sbt/internal/LogManager.scala +++ b/main/src/main/scala/sbt/internal/LogManager.scala @@ -152,8 +152,9 @@ object LogManager { case Some(x: Exec) => x.source match { // TODO: Fix this stringliness - case Some(x: CommandSource) if x.channelName == "console0" => Option(console) - case _ => Option(console) + case Some(x: CommandSource) if x.channelName == ConsoleChannel.defaultName => + Option(console) + case _ => Option(console) } case _ => Option(console) } diff --git a/main/src/main/scala/sbt/internal/TaskProgress.scala b/main/src/main/scala/sbt/internal/TaskProgress.scala index e208f4a03..2a03206ee 100644 --- a/main/src/main/scala/sbt/internal/TaskProgress.scala +++ b/main/src/main/scala/sbt/internal/TaskProgress.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ private[sbt] final class TaskProgress extends AbstractTaskExecuteProgress with ExecuteProgress[Task] { - @deprecated("Use the constructor taking an ExecID.", "1.4.0") + @deprecated("Use the no argument constructor.", "1.4.0") def this(log: ManagedLogger) = this() private[this] val lastTaskCount = new AtomicInteger(0) private[this] val currentProgressThread = new AtomicReference[Option[ProgressThread]](None) diff --git a/main/src/main/scala/sbt/internal/XMainConfiguration.scala b/main/src/main/scala/sbt/internal/XMainConfiguration.scala index b0eb06751..d83826c37 100644 --- a/main/src/main/scala/sbt/internal/XMainConfiguration.scala +++ b/main/src/main/scala/sbt/internal/XMainConfiguration.scala @@ -9,7 +9,7 @@ package sbt.internal import java.io.File import java.lang.reflect.InvocationTargetException -import java.net.URL +import java.net.{ URL, URLClassLoader } import java.util.concurrent.{ ExecutorService, Executors } import ClassLoaderClose.close @@ -58,10 +58,21 @@ private[internal] object ClassLoaderWarmup { */ private[sbt] class XMainConfiguration { def run(moduleName: String, configuration: xsbti.AppConfiguration): xsbti.MainResult = { + val topLoader = configuration.provider.scalaProvider.launcher.topLoader val updatedConfiguration = - if (configuration.provider.scalaProvider.launcher.topLoader.getClass.getCanonicalName - .contains("TestInterfaceLoader")) { - configuration + if (topLoader.getClass.getCanonicalName.contains("TestInterfaceLoader")) { + topLoader match { + case u: URLClassLoader => + val urls = u.getURLs + var i = 0 + while (i < urls.length && i >= 0) { + if (urls(i).toString.contains("jline")) i = -2 + else i += 1 + } + if (i < 0) configuration + else makeConfiguration(configuration) + case _ => configuration + } } else { makeConfiguration(configuration) } diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index edfc488cc..5a50b3016 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -16,19 +16,19 @@ import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue } import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes, LogMessageParams, MessageType } +import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes } import sbt.internal.protocol.{ JsonRpcNotificationMessage, JsonRpcRequestMessage, JsonRpcResponseError, JsonRpcResponseMessage } -import sbt.internal.util.{ ReadJsonFromInputStream, Prompt, Terminal, Util } +import sbt.internal.ui.{ UITask, UserThread } +import sbt.internal.util.{ Prompt, ReadJsonFromInputStream, Terminal, Util } import sbt.internal.util.Terminal.TerminalImpl -import sbt.internal.util.complete.Parser +import sbt.internal.util.complete.{ Parser, Parsers } import sbt.protocol._ import sbt.util.Logger -import sjsonnew._ -import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } import scala.annotation.tailrec import scala.collection.mutable @@ -37,6 +37,11 @@ import scala.util.Try import scala.util.control.NonFatal import Serialization.attach +import sjsonnew._ +import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } + +import Serialization.attach + final class NetworkChannel( val name: String, connection: Socket, @@ -44,8 +49,28 @@ final class NetworkChannel( auth: Set[ServerAuthentication], instance: ServerInstance, handlers: Seq[ServerHandler], - val log: Logger + val log: Logger, + mkUIThreadImpl: (State, CommandChannel) => UITask ) extends CommandChannel { self => + def this( + name: String, + connection: Socket, + structure: BuildStructure, + auth: Set[ServerAuthentication], + instance: ServerInstance, + handlers: Seq[ServerHandler], + log: Logger + ) = + this( + name, + connection, + structure, + auth, + instance, + handlers, + log, + new UITask.AskUserTask(_, _) + ) private val running = new AtomicBoolean(true) private val delimiter: Byte = '\n'.toByte @@ -76,6 +101,7 @@ final class NetworkChannel( private[this] val terminalHolder = new AtomicReference(Terminal.NullTerminal) override private[sbt] def terminal: Terminal = terminalHolder.get + override val userThread: UserThread = new UserThread(this) private lazy val callback: ServerCallback = new ServerCallback { def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = @@ -111,6 +137,22 @@ final class NetworkChannel( protected def authOptions: Set[ServerAuthentication] = auth + override def mkUIThread: (State, CommandChannel) => UITask = (state, command) => { + if (interactive.get) mkUIThreadImpl(state, command) + else + new UITask { + override private[sbt] def channel = NetworkChannel.this + override def reader: UITask.Reader = () => { + try { + this.synchronized(this.wait) + Left("exit") + } catch { + case _: InterruptedException => Right("") + } + } + } + } + val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") { private val ct = "Content-Type: " private val x1 = "application/sbt-x1" @@ -135,7 +177,7 @@ final class NetworkChannel( } } // while } finally { - shutdown() + shutdown(false) } } @@ -242,7 +284,7 @@ final class NetworkChannel( def respond[A: JsonFormat](event: A): Unit = respond(event, None) - def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = { + def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = if (alive.get) { respondResult(event, execId) } @@ -278,7 +320,7 @@ final class NetworkChannel( } catch { case _: IOException => alive.set(false) - shutdown() + shutdown(true) case _: InterruptedException => alive.set(false) } @@ -431,7 +473,7 @@ final class NetworkChannel( errorRespond("No tasks under execution") } } catch { - case NonFatal(e) => + case NonFatal(_) => errorRespond("Cancel request failed") } } else { @@ -439,10 +481,23 @@ final class NetworkChannel( } } - def shutdown(): Unit = { - log.info("Shutting down client connection") + @deprecated("Use variant that takes logShutdown parameter", "1.4.0") + override def shutdown(): Unit = { + shutdown(true) + } + import sjsonnew.BasicJsonProtocol.BooleanJsonFormat + override def shutdown(logShutdown: Boolean): Unit = { + terminal.close() + StandardMain.exchange.removeChannel(this) + super.shutdown(logShutdown) + if (logShutdown) Terminal.consoleLog(s"shutting down client connection $name") + VirtualTerminal.cancelRequests(name) + try jsonRpcNotify("shutdown", logShutdown) + catch { case _: IOException => } running.set(false) out.close() + thread.interrupt() + writeThread.interrupt() } /** Respond back to Language Server's client. */ @@ -643,4 +698,13 @@ object NetworkChannel { case object SingleLine extends ChannelState case object InHeader extends ChannelState case object InBody extends ChannelState + + private[sbt] val disconnect: Command = + Command.arb { s => + val dncParser: Parser[String] = BasicCommandStrings.DisconnectNetworkChannel + dncParser.examples() ~> Parsers.Space.examples() ~> Parsers.any.*.examples() + } { (st, channel) => + StandardMain.exchange.killChannel(channel.mkString) + st + } } diff --git a/main/src/test/scala/PluginCommandTest.scala b/main/src/test/scala/PluginCommandTest.scala index 52b594499..7fd524ac2 100644 --- a/main/src/test/scala/PluginCommandTest.scala +++ b/main/src/test/scala/PluginCommandTest.scala @@ -17,7 +17,8 @@ import sbt.internal.util.{ ConsoleOut, GlobalLogging, MainAppender, - Settings + Settings, + Terminal } object PluginCommandTestPlugin0 extends AutoPlugin { override def requires = empty } @@ -72,19 +73,21 @@ object PluginCommandTest extends Specification { object FakeState { def processCommand(input: String, enabledPlugins: AutoPlugin*): String = { - val previousOut = System.out val outBuffer = new ByteArrayOutputStream + val logFile = File.createTempFile("sbt", ".log") try { - System.setOut(new PrintStream(outBuffer, true)) - val state = FakeState(enabledPlugins: _*) - MainLoop.processCommand(Exec(input, None), state) + val state = FakeState(logFile, enabledPlugins: _*) + Terminal.withOut(new PrintStream(outBuffer, true)) { + MainLoop.processCommand(Exec(input, None), state) + } new String(outBuffer.toByteArray) } finally { - System.setOut(previousOut) + logFile.delete() + () } } - def apply(plugins: AutoPlugin*) = { + def apply(logFile: File, plugins: AutoPlugin*) = { val base = new File("").getAbsoluteFile val testProject = Project("test-project", base).setAutoPlugins(plugins) @@ -154,9 +157,9 @@ object FakeState { State.newHistory, attributes, GlobalLogging.initial( - MainAppender.globalDefault(ConsoleOut.systemOut), - File.createTempFile("sbt", ".log"), - ConsoleOut.systemOut + MainAppender.globalDefault(ConsoleOut.globalProxy), + logFile, + ConsoleOut.globalProxy ), None, State.Continue diff --git a/project/build.properties b/project/build.properties index a919a9b5f..797e7ccfd 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.3.8 +sbt.version=1.3.10 diff --git a/server-test/src/server-test/client/build.sbt b/server-test/src/server-test/client/build.sbt new file mode 100644 index 000000000..3225bd76d --- /dev/null +++ b/server-test/src/server-test/client/build.sbt @@ -0,0 +1,7 @@ +TaskKey[Unit]("willSucceed") := println("success") + +TaskKey[Unit]("willFail") := { throw new Exception("failed") } + +libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.8" % "test" + +TaskKey[Unit]("fooBar") := { () } diff --git a/server-test/src/server-test/client/src/main/scala/A.scala b/server-test/src/server-test/client/src/main/scala/A.scala new file mode 100644 index 000000000..69c493db2 --- /dev/null +++ b/server-test/src/server-test/client/src/main/scala/A.scala @@ -0,0 +1 @@ +object A diff --git a/server-test/src/server-test/client/src/test/scala/FooSpec.scala b/server-test/src/server-test/client/src/test/scala/FooSpec.scala new file mode 100644 index 000000000..269be5624 --- /dev/null +++ b/server-test/src/server-test/client/src/test/scala/FooSpec.scala @@ -0,0 +1,3 @@ +package test.pkg + +class FooSpec extends org.scalatest.FlatSpec diff --git a/server-test/src/server-test/events/Main.scala b/server-test/src/server-test/events/Main.scala index dca34c07f..6030b2602 100644 --- a/server-test/src/server-test/events/Main.scala +++ b/server-test/src/server-test/events/Main.scala @@ -1,8 +1,6 @@ - object Main extends App { - while (true) { - Thread.sleep(1000) - } + try this.synchronized(this.wait) + catch { case _: InterruptedException => } } diff --git a/server-test/src/server-test/handshake/build.sbt b/server-test/src/server-test/handshake/build.sbt index cb7de6741..5f77cb0a3 100644 --- a/server-test/src/server-test/handshake/build.sbt +++ b/server-test/src/server-test/handshake/build.sbt @@ -5,21 +5,16 @@ ThisBuild / scalaVersion := "2.12.11" lazy val root = (project in file(".")) .settings( Global / serverLog / logLevel := Level.Debug, - // custom handler Global / serverHandlers += ServerHandler({ callback => import callback._ import sjsonnew.BasicJsonProtocol._ import sbt.internal.protocol.JsonRpcRequestMessage - ServerIntent( - { - case r: JsonRpcRequestMessage if r.method == "lunar/helo" => - jsonRpcNotify("lunar/oleh", "") - () - }, - PartialFunction.empty - ) + ServerIntent.request { + case r: JsonRpcRequestMessage if r.method == "lunar/helo" => + jsonRpcNotify("lunar/oleh", "") + () + } }), - name := "handshake", ) diff --git a/server-test/src/test/scala/testpkg/EventsTest.scala b/server-test/src/test/scala/testpkg/EventsTest.scala index 5359bb2dd..8865e8eb3 100644 --- a/server-test/src/test/scala/testpkg/EventsTest.scala +++ b/server-test/src/test/scala/testpkg/EventsTest.scala @@ -8,54 +8,79 @@ package testpkg import scala.concurrent.duration._ +import java.util.concurrent.atomic.AtomicInteger +import sbt.Exec // starts svr using server-test/events and perform event related tests object EventsTest extends AbstractServerTest { override val testDirectory: String = "events" + val currentID = new AtomicInteger(0) test("report task failures in case of exceptions") { _ => + val id = currentID.getAndIncrement() svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 11, "method": "sbt/exec", "params": { "commandLine": "hello" } }""" + s"""{ "jsonrpc": "2.0", "id": $id, "method": "sbt/exec", "params": { "commandLine": "hello" } }""" ) assert(svr.waitForString(10.seconds) { s => - (s contains """"id":11""") && (s contains """"error":""") + (s contains s""""id":$id""") && (s contains """"error":""") }) } test("return error if cancelling non-matched task id") { _ => + val id = currentID.getAndIncrement() svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" + s"""{ "jsonrpc": "2.0", "id":$id, "method": "sbt/exec", "params": { "commandLine": "run" } }""" ) + assert(svr.waitForString(10.seconds) { s => + s contains "Waiting for" + }) + val cancelID = currentID.getAndIncrement() + val invalidID = currentID.getAndIncrement() svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id":13, "method": "sbt/cancelRequest", "params": { "id": "55" } }""" + s"""{ "jsonrpc": "2.0", "id":$cancelID, "method": "sbt/cancelRequest", "params": { "id": "$invalidID" } }""" ) assert(svr.waitForString(20.seconds) { s => (s contains """"error":{"code":-32800""") }) + svr.sendJsonRpc( + s"""{ "jsonrpc": "2.0", "id":${currentID.getAndIncrement}, "method": "sbt/cancelRequest", "params": { "id": "$id" } }""" + ) + assert(svr.waitForString(10.seconds) { s => + s contains """"result":{"status":"Task cancelled"""" + }) } test("cancel on-going task with numeric id") { _ => + val id = currentID.getAndIncrement() svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" + s"""{ "jsonrpc": "2.0", "id":$id, "method": "sbt/exec", "params": { "commandLine": "run" } }""" ) - assert(svr.waitForString(1.minute) { s => - svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id":13, "method": "sbt/cancelRequest", "params": { "id": "12" } }""" - ) + assert(svr.waitForString(10.seconds) { s => + s contains "Waiting for" + }) + val cancelID = currentID.getAndIncrement() + svr.sendJsonRpc( + s"""{ "jsonrpc": "2.0", "id":$cancelID, "method": "sbt/cancelRequest", "params": { "id": "$id" } }""" + ) + assert(svr.waitForString(10.seconds) { s => s contains """"result":{"status":"Task cancelled"""" }) } - /* test("cancel on-going task with string id") { _ => + val id = Exec.newExecId svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id": "foo", "method": "sbt/exec", "params": { "commandLine": "run" } }""" + s"""{ "jsonrpc": "2.0", "id": "$id", "method": "sbt/exec", "params": { "commandLine": "run" } }""" ) - assert(svr.waitForString(1.minute) { s => - svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id": "bar", "method": "sbt/cancelRequest", "params": { "id": "foo" } }""" - ) + assert(svr.waitForString(10.seconds) { s => + s contains "Waiting for" + }) + val cancelID = Exec.newExecId + svr.sendJsonRpc( + s"""{ "jsonrpc": "2.0", "id": "$cancelID", "method": "sbt/cancelRequest", "params": { "id": "$id" } }""" + ) + assert(svr.waitForString(10.seconds) { s => s contains """"result":{"status":"Task cancelled"""" }) - }*/ + } } diff --git a/server-test/src/test/scala/testpkg/ResponseTest.scala b/server-test/src/test/scala/testpkg/ResponseTest.scala index 75cec4250..4b4057989 100644 --- a/server-test/src/test/scala/testpkg/ResponseTest.scala +++ b/server-test/src/test/scala/testpkg/ResponseTest.scala @@ -18,7 +18,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "10", "method": "foo/export", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """"id":"10"""") && (s contains "scala-library.jar") }) @@ -29,7 +29,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "11", "method": "foo/rootClasspath", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """"id":"11"""") && (s contains "scala-library.jar") }) @@ -40,7 +40,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "12", "method": "foo/fail", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """"error":{"code":-33000,"message":"fail message"""") }) } @@ -50,7 +50,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "13", "method": "foo/customfail", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """"error":{"code":500,"message":"some error"""") }) } @@ -60,7 +60,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "14", "method": "foo/notification", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """{"jsonrpc":"2.0","method":"foo/something","params":"something"}""") }) } @@ -71,14 +71,14 @@ object ResponseTest extends AbstractServerTest { ) assert { svr.waitForString(1.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"id\":\"15\"" } } assert { // the second response should never be sent svr.neverReceive(500.milliseconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"id\":\"15\"" } } @@ -90,14 +90,14 @@ object ResponseTest extends AbstractServerTest { ) assert { svr.waitForString(1.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"id\":\"16\"" } } assert { // the second response (result or error) should never be sent svr.neverReceive(500.milliseconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"id\":\"16\"" } } @@ -109,7 +109,7 @@ object ResponseTest extends AbstractServerTest { ) assert { svr.neverReceive(500.milliseconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"result\":\"notification result\"" } } diff --git a/server-test/src/test/scala/testpkg/TestServer.scala b/server-test/src/test/scala/testpkg/TestServer.scala index 649923f47..d601e6d95 100644 --- a/server-test/src/test/scala/testpkg/TestServer.scala +++ b/server-test/src/test/scala/testpkg/TestServer.scala @@ -8,6 +8,7 @@ package testpkg import java.io.{ File, IOException } +import java.nio.file.Path import java.util.concurrent.TimeoutException import verify._ @@ -26,6 +27,7 @@ trait AbstractServerTest extends TestSuite[Unit] { private var temp: File = _ var svr: TestServer = _ def testDirectory: String + def testPath: Path = temp.toPath.resolve(testDirectory) private val targetDir: File = { val p0 = new File("..").getAbsoluteFile.getCanonicalFile / "target" @@ -224,7 +226,7 @@ case class TestServer( def bye(): Unit = { hostLog("sending exit") sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "exit" } }""" + """{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "shutdown" } }""" ) val deadline = 10.seconds.fromNow while (!deadline.isOverdue && process.isAlive) {