Consolidate and optimize input stream json reading

We had similar code for reading json frames from an input stream in
NetworkChannel and ServerConnection. I reworked and consolidated this
logic into a shared method in ReadJsonFromInputStream.

This commit also removes the ObjectMessage reporting methods that
weren't doing anything.
This commit is contained in:
Ethan Atkins 2020-06-23 17:09:36 -07:00
parent b0a859acb5
commit fcfe4333fe
5 changed files with 122 additions and 222 deletions

View File

@ -9,10 +9,13 @@ package sbt
package internal
package client
import java.net.{ SocketTimeoutException, Socket }
import java.io.IOException
import java.net.{ Socket, SocketTimeoutException }
import java.util.concurrent.atomic.AtomicBoolean
import sbt.protocol._
import sbt.internal.protocol._
import sbt.internal.util.ReadJsonFromInputStream
abstract class ServerConnection(connection: Socket) {
@ -25,69 +28,29 @@ abstract class ServerConnection(connection: Socket) {
val thread = new Thread(s"sbt-serverconnection-${connection.getPort}") {
override def run(): Unit = {
try {
val readBuffer = new Array[Byte](4096)
val in = connection.getInputStream
connection.setSoTimeout(5000)
var buffer: Vector[Byte] = Vector.empty
def readFrame: Vector[Byte] = {
def getContentLength: Int = {
readLine.drop(16).toInt
}
val l = getContentLength
readLine
readLine
readContentLength(l)
}
def readLine: String = {
if (buffer.isEmpty) {
val bytesRead = in.read(readBuffer)
if (bytesRead > 0) {
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
}
}
val delimPos = buffer.indexOf(delimiter)
if (delimPos > 0) {
val chunk0 = buffer.take(delimPos)
buffer = buffer.drop(delimPos + 1)
// remove \r at the end of line.
val chunk1 = if (chunk0.lastOption contains retByte) chunk0.dropRight(1) else chunk0
new String(chunk1.toArray, "utf-8")
} else readLine
}
def readContentLength(length: Int): Vector[Byte] = {
if (buffer.size < length) {
val bytesRead = in.read(readBuffer)
if (bytesRead > 0) {
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
} else ()
} else ()
if (length <= buffer.size) {
val chunk = buffer.take(length)
buffer = buffer.drop(length)
chunk
} else readContentLength(length)
}
while (running.get) {
try {
val frame = readFrame
Serialization
.deserializeJsonMessage(frame)
.fold(
{ errorDesc =>
val s = frame.mkString("") // new String(: Array[Byte], "UTF-8")
println(s"Got invalid chunk from server: $s \n" + errorDesc)
},
_ match {
case msg: JsonRpcRequestMessage => onRequest(msg)
case msg: JsonRpcResponseMessage => onResponse(msg)
case msg: JsonRpcNotificationMessage => onNotification(msg)
}
)
val frame = ReadJsonFromInputStream(in, running, None)
if (running.get) {
Serialization
.deserializeJsonMessage(frame)
.fold(
{ errorDesc =>
val s = frame.mkString("") // new String(: Array[Byte], "UTF-8")
println(s"Got invalid chunk from server: $s \n" + errorDesc)
},
_ match {
case msg: JsonRpcRequestMessage => onRequest(msg)
case msg: JsonRpcResponseMessage => onResponse(msg)
case msg: JsonRpcNotificationMessage => onNotification(msg)
}
)
}
} catch {
case _: SocketTimeoutException => // its ok
case e: IOException => running.set(false)
}
}
} finally {

View File

@ -0,0 +1,84 @@
/*
* sbt
* Copyright 2011 - 2018, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package sbt.internal.util
import java.io.InputStream
import java.nio.channels.ClosedChannelException
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable
import scala.util.Try
private[sbt] object ReadJsonFromInputStream {
def apply(
inputStream: InputStream,
running: AtomicBoolean,
onHeader: Option[String => Unit]
): Seq[Byte] = {
val newline = '\n'.toInt
val carriageReturn = '\r'.toInt
val contentLength = "Content-Length: "
var bytes = new mutable.ArrayBuffer[Byte]
def getLine(): String = {
val line = new String(bytes.toArray, "UTF-8")
bytes = new mutable.ArrayBuffer[Byte]
onHeader.foreach(oh => oh(line))
line
}
var content: Seq[Byte] = Seq.empty[Byte]
var consecutiveLineEndings = 0
var onCarriageReturn = false
do {
val byte = inputStream.read
byte match {
case `newline` =>
val line = getLine()
if (onCarriageReturn) consecutiveLineEndings += 1
onCarriageReturn = false
if (line.startsWith(contentLength)) {
Try(line.drop(contentLength.length).toInt) foreach { len =>
def drainHeaders(): Unit =
do {
inputStream.read match {
case `newline` if onCarriageReturn =>
getLine()
onCarriageReturn = false
consecutiveLineEndings += 1
case `carriageReturn` => onCarriageReturn = true
case c =>
if (c == newline) getLine()
else bytes += c.toByte
onCarriageReturn = false
consecutiveLineEndings = 0
}
} while (consecutiveLineEndings < 2)
drainHeaders()
val buf = new Array[Byte](len)
var offset = 0
do {
offset += inputStream.read(buf, offset, len - offset)
} while (offset < len)
content = buf.toSeq
}
} else if (line.startsWith("{")) {
// Assume this is a json object with no headers
content = line.getBytes.toSeq
}
case i if i < 0 =>
running.set(false)
throw new ClosedChannelException
case `carriageReturn` => onCarriageReturn = true
case c =>
onCarriageReturn = false
bytes += c.toByte
}
} while (content.isEmpty && running.get)
content
}
}

View File

@ -17,7 +17,7 @@ import sbt.BasicKeys._
import sbt.nio.Watch.NullLogger
import sbt.internal.protocol.JsonRpcResponseError
import sbt.internal.server._
import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, Terminal }
import sbt.internal.util.{ ConsoleOut, MainAppender, Terminal }
import sbt.io.syntax._
import sbt.io.{ Hash, IO }
import sbt.protocol.{ ExecStatusEvent, LogEvent }
@ -271,19 +271,6 @@ private[sbt] final class CommandExchange {
}
}
/**
* This publishes object events. The type information has been
* erased because it went through logging.
*/
private[sbt] def respondObjectEvent(event: ObjectEvent[_]): Unit = {
for {
source <- event.channelName
channel <- channels.collectFirst {
case c: NetworkChannel if c.name == source => c
}
} tryTo(_.respond(event))(channel)
}
def prompt(event: ConsolePromptEvent): Unit = {
activePrompt.set(Terminal.systemInIsAttached)
channels

View File

@ -44,7 +44,7 @@ class RelayAppender(name: String)
def appendEvent(event: AnyRef): Unit =
event match {
case x: StringEvent => exchange.logMessage(LogEvent(level = x.level, message = x.message))
case x: ObjectEvent[_] => exchange.respondObjectEvent(x)
case x: ObjectEvent[_] => // ignore object events
case _ =>
println(s"appendEvent: ${event.getClass}")
()

View File

@ -9,6 +9,7 @@ package sbt
package internal
package server
import java.io.IOException
import java.net.{ Socket, SocketTimeoutException }
import java.util.concurrent.atomic.AtomicBoolean
@ -19,14 +20,13 @@ import sbt.internal.protocol.{
JsonRpcResponseError,
JsonRpcResponseMessage
}
import sbt.internal.util.ObjectEvent
import sbt.internal.util.ReadJsonFromInputStream
import sbt.internal.util.complete.Parser
import sbt.protocol._
import sbt.util.Logger
import sjsonnew._
import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter }
import scala.annotation.tailrec
import scala.collection.mutable
import scala.util.Try
import scala.util.control.NonFatal
@ -40,17 +40,11 @@ final class NetworkChannel(
handlers: Seq[ServerHandler],
val log: Logger
) extends CommandChannel { self =>
import NetworkChannel._
private val running = new AtomicBoolean(true)
private val delimiter: Byte = '\n'.toByte
private val RetByte = '\r'.toByte
private val out = connection.getOutputStream
private var initialized = false
private val Curly = '{'.toByte
private val ContentLength = """^Content\-Length\:\s*(\d+)""".r
private val ContentType = """^Content\-Type\:\s*(.+)""".r
private var _contentType: String = ""
private val pendingRequests: mutable.Map[String, JsonRpcRequestMessage] = mutable.Map()
private lazy val callback: ServerCallback = new ServerCallback {
@ -81,9 +75,6 @@ final class NetworkChannel(
self.onCancellationRequest(execId, crp)
}
def setContentType(ct: String): Unit = synchronized { _contentType = ct }
def contentType: String = _contentType
protected def authenticate(token: String): Boolean = instance.authenticate(token)
protected def setInitialized(value: Boolean): Unit = initialized = value
@ -91,105 +82,26 @@ final class NetworkChannel(
protected def authOptions: Set[ServerAuthentication] = auth
val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") {
var contentLength: Int = 0
var state: ChannelState = SingleLine
private val ct = "Content-Type: "
private val x1 = "application/sbt-x1"
override def run(): Unit = {
try {
val readBuffer = new Array[Byte](4096)
val in = connection.getInputStream
connection.setSoTimeout(5000)
var buffer: Vector[Byte] = Vector.empty
var bytesRead = 0
def resetChannelState(): Unit = {
contentLength = 0
state = SingleLine
}
def tillEndOfLine: Option[Vector[Byte]] = {
val delimPos = buffer.indexOf(delimiter)
if (delimPos > 0) {
val chunk0 = buffer.take(delimPos)
buffer = buffer.drop(delimPos + 1)
// remove \r at the end of line.
if (chunk0.size > 0 && chunk0.indexOf(RetByte) == chunk0.size - 1)
Some(chunk0.dropRight(1))
else Some(chunk0)
} else None // no EOL yet, so skip this turn.
}
def tillContentLength: Option[Vector[Byte]] = {
if (contentLength <= buffer.size) {
val chunk = buffer.take(contentLength)
buffer = buffer.drop(contentLength)
resetChannelState()
Some(chunk)
} else None // have not read enough yet, so skip this turn.
}
@tailrec def process(): Unit = {
// handle un-framing
state match {
case SingleLine =>
val line = tillEndOfLine
line match {
case Some(chunk) =>
chunk.headOption match {
case None => // ignore blank line
case Some(Curly) =>
// When Content-Length header is not found, interpret the line as JSON message.
handleBody(chunk)
process()
case Some(_) =>
val str = (new String(chunk.toArray, "UTF-8")).trim
handleHeader(str) match {
case Some(_) =>
state = InHeader
process()
case _ =>
val msg = s"got invalid chunk from client: $str"
log.error(msg)
logMessage("error", msg)
}
}
case _ => ()
}
case InHeader =>
tillEndOfLine match {
case Some(chunk) =>
val str = (new String(chunk.toArray, "UTF-8")).trim
if (str == "") {
state = InBody
process()
} else
handleHeader(str) match {
case Some(_) => process()
case _ =>
log.error("Got invalid header from client: " + str)
resetChannelState()
}
case _ => ()
}
case InBody =>
tillContentLength match {
case Some(chunk) =>
handleBody(chunk)
process()
case _ => ()
}
}
}
val in = connection.getInputStream
// keep going unless the socket has closed
while (bytesRead != -1 && running.get) {
while (running.get) {
try {
bytesRead = in.read(readBuffer)
// log.debug(s"bytesRead: $bytesRead")
if (bytesRead > 0) {
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
val onHeader: String => Unit = line => {
if (line.startsWith(ct) && line.contains(x1)) {
logMessage("error", s"server protocol $x1 is no longer supported")
}
}
process()
val content = ReadJsonFromInputStream(in, running, Some(onHeader))
if (content.nonEmpty) handleBody(content)
} catch {
case _: SocketTimeoutException => // its ok
case _: SocketTimeoutException => // its ok
case _: IOException | _: InterruptedException => running.set(false)
}
} // while
} finally {
@ -213,7 +125,7 @@ final class NetworkChannel(
case (f, i) => f orElse i.onNotification
}
def handleBody(chunk: Vector[Byte]): Unit = {
def handleBody(chunk: Seq[Byte]): Unit = {
Serialization.deserializeJsonMessage(chunk) match {
case Right(req: JsonRpcRequestMessage) =>
try {
@ -240,22 +152,6 @@ final class NetworkChannel(
logMessage("error", msg)
}
}
def handleHeader(str: String): Option[Unit] = {
val sbtX1Protocol = "application/sbt-x1"
str match {
case ContentLength(len) =>
contentLength = len.toInt
Some(())
case ContentType(ct) =>
if (ct == sbtX1Protocol) {
logMessage("error", s"server protocol $ct is no longer supported")
}
setContentType(ct)
Some(())
case _ => None
}
}
}
thread.start()
@ -322,14 +218,6 @@ final class NetworkChannel(
}
}
/**
* This publishes object events. The type information has been
* erased because it went through logging.
*/
private[sbt] def respond(event: ObjectEvent[_]): Unit = {
onObjectEvent(event)
}
def publishBytes(event: Array[Byte]): Unit = publishBytes(event, false)
def publishBytes(event: Array[Byte], delimit: Boolean): Unit = {
@ -491,28 +379,6 @@ final class NetworkChannel(
out.close()
}
/**
* This reacts to various events that happens inside sbt, sometime
* in response to the previous requests.
* The type information has been erased because it went through logging.
*/
protected def onObjectEvent(event: ObjectEvent[_]): Unit = {
// import sbt.internal.langserver.codec.JsonProtocol._
val msgContentType = event.contentType
msgContentType match {
// LanguageServerReporter sends PublishDiagnosticsParams
case "sbt.internal.langserver.PublishDiagnosticsParams" =>
// val p = event.message.asInstanceOf[PublishDiagnosticsParams]
// jsonRpcNotify("textDocument/publishDiagnostics", p)
case "xsbti.Problem" =>
() // ignore
case _ =>
// log.debug(event)
()
}
}
/** Respond back to Language Server's client. */
private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = {
val m =