Merge pull request #4032 from eed3si9n/wip/servertest

improve server testing
This commit is contained in:
Dale Wijnand 2018-04-18 07:12:53 +02:00 committed by GitHub
commit dfff1ed928
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 215 additions and 198 deletions

View File

@ -510,10 +510,12 @@ lazy val sbtProj = (project in file("sbt"))
buildInfoKeys in Test := Seq[BuildInfoKey](
// WORKAROUND https://github.com/sbt/sbt-buildinfo/issues/117
BuildInfoKey.map((fullClasspath in Compile).taskValue) { case (ident, cp) => ident -> cp.files },
classDirectory in Compile,
classDirectory in Test,
),
connectInput in run in Test := true,
outputStrategy in run in Test := Some(StdoutOutput),
fork in Test := true,
Test / run / connectInput := true,
Test / run / outputStrategy := Some(StdoutOutput),
Test / run / fork := true,
)
.configure(addSbtCompilerBridge)

View File

@ -27,6 +27,6 @@ object TestUtil {
val mainClassesDir = buildinfo.TestBuildInfo.classDirectory
val testClassesDir = buildinfo.TestBuildInfo.test_classDirectory
val depsClasspath = buildinfo.TestBuildInfo.dependencyClasspath
mainClassesDir +: testClassesDir +: depsClasspath mkString ":"
mainClassesDir +: testClassesDir +: depsClasspath mkString java.io.File.pathSeparator
}
}

View File

