diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java b/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java new file mode 100644 index 000000000..89d3bcf43 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainServerSocket.java @@ -0,0 +1,178 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainServerSocket.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicInteger; + +import com.sun.jna.LastErrorException; +import com.sun.jna.ptr.IntByReference; + +/** + * Implements a {@link ServerSocket} which binds to a local Unix domain socket + * and returns instances of {@link NGUnixDomainSocket} from + * {@link #accept()}. + */ +public class NGUnixDomainServerSocket extends ServerSocket { + private static final int DEFAULT_BACKLOG = 50; + + // We use an AtomicInteger to prevent a race in this situation which + // could happen if fd were just an int: + // + // Thread 1 -> NGUnixDomainServerSocket.accept() + // -> lock this + // -> check isBound and isClosed + // -> unlock this + // -> descheduled while still in method + // Thread 2 -> NGUnixDomainServerSocket.close() + // -> lock this + // -> check isClosed + // -> NGUnixDomainSocketLibrary.close(fd) + // -> now fd is invalid + // -> unlock this + // Thread 1 -> re-scheduled while still in method + // -> NGUnixDomainSocketLibrary.accept(fd, which is invalid and maybe re-used) + // + // By using an AtomicInteger, we'll set this to -1 after it's closed, which + // will cause the accept() call above to cleanly fail instead of possibly + // being called on an unrelated fd (which may or may not fail). + private final AtomicInteger fd; + + private final int backlog; + private boolean isBound; + private boolean isClosed; + + public static class NGUnixDomainServerSocketAddress extends SocketAddress { + private final String path; + + public NGUnixDomainServerSocketAddress(String path) { + this.path = path; + } + + public String getPath() { + return path; + } + } + + /** + * Constructs an unbound Unix domain server socket. + */ + public NGUnixDomainServerSocket() throws IOException { + this(DEFAULT_BACKLOG, null); + } + + /** + * Constructs an unbound Unix domain server socket with the specified listen backlog. + */ + public NGUnixDomainServerSocket(int backlog) throws IOException { + this(backlog, null); + } + + /** + * Constructs and binds a Unix domain server socket to the specified path. + */ + public NGUnixDomainServerSocket(String path) throws IOException { + this(DEFAULT_BACKLOG, path); + } + + /** + * Constructs and binds a Unix domain server socket to the specified path + * with the specified listen backlog. + */ + public NGUnixDomainServerSocket(int backlog, String path) throws IOException { + try { + fd = new AtomicInteger( + NGUnixDomainSocketLibrary.socket( + NGUnixDomainSocketLibrary.PF_LOCAL, + NGUnixDomainSocketLibrary.SOCK_STREAM, + 0)); + this.backlog = backlog; + if (path != null) { + bind(new NGUnixDomainServerSocketAddress(path)); + } + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public synchronized void bind(SocketAddress endpoint) throws IOException { + if (!(endpoint instanceof NGUnixDomainServerSocketAddress)) { + throw new IllegalArgumentException( + "endpoint must be an instance of NGUnixDomainServerSocketAddress"); + } + if (isBound) { + throw new IllegalStateException("Socket is already bound"); + } + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + NGUnixDomainServerSocketAddress unEndpoint = (NGUnixDomainServerSocketAddress) endpoint; + NGUnixDomainSocketLibrary.SockaddrUn address = + new NGUnixDomainSocketLibrary.SockaddrUn(unEndpoint.getPath()); + try { + int socketFd = fd.get(); + NGUnixDomainSocketLibrary.bind(socketFd, address, address.size()); + NGUnixDomainSocketLibrary.listen(socketFd, backlog); + isBound = true; + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public Socket accept() throws IOException { + // We explicitly do not make this method synchronized, since the + // call to NGUnixDomainSocketLibrary.accept() will block + // indefinitely, causing another thread's call to close() to deadlock. + synchronized (this) { + if (!isBound) { + throw new IllegalStateException("Socket is not bound"); + } + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + } + try { + NGUnixDomainSocketLibrary.SockaddrUn sockaddrUn = + new NGUnixDomainSocketLibrary.SockaddrUn(); + IntByReference addressLen = new IntByReference(); + addressLen.setValue(sockaddrUn.size()); + int clientFd = NGUnixDomainSocketLibrary.accept(fd.get(), sockaddrUn, addressLen); + return new NGUnixDomainSocket(clientFd); + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + public synchronized void close() throws IOException { + if (isClosed) { + throw new IllegalStateException("Socket is already closed"); + } + try { + // Ensure any pending call to accept() fails. + NGUnixDomainSocketLibrary.close(fd.getAndSet(-1)); + isClosed = true; + } catch (LastErrorException e) { + throw new IOException(e); + } + } +} diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java new file mode 100644 index 000000000..1a9942ad9 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocket.java @@ -0,0 +1,171 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainSocket.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +import java.nio.ByteBuffer; + +import java.net.Socket; + +/** + * Implements a {@link Socket} backed by a native Unix domain socket. + * + * Instances of this class always return {@code null} for + * {@link Socket#getInetAddress()}, {@link Socket#getLocalAddress()}, + * {@link Socket#getLocalSocketAddress()}, {@link Socket#getRemoteSocketAddress()}. + */ +public class NGUnixDomainSocket extends Socket { + private final ReferenceCountedFileDescriptor fd; + private final InputStream is; + private final OutputStream os; + + /** + * Creates a Unix domain socket backed by a native file descriptor. + */ + public NGUnixDomainSocket(int fd) { + this.fd = new ReferenceCountedFileDescriptor(fd); + this.is = new NGUnixDomainSocketInputStream(); + this.os = new NGUnixDomainSocketOutputStream(); + } + + public InputStream getInputStream() { + return is; + } + + public OutputStream getOutputStream() { + return os; + } + + public void shutdownInput() throws IOException { + doShutdown(NGUnixDomainSocketLibrary.SHUT_RD); + } + + public void shutdownOutput() throws IOException { + doShutdown(NGUnixDomainSocketLibrary.SHUT_WR); + } + + private void doShutdown(int how) throws IOException { + try { + int socketFd = fd.acquire(); + if (socketFd != -1) { + NGUnixDomainSocketLibrary.shutdown(socketFd, how); + } + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + + public void close() throws IOException { + super.close(); + try { + // This might not close the FD right away. In case we are about + // to read or write on another thread, it will delay the close + // until the read or write completes, to prevent the FD from + // being re-used for a different purpose and the other thread + // reading from a different FD. + fd.close(); + } catch (LastErrorException e) { + throw new IOException(e); + } + } + + private class NGUnixDomainSocketInputStream extends InputStream { + public int read() throws IOException { + ByteBuffer buf = ByteBuffer.allocate(1); + int result; + if (doRead(buf) == 0) { + result = -1; + } else { + // Make sure to & with 0xFF to avoid sign extension + result = 0xFF & buf.get(); + } + return result; + } + + public int read(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return 0; + } + ByteBuffer buf = ByteBuffer.wrap(b, off, len); + int result = doRead(buf); + if (result == 0) { + result = -1; + } + return result; + } + + private int doRead(ByteBuffer buf) throws IOException { + try { + int fdToRead = fd.acquire(); + if (fdToRead == -1) { + return -1; + } + return NGUnixDomainSocketLibrary.read(fdToRead, buf, buf.remaining()); + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + } + + private class NGUnixDomainSocketOutputStream extends OutputStream { + + public void write(int b) throws IOException { + ByteBuffer buf = ByteBuffer.allocate(1); + buf.put(0, (byte) (0xFF & b)); + doWrite(buf); + } + + public void write(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return; + } + ByteBuffer buf = ByteBuffer.wrap(b, off, len); + doWrite(buf); + } + + private void doWrite(ByteBuffer buf) throws IOException { + try { + int fdToWrite = fd.acquire(); + if (fdToWrite == -1) { + return; + } + int ret = NGUnixDomainSocketLibrary.write(fdToWrite, buf, buf.remaining()); + if (ret != buf.remaining()) { + // This shouldn't happen with standard blocking Unix domain sockets. + throw new IOException("Could not write " + buf.remaining() + " bytes as requested " + + "(wrote " + ret + " bytes instead)"); + } + } catch (LastErrorException e) { + throw new IOException(e); + } finally { + fd.release(); + } + } + } +} diff --git a/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java new file mode 100644 index 000000000..7e760d37a --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGUnixDomainSocketLibrary.java @@ -0,0 +1,140 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGUnixDomainSocketLibrary.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; +import com.sun.jna.Native; +import com.sun.jna.Platform; +import com.sun.jna.Structure; +import com.sun.jna.Union; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; + +/** + * Utility class to bridge native Unix domain socket calls to Java using JNA. + */ +public class NGUnixDomainSocketLibrary { + public static final int PF_LOCAL = 1; + public static final int AF_LOCAL = 1; + public static final int SOCK_STREAM = 1; + + public static final int SHUT_RD = 0; + public static final int SHUT_WR = 1; + + // Utility class, do not instantiate. + private NGUnixDomainSocketLibrary() { } + + // BSD platforms write a length byte at the start of struct sockaddr_un. + private static final boolean HAS_SUN_LEN = + Platform.isMac() || Platform.isFreeBSD() || Platform.isNetBSD() || + Platform.isOpenBSD() || Platform.iskFreeBSD(); + + /** + * Bridges {@code struct sockaddr_un} to and from native code. + */ + public static class SockaddrUn extends Structure implements Structure.ByReference { + /** + * On BSD platforms, the {@code sun_len} and {@code sun_family} values in + * {@code struct sockaddr_un}. + */ + public static class SunLenAndFamily extends Structure { + public byte sunLen; + public byte sunFamily; + + protected List getFieldOrder() { + return Arrays.asList(new String[] { "sunLen", "sunFamily" }); + } + } + + /** + * On BSD platforms, {@code sunLenAndFamily} will be present. + * On other platforms, only {@code sunFamily} will be present. + */ + public static class SunFamily extends Union { + public SunLenAndFamily sunLenAndFamily; + public short sunFamily; + } + + public SunFamily sunFamily = new SunFamily(); + public byte[] sunPath = new byte[104]; + + /** + * Constructs an empty {@code struct sockaddr_un}. + */ + public SockaddrUn() { + if (HAS_SUN_LEN) { + sunFamily.sunLenAndFamily = new SunLenAndFamily(); + sunFamily.setType(SunLenAndFamily.class); + } else { + sunFamily.setType(Short.TYPE); + } + allocateMemory(); + } + + /** + * Constructs a {@code struct sockaddr_un} with a path whose bytes are encoded + * using the default encoding of the platform. + */ + public SockaddrUn(String path) throws IOException { + byte[] pathBytes = path.getBytes(); + if (pathBytes.length > sunPath.length - 1) { + throw new IOException("Cannot fit name [" + path + "] in maximum unix domain socket length"); + } + System.arraycopy(pathBytes, 0, sunPath, 0, pathBytes.length); + sunPath[pathBytes.length] = (byte) 0; + if (HAS_SUN_LEN) { + int len = fieldOffset("sunPath") + pathBytes.length; + sunFamily.sunLenAndFamily = new SunLenAndFamily(); + sunFamily.sunLenAndFamily.sunLen = (byte) len; + sunFamily.sunLenAndFamily.sunFamily = AF_LOCAL; + sunFamily.setType(SunLenAndFamily.class); + } else { + sunFamily.sunFamily = AF_LOCAL; + sunFamily.setType(Short.TYPE); + } + allocateMemory(); + } + + protected List getFieldOrder() { + return Arrays.asList(new String[] { "sunFamily", "sunPath" }); + } + } + + static { + Native.register(Platform.C_LIBRARY_NAME); + } + + public static native int socket(int domain, int type, int protocol) throws LastErrorException; + public static native int bind(int fd, SockaddrUn address, int addressLen) + throws LastErrorException; + public static native int listen(int fd, int backlog) throws LastErrorException; + public static native int accept(int fd, SockaddrUn address, IntByReference addressLen) + throws LastErrorException; + public static native int read(int fd, ByteBuffer buffer, int count) + throws LastErrorException; + public static native int write(int fd, ByteBuffer buffer, int count) + throws LastErrorException; + public static native int close(int fd) throws LastErrorException; + public static native int shutdown(int fd, int how) throws LastErrorException; +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java new file mode 100644 index 000000000..ba535691f --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeLibrary.java @@ -0,0 +1,90 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeLibrary.java + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import java.nio.ByteBuffer; + +import com.sun.jna.*; +import com.sun.jna.platform.win32.WinNT; +import com.sun.jna.platform.win32.WinNT.*; +import com.sun.jna.platform.win32.WinBase.*; +import com.sun.jna.ptr.IntByReference; + +import com.sun.jna.win32.W32APIOptions; + +public interface NGWin32NamedPipeLibrary extends WinNT { + int PIPE_ACCESS_DUPLEX = 3; + int PIPE_UNLIMITED_INSTANCES = 255; + int FILE_FLAG_FIRST_PIPE_INSTANCE = 524288; + + NGWin32NamedPipeLibrary INSTANCE = + (NGWin32NamedPipeLibrary) Native.loadLibrary( + "kernel32", + NGWin32NamedPipeLibrary.class, + W32APIOptions.UNICODE_OPTIONS); + + HANDLE CreateNamedPipe( + String lpName, + int dwOpenMode, + int dwPipeMode, + int nMaxInstances, + int nOutBufferSize, + int nInBufferSize, + int nDefaultTimeOut, + SECURITY_ATTRIBUTES lpSecurityAttributes); + boolean ConnectNamedPipe( + HANDLE hNamedPipe, + Pointer lpOverlapped); + boolean DisconnectNamedPipe( + HANDLE hObject); + boolean ReadFile( + HANDLE hFile, + Memory lpBuffer, + int nNumberOfBytesToRead, + IntByReference lpNumberOfBytesRead, + Pointer lpOverlapped); + boolean WriteFile( + HANDLE hFile, + ByteBuffer lpBuffer, + int nNumberOfBytesToWrite, + IntByReference lpNumberOfBytesWritten, + Pointer lpOverlapped); + boolean CloseHandle( + HANDLE hObject); + boolean GetOverlappedResult( + HANDLE hFile, + Pointer lpOverlapped, + IntByReference lpNumberOfBytesTransferred, + boolean wait); + boolean CancelIoEx( + HANDLE hObject, + Pointer lpOverlapped); + HANDLE CreateEvent( + SECURITY_ATTRIBUTES lpEventAttributes, + boolean bManualReset, + boolean bInitialState, + String lpName); + int WaitForSingleObject( + HANDLE hHandle, + int dwMilliseconds + ); + + int GetLastError(); +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java new file mode 100644 index 000000000..137d9b5dc --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeServerSocket.java @@ -0,0 +1,173 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeServerSocket.java + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.platform.win32.WinBase; +import com.sun.jna.platform.win32.WinError; +import com.sun.jna.platform.win32.WinNT; +import com.sun.jna.platform.win32.WinNT.HANDLE; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.LinkedBlockingQueue; + +public class NGWin32NamedPipeServerSocket extends ServerSocket { + private static final NGWin32NamedPipeLibrary API = NGWin32NamedPipeLibrary.INSTANCE; + private static final String WIN32_PIPE_PREFIX = "\\\\.\\pipe\\"; + private static final int BUFFER_SIZE = 65535; + private final LinkedBlockingQueue openHandles; + private final LinkedBlockingQueue connectedHandles; + private final NGWin32NamedPipeSocket.CloseCallback closeCallback; + private final String path; + private final int maxInstances; + private final HANDLE lockHandle; + + public NGWin32NamedPipeServerSocket(String path) throws IOException { + this(NGWin32NamedPipeLibrary.PIPE_UNLIMITED_INSTANCES, path); + } + + public NGWin32NamedPipeServerSocket(int maxInstances, String path) throws IOException { + this.openHandles = new LinkedBlockingQueue<>(); + this.connectedHandles = new LinkedBlockingQueue<>(); + this.closeCallback = handle -> { + if (connectedHandles.remove(handle)) { + closeConnectedPipe(handle, false); + } + if (openHandles.remove(handle)) { + closeOpenPipe(handle); + } + }; + this.maxInstances = maxInstances; + if (!path.startsWith(WIN32_PIPE_PREFIX)) { + this.path = WIN32_PIPE_PREFIX + path; + } else { + this.path = path; + } + String lockPath = this.path + "_lock"; + lockHandle = API.CreateNamedPipe( + lockPath, + NGWin32NamedPipeLibrary.FILE_FLAG_FIRST_PIPE_INSTANCE | NGWin32NamedPipeLibrary.PIPE_ACCESS_DUPLEX, + 0, + 1, + BUFFER_SIZE, + BUFFER_SIZE, + 0, + null); + if (lockHandle == NGWin32NamedPipeLibrary.INVALID_HANDLE_VALUE) { + throw new IOException(String.format("Could not create lock for %s, error %d", lockPath, API.GetLastError())); + } else { + if (!API.DisconnectNamedPipe(lockHandle)) { + throw new IOException(String.format("Could not disconnect lock %d", API.GetLastError())); + } + } + + } + + public void bind(SocketAddress endpoint) throws IOException { + throw new IOException("Win32 named pipes do not support bind(), pass path to constructor"); + } + + public Socket accept() throws IOException { + HANDLE handle = API.CreateNamedPipe( + path, + NGWin32NamedPipeLibrary.PIPE_ACCESS_DUPLEX | WinNT.FILE_FLAG_OVERLAPPED, + 0, + maxInstances, + BUFFER_SIZE, + BUFFER_SIZE, + 0, + null); + if (handle == NGWin32NamedPipeLibrary.INVALID_HANDLE_VALUE) { + throw new IOException(String.format("Could not create named pipe, error %d", API.GetLastError())); + } + openHandles.add(handle); + + HANDLE connWaitable = API.CreateEvent(null, true, false, null); + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = connWaitable; + olap.write(); + + boolean immediate = API.ConnectNamedPipe(handle, olap.getPointer()); + if (immediate) { + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } + + int connectError = API.GetLastError(); + if (connectError == WinError.ERROR_PIPE_CONNECTED) { + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else if (connectError == WinError.ERROR_NO_DATA) { + // Client has connected and disconnected between CreateNamedPipe() and ConnectNamedPipe() + // connection is broken, but it is returned it avoid loop here. + // Actual error will happen for NGSession when it will try to read/write from/to pipe + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else if (connectError == WinError.ERROR_IO_PENDING) { + if (!API.GetOverlappedResult(handle, olap.getPointer(), new IntByReference(), true)) { + openHandles.remove(handle); + closeOpenPipe(handle); + throw new IOException("GetOverlappedResult() failed for connect operation: " + API.GetLastError()); + } + openHandles.remove(handle); + connectedHandles.add(handle); + return new NGWin32NamedPipeSocket(handle, closeCallback); + } else { + throw new IOException("ConnectNamedPipe() failed with: " + connectError); + } + } + + public void close() throws IOException { + try { + List handlesToClose = new ArrayList<>(); + openHandles.drainTo(handlesToClose); + for (HANDLE handle : handlesToClose) { + closeOpenPipe(handle); + } + + List handlesToDisconnect = new ArrayList<>(); + connectedHandles.drainTo(handlesToDisconnect); + for (HANDLE handle : handlesToDisconnect) { + closeConnectedPipe(handle, true); + } + } finally { + API.CloseHandle(lockHandle); + } + } + + private void closeOpenPipe(HANDLE handle) throws IOException { + API.CancelIoEx(handle, null); + API.CloseHandle(handle); + } + + private void closeConnectedPipe(HANDLE handle, boolean shutdown) throws IOException { + if (!shutdown) { + API.WaitForSingleObject(handle, 10000); + } + API.DisconnectNamedPipe(handle); + API.CloseHandle(handle); + } +} diff --git a/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java new file mode 100644 index 000000000..072ed3901 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/NGWin32NamedPipeSocket.java @@ -0,0 +1,173 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/NGWin32NamedPipeSocket.java + +/* + + Copyright 2004-2017, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.Memory; +import com.sun.jna.platform.win32.WinBase; +import com.sun.jna.platform.win32.WinError; +import com.sun.jna.platform.win32.WinNT.HANDLE; +import com.sun.jna.ptr.IntByReference; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.Socket; +import java.nio.ByteBuffer; + +public class NGWin32NamedPipeSocket extends Socket { + private static final NGWin32NamedPipeLibrary API = NGWin32NamedPipeLibrary.INSTANCE; + private final HANDLE handle; + private final CloseCallback closeCallback; + private final InputStream is; + private final OutputStream os; + private final HANDLE readerWaitable; + private final HANDLE writerWaitable; + + interface CloseCallback { + void onNamedPipeSocketClose(HANDLE handle) throws IOException; + } + + public NGWin32NamedPipeSocket( + HANDLE handle, + NGWin32NamedPipeSocket.CloseCallback closeCallback) throws IOException { + this.handle = handle; + this.closeCallback = closeCallback; + this.readerWaitable = API.CreateEvent(null, true, false, null); + if (readerWaitable == null) { + throw new IOException("CreateEvent() failed "); + } + writerWaitable = API.CreateEvent(null, true, false, null); + if (writerWaitable == null) { + throw new IOException("CreateEvent() failed "); + } + this.is = new NGWin32NamedPipeSocketInputStream(handle); + this.os = new NGWin32NamedPipeSocketOutputStream(handle); + } + + @Override + public InputStream getInputStream() { + return is; + } + + @Override + public OutputStream getOutputStream() { + return os; + } + + @Override + public void close() throws IOException { + closeCallback.onNamedPipeSocketClose(handle); + } + + @Override + public void shutdownInput() throws IOException { + } + + @Override + public void shutdownOutput() throws IOException { + } + + private class NGWin32NamedPipeSocketInputStream extends InputStream { + private final HANDLE handle; + + NGWin32NamedPipeSocketInputStream(HANDLE handle) { + this.handle = handle; + } + + @Override + public int read() throws IOException { + int result; + byte[] b = new byte[1]; + if (read(b) == 0) { + result = -1; + } else { + result = 0xFF & b[0]; + } + return result; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + Memory readBuffer = new Memory(len); + + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = readerWaitable; + olap.write(); + + boolean immediate = API.ReadFile(handle, readBuffer, len, null, olap.getPointer()); + if (!immediate) { + int lastError = API.GetLastError(); + if (lastError != WinError.ERROR_IO_PENDING) { + throw new IOException("ReadFile() failed: " + lastError); + } + } + + IntByReference read = new IntByReference(); + if (!API.GetOverlappedResult(handle, olap.getPointer(), read, true)) { + int lastError = API.GetLastError(); + throw new IOException("GetOverlappedResult() failed for read operation: " + lastError); + } + if (read.getValue() != len) { + throw new IOException("ReadFile() read less bytes than requested"); + } + byte[] byteArray = readBuffer.getByteArray(0, len); + System.arraycopy(byteArray, 0, b, off, len); + return len; + } + } + + private class NGWin32NamedPipeSocketOutputStream extends OutputStream { + private final HANDLE handle; + + NGWin32NamedPipeSocketOutputStream(HANDLE handle) { + this.handle = handle; + } + + @Override + public void write(int b) throws IOException { + write(new byte[]{(byte) (0xFF & b)}); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + ByteBuffer data = ByteBuffer.wrap(b, off, len); + + WinBase.OVERLAPPED olap = new WinBase.OVERLAPPED(); + olap.hEvent = writerWaitable; + olap.write(); + + boolean immediate = API.WriteFile(handle, data, len, null, olap.getPointer()); + if (!immediate) { + int lastError = API.GetLastError(); + if (lastError != WinError.ERROR_IO_PENDING) { + throw new IOException("WriteFile() failed: " + lastError); + } + } + IntByReference written = new IntByReference(); + if (!API.GetOverlappedResult(handle, olap.getPointer(), written, true)) { + int lastError = API.GetLastError(); + throw new IOException("GetOverlappedResult() failed for write operation: " + lastError); + } + if (written.getValue() != len) { + throw new IOException("WriteFile() wrote less bytes than requested"); + } + } + } +} diff --git a/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java b/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java new file mode 100644 index 000000000..7fb5d9d53 --- /dev/null +++ b/main-command/src/main/java/sbt/internal/ReferenceCountedFileDescriptor.java @@ -0,0 +1,82 @@ +// Copied from https://github.com/facebook/nailgun/blob/af623fddedfdca010df46302a0711ce0e2cc1ba6/nailgun-server/src/main/java/com/martiansoftware/nailgun/ReferenceCountedFileDescriptor.java + +/* + + Copyright 2004-2015, Martian Software, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + */ +package sbt.internal; + +import com.sun.jna.LastErrorException; + +import java.io.IOException; + +/** + * Encapsulates a file descriptor plus a reference count to ensure close requests + * only close the file descriptor once the last reference to the file descriptor + * is released. + * + * If not explicitly closed, the file descriptor will be closed when + * this object is finalized. + */ +public class ReferenceCountedFileDescriptor { + private int fd; + private int fdRefCount; + private boolean closePending; + + public ReferenceCountedFileDescriptor(int fd) { + this.fd = fd; + this.fdRefCount = 0; + this.closePending = false; + } + + protected void finalize() throws IOException { + close(); + } + + public synchronized int acquire() { + fdRefCount++; + return fd; + } + + public synchronized void release() throws IOException { + fdRefCount--; + if (fdRefCount == 0 && closePending && fd != -1) { + doClose(); + } + } + + public synchronized void close() throws IOException { + if (fd == -1 || closePending) { + return; + } + + if (fdRefCount == 0) { + doClose(); + } else { + // Another thread has the FD. We'll close it when they release the reference. + closePending = true; + } + } + + private void doClose() throws IOException { + try { + NGUnixDomainSocketLibrary.close(fd); + fd = -1; + } catch (LastErrorException e) { + throw new IOException(e); + } + } +}