send response to the source channel only

This commit is contained in:
Adrien Piquerez 2020-05-12 14:33:19 +02:00
parent df293fbfd5
commit 255a0a6ea6
6 changed files with 95 additions and 181 deletions

View File

@ -11,7 +11,6 @@ package internal
import java.util.concurrent.ConcurrentLinkedQueue
import sbt.protocol.EventMessage
import sjsonnew.JsonFormat
/**
* A command channel represents an IO device such as network socket or human
@ -42,9 +41,6 @@ abstract class CommandChannel {
}
def poll: Option[Exec] = Option(commandQueue.poll)
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit
final def publishEvent[A: JsonFormat](event: A): Unit = publishEvent(event, None)
def publishEventMessage(event: EventMessage): Unit
def publishBytes(bytes: Array[Byte]): Unit
def shutdown(): Unit
def name: String

View File

@ -14,8 +14,6 @@ import java.util.concurrent.atomic.AtomicReference
import sbt.BasicKeys._
import sbt.internal.util._
import sbt.protocol.EventMessage
import sjsonnew.JsonFormat
private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel {
private[this] val askUserThread = new AtomicReference[AskUserThread]
@ -62,21 +60,16 @@ private[sbt] final class ConsoleChannel(val name: String) extends CommandChannel
def publishBytes(bytes: Array[Byte]): Unit = ()
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = ()
def publishEventMessage(event: EventMessage): Unit =
event match {
case e: ConsolePromptEvent =>
if (Terminal.systemInIsAttached) {
askUserThread.synchronized {
askUserThread.get match {
case null => askUserThread.set(makeAskUserThread(e.state))
case t => t.redraw()
}
}
def publishEventMessage(event: ConsolePromptEvent): Unit = {
if (Terminal.systemInIsAttached) {
askUserThread.synchronized {
askUserThread.get match {
case null => askUserThread.set(makeAskUserThread(event.state))
case t => t.redraw()
}
case _ => //
}
}
}
def shutdown(): Unit = askUserThread.synchronized {
askUserThread.get match {

View File

@ -16,16 +16,13 @@ import java.util.concurrent.atomic._
import sbt.BasicKeys._
import sbt.nio.Watch.NullLogger
import sbt.internal.protocol.JsonRpcResponseError
import sbt.internal.langserver.{ LogMessageParams, MessageType }
import sbt.internal.server._
import sbt.internal.util.codec.JValueFormats
import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, StringEvent, Terminal }
import sbt.io.syntax._
import sbt.io.{ Hash, IO }
import sbt.protocol.{ EventMessage, ExecStatusEvent }
import sbt.util.{ Level, LogExchange, Logger }
import sjsonnew.JsonFormat
import sjsonnew.shaded.scalajson.ast.unsafe._
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer
@ -51,17 +48,12 @@ private[sbt] final class CommandExchange {
private val commandChannelQueue = new LinkedBlockingQueue[CommandChannel]
private val nextChannelId: AtomicInteger = new AtomicInteger(0)
private[this] val activePrompt = new AtomicBoolean(false)
private lazy val jsonFormat = new sjsonnew.BasicJsonProtocol with JValueFormats {}
def channels: List[CommandChannel] = channelBuffer.toList
private[this] def removeChannels(toDel: List[CommandChannel]): Unit = {
toDel match {
case Nil => // do nothing
case xs =>
channelBufferLock.synchronized {
channelBuffer --= xs
()
}
private[this] def removeChannel(channel: CommandChannel): Unit = {
channelBufferLock.synchronized {
channelBuffer -= channel
()
}
}
@ -206,19 +198,7 @@ private[sbt] final class CommandExchange {
execId: Option[String],
source: Option[CommandSource]
): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
channels.foreach {
case _: ConsoleChannel =>
case c: NetworkChannel =>
try {
// broadcast to all network channels
c.respondError(code, message, execId)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
respondError(JsonRpcResponseError(code, message), execId, source)
}
private[sbt] def respondError(
@ -226,19 +206,13 @@ private[sbt] final class CommandExchange {
execId: Option[String],
source: Option[CommandSource]
): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
channels.foreach {
case _: ConsoleChannel =>
case c: NetworkChannel =>
try {
// broadcast to all network channels
c.respondError(err, execId)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
for {
source <- source.map(_.channelName)
channel <- channels.collectFirst {
// broadcast to the source channel only
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.respondError(err, execId), removeChannel)(channel)
}
// This is an interface to directly respond events.
@ -247,86 +221,57 @@ private[sbt] final class CommandExchange {
execId: Option[String],
source: Option[CommandSource]
): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
channels.foreach {
case _: ConsoleChannel =>
case c: NetworkChannel =>
try {
// broadcast to all network channels
c.respondEvent(event, execId)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
for {
source <- source.map(_.channelName)
channel <- channels.collectFirst {
// broadcast to the source channel only
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.respondResult(event, execId), removeChannel)(channel)
}
// This is an interface to directly notify events.
private[sbt] def notifyEvent[A: JsonFormat](method: String, params: A): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
channels.foreach {
case _: ConsoleChannel =>
// c.publishEvent(event)
case c: NetworkChannel =>
try {
c.notifyEvent(method, params)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
channels
.collect { case c: NetworkChannel => c }
.foreach {
tryTo(_.notifyEvent(method, params), removeChannel)
}
}
private def tryTo(x: => Unit, c: CommandChannel, toDel: ListBuffer[CommandChannel]): Unit =
try x
catch { case _: IOException => toDel += c }
private def tryTo(f: NetworkChannel => Unit, fallback: NetworkChannel => Unit)(
channel: NetworkChannel
): Unit =
try f(channel)
catch { case _: IOException => fallback(channel) }
def publishEvent[A: JsonFormat](event: A): Unit = {
val broadcastStringMessage = true
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
event match {
case entry: StringEvent =>
val params = toLogMessageParams(entry)
channels collect {
case c: ConsoleChannel =>
if (broadcastStringMessage || (entry.channelName forall (_ == c.name)))
c.publishEvent(event)
case c: NetworkChannel =>
tryTo(
{
// Note that language server's LogMessageParams does not hold the execid,
// so this is weaker than the StringMessage. We might want to double-send
// in case we have a better client that can utilize the knowledge.
import sbt.internal.langserver.codec.JsonProtocol._
if (broadcastStringMessage || (entry.channelName contains c.name))
c.jsonRpcNotify("window/logMessage", params)
},
c,
toDel
)
}
case entry: ExecStatusEvent =>
channels collect {
case c: ConsoleChannel =>
if (entry.channelName forall (_ == c.name)) c.publishEvent(event)
case c: NetworkChannel =>
if (entry.channelName contains c.name) tryTo(c.publishEvent(event), c, toDel)
}
case _ =>
channels foreach {
case c: ConsoleChannel => c.publishEvent(event)
case c: NetworkChannel =>
tryTo(c.publishEvent(event), c, toDel)
}
}
removeChannels(toDel.toList)
}
// Note that language server's LogMessageParams does not hold the execid,
// so this is weaker than the StringMessage. We might want to double-send
// in case we have a better client that can utilize the knowledge.
channels
.collect { case c: NetworkChannel => c }
.foreach {
tryTo(_.logMessage(entry.level, entry.message), removeChannel)
}
private[sbt] def toLogMessageParams(event: StringEvent): LogMessageParams = {
LogMessageParams(MessageType.fromLevelString(event.level), event.message)
case entry: ExecStatusEvent =>
for {
source <- entry.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.publishEvent(event), removeChannel)(channel)
case _ =>
channels
.collect { case c: NetworkChannel => c }
.foreach {
tryTo(_.publishEvent(event), removeChannel)
}
}
}
/**
@ -334,59 +279,38 @@ private[sbt] final class CommandExchange {
* erased because it went through logging.
*/
private[sbt] def publishObjectEvent(event: ObjectEvent[_]): Unit = {
import jsonFormat._
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
def json: JValue = JObject(
JField("type", JString(event.contentType)),
Vector(JField("message", event.json), JField("level", JString(event.level.toString))) ++
(event.channelName.toVector map { channelName =>
JField("channelName", JString(channelName))
}) ++
(event.execId.toVector map { execId =>
JField("execId", JString(execId))
}): _*
)
channels collect {
case c: ConsoleChannel =>
c.publishEvent(json)
case c: NetworkChannel =>
try {
c.publishObjectEvent(event)
} catch {
case _: IOException =>
toDel += c
}
}
removeChannels(toDel.toList)
for {
source <- event.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.publishObjectEvent(event), removeChannel)(channel)
}
// fanout publishEvent
def publishEventMessage(event: EventMessage): Unit = {
val toDel: ListBuffer[CommandChannel] = ListBuffer.empty
event match {
// Special treatment for ConsolePromptEvent since it's hand coded without codec.
case entry: ConsolePromptEvent =>
channels collect {
case c: ConsoleChannel =>
c.publishEventMessage(entry)
activePrompt.set(Terminal.systemInIsAttached)
}
activePrompt.set(Terminal.systemInIsAttached)
channels
.collect { case c: ConsoleChannel => c }
.foreach { _.publishEventMessage(entry) }
case entry: ExecStatusEvent =>
channels collect {
case c: ConsoleChannel =>
if (entry.channelName forall (_ == c.name)) c.publishEventMessage(event)
case c: NetworkChannel =>
if (entry.channelName contains c.name) tryTo(c.publishEventMessage(event), c, toDel)
}
case _ =>
channels collect {
case c: ConsoleChannel => c.publishEventMessage(event)
case c: NetworkChannel => tryTo(c.publishEventMessage(event), c, toDel)
}
}
for {
source <- entry.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.publishEventMessage(event), removeChannel)(channel)
removeChannels(toDel.toList)
case _ =>
channels
.collect { case c: NetworkChannel => c }
.foreach {
tryTo(_.publishEventMessage(event), removeChannel)
}
}
}
private[this] def needToFinishPromptLine(): Boolean = activePrompt.compareAndSet(true, false)
}

