diff --git a/build.sbt b/build.sbt index a70ce498f..dfbbcea83 100644 --- a/build.sbt +++ b/build.sbt @@ -955,6 +955,10 @@ lazy val mainProj = (project in file("main")) exclude[DirectMissingMethodProblem]("sbt.Classpaths.warnInsecureProtocol"), exclude[DirectMissingMethodProblem]("sbt.Classpaths.warnInsecureProtocolInModules"), exclude[MissingClassProblem]("sbt.internal.ExternalHooks*"), + // This seems to be a mima problem. The older constructor still exists but + // mima seems to incorrectly miss the secondary constructor that provides + // the binary compatible version. + exclude[IncompatibleMethTypeProblem]("sbt.internal.server.NetworkChannel.this"), ) ) .configure( diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/ManagedLogger.scala b/internal/util-logging/src/main/scala/sbt/internal/util/ManagedLogger.scala index 95b670635..9bb00524f 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/ManagedLogger.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/ManagedLogger.scala @@ -21,8 +21,11 @@ class ManagedLogger( val name: String, val channelName: Option[String], val execId: Option[String], - xlogger: XLogger + xlogger: XLogger, + terminal: Option[Terminal] ) extends Logger { + def this(name: String, channelName: Option[String], execId: Option[String], xlogger: XLogger) = + this(name, channelName, execId, xlogger, None) override def trace(t: => Throwable): Unit = logEvent(Level.Error, TraceEvent("Error", t, channelName, execId)) override def log(level: Level.Value, message: => String): Unit = { @@ -35,10 +38,12 @@ class ManagedLogger( private lazy val SuccessEventTag = scala.reflect.runtime.universe.typeTag[SuccessEvent] // send special event for success since it's not a real log level override def success(message: => String): Unit = { - infoEvent[SuccessEvent](SuccessEvent(message))( - implicitly[JsonFormat[SuccessEvent]], - SuccessEventTag - ) + if (terminal.fold(true)(_.isSuccessEnabled)) { + infoEvent[SuccessEvent](SuccessEvent(message))( + implicitly[JsonFormat[SuccessEvent]], + SuccessEventTag + ) + } } def registerStringCodec[A: ShowLines: TypeTag]: Unit = { diff --git a/internal/util-logging/src/main/scala/sbt/util/LogExchange.scala b/internal/util-logging/src/main/scala/sbt/util/LogExchange.scala index 869328ce2..a1073b36f 100644 --- a/internal/util-logging/src/main/scala/sbt/util/LogExchange.scala +++ b/internal/util-logging/src/main/scala/sbt/util/LogExchange.scala @@ -49,7 +49,7 @@ sealed abstract class LogExchange { config.addLogger(name, loggerConfig) ctx.updateLoggers val logger = ctx.getLogger(name) - new ManagedLogger(name, channelName, execId, logger) + new ManagedLogger(name, channelName, execId, logger, Some(Terminal.get)) } def unbindLoggerAppenders(loggerName: String): Unit = { val lc = loggerConfig(loggerName) diff --git a/main-command/src/main/scala/sbt/BasicCommandStrings.scala b/main-command/src/main/scala/sbt/BasicCommandStrings.scala index 331852d12..939b95e9e 100644 --- a/main-command/src/main/scala/sbt/BasicCommandStrings.scala +++ b/main-command/src/main/scala/sbt/BasicCommandStrings.scala @@ -235,4 +235,7 @@ $AliasCommand name= (ContinuousExecutePrefix + " ", continuousDetail) def ClearCaches: String = "clearCaches" def ClearCachesDetailed: String = "Clears all of sbt's internal caches." + + private[sbt] val networkExecPrefix = "__" + private[sbt] val DisconnectNetworkChannel = s"${networkExecPrefix}disconnectNetworkChannel" } diff --git a/main-command/src/main/scala/sbt/BasicCommands.scala b/main-command/src/main/scala/sbt/BasicCommands.scala index d74b3568c..cba9c73c1 100644 --- a/main-command/src/main/scala/sbt/BasicCommands.scala +++ b/main-command/src/main/scala/sbt/BasicCommands.scala @@ -346,7 +346,13 @@ object BasicCommands { private[this] def classpathStrings: Parser[Seq[String]] = token(StringBasic.map(s => IO.pathSplit(s).toSeq), "") - def exit: Command = Command.command(TerminateAction, exitBrief, exitBrief)(_ exit true) + def exit: Command = Command.command(TerminateAction, exitBrief, exitBrief) { s => + s.source match { + case Some(c) if c.channelName.startsWith("network") => + s"${DisconnectNetworkChannel} ${c.channelName}" :: s + case _ => s exit true + } + } @deprecated("Replaced by BuiltInCommands.continuous", "1.3.0") def continuous: Command = diff --git a/main-command/src/main/scala/sbt/BasicKeys.scala b/main-command/src/main/scala/sbt/BasicKeys.scala index 107b2df88..985ed5879 100644 --- a/main-command/src/main/scala/sbt/BasicKeys.scala +++ b/main-command/src/main/scala/sbt/BasicKeys.scala @@ -13,7 +13,7 @@ import com.github.ghik.silencer.silent import sbt.internal.inc.classpath.{ ClassLoaderCache => IncClassLoaderCache } import sbt.internal.classpath.ClassLoaderCache import sbt.internal.server.ServerHandler -import sbt.internal.util.AttributeKey +import sbt.internal.util.{ AttributeKey, Terminal } import sbt.librarymanagement.ModuleID import sbt.util.Level @@ -35,6 +35,11 @@ object BasicKeys { "The function that constructs the command prompt from the current build state.", 10000 ) + val terminalShellPrompt = AttributeKey[(Terminal, State) => String]( + "new-shell-prompt", + "The function that constructs the command prompt from the current build state for a given terminal.", + 10000 + ) @silent val watch = AttributeKey[Watched]("watched", "Continuous execution configuration.", 1000) val serverPort = diff --git a/main-command/src/main/scala/sbt/internal/CommandChannel.scala b/main-command/src/main/scala/sbt/internal/CommandChannel.scala index 9096c8a20..154e4cffd 100644 --- a/main-command/src/main/scala/sbt/internal/CommandChannel.scala +++ b/main-command/src/main/scala/sbt/internal/CommandChannel.scala @@ -9,9 +9,13 @@ package sbt package internal import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicReference +import sbt.internal.ui.{ UITask, UserThread } import sbt.internal.util.Terminal import sbt.protocol.EventMessage +import sbt.util.Level + import scala.collection.JavaConverters._ /** @@ -48,6 +52,8 @@ abstract class CommandChannel { private[sbt] final def initiateMaintenance(task: String): Unit = { maintenance.forEach(q => q.synchronized { q.add(new MaintenanceTask(this, task)); () }) } + private[sbt] def mkUIThread: (State, CommandChannel) => UITask + private[sbt] def makeUIThread(state: State): UITask = mkUIThread(state, this) final def append(exec: Exec): Boolean = { registered.synchronized { exec.commandLine.nonEmpty && { @@ -58,10 +64,29 @@ abstract class CommandChannel { } def poll: Option[Exec] = Option(commandQueue.poll) + def prompt(e: ConsolePromptEvent): Unit = userThread.onConsolePromptEvent(e) + def unprompt(e: ConsoleUnpromptEvent): Unit = userThread.onConsoleUnpromptEvent(e) def publishBytes(bytes: Array[Byte]): Unit - def shutdown(): Unit + private[sbt] def userThread: UserThread + def shutdown(logShutdown: Boolean): Unit = { + userThread.stopThread() + userThread.close() + } + @deprecated("Use the variant that takes the logShutdown parameter", "1.4.0") + def shutdown(): Unit = shutdown(true) def name: String + private[this] val level = new AtomicReference[Level.Value](Level.Info) + private[sbt] final def setLevel(l: Level.Value): Unit = level.set(l) + private[sbt] final def logLevel: Level.Value = level.get + private[this] def setLevel(value: Level.Value, cmd: String): Boolean = { + level.set(value) + append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name)))) + } private[sbt] def onCommand: String => Boolean = { + case "error" => setLevel(Level.Error, "error") + case "debug" => setLevel(Level.Debug, "debug") + case "info" => setLevel(Level.Info, "info") + case "warn" => setLevel(Level.Warn, "warn") case cmd => if (cmd.nonEmpty) append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name)))) else false @@ -89,7 +114,6 @@ case class ConsolePromptEvent(state: State) extends EventMessage /* * This is a data passed specifically for unprompting local console. */ -@deprecated("No longer used", "1.4.0") case class ConsoleUnpromptEvent(lastSource: Option[CommandSource]) extends EventMessage private[internal] class MaintenanceTask(val channel: CommandChannel, val task: String) diff --git a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala index 1482b0094..dac0c9565 100644 --- a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala +++ b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala @@ -8,76 +8,24 @@ package sbt package internal -import java.io.File -import java.nio.channels.ClosedChannelException -import java.util.concurrent.atomic.AtomicReference - -import sbt.BasicKeys._ +import sbt.internal.ui.{ UITask, UserThread } import sbt.internal.util._ +import sjsonnew.JsonFormat -private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel { - private[this] val askUserThread = new AtomicReference[AskUserThread] - private[this] def getPrompt(s: State): String = s.get(shellPrompt) match { - case Some(pf) => pf(s) - case None => - def ansi(s: String): String = if (ConsoleAppender.formatEnabledInEnv) s"$s" else "" - s"${ansi(ConsoleAppender.DeleteLine)}> ${ansi(ConsoleAppender.ClearScreenAfterCursor)}" - } - private[this] class AskUserThread(s: State) extends Thread("ask-user-thread") { - private val history = s.get(historyPath).getOrElse(Some(new File(s.baseDir, ".history"))) - private val prompt = getPrompt(s) - private val reader = - new FullReader( - history, - s.combinedParser, - LineReader.HandleCONT, - Terminal.console, - ) - setDaemon(true) - start() - override def run(): Unit = - try { - reader.readLine(prompt) match { - case Some(cmd) => append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name)))) - case None => - println("") // Prevents server shutdown log lines from appearing on the prompt line - append(Exec("exit", Some(Exec.newExecId), Some(CommandSource(name)))) - } - () - } catch { - case _: ClosedChannelException => - } finally askUserThread.synchronized(askUserThread.set(null)) - def redraw(): Unit = { - System.out.print(ConsoleAppender.clearLine(0)) - System.out.print(ConsoleAppender.ClearScreenAfterCursor) - System.out.flush() - } - } - private[this] def makeAskUserThread(s: State): AskUserThread = new AskUserThread(s) +private[sbt] final class ConsoleChannel( + val name: String, + override private[sbt] val mkUIThread: (State, CommandChannel) => UITask +) extends CommandChannel { def run(s: State): State = s def publishBytes(bytes: Array[Byte]): Unit = () - def prompt(event: ConsolePromptEvent): Unit = { - if (Terminal.systemInIsAttached) { - askUserThread.synchronized { - askUserThread.get match { - case null => askUserThread.set(makeAskUserThread(event.state)) - case t => t.redraw() - } - } - } - } + def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = () - def shutdown(): Unit = askUserThread.synchronized { - askUserThread.get match { - case null => - case t if t.isAlive => - t.interrupt() - askUserThread.set(null) - case _ => () - } - } - override private[sbt] def terminal = Terminal.console + override val userThread: UserThread = new UserThread(this) + private[sbt] def terminal = Terminal.console +} +private[sbt] object ConsoleChannel { + private[sbt] def defaultName = "console0" } diff --git a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala index 85c3b9867..93c9afcc0 100644 --- a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala +++ b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala @@ -20,12 +20,14 @@ import sbt.internal.util.ReadJsonFromInputStream abstract class ServerConnection(connection: Socket) { private val running = new AtomicBoolean(true) + private val closed = new AtomicBoolean(false) private val retByte: Byte = '\r'.toByte private val delimiter: Byte = '\n'.toByte private val out = connection.getOutputStream val thread = new Thread(s"sbt-serverconnection-${connection.getPort}") { + setDaemon(true) override def run(): Unit = { try { val in = connection.getInputStream @@ -67,17 +69,22 @@ abstract class ServerConnection(connection: Socket) { writeLine(a) } - def writeLine(a: Array[Byte]): Unit = { - def writeEndLine(): Unit = { - out.write(retByte.toInt) - out.write(delimiter.toInt) - out.flush + def writeLine(a: Array[Byte]): Unit = + try { + def writeEndLine(): Unit = { + out.write(retByte.toInt) + out.write(delimiter.toInt) + out.flush + } + if (a.nonEmpty) { + out.write(a) + } + writeEndLine + } catch { + case e: IOException => + shutdown() + throw e } - if (a.nonEmpty) { - out.write(a) - } - writeEndLine - } def onRequest(msg: JsonRpcRequestMessage): Unit def onResponse(msg: JsonRpcResponseMessage): Unit @@ -85,10 +92,14 @@ abstract class ServerConnection(connection: Socket) { def onShutdown(): Unit - def shutdown(): Unit = { - println("Shutting down client connection") - running.set(false) - out.close() + def shutdown(): Unit = if (closed.compareAndSet(false, true)) { + if (!running.compareAndSet(true, false)) { + System.err.println("\nsbt server connection closed.") + } + try { + out.close() + connection.close() + } catch { case e: IOException => e.printStackTrace() } onShutdown } diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 3d859d403..775683130 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -227,7 +227,7 @@ private[sbt] case class ServerConnection( socketfile: File, pipeName: String, bspConnectionFile: File, - appConfiguration: AppConfiguration + appConfiguration: AppConfiguration, ) { def shortName: String = { connectionType match { diff --git a/main-command/src/main/scala/sbt/internal/ui/UITask.scala b/main-command/src/main/scala/sbt/internal/ui/UITask.scala new file mode 100644 index 000000000..f5d365efe --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/ui/UITask.scala @@ -0,0 +1,99 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.ui + +import java.io.File +import java.nio.channels.ClosedChannelException +import java.util.concurrent.atomic.AtomicBoolean + +import jline.console.history.PersistentHistory +import sbt.BasicKeys.{ historyPath, terminalShellPrompt } +import sbt.State +import sbt.internal.CommandChannel +import sbt.internal.util.ConsoleAppender.{ ClearPromptLine, ClearScreenAfterCursor, DeleteLine } +import sbt.internal.util._ +import sbt.internal.util.complete.{ JLineCompletion, Parser } + +import scala.annotation.tailrec + +private[sbt] trait UITask extends Runnable with AutoCloseable { + private[sbt] def channel: CommandChannel + private[sbt] def reader: UITask.Reader + private[this] final def handleInput(s: Either[String, String]): Boolean = s match { + case Left(m) => channel.onMaintenance(m) + case Right(cmd) => channel.onCommand(cmd) + } + private[this] val isStopped = new AtomicBoolean(false) + override def run(): Unit = { + @tailrec def impl(): Unit = { + val res = reader.readLine() + if (!handleInput(res) && !isStopped.get) impl() + } + try impl() + catch { case _: InterruptedException | _: ClosedChannelException => isStopped.set(true) } + } + override def close(): Unit = isStopped.set(true) +} + +private[sbt] object UITask { + trait Reader { def readLine(): Either[String, String] } + object Reader { + def terminalReader(parser: Parser[_])( + terminal: Terminal, + state: State + ): Reader = { + val lineReader = LineReader.createReader(history(state), terminal, terminal.prompt) + JLineCompletion.installCustomCompletor(lineReader, parser) + () => { + val clear = terminal.ansi(ClearPromptLine, "") + try { + @tailrec def impl(): Either[String, String] = { + lineReader.readLine(clear + terminal.prompt.mkPrompt()) match { + case null => Left("exit") + case s: String => + lineReader.getHistory match { + case p: PersistentHistory => + p.add(s) + p.flush() + case _ => + } + s match { + case "" => impl() + case cmd @ ("shutdown" | "exit" | "cancel") => Left(cmd) + case cmd => + if (terminal.prompt != Prompt.Batch) terminal.setPrompt(Prompt.Running) + terminal.printStream.write(Int.MinValue) + Right(cmd) + } + } + } + impl() + } catch { + case _: InterruptedException => Right("") + } finally lineReader.close() + } + } + } + private[this] def history(s: State): Option[File] = + s.get(historyPath).getOrElse(Some(new File(s.baseDir, ".history"))) + private[sbt] def shellPrompt(terminal: Terminal, s: State): String = + s.get(terminalShellPrompt) match { + case Some(pf) => pf(terminal, s) + case None => + def ansi(s: String): String = if (terminal.isAnsiSupported) s"$s" else "" + s"${ansi(DeleteLine)}> ${ansi(ClearScreenAfterCursor)}" + } + private[sbt] class AskUserTask( + state: State, + override val channel: CommandChannel, + ) extends UITask { + override private[sbt] def reader: UITask.Reader = { + UITask.Reader.terminalReader(state.combinedParser)(channel.terminal, state) + } + } +} diff --git a/main-command/src/main/scala/sbt/internal/ui/UserThread.scala b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala new file mode 100644 index 000000000..e868c3682 --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/ui/UserThread.scala @@ -0,0 +1,91 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal + +package ui + +import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } +import java.util.concurrent.Executors + +import sbt.State +import sbt.internal.util.{ ConsoleAppender, ProgressEvent, ProgressState, Util } +import sbt.internal.util.Prompt.{ AskUser, Running } + +private[sbt] class UserThread(val channel: CommandChannel) extends AutoCloseable { + private[this] val uiThread = new AtomicReference[(UITask, Thread)] + private[sbt] final def onProgressEvent(pe: ProgressEvent): Unit = { + lastProgressEvent.set(pe) + ProgressState.updateProgressState(pe, channel.terminal) + } + private[this] val executor = + Executors.newSingleThreadExecutor(r => new Thread(r, s"sbt-$name-ui-thread")) + private[this] val lastProgressEvent = new AtomicReference[ProgressEvent] + private[this] val isClosed = new AtomicBoolean(false) + + private[sbt] def reset(state: State): Unit = if (!isClosed.get) { + uiThread.synchronized { + val task = channel.makeUIThread(state) + def submit(): Thread = { + val thread = new Thread(() => { + task.run() + uiThread.set(null) + }, s"sbt-$name-ui-thread") + thread.setDaemon(true) + thread.start() + uiThread.getAndSet((task, thread)) match { + case null => + case (_, t) => t.interrupt() + } + thread + } + uiThread.get match { + case null => uiThread.set((task, submit())) + case (t, _) if t.getClass == task.getClass => + case (t, thread) => + thread.interrupt() + uiThread.set((task, submit())) + } + } + Option(lastProgressEvent.get).foreach(onProgressEvent) + } + + private[sbt] def stopThread(): Unit = uiThread.synchronized { + uiThread.getAndSet(null) match { + case null => + case (t, thread) => + t.close() + Util.ignoreResult(thread.interrupt()) + } + } + + private[sbt] def onConsolePromptEvent(consolePromptEvent: ConsolePromptEvent): Unit = { + channel.terminal.withPrintStream { ps => + ps.print(ConsoleAppender.ClearScreenAfterCursor) + ps.flush() + } + val state = consolePromptEvent.state + terminal.prompt match { + case Running => terminal.setPrompt(AskUser(() => UITask.shellPrompt(terminal, state))) + case _ => + } + onProgressEvent(ProgressEvent("Info", Vector(), None, None, None)) + reset(state) + } + + private[sbt] def onConsoleUnpromptEvent( + consoleUnpromptEvent: ConsoleUnpromptEvent + ): Unit = { + if (consoleUnpromptEvent.lastSource.fold(true)(_.channelName != name)) { + terminal.progressState.reset() + } else stopThread() + } + + override def close(): Unit = if (isClosed.compareAndSet(false, true)) executor.shutdown() + private def terminal = channel.terminal + private def name: String = channel.name +} diff --git a/main-settings/src/main/scala/sbt/Def.scala b/main-settings/src/main/scala/sbt/Def.scala index 6f169c5c6..2f6833fb8 100644 --- a/main-settings/src/main/scala/sbt/Def.scala +++ b/main-settings/src/main/scala/sbt/Def.scala @@ -170,12 +170,11 @@ object Def extends Init[Scope] with TaskMacroExtra with InitializeImplicits { def displayMasked(scoped: ScopedKey[_], mask: ScopeMask, showZeroConfig: Boolean): String = Scope.displayMasked(scoped.scope, scoped.key.label, mask, showZeroConfig) - def withColor(s: String, color: Option[String]): String = { - val useColor = ConsoleAppender.formatEnabledInEnv - color match { - case Some(c) if useColor => c + s + scala.Console.RESET - case _ => s - } + def withColor(s: String, color: Option[String]): String = + withColor(s, color, useColor = ConsoleAppender.formatEnabledInEnv) + def withColor(s: String, color: Option[String], useColor: Boolean): String = color match { + case Some(c) if useColor => c + s + scala.Console.RESET + case _ => s } override def deriveAllowed[T](s: Setting[T], allowDynamic: Boolean): Option[String] = diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index b3d6f3e92..90b6bc815 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -407,7 +407,7 @@ object Defaults extends BuildCommon { // TODO: This should be on the new default settings for a project. def projectCore: Seq[Setting[_]] = Seq( name := thisProject.value.id, - logManager := LogManager.defaults(extraLoggers.value, StandardMain.console), + logManager := LogManager.defaults(extraLoggers.value, ConsoleOut.terminalOut), onLoadMessage := (onLoadMessage or Def.setting { s"set current project to ${name.value} (in build ${thisProjectRef.value.build})" @@ -1495,13 +1495,13 @@ object Defaults extends BuildCommon { def askForMainClass(classes: Seq[String]): Option[String] = sbt.SelectMainClass( - if (classes.length >= 10) Some(SimpleReader.readLine(_)) + if (classes.length >= 10) Some(SimpleReader(Terminal.get).readLine(_)) else Some(s => { def print(st: String) = { scala.Console.out.print(st); scala.Console.out.flush() } print(s) Terminal.get.withRawSystemIn { - Terminal.read match { + Terminal.get.inputStream.read match { case -1 => None case b => val res = b.toChar.toString @@ -2343,6 +2343,9 @@ object Classpaths { CrossVersion(scalaVersion, binVersion)(base).withCrossVersion(Disabled()) }, shellPrompt := shellPromptFromState, + terminalShellPrompt := { (t, s) => + shellPromptFromState(t)(s) + }, dynamicDependency := { (): Unit }, transitiveClasspathDependency := { (): Unit }, transitiveDynamicInputs :== Nil, @@ -3826,11 +3829,13 @@ object Classpaths { } } - def shellPromptFromState: State => String = { s: State => + def shellPromptFromState: State => String = shellPromptFromState(Terminal.console) + def shellPromptFromState(terminal: Terminal): State => String = { s: State => val extracted = Project.extract(s) (name in extracted.currentRef).get(extracted.structure.data) match { - case Some(name) => s"sbt:$name" + Def.withColor("> ", Option(scala.Console.CYAN)) - case _ => "> " + case Some(name) => + s"sbt:$name" + Def.withColor(s"> ", Option(scala.Console.CYAN), terminal.isColorEnabled) + case _ => "> " } } } diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 2a75c9901..54a3c5ee1 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -89,6 +89,7 @@ object Keys { // Command keys val historyPath = SettingKey(BasicKeys.historyPath) val shellPrompt = SettingKey(BasicKeys.shellPrompt) + val terminalShellPrompt = SettingKey(BasicKeys.terminalShellPrompt) val autoStartServer = SettingKey(BasicKeys.autoStartServer) val serverPort = SettingKey(BasicKeys.serverPort) val serverHost = SettingKey(BasicKeys.serverHost) diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 03a110754..fe155b856 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -13,8 +13,9 @@ import java.nio.channels.ClosedChannelException import java.nio.file.{ FileAlreadyExistsException, FileSystems, Files } import java.util.Properties import java.util.concurrent.ForkJoinPool +import java.util.concurrent.atomic.AtomicBoolean -import sbt.BasicCommandStrings.{ Shell, TemplateCommand } +import sbt.BasicCommandStrings.{ Shell, TemplateCommand, networkExecPrefix } import sbt.Project.LoadAction import sbt.compiler.EvalImports import sbt.internal.Aggregation.AnyKeys @@ -24,6 +25,7 @@ import sbt.internal.client.BspClient import sbt.internal.inc.ScalaInstance import sbt.internal.io.Retry import sbt.internal.nio.CheckBuildSources +import sbt.internal.server.NetworkChannel import sbt.internal.util.Types.{ const, idFun } import sbt.internal.util._ import sbt.internal.util.complete.{ Parser, SizeParser } @@ -34,7 +36,9 @@ import xsbti.compile.CompilerCache import scala.annotation.tailrec import scala.concurrent.ExecutionContext +import scala.concurrent.duration.Duration import scala.util.control.NonFatal +import sbt.internal.io.Retry /** This class is the entry point for sbt. */ final class xMain extends xsbti.AppMain { @@ -130,18 +134,24 @@ object StandardMain { pool.foreach(_.shutdownNow()) } + private[this] val isShutdown = new AtomicBoolean(false) def runManaged(s: State): xsbti.MainResult = { val previous = TrapExit.installManager() try { val hook = ShutdownHooks.add(closeRunnable) try { MainLoop.runLogged(s) + } catch { + case _: InterruptedException if isShutdown.get => + new xsbti.Exit { override def code(): Int = 0 } } finally { try DefaultBackgroundJobService.shutdown() finally hook.close() () } - } finally TrapExit.uninstallManager(previous) + } finally { + TrapExit.uninstallManager(previous) + } } /** The common interface to standard output, used for all built-in ConsoleLoggers. */ @@ -254,6 +264,7 @@ object BuiltinCommands { act, continuous, clearCaches, + NetworkChannel.disconnect, ) ++ allBasicCommands def DefaultBootCommands: Seq[String] = @@ -904,6 +915,16 @@ object BuiltinCommands { Command.command(ClearCaches, help)(f) } + private def getExec(state: State, interval: Duration): Exec = { + val exec: Exec = + StandardMain.exchange.blockUntilNextExec(interval, Some(state), state.globalLogging.full) + if (exec.source.fold(true)(_.channelName != ConsoleChannel.defaultName) && + !exec.commandLine.startsWith(networkExecPrefix)) { + Terminal.consoleLog(s"received remote command: ${exec.commandLine}") + } + exec + } + def shell: Command = Command.command(Shell, Help.more(Shell, ShellDetailed)) { s0 => import sbt.internal.ConsolePromptEvent val exchange = StandardMain.exchange @@ -914,18 +935,17 @@ object BuiltinCommands { .extract(s1) .getOpt(Keys.minForcegcInterval) .getOrElse(GCUtil.defaultMinForcegcInterval) - val exec: Exec = exchange.blockUntilNextExec(minGCInterval, s1.globalLogging.full) - if (exec.source.fold(true)(_.channelName != "console0")) { - s1.log.info(s"received remote command: ${exec.commandLine}") - } + val exec: Exec = getExec(s1, minGCInterval) val newState = s1 .copy( onFailure = Some(Exec(Shell, None)), remainingCommands = exec +: Exec(Shell, None) +: s1.remainingCommands ) .setInteractive(true) - if (exec.commandLine.trim.isEmpty) newState - else newState.clearGlobalLog + val res = + if (exec.commandLine.trim.isEmpty) newState + else newState.clearGlobalLog + res } def startServer: Command = diff --git a/main/src/main/scala/sbt/MainLoop.scala b/main/src/main/scala/sbt/MainLoop.scala index 41f312f0c..fd3005138 100644 --- a/main/src/main/scala/sbt/MainLoop.scala +++ b/main/src/main/scala/sbt/MainLoop.scala @@ -10,11 +10,13 @@ package sbt import java.io.PrintWriter import java.util.Properties +import sbt.BasicCommandStrings.{ StashOnFailure, networkExecPrefix } import sbt.internal.ShutdownHooks import sbt.internal.langserver.ErrorCodes import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.nio.CheckBuildSources.CheckBuildSourcesKey import sbt.internal.util.{ ErrorHandling, GlobalLogBacking, Terminal } +import sbt.internal.{ ConsoleUnpromptEvent, ShutdownHooks } import sbt.io.{ IO, Using } import sbt.protocol._ import sbt.util.Logger @@ -195,16 +197,21 @@ object MainLoop { state.put(sbt.Keys.currentTaskProgress, new Keys.TaskProgress(progress)) } else state } + StandardMain.exchange.setState(progressState) StandardMain.exchange.setExec(Some(exec)) + StandardMain.exchange.unprompt(ConsoleUnpromptEvent(exec.source)) val newState = Command.process(exec.commandLine, progressState) - val doneEvent = ExecStatusEvent( - "Done", - channelName, - exec.execId, - newState.remainingCommands.toVector map (_.commandLine), - exitCode(newState, state), - ) - StandardMain.exchange.respondStatus(doneEvent) + if (exec.execId.fold(true)(!_.startsWith(networkExecPrefix)) && + !exec.commandLine.startsWith(networkExecPrefix)) { + val doneEvent = ExecStatusEvent( + "Done", + channelName, + exec.execId, + newState.remainingCommands.toVector map (_.commandLine), + exitCode(newState, state), + ) + StandardMain.exchange.respondStatus(doneEvent) + } StandardMain.exchange.setExec(None) newState.get(sbt.Keys.currentTaskProgress).foreach(_.progress.stop()) newState.remove(sbt.Keys.currentTaskProgress) @@ -273,7 +280,8 @@ object MainLoop { // it's handled by executing the shell again, instead of the state failing // so we also use that to indicate that the execution failed private[this] def exitCodeFromStateOnFailure(state: State, prevState: State): ExitCode = - if (prevState.onFailure.isDefined && state.onFailure.isEmpty) ExitCode(ErrorCodes.UnknownError) + if (prevState.onFailure.isDefined && state.onFailure.isEmpty && + state.currentCommand.fold(true)(_ != StashOnFailure)) ExitCode(ErrorCodes.UnknownError) else ExitCode.Success } diff --git a/main/src/main/scala/sbt/Project.scala b/main/src/main/scala/sbt/Project.scala index 555737dd5..18cf13ddc 100755 --- a/main/src/main/scala/sbt/Project.scala +++ b/main/src/main/scala/sbt/Project.scala @@ -19,6 +19,7 @@ import Keys.{ historyPath, projectCommand, sessionSettings, + terminalShellPrompt, shellPrompt, templateResolverInfos, autoStartServer, @@ -508,6 +509,7 @@ object Project extends ProjectExtra { val allCommands = commandsIn(ref) ++ commandsIn(BuildRef(ref.build)) ++ (commands in Global get structure.data toList) val history = get(historyPath) flatMap idFun val prompt = get(shellPrompt) + val newPrompt = get(terminalShellPrompt) val trs = (templateResolverInfos in Global get structure.data).toList.flatten val startSvr: Option[Boolean] = get(autoStartServer) val host: Option[String] = get(serverHost) @@ -532,6 +534,7 @@ object Project extends ProjectExtra { .put(historyPath.key, history) .put(templateResolverInfos.key, trs) .setCond(shellPrompt.key, prompt) + .setCond(terminalShellPrompt.key, newPrompt) .setCond(serverLogLevel, srvLogLevel) .setCond(fullServerHandlers.key, hs) s.copy( diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 6b3178062..cbabbb6a5 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -10,19 +10,21 @@ package sbt package internal import java.io.IOException import java.net.Socket -import java.util.concurrent.{ ConcurrentLinkedQueue, LinkedBlockingQueue, TimeUnit } import java.util.concurrent.atomic._ +import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit } +import sbt.BasicCommandStrings.networkExecPrefix import sbt.BasicKeys._ -import sbt.nio.Watch.NullLogger import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.server._ -import sbt.internal.util.{ ConsoleOut, MainAppender, ProgressEvent, ProgressState, Terminal } +import sbt.internal.ui.UITask +import sbt.internal.util._ import sbt.io.syntax._ import sbt.io.{ Hash, IO } +import sbt.nio.Watch.NullLogger import sbt.protocol.{ ExecStatusEvent, LogEvent } -import sbt.util.{ Level, LogExchange, Logger } -import sjsonnew.JsonFormat +import sbt.util.Logger +import sbt.protocol.Serialization.attach import scala.annotation.tailrec import scala.collection.mutable.ListBuffer @@ -30,6 +32,8 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.util.{ Failure, Success, Try } +import sjsonnew.JsonFormat + /** * The command exchange merges multiple command channels (e.g. network and console), * and acts as the central multiplexing point. @@ -42,23 +46,15 @@ private[sbt] final class CommandExchange { private var server: Option[ServerInstance] = None private val firstInstance: AtomicBoolean = new AtomicBoolean(true) private var consoleChannel: Option[ConsoleChannel] = None - private val commandQueue: ConcurrentLinkedQueue[Exec] = new ConcurrentLinkedQueue() + private val commandQueue: LinkedBlockingQueue[Exec] = new LinkedBlockingQueue[Exec] private val channelBuffer: ListBuffer[CommandChannel] = new ListBuffer() private val channelBufferLock = new AnyRef {} - private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel] private val maintenanceChannelQueue = new LinkedBlockingQueue[MaintenanceTask] private val nextChannelId: AtomicInteger = new AtomicInteger(0) - private[this] val activePrompt = new AtomicBoolean(false) private[this] val lastState = new AtomicReference[State] private[this] val currentExecRef = new AtomicReference[Exec] def channels: List[CommandChannel] = channelBuffer.toList - private[this] def removeChannel(channel: CommandChannel): Unit = { - channelBufferLock.synchronized { - channelBuffer -= channel - () - } - } def subscribe(c: CommandChannel): Unit = channelBufferLock.synchronized { channelBuffer.append(c) @@ -68,23 +64,39 @@ private[sbt] final class CommandExchange { private[sbt] def withState[T](f: State => T): T = f(lastState.get) def blockUntilNextExec: Exec = blockUntilNextExec(Duration.Inf, NullLogger) // periodically move all messages from all the channels - private[sbt] def blockUntilNextExec(interval: Duration, logger: Logger): Exec = { + private[sbt] def blockUntilNextExec(interval: Duration, logger: Logger): Exec = + blockUntilNextExec(interval, None, logger) + private[sbt] def blockUntilNextExec( + interval: Duration, + state: Option[State], + logger: Logger + ): Exec = { @tailrec def impl(deadline: Option[Deadline]): Exec = { - @tailrec def slurpMessages(): Unit = - channels.foldLeft(Option.empty[Exec]) { _ orElse _.poll } match { - case None => () - case Some(x) => - commandQueue.add(x) - slurpMessages() - } - commandChannelQueue.poll(1, TimeUnit.SECONDS) - slurpMessages() - Option(commandQueue.poll) match { - case Some(exec) => - val needFinish = needToFinishPromptLine() - if (exec.source.fold(needFinish)(s => needFinish && s.channelName != "console0")) - ConsoleOut.systemOut.println("") - exec + state.foreach(s => prompt(ConsolePromptEvent(s))) + def poll: Option[Exec] = + Option(deadline match { + case Some(d: Deadline) => + commandQueue.poll(d.timeLeft.toMillis + 1, TimeUnit.MILLISECONDS) + case _ => commandQueue.take + }) + poll match { + case Some(exec) if exec.source.fold(true)(s => channels.exists(_.name == s.channelName)) => + exec.commandLine match { + case "shutdown" => + exec + .withCommandLine("exit") + .withSource(Some(CommandSource(ConsoleChannel.defaultName))) + case "exit" if exec.source.fold(false)(_.channelName.startsWith("network")) => + channels.collectFirst { + case c: NetworkChannel if exec.source.fold(false)(_.channelName == c.name) => c + } match { + case Some(c) if c.isAttached => + c.shutdown(false) + impl(deadline) + case _ => exec + } + case _ => exec + } case None => val newDeadline = if (deadline.fold(false)(_.isOverdue())) { GCUtil.forceGcWithInterval(interval, logger) @@ -100,23 +112,38 @@ private[sbt] final class CommandExchange { }) } - def run(s: State): State = { + private def addConsoleChannel(): Unit = if (consoleChannel.isEmpty) { - val console0 = new ConsoleChannel("console0") + val name = ConsoleChannel.defaultName + val console0 = new ConsoleChannel(name, mkAskUser(name)) consoleChannel = Some(console0) subscribe(console0) } - val autoStartServerAttr = s get autoStartServer match { - case Some(bool) => bool - case None => true - } - if (autoStartServerSysProp && autoStartServerAttr) runServer(s) + def run(s: State): State = run(s, s.get(autoStartServer).getOrElse(true)) + def run(s: State, autoStart: Boolean): State = { + if (autoStartServerSysProp && autoStart) runServer(s) else s } private[sbt] def setState(s: State): Unit = lastState.set(s) private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}" + private[sbt] def removeChannel(c: CommandChannel): Unit = { + channelBufferLock.synchronized { + Util.ignoreResult(channelBuffer -= c) + } + commandQueue.removeIf(_.source.map(_.channelName) == Some(c.name)) + currentExec.filter(_.source.map(_.channelName) == Some(c.name)).foreach { e => + Util.ignoreResult(NetworkChannel.cancel(e.execId, e.execId.getOrElse("0"))) + } + } + + private[this] def mkAskUser( + name: String, + ): (State, CommandChannel) => UITask = { (state, channel) => + new UITask.AskUserTask(state, channel) + } + private[sbt] def currentExec = Option(currentExecRef.get) /** @@ -128,22 +155,22 @@ private[sbt] final class CommandExchange { lazy val auth: Set[ServerAuthentication] = s.get(serverAuthentication).getOrElse(Set(ServerAuthentication.Token)) lazy val connectionType = s.get(serverConnectionType).getOrElse(ConnectionType.Tcp) - lazy val level = s.get(serverLogLevel).orElse(s.get(logLevel)).getOrElse(Level.Warn) lazy val handlers = s.get(fullServerHandlers).getOrElse(Nil) def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = { val name = newNetworkName - if (needToFinishPromptLine()) ConsoleOut.systemOut.println("") - s.log.info(s"new client connected: $name") - val logger: Logger = { - val log = LogExchange.logger(name, None, None) - LogExchange.unbindLoggerAppenders(name) - val appender = MainAppender.defaultScreen(s.globalLogging.console) - LogExchange.bindLoggerAppenders(name, List(appender -> level)) - log - } + Terminal.consoleLog(s"new client connected: $name") val channel = - new NetworkChannel(name, socket, Project structure s, auth, instance, handlers, logger) + new NetworkChannel( + name, + socket, + Project structure s, + auth, + instance, + handlers, + s.log, + mkAskUser(name) + ) subscribe(channel) } if (server.isEmpty && firstInstance.get) { @@ -165,7 +192,7 @@ private[sbt] final class CommandExchange { socketfile, pipeName, bspConnectionFile, - s.configuration + s.configuration, ) val serverInstance = Server.start(connection, onIncomingSocket, s.log) // don't throw exception when it times out @@ -196,7 +223,7 @@ private[sbt] final class CommandExchange { def shutdown(): Unit = { maintenanceThread.close() - channels foreach (_.shutdown()) + channels foreach (_.shutdown(true)) // interrupt and kill the thread server.foreach(_.shutdown()) server = None @@ -243,11 +270,10 @@ private[sbt] final class CommandExchange { // This is an interface to directly notify events. private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = { - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.notifyEvent(method, params)) - } + channels.foreach { + case c: NetworkChannel => tryTo(_.notifyEvent(method, params))(c) + case _ => + } } private def tryTo(f: NetworkChannel => Unit)( @@ -274,26 +300,22 @@ private[sbt] final class CommandExchange { tryTo(_.respondError(code, event.message.getOrElse(""), event.execId))(channel) } } - - tryTo(_.respond(event, event.execId))(channel) } } private[sbt] def setExec(exec: Option[Exec]): Unit = currentExecRef.set(exec.orNull) def prompt(event: ConsolePromptEvent): Unit = { - activePrompt.set(Terminal.systemInIsAttached) - channels - .collect { case c: ConsoleChannel => c } - .foreach { _.prompt(event) } + currentExecRef.set(null) + channels.foreach(_.prompt(event)) } + def unprompt(event: ConsoleUnpromptEvent): Unit = channels.foreach(_.unprompt(event)) def logMessage(event: LogEvent): Unit = { - channels - .collect { case c: NetworkChannel => c } - .foreach { - tryTo(_.notifyEvent(event)) - } + channels.foreach { + case c: NetworkChannel => tryTo(_.notifyEvent(event))(c) + case _ => + } } def notifyStatus(event: ExecStatusEvent): Unit = { @@ -305,19 +327,23 @@ private[sbt] final class CommandExchange { } tryTo(_.notifyEvent(event))(channel) } - private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) + private[sbt] def killChannel(channel: String): Unit = { + channels.find(_.name == channel).foreach(_.shutdown(false)) + } private[sbt] def updateProgress(pe: ProgressEvent): Unit = { val newPE = currentExec match { - case Some(e) => + case Some(e) if !e.commandLine.startsWith(networkExecPrefix) => pe.withCommand(currentExec.map(_.commandLine)) .withExecId(currentExec.flatMap(_.execId)) .withChannelName(currentExec.flatMap(_.source.map(_.channelName))) case _ => pe } + if (channels.isEmpty) addConsoleChannel() channels.foreach(c => ProgressState.updateProgressState(newPE, c.terminal)) } private[sbt] def shutdown(name: String): Unit = { + Option(currentExecRef.get).foreach(cancel) commandQueue.clear() val exit = Exec("shutdown", Some(Exec.newExecId), Some(CommandSource(name))) @@ -333,7 +359,7 @@ private[sbt] final class CommandExchange { private[this] val isStopped = new AtomicBoolean(false) override def run(): Unit = { def exit(mt: MaintenanceTask): Unit = { - mt.channel.shutdown() + mt.channel.shutdown(false) if (mt.channel.name.contains("console")) shutdown(mt.channel.name) } @tailrec def impl(): Unit = { @@ -341,6 +367,8 @@ private[sbt] final class CommandExchange { case null => case mt: MaintenanceTask => mt.task match { + case `attach` => mt.channel.prompt(ConsolePromptEvent(lastState.get)) + case "cancel" => Option(currentExecRef.get).foreach(cancel) case "exit" => exit(mt) case "shutdown" => shutdown(mt.channel.name) case _ => diff --git a/main/src/main/scala/sbt/internal/ConsoleProject.scala b/main/src/main/scala/sbt/internal/ConsoleProject.scala index 82fe28e8e..52beebc9a 100644 --- a/main/src/main/scala/sbt/internal/ConsoleProject.scala +++ b/main/src/main/scala/sbt/internal/ConsoleProject.scala @@ -28,7 +28,7 @@ object ConsoleProject { extracted.runTask(Keys.scalaCompilerBridgeBinaryJar.in(Keys.consoleProject), state1) val scalaInstance = { val scalaProvider = state.configuration.provider.scalaProvider - ScalaInstance(scalaProvider.version, scalaProvider.launcher) + ScalaInstance(scalaProvider.version, scalaProvider) } val g = BuildPaths.getGlobalBase(state) val zincDir = BuildPaths.getZincDirectory(state, g) @@ -61,13 +61,15 @@ object ConsoleProject { val importString = imports.mkString("", ";\n", ";\n\n") val initCommands = importString + extra - Terminal.get.withCanonicalIn { + val terminal = Terminal.get + terminal.withCanonicalIn { // TODO - Hook up dsl classpath correctly... (new Console(compiler))( unit.classpath, options, initCommands, - cleanupCommands + cleanupCommands, + terminal )(Some(unit.loader), bindings).get } () diff --git a/main/src/main/scala/sbt/internal/Continuous.scala b/main/src/main/scala/sbt/internal/Continuous.scala index eac9e76ca..9267964e2 100644 --- a/main/src/main/scala/sbt/internal/Continuous.scala +++ b/main/src/main/scala/sbt/internal/Continuous.scala @@ -274,7 +274,7 @@ private[sbt] object Continuous extends DeprecatedContinuous { private[this] def withCharBufferedStdIn[R](f: InputStream => R): R = Terminal.get.withRawSystemIn { - val wrapped = Terminal.get.inputStream + val wrapped = Terminal.wrappedSystemIn if (Util.isNonCygwinWindows) { val inputStream: InputStream with AutoCloseable = new InputStream with AutoCloseable { private[this] val buffer = new java.util.LinkedList[Int] diff --git a/main/src/main/scala/sbt/internal/LogManager.scala b/main/src/main/scala/sbt/internal/LogManager.scala index f44bd0371..a9db8d34b 100644 --- a/main/src/main/scala/sbt/internal/LogManager.scala +++ b/main/src/main/scala/sbt/internal/LogManager.scala @@ -152,8 +152,9 @@ object LogManager { case Some(x: Exec) => x.source match { // TODO: Fix this stringliness - case Some(x: CommandSource) if x.channelName == "console0" => Option(console) - case _ => Option(console) + case Some(x: CommandSource) if x.channelName == ConsoleChannel.defaultName => + Option(console) + case _ => Option(console) } case _ => Option(console) } diff --git a/main/src/main/scala/sbt/internal/TaskProgress.scala b/main/src/main/scala/sbt/internal/TaskProgress.scala index e208f4a03..2a03206ee 100644 --- a/main/src/main/scala/sbt/internal/TaskProgress.scala +++ b/main/src/main/scala/sbt/internal/TaskProgress.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ private[sbt] final class TaskProgress extends AbstractTaskExecuteProgress with ExecuteProgress[Task] { - @deprecated("Use the constructor taking an ExecID.", "1.4.0") + @deprecated("Use the no argument constructor.", "1.4.0") def this(log: ManagedLogger) = this() private[this] val lastTaskCount = new AtomicInteger(0) private[this] val currentProgressThread = new AtomicReference[Option[ProgressThread]](None) diff --git a/main/src/main/scala/sbt/internal/XMainConfiguration.scala b/main/src/main/scala/sbt/internal/XMainConfiguration.scala index b0eb06751..d83826c37 100644 --- a/main/src/main/scala/sbt/internal/XMainConfiguration.scala +++ b/main/src/main/scala/sbt/internal/XMainConfiguration.scala @@ -9,7 +9,7 @@ package sbt.internal import java.io.File import java.lang.reflect.InvocationTargetException -import java.net.URL +import java.net.{ URL, URLClassLoader } import java.util.concurrent.{ ExecutorService, Executors } import ClassLoaderClose.close @@ -58,10 +58,21 @@ private[internal] object ClassLoaderWarmup { */ private[sbt] class XMainConfiguration { def run(moduleName: String, configuration: xsbti.AppConfiguration): xsbti.MainResult = { + val topLoader = configuration.provider.scalaProvider.launcher.topLoader val updatedConfiguration = - if (configuration.provider.scalaProvider.launcher.topLoader.getClass.getCanonicalName - .contains("TestInterfaceLoader")) { - configuration + if (topLoader.getClass.getCanonicalName.contains("TestInterfaceLoader")) { + topLoader match { + case u: URLClassLoader => + val urls = u.getURLs + var i = 0 + while (i < urls.length && i >= 0) { + if (urls(i).toString.contains("jline")) i = -2 + else i += 1 + } + if (i < 0) configuration + else makeConfiguration(configuration) + case _ => configuration + } } else { makeConfiguration(configuration) } diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index edfc488cc..5a50b3016 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -16,19 +16,19 @@ import java.util.concurrent.{ ConcurrentHashMap, LinkedBlockingQueue } import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes, LogMessageParams, MessageType } +import sbt.internal.langserver.{ CancelRequestParams, ErrorCodes } import sbt.internal.protocol.{ JsonRpcNotificationMessage, JsonRpcRequestMessage, JsonRpcResponseError, JsonRpcResponseMessage } -import sbt.internal.util.{ ReadJsonFromInputStream, Prompt, Terminal, Util } +import sbt.internal.ui.{ UITask, UserThread } +import sbt.internal.util.{ Prompt, ReadJsonFromInputStream, Terminal, Util } import sbt.internal.util.Terminal.TerminalImpl -import sbt.internal.util.complete.Parser +import sbt.internal.util.complete.{ Parser, Parsers } import sbt.protocol._ import sbt.util.Logger -import sjsonnew._ -import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } import scala.annotation.tailrec import scala.collection.mutable @@ -37,6 +37,11 @@ import scala.util.Try import scala.util.control.NonFatal import Serialization.attach +import sjsonnew._ +import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } + +import Serialization.attach + final class NetworkChannel( val name: String, connection: Socket, @@ -44,8 +49,28 @@ final class NetworkChannel( auth: Set[ServerAuthentication], instance: ServerInstance, handlers: Seq[ServerHandler], - val log: Logger + val log: Logger, + mkUIThreadImpl: (State, CommandChannel) => UITask ) extends CommandChannel { self => + def this( + name: String, + connection: Socket, + structure: BuildStructure, + auth: Set[ServerAuthentication], + instance: ServerInstance, + handlers: Seq[ServerHandler], + log: Logger + ) = + this( + name, + connection, + structure, + auth, + instance, + handlers, + log, + new UITask.AskUserTask(_, _) + ) private val running = new AtomicBoolean(true) private val delimiter: Byte = '\n'.toByte @@ -76,6 +101,7 @@ final class NetworkChannel( private[this] val terminalHolder = new AtomicReference(Terminal.NullTerminal) override private[sbt] def terminal: Terminal = terminalHolder.get + override val userThread: UserThread = new UserThread(this) private lazy val callback: ServerCallback = new ServerCallback { def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit = @@ -111,6 +137,22 @@ final class NetworkChannel( protected def authOptions: Set[ServerAuthentication] = auth + override def mkUIThread: (State, CommandChannel) => UITask = (state, command) => { + if (interactive.get) mkUIThreadImpl(state, command) + else + new UITask { + override private[sbt] def channel = NetworkChannel.this + override def reader: UITask.Reader = () => { + try { + this.synchronized(this.wait) + Left("exit") + } catch { + case _: InterruptedException => Right("") + } + } + } + } + val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") { private val ct = "Content-Type: " private val x1 = "application/sbt-x1" @@ -135,7 +177,7 @@ final class NetworkChannel( } } // while } finally { - shutdown() + shutdown(false) } } @@ -242,7 +284,7 @@ final class NetworkChannel( def respond[A: JsonFormat](event: A): Unit = respond(event, None) - def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = { + def respond[A: JsonFormat](event: A, execId: Option[String]): Unit = if (alive.get) { respondResult(event, execId) } @@ -278,7 +320,7 @@ final class NetworkChannel( } catch { case _: IOException => alive.set(false) - shutdown() + shutdown(true) case _: InterruptedException => alive.set(false) } @@ -431,7 +473,7 @@ final class NetworkChannel( errorRespond("No tasks under execution") } } catch { - case NonFatal(e) => + case NonFatal(_) => errorRespond("Cancel request failed") } } else { @@ -439,10 +481,23 @@ final class NetworkChannel( } } - def shutdown(): Unit = { - log.info("Shutting down client connection") + @deprecated("Use variant that takes logShutdown parameter", "1.4.0") + override def shutdown(): Unit = { + shutdown(true) + } + import sjsonnew.BasicJsonProtocol.BooleanJsonFormat + override def shutdown(logShutdown: Boolean): Unit = { + terminal.close() + StandardMain.exchange.removeChannel(this) + super.shutdown(logShutdown) + if (logShutdown) Terminal.consoleLog(s"shutting down client connection $name") + VirtualTerminal.cancelRequests(name) + try jsonRpcNotify("shutdown", logShutdown) + catch { case _: IOException => } running.set(false) out.close() + thread.interrupt() + writeThread.interrupt() } /** Respond back to Language Server's client. */ @@ -643,4 +698,13 @@ object NetworkChannel { case object SingleLine extends ChannelState case object InHeader extends ChannelState case object InBody extends ChannelState + + private[sbt] val disconnect: Command = + Command.arb { s => + val dncParser: Parser[String] = BasicCommandStrings.DisconnectNetworkChannel + dncParser.examples() ~> Parsers.Space.examples() ~> Parsers.any.*.examples() + } { (st, channel) => + StandardMain.exchange.killChannel(channel.mkString) + st + } } diff --git a/main/src/test/scala/PluginCommandTest.scala b/main/src/test/scala/PluginCommandTest.scala index 52b594499..7fd524ac2 100644 --- a/main/src/test/scala/PluginCommandTest.scala +++ b/main/src/test/scala/PluginCommandTest.scala @@ -17,7 +17,8 @@ import sbt.internal.util.{ ConsoleOut, GlobalLogging, MainAppender, - Settings + Settings, + Terminal } object PluginCommandTestPlugin0 extends AutoPlugin { override def requires = empty } @@ -72,19 +73,21 @@ object PluginCommandTest extends Specification { object FakeState { def processCommand(input: String, enabledPlugins: AutoPlugin*): String = { - val previousOut = System.out val outBuffer = new ByteArrayOutputStream + val logFile = File.createTempFile("sbt", ".log") try { - System.setOut(new PrintStream(outBuffer, true)) - val state = FakeState(enabledPlugins: _*) - MainLoop.processCommand(Exec(input, None), state) + val state = FakeState(logFile, enabledPlugins: _*) + Terminal.withOut(new PrintStream(outBuffer, true)) { + MainLoop.processCommand(Exec(input, None), state) + } new String(outBuffer.toByteArray) } finally { - System.setOut(previousOut) + logFile.delete() + () } } - def apply(plugins: AutoPlugin*) = { + def apply(logFile: File, plugins: AutoPlugin*) = { val base = new File("").getAbsoluteFile val testProject = Project("test-project", base).setAutoPlugins(plugins) @@ -154,9 +157,9 @@ object FakeState { State.newHistory, attributes, GlobalLogging.initial( - MainAppender.globalDefault(ConsoleOut.systemOut), - File.createTempFile("sbt", ".log"), - ConsoleOut.systemOut + MainAppender.globalDefault(ConsoleOut.globalProxy), + logFile, + ConsoleOut.globalProxy ), None, State.Continue diff --git a/project/build.properties b/project/build.properties index a919a9b5f..797e7ccfd 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.3.8 +sbt.version=1.3.10 diff --git a/server-test/src/server-test/client/build.sbt b/server-test/src/server-test/client/build.sbt new file mode 100644 index 000000000..3225bd76d --- /dev/null +++ b/server-test/src/server-test/client/build.sbt @@ -0,0 +1,7 @@ +TaskKey[Unit]("willSucceed") := println("success") + +TaskKey[Unit]("willFail") := { throw new Exception("failed") } + +libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.8" % "test" + +TaskKey[Unit]("fooBar") := { () } diff --git a/server-test/src/server-test/client/src/main/scala/A.scala b/server-test/src/server-test/client/src/main/scala/A.scala new file mode 100644 index 000000000..69c493db2 --- /dev/null +++ b/server-test/src/server-test/client/src/main/scala/A.scala @@ -0,0 +1 @@ +object A diff --git a/server-test/src/server-test/client/src/test/scala/FooSpec.scala b/server-test/src/server-test/client/src/test/scala/FooSpec.scala new file mode 100644 index 000000000..269be5624 --- /dev/null +++ b/server-test/src/server-test/client/src/test/scala/FooSpec.scala @@ -0,0 +1,3 @@ +package test.pkg + +class FooSpec extends org.scalatest.FlatSpec diff --git a/server-test/src/server-test/events/Main.scala b/server-test/src/server-test/events/Main.scala index dca34c07f..6030b2602 100644 --- a/server-test/src/server-test/events/Main.scala +++ b/server-test/src/server-test/events/Main.scala @@ -1,8 +1,6 @@ - object Main extends App { - while (true) { - Thread.sleep(1000) - } + try this.synchronized(this.wait) + catch { case _: InterruptedException => } } diff --git a/server-test/src/server-test/handshake/build.sbt b/server-test/src/server-test/handshake/build.sbt index cb7de6741..5f77cb0a3 100644 --- a/server-test/src/server-test/handshake/build.sbt +++ b/server-test/src/server-test/handshake/build.sbt @@ -5,21 +5,16 @@ ThisBuild / scalaVersion := "2.12.11" lazy val root = (project in file(".")) .settings( Global / serverLog / logLevel := Level.Debug, - // custom handler Global / serverHandlers += ServerHandler({ callback => import callback._ import sjsonnew.BasicJsonProtocol._ import sbt.internal.protocol.JsonRpcRequestMessage - ServerIntent( - { - case r: JsonRpcRequestMessage if r.method == "lunar/helo" => - jsonRpcNotify("lunar/oleh", "") - () - }, - PartialFunction.empty - ) + ServerIntent.request { + case r: JsonRpcRequestMessage if r.method == "lunar/helo" => + jsonRpcNotify("lunar/oleh", "") + () + } }), - name := "handshake", ) diff --git a/server-test/src/test/scala/testpkg/EventsTest.scala b/server-test/src/test/scala/testpkg/EventsTest.scala index 5359bb2dd..8865e8eb3 100644 --- a/server-test/src/test/scala/testpkg/EventsTest.scala +++ b/server-test/src/test/scala/testpkg/EventsTest.scala @@ -8,54 +8,79 @@ package testpkg import scala.concurrent.duration._ +import java.util.concurrent.atomic.AtomicInteger +import sbt.Exec // starts svr using server-test/events and perform event related tests object EventsTest extends AbstractServerTest { override val testDirectory: String = "events" + val currentID = new AtomicInteger(0) test("report task failures in case of exceptions") { _ => + val id = currentID.getAndIncrement() svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 11, "method": "sbt/exec", "params": { "commandLine": "hello" } }""" + s"""{ "jsonrpc": "2.0", "id": $id, "method": "sbt/exec", "params": { "commandLine": "hello" } }""" ) assert(svr.waitForString(10.seconds) { s => - (s contains """"id":11""") && (s contains """"error":""") + (s contains s""""id":$id""") && (s contains """"error":""") }) } test("return error if cancelling non-matched task id") { _ => + val id = currentID.getAndIncrement() svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" + s"""{ "jsonrpc": "2.0", "id":$id, "method": "sbt/exec", "params": { "commandLine": "run" } }""" ) + assert(svr.waitForString(10.seconds) { s => + s contains "Waiting for" + }) + val cancelID = currentID.getAndIncrement() + val invalidID = currentID.getAndIncrement() svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id":13, "method": "sbt/cancelRequest", "params": { "id": "55" } }""" + s"""{ "jsonrpc": "2.0", "id":$cancelID, "method": "sbt/cancelRequest", "params": { "id": "$invalidID" } }""" ) assert(svr.waitForString(20.seconds) { s => (s contains """"error":{"code":-32800""") }) + svr.sendJsonRpc( + s"""{ "jsonrpc": "2.0", "id":${currentID.getAndIncrement}, "method": "sbt/cancelRequest", "params": { "id": "$id" } }""" + ) + assert(svr.waitForString(10.seconds) { s => + s contains """"result":{"status":"Task cancelled"""" + }) } test("cancel on-going task with numeric id") { _ => + val id = currentID.getAndIncrement() svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id":12, "method": "sbt/exec", "params": { "commandLine": "run" } }""" + s"""{ "jsonrpc": "2.0", "id":$id, "method": "sbt/exec", "params": { "commandLine": "run" } }""" ) - assert(svr.waitForString(1.minute) { s => - svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id":13, "method": "sbt/cancelRequest", "params": { "id": "12" } }""" - ) + assert(svr.waitForString(10.seconds) { s => + s contains "Waiting for" + }) + val cancelID = currentID.getAndIncrement() + svr.sendJsonRpc( + s"""{ "jsonrpc": "2.0", "id":$cancelID, "method": "sbt/cancelRequest", "params": { "id": "$id" } }""" + ) + assert(svr.waitForString(10.seconds) { s => s contains """"result":{"status":"Task cancelled"""" }) } - /* test("cancel on-going task with string id") { _ => + val id = Exec.newExecId svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id": "foo", "method": "sbt/exec", "params": { "commandLine": "run" } }""" + s"""{ "jsonrpc": "2.0", "id": "$id", "method": "sbt/exec", "params": { "commandLine": "run" } }""" ) - assert(svr.waitForString(1.minute) { s => - svr.sendJsonRpc( - """{ "jsonrpc": "2.0", "id": "bar", "method": "sbt/cancelRequest", "params": { "id": "foo" } }""" - ) + assert(svr.waitForString(10.seconds) { s => + s contains "Waiting for" + }) + val cancelID = Exec.newExecId + svr.sendJsonRpc( + s"""{ "jsonrpc": "2.0", "id": "$cancelID", "method": "sbt/cancelRequest", "params": { "id": "$id" } }""" + ) + assert(svr.waitForString(10.seconds) { s => s contains """"result":{"status":"Task cancelled"""" }) - }*/ + } } diff --git a/server-test/src/test/scala/testpkg/ResponseTest.scala b/server-test/src/test/scala/testpkg/ResponseTest.scala index 75cec4250..4b4057989 100644 --- a/server-test/src/test/scala/testpkg/ResponseTest.scala +++ b/server-test/src/test/scala/testpkg/ResponseTest.scala @@ -18,7 +18,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "10", "method": "foo/export", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """"id":"10"""") && (s contains "scala-library.jar") }) @@ -29,7 +29,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "11", "method": "foo/rootClasspath", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """"id":"11"""") && (s contains "scala-library.jar") }) @@ -40,7 +40,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "12", "method": "foo/fail", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """"error":{"code":-33000,"message":"fail message"""") }) } @@ -50,7 +50,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "13", "method": "foo/customfail", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """"error":{"code":500,"message":"some error"""") }) } @@ -60,7 +60,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "14", "method": "foo/notification", "params": {} }""" ) assert(svr.waitForString(10.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) (s contains """{"jsonrpc":"2.0","method":"foo/something","params":"something"}""") }) } @@ -71,14 +71,14 @@ object ResponseTest extends AbstractServerTest { ) assert { svr.waitForString(1.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"id\":\"15\"" } } assert { // the second response should never be sent svr.neverReceive(500.milliseconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"id\":\"15\"" } } @@ -90,14 +90,14 @@ object ResponseTest extends AbstractServerTest { ) assert { svr.waitForString(1.seconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"id\":\"16\"" } } assert { // the second response (result or error) should never be sent svr.neverReceive(500.milliseconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"id\":\"16\"" } } @@ -109,7 +109,7 @@ object ResponseTest extends AbstractServerTest { ) assert { svr.neverReceive(500.milliseconds) { s => - println(s) + if (!s.contains("systemOut")) println(s) s contains "\"result\":\"notification result\"" } } diff --git a/server-test/src/test/scala/testpkg/TestServer.scala b/server-test/src/test/scala/testpkg/TestServer.scala index 649923f47..d601e6d95 100644 --- a/server-test/src/test/scala/testpkg/TestServer.scala +++ b/server-test/src/test/scala/testpkg/TestServer.scala @@ -8,6 +8,7 @@ package testpkg import java.io.{ File, IOException } +import java.nio.file.Path import java.util.concurrent.TimeoutException import verify._ @@ -26,6 +27,7 @@ trait AbstractServerTest extends TestSuite[Unit] { private var temp: File = _ var svr: TestServer = _ def testDirectory: String + def testPath: Path = temp.toPath.resolve(testDirectory) private val targetDir: File = { val p0 = new File("..").getAbsoluteFile.getCanonicalFile / "target" @@ -224,7 +226,7 @@ case class TestServer( def bye(): Unit = { hostLog("sending exit") sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "exit" } }""" + """{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "shutdown" } }""" ) val deadline = 10.seconds.fromNow while (!deadline.isOverdue && process.isAlive) {