diff --git a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala index 5972cf623..b1aa51318 100644 --- a/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala +++ b/internal/util-logging/src/main/scala/sbt/internal/util/Terminal.scala @@ -11,7 +11,7 @@ import java.io.{ InputStream, OutputStream, PrintStream } import java.nio.channels.ClosedChannelException import java.util.Locale import java.util.concurrent.atomic.{ AtomicBoolean, AtomicReference } -import java.util.concurrent.{ CountDownLatch, Executors, LinkedBlockingQueue } +import java.util.concurrent.{ CountDownLatch, Executors, LinkedBlockingQueue, TimeUnit } import jline.DefaultTerminal2 import jline.console.ConsoleReader @@ -362,6 +362,7 @@ object Terminal { } private[this] val nonBlockingIn: WriteableInputStream = new WriteableInputStream(jline.TerminalFactory.get.wrapInIfNeeded(originalIn), "console") + private[this] val inputStream = new AtomicReference[InputStream](System.in) private[this] def withOut[T](f: => T): T = { try { @@ -397,11 +398,86 @@ object Terminal { */ private[this] val activeTerminal = new AtomicReference[Terminal](consoleTerminalHolder.get) jline.TerminalFactory.set(consoleTerminalHolder.get.toJLine) + + /** + * The boot input stream allows a remote client to forward input to the sbt process while + * it is still loading. It works by updating proxyInputStream to read from the + * value of bootInputStreamHolder if it is non-null as well as from the normal process + * console io (assuming there is console io). + */ + private[this] val bootInputStreamHolder = new AtomicReference[InputStream] + + /** + * The boot output stream allows sbt to relay the bytes written to stdout to one or + * more remote clients while the sbt build is loading and hasn't yet loaded a server. + * The output stream of TerminalConsole is updated to write to value of + * bootOutputStreamHolder when it is non-null as well as the normal process console + * output stream. + */ + private[this] val bootOutputStreamHolder = new AtomicReference[OutputStream] + private[sbt] def setBootStreams( + bootInputStream: InputStream, + bootOutputStream: OutputStream + ): Unit = { + bootInputStreamHolder.set(bootInputStream) + bootOutputStreamHolder.set(bootOutputStream) + } + private[this] object proxyInputStream extends InputStream { - def read(): Int = activeTerminal.get().inputStream.read() + private[this] val isScripted = System.getProperty("sbt.scripted", "false") == "true" + /* + * This is to handle the case when a remote client starts sbt and the build fails. + * We need to be able to consume input bytes from the remote client, but they + * haven't yet connected to the main server but may be connected to the + * BootServerSocket. Unfortunately there is no poll method on input stream that + * takes a duration so we have to manually implement that here. All of the input + * streams that we create in sbt are interruptible, so we can just poll each + * of the input streams and periodically interrupt the thread to switch between + * the two input streams. + */ + private class ReadThread extends Thread with AutoCloseable { + val result = new LinkedBlockingQueue[Integer] + setDaemon(true) + start() + val running = new AtomicBoolean(true) + override def run(): Unit = while (running.get) { + bootInputStreamHolder.get match { + case null => + case is => + def readFrom(inputStream: InputStream) = + try { + if (running.get) { + inputStream.read match { + case -1 => + case i => + result.put(i) + running.set(false) + } + } + } catch { case _: InterruptedException => } + readFrom(is) + readFrom(activeTerminal.get().inputStream) + } + } + override def close(): Unit = if (running.compareAndSet(true, false)) this.interrupt() + } + def read(): Int = { + if (isScripted) -1 + else if (bootInputStreamHolder.get == null) activeTerminal.get().inputStream.read() + else { + val thread = new ReadThread + @tailrec def poll(): Int = thread.result.poll(10, TimeUnit.MILLISECONDS) match { + case null => + thread.interrupt() + poll() + case i => i + } + poll() + } + } } private[this] object proxyOutputStream extends OutputStream { - private[this] def os = activeTerminal.get().outputStream + private[this] def os: OutputStream = activeTerminal.get().outputStream def write(byte: Int): Unit = { os.write(byte) os.flush() @@ -611,12 +687,28 @@ object Terminal { } def throwIfClosed[R](f: => R): R = if (isStopped.get) throw new ClosedChannelException else f + private val combinedOutputStream = new OutputStream { + override def write(b: Int): Unit = { + Option(bootOutputStreamHolder.get).foreach(_.write(b)) + out.write(b) + } + override def write(b: Array[Byte]): Unit = write(b, 0, b.length) + override def write(b: Array[Byte], offset: Int, len: Int): Unit = { + Option(bootOutputStreamHolder.get).foreach(_.write(b, offset, len)) + out.write(b, offset, len) + } + override def flush(): Unit = { + Option(bootOutputStreamHolder.get).foreach(_.flush()) + out.flush() + } + } + override val outputStream = new OutputStream { override def write(b: Int): Unit = throwIfClosed { writeLock.synchronized { if (b == Int.MinValue) currentLine.set(new ArrayBuffer[Byte]) else doWrite(Vector((b & 0xFF).toByte)) - if (b == 10) out.flush() + if (b == 10) combinedOutputStream.flush() } } override def write(b: Array[Byte]): Unit = throwIfClosed(write(b, 0, b.length)) @@ -629,6 +721,7 @@ object Terminal { } } } + override def flush(): Unit = combinedOutputStream.flush() private[this] val clear = s"$CursorLeft1000$ClearScreenAfterCursor" private def doWrite(bytes: Seq[Byte]): Unit = { def doWrite(b: Byte): Unit = out.write(b & 0xFF) @@ -638,8 +731,8 @@ object Terminal { progressState.clearBytes() val cl = currentLine.get if (buf.nonEmpty && isAnsiSupported && cl.isEmpty) clear.getBytes.foreach(doWrite) - out.write(buf.toArray) - out.write(10) + combinedOutputStream.write(buf.toArray) + combinedOutputStream.write(10) currentLine.get match { case s if s.nonEmpty => currentLine.set(new ArrayBuffer[Byte]) case _ => @@ -654,9 +747,9 @@ object Terminal { clear.getBytes.foreach(doWrite) } cl ++= remaining - out.write(remaining.toArray) + combinedOutputStream.write(remaining.toArray) } - out.flush() + combinedOutputStream.flush() } } override private[sbt] val printStream: PrintStream = new PrintStream(outputStream, true) @@ -681,7 +774,7 @@ object Terminal { Some(new String(bytes.toArray).replaceAllLiterally(ClearScreenAfterCursor, "")) } - private[this] val rawPrintStream: PrintStream = new PrintStream(out, true) { + private[this] val rawPrintStream: PrintStream = new PrintStream(combinedOutputStream, true) { override def close(): Unit = {} } override def withPrintStream[T](f: PrintStream => T): T = diff --git a/main-command/src/main/java/sbt/internal/BootServerSocket.java b/main-command/src/main/java/sbt/internal/BootServerSocket.java new file mode 100644 index 000000000..06ba74de2 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/BootServerSocket.java @@ -0,0 +1,318 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UnsupportedEncodingException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.Socket; +import java.net.ServerSocket; +import java.net.SocketException; +import java.net.SocketTimeoutException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import net.openhft.hashing.LongHashFunction; +import org.scalasbt.ipcsocket.UnixDomainServerSocket; +import org.scalasbt.ipcsocket.UnixDomainSocket; +import org.scalasbt.ipcsocket.Win32NamedPipeServerSocket; +import org.scalasbt.ipcsocket.Win32NamedPipeSocket; +import org.scalasbt.ipcsocket.Win32SecurityLevel; +import xsbti.AppConfiguration; + +/** + * A BootServerSocket is used for remote clients to connect to sbt for io while sbt is still loading + * the build. There are two scenarios in which this functionality is needed: + * + *

1. client a starts an sbt server and then client b tries to connect to the server before the + * server has loaded. Presently, client b will try to start a new server even though there is one + * booting. This can cause a java process leak because the second server launched by client b is + * unable to create a server because there is an existing portfile by the time it starts up. + * + *

2. a remote client initiates a reboot command. Reboot causes sbt to shutdown the server which + * makes the client disconnect. Since sbt does not start the server until the project has + * successfully loaded, there is no way for the client to see the output of the server. This is + * particularly problematic if loading fails because the server will be stuck waiting for input that + * will not be forthcoming. + * + *

To address these issues, the BootServerSocket can be used to immediately create a server + * socket before sbt even starts loading the build. It works by creating a local socket either in + * project/target/SOCK_NAME or a windows named pipe with name SOCK_NAME where SOCK_NAME is computed + * as the hash of the project's base directory (for disambiguation in the windows case). If the + * server can't create a server socket because there is already one running, it either prompts the + * user if they want to start a new server even if there is already one running if there is a + * console available or exits with the status code 2 which indicates that there is another sbt + * process starting up. + * + *

Once the server socket is created, it listens for new client connections. When a client + * connects, the server will forward its input and output to the client via Terminal.setBootStreams + * which updates the Terminal.proxyOutputStream to forward all bytes written to the + * BootServerSocket's outputStream which in turn writes the output to each of the connected clients. + * Input is handed similarly. + * + *

When the server finishes loading, it closes the boot server socket. + * + *

BootServerSocket is implemented in java so that it can be classloaded as quickly as possible. + */ +public class BootServerSocket implements AutoCloseable { + private ServerSocket serverSocket = null; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final AtomicBoolean running = new AtomicBoolean(false); + private final AtomicInteger threadId = new AtomicInteger(1); + private final Future acceptFuture; + private final ExecutorService service = + Executors.newCachedThreadPool( + r -> new Thread(r, "boot-server-socket-thread-" + threadId.getAndIncrement())); + private final Set clientSockets = ConcurrentHashMap.newKeySet(); + private final Object lock = new Object(); + private final LinkedBlockingQueue clientSocketReads = new LinkedBlockingQueue<>(); + private final Path socketFile; + + private class ClientSocket implements AutoCloseable { + final Socket socket; + final AtomicBoolean alive = new AtomicBoolean(true); + final Future future; + private final LinkedBlockingQueue bytes = new LinkedBlockingQueue(); + private final AtomicBoolean closed = new AtomicBoolean(false); + + @SuppressWarnings("deprecation") + ClientSocket(final Socket socket) { + this.socket = socket; + clientSockets.add(this); + Future f = null; + try { + f = + service.submit( + () -> { + try { + final InputStream inputStream = socket.getInputStream(); + while (alive.get()) { + try { + int b = inputStream.read(); + if (b != -1) { + bytes.put(b); + clientSocketReads.put(ClientSocket.this); + } else { + alive.set(false); + } + } catch (IOException e) { + alive.set(false); + } + } + } catch (final Exception ex) { + } + }); + } catch (final RejectedExecutionException e) { + alive.set(false); + } + future = f; + } + + private void write(final int i) { + try { + if (alive.get()) socket.getOutputStream().write(i); + } catch (final IOException e) { + alive.set(false); + close(); + } + } + + private void write(final byte[] b, final int offset, final int len) { + try { + if (alive.get()) socket.getOutputStream().write(b, offset, len); + } catch (final IOException e) { + alive.set(false); + close(); + } + } + + private void flush() { + try { + socket.getOutputStream().flush(); + } catch (final IOException e) { + alive.set(false); + close(); + } + } + + @SuppressWarnings("EmptyCatchBlock") + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + if (alive.get()) { + write(2); + bytes.forEach(this::write); + bytes.clear(); + write(3); + flush(); + } + alive.set(false); + if (future != null) future.cancel(true); + try { + socket.getOutputStream().close(); + socket.getInputStream().close(); + // Windows is very slow to close the socket for whatever reason + // We close the server socket anyway, so this should die then. + if (!System.getProperty("os.name", "").toLowerCase().startsWith("win")) socket.close(); + } catch (final IOException e) { + } + clientSockets.remove(this); + } + } + } + + private final Object writeLock = new Object(); + + public InputStream inputStream() { + return inputStream; + } + + private final InputStream inputStream = + new InputStream() { + @Override + public int read() { + try { + ClientSocket clientSocket = clientSocketReads.take(); + return clientSocket.bytes.take(); + } catch (final InterruptedException e) { + return -1; + } + } + }; + private final OutputStream outputStream = + new OutputStream() { + @Override + public void write(final int b) { + synchronized (lock) { + clientSockets.forEach(cs -> cs.write(b)); + } + } + + @Override + public void write(final byte[] b) { + write(b, 0, b.length); + } + + @Override + public void write(final byte[] b, final int offset, final int len) { + synchronized (lock) { + clientSockets.forEach(cs -> cs.write(b, offset, len)); + } + } + + @Override + public void flush() { + synchronized (lock) { + clientSockets.forEach(cs -> cs.flush()); + } + } + }; + + public OutputStream outputStream() { + return outputStream; + } + + private final Runnable acceptRunnable = + () -> { + try { + serverSocket.setSoTimeout(5000); + while (running.get()) { + try { + ClientSocket clientSocket = new ClientSocket(serverSocket.accept()); + } catch (final SocketTimeoutException e) { + } catch (final IOException e) { + running.set(false); + } + } + } catch (final SocketException e) { + } + }; + + public BootServerSocket(final AppConfiguration configuration) + throws ServerAlreadyBootingException, IOException { + final Path base = configuration.baseDirectory().toPath().toRealPath(); + final Path target = base.resolve("project").resolve("target"); + if (!isWindows) { + socketFile = Paths.get(socketLocation(base)); + Files.createDirectories(target); + } else { + socketFile = null; + } + serverSocket = newSocket(socketLocation(base)); + if (serverSocket != null) { + running.set(true); + acceptFuture = service.submit(acceptRunnable); + } else { + closed.set(true); + acceptFuture = null; + } + } + + public static String socketLocation(final Path base) throws UnsupportedEncodingException { + final Path target = base.resolve("project").resolve("target"); + if (isWindows) { + long hash = LongHashFunction.farmNa().hashBytes(target.toString().getBytes("UTF-8")); + return "sbt-load" + hash; + } else { + return base.relativize(target.resolve("sbt-load.sock")).toString(); + } + } + + @SuppressWarnings("EmptyCatchBlock") + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + // avoid concurrent modification exception + clientSockets.forEach(ClientSocket::close); + if (acceptFuture != null) acceptFuture.cancel(true); + service.shutdownNow(); + try { + if (serverSocket != null) serverSocket.close(); + } catch (final IOException e) { + } + try { + if (socketFile != null) Files.deleteIfExists(socketFile); + } catch (final IOException e) { + } + } + } + + static final boolean isWindows = + System.getProperty("os.name", "").toLowerCase().startsWith("win"); + + static ServerSocket newSocket(final String sock) throws ServerAlreadyBootingException { + ServerSocket socket = null; + String name = socketName(sock); + try { + if (!isWindows) Files.deleteIfExists(Paths.get(sock)); + socket = + isWindows + ? new Win32NamedPipeServerSocket(name, false, Win32SecurityLevel.OWNER_DACL) + : new UnixDomainServerSocket(name); + return socket; + } catch (final IOException e) { + throw new ServerAlreadyBootingException(); + } + } + + private static String socketName(String sock) { + return isWindows ? "\\\\.\\pipe\\" + sock : sock; + } +} diff --git a/main-command/src/main/java/sbt/internal/ServerAlreadyBootingException.java b/main-command/src/main/java/sbt/internal/ServerAlreadyBootingException.java new file mode 100644 index 000000000..070dc8b92 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/ServerAlreadyBootingException.java @@ -0,0 +1,10 @@ +/* + * sbt + * Copyright 2011 - 2018, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal; + +public class ServerAlreadyBootingException extends Exception {} diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index 1d52eb8b7..233ca0526 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -34,7 +34,7 @@ import scala.annotation.tailrec import scala.collection.mutable import scala.concurrent.duration._ import scala.util.control.NonFatal -import scala.util.{ Failure, Properties, Success } +import scala.util.{ Failure, Properties, Success, Try } import Serialization.{ CancelAll, attach, @@ -108,14 +108,11 @@ class NetworkClient( } private[this] val stdinBytes = new LinkedBlockingQueue[Int] - private[this] val stdin: InputStream = new InputStream { - override def available(): Int = stdinBytes.size - override def read: Int = stdinBytes.take - } private[this] val inputThread = new AtomicReference(new RawInputThread) private[this] val exitClean = new AtomicBoolean(true) private[this] val sbtProcess = new AtomicReference[Process](null) private class ConnectionRefusedException(t: Throwable) extends Throwable(t) + private class ServerFailedException extends Exception // Open server connection based on the portfile def init(prompt: Boolean, retry: Boolean): ServerConnection = @@ -138,9 +135,23 @@ class NetworkClient( forkServer(portfile, log = true) } } - val (sk, tkn) = - try mkSocket(portfile) - catch { case e: IOException => throw new ConnectionRefusedException(e) } + @tailrec def connect(attempt: Int): (Socket, Option[String]) = { + val res = try Some(mkSocket(portfile)) + catch { + // This catches a pipe busy exception which can happen if two windows clients + // attempt to connect in rapid succession + case e: IOException if e.getMessage.contains("Couldn't open") && attempt < 10 => None + case e: IOException => throw new ConnectionRefusedException(e) + } + res match { + case Some(r) => r + case None => + // Use a random sleep to spread out the competing processes + Thread.sleep(new java.util.Random().nextInt(20).toLong) + connect(attempt + 1) + } + } + val (sk, tkn) = connect(0) val conn = new ServerConnection(sk) { override def onNotification(msg: JsonRpcNotificationMessage): Unit = { msg.method match { @@ -188,57 +199,129 @@ class NetworkClient( * This instance must be shutdown explicitly via `sbt -client shutdown` */ def forkServer(portfile: File, log: Boolean): Unit = { - if (log) console.appendLog(Level.Info, "server was not detected. starting an instance") - val term = Terminal.console - val props = - Seq( - term.getWidth, - term.getHeight, - term.isAnsiSupported, - term.isColorEnabled, - term.isSupershellEnabled - ).mkString(",") - - val cmd = arguments.sbtScript +: arguments.sbtArguments :+ BasicCommandStrings.CloseIOStreams - val processBuilder = - new ProcessBuilder(cmd: _*) - .directory(arguments.baseDirectory) - .redirectInput(Redirect.PIPE) - processBuilder.environment.put(Terminal.TERMINAL_PROPS, props) - val process = processBuilder.start() - sbtProcess.set(process) + val bootSocketName = + BootServerSocket.socketLocation(arguments.baseDirectory.toPath.toRealPath()) + var socket: Option[Socket] = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + val process = socket match { + case None => + val term = Terminal.console + if (log) console.appendLog(Level.Info, "server was not detected. starting an instance") + val props = + Seq( + term.getWidth, + term.getHeight, + term.isAnsiSupported, + term.isColorEnabled, + term.isSupershellEnabled + ).mkString(",") + val cmd = arguments.sbtScript +: arguments.sbtArguments :+ BasicCommandStrings.CloseIOStreams + val processBuilder = + new ProcessBuilder(cmd: _*) + .directory(arguments.baseDirectory) + .redirectInput(Redirect.PIPE) + processBuilder.environment.put(Terminal.TERMINAL_PROPS, props) + val process = processBuilder.start() + sbtProcess.set(process) + Some(process) + case _ => + if (log) console.appendLog(Level.Info, "sbt server is booting up") + None + } val hook = new Thread(() => Option(sbtProcess.get).foreach(_.destroyForcibly())) Runtime.getRuntime.addShutdownHook(hook) - val stdout = process.getInputStream - val stderr = process.getErrorStream - val stdin = process.getOutputStream + val isWin = Properties.isWin + var gotInputBack = false + val readThreadAlive = new AtomicBoolean(true) + /* + * Socket.getInputStream.available doesn't always return a value greater than 0 + * so it is necessary to read the process output from the socket on a background + * thread. + */ + val readThread = new Thread("client-read-thread") { + setDaemon(true) + start() + override def run(): Unit = { + try { + while (readThreadAlive.get) { + socket.foreach { s => + try { + s.getInputStream.read match { + case -1 | 0 => readThreadAlive.set(false) + case 2 => gotInputBack = true + case 3 if gotInputBack => readThreadAlive.set(false) + case i if gotInputBack => stdinBytes.offer(i) + case i => printStream.write(i) + } + } catch { + case e @ (_: IOException | _: InterruptedException) => + readThreadAlive.set(false) + } + } + if (socket.isEmpty && readThreadAlive.get) { + try Thread.sleep(10) + catch { case _: InterruptedException => } + } + } + } catch { case e: IOException => e.printStackTrace(System.err) } + } + } @tailrec def blockUntilStart(): Unit = { + if (socket.isEmpty) { + socket = Try(ClientSocket.localSocket(bootSocketName, useJNI)).toOption + } val stop = try { - while (stdout.available > 0) { - val byte = stdout.read - printStream.write(byte) + socket match { + case None => + process.foreach { p => + val output = p.getInputStream + while (output.available > 0) { + printStream.write(output.read()) + } + } + case Some(s) => + while (!gotInputBack && !stdinBytes.isEmpty && socket.isDefined) { + val out = s.getOutputStream + val b = stdinBytes.poll + // echo stdin during boot + printStream.write(b) + printStream.flush() + out.write(b) + out.flush() + } } - while (stderr.available > 0) { - val byte = stderr.read - errorStream.write(byte) - } - while (!stdinBytes.isEmpty) { - stdin.write(stdinBytes.take) - stdin.flush() + process.foreach { p => + val error = p.getErrorStream + while (error.available > 0) { + errorStream.write(error.read()) + } } false - } catch { - case _: IOException => true - } + } catch { case e: IOException => true } Thread.sleep(10) - if (!portfile.exists && !stop) blockUntilStart() - else { - stdin.close() - stdout.close() - stderr.close() - process.getOutputStream.close() + printStream.flush() + errorStream.flush() + /* + * If an earlier server process is launching, the process launched by this client + * will return with exit value 2. In that case, we can treat the process as alive + * even if it is actually dead. + */ + val existsValidProcess = process.fold(socket.isDefined)(p => p.isAlive || p.exitValue == 2) + if (!portfile.exists && !stop && readThreadAlive.get && existsValidProcess) { + blockUntilStart() + } else { + socket.foreach { s => + s.getInputStream.close() + s.getOutputStream.close() + s.close() + } + readThread.interrupt() + process.foreach { p => + p.getOutputStream.close() + p.getErrorStream.close() + p.getInputStream.close() + } } } @@ -247,6 +330,8 @@ class NetworkClient( sbtProcess.set(null) Util.ignoreResult(Runtime.getRuntime.removeShutdownHook(hook)) } + if (!portfile.exists()) throw new ServerFailedException + if (attached.get && !stdinBytes.isEmpty) Option(inputThread.get).foreach(_.drain()) } /** Called on the response for a returning message. */ @@ -443,10 +528,16 @@ class NetworkClient( } } - def connect(log: Boolean, prompt: Boolean): Unit = { + def connect(log: Boolean, prompt: Boolean): Boolean = { if (log) console.appendLog(Level.Info, "entering *experimental* thin client - BEEP WHIRR") - init(prompt, retry = true) - () + try { + init(prompt, retry = true) + true + } catch { + case _: ServerFailedException => + console.appendLog(Level.Error, "failed to connect to server") + false + } } private[this] val contHandler: () => Unit = () => { @@ -505,7 +596,6 @@ class NetworkClient( } def getCompletions(query: String): Seq[String] = { - connect(log = false, prompt = true) val quoteCount = query.foldLeft(0) { case (count, '"') => count + 1 case (count, _) => count @@ -639,7 +729,10 @@ class NetworkClient( stdinBytes.offer(-1) val mainThread = interactiveThread.getAndSet(null) if (mainThread != null && mainThread != Thread.currentThread) mainThread.interrupt - connection.shutdown() + connectionHolder.get match { + case null => + case c => c.shutdown() + } Option(inputThread.get).foreach(_.interrupt()) } catch { case t: Throwable => t.printStackTrace(); throw t @@ -784,8 +877,8 @@ object NetworkClient { useJNI, ) try { - client.connect(log = true, prompt = false) - client.run() + if (client.connect(log = true, prompt = false)) client.run() + else 1 } catch { case _: Exception => 1 } finally client.close() } private def simpleClient( @@ -857,7 +950,9 @@ object NetworkClient { useJNI = useJNI, ) try { - val results = client.getCompletions(cmd) + val results = + if (client.connect(log = false, prompt = true)) client.getCompletions(cmd) + else Nil out.println(results.sorted.distinct mkString "\n") 0 } catch { case _: Exception => 1 } finally client.close() @@ -867,8 +962,8 @@ object NetworkClient { try { val client = new NetworkClient(configuration, parseArgs(arguments.toArray)) try { - client.connect(log = true, prompt = false) - client.run() + if (client.connect(log = true, prompt = false)) client.run() + else 1 } catch { case _: Throwable => 1 } finally client.close() } catch { case NonFatal(e) => diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index 8756fab7c..51f4ad4eb 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -549,6 +549,7 @@ object Keys { val globalPluginUpdate = taskKey[UpdateReport]("A hook to get the UpdateReport of the global plugin.").withRank(DTask) private[sbt] val taskCancelStrategy = settingKey[State => TaskCancellationStrategy]("Experimental task cancellation handler.").withRank(DTask) private[sbt] val cacheStoreFactoryFactory = AttributeKey[CacheStoreFactoryFactory]("cache-store-factory-factory") + private[sbt] val bootServerSocket = AttributeKey[BootServerSocket]("boot-server-socket") val fileCacheSize = settingKey[String]("The approximate maximum size in bytes of the cache used to store previous task results. For example, it could be set to \"256M\" to make the maximum size 256 megabytes.") // Experimental in sbt 0.13.2 to enable grabbing semantic compile failures. diff --git a/main/src/main/scala/sbt/Main.scala b/main/src/main/scala/sbt/Main.scala index 282d7b01f..a2c85eb6e 100644 --- a/main/src/main/scala/sbt/Main.scala +++ b/main/src/main/scala/sbt/Main.scala @@ -57,7 +57,7 @@ private[sbt] object xMain { override def provider: AppProvider = config.provider() } } - private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = + private[sbt] def run(configuration: xsbti.AppConfiguration): xsbti.MainResult = { try { import BasicCommandStrings.{ DashClient, DashDashClient, runEarly } import BasicCommands.early @@ -65,6 +65,10 @@ private[sbt] object xMain { import sbt.internal.CommandStrings.{ BootCommand, DefaultsCommand, InitCommand } import sbt.internal.client.NetworkClient + val bootServerSocket = getSocketOrExit(configuration) match { + case (_, Some(e)) => return e + case (s, _) => s + } // if we detect -Dsbt.client=true or -client, run thin client. val clientModByEnv = SysProp.client val userCommands = configuration.arguments.map(_.trim) @@ -73,6 +77,7 @@ private[sbt] object xMain { if (userCommands.exists(isBsp)) { BspClient.run(dealiasBaseDirectory(configuration)) } else { + bootServerSocket.foreach(l => Terminal.setBootStreams(l.inputStream, l.outputStream)) Terminal.withStreams { if (clientModByEnv || userCommands.exists(isClient)) { val args = userCommands.toList.filterNot(isClient) @@ -80,20 +85,43 @@ private[sbt] object xMain { Exit(0) } else { val closeStreams = userCommands.exists(_ == BasicCommandStrings.CloseIOStreams) - val state = StandardMain + val state0 = StandardMain .initialState( dealiasBaseDirectory(configuration), Seq(defaults, early), runEarly(DefaultsCommand) :: runEarly(InitCommand) :: BootCommand :: Nil ) .put(BasicKeys.closeIOStreams, closeStreams) - StandardMain.runManaged(state) + val state = bootServerSocket match { + case Some(l) => state0.put(Keys.bootServerSocket, l) + case _ => state0 + } + try StandardMain.runManaged(state) + finally bootServerSocket.foreach(_.close()) } } } } finally { ShutdownHooks.close() } + } + + private def getSocketOrExit( + configuration: xsbti.AppConfiguration + ): (Option[BootServerSocket], Option[Exit]) = + try (Some(new BootServerSocket(configuration)) -> None) + catch { + case _: ServerAlreadyBootingException if System.console != null => + println("sbt server is already booting. Create a new server? y/n (default y)") + val exit = Terminal.get.withRawSystemIn(System.in.read) match { + case 110 => Some(Exit(1)) + case _ => None + } + (None, exit) + case _: ServerAlreadyBootingException => + if (SysProp.forceServerStart) (None, None) + else (None, Some(Exit(2))) + } } final class ScriptMain extends xsbti.AppMain { @@ -805,8 +833,7 @@ object BuiltinCommands { @tailrec private[this] def doLoadFailed(s: State, loadArg: String): State = { s.log.warn("Project loading failed: (r)etry, (q)uit, (l)ast, or (i)gnore? (default: r)") - val terminal = Terminal.get - val result = try terminal.withRawSystemIn(terminal.inputStream.read) match { + val result = try Terminal.get.withRawSystemIn(System.in.read) match { case -1 => 'q'.toInt case b => b } catch { case _: ClosedChannelException => 'q' } diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 3be85d4e8..a4e048edc 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -248,9 +248,11 @@ private[sbt] final class CommandExchange { server = None firstInstance.set(false) } + Terminal.setBootStreams(null, null) if (s.get(BasicKeys.closeIOStreams).getOrElse(false)) Terminal.close() + s.get(Keys.bootServerSocket).foreach(_.close()) } - s + s.remove(Keys.bootServerSocket) } def shutdown(): Unit = { diff --git a/main/src/main/scala/sbt/internal/SysProp.scala b/main/src/main/scala/sbt/internal/SysProp.scala index 473487609..fe2fd2a7f 100644 --- a/main/src/main/scala/sbt/internal/SysProp.scala +++ b/main/src/main/scala/sbt/internal/SysProp.scala @@ -79,6 +79,7 @@ object SysProp { def allowRootDir: Boolean = getOrFalse("sbt.rootdir") def legacyTestReport: Boolean = getOrFalse("sbt.testing.legacyreport") def semanticdb: Boolean = getOrFalse("sbt.semanticdb") + def forceServerStart: Boolean = getOrFalse("sbt.server.forcestart") def watchMode: String = sys.props.get("sbt.watch.mode").getOrElse("auto") diff --git a/protocol/src/main/scala/sbt/protocol/ClientSocket.scala b/protocol/src/main/scala/sbt/protocol/ClientSocket.scala index d5c498eee..165148f2a 100644 --- a/protocol/src/main/scala/sbt/protocol/ClientSocket.scala +++ b/protocol/src/main/scala/sbt/protocol/ClientSocket.scala @@ -35,13 +35,13 @@ object ClientSocket { t.token } val sk = uri.getScheme match { - case "local" if isWindows => - (new Win32NamedPipeSocket("""\\.\pipe\""" + uri.getSchemeSpecificPart, useJNI): Socket) - case "local" => - (new UnixDomainSocket(uri.getSchemeSpecificPart, useJNI): Socket) - case "tcp" => new Socket(InetAddress.getByName(uri.getHost), uri.getPort) - case _ => sys.error(s"Unsupported uri: $uri") + case "local" => localSocket(uri.getSchemeSpecificPart, useJNI) + case "tcp" => new Socket(InetAddress.getByName(uri.getHost), uri.getPort) + case _ => sys.error(s"Unsupported uri: $uri") } (sk, token) } + def localSocket(name: String, useJNI: Boolean): Socket = + if (isWindows) new Win32NamedPipeSocket(s"\\\\.\\pipe\\$name", useJNI) + else new UnixDomainSocket(name, useJNI) } diff --git a/scripted-sbt-redux/src/main/scala/sbt/scriptedtest/RemoteSbtCreator.scala b/scripted-sbt-redux/src/main/scala/sbt/scriptedtest/RemoteSbtCreator.scala index be1c67a56..41044871f 100644 --- a/scripted-sbt-redux/src/main/scala/sbt/scriptedtest/RemoteSbtCreator.scala +++ b/scripted-sbt-redux/src/main/scala/sbt/scriptedtest/RemoteSbtCreator.scala @@ -32,8 +32,9 @@ final class LauncherBasedRemoteSbtCreator( def newRemote(server: IPC.Server) = { val launcherJar = launcher.getAbsolutePath val globalBase = "-Dsbt.global.base=" + (new File(directory, "global")).getAbsolutePath + val scripted = "-Dsbt.scripted=true" val args = List("<" + server.port) - val cmd = "java" :: launchOpts.toList ::: globalBase :: "-jar" :: launcherJar :: args ::: Nil + val cmd = "java" :: launchOpts.toList ::: globalBase :: scripted :: "-jar" :: launcherJar :: args ::: Nil val io = BasicIO(false, log).withInput(_.close()) val p = Process(cmd, directory) run (io) val thread = new Thread() { override def run() = { p.exitValue(); server.close() } } @@ -52,11 +53,12 @@ final class RunFromSourceBasedRemoteSbtCreator( ) extends RemoteSbtCreator { def newRemote(server: IPC.Server): Process = { val globalBase = "-Dsbt.global.base=" + new File(directory, "global").getAbsolutePath + val scripted = "-Dsbt.scripted=true" val mainClassName = "sbt.RunFromSourceMain" val cpString = classpath.mkString(java.io.File.pathSeparator) val args = List(mainClassName, directory.toString, scalaVersion, sbtVersion, cpString, "<" + server.port) - val cmd = "java" :: launchOpts.toList ::: globalBase :: "-cp" :: cpString :: args ::: Nil + val cmd = "java" :: launchOpts.toList ::: globalBase :: scripted :: "-cp" :: cpString :: args ::: Nil val io = BasicIO(false, log).withInput(_.close()) val p = Process(cmd, directory) run (io) val thread = new Thread() { override def run() = { p.exitValue(); server.close() } }