From 2b2c1f05684d05f589bafd5b0f0107e93e308c8a Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Tue, 5 Dec 2017 06:32:50 -0500 Subject: [PATCH 1/2] prevent "shutdown" when server didn't come up --- main/src/main/scala/sbt/internal/CommandExchange.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 476ea5f76..32175d256 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -152,11 +152,13 @@ private[sbt] final class CommandExchange { Await.ready(x.ready, Duration("10s")) x.ready.value match { case Some(Success(_)) => + // rememeber to shutdown only when the server comes up + server = Some(x) case Some(Failure(e)) => s.log.error(e.toString) + server = None case None => // this won't happen because we awaited } - server = Some(x) } s } From 322f9b31cdea4daab1d8029b2622b50069fe7555 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Tue, 5 Dec 2017 08:07:20 -0500 Subject: [PATCH 2/2] Only the first session starts the server Currently the server will try to start even if there are existing sbt sessions. This causes the second session to take over the server for Unix domain socket. This adds a check before the server comes up and make sure that the socket is not taken. --- .../java/sbt/internal/NGUnixDomainSocket.java | 21 ++++++++++ .../internal/NGUnixDomainSocketLibrary.java | 2 + .../scala/sbt/internal/server/Server.scala | 40 ++++++++++++++----- 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java index 1a9942ad9..b70bef611 100644 --- a/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java @@ -29,6 +29,8 @@ import java.nio.ByteBuffer; import java.net.Socket; +import java.util.concurrent.atomic.AtomicInteger; + /** * Implements a {@link Socket} backed by a native Unix domain socket. * @@ -41,6 +43,25 @@ public class NGUnixDomainSocket extends Socket { private final InputStream is; private final OutputStream os; + public NGUnixDomainSocket(String path) throws IOException { + try { + AtomicInteger fd = new AtomicInteger( + NGUnixDomainSocketLibrary.socket( + NGUnixDomainSocketLibrary.PF_LOCAL, + NGUnixDomainSocketLibrary.SOCK_STREAM, + 0)); + NGUnixDomainSocketLibrary.SockaddrUn address = + new NGUnixDomainSocketLibrary.SockaddrUn(path); + int socketFd = fd.get(); + NGUnixDomainSocketLibrary.connect(socketFd, address, address.size()); + this.fd = new ReferenceCountedFileDescriptor(socketFd); + this.is = new NGUnixDomainSocketInputStream(); + this.os = new NGUnixDomainSocketOutputStream(); + } catch (LastErrorException e) { + throw new IOException(e); + } + } + /** * Creates a Unix domain socket backed by a native file descriptor. */ diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java index 7e760d37a..4d781b6b6 100644 --- a/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java @@ -131,6 +131,8 @@ public class NGUnixDomainSocketLibrary { public static native int listen(int fd, int backlog) throws LastErrorException; public static native int accept(int fd, SockaddrUn address, IntByReference addressLen) throws LastErrorException; + public static native int connect(int fd, SockaddrUn address, int addressLen) + throws LastErrorException; public static native int read(int fd, ByteBuffer buffer, int count) throws LastErrorException; public static native int write(int fd, ByteBuffer buffer, int count) diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index c4d3b542c..0b8ee4b32 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -9,7 +9,7 @@ package sbt package internal package server -import java.io.File +import java.io.{ File, IOException } import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket } import java.util.concurrent.atomic.AtomicBoolean import java.nio.file.attribute.{ UserPrincipal, AclEntry, AclEntryPermission, AclEntryType } @@ -54,15 +54,17 @@ private[sbt] object Server { val serverThread = new Thread("sbt-socket-server") { override def run(): Unit = { Try { - ErrorHandling.translate(s"server failed to start on ${connection.shortName}. ") { - connection.connectionType match { - case ConnectionType.Local if isWindows => - new NGWin32NamedPipeServerSocket(pipeName) - case ConnectionType.Local => - prepareSocketfile() - new NGUnixDomainServerSocket(socketfile.getAbsolutePath) - case ConnectionType.Tcp => new ServerSocket(port, 50, InetAddress.getByName(host)) - } + connection.connectionType match { + case ConnectionType.Local if isWindows => + // Named pipe already has an exclusive lock. + addServerError(new NGWin32NamedPipeServerSocket(pipeName)) + case ConnectionType.Local => + tryClient(new NGUnixDomainSocket(socketfile.getAbsolutePath)) + prepareSocketfile() + addServerError(new NGUnixDomainServerSocket(socketfile.getAbsolutePath)) + case ConnectionType.Tcp => + tryClient(new Socket(InetAddress.getByName(host), port)) + addServerError(new ServerSocket(port, 50, InetAddress.getByName(host))) } } match { case Failure(e) => p.failure(e) @@ -87,6 +89,24 @@ private[sbt] object Server { } serverThread.start() + // Try the socket as a client to make sure that the server is not already up. + // f tries to connect to the server, and flip the result. + def tryClient(f: => Socket): Unit = { + if (portfile.exists) { + Try { f } match { + case Failure(e) => () + case Success(socket) => + socket.close() + throw new IOException("sbt server is already running.") + } + } else () + } + + def addServerError(f: => ServerSocket): ServerSocket = + ErrorHandling.translate(s"server failed to start on ${connection.shortName}. ") { + f + } + override def authenticate(challenge: String): Boolean = synchronized { if (token == challenge) { token = nextToken