diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala
index 5972cf623..b1aa51318 100644
--- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala
+++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala
@@ -11,7 +11,7 @@ import java.io.{ InputStream, OutputStream, PrintStream }
import java.nio.channels.ClosedChannelException
import java.util.Locale
import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference }
-import java.util.concurrent.{ CountDownLatch, Executors, LinkedBlockingQueue }
+import java.util.concurrent.{ CountDownLatch, Executors, LinkedBlockingQueue, TimeUnit }
import jline.DefaultTerminal2
import jline.console.ConsoleReader
@@ -362,6 +362,7 @@ object Terminal {
}
private[this] val nonBlockingIn: WriteableInputStream =
new WriteableInputStream(jline.TerminalFactory.get.wrapInIfNeeded(originalIn), "console")
+
private[this] val inputStream = new AtomicReference[InputStream](System.in)
private[this] def withOut[T](f: => T): T = {
try {
@@ -397,11 +398,86 @@ object Terminal {
*/
private[this] val activeTerminal = new AtomicReference[Terminal](consoleTerminalHolder.get)
jline.TerminalFactory.set(consoleTerminalHolder.get.toJLine)
+
+ /**
+ * The boot input stream allows a remote client to forward input to the sbt process while
+ * it is still loading. It works by updating proxyInputStream to read from the
+ * value of bootInputStreamHolder if it is non-null as well as from the normal process
+ * console io (assuming there is console io).
+ */
+ private[this] val bootInputStreamHolder = new AtomicReference[InputStream]
+
+ /**
+ * The boot output stream allows sbt to relay the bytes written to stdout to one or
+ * more remote clients while the sbt build is loading and hasn't yet loaded a server.
+ * The output stream of TerminalConsole is updated to write to value of
+ * bootOutputStreamHolder when it is non-null as well as the normal process console
+ * output stream.
+ */
+ private[this] val bootOutputStreamHolder = new AtomicReference[OutputStream]
+ private[sbt] def setBootStreams(
+ bootInputStream: InputStream,
+ bootOutputStream: OutputStream
+ ): Unit = {
+ bootInputStreamHolder.set(bootInputStream)
+ bootOutputStreamHolder.set(bootOutputStream)
+ }
+
private[this] object proxyInputStream extends InputStream {
- def read(): Int = activeTerminal.get().inputStream.read()
+ private[this] val isScripted = System.getProperty("sbt.scripted", "false") == "true"
+ /*
+ * This is to handle the case when a remote client starts sbt and the build fails.
+ * We need to be able to consume input bytes from the remote client, but they
+ * haven't yet connected to the main server but may be connected to the
+ * BootServerSocket. Unfortunately there is no poll method on input stream that
+ * takes a duration so we have to manually implement that here. All of the input
+ * streams that we create in sbt are interruptible, so we can just poll each
+ * of the input streams and periodically interrupt the thread to switch between
+ * the two input streams.
+ */
+ private class ReadThread extends Thread with AutoCloseable {
+ val result = new LinkedBlockingQueue[Integer]
+ setDaemon(true)
+ start()
+ val running = new AtomicBoolean(true)
+ override def run(): Unit = while (running.get) {
+ bootInputStreamHolder.get match {
+ case null =>
+ case is =>
+ def readFrom(inputStream: InputStream) =
+ try {
+ if (running.get) {
+ inputStream.read match {
+ case -1 =>
+ case i =>
+ result.put(i)
+ running.set(false)
+ }
+ }
+ } catch { case _: InterruptedException => }
+ readFrom(is)
+ readFrom(activeTerminal.get().inputStream)
+ }
+ }
+ override def close(): Unit = if (running.compareAndSet(true, false)) this.interrupt()
+ }
+ def read(): Int = {
+ if (isScripted) -1
+ else if (bootInputStreamHolder.get == null) activeTerminal.get().inputStream.read()
+ else {
+ val thread = new ReadThread
+ @tailrec def poll(): Int = thread.result.poll(10, TimeUnit.MILLISECONDS) match {
+ case null =>
+ thread.interrupt()
+ poll()
+ case i => i
+ }
+ poll()
+ }
+ }
}
private[this] object proxyOutputStream extends OutputStream {
- private[this] def os = activeTerminal.get().outputStream
+ private[this] def os: OutputStream = activeTerminal.get().outputStream
def write(byte: Int): Unit = {
os.write(byte)
os.flush()
@@ -611,12 +687,28 @@ object Terminal {
}
def throwIfClosed[R](f: => R): R = if (isStopped.get) throw new ClosedChannelException else f
+ private val combinedOutputStream = new OutputStream {
+ override def write(b: Int): Unit = {
+ Option(bootOutputStreamHolder.get).foreach(_.write(b))
+ out.write(b)
+ }
+ override def write(b: Array[Byte]): Unit = write(b, 0, b.length)
+ override def write(b: Array[Byte], offset: Int, len: Int): Unit = {
+ Option(bootOutputStreamHolder.get).foreach(_.write(b, offset, len))
+ out.write(b, offset, len)
+ }
+ override def flush(): Unit = {
+ Option(bootOutputStreamHolder.get).foreach(_.flush())
+ out.flush()
+ }
+ }
+
override val outputStream = new OutputStream {
override def write(b: Int): Unit = throwIfClosed {
writeLock.synchronized {
if (b == Int.MinValue) currentLine.set(new ArrayBuffer[Byte])
else doWrite(Vector((b & 0xFF).toByte))
- if (b == 10) out.flush()
+ if (b == 10) combinedOutputStream.flush()
}
}
override def write(b: Array[Byte]): Unit = throwIfClosed(write(b, 0, b.length))
@@ -629,6 +721,7 @@ object Terminal {
}
}
}
+ override def flush(): Unit = combinedOutputStream.flush()
private[this] val clear = s"$CursorLeft1000$ClearScreenAfterCursor"
private def doWrite(bytes: Seq[Byte]): Unit = {
def doWrite(b: Byte): Unit = out.write(b & 0xFF)
@@ -638,8 +731,8 @@ object Terminal {
progressState.clearBytes()
val cl = currentLine.get
if (buf.nonEmpty && isAnsiSupported && cl.isEmpty) clear.getBytes.foreach(doWrite)
- out.write(buf.toArray)
- out.write(10)
+ combinedOutputStream.write(buf.toArray)
+ combinedOutputStream.write(10)
currentLine.get match {
case s if s.nonEmpty => currentLine.set(new ArrayBuffer[Byte])
case _ =>
@@ -654,9 +747,9 @@ object Terminal {
clear.getBytes.foreach(doWrite)
}
cl ++= remaining
- out.write(remaining.toArray)
+ combinedOutputStream.write(remaining.toArray)
}
- out.flush()
+ combinedOutputStream.flush()
}
}
override private[sbt] val printStream: PrintStream = new PrintStream(outputStream, true)
@@ -681,7 +774,7 @@ object Terminal {
Some(new String(bytes.toArray).replaceAllLiterally(ClearScreenAfterCursor, ""))
}
- private[this] val rawPrintStream: PrintStream = new PrintStream(out, true) {
+ private[this] val rawPrintStream: PrintStream = new PrintStream(combinedOutputStream, true) {
override def close(): Unit = {}
}
override def withPrintStream[T](f: PrintStream => T): T =
diff --git a/main-command/src/main/java/sbt/internal/BootServerSocket.java b/main-command/src/main/java/sbt/internal/BootServerSocket.java
new file mode 100644
index 000000000..06ba74de2
--- /dev/null
+++ b/main-command/src/main/java/sbt/internal/BootServerSocket.java
@@ -0,0 +1,318 @@
+/*
+ * sbt
+ * Copyright 2011 - 2018, Lightbend, Inc.
+ * Copyright 2008 - 2010, Mark Harrah
+ * Licensed under Apache License 2.0 (see LICENSE)
+ */
+
+package sbt.internal;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.UnsupportedEncodingException;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.net.Socket;
+import java.net.ServerSocket;
+import java.net.SocketException;
+import java.net.SocketTimeoutException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import net.openhft.hashing.LongHashFunction;
+import org.scalasbt.ipcsocket.UnixDomainServerSocket;
+import org.scalasbt.ipcsocket.UnixDomainSocket;
+import org.scalasbt.ipcsocket.Win32NamedPipeServerSocket;
+import org.scalasbt.ipcsocket.Win32NamedPipeSocket;
+import org.scalasbt.ipcsocket.Win32SecurityLevel;
+import xsbti.AppConfiguration;
+
+/**
+ * A BootServerSocket is used for remote clients to connect to sbt for io while sbt is still loading
+ * the build. There are two scenarios in which this functionality is needed:
+ *
+ *
1. client a starts an sbt server and then client b tries to connect to the server before the
+ * server has loaded. Presently, client b will try to start a new server even though there is one
+ * booting. This can cause a java process leak because the second server launched by client b is
+ * unable to create a server because there is an existing portfile by the time it starts up.
+ *
+ *
2. a remote client initiates a reboot command. Reboot causes sbt to shutdown the server which
+ * makes the client disconnect. Since sbt does not start the server until the project has
+ * successfully loaded, there is no way for the client to see the output of the server. This is
+ * particularly problematic if loading fails because the server will be stuck waiting for input that
+ * will not be forthcoming.
+ *
+ *
To address these issues, the BootServerSocket can be used to immediately create a server
+ * socket before sbt even starts loading the build. It works by creating a local socket either in
+ * project/target/SOCK_NAME or a windows named pipe with name SOCK_NAME where SOCK_NAME is computed
+ * as the hash of the project's base directory (for disambiguation in the windows case). If the
+ * server can't create a server socket because there is already one running, it either prompts the
+ * user if they want to start a new server even if there is already one running if there is a
+ * console available or exits with the status code 2 which indicates that there is another sbt
+ * process starting up.
+ *
+ *
Once the server socket is created, it listens for new client connections. When a client
+ * connects, the server will forward its input and output to the client via Terminal.setBootStreams
+ * which updates the Terminal.proxyOutputStream to forward all bytes written to the
+ * BootServerSocket's outputStream which in turn writes the output to each of the connected clients.
+ * Input is handed similarly.
+ *
+ *
When the server finishes loading, it closes the boot server socket.
+ *
+ *
BootServerSocket is implemented in java so that it can be classloaded as quickly as possible.
+ */
+public class BootServerSocket implements AutoCloseable {
+ private ServerSocket serverSocket = null;
+ private final AtomicBoolean closed = new AtomicBoolean(false);
+ private final AtomicBoolean running = new AtomicBoolean(false);
+ private final AtomicInteger threadId = new AtomicInteger(1);
+ private final Future> acceptFuture;
+ private final ExecutorService service =
+ Executors.newCachedThreadPool(
+ r -> new Thread(r, "boot-server-socket-thread-" + threadId.getAndIncrement()));
+ private final Set clientSockets = ConcurrentHashMap.newKeySet();
+ private final Object lock = new Object();
+ private final LinkedBlockingQueue clientSocketReads = new LinkedBlockingQueue<>();
+ private final Path socketFile;
+
+ private class ClientSocket implements AutoCloseable {
+ final Socket socket;
+ final AtomicBoolean alive = new AtomicBoolean(true);
+ final Future> future;
+ private final LinkedBlockingQueue bytes = new LinkedBlockingQueue();
+ private final AtomicBoolean closed = new AtomicBoolean(false);
+
+ @SuppressWarnings("deprecation")
+ ClientSocket(final Socket socket) {
+ this.socket = socket;
+ clientSockets.add(this);
+ Future> f = null;
+ try {
+ f =
+ service.submit(
+ () -> {
+ try {
+ final InputStream inputStream = socket.getInputStream();
+ while (alive.get()) {
+ try {
+ int b = inputStream.read();
+ if (b != -1) {
+ bytes.put(b);
+ clientSocketReads.put(ClientSocket.this);
+ } else {
+ alive.set(false);
+ }
+ } catch (IOException e) {
+ alive.set(false);
+ }
+ }
+ } catch (final Exception ex) {
+ }
+ });
+ } catch (final RejectedExecutionException e) {
+ alive.set(false);
+ }
+ future = f;
+ }
+
+ private void write(final int i) {
+ try {
+ if (alive.get()) socket.getOutputStream().write(i);
+ } catch (final IOException e) {
+ alive.set(false);
+ close();
+ }
+ }
+
+ private void write(final byte[] b, final int offset, final int len) {
+ try {
+ if (alive.get()) socket.getOutputStream().write(b, offset, len);
+ } catch (final IOException e) {
+ alive.set(false);
+ close();
+ }
+ }
+
+ private void flush() {
+ try {
+ socket.getOutputStream().flush();
+ } catch (final IOException e) {
+ alive.set(false);
+ close();
+ }
+ }
+
+ @SuppressWarnings("EmptyCatchBlock")
+ @Override
+ public void close() {
+ if (closed.compareAndSet(false, true)) {
+ if (alive.get()) {
+ write(2);
+ bytes.forEach(this::write);
+ bytes.clear();
+ write(3);
+ flush();
+ }
+ alive.set(false);
+ if (future != null) future.cancel(true);
+ try {
+ socket.getOutputStream().close();
+ socket.getInputStream().close();
+ // Windows is very slow to close the socket for whatever reason
+ // We close the server socket anyway, so this should die then.
+ if (!System.getProperty("os.name", "").toLowerCase().startsWith("win")) socket.close();
+ } catch (final IOException e) {
+ }
+ clientSockets.remove(this);
+ }
+ }
+ }
+
+ private final Object writeLock = new Object();
+
+ public InputStream inputStream() {
+ return inputStream;
+ }
+
+ private final InputStream inputStream =
+ new InputStream() {
+ @Override
+ public int read() {
+ try {
+ ClientSocket clientSocket = clientSocketReads.take();
+ return clientSocket.bytes.take();
+ } catch (final InterruptedException e) {
+ return -1;
+ }
+ }
+ };
+ private final OutputStream outputStream =
+ new OutputStream() {
+ @Override
+ public void write(final int b) {
+ synchronized (lock) {
+ clientSockets.forEach(cs -> cs.write(b));
+ }
+ }
+
+ @Override
+ public void write(final byte[] b) {
+ write(b, 0, b.length);
+ }
+
+ @Override
+ public void write(final byte[] b, final int offset, final int len) {
+ synchronized (lock) {
+ clientSockets.forEach(cs -> cs.write(b, offset, len));
+ }
+ }
+
+ @Override
+ public void flush() {
+ synchronized (lock) {
+ clientSockets.forEach(cs -> cs.flush());
+ }
+ }
+ };
+
+ public OutputStream outputStream() {
+ return outputStream;
+ }
+
+ private final Runnable acceptRunnable =
+ () -> {
+ try {
+ serverSocket.setSoTimeout(5000);
+ while (running.get()) {
+ try {
+ ClientSocket clientSocket = new ClientSocket(serverSocket.accept());
+ } catch (final SocketTimeoutException e) {
+ } catch (final IOException e) {
+ running.set(false);
+ }
+ }
+ } catch (final SocketException e) {
+ }
+ };
+
+ public BootServerSocket(final AppConfiguration configuration)
+ throws ServerAlreadyBootingException, IOException {
+ final Path base = configuration.baseDirectory().toPath().toRealPath();
+ final Path target = base.resolve("project").resolve("target");
+ if (!isWindows) {
+ socketFile = Paths.get(socketLocation(base));
+ Files.createDirectories(target);
+ } else {
+ socketFile = null;
+ }
+ serverSocket = newSocket(socketLocation(base));
+ if (serverSocket != null) {
+ running.set(true);
+ acceptFuture = service.submit(acceptRunnable);
+ } else {
+ closed.set(true);
+ acceptFuture = null;
+ }
+ }
+
+ public static String socketLocation(final Path base) throws UnsupportedEncodingException {
+ final Path target = base.resolve("project").resolve("target");
+ if (isWindows) {
+ long hash = LongHashFunction.farmNa().hashBytes(target.toString().getBytes("UTF-8"));
+ return "sbt-load" + hash;
+ } else {
+ return base.relativize(target.resolve("sbt-load.sock")).toString();
+ }
+ }
+
+ @SuppressWarnings("EmptyCatchBlock")
+ @Override
+ public void close() {
+ if (closed.compareAndSet(false, true)) {
+ // avoid concurrent modification exception
+ clientSockets.forEach(ClientSocket::close);
+ if (acceptFuture != null) acceptFuture.cancel(true);
+ service.shutdownNow();
+ try {
+ if (serverSocket != null) serverSocket.close();
+ } catch (final IOException e) {
+ }
+ try {
+ if (socketFile != null) Files.deleteIfExists(socketFile);
+ } catch (final IOException e) {
+ }
+ }
+ }
+
+ static final boolean isWindows =
+ System.getProperty("os.name", "").toLowerCase().startsWith("win");
+
+ static ServerSocket newSocket(final String sock) throws ServerAlreadyBootingException {
+ ServerSocket socket = null;
+ String name = socketName(sock);
+ try {
+ if (!isWindows) Files.deleteIfExists(Paths.get(sock));
+ socket =
+ isWindows
+ ? new Win32NamedPipeServerSocket(name, false, Win32SecurityLevel.OWNER_DACL)
+ : new UnixDomainServerSocket(name);
+ return socket;
+ } catch (final IOException e) {
+ throw new ServerAlreadyBootingException();
+ }
+ }
+
+ private static String socketName(String sock) {
+ return isWindows ? "\\\\.\\pipe\\" + sock : sock;
+ }
+}
diff --git a/main-command/src/main/java/sbt/internal/ServerAlreadyBootingException.java b/main-command/src/main/java/sbt/internal/ServerAlreadyBootingException.java
new file mode 100644
index 000000000..070dc8b92
--- /dev/null
+++ b/main-command/src/main/java/sbt/internal/ServerAlreadyBootingException.java
@@ -0,0 +1,10 @@
+/*
+ * sbt
+ * Copyright 2011 - 2018, Lightbend, Inc.
+ * Copyright 2008 - 2010, Mark Harrah
+ * Licensed under Apache License 2.0 (see LICENSE)
+ */
+
+package sbt.internal;
+
+public class ServerAlreadyBootingException extends Exception {}
diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala
index 1d52eb8b7..233ca0526 100644
--- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala
+++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala
@@ -34,7 +34,7 @@ import scala.annotation.tailrec
import scala.collection.mutable
import scala.concurrent.duration._
import scala.util.control.NonFatal
-import scala.util.{ Failure, Properties, Success }
+import scala.util.{ Failure, Properties, Success, Try }
import Serialization.{
CancelAll,
attach,
@@ -108,14 +108,11 @@ class NetworkClient(
}
private[this] val stdinBytes = new LinkedBlockingQueue[Int]
- private[this] val stdin: InputStream = new InputStream {
- override def available(): Int = stdinBytes.size
- override def read: Int = stdinBytes.take
- }
private[this] val inputThread = new AtomicReference(new RawInputThread)
private[this] val exitClean = new AtomicBoolean(true)
private[this] val sbtProcess = new AtomicReference[Process](null)
private class ConnectionRefusedException(t: Throwable) extends Throwable(t)
+ private class ServerFailedException extends Exception
// Open server connection based on the portfile
def init(prompt: Boolean, retry: Boolean): ServerConnection =
@@ -138,9 +135,23 @@ class NetworkClient(
forkServer(portfile, log = true)
}
}
- val (sk, tkn) =
- try mkSocket(portfile)
- catch { case e: IOException => throw new ConnectionRefusedException(e) }
+ @tailrec def connect(attempt: Int): (Socket, Option[String]) = {
+ val res = try Some(mkSocket(portfile))
+ catch {
+ // This catches a pipe busy exception which can happen if two windows clients
+ // attempt to connect in rapid succession
+ case e: IOException if e.getMessage.contains("Couldn't open") && attempt < 10 => None
+ case e: IOException => throw new ConnectionRefusedException(e)
+ }
+ res match {
+ case Some(r) => r
+ case None =>
+ // Use a random sleep to spread out the competing processes
+ Thread.sleep(new java.util.Random().nextInt(20).toLong)
+ connect(attempt + 1)
+ }
+ }
+ val (sk, tkn) = connect(0)
val conn = new ServerConnection(sk) {
override def onNotification(msg: JsonRpcNotificationMessage): Unit = {
msg.method match {
@@ -188,57 +199,129 @@ class NetworkClient(
* This instance must be shutdown explicitly via `sbt -client shutdown`
*/
def forkServer(portfile: File, log: Boolean): Unit = {
- if (log) console.appendLog(Level.Info, "server was not detected. starting an instance")
- val term = Terminal.console
- val props =
- Seq(
- term.getWidth,
- term.getHeight,
- term.isAnsiSupported,
- term.isColorEnabled,
- term.isSupershellEnabled
- ).mkString(",")
-
- val cmd = arguments.sbtScript +: arguments.sbtArguments :+ BasicCommandStrings.CloseIOStreams
- val processBuilder =
- new ProcessBuilder(cmd: _*)
- .directory(arguments.baseDirectory)
- .redirectInput(Redirect.PIPE)
- processBuilder.environment.put(Terminal.TERMINAL_PROPS, props)
- val process = processBuilder.start()
- sbtProcess.set(process)
+ val bootSocketName =
+ BootServerSocket.socketLocation(arguments.baseDirectory.toPath.toRealPath())
+ var socket: Option[Socket] = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
+ val process = socket match {
+ case None =>
+ val term = Terminal.console
+ if (log) console.appendLog(Level.Info, "server was not detected. starting an instance")
+ val props =
+ Seq(
+ term.getWidth,
+ term.getHeight,
+ term.isAnsiSupported,
+ term.isColorEnabled,
+ term.isSupershellEnabled
+ ).mkString(",")
+ val cmd = arguments.sbtScript +: arguments.sbtArguments :+ BasicCommandStrings.CloseIOStreams
+ val processBuilder =
+ new ProcessBuilder(cmd: _*)
+ .directory(arguments.baseDirectory)
+ .redirectInput(Redirect.PIPE)
+ processBuilder.environment.put(Terminal.TERMINAL_PROPS, props)
+ val process = processBuilder.start()
+ sbtProcess.set(process)
+ Some(process)
+ case _ =>
+ if (log) console.appendLog(Level.Info, "sbt server is booting up")
+ None
+ }
val hook = new Thread(() => Option(sbtProcess.get).foreach(_.destroyForcibly()))
Runtime.getRuntime.addShutdownHook(hook)
- val stdout = process.getInputStream
- val stderr = process.getErrorStream
- val stdin = process.getOutputStream
+ val isWin = Properties.isWin
+ var gotInputBack = false
+ val readThreadAlive = new AtomicBoolean(true)
+ /*
+ * Socket.getInputStream.available doesn't always return a value greater than 0
+ * so it is necessary to read the process output from the socket on a background
+ * thread.
+ */
+ val readThread = new Thread("client-read-thread") {
+ setDaemon(true)
+ start()
+ override def run(): Unit = {
+ try {
+ while (readThreadAlive.get) {
+ socket.foreach { s =>
+ try {
+ s.getInputStream.read match {
+ case -1 | 0 => readThreadAlive.set(false)
+ case 2 => gotInputBack = true
+ case 3 if gotInputBack => readThreadAlive.set(false)
+ case i if gotInputBack => stdinBytes.offer(i)
+ case i => printStream.write(i)
+ }
+ } catch {
+ case e @ (_: IOException | _: InterruptedException) =>
+ readThreadAlive.set(false)
+ }
+ }
+ if (socket.isEmpty && readThreadAlive.get) {
+ try Thread.sleep(10)
+ catch { case _: InterruptedException => }
+ }
+ }
+ } catch { case e: IOException => e.printStackTrace(System.err) }
+ }
+ }
@tailrec
def blockUntilStart(): Unit = {
+ if (socket.isEmpty) {
+ socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption
+ }
val stop = try {
- while (stdout.available > 0) {
- val byte = stdout.read
- printStream.write(byte)
+ socket match {
+ case None =>
+ process.foreach { p =>
+ val output = p.getInputStream
+ while (output.available > 0) {
+ printStream.write(output.read())
+ }
+ }
+ case Some(s) =>
+ while (!gotInputBack && !stdinBytes.isEmpty && socket.isDefined) {
+ val out = s.getOutputStream
+ val b = stdinBytes.poll
+ // echo stdin during boot
+ printStream.write(b)
+ printStream.flush()
+ out.write(b)
+ out.flush()
+ }
}
- while (stderr.available > 0) {
- val byte = stderr.read
- errorStream.write(byte)
- }
- while (!stdinBytes.isEmpty) {
- stdin.write(stdinBytes.take)
- stdin.flush()
+ process.foreach { p =>
+ val error = p.getErrorStream
+ while (error.available > 0) {
+ errorStream.write(error.read())
+ }
}
false
- } catch {
- case _: IOException => true
- }
+ } catch { case e: IOException => true }
Thread.sleep(10)
- if (!portfile.exists && !stop) blockUntilStart()
- else {
- stdin.close()
- stdout.close()
- stderr.close()
- process.getOutputStream.close()
+ printStream.flush()
+ errorStream.flush()
+ /*
+ * If an earlier server process is launching, the process launched by this client
+ * will return with exit value 2. In that case, we can treat the process as alive
+ * even if it is actually dead.
+ */
+ val existsValidProcess = process.fold(socket.isDefined)(p => p.isAlive || p.exitValue == 2)
+ if (!portfile.exists && !stop && readThreadAlive.get && existsValidProcess) {
+ blockUntilStart()
+ } else {
+ socket.foreach { s =>
+ s.getInputStream.close()
+ s.getOutputStream.close()
+ s.close()
+ }
+ readThread.interrupt()
+ process.foreach { p =>
+ p.getOutputStream.close()
+ p.getErrorStream.close()
+ p.getInputStream.close()
+ }
}
}
@@ -247,6 +330,8 @@ class NetworkClient(
sbtProcess.set(null)
Util.ignoreResult(Runtime.getRuntime.removeShutdownHook(hook))
}
+ if (!portfile.exists()) throw new ServerFailedException
+ if (attached.get && !stdinBytes.isEmpty) Option(inputThread.get).foreach(_.drain())
}
/** Called on the response for a returning message. */
@@ -443,10 +528,16 @@ class NetworkClient(
}
}
- def connect(log: Boolean, prompt: Boolean): Unit = {
+ def connect(log: Boolean, prompt: Boolean): Boolean = {
if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR")
- init(prompt, retry = true)
- ()
+ try {
+ init(prompt, retry = true)
+ true
+ } catch {
+ case _: ServerFailedException =>
+ console.appendLog(Level.Error, "failed to connect to server")
+ false
+ }
}
private[this] val contHandler: () => Unit = () => {
@@ -505,7 +596,6 @@ class NetworkClient(
}
def getCompletions(query: String): Seq[String] = {
- connect(log = false, prompt = true)
val quoteCount = query.foldLeft(0) {
case (count, '"') => count + 1
case (count, _) => count
@@ -639,7 +729,10 @@ class NetworkClient(
stdinBytes.offer(-1)
val mainThread = interactiveThread.getAndSet(null)
if (mainThread != null && mainThread != Thread.currentThread) mainThread.interrupt
- connection.shutdown()
+ connectionHolder.get match {
+ case null =>
+ case c => c.shutdown()
+ }
Option(inputThread.get).foreach(_.interrupt())
} catch {
case t: Throwable => t.printStackTrace(); throw t
@@ -784,8 +877,8 @@ object NetworkClient {
useJNI,
)
try {
- client.connect(log = true, prompt = false)
- client.run()
+ if (client.connect(log = true, prompt = false)) client.run()
+ else 1
} catch { case _: Exception => 1 } finally client.close()
}
private def simpleClient(
@@ -857,7 +950,9 @@ object NetworkClient {
useJNI = useJNI,
)
try {
- val results = client.getCompletions(cmd)
+ val results =
+ if (client.connect(log = false, prompt = true)) client.getCompletions(cmd)
+ else Nil
out.println(results.sorted.distinct mkString "\n")
0
} catch { case _: Exception => 1 } finally client.close()
@@ -867,8 +962,8 @@ object NetworkClient {
try {
val client = new NetworkClient(configuration, parseArgs(arguments.toArray))
try {
- client.connect(log = true, prompt = false)
- client.run()
+ if (client.connect(log = true, prompt = false)) client.run()
+ else 1
} catch { case _: Throwable => 1 } finally client.close()
} catch {
case NonFatal(e) =>
diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala
index 8756fab7c..51f4ad4eb 100644
--- a/main/src/main/scala/sbt/Keys.scala
+++ b/main/src/main/scala/sbt/Keys.scala
@@ -549,6 +549,7 @@ object Keys {
val globalPluginUpdate = taskKey[UpdateReport]("A hook to get the UpdateReport of the global plugin.").withRank(DTask)
private[sbt] val taskCancelStrategy = settingKey[State => TaskCancellationStrategy]("Experimental task cancellation handler.").withRank(DTask)
private[sbt] val cacheStoreFactoryFactory = AttributeKey[CacheStoreFactoryFactory]("cache-store-factory-factory")
+ private[sbt] val bootServerSocket = AttributeKey[BootServerSocket]("boot-server-socket")
val fileCacheSize = settingKey[String]("The approximate maximum size in bytes of the cache used to store previous task results. For example, it could be set to \"256M\" to make the maximum size 256 megabytes.")
// Experimental in sbt 0.13.2 to enable grabbing semantic compile failures.
diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala
index 282d7b01f..a2c85eb6e 100644
--- a/main/src/main/scala/sbt/Main.scala
+++ b/main/src/main/scala/sbt/Main.scala
@@ -57,7 +57,7 @@ private[sbt] object xMain {
override def provider: AppProvider = config.provider()
}
}
- private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult =
+ private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = {
try {
import BasicCommandStrings.{ DashClient, DashDashClient, runEarly }
import BasicCommands.early
@@ -65,6 +65,10 @@ private[sbt] object xMain {
import sbt.internal.CommandStrings.{ BootCommand, DefaultsCommand, InitCommand }
import sbt.internal.client.NetworkClient
+ val bootServerSocket = getSocketOrExit(configuration) match {
+ case (_, Some(e)) => return e
+ case (s, _) => s
+ }
// if we detect -Dsbt.client=true or -client, run thin client.
val clientModByEnv = SysProp.client
val userCommands = configuration.arguments.map(_.trim)
@@ -73,6 +77,7 @@ private[sbt] object xMain {
if (userCommands.exists(isBsp)) {
BspClient.run(dealiasBaseDirectory(configuration))
} else {
+ bootServerSocket.foreach(l => Terminal.setBootStreams(l.inputStream, l.outputStream))
Terminal.withStreams {
if (clientModByEnv || userCommands.exists(isClient)) {
val args = userCommands.toList.filterNot(isClient)
@@ -80,20 +85,43 @@ private[sbt] object xMain {
Exit(0)
} else {
val closeStreams = userCommands.exists(_ == BasicCommandStrings.CloseIOStreams)
- val state = StandardMain
+ val state0 = StandardMain
.initialState(
dealiasBaseDirectory(configuration),
Seq(defaults, early),
runEarly(DefaultsCommand) :: runEarly(InitCommand) :: BootCommand :: Nil
)
.put(BasicKeys.closeIOStreams, closeStreams)
- StandardMain.runManaged(state)
+ val state = bootServerSocket match {
+ case Some(l) => state0.put(Keys.bootServerSocket, l)
+ case _ => state0
+ }
+ try StandardMain.runManaged(state)
+ finally bootServerSocket.foreach(_.close())
}
}
}
} finally {
ShutdownHooks.close()
}
+ }
+
+ private def getSocketOrExit(
+ configuration: xsbti.AppConfiguration
+ ): (Option[BootServerSocket], Option[Exit]) =
+ try (Some(new BootServerSocket(configuration)) -> None)
+ catch {
+ case _: ServerAlreadyBootingException if System.console != null =>
+ println("sbt server is already booting. Create a new server? y/n (default y)")
+ val exit = Terminal.get.withRawSystemIn(System.in.read) match {
+ case 110 => Some(Exit(1))
+ case _ => None
+ }
+ (None, exit)
+ case _: ServerAlreadyBootingException =>
+ if (SysProp.forceServerStart) (None, None)
+ else (None, Some(Exit(2)))
+ }
}
final class ScriptMain extends xsbti.AppMain {
@@ -805,8 +833,7 @@ object BuiltinCommands {
@tailrec
private[this] def doLoadFailed(s: State, loadArg: String): State = {
s.log.warn("Project loading failed: (r)etry, (q)uit, (l)ast, or (i)gnore? (default: r)")
- val terminal = Terminal.get
- val result = try terminal.withRawSystemIn(terminal.inputStream.read) match {
+ val result = try Terminal.get.withRawSystemIn(System.in.read) match {
case -1 => 'q'.toInt
case b => b
} catch { case _: ClosedChannelException => 'q' }
diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala
index 3be85d4e8..a4e048edc 100644
--- a/main/src/main/scala/sbt/internal/CommandExchange.scala
+++ b/main/src/main/scala/sbt/internal/CommandExchange.scala
@@ -248,9 +248,11 @@ private[sbt] final class CommandExchange {
server = None
firstInstance.set(false)
}
+ Terminal.setBootStreams(null, null)
if (s.get(BasicKeys.closeIOStreams).getOrElse(false)) Terminal.close()
+ s.get(Keys.bootServerSocket).foreach(_.close())
}
- s
+ s.remove(Keys.bootServerSocket)
}
def shutdown(): Unit = {
diff --git a/main/src/main/scala/sbt/internal/SysProp.scala b/main/src/main/scala/sbt/internal/SysProp.scala
index 473487609..fe2fd2a7f 100644
--- a/main/src/main/scala/sbt/internal/SysProp.scala
+++ b/main/src/main/scala/sbt/internal/SysProp.scala
@@ -79,6 +79,7 @@ object SysProp {
def allowRootDir: Boolean = getOrFalse("sbt.rootdir")
def legacyTestReport: Boolean = getOrFalse("sbt.testing.legacyreport")
def semanticdb: Boolean = getOrFalse("sbt.semanticdb")
+ def forceServerStart: Boolean = getOrFalse("sbt.server.forcestart")
def watchMode: String =
sys.props.get("sbt.watch.mode").getOrElse("auto")
diff --git a/protocol/src/main/scala/sbt/protocol/ClientSocket.scala b/protocol/src/main/scala/sbt/protocol/ClientSocket.scala
index d5c498eee..165148f2a 100644
--- a/protocol/src/main/scala/sbt/protocol/ClientSocket.scala
+++ b/protocol/src/main/scala/sbt/protocol/ClientSocket.scala
@@ -35,13 +35,13 @@ object ClientSocket {
t.token
}
val sk = uri.getScheme match {
- case "local" if isWindows =>
- (new Win32NamedPipeSocket("""\\.\pipe\""" + uri.getSchemeSpecificPart, useJNI): Socket)
- case "local" =>
- (new UnixDomainSocket(uri.getSchemeSpecificPart, useJNI): Socket)
- case "tcp" => new Socket(InetAddress.getByName(uri.getHost), uri.getPort)
- case _ => sys.error(s"Unsupported uri: $uri")
+ case "local" => localSocket(uri.getSchemeSpecificPart, useJNI)
+ case "tcp" => new Socket(InetAddress.getByName(uri.getHost), uri.getPort)
+ case _ => sys.error(s"Unsupported uri: $uri")
}
(sk, token)
}
+ def localSocket(name: String, useJNI: Boolean): Socket =
+ if (isWindows) new Win32NamedPipeSocket(s"\\\\.\\pipe\\$name", useJNI)
+ else new UnixDomainSocket(name, useJNI)
}
diff --git a/scripted-sbt-redux/src/main/scala/sbt/scriptedtest/RemoteSbtCreator.scala b/scripted-sbt-redux/src/main/scala/sbt/scriptedtest/RemoteSbtCreator.scala
index be1c67a56..41044871f 100644
--- a/scripted-sbt-redux/src/main/scala/sbt/scriptedtest/RemoteSbtCreator.scala
+++ b/scripted-sbt-redux/src/main/scala/sbt/scriptedtest/RemoteSbtCreator.scala
@@ -32,8 +32,9 @@ final class LauncherBasedRemoteSbtCreator(
def newRemote(server: IPC.Server) = {
val launcherJar = launcher.getAbsolutePath
val globalBase = "-Dsbt.global.base=" + (new File(directory, "global")).getAbsolutePath
+ val scripted = "-Dsbt.scripted=true"
val args = List("<" + server.port)
- val cmd = "java" :: launchOpts.toList ::: globalBase :: "-jar" :: launcherJar :: args ::: Nil
+ val cmd = "java" :: launchOpts.toList ::: globalBase :: scripted :: "-jar" :: launcherJar :: args ::: Nil
val io = BasicIO(false, log).withInput(_.close())
val p = Process(cmd, directory) run (io)
val thread = new Thread() { override def run() = { p.exitValue(); server.close() } }
@@ -52,11 +53,12 @@ final class RunFromSourceBasedRemoteSbtCreator(
) extends RemoteSbtCreator {
def newRemote(server: IPC.Server): Process = {
val globalBase = "-Dsbt.global.base=" + new File(directory, "global").getAbsolutePath
+ val scripted = "-Dsbt.scripted=true"
val mainClassName = "sbt.RunFromSourceMain"
val cpString = classpath.mkString(java.io.File.pathSeparator)
val args =
List(mainClassName, directory.toString, scalaVersion, sbtVersion, cpString, "<" + server.port)
- val cmd = "java" :: launchOpts.toList ::: globalBase :: "-cp" :: cpString :: args ::: Nil
+ val cmd = "java" :: launchOpts.toList ::: globalBase :: scripted :: "-cp" :: cpString :: args ::: Nil
val io = BasicIO(false, log).withInput(_.close())
val p = Process(cmd, directory) run (io)
val thread = new Thread() { override def run() = { p.exitValue(); server.close() } }