@ -7,14 +7,34 @@
package sbt
import scala.util.Try
import sbt.util.LogExchange
import scala.annotation.tailrec
import buildinfo.TestBuildInfo
import xsbti._
object RunFromSourceMain {
private val sbtVersion = "1.0.3" // "dev"
private val scalaVersion = "2.12.4"
def fork(workingDirectory: File): Try[Unit] = {
val fo = ForkOptions()
.withOutputStrategy(OutputStrategy.StdoutOutput)
fork(fo, workingDirectory)
}
def fork(fo0: ForkOptions, workingDirectory: File): Try[Unit] = {
val fo = fo0
.withWorkingDirectory(workingDirectory)
implicit val runner = new ForkRun(fo)
val cp = {
TestBuildInfo.test_classDirectory +: TestBuildInfo.fullClasspath
}
val options = Vector(workingDirectory.toString)
val log = LogExchange.logger("RunFromSourceMain.fork", None, None)
Run.run("sbt.RunFromSourceMain", cp, options, log)
}
def main(args: Array[String]): Unit = args match {
case Array() => sys.error(s"Must specify working directory as the first argument")
case Array(wd, args @ _*) => run(file(wd), args)

View File

@ -1,193 +0,0 @@
/*
* sbt
* Copyright 2011 - 2017, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under BSD-3-Clause license (see LICENSE)
*/
package sbt
import org.scalatest._
import scala.concurrent._
import scala.annotation.tailrec
import java.io.{ InputStream, OutputStream }
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{ ThreadFactory, ThreadPoolExecutor }
import sbt.protocol.ClientSocket
class ServerSpec extends AsyncFlatSpec with Matchers {
import ServerSpec._
"server" should "start" in {
withBuildSocket("handshake") { (out, in, tkn) =>
writeLine(
"""{ "jsonrpc": "2.0", "id": 3, "method": "sbt/setting", "params": { "setting": "root/name" } }""",
out)
Thread.sleep(100)
assert(waitFor(in, 10) { s =>
s contains """"id":3"""
})
}
}
@tailrec
private[this] def waitFor(in: InputStream, num: Int)(f: String => Boolean): Boolean = {
if (num < 0) false
else
readFrame(in) match {
case Some(x) if f(x) => true
case _ =>
waitFor(in, num - 1)(f)
}
}
}
object ServerSpec {
private val serverTestBase: File = new File(".").getAbsoluteFile / "sbt" / "src" / "server-test"
private val nextThreadId = new AtomicInteger(1)
private val threadGroup = Thread.currentThread.getThreadGroup()
val readBuffer = new Array[Byte](4096)
var buffer: Vector[Byte] = Vector.empty
var bytesRead = 0
private val delimiter: Byte = '\n'.toByte
private val RetByte = '\r'.toByte
private val threadFactory = new ThreadFactory() {
override def newThread(runnable: Runnable): Thread = {
val thread =
new Thread(threadGroup,
runnable,
s"sbt-test-server-threads-${nextThreadId.getAndIncrement}")
// Do NOT setDaemon because then the code in TaskExit.scala in sbt will insta-kill
// the backgrounded process, at least for the case of the run task.
thread
}
}
private val executor = new ThreadPoolExecutor(
0, /* corePoolSize */
1, /* maxPoolSize, max # of servers */
2,
java.util.concurrent.TimeUnit.SECONDS,
/* keep alive unused threads this long (if corePoolSize < maxPoolSize) */
new java.util.concurrent.SynchronousQueue[Runnable](),
threadFactory
)
def backgroundRun(baseDir: File, args: Seq[String]): Unit = {
executor.execute(new Runnable {
def run(): Unit = {
RunFromSourceMain.run(baseDir, args)
}
})
}
def shutdown(): Unit = executor.shutdown()
def withBuildSocket(testBuild: String)(
f: (OutputStream, InputStream, Option[String]) => Future[Assertion]): Future[Assertion] = {
IO.withTemporaryDirectory { temp =>
IO.copyDirectory(serverTestBase / testBuild, temp / testBuild)
withBuildSocket(temp / testBuild)(f)
}
}
def sendJsonRpc(message: String, out: OutputStream): Unit = {
writeLine(s"""Content-Length: ${message.size + 2}""", out)
writeLine("", out)
writeLine(message, out)
}
def readFrame(in: InputStream): Option[String] = {
val l = contentLength(in)
readLine(in)
readLine(in)
readContentLength(in, l)
}
def contentLength(in: InputStream): Int = {
readLine(in) map { line =>
line.drop(16).toInt
} getOrElse (0)
}
def readLine(in: InputStream): Option[String] = {
if (buffer.isEmpty) {
val bytesRead = in.read(readBuffer)
if (bytesRead > 0) {
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
}
}
val delimPos = buffer.indexOf(delimiter)
if (delimPos > 0) {
val chunk0 = buffer.take(delimPos)
buffer = buffer.drop(delimPos + 1)
// remove \r at the end of line.
val chunk1 = if (chunk0.lastOption contains RetByte) chunk0.dropRight(1) else chunk0
Some(new String(chunk1.toArray, "utf-8"))
} else None // no EOL yet, so skip this turn.
}
def readContentLength(in: InputStream, length: Int): Option[String] = {
if (buffer.isEmpty) {
val bytesRead = in.read(readBuffer)
if (bytesRead > 0) {
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
}
}
if (length <= buffer.size) {
val chunk = buffer.take(length)
buffer = buffer.drop(length)
Some(new String(chunk.toArray, "utf-8"))
} else None // have not read enough yet, so skip this turn.
}
def writeLine(s: String, out: OutputStream): Unit = {
def writeEndLine(): Unit = {
val retByte: Byte = '\r'.toByte
val delimiter: Byte = '\n'.toByte
out.write(retByte.toInt)
out.write(delimiter.toInt)
out.flush
}
if (s != "") {
out.write(s.getBytes("UTF-8"))
}
writeEndLine
}
def withBuildSocket(baseDirectory: File)(
f: (OutputStream, InputStream, Option[String]) => Future[Assertion]): Future[Assertion] = {
backgroundRun(baseDirectory, Nil)
val portfile = baseDirectory / "project" / "target" / "active.json"
def waitForPortfile(n: Int): Unit =
if (portfile.exists) ()
else {
if (n <= 0) sys.error(s"Timeout. $portfile is not found.")
else {
Thread.sleep(1000)
waitForPortfile(n - 1)
}
}
waitForPortfile(10)
val (sk, tkn) = ClientSocket.socket(portfile)
val out = sk.getOutputStream
val in = sk.getInputStream
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""",
out)
try {
f(out, in, tkn)
} finally {
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "exit" } }""",
out)
shutdown()
}
}
}

View File

