From fcfe4333feb8a743bf8eba7e20ec2ef3c0883957 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Tue, 23 Jun 2020 17:09:36 -0700 Subject: [PATCH] Consolidate and optimize input stream json reading We had similar code for reading json frames from an input stream in NetworkChannel and ServerConnection. I reworked and consolidated this logic into a shared method in ReadJsonFromInputStream. This commit also removes the ObjectMessage reporting methods that weren't doing anything. --- .../internal/client/ServerConnection.scala | 79 +++------ .../util/ReadJsonFromInputStream.scala | 84 +++++++++ .../scala/sbt/internal/CommandExchange.scala | 15 +- .../scala/sbt/internal/RelayAppender.scala | 2 +- .../sbt/internal/server/NetworkChannel.scala | 164 ++---------------- 5 files changed, 122 insertions(+), 222 deletions(-) create mode 100644 main-command/src/main/scala/sbt/internal/util/ReadJsonFromInputStream.scala diff --git a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala index ebde66bd1..85c3b9867 100644 --- a/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala +++ b/main-command/src/main/scala/sbt/internal/client/ServerConnection.scala @@ -9,10 +9,13 @@ package sbt package internal package client -import java.net.{ SocketTimeoutException, Socket } +import java.io.IOException +import java.net.{ Socket, SocketTimeoutException } import java.util.concurrent.atomic.AtomicBoolean + import sbt.protocol._ import sbt.internal.protocol._ +import sbt.internal.util.ReadJsonFromInputStream abstract class ServerConnection(connection: Socket) { @@ -25,69 +28,29 @@ abstract class ServerConnection(connection: Socket) { val thread = new Thread(s"sbt-serverconnection-${connection.getPort}") { override def run(): Unit = { try { - val readBuffer = new Array[Byte](4096) val in = connection.getInputStream connection.setSoTimeout(5000) - var buffer: Vector[Byte] = Vector.empty - def readFrame: Vector[Byte] = { - def getContentLength: Int = { - readLine.drop(16).toInt - } - val l = getContentLength - readLine - readLine - readContentLength(l) - } - - def readLine: String = { - if (buffer.isEmpty) { - val bytesRead = in.read(readBuffer) - if (bytesRead > 0) { - buffer = buffer ++ readBuffer.toVector.take(bytesRead) - } - } - val delimPos = buffer.indexOf(delimiter) - if (delimPos > 0) { - val chunk0 = buffer.take(delimPos) - buffer = buffer.drop(delimPos + 1) - // remove \r at the end of line. - val chunk1 = if (chunk0.lastOption contains retByte) chunk0.dropRight(1) else chunk0 - new String(chunk1.toArray, "utf-8") - } else readLine - } - - def readContentLength(length: Int): Vector[Byte] = { - if (buffer.size < length) { - val bytesRead = in.read(readBuffer) - if (bytesRead > 0) { - buffer = buffer ++ readBuffer.toVector.take(bytesRead) - } else () - } else () - if (length <= buffer.size) { - val chunk = buffer.take(length) - buffer = buffer.drop(length) - chunk - } else readContentLength(length) - } - while (running.get) { try { - val frame = readFrame - Serialization - .deserializeJsonMessage(frame) - .fold( - { errorDesc => - val s = frame.mkString("") // new String(: Array[Byte], "UTF-8") - println(s"Got invalid chunk from server: $s \n" + errorDesc) - }, - _ match { - case msg: JsonRpcRequestMessage => onRequest(msg) - case msg: JsonRpcResponseMessage => onResponse(msg) - case msg: JsonRpcNotificationMessage => onNotification(msg) - } - ) + val frame = ReadJsonFromInputStream(in, running, None) + if (running.get) { + Serialization + .deserializeJsonMessage(frame) + .fold( + { errorDesc => + val s = frame.mkString("") // new String(: Array[Byte], "UTF-8") + println(s"Got invalid chunk from server: $s \n" + errorDesc) + }, + _ match { + case msg: JsonRpcRequestMessage => onRequest(msg) + case msg: JsonRpcResponseMessage => onResponse(msg) + case msg: JsonRpcNotificationMessage => onNotification(msg) + } + ) + } } catch { case _: SocketTimeoutException => // its ok + case e: IOException => running.set(false) } } } finally { diff --git a/main-command/src/main/scala/sbt/internal/util/ReadJsonFromInputStream.scala b/main-command/src/main/scala/sbt/internal/util/ReadJsonFromInputStream.scala new file mode 100644 index 000000000..faf63f7e4 --- /dev/null +++ b/main-command/src/main/scala/sbt/internal/util/ReadJsonFromInputStream.scala @@ -0,0 +1,84 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal.util + +import java.io.InputStream +import java.nio.channels.ClosedChannelException +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable +import scala.util.Try + +private[sbt] object ReadJsonFromInputStream { + def apply( + inputStream: InputStream, + running: AtomicBoolean, + onHeader: Option[String => Unit] + ): Seq[Byte] = { + val newline = '\n'.toInt + val carriageReturn = '\r'.toInt + val contentLength = "Content-Length: " + var bytes = new mutable.ArrayBuffer[Byte] + def getLine(): String = { + val line = new String(bytes.toArray, "UTF-8") + bytes = new mutable.ArrayBuffer[Byte] + onHeader.foreach(oh => oh(line)) + line + } + var content: Seq[Byte] = Seq.empty[Byte] + var consecutiveLineEndings = 0 + var onCarriageReturn = false + do { + val byte = inputStream.read + byte match { + case `newline` => + val line = getLine() + if (onCarriageReturn) consecutiveLineEndings += 1 + onCarriageReturn = false + if (line.startsWith(contentLength)) { + Try(line.drop(contentLength.length).toInt) foreach { len => + def drainHeaders(): Unit = + do { + inputStream.read match { + case `newline` if onCarriageReturn => + getLine() + onCarriageReturn = false + consecutiveLineEndings += 1 + case `carriageReturn` => onCarriageReturn = true + case c => + if (c == newline) getLine() + else bytes += c.toByte + onCarriageReturn = false + consecutiveLineEndings = 0 + } + } while (consecutiveLineEndings < 2) + drainHeaders() + val buf = new Array[Byte](len) + var offset = 0 + do { + offset += inputStream.read(buf, offset, len - offset) + } while (offset < len) + content = buf.toSeq + } + } else if (line.startsWith("{")) { + // Assume this is a json object with no headers + content = line.getBytes.toSeq + } + case i if i < 0 => + running.set(false) + throw new ClosedChannelException + case `carriageReturn` => onCarriageReturn = true + case c => + onCarriageReturn = false + bytes += c.toByte + + } + } while (content.isEmpty && running.get) + content + } + +} diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 785c106d2..8b7da864a 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -17,7 +17,7 @@ import sbt.BasicKeys._ import sbt.nio.Watch.NullLogger import sbt.internal.protocol.JsonRpcResponseError import sbt.internal.server._ -import sbt.internal.util.{ ConsoleOut, MainAppender, ObjectEvent, Terminal } +import sbt.internal.util.{ ConsoleOut, MainAppender, Terminal } import sbt.io.syntax._ import sbt.io.{ Hash, IO } import sbt.protocol.{ ExecStatusEvent, LogEvent } @@ -271,19 +271,6 @@ private[sbt] final class CommandExchange { } } - /** - * This publishes object events. The type information has been - * erased because it went through logging. - */ - private[sbt] def respondObjectEvent(event: ObjectEvent[_]): Unit = { - for { - source <- event.channelName - channel <- channels.collectFirst { - case c: NetworkChannel if c.name == source => c - } - } tryTo(_.respond(event))(channel) - } - def prompt(event: ConsolePromptEvent): Unit = { activePrompt.set(Terminal.systemInIsAttached) channels diff --git a/main/src/main/scala/sbt/internal/RelayAppender.scala b/main/src/main/scala/sbt/internal/RelayAppender.scala index edfb70d0a..e140754b6 100644 --- a/main/src/main/scala/sbt/internal/RelayAppender.scala +++ b/main/src/main/scala/sbt/internal/RelayAppender.scala @@ -44,7 +44,7 @@ class RelayAppender(name: String) def appendEvent(event: AnyRef): Unit = event match { case x: StringEvent => exchange.logMessage(LogEvent(level = x.level, message = x.message)) - case x: ObjectEvent[_] => exchange.respondObjectEvent(x) + case x: ObjectEvent[_] => // ignore object events case _ => println(s"appendEvent: ${event.getClass}") () diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index b3bd4e644..417b80689 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -9,6 +9,7 @@ package sbt package internal package server +import java.io.IOException import java.net.{ Socket, SocketTimeoutException } import java.util.concurrent.atomic.AtomicBoolean @@ -19,14 +20,13 @@ import sbt.internal.protocol.{ JsonRpcResponseError, JsonRpcResponseMessage } -import sbt.internal.util.ObjectEvent +import sbt.internal.util.ReadJsonFromInputStream import sbt.internal.util.complete.Parser import sbt.protocol._ import sbt.util.Logger import sjsonnew._ import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } -import scala.annotation.tailrec import scala.collection.mutable import scala.util.Try import scala.util.control.NonFatal @@ -40,17 +40,11 @@ final class NetworkChannel( handlers: Seq[ServerHandler], val log: Logger ) extends CommandChannel { self => - import NetworkChannel._ private val running = new AtomicBoolean(true) private val delimiter: Byte = '\n'.toByte - private val RetByte = '\r'.toByte private val out = connection.getOutputStream private var initialized = false - private val Curly = '{'.toByte - private val ContentLength = """^Content\-Length\:\s*(\d+)""".r - private val ContentType = """^Content\-Type\:\s*(.+)""".r - private var _contentType: String = "" private val pendingRequests: mutable.Map[String, JsonRpcRequestMessage] = mutable.Map() private lazy val callback: ServerCallback = new ServerCallback { @@ -81,9 +75,6 @@ final class NetworkChannel( self.onCancellationRequest(execId, crp) } - def setContentType(ct: String): Unit = synchronized { _contentType = ct } - def contentType: String = _contentType - protected def authenticate(token: String): Boolean = instance.authenticate(token) protected def setInitialized(value: Boolean): Unit = initialized = value @@ -91,105 +82,26 @@ final class NetworkChannel( protected def authOptions: Set[ServerAuthentication] = auth val thread = new Thread(s"sbt-networkchannel-${connection.getPort}") { - var contentLength: Int = 0 - var state: ChannelState = SingleLine - + private val ct = "Content-Type: " + private val x1 = "application/sbt-x1" override def run(): Unit = { try { - val readBuffer = new Array[Byte](4096) - val in = connection.getInputStream connection.setSoTimeout(5000) - var buffer: Vector[Byte] = Vector.empty - var bytesRead = 0 - def resetChannelState(): Unit = { - contentLength = 0 - state = SingleLine - } - def tillEndOfLine: Option[Vector[Byte]] = { - val delimPos = buffer.indexOf(delimiter) - if (delimPos > 0) { - val chunk0 = buffer.take(delimPos) - buffer = buffer.drop(delimPos + 1) - // remove \r at the end of line. - if (chunk0.size > 0 && chunk0.indexOf(RetByte) == chunk0.size - 1) - Some(chunk0.dropRight(1)) - else Some(chunk0) - } else None // no EOL yet, so skip this turn. - } - - def tillContentLength: Option[Vector[Byte]] = { - if (contentLength <= buffer.size) { - val chunk = buffer.take(contentLength) - buffer = buffer.drop(contentLength) - resetChannelState() - Some(chunk) - } else None // have not read enough yet, so skip this turn. - } - - @tailrec def process(): Unit = { - // handle un-framing - state match { - case SingleLine => - val line = tillEndOfLine - line match { - case Some(chunk) => - chunk.headOption match { - case None => // ignore blank line - case Some(Curly) => - // When Content-Length header is not found, interpret the line as JSON message. - handleBody(chunk) - process() - case Some(_) => - val str = (new String(chunk.toArray, "UTF-8")).trim - handleHeader(str) match { - case Some(_) => - state = InHeader - process() - case _ => - val msg = s"got invalid chunk from client: $str" - log.error(msg) - logMessage("error", msg) - } - } - case _ => () - } - case InHeader => - tillEndOfLine match { - case Some(chunk) => - val str = (new String(chunk.toArray, "UTF-8")).trim - if (str == "") { - state = InBody - process() - } else - handleHeader(str) match { - case Some(_) => process() - case _ => - log.error("Got invalid header from client: " + str) - resetChannelState() - } - case _ => () - } - case InBody => - tillContentLength match { - case Some(chunk) => - handleBody(chunk) - process() - case _ => () - } - } - } + val in = connection.getInputStream // keep going unless the socket has closed - while (bytesRead != -1 && running.get) { + while (running.get) { try { - bytesRead = in.read(readBuffer) - // log.debug(s"bytesRead: $bytesRead") - if (bytesRead > 0) { - buffer = buffer ++ readBuffer.toVector.take(bytesRead) + val onHeader: String => Unit = line => { + if (line.startsWith(ct) && line.contains(x1)) { + logMessage("error", s"server protocol $x1 is no longer supported") + } } - process() + val content = ReadJsonFromInputStream(in, running, Some(onHeader)) + if (content.nonEmpty) handleBody(content) } catch { - case _: SocketTimeoutException => // its ok + case _: SocketTimeoutException => // its ok + case _: IOException | _: InterruptedException => running.set(false) } } // while } finally { @@ -213,7 +125,7 @@ final class NetworkChannel( case (f, i) => f orElse i.onNotification } - def handleBody(chunk: Vector[Byte]): Unit = { + def handleBody(chunk: Seq[Byte]): Unit = { Serialization.deserializeJsonMessage(chunk) match { case Right(req: JsonRpcRequestMessage) => try { @@ -240,22 +152,6 @@ final class NetworkChannel( logMessage("error", msg) } } - - def handleHeader(str: String): Option[Unit] = { - val sbtX1Protocol = "application/sbt-x1" - str match { - case ContentLength(len) => - contentLength = len.toInt - Some(()) - case ContentType(ct) => - if (ct == sbtX1Protocol) { - logMessage("error", s"server protocol $ct is no longer supported") - } - setContentType(ct) - Some(()) - case _ => None - } - } } thread.start() @@ -322,14 +218,6 @@ final class NetworkChannel( } } - /** - * This publishes object events. The type information has been - * erased because it went through logging. - */ - private[sbt] def respond(event: ObjectEvent[_]): Unit = { - onObjectEvent(event) - } - def publishBytes(event: Array[Byte]): Unit = publishBytes(event, false) def publishBytes(event: Array[Byte], delimit: Boolean): Unit = { @@ -491,28 +379,6 @@ final class NetworkChannel( out.close() } - /** - * This reacts to various events that happens inside sbt, sometime - * in response to the previous requests. - * The type information has been erased because it went through logging. - */ - protected def onObjectEvent(event: ObjectEvent[_]): Unit = { - // import sbt.internal.langserver.codec.JsonProtocol._ - - val msgContentType = event.contentType - msgContentType match { - // LanguageServerReporter sends PublishDiagnosticsParams - case "sbt.internal.langserver.PublishDiagnosticsParams" => - // val p = event.message.asInstanceOf[PublishDiagnosticsParams] - // jsonRpcNotify("textDocument/publishDiagnostics", p) - case "xsbti.Problem" => - () // ignore - case _ => - // log.debug(event) - () - } - } - /** Respond back to Language Server's client. */ private[sbt] def jsonRpcRespond[A: JsonFormat](event: A, execId: String): Unit = { val m =