/* * sbt * Copyright 2011 - 2018, Lightbend, Inc. * Copyright 2008 - 2010, Mark Harrah * Licensed under Apache License 2.0 (see LICENSE) */ package testpkg import java.io.{ File, IOException } import java.nio.file.Path import java.util.concurrent.TimeoutException import verify._ import sbt.RunFromSourceMain import sbt.io.IO import sbt.io.syntax._ import sbt.protocol.ClientSocket import scala.annotation.tailrec import scala.concurrent._ import scala.concurrent.duration._ import scala.util.{ Success, Try } trait AbstractServerTest extends TestSuite[Unit] { implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global private var temp: File = _ var svr: TestServer = _ def testDirectory: String def testPath: Path = temp.toPath.resolve(testDirectory) private val targetDir: File = { val p0 = new File("..").getAbsoluteFile.getCanonicalFile / "target" val p1 = new File("target").getAbsoluteFile if (p0.exists) p0 else p1 } override def setupSuite(): Unit = { temp = targetDir / "test-server" / testDirectory if (temp.exists) { IO.delete(temp) } val classpath = sys.props.get("sbt.server.classpath") match { case Some(s: String) => s.split(java.io.File.pathSeparator).map(file) case _ => throw new IllegalStateException("No server classpath was specified.") } val sbtVersion = sys.props.get("sbt.server.version") match { case Some(v: String) => v case _ => throw new IllegalStateException("No server version was specified.") } val scalaVersion = sys.props.get("sbt.server.scala.version") match { case Some(v: String) => v case _ => throw new IllegalStateException("No server scala version was specified.") } svr = TestServer.get(testDirectory, scalaVersion, sbtVersion, classpath, temp) } override def tearDownSuite(): Unit = { svr.bye() svr = null IO.delete(temp) } override def setup(): Unit = () override def tearDown(env: Unit): Unit = () } object TestServer { // forking affects this private val serverTestBase: File = { val p0 = new File(".").getAbsoluteFile / "server-test" / "src" / "server-test" val p1 = new File(".").getAbsoluteFile / "src" / "server-test" if (p0.exists) p0 else p1 } def get( testBuild: String, scalaVersion: String, sbtVersion: String, classpath: Seq[File], temp: File ): TestServer = { println(s"Starting test server $testBuild") IO.copyDirectory(serverTestBase / testBuild, temp / testBuild) // Each test server instance will be executed in a Thread pool separated from the tests val testServer = TestServer(temp / testBuild, scalaVersion, sbtVersion, classpath) // checking last log message after initialization // if something goes wrong here the communication streams are corrupted, restarting val init = Try { testServer.waitForString(30.seconds) { s => println(s) s contains """"message":"Done"""" } } init.get testServer } def withTestServer( testBuild: String )(f: TestServer => Future[Unit]): Future[Unit] = { println(s"Starting test") IO.withTemporaryDirectory { temp => IO.copyDirectory(serverTestBase / testBuild, temp / testBuild) withTestServer(testBuild, temp / testBuild)(f) } } def withTestServer(testBuild: String, baseDirectory: File)( f: TestServer => Future[Unit] ): Future[Unit] = { val classpath = sys.props.get("sbt.server.classpath") match { case Some(s: String) => s.split(java.io.File.pathSeparator).map(file) case _ => throw new IllegalStateException("No server classpath was specified.") } val sbtVersion = sys.props.get("sbt.server.version") match { case Some(v: String) => v case _ => throw new IllegalStateException("No server version was specified.") } val scalaVersion = sys.props.get("sbt.server.scala.version") match { case Some(v: String) => v case _ => throw new IllegalStateException("No server scala version was specified.") } // Each test server instance will be executed in a Thread pool separated from the tests val testServer = TestServer(baseDirectory, scalaVersion, sbtVersion, classpath) // checking last log message after initialization // if something goes wrong here the communication streams are corrupted, restarting val init = Try { testServer.waitForString(30.seconds) { s => if (s.nonEmpty) println(s) s contains """"message":"Done"""" } } init match { case Success(_) => try { f(testServer) } finally { try { testServer.bye() } finally {} } case _ => try { testServer.bye() } finally {} hostLog("Server started but not connected properly... restarting...") withTestServer(testBuild)(f) } } def hostLog(s: String): Unit = { println(s"""[${scala.Console.MAGENTA}build-1${scala.Console.RESET}] $s""") } } case class TestServer( baseDirectory: File, scalaVersion: String, sbtVersion: String, classpath: Seq[File] ) { import scala.concurrent.ExecutionContext.Implicits._ import TestServer.hostLog val readBuffer = new Array[Byte](40960) var buffer: Vector[Byte] = Vector.empty var bytesRead = 0 private val delimiter: Byte = '\n'.toByte private val RetByte = '\r'.toByte hostLog("fork to a new sbt instance") val process = RunFromSourceMain.fork(baseDirectory, scalaVersion, sbtVersion, classpath) lazy val portfile = baseDirectory / "project" / "target" / "active.json" def portfileIsEmpty(): Boolean = try IO.read(portfile).isEmpty catch { case _: IOException => true } def waitForPortfile(duration: FiniteDuration): Unit = { val deadline = duration.fromNow var nextLog = 10.seconds.fromNow while (portfileIsEmpty && !deadline.isOverdue && process.isAlive) { if (nextLog.isOverdue) { hostLog("waiting for the server...") nextLog = 10.seconds.fromNow } Thread.sleep(10) // Don't spam the portfile } if (deadline.isOverdue) sys.error(s"Timeout. $portfile is not found.") if (!process.isAlive) sys.error(s"Server unexpectedly terminated.") } private val waitDuration: FiniteDuration = 120.seconds hostLog(s"wait $waitDuration until the server is ready to respond") waitForPortfile(90.seconds) // make connection to the socket described in the portfile var (sk, _) = ClientSocket.socket(portfile) var out = sk.getOutputStream var in = sk.getInputStream // initiate handshake sendJsonRpc( """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" ) def resetConnection() = { Option(sk).foreach(_.close()) sk = ClientSocket.socket(portfile)._1 out = sk.getOutputStream in = sk.getInputStream sendJsonRpc( """{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""" ) } def test(f: TestServer => Future[Assertion]): Future[Assertion] = { f(this) } def bye(): Unit = { hostLog("sending exit") sendJsonRpc( """{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "shutdown" } }""" ) val deadline = 10.seconds.fromNow while (!deadline.isOverdue && process.isAlive) { Thread.sleep(10) } // We gave the server a chance to exit but it didn't within a reasonable time frame. if (deadline.isOverdue) process.destroy() } def sendJsonRpc(message: String): Unit = { writeLine(s"""Content-Length: ${message.size + 2}""") writeLine("") writeLine(message) } private def writeLine(s: String): Unit = { def writeEndLine(): Unit = { val retByte: Byte = '\r'.toByte val delimiter: Byte = '\n'.toByte out.write(retByte.toInt) out.write(delimiter.toInt) out.flush } if (s != "") { out.write(s.getBytes("UTF-8")) } writeEndLine } def readFrame: Future[Option[String]] = Future { def getContentLength: Int = { readLine map { line => line.drop(16).toInt } getOrElse (0) } val l = getContentLength readLine readLine readContentLength(l) } final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = { val deadline = duration.fromNow @tailrec def impl(): Boolean = { val res = try { Await.result(readFrame, deadline.timeLeft).fold(false)(f) } catch { case _: TimeoutException => resetConnection() // create a new connection to invalidate the running readFrame future false } if (!res && !deadline.isOverdue) impl() else !deadline.isOverdue() } impl() } final def neverReceive(duration: FiniteDuration)(f: String => Boolean): Boolean = { val deadline = duration.fromNow @tailrec def impl(): Boolean = { val res = try { Await.result(readFrame, deadline.timeLeft).fold(true)(s => !f(s)) } catch { case _: TimeoutException => resetConnection() // create a new connection to invalidate the running readFrame future true } if (res && !deadline.isOverdue) impl else res || !deadline.isOverdue } impl() } def readLine: Option[String] = { if (buffer.isEmpty) { val bytesRead = in.read(readBuffer) if (bytesRead > 0) { buffer = buffer ++ readBuffer.toVector.take(bytesRead) } } val delimPos = buffer.indexOf(delimiter) if (delimPos > 0) { val chunk0 = buffer.take(delimPos) buffer = buffer.drop(delimPos + 1) // remove \r at the end of line. val chunk1 = if (chunk0.lastOption contains RetByte) chunk0.dropRight(1) else chunk0 Some(new String(chunk1.toArray, "utf-8")) } else None // no EOL yet, so skip this turn. } def readContentLength(length: Int): Option[String] = { if (buffer.isEmpty) { val bytesRead = in.read(readBuffer) if (bytesRead > 0) { buffer = buffer ++ readBuffer.toVector.take(bytesRead) } } if (length <= buffer.size) { val chunk = buffer.take(length) buffer = buffer.drop(length) Some(new String(chunk.toArray, "utf-8")) } else None // have not read enough yet, so skip this turn. } }