From 0c803214aa2d99a6f71c1d300b6ced69e3ba1f28 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Mon, 27 Nov 2017 21:36:25 -0500 Subject: [PATCH 1/3] IPC Unix Domain Socket and Windows Named Pipe sockets The Java socket implementation for IPC is lifted from facebook/nailgun. https://github.com/facebook/nailgun/tree/af623fddedfdca010df46302a0711ce0e2cc1ba6/ --- .../internal/NGUnixDomainServerSocket.java | 178 ++++++++++++++++++ .../java/sbt/internal/NGUnixDomainSocket.java | 171 +++++++++++++++++ .../internal/NGUnixDomainSocketLibrary.java | 140 ++++++++++++++ .../sbt/internal/NGWin32NamedPipeLibrary.java | 90 +++++++++ .../NGWin32NamedPipeServerSocket.java | 173 +++++++++++++++++ .../sbt/internal/NGWin32NamedPipeSocket.java | 173 +++++++++++++++++ .../ReferenceCountedFileDescriptor.java | 82 ++++++++ 7 files changed, 1007 insertions(+) create mode 100644 main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java create mode 100644 main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java create mode 100644 main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java create mode 100644 main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java create mode 100644 main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java create mode 100644 main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java create mode 100644 main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java b/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java new file mode 100644 index 000000000..89d3bcf43 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java @@ -0,0 +1,178 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainServerSocket.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicInteger; + +import com.sun.jna.LastErrorException; +import com.sun.jna.ptr.IntByReference; + +/** + * Implements a {@link ServerSocket} which binds to a local Unix domain socket + * and returns instances of {@link NGUnixDomainSocket} from + * {@link #accept()}. + */ +public class NGUnixDomainServerSocket extends ServerSocket { + private static final int DEFAULT_BACKLOG = 50; + + // We use an AtomicInteger to prevent a race in this situation which + // could happen if fd were just an int: + // + // Thread 1 -> NGUnixDomainServerSocket.accept() + // -> lock this + // -> check isBound and isClosed + // -> unlock this + // -> descheduled while still in method + // Thread 2 -> NGUnixDomainServerSocket.close() + // -> lock this + // -> check isClosed + // -> NGUnixDomainSocketLibrary.close(fd) + // -> now fd is invalid + // -> unlock this + // Thread 1 -> re-scheduled while still in method + // -> NGUnixDomainSocketLibrary.accept(fd, which is invalid and maybe re-used) + // + // By using an AtomicInteger, we'll set this to -1 after it's closed, which + // will cause the accept() call above to cleanly fail instead of possibly + // being called on an unrelated fd (which may or may not fail). + private final AtomicInteger fd; + + private final int backlog; + private boolean isBound; + private boolean isClosed; + + public static class NGUnixDomainServerSocketAddress extends SocketAddress { + private final String path; + + public NGUnixDomainServerSocketAddress(String path) { + this.path = path; + } + + public String getPath() { + return path; + } + } + + /** + * Constructs an unbound Unix domain server socket. + */ + public NGUnixDomainServerSocket() throws IOException { + this(DEFAULT_BACKLOG, null); + } + + /** + * Constructs an unbound Unix domain server socket with the specified listen backlog. + */ + public NGUnixDomainServerSocket(int backlog) throws IOException { + this(backlog, null); + } + + /** + * Constructs and binds a Unix domain server socket to the specified path. + */ + public NGUnixDomainServerSocket(String path) throws IOException { + this(DEFAULT_BACKLOG, path); + } + + /** + * Constructs and binds a Unix domain server socket to the specified path + * with the specified listen backlog. + */ + public NGUnixDomainServerSocket(int backlog, String path) throws IOException { + try { + fd = new AtomicInteger( + NGUnixDomainSocketLibrary.socket( + NGUnixDomainSocketLibrary.PF_LOCAL, + NGUnixDomainSocketLibrary.SOCK_STREAM, + 0)); + this.backlog = backlog; + if (path != null) { + bind(new NGUnixDomainServerSocketAddress(path)); + } + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public synchronized void bind(SocketAddress endpoint) throws IOException { + if (!(endpoint instanceof NGUnixDomainServerSocketAddress)) { + throw new IllegalArgumentException( + "endpoint must be an instance of NGUnixDomainServerSocketAddress"); + } + if (isBound) { + throw new IllegalStateException("Socket is already bound"); + } + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + NGUnixDomainServerSocketAddress unEndpoint = (NGUnixDomainServerSocketAddress) endpoint; + NGUnixDomainSocketLibrary.SockaddrUn address = + new NGUnixDomainSocketLibrary.SockaddrUn(unEndpoint.getPath()); + try { + int socketFd = fd.get(); + NGUnixDomainSocketLibrary.bind(socketFd, address, address.size()); + NGUnixDomainSocketLibrary.listen(socketFd, backlog); + isBound = true; + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public Socket accept() throws IOException { + // We explicitly do not make this method synchronized, since the + // call to NGUnixDomainSocketLibrary.accept() will block + // indefinitely, causing another thread's call to close() to deadlock. + synchronized (this) { + if (!isBound) { + throw new IllegalStateException("Socket is not bound"); + } + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + } + try { + NGUnixDomainSocketLibrary.SockaddrUn sockaddrUn = + new NGUnixDomainSocketLibrary.SockaddrUn(); + IntByReference addressLen = new IntByReference(); + addressLen.setValue(sockaddrUn.size()); + int clientFd = NGUnixDomainSocketLibrary.accept(fd.get(), sockaddrUn, addressLen); + return new NGUnixDomainSocket(clientFd); + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public synchronized void close() throws IOException { + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + try { + // Ensure any pending call to accept() fails. + NGUnixDomainSocketLibrary.close(fd.getAndSet(-1)); + isClosed = true; + } catch (LastErrorException e) { + throw new IOException(e); + } + } +} diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java new file mode 100644 index 000000000..1a9942ad9 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java @@ -0,0 +1,171 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainSocket.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +import java.nio.ByteBuffer; + +import java.net.Socket; + +/** + * Implements a {@link Socket} backed by a native Unix domain socket. + * + * Instances of this class always return {@code null} for + * {@link Socket#getInetAddress()}, {@link Socket#getLocalAddress()}, + * {@link Socket#getLocalSocketAddress()}, {@link Socket#getRemoteSocketAddress()}. + */ +public class NGUnixDomainSocket extends Socket { + private final ReferenceCountedFileDescriptor fd; + private final InputStream is; + private final OutputStream os; + + /** + * Creates a Unix domain socket backed by a native file descriptor. + */ + public NGUnixDomainSocket(int fd) { + this.fd = new ReferenceCountedFileDescriptor(fd); + this.is = new NGUnixDomainSocketInputStream(); + this.os = new NGUnixDomainSocketOutputStream(); + } + + public InputStream getInputStream() { + return is; + } + + public OutputStream getOutputStream() { + return os; + } + + public void shutdownInput() throws IOException { + doShutdown(NGUnixDomainSocketLibrary.SHUT_RD); + } + + public void shutdownOutput() throws IOException { + doShutdown(NGUnixDomainSocketLibrary.SHUT_WR); + } + + private void doShutdown(int how) throws IOException { + try { + int socketFd = fd.acquire(); + if (socketFd != -1) { + NGUnixDomainSocketLibrary.shutdown(socketFd, how); + } + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + + public void close() throws IOException { + super.close(); + try { + // This might not close the FD right away. In case we are about + // to read or write on another thread, it will delay the close + // until the read or write completes, to prevent the FD from + // being re-used for a different purpose and the other thread + // reading from a different FD. + fd.close(); + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + private class NGUnixDomainSocketInputStream extends InputStream { + public int read() throws IOException { + ByteBuffer buf = ByteBuffer.allocate(1); + int result; + if (doRead(buf) == 0) { + result = -1; + } else { + // Make sure to & with 0xFF to avoid sign extension + result = 0xFF & buf.get(); + } + return result; + } + + public int read(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return 0; + } + ByteBuffer buf = ByteBuffer.wrap(b, off, len); + int result = doRead(buf); + if (result == 0) { + result = -1; + } + return result; + } + + private int doRead(ByteBuffer buf) throws IOException { + try { + int fdToRead = fd.acquire(); + if (fdToRead == -1) { + return -1; + } + return NGUnixDomainSocketLibrary.read(fdToRead, buf, buf.remaining()); + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + } + + private class NGUnixDomainSocketOutputStream extends OutputStream { + + public void write(int b) throws IOException { + ByteBuffer buf = ByteBuffer.allocate(1); + buf.put(0, (byte) (0xFF & b)); + doWrite(buf); + } + + public void write(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return; + } + ByteBuffer buf = ByteBuffer.wrap(b, off, len); + doWrite(buf); + } + + private void doWrite(ByteBuffer buf) throws IOException { + try { + int fdToWrite = fd.acquire(); + if (fdToWrite == -1) { + return; + } + int ret = NGUnixDomainSocketLibrary.write(fdToWrite, buf, buf.remaining()); + if (ret != buf.remaining()) { + // This shouldn't happen with standard blocking Unix domain sockets. + throw new IOException("Could not write " + buf.remaining() + " bytes as requested " + + "(wrote " + ret + " bytes instead)"); + } + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + } +} diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java new file mode 100644 index 000000000..7e760d37a --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java @@ -0,0 +1,140 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainSocketLibrary.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; +import com.sun.jna.Native; +import com.sun.jna.Platform; +import com.sun.jna.Structure; +import com.sun.jna.Union; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; + +/** + * Utility class to bridge native Unix domain socket calls to Java using JNA. + */ +public class NGUnixDomainSocketLibrary { + public static final int PF_LOCAL = 1; + public static final int AF_LOCAL = 1; + public static final int SOCK_STREAM = 1; + + public static final int SHUT_RD = 0; + public static final int SHUT_WR = 1; + + // Utility class, do not instantiate. + private NGUnixDomainSocketLibrary() { } + + // BSD platforms write a length byte at the start of struct sockaddr_un. + private static final boolean HAS_SUN_LEN = + Platform.isMac() || Platform.isFreeBSD() || Platform.isNetBSD() || + Platform.isOpenBSD() || Platform.iskFreeBSD(); + + /** + * Bridges {@code struct sockaddr_un} to and from native code. + */ + public static class SockaddrUn extends Structure implements Structure.ByReference { + /** + * On BSD platforms, the {@code sun_len} and {@code sun_family} values in + * {@code struct sockaddr_un}. + */ + public static class SunLenAndFamily extends Structure { + public byte sunLen; + public byte sunFamily; + + protected List getFieldOrder() { + return Arrays.asList(new String[] { "sunLen", "sunFamily" }); + } + } + + /** + * On BSD platforms, {@code sunLenAndFamily} will be present. + * On other platforms, only {@code sunFamily} will be present. + */ + public static class SunFamily extends Union { + public SunLenAndFamily sunLenAndFamily; + public short sunFamily; + } + + public SunFamily sunFamily = new SunFamily(); + public byte[] sunPath = new byte[104]; + + /** + * Constructs an empty {@code struct sockaddr_un}. + */ + public SockaddrUn() { + if (HAS_SUN_LEN) { + sunFamily.sunLenAndFamily = new SunLenAndFamily(); + sunFamily.setType(SunLenAndFamily.class); + } else { + sunFamily.setType(Short.TYPE); + } + allocateMemory(); + } + + /** + * Constructs a {@code struct sockaddr_un} with a path whose bytes are encoded + * using the default encoding of the platform. + */ + public SockaddrUn(String path) throws IOException { + byte[] pathBytes = path.getBytes(); + if (pathBytes.length > sunPath.length - 1) { + throw new IOException("Cannot fit name [" + path + "] in maximum unix domain socket length"); + } + System.arraycopy(pathBytes, 0, sunPath, 0, pathBytes.length); + sunPath[pathBytes.length] = (byte) 0; + if (HAS_SUN_LEN) { + int len = fieldOffset("sunPath") + pathBytes.length; + sunFamily.sunLenAndFamily = new SunLenAndFamily(); + sunFamily.sunLenAndFamily.sunLen = (byte) len; + sunFamily.sunLenAndFamily.sunFamily = AF_LOCAL; + sunFamily.setType(SunLenAndFamily.class); + } else { + sunFamily.sunFamily = AF_LOCAL; + sunFamily.setType(Short.TYPE); + } + allocateMemory(); + } + + protected List getFieldOrder() { + return Arrays.asList(new String[] { "sunFamily", "sunPath" }); + } + } + + static { + Native.register(Platform.C_LIBRARY_NAME); + } + + public static native int socket(int domain, int type, int protocol) throws LastErrorException; + public static native int bind(int fd, SockaddrUn address, int addressLen) + throws LastErrorException; + public static native int listen(int fd, int backlog) throws LastErrorException; + public static native int accept(int fd, SockaddrUn address, IntByReference addressLen) + throws LastErrorException; + public static native int read(int fd, ByteBuffer buffer, int count) + throws LastErrorException; + public static native int write(int fd, ByteBuffer buffer, int count) + throws LastErrorException; + public static native int close(int fd) throws LastErrorException; + public static native int shutdown(int fd, int how) throws LastErrorException; +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java new file mode 100644 index 000000000..ba535691f --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java @@ -0,0 +1,90 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeLibrary.java + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import java.nio.ByteBuffer; + +import com.sun.jna.*; +import com.sun.jna.platform.win32.WinNT; +import com.sun.jna.platform.win32.WinNT.*; +import com.sun.jna.platform.win32.WinBase.*; +import com.sun.jna.ptr.IntByReference; + +import com.sun.jna.win32.W32APIOptions; + +public interface NGWin32NamedPipeLibrary extends WinNT { + int PIPE_ACCESS_DUPLEX = 3; + int PIPE_UNLIMITED_INSTANCES = 255; + int FILE_FLAG_FIRST_PIPE_INSTANCE = 524288; + + NGWin32NamedPipeLibrary INSTANCE = + (NGWin32NamedPipeLibrary) Native.loadLibrary( + "kernel32", + NGWin32NamedPipeLibrary.class, + W32APIOptions.UNICODE_OPTIONS); + + HANDLE CreateNamedPipe( + String lpName, + int dwOpenMode, + int dwPipeMode, + int nMaxInstances, + int nOutBufferSize, + int nInBufferSize, + int nDefaultTimeOut, + SECURITY_ATTRIBUTES lpSecurityAttributes); + boolean ConnectNamedPipe( + HANDLE hNamedPipe, + Pointer lpOverlapped); + boolean DisconnectNamedPipe( + HANDLE hObject); + boolean ReadFile( + HANDLE hFile, + Memory lpBuffer, + int nNumberOfBytesToRead, + IntByReference lpNumberOfBytesRead, + Pointer lpOverlapped); + boolean WriteFile( + HANDLE hFile, + ByteBuffer lpBuffer, + int nNumberOfBytesToWrite, + IntByReference lpNumberOfBytesWritten, + Pointer lpOverlapped); + boolean CloseHandle( + HANDLE hObject); + boolean GetOverlappedResult( + HANDLE hFile, + Pointer lpOverlapped, + IntByReference lpNumberOfBytesTransferred, + boolean wait); + boolean CancelIoEx( + HANDLE hObject, + Pointer lpOverlapped); + HANDLE CreateEvent( + SECURITY_ATTRIBUTES lpEventAttributes, + boolean bManualReset, + boolean bInitialState, + String lpName); + int WaitForSingleObject( + HANDLE hHandle, + int dwMilliseconds + ); + + int GetLastError(); +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java new file mode 100644 index 000000000..137d9b5dc --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java @@ -0,0 +1,173 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeServerSocket.java + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.platform.win32.WinBase; +import com.sun.jna.platform.win32.WinError; +import com.sun.jna.platform.win32.WinNT; +import com.sun.jna.platform.win32.WinNT.HANDLE; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.LinkedBlockingQueue; + +public class NGWin32NamedPipeServerSocket extends ServerSocket { + private static final NGWin32NamedPipeLibrary API = NGWin32NamedPipeLibrary.INSTANCE; + private static final String WIN32_PIPE_PREFIX = "\\\\.\\pipe\\"; + private static final int BUFFER_SIZE = 65535; + private final LinkedBlockingQueue openHandles; + private final LinkedBlockingQueue connectedHandles; + private final NGWin32NamedPipeSocket.CloseCallback closeCallback; + private final String path; + private final int maxInstances; + private final HANDLE lockHandle; + + public NGWin32NamedPipeServerSocket(String path) throws IOException { + this(NGWin32NamedPipeLibrary.PIPE_UNLIMITED_INSTANCES, path); + } + + public NGWin32NamedPipeServerSocket(int maxInstances, String path) throws IOException { + this.openHandles = new LinkedBlockingQueue<>(); + this.connectedHandles = new LinkedBlockingQueue<>(); + this.closeCallback = handle -> { + if (connectedHandles.remove(handle)) { + closeConnectedPipe(handle, false); + } + if (openHandles.remove(handle)) { + closeOpenPipe(handle); + } + }; + this.maxInstances = maxInstances; + if (!path.startsWith(WIN32_PIPE_PREFIX)) { + this.path = WIN32_PIPE_PREFIX + path; + } else { + this.path = path; + } + String lockPath = this.path + "_lock"; + lockHandle = API.CreateNamedPipe( + lockPath, + NGWin32NamedPipeLibrary.FILE_FLAG_FIRST_PIPE_INSTANCE | NGWin32NamedPipeLibrary.PIPE_ACCESS_DUPLEX, + 0, + 1, + BUFFER_SIZE, + BUFFER_SIZE, + 0, + null); + if (lockHandle == NGWin32NamedPipeLibrary.INVALID_HANDLE_VALUE) { + throw new IOException(String.format("Could not create lock for %s, error %d", lockPath, API.GetLastError())); + } else { + if (!API.DisconnectNamedPipe(lockHandle)) { + throw new IOException(String.format("Could not disconnect lock %d", API.GetLastError())); + } + } + + } + + public void bind(SocketAddress endpoint) throws IOException { + throw new IOException("Win32 named pipes do not support bind(), pass path to constructor"); + } + + public Socket accept() throws IOException { + HANDLE handle = API.CreateNamedPipe( + path, + NGWin32NamedPipeLibrary.PIPE_ACCESS_DUPLEX | WinNT.FILE_FLAG_OVERLAPPED, + 0, + maxInstances, + BUFFER_SIZE, + BUFFER_SIZE, + 0, + null); + if (handle == NGWin32NamedPipeLibrary.INVALID_HANDLE_VALUE) { + throw new IOException(String.format("Could not create named pipe, error %d", API.GetLastError())); + } + openHandles.add(handle); + + HANDLE connWaitable = API.CreateEvent(null, true, false, null); + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = connWaitable; + olap.write(); + + boolean immediate = API.ConnectNamedPipe(handle, olap.getPointer()); + if (immediate) { + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } + + int connectError = API.GetLastError(); + if (connectError == WinError.ERROR_PIPE_CONNECTED) { + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else if (connectError == WinError.ERROR_NO_DATA) { + // Client has connected and disconnected between CreateNamedPipe() and ConnectNamedPipe() + // connection is broken, but it is returned it avoid loop here. + // Actual error will happen for NGSession when it will try to read/write from/to pipe + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else if (connectError == WinError.ERROR_IO_PENDING) { + if (!API.GetOverlappedResult(handle, olap.getPointer(), new IntByReference(), true)) { + openHandles.remove(handle); + closeOpenPipe(handle); + throw new IOException("GetOverlappedResult() failed for connect operation: " + API.GetLastError()); + } + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else { + throw new IOException("ConnectNamedPipe() failed with: " + connectError); + } + } + + public void close() throws IOException { + try { + List handlesToClose = new ArrayList<>(); + openHandles.drainTo(handlesToClose); + for (HANDLE handle : handlesToClose) { + closeOpenPipe(handle); + } + + List handlesToDisconnect = new ArrayList<>(); + connectedHandles.drainTo(handlesToDisconnect); + for (HANDLE handle : handlesToDisconnect) { + closeConnectedPipe(handle, true); + } + } finally { + API.CloseHandle(lockHandle); + } + } + + private void closeOpenPipe(HANDLE handle) throws IOException { + API.CancelIoEx(handle, null); + API.CloseHandle(handle); + } + + private void closeConnectedPipe(HANDLE handle, boolean shutdown) throws IOException { + if (!shutdown) { + API.WaitForSingleObject(handle, 10000); + } + API.DisconnectNamedPipe(handle); + API.CloseHandle(handle); + } +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java new file mode 100644 index 000000000..072ed3901 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java @@ -0,0 +1,173 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeSocket.java + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.Memory; +import com.sun.jna.platform.win32.WinBase; +import com.sun.jna.platform.win32.WinError; +import com.sun.jna.platform.win32.WinNT.HANDLE; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.nio.ByteBuffer; + +public class NGWin32NamedPipeSocket extends Socket { + private static final NGWin32NamedPipeLibrary API = NGWin32NamedPipeLibrary.INSTANCE; + private final HANDLE handle; + private final CloseCallback closeCallback; + private final InputStream is; + private final OutputStream os; + private final HANDLE readerWaitable; + private final HANDLE writerWaitable; + + interface CloseCallback { + void onNamedPipeSocketClose(HANDLE handle) throws IOException; + } + + public NGWin32NamedPipeSocket( + HANDLE handle, + NGWin32NamedPipeSocket.CloseCallback closeCallback) throws IOException { + this.handle = handle; + this.closeCallback = closeCallback; + this.readerWaitable = API.CreateEvent(null, true, false, null); + if (readerWaitable == null) { + throw new IOException("CreateEvent() failed "); + } + writerWaitable = API.CreateEvent(null, true, false, null); + if (writerWaitable == null) { + throw new IOException("CreateEvent() failed "); + } + this.is = new NGWin32NamedPipeSocketInputStream(handle); + this.os = new NGWin32NamedPipeSocketOutputStream(handle); + } + + @Override + public InputStream getInputStream() { + return is; + } + + @Override + public OutputStream getOutputStream() { + return os; + } + + @Override + public void close() throws IOException { + closeCallback.onNamedPipeSocketClose(handle); + } + + @Override + public void shutdownInput() throws IOException { + } + + @Override + public void shutdownOutput() throws IOException { + } + + private class NGWin32NamedPipeSocketInputStream extends InputStream { + private final HANDLE handle; + + NGWin32NamedPipeSocketInputStream(HANDLE handle) { + this.handle = handle; + } + + @Override + public int read() throws IOException { + int result; + byte[] b = new byte[1]; + if (read(b) == 0) { + result = -1; + } else { + result = 0xFF & b[0]; + } + return result; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Memory readBuffer = new Memory(len); + + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = readerWaitable; + olap.write(); + + boolean immediate = API.ReadFile(handle, readBuffer, len, null, olap.getPointer()); + if (!immediate) { + int lastError = API.GetLastError(); + if (lastError != WinError.ERROR_IO_PENDING) { + throw new IOException("ReadFile() failed: " + lastError); + } + } + + IntByReference read = new IntByReference(); + if (!API.GetOverlappedResult(handle, olap.getPointer(), read, true)) { + int lastError = API.GetLastError(); + throw new IOException("GetOverlappedResult() failed for read operation: " + lastError); + } + if (read.getValue() != len) { + throw new IOException("ReadFile() read less bytes than requested"); + } + byte[] byteArray = readBuffer.getByteArray(0, len); + System.arraycopy(byteArray, 0, b, off, len); + return len; + } + } + + private class NGWin32NamedPipeSocketOutputStream extends OutputStream { + private final HANDLE handle; + + NGWin32NamedPipeSocketOutputStream(HANDLE handle) { + this.handle = handle; + } + + @Override + public void write(int b) throws IOException { + write(new byte[]{(byte) (0xFF & b)}); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + ByteBuffer data = ByteBuffer.wrap(b, off, len); + + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = writerWaitable; + olap.write(); + + boolean immediate = API.WriteFile(handle, data, len, null, olap.getPointer()); + if (!immediate) { + int lastError = API.GetLastError(); + if (lastError != WinError.ERROR_IO_PENDING) { + throw new IOException("WriteFile() failed: " + lastError); + } + } + IntByReference written = new IntByReference(); + if (!API.GetOverlappedResult(handle, olap.getPointer(), written, true)) { + int lastError = API.GetLastError(); + throw new IOException("GetOverlappedResult() failed for write operation: " + lastError); + } + if (written.getValue() != len) { + throw new IOException("WriteFile() wrote less bytes than requested"); + } + } + } +} diff --git a/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java b/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java new file mode 100644 index 000000000..7fb5d9d53 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java @@ -0,0 +1,82 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/ReferenceCountedFileDescriptor.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; + +import java.io.IOException; + +/** + * Encapsulates a file descriptor plus a reference count to ensure close requests + * only close the file descriptor once the last reference to the file descriptor + * is released. + * + * If not explicitly closed, the file descriptor will be closed when + * this object is finalized. + */ +public class ReferenceCountedFileDescriptor { + private int fd; + private int fdRefCount; + private boolean closePending; + + public ReferenceCountedFileDescriptor(int fd) { + this.fd = fd; + this.fdRefCount = 0; + this.closePending = false; + } + + protected void finalize() throws IOException { + close(); + } + + public synchronized int acquire() { + fdRefCount++; + return fd; + } + + public synchronized void release() throws IOException { + fdRefCount--; + if (fdRefCount == 0 && closePending && fd != -1) { + doClose(); + } + } + + public synchronized void close() throws IOException { + if (fd == -1 || closePending) { + return; + } + + if (fdRefCount == 0) { + doClose(); + } else { + // Another thread has the FD. We'll close it when they release the reference. + closePending = true; + } + } + + private void doClose() throws IOException { + try { + NGUnixDomainSocketLibrary.close(fd); + fd = -1; + } catch (LastErrorException e) { + throw new IOException(e); + } + } +} From f785750fc40515efd0614d774f2251583ac81118 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Mon, 27 Nov 2017 21:37:31 -0500 Subject: [PATCH 2/3] IPC Unix domain socket for sbt server In addition to TCP, this adds sbt server support for IPC (interprocess communication) using Unix domain socket and Windows named pipe. The use of Unix domain socket has performance and security benefits. --- build.sbt | 11 +++- .../ConnectionTypeFormats.scala | 28 +++++++++ .../contraband-scala/sbt/ConnectionType.scala | 13 +++++ main-command/src/main/contraband/state.contra | 7 +++ .../src/main/scala/sbt/BasicKeys.scala | 5 ++ .../scala/sbt/internal/server/Server.scala | 57 +++++++++++++++---- main/src/main/scala/sbt/Defaults.scala | 6 +- main/src/main/scala/sbt/Keys.scala | 2 + main/src/main/scala/sbt/Project.scala | 3 + .../scala/sbt/internal/CommandExchange.scala | 32 ++++++++--- project/Dependencies.scala | 2 + sbt/src/sbt-test/server/handshake/build.sbt | 1 + vscode-sbt-scala/client/src/extension.ts | 20 ++++--- vscode-sbt-scala/server/src/server.ts | 13 ++++- 14 files changed, 169 insertions(+), 31 deletions(-) create mode 100644 main-command/src/main/contraband-scala/ConnectionTypeFormats.scala create mode 100644 main-command/src/main/contraband-scala/sbt/ConnectionType.scala diff --git a/build.sbt b/build.sbt index c5b2c6cfc..61842c40c 100644 --- a/build.sbt +++ b/build.sbt @@ -55,7 +55,7 @@ def commonSettings: Seq[Setting[_]] = concurrentRestrictions in Global += Util.testExclusiveRestriction, testOptions in Test += Tests.Argument(TestFrameworks.ScalaCheck, "-w", "1"), testOptions in Test += Tests.Argument(TestFrameworks.ScalaCheck, "-verbosity", "2"), - javacOptions in compile ++= Seq("-target", "6", "-source", "6", "-Xlint", "-Xlint:-serial"), + javacOptions in compile ++= Seq("-Xlint", "-Xlint:-serial"), crossScalaVersions := Seq(baseScalaVersion), bintrayPackage := (bintrayPackage in ThisBuild).value, bintrayRepository := (bintrayRepository in ThisBuild).value, @@ -309,7 +309,8 @@ lazy val commandProj = (project in file("main-command")) .settings( testedBaseSettings, name := "Command", - libraryDependencies ++= Seq(launcherInterface, sjsonNewScalaJson.value, templateResolverApi), + libraryDependencies ++= Seq(launcherInterface, sjsonNewScalaJson.value, templateResolverApi, + jna, jnaPlatform), managedSourceDirectories in Compile += baseDirectory.value / "src" / "main" / "contraband-scala", sourceManaged in (Compile, generateContrabands) := baseDirectory.value / "src" / "main" / "contraband-scala", @@ -324,7 +325,11 @@ lazy val commandProj = (project in file("main-command")) exclude[ReversedMissingMethodProblem]("sbt.internal.CommandChannel.*"), // Added an overload to reboot. The overload is private[sbt]. exclude[ReversedMissingMethodProblem]("sbt.StateOps.reboot"), - ) + ), + unmanagedSources in (Compile, headerCreate) := { + val old = (unmanagedSources in (Compile, headerCreate)).value + old filterNot { x => (x.getName startsWith "NG") || (x.getName == "ReferenceCountedFileDescriptor.java") } + }, ) .configure( addSbtIO, diff --git a/main-command/src/main/contraband-scala/ConnectionTypeFormats.scala b/main-command/src/main/contraband-scala/ConnectionTypeFormats.scala new file mode 100644 index 000000000..4562f75b9 --- /dev/null +++ b/main-command/src/main/contraband-scala/ConnectionTypeFormats.scala @@ -0,0 +1,28 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait ConnectionTypeFormats { self: sjsonnew.BasicJsonProtocol => +implicit lazy val ConnectionTypeFormat: JsonFormat[sbt.ConnectionType] = new JsonFormat[sbt.ConnectionType] { + override def read[J](jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.ConnectionType = { + jsOpt match { + case Some(js) => + unbuilder.readString(js) match { + case "Local" => sbt.ConnectionType.Local + case "Tcp" => sbt.ConnectionType.Tcp + } + case None => + deserializationError("Expected JsString but found None") + } + } + override def write[J](obj: sbt.ConnectionType, builder: Builder[J]): Unit = { + val str = obj match { + case sbt.ConnectionType.Local => "Local" + case sbt.ConnectionType.Tcp => "Tcp" + } + builder.writeString(str) + } +} +} diff --git a/main-command/src/main/contraband-scala/sbt/ConnectionType.scala b/main-command/src/main/contraband-scala/sbt/ConnectionType.scala new file mode 100644 index 000000000..af50ed2e9 --- /dev/null +++ b/main-command/src/main/contraband-scala/sbt/ConnectionType.scala @@ -0,0 +1,13 @@ +/** + * This code is generated using [[http://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt +sealed abstract class ConnectionType extends Serializable +object ConnectionType { + + /** This uses Unix domain socket on POSIX, and named pipe on Windows. */ + case object Local extends ConnectionType + case object Tcp extends ConnectionType +} diff --git a/main-command/src/main/contraband/state.contra b/main-command/src/main/contraband/state.contra index 79d0bcaab..2737ce8ab 100644 --- a/main-command/src/main/contraband/state.contra +++ b/main-command/src/main/contraband/state.contra @@ -16,3 +16,10 @@ type CommandSource { enum ServerAuthentication { Token } + +enum ConnectionType { + ## This uses Unix domain socket on POSIX, and named pipe on Windows. + Local + Tcp + # Ssh +} diff --git a/main-command/src/main/scala/sbt/BasicKeys.scala b/main-command/src/main/scala/sbt/BasicKeys.scala index ef6bb00cd..42b68daa4 100644 --- a/main-command/src/main/scala/sbt/BasicKeys.scala +++ b/main-command/src/main/scala/sbt/BasicKeys.scala @@ -33,6 +33,11 @@ object BasicKeys { "Method of authenticating server command.", 10000) + val serverConnectionType = + AttributeKey[ConnectionType]("serverConnectionType", + "The wire protocol for the server command.", + 10000) + private[sbt] val interactive = AttributeKey[Boolean]( "interactive", "True if commands are currently being entered from an interactive environment.", diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 3a7c2785b..c4d3b542c 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -17,13 +17,14 @@ import java.security.SecureRandom import java.math.BigInteger import scala.concurrent.{ Future, Promise } import scala.util.{ Try, Success, Failure } -import sbt.internal.util.ErrorHandling import sbt.internal.protocol.{ PortFile, TokenFile } import sbt.util.Logger import sbt.io.IO import sbt.io.syntax._ import sjsonnew.support.scalajson.unsafe.{ Converter, CompactPrinter } import sbt.internal.protocol.codec._ +import sbt.internal.util.ErrorHandling +import sbt.internal.util.Util.isWindows private[sbt] sealed trait ServerInstance { def shutdown(): Unit @@ -38,31 +39,37 @@ private[sbt] object Server { with TokenFileFormats object JsonProtocol extends JsonProtocol - def start(host: String, - port: Int, + def start(connection: ServerConnection, onIncomingSocket: (Socket, ServerInstance) => Unit, - auth: Set[ServerAuthentication], - portfile: File, - tokenfile: File, log: Logger): ServerInstance = new ServerInstance { self => + import connection._ val running = new AtomicBoolean(false) val p: Promise[Unit] = Promise[Unit]() val ready: Future[Unit] = p.future private[this] val rand = new SecureRandom private[this] var token: String = nextToken + private[this] var serverSocketOpt: Option[ServerSocket] = None val serverThread = new Thread("sbt-socket-server") { override def run(): Unit = { Try { - ErrorHandling.translate(s"server failed to start on $host:$port. ") { - new ServerSocket(port, 50, InetAddress.getByName(host)) + ErrorHandling.translate(s"server failed to start on ${connection.shortName}. ") { + connection.connectionType match { + case ConnectionType.Local if isWindows => + new NGWin32NamedPipeServerSocket(pipeName) + case ConnectionType.Local => + prepareSocketfile() + new NGUnixDomainServerSocket(socketfile.getAbsolutePath) + case ConnectionType.Tcp => new ServerSocket(port, 50, InetAddress.getByName(host)) + } } } match { case Failure(e) => p.failure(e) case Success(serverSocket) => serverSocket.setSoTimeout(5000) - log.info(s"sbt server started at $host:$port") + serverSocketOpt = Option(serverSocket) + log.info(s"sbt server started at ${connection.shortName}") writePortfile() running.set(true) p.success(()) @@ -74,6 +81,7 @@ private[sbt] object Server { case _: SocketTimeoutException => // its ok } } + serverSocket.close() } } } @@ -106,7 +114,7 @@ private[sbt] object Server { private[this] def writeTokenfile(): Unit = { import JsonProtocol._ - val uri = s"tcp://$host:$port" + val uri = connection.shortName val t = TokenFile(uri, token) val jsonToken = Converter.toJson(t).get @@ -141,7 +149,7 @@ private[sbt] object Server { private[this] def writePortfile(): Unit = { import JsonProtocol._ - val uri = s"tcp://$host:$port" + val uri = connection.shortName val p = auth match { case _ if auth(ServerAuthentication.Token) => @@ -153,5 +161,32 @@ private[sbt] object Server { val json = Converter.toJson(p).get IO.write(portfile, CompactPrinter(json)) } + + private[sbt] def prepareSocketfile(): Unit = { + if (socketfile.exists) { + IO.delete(socketfile) + } + IO.createDirectory(socketfile.getParentFile) + } } } + +private[sbt] case class ServerConnection( + connectionType: ConnectionType, + host: String, + port: Int, + auth: Set[ServerAuthentication], + portfile: File, + tokenfile: File, + socketfile: File, + pipeName: String +) { + def shortName: String = { + connectionType match { + case ConnectionType.Local if isWindows => s"local:$pipeName" + case ConnectionType.Local => s"local://$socketfile" + case ConnectionType.Tcp => s"tcp://$host:$port" + // case ConnectionType.Ssh => s"ssh://$host:$port" + } + } +} diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 3275fe092..7041fe6de 100755 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -272,7 +272,11 @@ object Defaults extends BuildCommon { serverPort := 5000 + (Hash .toHex(Hash(appConfiguration.value.baseDirectory.toString)) .## % 1000), - serverAuthentication := Set(ServerAuthentication.Token), + serverConnectionType := ConnectionType.Local, + serverAuthentication := { + if (serverConnectionType.value == ConnectionType.Tcp) Set(ServerAuthentication.Token) + else Set() + }, insideCI :== sys.env.contains("BUILD_NUMBER") || sys.env.contains("CI"), )) diff --git a/main/src/main/scala/sbt/Keys.scala b/main/src/main/scala/sbt/Keys.scala index ca92e306b..211b46b1b 100644 --- a/main/src/main/scala/sbt/Keys.scala +++ b/main/src/main/scala/sbt/Keys.scala @@ -133,6 +133,8 @@ object Keys { val serverPort = SettingKey(BasicKeys.serverPort) val serverHost = SettingKey(BasicKeys.serverHost) val serverAuthentication = SettingKey(BasicKeys.serverAuthentication) + val serverConnectionType = SettingKey(BasicKeys.serverConnectionType) + val analysis = AttributeKey[CompileAnalysis]("analysis", "Analysis of compilation, including dependencies and generated outputs.", DSetting) val watch = SettingKey(BasicKeys.watch) val suppressSbtShellNotification = settingKey[Boolean]("""True to suppress the "Executing in batch mode.." message.""").withRank(CSetting) diff --git a/main/src/main/scala/sbt/Project.scala b/main/src/main/scala/sbt/Project.scala index 6ec260e86..680bba811 100755 --- a/main/src/main/scala/sbt/Project.scala +++ b/main/src/main/scala/sbt/Project.scala @@ -23,6 +23,7 @@ import Keys.{ serverHost, serverPort, serverAuthentication, + serverConnectionType, watch } import Scope.{ Global, ThisScope } @@ -461,6 +462,7 @@ object Project extends ProjectExtra { val host: Option[String] = get(serverHost) val port: Option[Int] = get(serverPort) val authentication: Option[Set[ServerAuthentication]] = get(serverAuthentication) + val connectionType: Option[ConnectionType] = get(serverConnectionType) val commandDefs = allCommands.distinct.flatten[Command].map(_ tag (projectCommand, true)) val newDefinedCommands = commandDefs ++ BasicCommands.removeTagged(s.definedCommands, projectCommand) @@ -471,6 +473,7 @@ object Project extends ProjectExtra { .setCond(serverPort.key, port) .setCond(serverHost.key, host) .setCond(serverAuthentication.key, authentication) + .setCond(serverConnectionType.key, connectionType) .put(historyPath.key, history) .put(templateResolverInfos.key, trs) .setCond(shellPrompt.key, prompt) diff --git a/main/src/main/scala/sbt/internal/CommandExchange.scala b/main/src/main/scala/sbt/internal/CommandExchange.scala index 9605558a4..4a13bdb80 100644 --- a/main/src/main/scala/sbt/internal/CommandExchange.scala +++ b/main/src/main/scala/sbt/internal/CommandExchange.scala @@ -13,7 +13,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ListBuffer import scala.annotation.tailrec -import BasicKeys.{ serverHost, serverPort, serverAuthentication } +import BasicKeys.{ serverHost, serverPort, serverAuthentication, serverConnectionType } import java.net.Socket import sjsonnew.JsonFormat import sjsonnew.shaded.scalajson.ast.unsafe._ @@ -83,6 +83,7 @@ private[sbt] final class CommandExchange { } private def newChannelName: String = s"channel-${nextChannelId.incrementAndGet()}" + private def newNetworkName: String = s"network-${nextChannelId.incrementAndGet()}" /** * Check if a server instance is running already, and start one if it isn't. @@ -100,19 +101,23 @@ private[sbt] final class CommandExchange { case Some(xs) => xs case None => Set(ServerAuthentication.Token) } + lazy val connectionType = (s get serverConnectionType) match { + case Some(x) => x + case None => ConnectionType.Tcp + } val serverLogLevel: Level.Value = Level.Debug def onIncomingSocket(socket: Socket, instance: ServerInstance): Unit = { - s.log.info(s"new client connected from: ${socket.getPort}") + val name = newNetworkName + s.log.info(s"new client connected: $name") val logger: Logger = { - val loggerName = s"network-${socket.getPort}" - val log = LogExchange.logger(loggerName, None, None) - LogExchange.unbindLoggerAppenders(loggerName) + val log = LogExchange.logger(name, None, None) + LogExchange.unbindLoggerAppenders(name) val appender = MainAppender.defaultScreen(s.globalLogging.console) - LogExchange.bindLoggerAppenders(loggerName, List(appender -> serverLogLevel)) + LogExchange.bindLoggerAppenders(name, List(appender -> serverLogLevel)) log } val channel = - new NetworkChannel(newChannelName, socket, Project structure s, auth, instance, logger) + new NetworkChannel(name, socket, Project structure s, auth, instance, logger) subscribe(channel) } server match { @@ -121,7 +126,18 @@ private[sbt] final class CommandExchange { val portfile = (new File(".")).getAbsoluteFile / "project" / "target" / "active.json" val h = Hash.halfHashString(portfile.toURI.toString) val tokenfile = BuildPaths.getGlobalBase(s) / "server" / h / "token.json" - val x = Server.start(host, port, onIncomingSocket, auth, portfile, tokenfile, s.log) + val socketfile = BuildPaths.getGlobalBase(s) / "server" / h / "sock" + val pipeName = "sbt-server-" + h + val connection = + ServerConnection(connectionType, + host, + port, + auth, + portfile, + tokenfile, + socketfile, + pipeName) + val x = Server.start(connection, onIncomingSocket, s.log) Await.ready(x.ready, Duration("10s")) x.ready.value match { case Some(Success(_)) => diff --git a/project/Dependencies.scala b/project/Dependencies.scala index f714cbc82..ce9bd9223 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -106,6 +106,8 @@ object Dependencies { val specs2 = "org.specs2" %% "specs2" % "2.4.17" val junit = "junit" % "junit" % "4.11" val templateResolverApi = "org.scala-sbt" % "template-resolver" % "0.1" + val jna = "net.java.dev.jna" % "jna" % "4.1.0" + val jnaPlatform = "net.java.dev.jna" % "jna-platform" % "4.1.0" private def scala211Module(name: String, moduleVersion: String) = Def setting ( scalaBinaryVersion.value match { diff --git a/sbt/src/sbt-test/server/handshake/build.sbt b/sbt/src/sbt-test/server/handshake/build.sbt index fd924f0e4..851648f3c 100644 --- a/sbt/src/sbt-test/server/handshake/build.sbt +++ b/sbt/src/sbt-test/server/handshake/build.sbt @@ -2,6 +2,7 @@ lazy val runClient = taskKey[Unit]("") lazy val root = (project in file(".")) .settings( + serverConnectionType in Global := ConnectionType.Tcp, scalaVersion := "2.12.3", serverPort in Global := 5123, libraryDependencies += "org.scala-sbt" %% "io" % "1.0.1", diff --git a/vscode-sbt-scala/client/src/extension.ts b/vscode-sbt-scala/client/src/extension.ts index f4a3eabb4..768e0fc03 100644 --- a/vscode-sbt-scala/client/src/extension.ts +++ b/vscode-sbt-scala/client/src/extension.ts @@ -23,19 +23,25 @@ export function activate(context: ExtensionContext) { let clientOptions: LanguageClientOptions = { documentSelector: [{ language: 'scala', scheme: 'file' }, { language: 'java', scheme: 'file' }], initializationOptions: () => { - return { - token: discoverToken() - }; + return discoverToken(); } } // the port file is hardcoded to a particular location relative to the build. - function discoverToken(): String { + function discoverToken(): any { let pf = path.join(workspace.rootPath, 'project', 'target', 'active.json'); let portfile = JSON.parse(fs.readFileSync(pf)); - let tf = portfile.tokenfilePath; - let tokenfile = JSON.parse(fs.readFileSync(tf)); - return tokenfile.token; + + // if tokenfilepath exists, return the token. + if (portfile.hasOwnProperty('tokenfilePath')) { + let tf = portfile.tokenfilePath; + let tokenfile = JSON.parse(fs.readFileSync(tf)); + return { + token: tokenfile.token + }; + } else { + return {}; + } } // Create the language client and start the client. diff --git a/vscode-sbt-scala/server/src/server.ts b/vscode-sbt-scala/server/src/server.ts index bef7e50a9..05932d451 100644 --- a/vscode-sbt-scala/server/src/server.ts +++ b/vscode-sbt-scala/server/src/server.ts @@ -4,6 +4,7 @@ import * as path from 'path'; import * as url from 'url'; let net = require('net'), fs = require('fs'), + os = require('os'), stdin = process.stdin, stdout = process.stdout; @@ -16,7 +17,17 @@ socket.on('data', (chunk: any) => { }).on('end', () => { stdin.pause(); }); -socket.connect(u.port, '127.0.0.1'); + +if (u.protocol == 'tcp:') { + socket.connect(u.port, '127.0.0.1'); +} else if (u.protocol == 'local:' && os.platform() == 'win32') { + let pipePath = '\\\\.\\pipe\\' + u.hostname; + socket.connect(pipePath); +} else if (u.protocol == 'local:') { + socket.connect(u.path); +} else { + throw 'Unknown protocol ' + u.protocol; +} stdin.resume(); stdin.on('data', (chunk: any) => { From ef61a1efa7b3be9ae296ff0a6de48e1bfe8f2269 Mon Sep 17 00:00:00 2001 From: Eugene Yokota Date: Mon, 27 Nov 2017 21:26:24 -0500 Subject: [PATCH 3/3] Read just the amount of bytes available --- .../java/sbt/internal/NGWin32NamedPipeSocket.java | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java index 072ed3901..b22bb6bbf 100644 --- a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java @@ -1,4 +1,5 @@ // Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeSocket.java +// Made change in `read` to read just the amount of bytes available. /* @@ -124,12 +125,10 @@ public class NGWin32NamedPipeSocket extends Socket { int lastError = API.GetLastError(); throw new IOException("GetOverlappedResult() failed for read operation: " + lastError); } - if (read.getValue() != len) { - throw new IOException("ReadFile() read less bytes than requested"); - } - byte[] byteArray = readBuffer.getByteArray(0, len); - System.arraycopy(byteArray, 0, b, off, len); - return len; + int actualLen = read.getValue(); + byte[] byteArray = readBuffer.getByteArray(0, actualLen); + System.arraycopy(byteArray, 0, b, off, actualLen); + return actualLen; } }