mirror of https://github.com/sbt/sbt.git
335 lines
10 KiB
Scala
335 lines
10 KiB
Scala
/*
|
|
* sbt
|
|
* Copyright 2011 - 2018, Lightbend, Inc.
|
|
* Copyright 2008 - 2010, Mark Harrah
|
|
* Licensed under Apache License 2.0 (see LICENSE)
|
|
*/
|
|
|
|
package testpkg
|
|
|
|
import java.io.{ File, IOException }
|
|
import java.nio.file.Path
|
|
import java.util.concurrent.TimeoutException
|
|
|
|
import verify._
|
|
import sbt.RunFromSourceMain
|
|
import sbt.io.IO
|
|
import sbt.io.syntax._
|
|
import sbt.protocol.ClientSocket
|
|
|
|
import scala.annotation.tailrec
|
|
import scala.concurrent._
|
|
import scala.concurrent.duration._
|
|
import scala.util.{ Success, Try }
|
|
|
|
trait AbstractServerTest extends TestSuite[Unit] {
|
|
implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
|
|
private var temp: File = _
|
|
var svr: TestServer = _
|
|
def testDirectory: String
|
|
def testPath: Path = temp.toPath.resolve(testDirectory)
|
|
|
|
private val targetDir: File = {
|
|
val p0 = new File("..").getAbsoluteFile.getCanonicalFile / "target"
|
|
val p1 = new File("target").getAbsoluteFile
|
|
if (p0.exists) p0
|
|
else p1
|
|
}
|
|
|
|
override def setupSuite(): Unit = {
|
|
temp = targetDir / "test-server" / testDirectory
|
|
if (temp.exists) {
|
|
IO.delete(temp)
|
|
}
|
|
val classpath = sys.props.get("sbt.server.classpath") match {
|
|
case Some(s: String) => s.split(java.io.File.pathSeparator).map(file)
|
|
case _ => throw new IllegalStateException("No server classpath was specified.")
|
|
}
|
|
val sbtVersion = sys.props.get("sbt.server.version") match {
|
|
case Some(v: String) => v
|
|
case _ => throw new IllegalStateException("No server version was specified.")
|
|
}
|
|
val scalaVersion = sys.props.get("sbt.server.scala.version") match {
|
|
case Some(v: String) => v
|
|
case _ => throw new IllegalStateException("No server scala version was specified.")
|
|
}
|
|
svr = TestServer.get(testDirectory, scalaVersion, sbtVersion, classpath, temp)
|
|
}
|
|
override def tearDownSuite(): Unit = {
|
|
svr.bye()
|
|
svr = null
|
|
IO.delete(temp)
|
|
}
|
|
override def setup(): Unit = ()
|
|
override def tearDown(env: Unit): Unit = ()
|
|
}
|
|
|
|
object TestServer {
|
|
// forking affects this
|
|
private val serverTestBase: File = {
|
|
val p0 = new File(".").getAbsoluteFile / "server-test" / "src" / "server-test"
|
|
val p1 = new File(".").getAbsoluteFile / "src" / "server-test"
|
|
if (p0.exists) p0
|
|
else p1
|
|
}
|
|
|
|
def get(
|
|
testBuild: String,
|
|
scalaVersion: String,
|
|
sbtVersion: String,
|
|
classpath: Seq[File],
|
|
temp: File
|
|
): TestServer = {
|
|
println(s"Starting test server $testBuild")
|
|
IO.copyDirectory(serverTestBase / testBuild, temp / testBuild)
|
|
|
|
// Each test server instance will be executed in a Thread pool separated from the tests
|
|
val testServer = TestServer(temp / testBuild, scalaVersion, sbtVersion, classpath)
|
|
// checking last log message after initialization
|
|
// if something goes wrong here the communication streams are corrupted, restarting
|
|
val init =
|
|
Try {
|
|
testServer.waitForString(30.seconds) { s =>
|
|
println(s)
|
|
s contains """"message":"Done""""
|
|
}
|
|
}
|
|
init.get
|
|
testServer
|
|
}
|
|
|
|
def withTestServer(
|
|
testBuild: String
|
|
)(f: TestServer => Future[Unit]): Future[Unit] = {
|
|
println(s"Starting test")
|
|
IO.withTemporaryDirectory { temp =>
|
|
IO.copyDirectory(serverTestBase / testBuild, temp / testBuild)
|
|
withTestServer(testBuild, temp / testBuild)(f)
|
|
}
|
|
}
|
|
|
|
def withTestServer(testBuild: String, baseDirectory: File)(
|
|
f: TestServer => Future[Unit]
|
|
): Future[Unit] = {
|
|
val classpath = sys.props.get("sbt.server.classpath") match {
|
|
case Some(s: String) => s.split(java.io.File.pathSeparator).map(file)
|
|
case _ => throw new IllegalStateException("No server classpath was specified.")
|
|
}
|
|
val sbtVersion = sys.props.get("sbt.server.version") match {
|
|
case Some(v: String) => v
|
|
case _ => throw new IllegalStateException("No server version was specified.")
|
|
}
|
|
val scalaVersion = sys.props.get("sbt.server.scala.version") match {
|
|
case Some(v: String) => v
|
|
case _ => throw new IllegalStateException("No server scala version was specified.")
|
|
}
|
|
// Each test server instance will be executed in a Thread pool separated from the tests
|
|
val testServer = TestServer(baseDirectory, scalaVersion, sbtVersion, classpath)
|
|
// checking last log message after initialization
|
|
// if something goes wrong here the communication streams are corrupted, restarting
|
|
val init =
|
|
Try {
|
|
testServer.waitForString(30.seconds) { s =>
|
|
if (s.nonEmpty) println(s)
|
|
s contains """"message":"Done""""
|
|
}
|
|
}
|
|
|
|
init match {
|
|
case Success(_) =>
|
|
try {
|
|
f(testServer)
|
|
} finally {
|
|
try {
|
|
testServer.bye()
|
|
} finally {}
|
|
}
|
|
case _ =>
|
|
try {
|
|
testServer.bye()
|
|
} finally {}
|
|
hostLog("Server started but not connected properly... restarting...")
|
|
withTestServer(testBuild)(f)
|
|
}
|
|
}
|
|
|
|
def hostLog(s: String): Unit = {
|
|
println(s"""[${scala.Console.MAGENTA}build-1${scala.Console.RESET}] $s""")
|
|
}
|
|
}
|
|
|
|
case class TestServer(
|
|
baseDirectory: File,
|
|
scalaVersion: String,
|
|
sbtVersion: String,
|
|
classpath: Seq[File]
|
|
) {
|
|
import scala.concurrent.ExecutionContext.Implicits._
|
|
import TestServer.hostLog
|
|
|
|
val readBuffer = new Array[Byte](40960)
|
|
var buffer: Vector[Byte] = Vector.empty
|
|
var bytesRead = 0
|
|
private val delimiter: Byte = '\n'.toByte
|
|
private val RetByte = '\r'.toByte
|
|
|
|
hostLog("fork to a new sbt instance")
|
|
val process = RunFromSourceMain.fork(baseDirectory, scalaVersion, sbtVersion, classpath)
|
|
|
|
lazy val portfile = baseDirectory / "project" / "target" / "active.json"
|
|
|
|
def portfileIsEmpty(): Boolean =
|
|
try IO.read(portfile).isEmpty
|
|
catch { case _: IOException => true }
|
|
def waitForPortfile(duration: FiniteDuration): Unit = {
|
|
val deadline = duration.fromNow
|
|
var nextLog = 10.seconds.fromNow
|
|
while (portfileIsEmpty && !deadline.isOverdue && process.isAlive) {
|
|
if (nextLog.isOverdue) {
|
|
hostLog("waiting for the server...")
|
|
nextLog = 10.seconds.fromNow
|
|
}
|
|
Thread.sleep(10) // Don't spam the portfile
|
|
}
|
|
if (deadline.isOverdue) sys.error(s"Timeout. $portfile is not found.")
|
|
if (!process.isAlive) sys.error(s"Server unexpectedly terminated.")
|
|
}
|
|
private val waitDuration: FiniteDuration = 120.seconds
|
|
hostLog(s"wait $waitDuration until the server is ready to respond")
|
|
waitForPortfile(90.seconds)
|
|
|
|
// make connection to the socket described in the portfile
|
|
var (sk, _) = ClientSocket.socket(portfile)
|
|
var out = sk.getOutputStream
|
|
var in = sk.getInputStream
|
|
|
|
// initiate handshake
|
|
sendJsonRpc(
|
|
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }"""
|
|
)
|
|
|
|
def resetConnection() = {
|
|
Option(sk).foreach(_.close())
|
|
sk = ClientSocket.socket(portfile)._1
|
|
out = sk.getOutputStream
|
|
in = sk.getInputStream
|
|
|
|
sendJsonRpc(
|
|
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }"""
|
|
)
|
|
}
|
|
|
|
def test(f: TestServer => Future[Assertion]): Future[Assertion] = {
|
|
f(this)
|
|
}
|
|
|
|
def bye(): Unit = {
|
|
hostLog("sending exit")
|
|
sendJsonRpc(
|
|
"""{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "shutdown" } }"""
|
|
)
|
|
val deadline = 10.seconds.fromNow
|
|
while (!deadline.isOverdue && process.isAlive) {
|
|
Thread.sleep(10)
|
|
}
|
|
// We gave the server a chance to exit but it didn't within a reasonable time frame.
|
|
if (deadline.isOverdue) process.destroy()
|
|
}
|
|
|
|
def sendJsonRpc(message: String): Unit = {
|
|
writeLine(s"""Content-Length: ${message.size + 2}""")
|
|
writeLine("")
|
|
writeLine(message)
|
|
}
|
|
|
|
private def writeLine(s: String): Unit = {
|
|
def writeEndLine(): Unit = {
|
|
val retByte: Byte = '\r'.toByte
|
|
val delimiter: Byte = '\n'.toByte
|
|
out.write(retByte.toInt)
|
|
out.write(delimiter.toInt)
|
|
out.flush
|
|
}
|
|
|
|
if (s != "") {
|
|
out.write(s.getBytes("UTF-8"))
|
|
}
|
|
writeEndLine
|
|
}
|
|
|
|
def readFrame: Future[Option[String]] = Future {
|
|
def getContentLength: Int = {
|
|
readLine map { line =>
|
|
line.drop(16).toInt
|
|
} getOrElse (0)
|
|
}
|
|
val l = getContentLength
|
|
readLine
|
|
readLine
|
|
readContentLength(l)
|
|
}
|
|
|
|
final def waitForString(duration: FiniteDuration)(f: String => Boolean): Boolean = {
|
|
val deadline = duration.fromNow
|
|
@tailrec def impl(): Boolean = {
|
|
val res = try {
|
|
Await.result(readFrame, deadline.timeLeft).fold(false)(f)
|
|
} catch {
|
|
case _: TimeoutException =>
|
|
resetConnection() // create a new connection to invalidate the running readFrame future
|
|
false
|
|
}
|
|
if (!res && !deadline.isOverdue) impl() else !deadline.isOverdue()
|
|
}
|
|
impl()
|
|
}
|
|
|
|
final def neverReceive(duration: FiniteDuration)(f: String => Boolean): Boolean = {
|
|
val deadline = duration.fromNow
|
|
@tailrec
|
|
def impl(): Boolean = {
|
|
val res = try {
|
|
Await.result(readFrame, deadline.timeLeft).fold(true)(s => !f(s))
|
|
} catch {
|
|
case _: TimeoutException =>
|
|
resetConnection() // create a new connection to invalidate the running readFrame future
|
|
true
|
|
}
|
|
if (res && !deadline.isOverdue) impl else res || !deadline.isOverdue
|
|
}
|
|
impl()
|
|
}
|
|
|
|
def readLine: Option[String] = {
|
|
if (buffer.isEmpty) {
|
|
val bytesRead = in.read(readBuffer)
|
|
if (bytesRead > 0) {
|
|
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
|
|
}
|
|
}
|
|
val delimPos = buffer.indexOf(delimiter)
|
|
if (delimPos > 0) {
|
|
val chunk0 = buffer.take(delimPos)
|
|
buffer = buffer.drop(delimPos + 1)
|
|
// remove \r at the end of line.
|
|
val chunk1 = if (chunk0.lastOption contains RetByte) chunk0.dropRight(1) else chunk0
|
|
Some(new String(chunk1.toArray, "utf-8"))
|
|
} else None // no EOL yet, so skip this turn.
|
|
}
|
|
|
|
def readContentLength(length: Int): Option[String] = {
|
|
if (buffer.isEmpty) {
|
|
val bytesRead = in.read(readBuffer)
|
|
if (bytesRead > 0) {
|
|
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
|
|
}
|
|
}
|
|
if (length <= buffer.size) {
|
|
val chunk = buffer.take(length)
|
|
buffer = buffer.drop(length)
|
|
Some(new String(chunk.toArray, "utf-8"))
|
|
} else None // have not read enough yet, so skip this turn.
|
|
}
|
|
|
|
}
|