diff --git a/build.sbt b/build.sbt index c5b2c6cfc..61842c40c 100644 --- a/build.sbt +++ b/build.sbt @@ -55,7 +55,7 @@ def commonSettings: Seq[Setting[_]] = concurrentRestrictions in Global += Util.testExclusiveRestriction, testOptions in Test += Tests.Argument(TestFrameworks.ScalaCheck, "-w", "1"), testOptions in Test += Tests.Argument(TestFrameworks.ScalaCheck, "-verbosity", "2"), - javacOptions in compile ++= Seq("-target", "6", "-source", "6", "-Xlint", "-Xlint:-serial"), + javacOptions in compile ++= Seq("-Xlint", "-Xlint:-serial"), crossScalaVersions := Seq(baseScalaVersion), bintrayPackage := (bintrayPackage in ThisBuild).value, bintrayRepository := (bintrayRepository in ThisBuild).value, @@ -309,7 +309,8 @@ lazy val commandProj = (project in file("main-command")) .settings( testedBaseSettings, name := "Command", - libraryDependencies ++= Seq(launcherInterface, sjsonNewScalaJson.value, templateResolverApi), + libraryDependencies ++= Seq(launcherInterface, sjsonNewScalaJson.value, templateResolverApi, + jna, jnaPlatform), managedSourceDirectories in Compile += baseDirectory.value / "src" / "main" / "contraband-scala", sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala", @@ -324,7 +325,11 @@ lazy val commandProj = (project in file("main-command")) exclude[ReversedMissingMethodProblem]("sbt.internal.CommandChannel.*"), // Added an overload to reboot. The overload is private[sbt]. exclude[ReversedMissingMethodProblem]("sbt.StateOps.reboot"), - ) + ), + unmanagedSources in (Compile, headerCreate) := { + val old = (unmanagedSources in (Compile, headerCreate)).value + old filterNot { x => (x.getName startsWith "NG") || (x.getName == "ReferenceCountedFileDescriptor.java") } + }, ) .configure( addSbtIO, diff --git a/main-command/src/main/contraband-scala/ConnectionTypeFormats.scala b/main-command/src/main/contraband-scala/ConnectionTypeFormats.scala new file mode 100644 index 000000000..4562f75b9 --- /dev/null +++ b/main-command/src/main/contraband-scala/ConnectionTypeFormats.scala @@ -0,0 +1,28 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait ConnectionTypeFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val ConnectionTypeFormat: JsonFormat[sbt.ConnectionType] = new JsonFormat[sbt.ConnectionType] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.ConnectionType = { + jsOpt match { + case Some(js) => + unbuilder.readString(js) match { + case "Local" => sbt.ConnectionType.Local + case "Tcp" => sbt.ConnectionType.Tcp + } + case None => + deserializationError("Expected JsString but found None") + } + } + override def write[J](obj: sbt.ConnectionType, builder: Builder[J]): Unit = { + val str = obj match { + case sbt.ConnectionType.Local => "Local" + case sbt.ConnectionType.Tcp => "Tcp" + } + builder.writeString(str) + } +} +} diff --git a/main-command/src/main/contraband-scala/sbt/ConnectionType.scala b/main-command/src/main/contraband-scala/sbt/ConnectionType.scala new file mode 100644 index 000000000..af50ed2e9 --- /dev/null +++ b/main-command/src/main/contraband-scala/sbt/ConnectionType.scala @@ -0,0 +1,13 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt +sealed abstract class ConnectionType extends Serializable +object ConnectionType { + + /** This uses Unix domain socket on POSIX, and named pipe on Windows. */ + case object Local extends ConnectionType + case object Tcp extends ConnectionType +} diff --git a/main-command/src/main/contraband/state.contra b/main-command/src/main/contraband/state.contra index 79d0bcaab..2737ce8ab 100644 --- a/main-command/src/main/contraband/state.contra +++ b/main-command/src/main/contraband/state.contra @@ -16,3 +16,10 @@ type CommandSource { enum ServerAuthentication { Token } + +enum ConnectionType { + ## This uses Unix domain socket on POSIX, and named pipe on Windows. + Local + Tcp + # Ssh +} diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java b/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java new file mode 100644 index 000000000..89d3bcf43 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java @@ -0,0 +1,178 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainServerSocket.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicInteger; + +import com.sun.jna.LastErrorException; +import com.sun.jna.ptr.IntByReference; + +/** + * Implements a {@link ServerSocket} which binds to a local Unix domain socket + * and returns instances of {@link NGUnixDomainSocket} from + * {@link #accept()}. + */ +public class NGUnixDomainServerSocket extends ServerSocket { + private static final int DEFAULT_BACKLOG = 50; + + // We use an AtomicInteger to prevent a race in this situation which + // could happen if fd were just an int: + // + // Thread 1 -> NGUnixDomainServerSocket.accept() + // -> lock this + // -> check isBound and isClosed + // -> unlock this + // -> descheduled while still in method + // Thread 2 -> NGUnixDomainServerSocket.close() + // -> lock this + // -> check isClosed + // -> NGUnixDomainSocketLibrary.close(fd) + // -> now fd is invalid + // -> unlock this + // Thread 1 -> re-scheduled while still in method + // -> NGUnixDomainSocketLibrary.accept(fd, which is invalid and maybe re-used) + // + // By using an AtomicInteger, we'll set this to -1 after it's closed, which + // will cause the accept() call above to cleanly fail instead of possibly + // being called on an unrelated fd (which may or may not fail). + private final AtomicInteger fd; + + private final int backlog; + private boolean isBound; + private boolean isClosed; + + public static class NGUnixDomainServerSocketAddress extends SocketAddress { + private final String path; + + public NGUnixDomainServerSocketAddress(String path) { + this.path = path; + } + + public String getPath() { + return path; + } + } + + /** + * Constructs an unbound Unix domain server socket. + */ + public NGUnixDomainServerSocket() throws IOException { + this(DEFAULT_BACKLOG, null); + } + + /** + * Constructs an unbound Unix domain server socket with the specified listen backlog. + */ + public NGUnixDomainServerSocket(int backlog) throws IOException { + this(backlog, null); + } + + /** + * Constructs and binds a Unix domain server socket to the specified path. + */ + public NGUnixDomainServerSocket(String path) throws IOException { + this(DEFAULT_BACKLOG, path); + } + + /** + * Constructs and binds a Unix domain server socket to the specified path + * with the specified listen backlog. + */ + public NGUnixDomainServerSocket(int backlog, String path) throws IOException { + try { + fd = new AtomicInteger( + NGUnixDomainSocketLibrary.socket( + NGUnixDomainSocketLibrary.PF_LOCAL, + NGUnixDomainSocketLibrary.SOCK_STREAM, + 0)); + this.backlog = backlog; + if (path != null) { + bind(new NGUnixDomainServerSocketAddress(path)); + } + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public synchronized void bind(SocketAddress endpoint) throws IOException { + if (!(endpoint instanceof NGUnixDomainServerSocketAddress)) { + throw new IllegalArgumentException( + "endpoint must be an instance of NGUnixDomainServerSocketAddress"); + } + if (isBound) { + throw new IllegalStateException("Socket is already bound"); + } + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + NGUnixDomainServerSocketAddress unEndpoint = (NGUnixDomainServerSocketAddress) endpoint; + NGUnixDomainSocketLibrary.SockaddrUn address = + new NGUnixDomainSocketLibrary.SockaddrUn(unEndpoint.getPath()); + try { + int socketFd = fd.get(); + NGUnixDomainSocketLibrary.bind(socketFd, address, address.size()); + NGUnixDomainSocketLibrary.listen(socketFd, backlog); + isBound = true; + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public Socket accept() throws IOException { + // We explicitly do not make this method synchronized, since the + // call to NGUnixDomainSocketLibrary.accept() will block + // indefinitely, causing another thread's call to close() to deadlock. + synchronized (this) { + if (!isBound) { + throw new IllegalStateException("Socket is not bound"); + } + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + } + try { + NGUnixDomainSocketLibrary.SockaddrUn sockaddrUn = + new NGUnixDomainSocketLibrary.SockaddrUn(); + IntByReference addressLen = new IntByReference(); + addressLen.setValue(sockaddrUn.size()); + int clientFd = NGUnixDomainSocketLibrary.accept(fd.get(), sockaddrUn, addressLen); + return new NGUnixDomainSocket(clientFd); + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public synchronized void close() throws IOException { + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + try { + // Ensure any pending call to accept() fails. + NGUnixDomainSocketLibrary.close(fd.getAndSet(-1)); + isClosed = true; + } catch (LastErrorException e) { + throw new IOException(e); + } + } +} diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java new file mode 100644 index 000000000..1a9942ad9 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java @@ -0,0 +1,171 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainSocket.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +import java.nio.ByteBuffer; + +import java.net.Socket; + +/** + * Implements a {@link Socket} backed by a native Unix domain socket. + * + * Instances of this class always return {@code null} for + * {@link Socket#getInetAddress()}, {@link Socket#getLocalAddress()}, + * {@link Socket#getLocalSocketAddress()}, {@link Socket#getRemoteSocketAddress()}. + */ +public class NGUnixDomainSocket extends Socket { + private final ReferenceCountedFileDescriptor fd; + private final InputStream is; + private final OutputStream os; + + /** + * Creates a Unix domain socket backed by a native file descriptor. + */ + public NGUnixDomainSocket(int fd) { + this.fd = new ReferenceCountedFileDescriptor(fd); + this.is = new NGUnixDomainSocketInputStream(); + this.os = new NGUnixDomainSocketOutputStream(); + } + + public InputStream getInputStream() { + return is; + } + + public OutputStream getOutputStream() { + return os; + } + + public void shutdownInput() throws IOException { + doShutdown(NGUnixDomainSocketLibrary.SHUT_RD); + } + + public void shutdownOutput() throws IOException { + doShutdown(NGUnixDomainSocketLibrary.SHUT_WR); + } + + private void doShutdown(int how) throws IOException { + try { + int socketFd = fd.acquire(); + if (socketFd != -1) { + NGUnixDomainSocketLibrary.shutdown(socketFd, how); + } + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + + public void close() throws IOException { + super.close(); + try { + // This might not close the FD right away. In case we are about + // to read or write on another thread, it will delay the close + // until the read or write completes, to prevent the FD from + // being re-used for a different purpose and the other thread + // reading from a different FD. + fd.close(); + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + private class NGUnixDomainSocketInputStream extends InputStream { + public int read() throws IOException { + ByteBuffer buf = ByteBuffer.allocate(1); + int result; + if (doRead(buf) == 0) { + result = -1; + } else { + // Make sure to & with 0xFF to avoid sign extension + result = 0xFF & buf.get(); + } + return result; + } + + public int read(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return 0; + } + ByteBuffer buf = ByteBuffer.wrap(b, off, len); + int result = doRead(buf); + if (result == 0) { + result = -1; + } + return result; + } + + private int doRead(ByteBuffer buf) throws IOException { + try { + int fdToRead = fd.acquire(); + if (fdToRead == -1) { + return -1; + } + return NGUnixDomainSocketLibrary.read(fdToRead, buf, buf.remaining()); + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + } + + private class NGUnixDomainSocketOutputStream extends OutputStream { + + public void write(int b) throws IOException { + ByteBuffer buf = ByteBuffer.allocate(1); + buf.put(0, (byte) (0xFF & b)); + doWrite(buf); + } + + public void write(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return; + } + ByteBuffer buf = ByteBuffer.wrap(b, off, len); + doWrite(buf); + } + + private void doWrite(ByteBuffer buf) throws IOException { + try { + int fdToWrite = fd.acquire(); + if (fdToWrite == -1) { + return; + } + int ret = NGUnixDomainSocketLibrary.write(fdToWrite, buf, buf.remaining()); + if (ret != buf.remaining()) { + // This shouldn't happen with standard blocking Unix domain sockets. + throw new IOException("Could not write " + buf.remaining() + " bytes as requested " + + "(wrote " + ret + " bytes instead)"); + } + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + } +} diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java new file mode 100644 index 000000000..7e760d37a --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java @@ -0,0 +1,140 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainSocketLibrary.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; +import com.sun.jna.Native; +import com.sun.jna.Platform; +import com.sun.jna.Structure; +import com.sun.jna.Union; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; + +/** + * Utility class to bridge native Unix domain socket calls to Java using JNA. + */ +public class NGUnixDomainSocketLibrary { + public static final int PF_LOCAL = 1; + public static final int AF_LOCAL = 1; + public static final int SOCK_STREAM = 1; + + public static final int SHUT_RD = 0; + public static final int SHUT_WR = 1; + + // Utility class, do not instantiate. + private NGUnixDomainSocketLibrary() { } + + // BSD platforms write a length byte at the start of struct sockaddr_un. + private static final boolean HAS_SUN_LEN = + Platform.isMac() || Platform.isFreeBSD() || Platform.isNetBSD() || + Platform.isOpenBSD() || Platform.iskFreeBSD(); + + /** + * Bridges {@code struct sockaddr_un} to and from native code. + */ + public static class SockaddrUn extends Structure implements Structure.ByReference { + /** + * On BSD platforms, the {@code sun_len} and {@code sun_family} values in + * {@code struct sockaddr_un}. + */ + public static class SunLenAndFamily extends Structure { + public byte sunLen; + public byte sunFamily; + + protected List getFieldOrder() { + return Arrays.asList(new String[] { "sunLen", "sunFamily" }); + } + } + + /** + * On BSD platforms, {@code sunLenAndFamily} will be present. + * On other platforms, only {@code sunFamily} will be present. + */ + public static class SunFamily extends Union { + public SunLenAndFamily sunLenAndFamily; + public short sunFamily; + } + + public SunFamily sunFamily = new SunFamily(); + public byte[] sunPath = new byte[104]; + + /** + * Constructs an empty {@code struct sockaddr_un}. + */ + public SockaddrUn() { + if (HAS_SUN_LEN) { + sunFamily.sunLenAndFamily = new SunLenAndFamily(); + sunFamily.setType(SunLenAndFamily.class); + } else { + sunFamily.setType(Short.TYPE); + } + allocateMemory(); + } + + /** + * Constructs a {@code struct sockaddr_un} with a path whose bytes are encoded + * using the default encoding of the platform. + */ + public SockaddrUn(String path) throws IOException { + byte[] pathBytes = path.getBytes(); + if (pathBytes.length > sunPath.length - 1) { + throw new IOException("Cannot fit name [" + path + "] in maximum unix domain socket length"); + } + System.arraycopy(pathBytes, 0, sunPath, 0, pathBytes.length); + sunPath[pathBytes.length] = (byte) 0; + if (HAS_SUN_LEN) { + int len = fieldOffset("sunPath") + pathBytes.length; + sunFamily.sunLenAndFamily = new SunLenAndFamily(); + sunFamily.sunLenAndFamily.sunLen = (byte) len; + sunFamily.sunLenAndFamily.sunFamily = AF_LOCAL; + sunFamily.setType(SunLenAndFamily.class); + } else { + sunFamily.sunFamily = AF_LOCAL; + sunFamily.setType(Short.TYPE); + } + allocateMemory(); + } + + protected List getFieldOrder() { + return Arrays.asList(new String[] { "sunFamily", "sunPath" }); + } + } + + static { + Native.register(Platform.C_LIBRARY_NAME); + } + + public static native int socket(int domain, int type, int protocol) throws LastErrorException; + public static native int bind(int fd, SockaddrUn address, int addressLen) + throws LastErrorException; + public static native int listen(int fd, int backlog) throws LastErrorException; + public static native int accept(int fd, SockaddrUn address, IntByReference addressLen) + throws LastErrorException; + public static native int read(int fd, ByteBuffer buffer, int count) + throws LastErrorException; + public static native int write(int fd, ByteBuffer buffer, int count) + throws LastErrorException; + public static native int close(int fd) throws LastErrorException; + public static native int shutdown(int fd, int how) throws LastErrorException; +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java new file mode 100644 index 000000000..ba535691f --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java @@ -0,0 +1,90 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeLibrary.java + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import java.nio.ByteBuffer; + +import com.sun.jna.*; +import com.sun.jna.platform.win32.WinNT; +import com.sun.jna.platform.win32.WinNT.*; +import com.sun.jna.platform.win32.WinBase.*; +import com.sun.jna.ptr.IntByReference; + +import com.sun.jna.win32.W32APIOptions; + +public interface NGWin32NamedPipeLibrary extends WinNT { + int PIPE_ACCESS_DUPLEX = 3; + int PIPE_UNLIMITED_INSTANCES = 255; + int FILE_FLAG_FIRST_PIPE_INSTANCE = 524288; + + NGWin32NamedPipeLibrary INSTANCE = + (NGWin32NamedPipeLibrary) Native.loadLibrary( + "kernel32", + NGWin32NamedPipeLibrary.class, + W32APIOptions.UNICODE_OPTIONS); + + HANDLE CreateNamedPipe( + String lpName, + int dwOpenMode, + int dwPipeMode, + int nMaxInstances, + int nOutBufferSize, + int nInBufferSize, + int nDefaultTimeOut, + SECURITY_ATTRIBUTES lpSecurityAttributes); + boolean ConnectNamedPipe( + HANDLE hNamedPipe, + Pointer lpOverlapped); + boolean DisconnectNamedPipe( + HANDLE hObject); + boolean ReadFile( + HANDLE hFile, + Memory lpBuffer, + int nNumberOfBytesToRead, + IntByReference lpNumberOfBytesRead, + Pointer lpOverlapped); + boolean WriteFile( + HANDLE hFile, + ByteBuffer lpBuffer, + int nNumberOfBytesToWrite, + IntByReference lpNumberOfBytesWritten, + Pointer lpOverlapped); + boolean CloseHandle( + HANDLE hObject); + boolean GetOverlappedResult( + HANDLE hFile, + Pointer lpOverlapped, + IntByReference lpNumberOfBytesTransferred, + boolean wait); + boolean CancelIoEx( + HANDLE hObject, + Pointer lpOverlapped); + HANDLE CreateEvent( + SECURITY_ATTRIBUTES lpEventAttributes, + boolean bManualReset, + boolean bInitialState, + String lpName); + int WaitForSingleObject( + HANDLE hHandle, + int dwMilliseconds + ); + + int GetLastError(); +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java new file mode 100644 index 000000000..137d9b5dc --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java @@ -0,0 +1,173 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeServerSocket.java + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.platform.win32.WinBase; +import com.sun.jna.platform.win32.WinError; +import com.sun.jna.platform.win32.WinNT; +import com.sun.jna.platform.win32.WinNT.HANDLE; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.LinkedBlockingQueue; + +public class NGWin32NamedPipeServerSocket extends ServerSocket { + private static final NGWin32NamedPipeLibrary API = NGWin32NamedPipeLibrary.INSTANCE; + private static final String WIN32_PIPE_PREFIX = "\\\\.\\pipe\\"; + private static final int BUFFER_SIZE = 65535; + private final LinkedBlockingQueue openHandles; + private final LinkedBlockingQueue connectedHandles; + private final NGWin32NamedPipeSocket.CloseCallback closeCallback; + private final String path; + private final int maxInstances; + private final HANDLE lockHandle; + + public NGWin32NamedPipeServerSocket(String path) throws IOException { + this(NGWin32NamedPipeLibrary.PIPE_UNLIMITED_INSTANCES, path); + } + + public NGWin32NamedPipeServerSocket(int maxInstances, String path) throws IOException { + this.openHandles = new LinkedBlockingQueue<>(); + this.connectedHandles = new LinkedBlockingQueue<>(); + this.closeCallback = handle -> { + if (connectedHandles.remove(handle)) { + closeConnectedPipe(handle, false); + } + if (openHandles.remove(handle)) { + closeOpenPipe(handle); + } + }; + this.maxInstances = maxInstances; + if (!path.startsWith(WIN32_PIPE_PREFIX)) { + this.path = WIN32_PIPE_PREFIX + path; + } else { + this.path = path; + } + String lockPath = this.path + "_lock"; + lockHandle = API.CreateNamedPipe( + lockPath, + NGWin32NamedPipeLibrary.FILE_FLAG_FIRST_PIPE_INSTANCE | NGWin32NamedPipeLibrary.PIPE_ACCESS_DUPLEX, + 0, + 1, + BUFFER_SIZE, + BUFFER_SIZE, + 0, + null); + if (lockHandle == NGWin32NamedPipeLibrary.INVALID_HANDLE_VALUE) { + throw new IOException(String.format("Could not create lock for %s, error %d", lockPath, API.GetLastError())); + } else { + if (!API.DisconnectNamedPipe(lockHandle)) { + throw new IOException(String.format("Could not disconnect lock %d", API.GetLastError())); + } + } + + } + + public void bind(SocketAddress endpoint) throws IOException { + throw new IOException("Win32 named pipes do not support bind(), pass path to constructor"); + } + + public Socket accept() throws IOException { + HANDLE handle = API.CreateNamedPipe( + path, + NGWin32NamedPipeLibrary.PIPE_ACCESS_DUPLEX | WinNT.FILE_FLAG_OVERLAPPED, + 0, + maxInstances, + BUFFER_SIZE, + BUFFER_SIZE, + 0, + null); + if (handle == NGWin32NamedPipeLibrary.INVALID_HANDLE_VALUE) { + throw new IOException(String.format("Could not create named pipe, error %d", API.GetLastError())); + } + openHandles.add(handle); + + HANDLE connWaitable = API.CreateEvent(null, true, false, null); + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = connWaitable; + olap.write(); + + boolean immediate = API.ConnectNamedPipe(handle, olap.getPointer()); + if (immediate) { + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } + + int connectError = API.GetLastError(); + if (connectError == WinError.ERROR_PIPE_CONNECTED) { + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else if (connectError == WinError.ERROR_NO_DATA) { + // Client has connected and disconnected between CreateNamedPipe() and ConnectNamedPipe() + // connection is broken, but it is returned it avoid loop here. + // Actual error will happen for NGSession when it will try to read/write from/to pipe + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else if (connectError == WinError.ERROR_IO_PENDING) { + if (!API.GetOverlappedResult(handle, olap.getPointer(), new IntByReference(), true)) { + openHandles.remove(handle); + closeOpenPipe(handle); + throw new IOException("GetOverlappedResult() failed for connect operation: " + API.GetLastError()); + } + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else { + throw new IOException("ConnectNamedPipe() failed with: " + connectError); + } + } + + public void close() throws IOException { + try { + List handlesToClose = new ArrayList<>(); + openHandles.drainTo(handlesToClose); + for (HANDLE handle : handlesToClose) { + closeOpenPipe(handle); + } + + List handlesToDisconnect = new ArrayList<>(); + connectedHandles.drainTo(handlesToDisconnect); + for (HANDLE handle : handlesToDisconnect) { + closeConnectedPipe(handle, true); + } + } finally { + API.CloseHandle(lockHandle); + } + } + + private void closeOpenPipe(HANDLE handle) throws IOException { + API.CancelIoEx(handle, null); + API.CloseHandle(handle); + } + + private void closeConnectedPipe(HANDLE handle, boolean shutdown) throws IOException { + if (!shutdown) { + API.WaitForSingleObject(handle, 10000); + } + API.DisconnectNamedPipe(handle); + API.CloseHandle(handle); + } +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java new file mode 100644 index 000000000..b22bb6bbf --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java @@ -0,0 +1,172 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeSocket.java +// Made change in `read` to read just the amount of bytes available. + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.Memory; +import com.sun.jna.platform.win32.WinBase; +import com.sun.jna.platform.win32.WinError; +import com.sun.jna.platform.win32.WinNT.HANDLE; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.nio.ByteBuffer; + +public class NGWin32NamedPipeSocket extends Socket { + private static final NGWin32NamedPipeLibrary API = NGWin32NamedPipeLibrary.INSTANCE; + private final HANDLE handle; + private final CloseCallback closeCallback; + private final InputStream is; + private final OutputStream os; + private final HANDLE readerWaitable; + private final HANDLE writerWaitable; + + interface CloseCallback { + void onNamedPipeSocketClose(HANDLE handle) throws IOException; + } + + public NGWin32NamedPipeSocket( + HANDLE handle, + NGWin32NamedPipeSocket.CloseCallback closeCallback) throws IOException { + this.handle = handle; + this.closeCallback = closeCallback; + this.readerWaitable = API.CreateEvent(null, true, false, null); + if (readerWaitable == null) { + throw new IOException("CreateEvent() failed "); + } + writerWaitable = API.CreateEvent(null, true, false, null); + if (writerWaitable == null) { + throw new IOException("CreateEvent() failed "); + } + this.is = new NGWin32NamedPipeSocketInputStream(handle); + this.os = new NGWin32NamedPipeSocketOutputStream(handle); + } + + @Override + public InputStream getInputStream() { + return is; + } + + @Override + public OutputStream getOutputStream() { + return os; + } + + @Override + public void close() throws IOException { + closeCallback.onNamedPipeSocketClose(handle); + } + + @Override + public void shutdownInput() throws IOException { + } + + @Override + public void shutdownOutput() throws IOException { + } + + private class NGWin32NamedPipeSocketInputStream extends InputStream { + private final HANDLE handle; + + NGWin32NamedPipeSocketInputStream(HANDLE handle) { + this.handle = handle; + } + + @Override + public int read() throws IOException { + int result; + byte[] b = new byte[1]; + if (read(b) == 0) { + result = -1; + } else { + result = 0xFF & b[0]; + } + return result; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Memory readBuffer = new Memory(len); + + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = readerWaitable; + olap.write(); + + boolean immediate = API.ReadFile(handle, readBuffer, len, null, olap.getPointer()); + if (!immediate) { + int lastError = API.GetLastError(); + if (lastError != WinError.ERROR_IO_PENDING) { + throw new IOException("ReadFile() failed: " + lastError); + } + } + + IntByReference read = new IntByReference(); + if (!API.GetOverlappedResult(handle, olap.getPointer(), read, true)) { + int lastError = API.GetLastError(); + throw new IOException("GetOverlappedResult() failed for read operation: " + lastError); + } + int actualLen = read.getValue(); + byte[] byteArray = readBuffer.getByteArray(0, actualLen); + System.arraycopy(byteArray, 0, b, off, actualLen); + return actualLen; + } + } + + private class NGWin32NamedPipeSocketOutputStream extends OutputStream { + private final HANDLE handle; + + NGWin32NamedPipeSocketOutputStream(HANDLE handle) { + this.handle = handle; + } + + @Override + public void write(int b) throws IOException { + write(new byte[]{(byte) (0xFF & b)}); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + ByteBuffer data = ByteBuffer.wrap(b, off, len); + + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = writerWaitable; + olap.write(); + + boolean immediate = API.WriteFile(handle, data, len, null, olap.getPointer()); + if (!immediate) { + int lastError = API.GetLastError(); + if (lastError != WinError.ERROR_IO_PENDING) { + throw new IOException("WriteFile() failed: " + lastError); + } + } + IntByReference written = new IntByReference(); + if (!API.GetOverlappedResult(handle, olap.getPointer(), written, true)) { + int lastError = API.GetLastError(); + throw new IOException("GetOverlappedResult() failed for write operation: " + lastError); + } + if (written.getValue() != len) { + throw new IOException("WriteFile() wrote less bytes than requested"); + } + } + } +} diff --git a/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java b/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java new file mode 100644 index 000000000..7fb5d9d53 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java @@ -0,0 +1,82 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/ReferenceCountedFileDescriptor.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; + +import java.io.IOException; + +/** + * Encapsulates a file descriptor plus a reference count to ensure close requests + * only close the file descriptor once the last reference to the file descriptor + * is released. + * + * If not explicitly closed, the file descriptor will be closed when + * this object is finalized. + */ +public class ReferenceCountedFileDescriptor { + private int fd; + private int fdRefCount; + private boolean closePending; + + public ReferenceCountedFileDescriptor(int fd) { + this.fd = fd; + this.fdRefCount = 0; + this.closePending = false; + } + + protected void finalize() throws IOException { + close(); + } + + public synchronized int acquire() { + fdRefCount++; + return fd; + } + + public synchronized void release() throws IOException { + fdRefCount--; + if (fdRefCount == 0 && closePending && fd != -1) { + doClose(); + } + } + + public synchronized void close() throws IOException { + if (fd == -1 || closePending) { + return; + } + + if (fdRefCount == 0) { + doClose(); + } else { + // Another thread has the FD. We'll close it when they release the reference. + closePending = true; + } + } + + private void doClose() throws IOException { + try { + NGUnixDomainSocketLibrary.close(fd); + fd = -1; + } catch (LastErrorException e) { + throw new IOException(e); + } + } +} diff --git a/main-command/src/main/scala/sbt/BasicKeys.scala b/main-command/src/main/scala/sbt/BasicKeys.scala index ef6bb00cd..42b68daa4 100644 --- a/main-command/src/main/scala/sbt/BasicKeys.scala +++ b/main-command/src/main/scala/sbt/BasicKeys.scala @@ -33,6 +33,11 @@ object BasicKeys { "Method of authenticating server command.", 10000) + val serverConnectionType = + AttributeKey[ConnectionType]("serverConnectionType", + "The wire protocol for the server command.", + 10000) + private[sbt] val interactive = AttributeKey[Boolean]( "interactive", "True if commands are currently being entered from an interactive environment.", diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 3a7c2785b..c4d3b542c 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -17,13 +17,14 @@ import java.security.SecureRandom import java.math.BigInteger import scala.concurrent.{ Future, Promise } import scala.util.{ Try, Success, Failure } -import sbt.internal.util.ErrorHandling import sbt.internal.protocol.{ PortFile, TokenFile } import sbt.util.Logger import sbt.io.IO import sbt.io.syntax._ import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter } import sbt.internal.protocol.codec._ +import sbt.internal.util.ErrorHandling +import sbt.internal.util.Util.isWindows private[sbt] sealed trait ServerInstance { def shutdown(): Unit @@ -38,31 +39,37 @@ private[sbt] object Server { with TokenFileFormats object JsonProtocol extends JsonProtocol - def start(host: String, - port: Int, + def start(connection: ServerConnection, onIncomingSocket: (Socket, ServerInstance) => Unit, - auth: Set[ServerAuthentication], - portfile: File, - tokenfile: File, log: Logger): ServerInstance = new ServerInstance { self => + import connection._ val running = new AtomicBoolean(false) val p: Promise[Unit] = Promise[Unit]() val ready: Future[Unit] = p.future private[this] val rand = new SecureRandom private[this] var token: String = nextToken + private[this] var serverSocketOpt: Option[ServerSocket] = None val serverThread = new Thread("sbt-socket-server") { override def run(): Unit = { Try { - ErrorHandling.translate(s"server failed to start on $host:$port. ") { - new ServerSocket(port, 50, InetAddress.getByName(host)) + ErrorHandling.translate(s"server failed to start on ${connection.shortName}. ") { + connection.connectionType match { + case ConnectionType.Local if isWindows => + new NGWin32NamedPipeServerSocket(pipeName) + case ConnectionType.Local => + prepareSocketfile() + new NGUnixDomainServerSocket(socketfile.getAbsolutePath) + case ConnectionType.Tcp => new ServerSocket(port, 50, InetAddress.getByName(host)) + } } } match { case Failure(e) => p.failure(e) case Success(serverSocket) => serverSocket.setSoTimeout(5000) - log.info(s"sbt server started at $host:$port") + serverSocketOpt = Option(serverSocket) + log.info(s"sbt server started at ${connection.shortName}") writePortfile() running.set(true) p.success(()) @@ -74,6 +81,7 @@ private[sbt] object Server { case _: SocketTimeoutException => // its ok } } + serverSocket.close() } } } @@ -106,7 +114,7 @@ private[sbt] object Server { private[this] def writeTokenfile(): Unit = { import JsonProtocol._ - val uri = s"tcp://$host:$port" + val uri = connection.shortName val t = TokenFile(uri, token) val jsonToken = Converter.toJson(t).get @@ -141,7 +149,7 @@ private[sbt] object Server { private[this] def writePortfile(): Unit = { import JsonProtocol._ - val uri = s"tcp://$host:$port" + val uri = connection.shortName val p = auth match { case _ if auth(ServerAuthentication.Token) => @@ -153,5 +161,32 @@ private[sbt] object Server { val json = Converter.toJson(p).get IO.write(portfile, CompactPrinter(json)) } + + private[sbt] def prepareSocketfile(): Unit = { + if (socketfile.exists) { + IO.delete(socketfile) + } + IO.createDirectory(socketfile.getParentFile) + } } } + +private[sbt] case class ServerConnection( + connectionType: ConnectionType, + host: String, + port: Int, + auth: Set[ServerAuthentication], + portfile: File, + tokenfile: File, + socketfile: File, + pipeName: String +) { + def shortName: String = { + connectionType match { + case ConnectionType.Local if isWindows => s"local:$pipeName" + case ConnectionType.Local => s"local://$socketfile" + case ConnectionType.Tcp => s"tcp://$host:$port" + // case ConnectionType.Ssh => s"ssh://$host:$port" + } + } +} diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 3275fe092..7041fe6de 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -272,7 +272,11 @@ object Defaults extends BuildCommon { serverPort := 5000 + (Hash .toHex(Hash(appConfiguration.value.baseDirectory.toString)) .## % 1000), - serverAuthentication := Set(ServerAuthentication.Token), + serverConnectionType := ConnectionType.Local, + serverAuthentication := { + if (serverConnectionType.value == ConnectionType.Tcp) Set(ServerAuthentication.Token) + else Set() + }, insideCI :== sys.env.contains("BUILD_NUMBER") || sys.env.contains("CI"), )) diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index ca92e306b..211b46b1b 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -133,6 +133,8 @@ object Keys { val serverPort = SettingKey(BasicKeys.serverPort) val serverHost = SettingKey(BasicKeys.serverHost) val serverAuthentication = SettingKey(BasicKeys.serverAuthentication) + val serverConnectionType = SettingKey(BasicKeys.serverConnectionType) + val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting) val watch = SettingKey(BasicKeys.watch) val suppressSbtShellNotification = settingKey[Boolean]("""True to suppress the "Executing in batch mode.." message.""").withRank(CSetting) diff --git a/main/src/main/scala/sbt/Project.scala b/main/src/main/scala/sbt/Project.scala index 6ec260e86..680bba811 100755 --- a/main/src/main/scala/sbt/Project.scala +++ b/main/src/main/scala/sbt/Project.scala @@ -23,6 +23,7 @@ import Keys.{ serverHost, serverPort, serverAuthentication, + serverConnectionType, watch } import Scope.{ Global, ThisScope } @@ -461,6 +462,7 @@ object Project extends ProjectExtra { val host: Option[String] = get(serverHost) val port: Option[Int] = get(serverPort) val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication) + val connectionType: Option[ConnectionType] = get(serverConnectionType) val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true)) val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands, projectCommand) @@ -471,6 +473,7 @@ object Project extends ProjectExtra { .setCond(serverPort.key, port) .setCond(serverHost.key, host) .setCond(serverAuthentication.key, authentication) + .setCond(serverConnectionType.key, connectionType) .put(historyPath.key, history) .put(templateResolverInfos.key, trs) .setCond(shellPrompt.key, prompt) diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 5d63c3329..f03546b64 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -13,7 +13,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ListBuffer import scala.annotation.tailrec -import BasicKeys.{ serverHost, serverPort, serverAuthentication } +import BasicKeys.{ serverHost, serverPort, serverAuthentication, serverConnectionType } import java.net.Socket import sjsonnew.JsonFormat import sjsonnew.shaded.scalajson.ast.unsafe._ @@ -84,6 +84,7 @@ private[sbt] final class CommandExchange { } private def newChannelName: String = s"channel-${nextChannelId.incrementAndGet()}" + private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}" /** * Check if a server instance is running already, and start one if it isn't. @@ -101,19 +102,23 @@ private[sbt] final class CommandExchange { case Some(xs) => xs case None => Set(ServerAuthentication.Token) } + lazy val connectionType = (s get serverConnectionType) match { + case Some(x) => x + case None => ConnectionType.Tcp + } val serverLogLevel: Level.Value = Level.Debug def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = { - s.log.info(s"new client connected from: ${socket.getPort}") + val name = newNetworkName + s.log.info(s"new client connected: $name") val logger: Logger = { - val loggerName = s"network-${socket.getPort}" - val log = LogExchange.logger(loggerName, None, None) - LogExchange.unbindLoggerAppenders(loggerName) + val log = LogExchange.logger(name, None, None) + LogExchange.unbindLoggerAppenders(name) val appender = MainAppender.defaultScreen(s.globalLogging.console) - LogExchange.bindLoggerAppenders(loggerName, List(appender -> serverLogLevel)) + LogExchange.bindLoggerAppenders(name, List(appender -> serverLogLevel)) log } val channel = - new NetworkChannel(newChannelName, socket, Project structure s, auth, instance, logger) + new NetworkChannel(name, socket, Project structure s, auth, instance, logger) subscribe(channel) } server match { @@ -122,7 +127,18 @@ private[sbt] final class CommandExchange { val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json" val h = Hash.halfHashString(portfile.toURI.toString) val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json" - val x = Server.start(host, port, onIncomingSocket, auth, portfile, tokenfile, s.log) + val socketfile = BuildPaths.getGlobalBase(s) / "server" / h / "sock" + val pipeName = "sbt-server-" + h + val connection = + ServerConnection(connectionType, + host, + port, + auth, + portfile, + tokenfile, + socketfile, + pipeName) + val x = Server.start(connection, onIncomingSocket, s.log) Await.ready(x.ready, Duration("10s")) x.ready.value match { case Some(Success(_)) => diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 61bc30342..c9f430e17 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -106,6 +106,8 @@ object Dependencies { val specs2 = "org.specs2" %% "specs2-junit" % "4.0.1" val junit = "junit" % "junit" % "4.11" val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1" + val jna = "net.java.dev.jna" % "jna" % "4.1.0" + val jnaPlatform = "net.java.dev.jna" % "jna-platform" % "4.1.0" private def scala211Module(name: String, moduleVersion: String) = Def setting ( scalaBinaryVersion.value match { diff --git a/sbt/src/sbt-test/server/handshake/build.sbt b/sbt/src/sbt-test/server/handshake/build.sbt index fd924f0e4..851648f3c 100644 --- a/sbt/src/sbt-test/server/handshake/build.sbt +++ b/sbt/src/sbt-test/server/handshake/build.sbt @@ -2,6 +2,7 @@ lazy val runClient = taskKey[Unit]("") lazy val root = (project in file(".")) .settings( + serverConnectionType in Global := ConnectionType.Tcp, scalaVersion := "2.12.3", serverPort in Global := 5123, libraryDependencies += "org.scala-sbt" %% "io" % "1.0.1", diff --git a/vscode-sbt-scala/client/src/extension.ts b/vscode-sbt-scala/client/src/extension.ts index f4a3eabb4..768e0fc03 100644 --- a/vscode-sbt-scala/client/src/extension.ts +++ b/vscode-sbt-scala/client/src/extension.ts @@ -23,19 +23,25 @@ export function activate(context: ExtensionContext) { let clientOptions: LanguageClientOptions = { documentSelector: [{ language: 'scala', scheme: 'file' }, { language: 'java', scheme: 'file' }], initializationOptions: () => { - return { - token: discoverToken() - }; + return discoverToken(); } } // the port file is hardcoded to a particular location relative to the build. - function discoverToken(): String { + function discoverToken(): any { let pf = path.join(workspace.rootPath, 'project', 'target', 'active.json'); let portfile = JSON.parse(fs.readFileSync(pf)); - let tf = portfile.tokenfilePath; - let tokenfile = JSON.parse(fs.readFileSync(tf)); - return tokenfile.token; + + // if tokenfilepath exists, return the token. + if (portfile.hasOwnProperty('tokenfilePath')) { + let tf = portfile.tokenfilePath; + let tokenfile = JSON.parse(fs.readFileSync(tf)); + return { + token: tokenfile.token + }; + } else { + return {}; + } } // Create the language client and start the client. diff --git a/vscode-sbt-scala/server/src/server.ts b/vscode-sbt-scala/server/src/server.ts index bef7e50a9..05932d451 100644 --- a/vscode-sbt-scala/server/src/server.ts +++ b/vscode-sbt-scala/server/src/server.ts @@ -4,6 +4,7 @@ import * as path from 'path'; import * as url from 'url'; let net = require('net'), fs = require('fs'), + os = require('os'), stdin = process.stdin, stdout = process.stdout; @@ -16,7 +17,17 @@ socket.on('data', (chunk: any) => { }).on('end', () => { stdin.pause(); }); -socket.connect(u.port, '127.0.0.1'); + +if (u.protocol == 'tcp:') { + socket.connect(u.port, '127.0.0.1'); +} else if (u.protocol == 'local:' && os.platform() == 'win32') { + let pipePath = '\\\\.\\pipe\\' + u.hostname; + socket.connect(pipePath); +} else if (u.protocol == 'local:') { + socket.connect(u.path); +} else { + throw 'Unknown protocol ' + u.protocol; +} stdin.resume(); stdin.on('data', (chunk: any) => {