diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java index 1a9942ad9..b70bef611 100644 --- a/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java @@ -29,6 +29,8 @@ import java.nio.ByteBuffer; import java.net.Socket; +import java.util.concurrent.atomic.AtomicInteger; + /** * Implements a {@link Socket} backed by a native Unix domain socket. * @@ -41,6 +43,25 @@ public class NGUnixDomainSocket extends Socket { private final InputStream is; private final OutputStream os; + public NGUnixDomainSocket(String path) throws IOException { + try { + AtomicInteger fd = new AtomicInteger( + NGUnixDomainSocketLibrary.socket( + NGUnixDomainSocketLibrary.PF_LOCAL, + NGUnixDomainSocketLibrary.SOCK_STREAM, + 0)); + NGUnixDomainSocketLibrary.SockaddrUn address = + new NGUnixDomainSocketLibrary.SockaddrUn(path); + int socketFd = fd.get(); + NGUnixDomainSocketLibrary.connect(socketFd, address, address.size()); + this.fd = new ReferenceCountedFileDescriptor(socketFd); + this.is = new NGUnixDomainSocketInputStream(); + this.os = new NGUnixDomainSocketOutputStream(); + } catch (LastErrorException e) { + throw new IOException(e); + } + } + /** * Creates a Unix domain socket backed by a native file descriptor. */ diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java index 7e760d37a..4d781b6b6 100644 --- a/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java @@ -131,6 +131,8 @@ public class NGUnixDomainSocketLibrary { public static native int listen(int fd, int backlog) throws LastErrorException; public static native int accept(int fd, SockaddrUn address, IntByReference addressLen) throws LastErrorException; + public static native int connect(int fd, SockaddrUn address, int addressLen) + throws LastErrorException; public static native int read(int fd, ByteBuffer buffer, int count) throws LastErrorException; public static native int write(int fd, ByteBuffer buffer, int count) diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index c4d3b542c..0b8ee4b32 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -9,7 +9,7 @@ package sbt package internal package server -import java.io.File +import java.io.{ File, IOException } import java.net.{ SocketTimeoutException, InetAddress, ServerSocket, Socket } import java.util.concurrent.atomic.AtomicBoolean import java.nio.file.attribute.{ UserPrincipal, AclEntry, AclEntryPermission, AclEntryType } @@ -54,15 +54,17 @@ private[sbt] object Server { val serverThread = new Thread("sbt-socket-server") { override def run(): Unit = { Try { - ErrorHandling.translate(s"server failed to start on ${connection.shortName}. ") { - connection.connectionType match { - case ConnectionType.Local if isWindows => - new NGWin32NamedPipeServerSocket(pipeName) - case ConnectionType.Local => - prepareSocketfile() - new NGUnixDomainServerSocket(socketfile.getAbsolutePath) - case ConnectionType.Tcp => new ServerSocket(port, 50, InetAddress.getByName(host)) - } + connection.connectionType match { + case ConnectionType.Local if isWindows => + // Named pipe already has an exclusive lock. + addServerError(new NGWin32NamedPipeServerSocket(pipeName)) + case ConnectionType.Local => + tryClient(new NGUnixDomainSocket(socketfile.getAbsolutePath)) + prepareSocketfile() + addServerError(new NGUnixDomainServerSocket(socketfile.getAbsolutePath)) + case ConnectionType.Tcp => + tryClient(new Socket(InetAddress.getByName(host), port)) + addServerError(new ServerSocket(port, 50, InetAddress.getByName(host))) } } match { case Failure(e) => p.failure(e) @@ -87,6 +89,24 @@ private[sbt] object Server { } serverThread.start() + // Try the socket as a client to make sure that the server is not already up. + // f tries to connect to the server, and flip the result. + def tryClient(f: => Socket): Unit = { + if (portfile.exists) { + Try { f } match { + case Failure(e) => () + case Success(socket) => + socket.close() + throw new IOException("sbt server is already running.") + } + } else () + } + + def addServerError(f: => ServerSocket): ServerSocket = + ErrorHandling.translate(s"server failed to start on ${connection.shortName}. ") { + f + } + override def authenticate(challenge: String): Boolean = synchronized { if (token == challenge) { token = nextToken diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 476ea5f76..32175d256 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -152,11 +152,13 @@ private[sbt] final class CommandExchange { Await.ready(x.ready, Duration("10s")) x.ready.value match { case Some(Success(_)) => + // rememeber to shutdown only when the server comes up + server = Some(x) case Some(Failure(e)) => s.log.error(e.toString) + server = None case None => // this won't happen because we awaited } - server = Some(x) } s }