mirror of https://github.com/sbt/sbt.git
Consolidate and optimize input stream json reading
We had similar code for reading json frames from an input stream in NetworkChannel and ServerConnection. I reworked and consolidated this logic into a shared method in ReadJsonFromInputStream. This commit also removes the ObjectMessage reporting methods that weren't doing anything.
This commit is contained in:
parent
b0a859acb5
commit
fcfe4333fe
|
|
@ -9,10 +9,13 @@ package sbt
|
|||
package internal
|
||||
package client
|
||||
|
||||
import java.net.{ SocketTimeoutException, Socket }
|
||||
import java.io.IOException
|
||||
import java.net.{ Socket, SocketTimeoutException }
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import sbt.protocol._
|
||||
import sbt.internal.protocol._
|
||||
import sbt.internal.util.ReadJsonFromInputStream
|
||||
|
||||
abstract class ServerConnection(connection: Socket) {
|
||||
|
||||
|
|
@ -25,69 +28,29 @@ abstract class ServerConnection(connection: Socket) {
|
|||
val thread = new Thread(s"sbt-serverconnection-${connection.getPort}") {
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
val readBuffer = new Array[Byte](4096)
|
||||
val in = connection.getInputStream
|
||||
connection.setSoTimeout(5000)
|
||||
var buffer: Vector[Byte] = Vector.empty
|
||||
def readFrame: Vector[Byte] = {
|
||||
def getContentLength: Int = {
|
||||
readLine.drop(16).toInt
|
||||
}
|
||||
val l = getContentLength
|
||||
readLine
|
||||
readLine
|
||||
readContentLength(l)
|
||||
}
|
||||
|
||||
def readLine: String = {
|
||||
if (buffer.isEmpty) {
|
||||
val bytesRead = in.read(readBuffer)
|
||||
if (bytesRead > 0) {
|
||||
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
|
||||
}
|
||||
}
|
||||
val delimPos = buffer.indexOf(delimiter)
|
||||
if (delimPos > 0) {
|
||||
val chunk0 = buffer.take(delimPos)
|
||||
buffer = buffer.drop(delimPos + 1)
|
||||
// remove \r at the end of line.
|
||||
val chunk1 = if (chunk0.lastOption contains retByte) chunk0.dropRight(1) else chunk0
|
||||
new String(chunk1.toArray, "utf-8")
|
||||
} else readLine
|
||||
}
|
||||
|
||||
def readContentLength(length: Int): Vector[Byte] = {
|
||||
if (buffer.size < length) {
|
||||
val bytesRead = in.read(readBuffer)
|
||||
if (bytesRead > 0) {
|
||||
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
|
||||
} else ()
|
||||
} else ()
|
||||
if (length <= buffer.size) {
|
||||
val chunk = buffer.take(length)
|
||||
buffer = buffer.drop(length)
|
||||
chunk
|
||||
} else readContentLength(length)
|
||||
}
|
||||
|
||||
while (running.get) {
|
||||
try {
|
||||
val frame = readFrame
|
||||
Serialization
|
||||
.deserializeJsonMessage(frame)
|
||||
.fold(
|
||||
{ errorDesc =>
|
||||
val s = frame.mkString("") // new String(: Array[Byte], "UTF-8")
|
||||
println(s"Got invalid chunk from server: $s \n" + errorDesc)
|
||||
},
|
||||
_ match {
|
||||
case msg: JsonRpcRequestMessage => onRequest(msg)
|
||||
case msg: JsonRpcResponseMessage => onResponse(msg)
|
||||
case msg: JsonRpcNotificationMessage => onNotification(msg)
|
||||
}
|
||||
)
|
||||
val frame = ReadJsonFromInputStream(in, running, None)
|
||||
if (running.get) {
|
||||
Serialization
|
||||
.deserializeJsonMessage(frame)
|
||||
.fold(
|
||||
{ errorDesc =>
|
||||
val s = frame.mkString("") // new String(: Array[Byte], "UTF-8")
|
||||
println(s"Got invalid chunk from server: $s \n" + errorDesc)
|
||||
},
|
||||
_ match {
|
||||
case msg: JsonRpcRequestMessage => onRequest(msg)
|
||||
case msg: JsonRpcResponseMessage => onResponse(msg)
|
||||
case msg: JsonRpcNotificationMessage => onNotification(msg)
|
||||
}
|
||||
)
|
||||
}
|
||||
} catch {
|
||||
case _: SocketTimeoutException => // its ok
|
||||
case e: IOException => running.set(false)
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* sbt
|
||||
* Copyright 2011 - 2018, Lightbend, Inc.
|
||||
* Copyright 2008 - 2010, Mark Harrah
|
||||
* Licensed under Apache License 2.0 (see LICENSE)
|
||||
*/
|
||||
|
||||
package sbt.internal.util
|
||||
|
||||
import java.io.InputStream
|
||||
import java.nio.channels.ClosedChannelException
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import scala.collection.mutable
|
||||
import scala.util.Try
|
||||
|
||||
private[sbt] object ReadJsonFromInputStream {
|
||||
def apply(
|
||||
inputStream: InputStream,
|
||||
running: AtomicBoolean,
|
||||
onHeader: Option[String => Unit]
|
||||
): Seq[Byte] = {
|
||||
val newline = '\n'.toInt
|
||||
val carriageReturn = '\r'.toInt
|
||||
val contentLength = "Content-Length: "
|
||||
var bytes = new mutable.ArrayBuffer[Byte]
|
||||
def getLine(): String = {
|
||||
val line = new String(bytes.toArray, "UTF-8")
|
||||
bytes = new mutable.ArrayBuffer[Byte]
|
||||
onHeader.foreach(oh => oh(line))
|
||||
line
|
||||
}
|
||||
var content: Seq[Byte] = Seq.empty[Byte]
|
||||
var consecutiveLineEndings = 0
|
||||
var onCarriageReturn = false
|
||||
do {
|
||||
val byte = inputStream.read
|
||||
byte match {
|
||||
case `newline` =>
|
||||
val line = getLine()
|
||||
if (onCarriageReturn) consecutiveLineEndings += 1
|
||||
onCarriageReturn = false
|
||||
if (line.startsWith(contentLength)) {
|
||||
Try(line.drop(contentLength.length).toInt) foreach { len =>
|
||||
def drainHeaders(): Unit =
|
||||
do {
|
||||
inputStream.read match {
|
||||
case `newline` if onCarriageReturn =>
|
||||
getLine()
|
||||
onCarriageReturn = false
|
||||
consecutiveLineEndings += 1
|
||||
case `carriageReturn` => onCarriageReturn = true
|
||||
case c =>
|
||||
if (c == newline) getLine()
|
||||
else bytes += c.toByte
|
||||
onCarriageReturn = false
|
||||
consecutiveLineEndings = 0
|
||||
}
|
||||
} while (consecutiveLineEndings < 2)
|
||||
drainHeaders()
|
||||
val buf = new Array[Byte](len)
|
||||
var offset = 0
|
||||
do {
|
||||
offset += inputStream.read(buf, offset, len - offset)
|
||||
} while (offset < len)
|
||||
content = buf.toSeq
|
||||
}
|
||||
} else if (line.startsWith("{")) {
|
||||
// Assume this is a json object with no headers
|
||||
content = line.getBytes.toSeq
|
||||
}
|
||||
case i if i < 0 =>
|
||||
running.set(false)
|
||||
throw new ClosedChannelException
|
||||
case `carriageReturn` => onCarriageReturn = true
|
||||
case c =>
|
||||
onCarriageReturn = false
|
||||
bytes += c.toByte
|
||||
|
||||
}
|
||||
} while (content.isEmpty && running.get)
|
||||
content
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -17,7 +17,7 @@ import sbt.BasicKeys._
|
|||
import sbt.nio.Watch.NullLogger
|
||||
import sbt.internal.protocol.JsonRpcResponseError
|
||||
import sbt.internal.server._
|
||||
import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, Terminal }
|
||||
import sbt.internal.util.{ ConsoleOut, MainAppender, Terminal }
|
||||
import sbt.io.syntax._
|
||||
import sbt.io.{ Hash, IO }
|
||||
import sbt.protocol.{ ExecStatusEvent, LogEvent }
|
||||
|
|
@ -271,19 +271,6 @@ private[sbt] final class CommandExchange {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This publishes object events. The type information has been
|
||||
* erased because it went through logging.
|
||||
*/
|
||||
private[sbt] def respondObjectEvent(event: ObjectEvent[_]): Unit = {
|
||||
for {
|
||||
source <- event.channelName
|
||||
channel <- channels.collectFirst {
|
||||
case c: NetworkChannel if c.name == source => c
|
||||
}
|
||||
} tryTo(_.respond(event))(channel)
|
||||
}
|
||||
|
||||
def prompt(event: ConsolePromptEvent): Unit = {
|
||||
activePrompt.set(Terminal.systemInIsAttached)
|
||||
channels
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class RelayAppender(name: String)
|
|||
def appendEvent(event: AnyRef): Unit =
|
||||
event match {
|
||||
case x: StringEvent => exchange.logMessage(LogEvent(level = x.level, message = x.message))
|
||||
case x: ObjectEvent[_] => exchange.respondObjectEvent(x)
|
||||
case x: ObjectEvent[_] => // ignore object events
|
||||
case _ =>
|
||||
println(s"appendEvent: ${event.getClass}")
|
||||
()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ package sbt
|
|||
package internal
|
||||
package server
|
||||
|
||||
import java.io.IOException
|
||||
import java.net.{ Socket, SocketTimeoutException }
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
|
|
@ -19,14 +20,13 @@ import sbt.internal.protocol.{
|
|||
JsonRpcResponseError,
|
||||
JsonRpcResponseMessage
|
||||
}
|
||||
import sbt.internal.util.ObjectEvent
|
||||
import sbt.internal.util.ReadJsonFromInputStream
|
||||
import sbt.internal.util.complete.Parser
|
||||
import sbt.protocol._
|
||||
import sbt.util.Logger
|
||||
import sjsonnew._
|
||||
import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter }
|
||||
|
||||
import scala.annotation.tailrec
|
||||
import scala.collection.mutable
|
||||
import scala.util.Try
|
||||
import scala.util.control.NonFatal
|
||||
|
|
@ -40,17 +40,11 @@ final class NetworkChannel(
|
|||
handlers: Seq[ServerHandler],
|
||||
val log: Logger
|
||||
) extends CommandChannel { self =>
|
||||
import NetworkChannel._
|
||||
|
||||
private val running = new AtomicBoolean(true)
|
||||
private val delimiter: Byte = '\n'.toByte
|
||||
private val RetByte = '\r'.toByte
|
||||
private val out = connection.getOutputStream
|
||||
private var initialized = false
|
||||
private val Curly = '{'.toByte
|
||||
private val ContentLength = """^Content\-Length\:\s*(\d+)""".r
|
||||
private val ContentType = """^Content\-Type\:\s*(.+)""".r
|
||||
private var _contentType: String = ""
|
||||
private val pendingRequests: mutable.Map[String, JsonRpcRequestMessage] = mutable.Map()
|
||||
|
||||
private lazy val callback: ServerCallback = new ServerCallback {
|
||||
|
|
@ -81,9 +75,6 @@ final class NetworkChannel(
|
|||
self.onCancellationRequest(execId, crp)
|
||||
}
|
||||
|
||||
def setContentType(ct: String): Unit = synchronized { _contentType = ct }
|
||||
def contentType: String = _contentType
|
||||
|
||||
protected def authenticate(token: String): Boolean = instance.authenticate(token)
|
||||
|
||||
protected def setInitialized(value: Boolean): Unit = initialized = value
|
||||
|
|
@ -91,105 +82,26 @@ final class NetworkChannel(
|
|||
protected def authOptions: Set[ServerAuthentication] = auth
|
||||
|
||||
val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") {
|
||||
var contentLength: Int = 0
|
||||
var state: ChannelState = SingleLine
|
||||
|
||||
private val ct = "Content-Type: "
|
||||
private val x1 = "application/sbt-x1"
|
||||
override def run(): Unit = {
|
||||
try {
|
||||
val readBuffer = new Array[Byte](4096)
|
||||
val in = connection.getInputStream
|
||||
connection.setSoTimeout(5000)
|
||||
var buffer: Vector[Byte] = Vector.empty
|
||||
var bytesRead = 0
|
||||
def resetChannelState(): Unit = {
|
||||
contentLength = 0
|
||||
state = SingleLine
|
||||
}
|
||||
def tillEndOfLine: Option[Vector[Byte]] = {
|
||||
val delimPos = buffer.indexOf(delimiter)
|
||||
if (delimPos > 0) {
|
||||
val chunk0 = buffer.take(delimPos)
|
||||
buffer = buffer.drop(delimPos + 1)
|
||||
// remove \r at the end of line.
|
||||
if (chunk0.size > 0 && chunk0.indexOf(RetByte) == chunk0.size - 1)
|
||||
Some(chunk0.dropRight(1))
|
||||
else Some(chunk0)
|
||||
} else None // no EOL yet, so skip this turn.
|
||||
}
|
||||
|
||||
def tillContentLength: Option[Vector[Byte]] = {
|
||||
if (contentLength <= buffer.size) {
|
||||
val chunk = buffer.take(contentLength)
|
||||
buffer = buffer.drop(contentLength)
|
||||
resetChannelState()
|
||||
Some(chunk)
|
||||
} else None // have not read enough yet, so skip this turn.
|
||||
}
|
||||
|
||||
@tailrec def process(): Unit = {
|
||||
// handle un-framing
|
||||
state match {
|
||||
case SingleLine =>
|
||||
val line = tillEndOfLine
|
||||
line match {
|
||||
case Some(chunk) =>
|
||||
chunk.headOption match {
|
||||
case None => // ignore blank line
|
||||
case Some(Curly) =>
|
||||
// When Content-Length header is not found, interpret the line as JSON message.
|
||||
handleBody(chunk)
|
||||
process()
|
||||
case Some(_) =>
|
||||
val str = (new String(chunk.toArray, "UTF-8")).trim
|
||||
handleHeader(str) match {
|
||||
case Some(_) =>
|
||||
state = InHeader
|
||||
process()
|
||||
case _ =>
|
||||
val msg = s"got invalid chunk from client: $str"
|
||||
log.error(msg)
|
||||
logMessage("error", msg)
|
||||
}
|
||||
}
|
||||
case _ => ()
|
||||
}
|
||||
case InHeader =>
|
||||
tillEndOfLine match {
|
||||
case Some(chunk) =>
|
||||
val str = (new String(chunk.toArray, "UTF-8")).trim
|
||||
if (str == "") {
|
||||
state = InBody
|
||||
process()
|
||||
} else
|
||||
handleHeader(str) match {
|
||||
case Some(_) => process()
|
||||
case _ =>
|
||||
log.error("Got invalid header from client: " + str)
|
||||
resetChannelState()
|
||||
}
|
||||
case _ => ()
|
||||
}
|
||||
case InBody =>
|
||||
tillContentLength match {
|
||||
case Some(chunk) =>
|
||||
handleBody(chunk)
|
||||
process()
|
||||
case _ => ()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val in = connection.getInputStream
|
||||
// keep going unless the socket has closed
|
||||
while (bytesRead != -1 && running.get) {
|
||||
while (running.get) {
|
||||
try {
|
||||
bytesRead = in.read(readBuffer)
|
||||
// log.debug(s"bytesRead: $bytesRead")
|
||||
if (bytesRead > 0) {
|
||||
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
|
||||
val onHeader: String => Unit = line => {
|
||||
if (line.startsWith(ct) && line.contains(x1)) {
|
||||
logMessage("error", s"server protocol $x1 is no longer supported")
|
||||
}
|
||||
}
|
||||
process()
|
||||
val content = ReadJsonFromInputStream(in, running, Some(onHeader))
|
||||
if (content.nonEmpty) handleBody(content)
|
||||
} catch {
|
||||
case _: SocketTimeoutException => // its ok
|
||||
case _: SocketTimeoutException => // its ok
|
||||
case _: IOException | _: InterruptedException => running.set(false)
|
||||
}
|
||||
} // while
|
||||
} finally {
|
||||
|
|
@ -213,7 +125,7 @@ final class NetworkChannel(
|
|||
case (f, i) => f orElse i.onNotification
|
||||
}
|
||||
|
||||
def handleBody(chunk: Vector[Byte]): Unit = {
|
||||
def handleBody(chunk: Seq[Byte]): Unit = {
|
||||
Serialization.deserializeJsonMessage(chunk) match {
|
||||
case Right(req: JsonRpcRequestMessage) =>
|
||||
try {
|
||||
|
|
@ -240,22 +152,6 @@ final class NetworkChannel(
|
|||
logMessage("error", msg)
|
||||
}
|
||||
}
|
||||
|
||||
def handleHeader(str: String): Option[Unit] = {
|
||||
val sbtX1Protocol = "application/sbt-x1"
|
||||
str match {
|
||||
case ContentLength(len) =>
|
||||
contentLength = len.toInt
|
||||
Some(())
|
||||
case ContentType(ct) =>
|
||||
if (ct == sbtX1Protocol) {
|
||||
logMessage("error", s"server protocol $ct is no longer supported")
|
||||
}
|
||||
setContentType(ct)
|
||||
Some(())
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
}
|
||||
thread.start()
|
||||
|
||||
|
|
@ -322,14 +218,6 @@ final class NetworkChannel(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This publishes object events. The type information has been
|
||||
* erased because it went through logging.
|
||||
*/
|
||||
private[sbt] def respond(event: ObjectEvent[_]): Unit = {
|
||||
onObjectEvent(event)
|
||||
}
|
||||
|
||||
def publishBytes(event: Array[Byte]): Unit = publishBytes(event, false)
|
||||
|
||||
def publishBytes(event: Array[Byte], delimit: Boolean): Unit = {
|
||||
|
|
@ -491,28 +379,6 @@ final class NetworkChannel(
|
|||
out.close()
|
||||
}
|
||||
|
||||
/**
|
||||
* This reacts to various events that happens inside sbt, sometime
|
||||
* in response to the previous requests.
|
||||
* The type information has been erased because it went through logging.
|
||||
*/
|
||||
protected def onObjectEvent(event: ObjectEvent[_]): Unit = {
|
||||
// import sbt.internal.langserver.codec.JsonProtocol._
|
||||
|
||||
val msgContentType = event.contentType
|
||||
msgContentType match {
|
||||
// LanguageServerReporter sends PublishDiagnosticsParams
|
||||
case "sbt.internal.langserver.PublishDiagnosticsParams" =>
|
||||
// val p = event.message.asInstanceOf[PublishDiagnosticsParams]
|
||||
// jsonRpcNotify("textDocument/publishDiagnostics", p)
|
||||
case "xsbti.Problem" =>
|
||||
() // ignore
|
||||
case _ =>
|
||||
// log.debug(event)
|
||||
()
|
||||
}
|
||||
}
|
||||
|
||||
/** Respond back to Language Server's client. */
|
||||
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = {
|
||||
val m =
|
||||
|
|
|
|||
Loading…
Reference in New Issue