diff --git a/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcRequestMessageFormats.scala b/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcRequestMessageFormats.scala index dac0987e4..2c1273534 100644 --- a/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcRequestMessageFormats.scala +++ b/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcRequestMessageFormats.scala @@ -24,7 +24,10 @@ trait JsonRpcRequestMessageFormats { val id = try { unbuilder.readField[String]("id") } catch { - case _: Throwable => unbuilder.readField[Long]("id").toString + case _: Throwable => { + val prefix = "\u2668" // Append prefix to show the original type was Number + prefix + unbuilder.readField[Long]("id").toString + } } val method = unbuilder.readField[String]("method") val params = unbuilder.lookupField("params") map { diff --git a/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala b/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala index d10164e67..c9d943296 100644 --- a/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala +++ b/protocol/src/main/scala/sbt/internal/protocol/codec/JsonRpcResponseMessageFormats.scala @@ -7,8 +7,8 @@ package sbt.internal.protocol.codec -import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } -import sjsonnew.shaded.scalajson.ast.unsafe.JValue +import _root_.sjsonnew.{ Builder, JsonFormat, Unbuilder, deserializationError } +import sjsonnew.shaded.scalajson.ast.unsafe._ trait JsonRpcResponseMessageFormats { self: sbt.internal.util.codec.JValueFormats @@ -45,10 +45,35 @@ trait JsonRpcResponseMessageFormats { } override def write[J](obj: sbt.internal.protocol.JsonRpcResponseMessage, builder: Builder[J]): Unit = { + // Parse given id to Long or String judging by prefix + def parseId(str: String): Either[Long, String] = { + if (str.startsWith("\u2668")) Left(str.substring(1).toLong) + else Right(str) + } + def parseResult(jValue: JValue): JValue = jValue match { + case JObject(jFields) => + val replaced = jFields map { + case field @ JField("execId", JString(str)) => + parseId(str) match { + case Right(strId) => field.copy(value = JString(strId)) + case Left(longId) => field.copy(value = JNumber(longId)) + } + case other => + other + } + JObject(replaced) + case other => + other + } builder.beginObject() builder.addField("jsonrpc", obj.jsonrpc) - builder.addField("id", obj.id) - builder.addField("result", obj.result) + obj.id foreach { id => + parseId(id) match { + case Right(strId) => builder.addField("id", strId) + case Left(longId) => builder.addField("id", longId) + } + } + builder.addField("result", obj.result map parseResult) builder.addField("error", obj.error) builder.endObject() } diff --git a/sbt/src/test/scala/sbt/ServerSpec.scala b/sbt/src/test/scala/sbt/ServerSpec.scala index 7ad307fd8..a328c6dfe 100644 --- a/sbt/src/test/scala/sbt/ServerSpec.scala +++ b/sbt/src/test/scala/sbt/ServerSpec.scala @@ -9,6 +9,7 @@ package sbt import org.scalatest._ import scala.concurrent._ +import scala.annotation.tailrec import java.io.{ InputStream, OutputStream } import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{ ThreadFactory, ThreadPoolExecutor } @@ -23,15 +24,22 @@ class ServerSpec extends AsyncFlatSpec with Matchers { """{ "jsonrpc": "2.0", "id": 3, "method": "sbt/setting", "params": { "setting": "root/name" } }""", out) Thread.sleep(100) - val l2 = contentLength(in) - println(l2) - readLine(in) - readLine(in) - val x2 = readContentLength(in, l2) - println(x2) - assert(1 == 1) + assert(waitFor(in, 10) { s => + s contains """"id":3""" + }) } } + + @tailrec + private[this] def waitFor(in: InputStream, num: Int)(f: String => Boolean): Boolean = { + if (num < 0) false + else + readFrame(in) match { + case Some(x) if f(x) => true + case _ => + waitFor(in, num - 1)(f) + } + } } object ServerSpec { @@ -90,6 +98,13 @@ object ServerSpec { writeLine(message, out) } + def readFrame(in: InputStream): Option[String] = { + val l = contentLength(in) + readLine(in) + readLine(in) + readContentLength(in, l) + } + def contentLength(in: InputStream): Int = { readLine(in) map { line => line.drop(16).toInt