sbt/server-test/src/test/scala/testpkg/TestServer.scala

372 lines
12 KiB
Scala

/*
* sbt
* Copyright 2011 - 2018, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under Apache License 2.0 (see LICENSE)
*/
package testpkg
import java.io.{ File, IOException }
import java.net.Socket
import java.nio.file.{ Files, Path }
import java.util.concurrent.{ LinkedBlockingQueue, TimeUnit }
import java.util.concurrent.atomic.AtomicBoolean
import sbt.{ ForkOptions, OutputStrategy, RunFromSourceMain }
import sbt.io.IO
import sbt.io.syntax._
import sbt.protocol.ClientSocket
import sjsonnew.JsonReader
import sjsonnew.support.scalajson.unsafe.{ Converter, Parser }
import scala.annotation.tailrec
import scala.concurrent._
import scala.concurrent.duration._
import scala.util.{ Failure, Success, Try }
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.BeforeAndAfterAll
trait AbstractServerTest extends AnyFunSuite with BeforeAndAfterAll {
private var temp: File = _
var svr: TestServer = _
def testDirectory: String
def testPath: Path = temp.toPath.resolve(testDirectory)
def sbtVersion = sys.props
.get("sbt.server.version")
.getOrElse(throw new IllegalStateException("No server version was specified."))
private val targetDir: File = {
val p0 = new File("..").getAbsoluteFile.getCanonicalFile / "target"
val p1 = new File("target").getAbsoluteFile
if (p0.exists) p0
else p1
}
override def beforeAll(): Unit = {
val base = Files.createTempDirectory(
Files.createDirectories(targetDir.toPath.resolve("test-server")),
"server-test"
)
temp = base.toFile
val classpath = TestProperties.classpath.split(File.pathSeparator).map(new File(_))
val sbtVersion = TestProperties.version
val scalaVersion = TestProperties.scalaVersion
svr = TestServer.get(testDirectory, scalaVersion, sbtVersion, classpath.toSeq, temp)
}
override protected def afterAll(): Unit = {
svr.bye()
svr = null
IO.delete(temp)
}
}
object TestServer {
// forking affects this
private val serverTestBase: File = {
val p0 = new File(".").getAbsoluteFile / "server-test" / "src" / "server-test"
val p1 = new File(".").getAbsoluteFile / "src" / "server-test"
if (p0.exists) p0
else p1
}
def get(
testBuild: String,
scalaVersion: String,
sbtVersion: String,
classpath: Seq[File],
temp: File
): TestServer = {
println(s"Starting test server $testBuild")
IO.copyDirectory(serverTestBase / testBuild, temp / testBuild)
// Each test server instance will be executed in a Thread pool separated from the tests
val testServer = TestServer(temp / testBuild, scalaVersion, sbtVersion, classpath)
// checking last log message after initialization
// if something goes wrong here the communication streams are corrupted, restarting
val init =
Try {
testServer.waitForString(10.seconds) { s =>
println(s)
s contains """"capabilities":{""""
}
}
init.get
testServer
}
def withTestServer(
testBuild: String
)(f: TestServer => Future[Unit]): Future[Unit] = {
println(s"Starting test")
IO.withTemporaryDirectory { temp =>
IO.copyDirectory(serverTestBase / testBuild, temp / testBuild)
withTestServer(testBuild, temp / testBuild)(f)
}
}
def withTestServer(testBuild: String, baseDirectory: File)(
f: TestServer => Future[Unit]
): Future[Unit] = {
val classpath = sys.props.get("sbt.server.classpath") match {
case Some(s: String) => s.split(java.io.File.pathSeparator).map(file)
case _ => throw new IllegalStateException("No server classpath was specified.")
}
val sbtVersion = sys.props.get("sbt.server.version") match {
case Some(v: String) => v
case _ => throw new IllegalStateException("No server version was specified.")
}
val scalaVersion = sys.props.get("sbt.server.scala.version") match {
case Some(v: String) => v
case _ => throw new IllegalStateException("No server scala version was specified.")
}
// Each test server instance will be executed in a Thread pool separated from the tests
val testServer = TestServer(baseDirectory, scalaVersion, sbtVersion, classpath.toSeq)
// checking last log message after initialization
// if something goes wrong here the communication streams are corrupted, restarting
val init =
Try {
testServer.waitForString(10.seconds) { s =>
if (s.nonEmpty) println(s)
s contains """"capabilities":{""""
}
}
init match {
case Success(_) =>
try {
f(testServer)
} finally {
try {
testServer.bye()
} finally {}
}
case _ =>
try {
testServer.bye()
} finally {}
hostLog("Server started but not connected properly... restarting...")
withTestServer(testBuild)(f)
}
}
def hostLog(s: String): Unit = {
println(s"""[${scala.Console.MAGENTA}build-1${scala.Console.RESET}] $s""")
}
}
case class TestServer(
baseDirectory: File,
scalaVersion: String,
sbtVersion: String,
classpath: Seq[File]
) {
import TestServer.hostLog
hostLog("fork to a new sbt instance")
val forkOptions =
ForkOptions()
.withOutputStrategy(OutputStrategy.StdoutOutput)
.withRunJVMOptions(
Vector(
"-Djline.terminal=none",
"-Dsbt.io.virtual=false",
// "-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=1044"
)
)
val process =
RunFromSourceMain.fork(forkOptions, baseDirectory, scalaVersion, sbtVersion, classpath)
lazy val portfile = baseDirectory / "project" / "target" / "active.json"
def portfileIsEmpty(): Boolean =
try IO.read(portfile).isEmpty
catch { case _: IOException => true }
def waitForPortfile(duration: FiniteDuration): Unit = {
val deadline = duration.fromNow
var nextLog = 10.seconds.fromNow
while (portfileIsEmpty() && !deadline.isOverdue && process.isAlive) {
if (nextLog.isOverdue) {
hostLog("waiting for the server...")
nextLog = 10.seconds.fromNow
}
Thread.sleep(10) // Don't spam the portfile
}
if (deadline.isOverdue) sys.error(s"Timeout. $portfile is not found.")
if (!process.isAlive) sys.error(s"Server unexpectedly terminated.")
}
private val waitDuration: FiniteDuration = 1.minute
hostLog(s"wait $waitDuration until the server is ready to respond")
waitForPortfile(waitDuration)
@tailrec
private def connect(attempt: Int): Socket = {
val res =
try Some(ClientSocket.socket(portfile)._1)
catch { case _: IOException if attempt < 10 => None }
res match {
case Some(s) => s
case _ =>
Thread.sleep(100)
connect(attempt + 1)
}
}
// make connection to the socket described in the portfile
val sk = connect(0)
val out = sk.getOutputStream
val in = sk.getInputStream
private val lines = new LinkedBlockingQueue[String]
val running = new AtomicBoolean(true)
val readThread =
new Thread(
() => {
while (running.get) {
try lines.put(sbt.ReadJson(in, running))
catch { case _: Exception => running.set(false) }
}
},
"sbt-server-test-read-thread"
) {
setDaemon(true)
start()
}
// initiate handshake
sendJsonRpc(
s"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { "skipAnalysis": true } } }"""
)
def test(f: TestServer => Future[Unit]): Future[Unit] = f(this)
def bye(): Unit =
try {
running.set(false)
hostLog("sending exit")
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "shutdown" } }"""
)
val deadline = 5.seconds.fromNow
while (!deadline.isOverdue && process.isAlive) {
Thread.sleep(10)
}
// We gave the server a chance to exit but it didn't within a reasonable time frame.
if (deadline.isOverdue && process.isAlive) {
process.destroy()
val newDeadline = 10.seconds.fromNow
while (!newDeadline.isOverdue && process.isAlive) {
Thread.sleep(10)
}
}
if (process.isAlive) throw new IllegalStateException(s"process $process failed to exit")
} finally {
readThread.interrupt()
/*
* The UnixDomainSocket input stream cannot be closed while a thread is
* reading from it (even if the UnixDomainSocket itself is closed):
* https://github.com/sbt/ipcsocket/blob/f02d29092f9f0c57e5c4b276a31fa16975ddf66e/src/main/java/org/scalasbt/ipcsocket/UnixDomainSocket.java#L111-L118
* This makes it impossible to interrupt the readThread until after the
* server process has exited which closes the ServerSocket which does
* cause the input stream to be closed. We could change the behavior of
* ipcsocket, but that seems risky without knowing exactly why the behavior
* exists. For now, ensure that we are able to interrupt and join the
* read thread and throw an exception if not.
*/
readThread.join(5000)
if (readThread.isAlive) throw new IllegalStateException(s"Unable to join read thread")
}
def sendJsonRpc(message: String): Unit = {
writeLine(s"""Content-Length: ${message.size + 2}""")
writeLine("")
writeLine(message)
}
private def writeLine(s: String): Unit = {
def writeEndLine(): Unit = {
val retByte: Byte = '\r'.toByte
val delimiter: Byte = '\n'.toByte
out.write(retByte.toInt)
out.write(delimiter.toInt)
out.flush
}
if (s != "") {
out.write(s.getBytes("UTF-8"))
}
writeEndLine()
}
final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = {
val deadline = duration.fromNow
@tailrec def impl(): Boolean =
lines.poll(deadline.timeLeft.toMillis, TimeUnit.MILLISECONDS) match {
case null => false
case s => if (!f(s) && !deadline.isOverdue) impl() else !deadline.isOverdue()
}
impl()
}
final def waitFor[T: JsonReader](duration: FiniteDuration, debug: Boolean = false): T = {
val deadline = duration.fromNow
var lastEx: Throwable = null
@tailrec def impl(): T =
lines.poll(deadline.timeLeft.toMillis, TimeUnit.MILLISECONDS) match {
case null =>
if (lastEx != null) throw lastEx
else throw new TimeoutException
case s =>
if debug then println(s)
Parser
.parseFromString(s)
.flatMap { jvalue =>
Converter.fromJson[T](
jvalue.toStandard
.asInstanceOf[sjsonnew.shaded.scalajson.ast.JObject]
.value("result")
.toUnsafe
)
} match {
case Success(value) =>
value
case Failure(exception) =>
if (deadline.isOverdue) {
val ex = new TimeoutException()
ex.initCause(exception)
throw ex
} else {
lastEx = exception
impl()
}
}
}
impl()
}
final def waitForResponse(duration: FiniteDuration, id: Int): String = {
val deadline = duration.fromNow
@tailrec def impl(): String =
lines.poll(deadline.timeLeft.toMillis, TimeUnit.MILLISECONDS) match {
case null =>
throw new TimeoutException()
case s =>
val s1 = s
val correctId = s1.contains("\"id\":\"" + id + "\"")
if (!correctId && !deadline.isOverdue) impl()
else if (deadline.isOverdue)
throw new TimeoutException()
else s
}
impl()
}
final def neverReceive(duration: FiniteDuration)(f: String => Boolean): Boolean = {
val deadline = duration.fromNow
@tailrec
def impl(): Boolean =
lines.poll(deadline.timeLeft.toMillis, TimeUnit.MILLISECONDS) match {
case null => true
case s => if (!f(s)) impl() else false
}
impl()
}
}