View File

@ -40,7 +40,7 @@ private[sbt] object Definition {
def send[A: JsonFormat](source: CommandSource, execId: String)(params: A): Unit = {
for {
channel <- StandardMain.exchange.channels.collectFirst {
case c if c.name == source.channelName => c
case c: NetworkChannel if c.name == source.channelName => c
}
} {
channel.publishEvent(params, Option(execId))

View File

@ -116,7 +116,7 @@ private[sbt] trait LanguageServerProtocol { self: NetworkChannel =>
protected lazy val callbackImpl: ServerCallback = new ServerCallback {
def jsonRpcRespond[A: JsonFormat](event: A, execId: Option[String]): Unit =
self.respondEvent(event, execId)
self.respondResult(event, execId)
def jsonRpcRespondError(execId: Option[String], code: Long, message: String): Unit =
self.respondError(code, message, execId)

View File

@ -267,7 +267,6 @@ final class NetworkChannel(
err: JsonRpcResponseError,
execId: Option[String]
): Unit = this.synchronized {
println(s"respond error for $execId")
execId match {
case Some(id) if onGoingRequests.contains(id) =>
onGoingRequests -= id
@ -285,11 +284,10 @@ final class NetworkChannel(
respondError(JsonRpcResponseError(code, message), execId)
}
private[sbt] def respondEvent[A: JsonFormat](
private[sbt] def respondResult[A: JsonFormat](
event: A,
execId: Option[String]
): Unit = this.synchronized {
println(s"respond result for $execId")
execId match {
case Some(id) if onGoingRequests.contains(id) =>
onGoingRequests -= id
@ -307,17 +305,20 @@ final class NetworkChannel(
}
}
def publishEvent[A: JsonFormat](event: A): Unit =
publishEvent(event, None)
def publishEvent[A: JsonFormat](event: A, execId: Option[String]): Unit = {
if (isLanguageServerProtocol) {
event match {
case entry: StringEvent => logMessage(entry.level, entry.message)
case entry: ExecStatusEvent =>
entry.exitCode match {
case None => respondEvent(event, entry.execId)
case Some(0) => respondEvent(event, entry.execId)
case None => respondResult(event, entry.execId)
case Some(0) => respondResult(event, entry.execId)
case Some(exitCode) => respondError(exitCode, entry.message.getOrElse(""), entry.execId)
}
case _ => respondEvent(event, execId)
case _ => respondResult(event, execId)
}
} else {
contentType match {
@ -417,7 +418,7 @@ final class NetworkChannel(
if (initialized) {
import sbt.protocol.codec.JsonProtocol._
SettingQuery.handleSettingQueryEither(req, structure) match {
case Right(x) => respondEvent(x, execId)
case Right(x) => respondResult(x, execId)
case Left(s) => respondError(ErrorCodes.InvalidParams, s, execId)
}
} else {
@ -440,7 +441,7 @@ final class NetworkChannel(
}
.map(c => cp.query + c)
import sbt.protocol.codec.JsonProtocol._
respondEvent(
respondResult(
CompletionResponse(
items = completionItems.toVector
),
@ -498,7 +499,7 @@ final class NetworkChannel(
runningEngine.cancelAndShutdown()
import sbt.protocol.codec.JsonProtocol._
respondEvent(
respondResult(
ExecStatusEvent(
"Task cancelled",
Some(name),