diff --git a/internal/util-complete/src/main/scala/sbt/internal/util/LineReader.scala b/internal/util-complete/src/main/scala/sbt/internal/util/LineReader.scala index 660406559..ab082c600 100644 --- a/internal/util-complete/src/main/scala/sbt/internal/util/LineReader.scala +++ b/internal/util-complete/src/main/scala/sbt/internal/util/LineReader.scala @@ -33,6 +33,11 @@ abstract class JLine extends LineReader { Option("") } + override def redraw(): Unit = { + reader.drawLine() + reader.flush() + } + private[this] def unsynchronizedReadLine(prompt: String, mask: Option[Char]): Option[String] = readLineWithHistory(prompt, mask) map { x => x.trim @@ -169,6 +174,7 @@ private[sbt] class InputStreamWrapper(is: InputStream, val poll: Duration) trait LineReader { def readLine(prompt: String, mask: Option[Char] = None): Option[String] + def redraw(): Unit = () } final class FullReader( diff --git a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala index d45dce4cb..11bee3a2a 100644 --- a/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala +++ b/main-command/src/main/scala/sbt/internal/ConsoleChannel.scala @@ -10,27 +10,29 @@ package internal import java.io.File import java.nio.channels.ClosedChannelException +import java.util.concurrent.atomic.AtomicReference import sbt.BasicKeys._ -import sbt.internal.util.Util.AnyOps import sbt.internal.util._ import sbt.protocol.EventMessage import sjsonnew.JsonFormat private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel { - private var askUserThread: Option[Thread] = None - def makeAskUserThread(s: State): Thread = new Thread("ask-user-thread") { - val history = (s get historyPath) getOrElse (new File(s.baseDir, ".history")).some - val prompt = (s get shellPrompt) match { - case Some(pf) => pf(s) - case None => - def ansi(s: String): String = if (ConsoleAppender.formatEnabledInEnv) s"$s" else "" - s"${ansi(ConsoleAppender.DeleteLine)}> ${ansi(ConsoleAppender.clearScreen(0))}" - } - val reader = + private[this] val askUserThread = new AtomicReference[AskUserThread] + private[this] def getPrompt(s: State): String = s.get(shellPrompt) match { + case Some(pf) => pf(s) + case None => + def ansi(s: String): String = if (ConsoleAppender.formatEnabledInEnv) s"$s" else "" + s"${ansi(ConsoleAppender.DeleteLine)}> ${ansi(ConsoleAppender.clearScreen(0))}" + } + private[this] class AskUserThread(s: State) extends Thread("ask-user-thread") { + private val history = s.get(historyPath).getOrElse(Some(new File(s.baseDir, ".history"))) + private val prompt = getPrompt(s) + private val reader = new FullReader(history, s.combinedParser, JLine.HandleCONT, Terminal.throwOnClosedSystemIn) - override def run(): Unit = { - // This internally handles thread interruption and returns Some("") + setDaemon(true) + start() + override def run(): Unit = try { reader.readLine(prompt) match { case Some(cmd) => append(Exec(cmd, Some(Exec.newExecId), Some(CommandSource(name)))) @@ -38,12 +40,18 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel println("") // Prevents server shutdown log lines from appearing on the prompt line append(Exec("exit", Some(Exec.newExecId), Some(CommandSource(name)))) } + () } catch { case _: ClosedChannelException => - } - askUserThread = None + } finally askUserThread.synchronized(askUserThread.set(null)) + def redraw(): Unit = { + System.out.print(ConsoleAppender.clearLine(0)) + reader.redraw() + System.out.print(ConsoleAppender.clearScreen(0)) + System.out.flush() } } + private[this] def makeAskUserThread(s: State): AskUserThread = new AskUserThread(s) def run(s: State): State = s @@ -54,21 +62,24 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel def publishEventMessage(event: EventMessage): Unit = event match { case e: ConsolePromptEvent => - askUserThread match { - case Some(_) => - case _ => - val x = makeAskUserThread(e.state) - askUserThread = Some(x) - x.start() + if (Terminal.systemInIsAttached) { + askUserThread.synchronized { + askUserThread.get match { + case null => askUserThread.set(makeAskUserThread(e.state)) + case t => t.redraw() + } + } } case _ => // } - def shutdown(): Unit = - askUserThread match { - case Some(x) if x.isAlive => - x.interrupt() - askUserThread = None + def shutdown(): Unit = askUserThread.synchronized { + askUserThread.get match { + case null => + case t if t.isAlive => + t.interrupt() + askUserThread.set(null) case _ => () } + } } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 45421ac0f..85aa09f6a 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -19,7 +19,7 @@ import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.langserver.{ LogMessageParams, MessageType } import sbt.internal.server._ import sbt.internal.util.codec.JValueFormats -import sbt.internal.util.{ MainAppender, ObjectEvent, StringEvent } +import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal } import sbt.io.syntax._ import sbt.io.{ Hash, IO } import sbt.protocol.{ EventMessage, ExecStatusEvent } @@ -50,6 +50,7 @@ private[sbt] final class CommandExchange { private val channelBufferLock = new AnyRef {} private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel] private val nextChannelId: AtomicInteger = new AtomicInteger(0) + private[this] val activePrompt = new AtomicBoolean(false) private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {} def channels: List[CommandChannel] = channelBuffer.toList @@ -83,7 +84,11 @@ private[sbt] final class CommandExchange { commandChannelQueue.poll(1, TimeUnit.SECONDS) slurpMessages() Option(commandQueue.poll) match { - case Some(x) => x + case Some(exec) => + val needFinish = needToFinishPromptLine() + if (exec.source.fold(needFinish)(s => needFinish && s.channelName != "console0")) + ConsoleOut.systemOut.println("") + exec case None => val newDeadline = if (deadline.fold(false)(_.isOverdue())) { GCUtil.forceGcWithInterval(interval, logger) @@ -129,6 +134,7 @@ private[sbt] final class CommandExchange { def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = { val name = newNetworkName + if (needToFinishPromptLine()) ConsoleOut.systemOut.println("") s.log.info(s"new client connected: $name") val logger: Logger = { val log = LogExchange.logger(name, None, None) @@ -362,7 +368,9 @@ private[sbt] final class CommandExchange { // Special treatment for ConsolePromptEvent since it's hand coded without codec. case entry: ConsolePromptEvent => channels collect { - case c: ConsoleChannel => c.publishEventMessage(entry) + case c: ConsoleChannel => + c.publishEventMessage(entry) + activePrompt.set(Terminal.systemInIsAttached) } case entry: ExecStatusEvent => channels collect { @@ -380,4 +388,5 @@ private[sbt] final class CommandExchange { removeChannels(toDel.toList) } + private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false) }