Add tab completion support to thin client

The sbtc client can provide a ux very similar to using the sbt shell
when combined with tab completions. In fact, since some shells have a
better tab completion engine than that provided by jilne2, the
experience can be even better. To make this work, we add another entry
point to the thin client that is capable of generating completions for
an input string. It queries sbt for the completions and prints the
result to stdout, where they are consumed by the shell and fed into its
completion engine.

In addition to providing tab completions, if there is no server running
or if the user is completing `runMain`, `testOnly` or `testQuick`, the
thin client will prompt the user to ask if they would like to start an
sbt server or if they would like to compile to generate the main class
or test names. Neither powershell nor zsh support forwarding input to
the tab completion script. Zsh will print output to stderr so we
opportunistically start the server or complete the test class names.
Powershell does not print completion output at all, so we do not start a
server or fill completions in that case*. For fish and bash, we prompt
the user that they can take these actions so that they can avoid the
expensive operation if desired.

* Powershell users can set the environment variable SBTC_AUTO_COMPLETE
if they want to automatically start a server of compile for run and test
names. No output will be displayed so there can be a long latency
between pressing <tab> and seeing completion results if this variable is
set.
This commit is contained in:
Ethan Atkins 2020-06-24 18:31:18 -07:00
parent ea823f1051
commit d5cbc43075
13 changed files with 373 additions and 51 deletions

5
client/completions/_sbtc Executable file
View File

@ -0,0 +1,5 @@
#compdef sbtc
COMPLETE="--completions=${words[@]}"
COMPLETIONS=($(sbtc --no-tab ${COMPLETE}))
_alternative 'arguments:custom arg:($COMPLETIONS)'

View File

@ -0,0 +1,7 @@
#!/usr/bin/env bash
_do_sbtc_completions() {
COMPREPLY=($(sbtc "--completions=${COMP_LINE}"))
}
complete -F _do_sbtc_completions sbtc

4
client/completions/sbtc.fish Executable file
View File

@ -0,0 +1,4 @@
function __sbtcomp
sbtc --completions="$argv"
end
complete --command sbtc -f --arguments '(__sbtcomp (commandline -cp))'

View File

@ -0,0 +1,10 @@
$scriptblock = {
param($commandName, $line, $position)
$len = $line.ToString().length
$spaces = " " * ($position - $len)
$arg="--completions=$line$spaces"
& 'sbtc.exe' @('--no-tab', '--no-stderr', $arg)
}
Set-Alias -Name sbtc -Value sbtc.exe
Register-ArgumentCompleter -CommandName sbtc.exe -ScriptBlock $scriptBlock
Register-ArgumentCompleter -CommandName sbtc -ScriptBlock $scriptBlock

View File