@ -0,0 +1,188 @@
/*
* sbt
* Copyright 2011 - 2017, Lightbend, Inc.
* Copyright 2008 - 2010, Mark Harrah
* Licensed under BSD-3-Clause license (see LICENSE)
*/
package testpkg
import org.scalatest._
import scala.concurrent._
import scala.annotation.tailrec
import sbt.protocol.ClientSocket
import TestServer.withTestServer
import java.io.File
import sbt.io.syntax._
import sbt.io.IO
import sbt.RunFromSourceMain
class ServerSpec extends AsyncFreeSpec with Matchers {
"server" - {
"should start" in withTestServer("handshake") { p =>
p.writeLine(
"""{ "jsonrpc": "2.0", "id": "3", "method": "sbt/setting", "params": { "setting": "root/name" } }""")
assert(p.waitForString(10) { s =>
s contains """"id":"3""""
})
}
"return number id when number id is sent" in withTestServer("handshake") { p =>
p.writeLine(
"""{ "jsonrpc": "2.0", "id": 3, "method": "sbt/setting", "params": { "setting": "root/name" } }""")
assert(p.waitForString(10) { s =>
s contains """"id":3"""
})
}
}
}
object TestServer {
private val serverTestBase: File = new File(".").getAbsoluteFile / "sbt" / "src" / "server-test"
def withTestServer(testBuild: String)(f: TestServer => Future[Assertion]): Future[Assertion] = {
IO.withTemporaryDirectory { temp =>
IO.copyDirectory(serverTestBase / testBuild, temp / testBuild)
withTestServer(temp / testBuild)(f)
}
}
def withTestServer(baseDirectory: File)(f: TestServer => Future[Assertion]): Future[Assertion] = {
val testServer = TestServer(baseDirectory)
try {
f(testServer)
} finally {
testServer.bye()
}
}
def hostLog(s: String): Unit = {
println(s"""[${scala.Console.MAGENTA}build-1${scala.Console.RESET}] $s""")
}
}
case class TestServer(baseDirectory: File) {
import TestServer.hostLog
val readBuffer = new Array[Byte](4096)
var buffer: Vector[Byte] = Vector.empty
var bytesRead = 0
private val delimiter: Byte = '\n'.toByte
private val RetByte = '\r'.toByte
hostLog("fork to a new sbt instance")
import scala.concurrent.ExecutionContext.Implicits.global
Future {
RunFromSourceMain.fork(baseDirectory)
()
}
lazy val portfile = baseDirectory / "project" / "target" / "active.json"
hostLog("wait 30s until the server is ready to respond")
def waitForPortfile(n: Int): Unit =
if (portfile.exists) ()
else {
if (n <= 0) sys.error(s"Timeout. $portfile is not found.")
else {
Thread.sleep(1000)
waitForPortfile(n - 1)
}
}
waitForPortfile(30)
// make connection to the socket described in the portfile
val (sk, tkn) = ClientSocket.socket(portfile)
val out = sk.getOutputStream
val in = sk.getInputStream
// initiate handshake
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 1, "method": "initialize", "params": { "initializationOptions": { } } }""")
def test(f: TestServer => Future[Assertion]): Future[Assertion] = {
f(this)
}
def bye(): Unit = {
hostLog("sending exit")
sendJsonRpc(
"""{ "jsonrpc": "2.0", "id": 9, "method": "sbt/exec", "params": { "commandLine": "exit" } }""")
}
def sendJsonRpc(message: String): Unit = {
writeLine(s"""Content-Length: ${message.size + 2}""")
writeLine("")
writeLine(message)
}
def writeLine(s: String): Unit = {
def writeEndLine(): Unit = {
val retByte: Byte = '\r'.toByte
val delimiter: Byte = '\n'.toByte
out.write(retByte.toInt)
out.write(delimiter.toInt)
out.flush
}
if (s != "") {
out.write(s.getBytes("UTF-8"))
}
writeEndLine
}
def readFrame: Option[String] = {
def getContentLength: Int = {
readLine map { line =>
line.drop(16).toInt
} getOrElse (0)
}
val l = getContentLength
readLine
readLine
readContentLength(l)
}
@tailrec
final def waitForString(num: Int)(f: String => Boolean): Boolean = {
if (num < 0) false
else
readFrame match {
case Some(x) if f(x) => true
case _ =>
waitForString(num - 1)(f)
}
}
def readLine: Option[String] = {
if (buffer.isEmpty) {
val bytesRead = in.read(readBuffer)
if (bytesRead > 0) {
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
}
}
val delimPos = buffer.indexOf(delimiter)
if (delimPos > 0) {
val chunk0 = buffer.take(delimPos)
buffer = buffer.drop(delimPos + 1)
// remove \r at the end of line.
val chunk1 = if (chunk0.lastOption contains RetByte) chunk0.dropRight(1) else chunk0
Some(new String(chunk1.toArray, "utf-8"))
} else None // no EOL yet, so skip this turn.
}
def readContentLength(length: Int): Option[String] = {
if (buffer.isEmpty) {
val bytesRead = in.read(readBuffer)
if (bytesRead > 0) {
buffer = buffer ++ readBuffer.toVector.take(bytesRead)
}
}
if (length <= buffer.size) {
val chunk = buffer.take(length)
buffer = buffer.drop(length)
Some(new String(chunk.toArray, "utf-8"))
} else None // have not read enough yet, so skip this turn.
}
}