diff --git a/server-test/src/test/scala/sbt/ReadJson.scala b/server-test/src/test/scala/sbt/ReadJson.scala index 6b9433bcd..6b2fd14a1 100644 --- a/server-test/src/test/scala/sbt/ReadJson.scala +++ b/server-test/src/test/scala/sbt/ReadJson.scala @@ -11,8 +11,6 @@ import java.util.concurrent.atomic.AtomicBoolean import java.io.InputStream object ReadJson { - def apply(in: InputStream, running: AtomicBoolean): Option[String] = { - val bytes = sbt.internal.util.ReadJsonFromInputStream(in, running, None).toArray - Some(new String(bytes, "UTF-8")) - } + def apply(in: InputStream, running: AtomicBoolean): String = + new String(sbt.internal.util.ReadJsonFromInputStream(in, running, None).toArray, "UTF-8") } diff --git a/server-test/src/test/scala/testpkg/ResponseTest.scala b/server-test/src/test/scala/testpkg/ResponseTest.scala index 4b4057989..23d0f2e16 100644 --- a/server-test/src/test/scala/testpkg/ResponseTest.scala +++ b/server-test/src/test/scala/testpkg/ResponseTest.scala @@ -70,7 +70,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "15", "method": "foo/respondTwice", "params": {} }""" ) assert { - svr.waitForString(1.seconds) { s => + svr.waitForString(10.seconds) { s => if (!s.contains("systemOut")) println(s) s contains "\"id\":\"15\"" } @@ -89,7 +89,7 @@ object ResponseTest extends AbstractServerTest { """{ "jsonrpc": "2.0", "id": "16", "method": "foo/resultAndError", "params": {} }""" ) assert { - svr.waitForString(1.seconds) { s => + svr.waitForString(10.seconds) { s => if (!s.contains("systemOut")) println(s) s contains "\"id\":\"16\"" } diff --git a/server-test/src/test/scala/testpkg/TestServer.scala b/server-test/src/test/scala/testpkg/TestServer.scala index 01964f39f..86c0d5e56 100644 --- a/server-test/src/test/scala/testpkg/TestServer.scala +++ b/server-test/src/test/scala/testpkg/TestServer.scala @@ -9,7 +9,7 @@ package testpkg import java.io.{ File, IOException } import java.nio.file.Path -import java.util.concurrent.TimeoutException +import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit } import java.util.concurrent.atomic.AtomicBoolean import verify._ @@ -24,7 +24,6 @@ import scala.concurrent.duration._ import scala.util.{ Success, Try } trait AbstractServerTest extends TestSuite[Unit] { - implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global private var temp: File = _ var svr: TestServer = _ def testDirectory: String @@ -90,9 +89,9 @@ object TestServer { // if something goes wrong here the communication streams are corrupted, restarting val init = Try { - testServer.waitForString(30.seconds) { s => + testServer.waitForString(10.seconds) { s => println(s) - s contains """"message":"Done"""" + s contains """"capabilities":{"""" } } init.get @@ -130,9 +129,9 @@ object TestServer { // if something goes wrong here the communication streams are corrupted, restarting val init = Try { - testServer.waitForString(30.seconds) { s => + testServer.waitForString(10.seconds) { s => if (s.nonEmpty) println(s) - s contains """"message":"Done"""" + s contains """"capabilities":{"""" } } @@ -165,14 +164,8 @@ case class TestServer( sbtVersion: String, classpath: Seq[File] ) { - import scala.concurrent.ExecutionContext.Implicits._ import TestServer.hostLog - val readBuffer = new Array[Byte](40960) - var buffer: Vector[Byte] = Vector.empty - var bytesRead = 0 - val running = new AtomicBoolean(true) - hostLog("fork to a new sbt instance") val process = RunFromSourceMain.fork(baseDirectory, scalaVersion, sbtVersion, classpath) @@ -194,47 +187,72 @@ case class TestServer( if (deadline.isOverdue) sys.error(s"Timeout. $portfile is not found.") if (!process.isAlive) sys.error(s"Server unexpectedly terminated.") } - private val waitDuration: FiniteDuration = 120.seconds + private val waitDuration: FiniteDuration = 1.minute hostLog(s"wait $waitDuration until the server is ready to respond") - waitForPortfile(90.seconds) + waitForPortfile(waitDuration) // make connection to the socket described in the portfile - var (sk, _) = ClientSocket.socket(portfile) - var out = sk.getOutputStream - var in = sk.getInputStream + val (sk, _) = ClientSocket.socket(portfile) + val out = sk.getOutputStream + val in = sk.getInputStream + private val lines = new LinkedBlockingQueue[String] + val running = new AtomicBoolean(true) + val readThread = + new Thread(() => { + while (running.get) { + try lines.put(sbt.ReadJson(in, running)) + catch { case _: Exception => running.set(false) } + } + }, "sbt-server-test-read-thread") { + setDaemon(true) + start() + } // initiate handshake sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" + s"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { "skipAnalysis": true } } }""" ) - def resetConnection() = { - Option(sk).foreach(_.close()) - sk = ClientSocket.socket(portfile)._1 - out = sk.getOutputStream - in = sk.getInputStream - - sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" - ) - } - def test(f: TestServer => Future[Assertion]): Future[Assertion] = { f(this) } - def bye(): Unit = { - hostLog("sending exit") - sendJsonRpc( - """{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "shutdown" } }""" - ) - val deadline = 10.seconds.fromNow - while (!deadline.isOverdue && process.isAlive) { - Thread.sleep(10) + def bye(): Unit = + try { + running.set(false) + hostLog("sending exit") + sendJsonRpc( + """{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "shutdown" } }""" + ) + val deadline = 5.seconds.fromNow + while (!deadline.isOverdue && process.isAlive) { + Thread.sleep(10) + } + // We gave the server a chance to exit but it didn't within a reasonable time frame. + if (deadline.isOverdue && process.isAlive) { + process.destroy() + val newDeadline = 10.seconds.fromNow + while (!newDeadline.isOverdue && process.isAlive) { + Thread.sleep(10) + } + } + if (process.isAlive) throw new IllegalStateException(s"process $process failed to exit") + } finally { + readThread.interrupt() + /* + * The UnixDomainSocket input stream cannot be closed while a thread is + * reading from it (even if the UnixDomainSocket itself is closed): + * https://github.com/sbt/ipcsocket/blob/f02d29092f9f0c57e5c4b276a31fa16975ddf66e/src/main/java/org/scalasbt/ipcsocket/UnixDomainSocket.java#L111-L118 + * This makes it impossible to interrupt the readThread until after the + * server process has exited which closes the ServerSocket which does + * cause the input stream to be closed. We could change the behavior of + * ipcsocket, but that seems risky without knowing exactly why the behavior + * exists. For now, ensure that we are able to interrupt and join the + * read thread and throw an exception if not. + */ + readThread.join(5000) + if (readThread.isAlive) throw new IllegalStateException(s"Unable to join read thread") } - // We gave the server a chance to exit but it didn't within a reasonable time frame. - if (deadline.isOverdue) process.destroy() - } def sendJsonRpc(message: String): Unit = { writeLine(s"""Content-Length: ${message.size + 2}""") @@ -257,36 +275,24 @@ case class TestServer( writeEndLine } - def readFrame: Future[Option[String]] = Future(sbt.ReadJson(in, running)) - final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = { val deadline = duration.fromNow - @tailrec def impl(): Boolean = { - val res = try { - Await.result(readFrame, deadline.timeLeft).fold(false)(f) - } catch { - case _: TimeoutException => - resetConnection() // create a new connection to invalidate the running readFrame future - false + @tailrec def impl(): Boolean = + lines.poll(deadline.timeLeft.toMillis, TimeUnit.MILLISECONDS) match { + case null => false + case s => if (!f(s) && !deadline.isOverdue) impl() else !deadline.isOverdue() } - if (!res && !deadline.isOverdue) impl() else !deadline.isOverdue() - } impl() } final def neverReceive(duration: FiniteDuration)(f: String => Boolean): Boolean = { val deadline = duration.fromNow @tailrec - def impl(): Boolean = { - val res = try { - Await.result(readFrame, deadline.timeLeft).fold(true)(s => !f(s)) - } catch { - case _: TimeoutException => - resetConnection() // create a new connection to invalidate the running readFrame future - true + def impl(): Boolean = + lines.poll(deadline.timeLeft.toMillis, TimeUnit.MILLISECONDS) match { + case null => true + case s => if (!f(s)) impl() else false } - if (res && !deadline.isOverdue) impl else res || !deadline.isOverdue - } impl() }