@ -91,13 +91,17 @@ class NetworkClient(
private val connectionHolder = new AtomicReference[ServerConnection] private val connectionHolder = new AtomicReference[ServerConnection]
private val batchMode = new AtomicBoolean(false) private val batchMode = new AtomicBoolean(false)
private val interactiveThread = new AtomicReference[Thread](null) private val interactiveThread = new AtomicReference[Thread](null)
private lazy val noTab = arguments.completionArguments.contains("--no-tab")
private lazy val noStdErr = arguments.completionArguments.contains("--no-stderr") &&
System.getenv("SBTC_AUTO_COMPLETE") == null
private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI) private def mkSocket(file: File): (Socket, Option[String]) = ClientSocket.socket(file, useJNI)
private def portfile = arguments.baseDirectory / "project" / "target" / "active.json" private def portfile = arguments.baseDirectory / "project" / "target" / "active.json"
def connection: ServerConnection = connectionHolder.synchronized { def connection: ServerConnection = connectionHolder.synchronized {
connectionHolder.get match { connectionHolder.get match {
case null => init(true) case null => init(prompt = false, retry = true)
case c => c case c => c
} }
} }
@ -113,17 +117,51 @@ class NetworkClient(
private class ConnectionRefusedException(t: Throwable) extends Throwable(t) private class ConnectionRefusedException(t: Throwable) extends Throwable(t)
// Open server connection based on the portfile // Open server connection based on the portfile
def init(retry: Boolean): ServerConnection = def init(prompt: Boolean, retry: Boolean): ServerConnection =
try { try {
if (!portfile.exists) { if (!portfile.exists) {
forkServer(portfile, log = true) if (prompt) {
val msg = if (noTab) "" else "No sbt server is running. Press <tab> to start one..."
errorStream.print(s"\n$msg")
if (noStdErr) System.exit(0)
else if (noTab) forkServer(portfile, log = true)
else {
stdinBytes.take match {
case 9 =>
errorStream.println("\nStarting server...")
forkServer(portfile, !prompt)
case _ => System.exit(0)
}
}
} else {
forkServer(portfile, log = true)
}
} }
val (sk, tkn) = val (sk, tkn) =
try mkSocket(portfile) try mkSocket(portfile)
catch { case e: IOException => throw new ConnectionRefusedException(e) } catch { case e: IOException => throw new ConnectionRefusedException(e) }
val conn = new ServerConnection(sk) { val conn = new ServerConnection(sk) {
override def onNotification(msg: JsonRpcNotificationMessage): Unit = override def onNotification(msg: JsonRpcNotificationMessage): Unit = {
self.onNotification(msg) msg.method match {
case "shutdown" =>
val log = msg.params match {
case Some(jvalue) => Converter.fromJson[Boolean](jvalue).getOrElse(true)
case _ => false
}
if (running.compareAndSet(true, false) && log) {
if (!arguments.commandArguments.contains("shutdown")) {
if (Terminal.console.getLastLine.fold(true)(_.nonEmpty)) errorStream.println()
console.appendLog(Level.Error, "sbt server disconnected")
exitClean.set(false)
}
}
stdinBytes.offer(-1)
Option(inputThread.get).foreach(_.close())
Option(interactiveThread.get).foreach(_.interrupt)
case "readInput" =>
case _ => self.onNotification(msg)
}
}
override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg) override def onRequest(msg: JsonRpcRequestMessage): Unit = self.onRequest(msg)
override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg) override def onResponse(msg: JsonRpcResponseMessage): Unit = self.onResponse(msg)
override def onShutdown(): Unit = { override def onShutdown(): Unit = {
@ -140,7 +178,7 @@ class NetworkClient(
conn conn
} catch { } catch {
case e: ConnectionRefusedException if retry => case e: ConnectionRefusedException if retry =>
if (Files.deleteIfExists(portfile.toPath)) init(retry = false) if (Files.deleteIfExists(portfile.toPath)) init(prompt, retry = false)
else throw e else throw e
} }
@ -275,6 +313,14 @@ class NetworkClient(
.fromJson[Vector[String]](i.value) .fromJson[Vector[String]](i.value)
.getOrElse(Vector.empty[String]) .getOrElse(Vector.empty[String])
) )
else if (i.field == "cachedTestNames")
resp.withCachedTestNames(
Converter.fromJson[Boolean](i.value).getOrElse(true)
)
else if (i.field == "cachedMainClassNames")
resp.withCachedMainClassNames(
Converter.fromJson[Boolean](i.value).getOrElse(true)
)
else resp else resp
} }
case _ => CompletionResponse(Vector.empty[String]) case _ => CompletionResponse(Vector.empty[String])
@ -391,9 +437,9 @@ class NetworkClient(
} }
} }
def connect(log: Boolean): Unit = { def connect(log: Boolean, prompt: Boolean): Unit = {
if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR") if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR")
init(retry = true) init(prompt, retry = true)
() ()
} }
@ -450,6 +496,68 @@ class NetworkClient(
sendAndWait(cmd, None) sendAndWait(cmd, None)
} }
def getCompletions(query: String): Seq[String] = {
connect(log = false, prompt = true)
val quoteCount = query.foldLeft(0) {
case (count, '"') => count + 1
case (count, _) => count
}
val inQuote = quoteCount % 2 != 0
val (rawPrefix, prefix, rawSuffix, suffix) = if (quoteCount > 0) {
query.lastIndexOf('"') match {
case -1 => (query, query, None, None) // shouldn't happen
case i =>
val rawPrefix = query.substring(0, i)
val prefix = rawPrefix.replaceAllLiterally("\"", "").replaceAllLiterally("\\;", ";")
val rawSuffix = query.substring(i).replaceAllLiterally("\\;", ";")
val suffix = if (rawSuffix.length > 1) rawSuffix.substring(1) else ""
(rawPrefix, prefix, Some(rawSuffix), Some(suffix))
}
} else (query, query.replaceAllLiterally("\\;", ";"), None, None)
val tailSpace = query.endsWith(" ") || query.endsWith("\"")
val sanitizedQuery = suffix.foldLeft(prefix) { _ + _ }
def getCompletions(query: String, sendCommand: Boolean): Seq[String] = {
val result = new LinkedBlockingQueue[CompletionResponse]()
val json = s"""{"query":"$query","level":1}"""
val execId = sendJson("sbt/completion", json)
pendingCompletions.put(execId, result.put)
val response = result.take
def fillCompletions(label: String, regex: String, command: String): Seq[String] = {
def updateCompletions(): Seq[String] = {
errorStream.println()
sendJson(attach, s"""{"interactive": false}""")
sendAndWait(query.replaceAll(regex + ".*", command).trim, None)
getCompletions(query, false)
}
if (noStdErr) Nil
else if (noTab) updateCompletions()
else {
errorStream.print(s"\nNo cached $label names found. Press '<tab>' to compile: ")
stdinBytes.take match {
case 9 =>
updateCompletions()
case _ => Nil
}
}
}
val testNameCompletions =
if (!response.cachedTestNames.getOrElse(true) && sendCommand)
fillCompletions("test", "test(Only|Quick)", "definedTestNames")
else Nil
val classNameCompletions =
if (!response.cachedMainClassNames.getOrElse(true) && sendCommand)
fillCompletions("main class", "runMain", "discoveredMainClasses")
else Nil
val completions = response.items
testNameCompletions ++ classNameCompletions ++ completions
}
getCompletions(sanitizedQuery, true) collect {
case c if inQuote => c
case c if tailSpace && c.contains(" ") => c.replaceAllLiterally(prefix, "")
case c if !tailSpace => c.split(" ").last
}
}
private def sendAndWait(cmd: String, limit: Option[Deadline]): Int = { private def sendAndWait(cmd: String, limit: Option[Deadline]): Int = {
val queue = sendExecCommand(cmd) val queue = sendExecCommand(cmd)
var result: Integer = null var result: Integer = null
@ -610,35 +718,48 @@ object NetworkClient {
val baseDirectory: File, val baseDirectory: File,
val sbtArguments: Seq[String], val sbtArguments: Seq[String],
val commandArguments: Seq[String], val commandArguments: Seq[String],
val completionArguments: Seq[String],
val sbtScript: String, val sbtScript: String,
) { ) {
def withBaseDirectory(file: File): Arguments = def withBaseDirectory(file: File): Arguments =
new Arguments(file, sbtArguments, commandArguments, sbtScript) new Arguments(file, sbtArguments, commandArguments, completionArguments, sbtScript)
} }
private[client] val completions = "--completions"
private[client] val noTab = "--no-tab"
private[client] val noStdErr = "--no-stderr"
private[client] val sbtBase = "--sbt-base-directory"
private[client] def parseArgs(args: Array[String]): Arguments = { private[client] def parseArgs(args: Array[String]): Arguments = {
var i = 0
var sbtScript = if (Properties.isWin) "sbt.cmd" else "sbt" var sbtScript = if (Properties.isWin) "sbt.cmd" else "sbt"
val commandArgs = new mutable.ArrayBuffer[String] val commandArgs = new mutable.ArrayBuffer[String]
val sbtArguments = new mutable.ArrayBuffer[String] val sbtArguments = new mutable.ArrayBuffer[String]
val completionArguments = new mutable.ArrayBuffer[String]
val SysProp = "-D([^=]+)=(.*)".r val SysProp = "-D([^=]+)=(.*)".r
val sanitized = args.flatMap { val sanitized = args.flatMap {
case a if a.startsWith("\"") => Array(a) case a if a.startsWith("\"") => Array(a)
case a => a.split(" ") case a => a.split(" ")
} }
var foundCompletions = false
var i = 0
while (i < sanitized.length) { while (i < sanitized.length) {
sanitized(i) match { sanitized(i) match {
case a if foundCompletions => completionArguments += a
case a if a == noStdErr || a == noTab || a.startsWith(completions) =>
foundCompletions = true
completionArguments += a
case a if a.startsWith("--sbt-script=") => case a if a.startsWith("--sbt-script=") =>
sbtScript = a.split("--sbt-script=").lastOption.getOrElse(sbtScript) sbtScript = a.split("--sbt-script=").lastOption.getOrElse(sbtScript)
case a if !a.startsWith("-") => commandArgs += a case a if !a.startsWith("-") => commandArgs += a
case a if !a.startsWith("-") => commandArgs += a
case a @ SysProp(key, value) => case a @ SysProp(key, value) =>
System.setProperty(key, value) System.setProperty(key, value)
sbtArguments += a sbtArguments += a
case a => case a if !foundCompletions =>
sbtArguments += a sbtArguments += a
} }
i += 1 i += 1
} }
new Arguments(new File("").getCanonicalFile, sbtArguments, commandArgs, sbtScript) val base = new File("").getCanonicalFile
new Arguments(base, sbtArguments, commandArgs, completionArguments, sbtScript)
} }
def client( def client(
@ -658,7 +779,7 @@ object NetworkClient {
useJNI, useJNI,
) )
try { try {
client.connect(log = true) client.connect(log = true, prompt = false)
client.run() client.run()
} catch { case _: Exception => 1 } finally client.close() } catch { case _: Exception => 1 } finally client.close()
} }
@ -675,31 +796,73 @@ object NetworkClient {
inputStream, inputStream,
errorStream, errorStream,
printStream, printStream,
useJNI useJNI,
) )
def main(args: Array[String]): Unit = { def main(args: Array[String]): Unit = {
val (jnaArg, restOfArgs) = args.partition(_ == "--jna") val (jnaArg, restOfArgs) = args.partition(_ == "--jna")
val useJNI = jnaArg.isEmpty val useJNI = jnaArg.isEmpty
val hook = new Thread(() => { val base = new File("").getCanonicalFile
System.out.print(ConsoleAppender.ClearScreenAfterCursor) if (restOfArgs.exists(_.startsWith(NetworkClient.completions)))
System.out.flush() System.exit(complete(base, restOfArgs, useJNI, System.in, System.out))
}) else {
Runtime.getRuntime.addShutdownHook(hook) val hook = new Thread(() => {
System.exit(Terminal.withStreams { System.out.print(ConsoleAppender.ClearScreenAfterCursor)
val base = new File("").getCanonicalFile() System.out.flush()
try client(base, restOfArgs, System.in, System.err, System.out, useJNI) })
finally { Runtime.getRuntime.addShutdownHook(hook)
Runtime.getRuntime.removeShutdownHook(hook) System.exit(Terminal.withStreams {
hook.run() try client(base, restOfArgs, System.in, System.err, System.out, useJNI)
} finally {
}) Runtime.getRuntime.removeShutdownHook(hook)
hook.run()
}
})
}
}
def complete(
baseDirectory: File,
args: Array[String],
useJNI: Boolean,
in: InputStream,
out: PrintStream
): Int = {
val cmd: String = args.find(_.startsWith(NetworkClient.completions)) match {
case Some(c) =>
c.split('=').lastOption match {
case Some(query) =>
query.indexOf(" ") match {
case -1 => throw new IllegalArgumentException(query)
case i => query.substring(i + 1)
}
case _ => throw new IllegalArgumentException(c)
}
case _ => throw new IllegalStateException("should be unreachable")
}
val quiet = args.exists(_ == "--quiet")
val errorStream = if (quiet) new PrintStream(_ => {}, false) else System.err
val sbtArgs = args.takeWhile(!_.startsWith(NetworkClient.completions))
val arguments = NetworkClient.parseArgs(sbtArgs)
val noTab = args.contains("--no-tab")
val client =
simpleClient(
arguments.withBaseDirectory(baseDirectory),
inputStream = in,
errorStream = errorStream,
printStream = errorStream,
useJNI = useJNI,
)
try {
val results = client.getCompletions(cmd)
out.println(results.sorted.distinct mkString "\n")
0
} catch { case _: Exception => 1 } finally client.close()
} }
def run(configuration: xsbti.AppConfiguration, arguments: List[String]): Int = def run(configuration: xsbti.AppConfiguration, arguments: List[String]): Int =
try { try {
val client = new NetworkClient(configuration, parseArgs(arguments.toArray)) val client = new NetworkClient(configuration, parseArgs(arguments.toArray))
try { try {
client.connect(log = true) client.connect(log = true, prompt = false)
client.run() client.run()
} catch { case _: Throwable => 1 } finally client.close() } catch { case _: Throwable => 1 } finally client.close()
} catch { } catch {

View File

@ -411,19 +411,47 @@ final class NetworkChannel(
try { try {
Option(EvaluateTask.lastEvaluatedState.get) match { Option(EvaluateTask.lastEvaluatedState.get) match {
case Some(sstate) => case Some(sstate) =>
val completionItems = import sbt.protocol.codec.JsonProtocol._
def completionItems(s: State) = {
Parser Parser
.completions(sstate.combinedParser, cp.query, 9) .completions(s.combinedParser, cp.query, cp.level.getOrElse(9))
.get .get
.flatMap { c => .flatMap { c =>
if (!c.isEmpty) Some(c.append.replaceAll("\n", " ")) if (!c.isEmpty) Some(c.append.replaceAll("\n", " "))
else None else None
} }
.map(c => cp.query + c) .map(c => cp.query + c)
import sbt.protocol.codec.JsonProtocol._ }
val (items, cachedMainClassNames, cachedTestNames) = StandardMain.exchange.withState {
s =>
val scopedKeyParser: Parser[Seq[Def.ScopedKey[_]]] =
Act.aggregatedKeyParser(s) <~ Parsers.any.*
Parser.parse(cp.query, scopedKeyParser) match {
case Right(keys) =>
val testKeys =
keys.filter(k => k.key.label == "testOnly" || k.key.label == "testQuick")
val (testState, cachedTestNames) = testKeys.foldLeft((s, true)) {
case ((st, allCached), k) =>
SessionVar.loadAndSet(sbt.Keys.definedTestNames in k.scope, st, true) match {
case (nst, d) => (nst, allCached && d.isDefined)
}
}
val runKeys = keys.filter(_.key.label == "runMain")
val (runState, cachedMainClassNames) = runKeys.foldLeft((testState, true)) {
case ((st, allCached), k) =>
SessionVar.loadAndSet(sbt.Keys.discoveredMainClasses in k.scope, st, true) match {
case (nst, d) => (nst, allCached && d.isDefined)
}
}
(completionItems(runState), cachedMainClassNames, cachedTestNames)
case _ => (completionItems(s), true, true)
}
}
respondResult( respondResult(
CompletionResponse( CompletionResponse(
items = completionItems.toVector items = items.toVector,
cachedMainClassNames = cachedMainClassNames,
cachedTestNames = cachedTestNames
), ),
execId execId
) )

View File

@ -5,28 +5,37 @@
// DO NOT EDIT MANUALLY // DO NOT EDIT MANUALLY
package sbt.protocol package sbt.protocol
final class CompletionParams private ( final class CompletionParams private (
val query: String) extends Serializable { val query: String,
val level: Option[Int]) extends Serializable {
private def this(query: String) = this(query, None)
override def equals(o: Any): Boolean = o match { override def equals(o: Any): Boolean = o match {
case x: CompletionParams => (this.query == x.query) case x: CompletionParams => (this.query == x.query) && (this.level == x.level)
case _ => false case _ => false
} }
override def hashCode: Int = { override def hashCode: Int = {
37 * (37 * (17 + "sbt.protocol.CompletionParams".##) + query.##) 37 * (37 * (37 * (17 + "sbt.protocol.CompletionParams".##) + query.##) + level.##)
} }
override def toString: String = { override def toString: String = {
"CompletionParams(" + query + ")" "CompletionParams(" + query + ", " + level + ")"
} }
private[this] def copy(query: String = query): CompletionParams = { private[this] def copy(query: String = query, level: Option[Int] = level): CompletionParams = {
new CompletionParams(query) new CompletionParams(query, level)
} }
def withQuery(query: String): CompletionParams = { def withQuery(query: String): CompletionParams = {
copy(query = query) copy(query = query)
} }
def withLevel(level: Option[Int]): CompletionParams = {
copy(level = level)
}
def withLevel(level: Int): CompletionParams = {
copy(level = Option(level))
}
} }
object CompletionParams { object CompletionParams {
def apply(query: String): CompletionParams = new CompletionParams(query) def apply(query: String): CompletionParams = new CompletionParams(query)
def apply(query: String, level: Option[Int]): CompletionParams = new CompletionParams(query, level)
def apply(query: String, level: Int): CompletionParams = new CompletionParams(query, Option(level))
} }

View File

@ -5,28 +5,44 @@
// DO NOT EDIT MANUALLY // DO NOT EDIT MANUALLY
package sbt.protocol package sbt.protocol
final class CompletionResponse private ( final class CompletionResponse private (
val items: Vector[String]) extends Serializable { val items: Vector[String],
val cachedMainClassNames: Option[Boolean],
val cachedTestNames: Option[Boolean]) extends Serializable {
private def this(items: Vector[String]) = this(items, None, None)
override def equals(o: Any): Boolean = o match { override def equals(o: Any): Boolean = o match {
case x: CompletionResponse => (this.items == x.items) case x: CompletionResponse => (this.items == x.items) && (this.cachedMainClassNames == x.cachedMainClassNames) && (this.cachedTestNames == x.cachedTestNames)
case _ => false case _ => false
} }
override def hashCode: Int = { override def hashCode: Int = {
37 * (37 * (17 + "sbt.protocol.CompletionResponse".##) + items.##) 37 * (37 * (37 * (37 * (17 + "sbt.protocol.CompletionResponse".##) + items.##) + cachedMainClassNames.##) + cachedTestNames.##)
} }
override def toString: String = { override def toString: String = {
"CompletionResponse(" + items + ")" "CompletionResponse(" + items + ", " + cachedMainClassNames + ", " + cachedTestNames + ")"
} }
private[this] def copy(items: Vector[String] = items): CompletionResponse = { private[this] def copy(items: Vector[String] = items, cachedMainClassNames: Option[Boolean] = cachedMainClassNames, cachedTestNames: Option[Boolean] = cachedTestNames): CompletionResponse = {
new CompletionResponse(items) new CompletionResponse(items, cachedMainClassNames, cachedTestNames)
} }
def withItems(items: Vector[String]): CompletionResponse = { def withItems(items: Vector[String]): CompletionResponse = {
copy(items = items) copy(items = items)
} }
def withCachedMainClassNames(cachedMainClassNames: Option[Boolean]): CompletionResponse = {
copy(cachedMainClassNames = cachedMainClassNames)
}
def withCachedMainClassNames(cachedMainClassNames: Boolean): CompletionResponse = {
copy(cachedMainClassNames = Option(cachedMainClassNames))
}
def withCachedTestNames(cachedTestNames: Option[Boolean]): CompletionResponse = {
copy(cachedTestNames = cachedTestNames)
}
def withCachedTestNames(cachedTestNames: Boolean): CompletionResponse = {
copy(cachedTestNames = Option(cachedTestNames))
}
} }
object CompletionResponse { object CompletionResponse {
def apply(items: Vector[String]): CompletionResponse = new CompletionResponse(items) def apply(items: Vector[String]): CompletionResponse = new CompletionResponse(items)
def apply(items: Vector[String], cachedMainClassNames: Option[Boolean], cachedTestNames: Option[Boolean]): CompletionResponse = new CompletionResponse(items, cachedMainClassNames, cachedTestNames)
def apply(items: Vector[String], cachedMainClassNames: Boolean, cachedTestNames: Boolean): CompletionResponse = new CompletionResponse(items, Option(cachedMainClassNames), Option(cachedTestNames))
} }

View File

@ -12,8 +12,9 @@ implicit lazy val CompletionParamsFormat: JsonFormat[sbt.protocol.CompletionPara
case Some(__js) => case Some(__js) =>
unbuilder.beginObject(__js) unbuilder.beginObject(__js)
val query = unbuilder.readField[String]("query") val query = unbuilder.readField[String]("query")
val level = unbuilder.readField[Option[Int]]("level")
unbuilder.endObject() unbuilder.endObject()
sbt.protocol.CompletionParams(query) sbt.protocol.CompletionParams(query, level)
case None => case None =>
deserializationError("Expected JsObject but found None") deserializationError("Expected JsObject but found None")
} }
@ -21,6 +22,7 @@ implicit lazy val CompletionParamsFormat: JsonFormat[sbt.protocol.CompletionPara
override def write[J](obj: sbt.protocol.CompletionParams, builder: Builder[J]): Unit = { override def write[J](obj: sbt.protocol.CompletionParams, builder: Builder[J]): Unit = {
builder.beginObject() builder.beginObject()
builder.addField("query", obj.query) builder.addField("query", obj.query)
builder.addField("level", obj.level)
builder.endObject() builder.endObject()
} }
} }

View File

@ -12,8 +12,10 @@ implicit lazy val CompletionResponseFormat: JsonFormat[sbt.protocol.CompletionRe
case Some(__js) => case Some(__js) =>
unbuilder.beginObject(__js) unbuilder.beginObject(__js)
val items = unbuilder.readField[Vector[String]]("items") val items = unbuilder.readField[Vector[String]]("items")
val cachedMainClassNames = unbuilder.readField[Option[Boolean]]("cachedMainClassNames")
val cachedTestNames = unbuilder.readField[Option[Boolean]]("cachedTestNames")
unbuilder.endObject() unbuilder.endObject()
sbt.protocol.CompletionResponse(items) sbt.protocol.CompletionResponse(items, cachedMainClassNames, cachedTestNames)
case None => case None =>
deserializationError("Expected JsObject but found None") deserializationError("Expected JsObject but found None")
} }
@ -21,6 +23,8 @@ implicit lazy val CompletionResponseFormat: JsonFormat[sbt.protocol.CompletionRe
override def write[J](obj: sbt.protocol.CompletionResponse, builder: Builder[J]): Unit = { override def write[J](obj: sbt.protocol.CompletionResponse, builder: Builder[J]): Unit = {
builder.beginObject() builder.beginObject()
builder.addField("items", obj.items) builder.addField("items", obj.items)
builder.addField("cachedMainClassNames", obj.cachedMainClassNames)
builder.addField("cachedTestNames", obj.cachedTestNames)
builder.endObject() builder.endObject()
} }
} }

View File

@ -29,6 +29,7 @@ type Attach implements CommandMessage {
type CompletionParams { type CompletionParams {
query: String! query: String!
level: Int @since("1.4.0")
} }
## Message for events. ## Message for events.
@ -68,6 +69,8 @@ type SettingQueryFailure implements SettingQueryResponse {
type CompletionResponse { type CompletionResponse {
items: [String] items: [String]
cachedMainClassNames: Boolean @since("1.4.0")
cachedTestNames: Boolean @since("1.4.0")
} }
# enum Status { # enum Status {

View File

@ -1,7 +1,9 @@
package testpkg package testpkg
import java.io.{ InputStream, PrintStream } import java.io.{ InputStream, OutputStream, PrintStream }
import sbt.internal.client.NetworkClient import sbt.internal.client.NetworkClient
import sbt.internal.util.Util
import scala.collection.mutable
object ClientTest extends AbstractServerTest { object ClientTest extends AbstractServerTest {
override val testDirectory: String = "client" override val testDirectory: String = "client"
@ -13,6 +15,31 @@ object ClientTest extends AbstractServerTest {
} }
} }
val NullPrintStream = new PrintStream(_ => {}, false) val NullPrintStream = new PrintStream(_ => {}, false)
class CachingPrintStream extends { val cos = new CachingOutputStream }
with PrintStream(cos, true) {
def lines = cos.lines
}
class CachingOutputStream extends OutputStream {
private val lineBuffer = new mutable.ArrayBuffer[String]
private var byteBuffer = new mutable.ArrayBuffer[Byte]
override def write(i: Int) = {
if (i == 10) {
lineBuffer += new String(byteBuffer.toArray)
byteBuffer = new mutable.ArrayBuffer[Byte]
} else Util.ignoreResult(byteBuffer += i.toByte)
}
def lines = lineBuffer.toVector
}
class FixedInputStream(keys: Char*) extends InputStream {
var i = 0
override def read(): Int = {
if (i < keys.length) {
val res = keys(i).toInt
i += 1
res
} else -1
}
}
private def client(args: String*) = private def client(args: String*) =
NetworkClient.client( NetworkClient.client(
testPath.toFile, testPath.toFile,
@ -22,6 +49,21 @@ object ClientTest extends AbstractServerTest {
NullPrintStream, NullPrintStream,
false false
) )
// This ensures that the completion command will send a tab that triggers
// sbt to call definedTestNames or discoveredMainClasses if there hasn't
// been a necessary compilation
def tabs = new FixedInputStream('\t', '\t')
private def complete(completionString: String): Seq[String] = {
val cps = new CachingPrintStream
NetworkClient.complete(
testPath.toFile,
Array(s"--completions=sbtc $completionString"),
false,
tabs,
cps
)
cps.lines
}
test("exit success") { c => test("exit success") { c =>
assert(client("willSucceed") == 0) assert(client("willSucceed") == 0)
} }
@ -43,4 +85,33 @@ object ClientTest extends AbstractServerTest {
test("three commands with middle failure") { _ => test("three commands with middle failure") { _ =>
assert(client("compile;willFail;willSucceed") == 1) assert(client("compile;willFail;willSucceed") == 1)
} }
test("compi completions") { _ =>
val expected = Vector(
"compile",
"compile:",
"compileAnalysisFile",
"compileAnalysisFilename",
"compileAnalysisTargetRoot",
"compileIncSetup",
"compileIncremental",
"compileOutputs",
"compilers",
)
assert(complete("compi") == expected)
}
test("testOnly completions") { _ =>
val testOnlyExpected = Vector(
"testOnly",
"testOnly/",
"testOnly::",
"testOnly;",
)
assert(complete("testOnly") == testOnlyExpected)
val testOnlyOptionsExpected = Vector("--", ";", "test.pkg.FooSpec")
assert(complete("testOnly ") == testOnlyOptionsExpected)
}
test("quote with semi") { _ =>
assert(complete("\"compile; fooB") == Vector("compile; fooBar"))
}
} }

View File

@ -29,7 +29,7 @@ object ServerCompletionsTest extends AbstractServerTest {
s"""{ "jsonrpc": "2.0", "id": 16, "method": "sbt/completion", "params": $completionStr }""" s"""{ "jsonrpc": "2.0", "id": 16, "method": "sbt/completion", "params": $completionStr }"""
) )
assert(svr.waitForString(10.seconds) { s => assert(svr.waitForString(10.seconds) { s =>
s contains """"result":{"items":["hello"]}""" s contains """"result":{"items":["hello"]"""
}) })
} }
@ -39,7 +39,7 @@ object ServerCompletionsTest extends AbstractServerTest {
s"""{ "jsonrpc": "2.0", "id": 17, "method": "sbt/completion", "params": $completionStr }""" s"""{ "jsonrpc": "2.0", "id": 17, "method": "sbt/completion", "params": $completionStr }"""
) )
assert(svr.waitForString(10.seconds) { s => assert(svr.waitForString(10.seconds) { s =>
s contains """"result":{"items":["testOnly org.sbt.ExampleSpec"]}""" s contains """"result":{"items":["testOnly org.sbt.ExampleSpec"]"""
}) })
} }
} }