diff --git a/.github/workflows/client-test.yml b/.github/workflows/client-test.yml index 65c286adc..f7a289059 100644 --- a/.github/workflows/client-test.yml +++ b/.github/workflows/client-test.yml @@ -47,11 +47,24 @@ jobs: - name: Setup Windows C++ toolchain uses: ilammy/msvc-dev-cmd@v1 if: ${{ matrix.os == 'windows-latest' }} + - name: Setup musl toolchain (Linux) + if: ${{ matrix.os == 'ubuntu-latest' }} + shell: bash + run: | + sudo apt-get update + sudo apt-get install -y musl-tools musl-dev + curl -fsSL -o zlib.tar.gz https://github.com/madler/zlib/releases/download/v1.3.1/zlib-1.3.1.tar.gz + tar xzf zlib.tar.gz + cd zlib-1.3.1 + CC=musl-gcc ./configure --static --prefix=/usr/local/musl + make + sudo make install + sudo ln -sf /usr/local/musl/lib/libz.a /usr/lib/x86_64-linux-musl/libz.a - name: Client test (Linux) if: ${{ matrix.os == 'ubuntu-latest' }} shell: bash run: | - # test building sbtn on Linux + # test building sbtn on Linux with musl static linking sbt "-Dsbt.io.virtual=false" nativeImage # smoke test native Image ./client/target/bin/sbtn --sbt-script=$(pwd)/sbt about diff --git a/build.sbt b/build.sbt index 753c18884..c7d60616e 100644 --- a/build.sbt +++ b/build.sbt @@ -845,7 +845,7 @@ lazy val sbtClientProj = (project in file("client")) "-H:+ReportExceptionStackTraces", "-H:-ParseRuntimeOptions", s"-H:Name=${target.value / "bin" / "sbtn"}", - ), + ) ++ (if (isLinux) Seq("--static", "--libc=musl") else Nil), buildThinClient := { val isFish = Def.spaceDelimited("").parsed.headOption.fold(false)(_ == "--fish") val ext = if (isWin) ".bat" else if (isFish) ".fish" else ".sh" diff --git a/client/src/main/resources/META-INF/native-image/reflect-config.json b/client/src/main/resources/META-INF/native-image/reflect-config.json index a2b39abca..694e27688 100644 --- a/client/src/main/resources/META-INF/native-image/reflect-config.json +++ b/client/src/main/resources/META-INF/native-image/reflect-config.json @@ -14,5 +14,22 @@ { "name":"jline.UnixTerminal", "methods":[{"name":"","parameterTypes":[] }] + }, + { + "name":"java.net.UnixDomainSocketAddress", + "methods":[{"name":"of","parameterTypes":["java.nio.file.Path"] }] + }, + { + "name":"java.net.StandardProtocolFamily", + "allDeclaredFields":true, + "allPublicFields":true + }, + { + "name":"java.nio.channels.SocketChannel", + "methods":[{"name":"open","parameterTypes":["java.net.ProtocolFamily"] }] + }, + { + "name":"java.nio.channels.ServerSocketChannel", + "methods":[{"name":"open","parameterTypes":["java.net.ProtocolFamily"] }] } ] diff --git a/main-command/src/main/java/sbt/internal/BootServerSocket.java b/main-command/src/main/java/sbt/internal/BootServerSocket.java index 84cd3e69c..0689377a6 100644 --- a/main-command/src/main/java/sbt/internal/BootServerSocket.java +++ b/main-command/src/main/java/sbt/internal/BootServerSocket.java @@ -31,9 +31,7 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import net.openhft.hashing.LongHashFunction; -import org.scalasbt.ipcsocket.UnixDomainServerSocket; import org.scalasbt.ipcsocket.Win32NamedPipeServerSocket; -import org.scalasbt.ipcsocket.Win32NamedPipeSocket; import org.scalasbt.ipcsocket.Win32SecurityLevel; import sbt.internal.util.Terminal; import xsbti.AppConfiguration; @@ -351,7 +349,7 @@ public class BootServerSocket implements AutoCloseable { socket = isWindows ? new Win32NamedPipeServerSocket(name, jni, Win32SecurityLevel.OWNER_DACL) - : new UnixDomainServerSocket(name, jni); + : UnixDomainSocketFactory.newServerSocket(name, jni); return socket; } catch (final IOException e) { throw new ServerAlreadyBootingException(e); diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index e267e2ffe..eb693a203 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -32,6 +32,7 @@ import sbt.internal.util.{ } import sbt.io.IO import sbt.io.syntax.* +import sbt.internal.UnixDomainSocketFactory import sbt.protocol.* import sbt.util.Level import sjsonnew.BasicJsonProtocol.* @@ -1355,7 +1356,7 @@ object NetworkClient { } def main(args: Array[String]): Unit = { val (jnaArg, restOfArgs) = args.partition(_ == "--jna") - val useJNI = jnaArg.isEmpty + val useJNI = jnaArg.isEmpty && (Util.isWindows || !UnixDomainSocketFactory.isJdk17Available) val base = new File("").getCanonicalFile if (restOfArgs.exists(_.startsWith(NetworkClient.completions))) System.exit(complete(base, restOfArgs, useJNI, System.in, System.out)) diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 6043ebd42..51e2160db 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -27,7 +27,8 @@ import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } import sbt.internal.protocol.codec.* import sbt.internal.util.ErrorHandling import sbt.internal.util.Util.isWindows -import org.scalasbt.ipcsocket.* +import sbt.internal.UnixDomainSocketFactory +import org.scalasbt.ipcsocket.{ UnixDomainSocketLibraryProvider, Win32NamedPipeServerSocket } import sbt.internal.bsp.BuildServerConnection import xsbti.AppConfiguration @@ -82,9 +83,9 @@ private[sbt] object Server { "or define a short \"SBT_GLOBAL_SERVER_DIR\" value. " + s"Current path: ${path}" ) - tryClient(new UnixDomainSocket(path, connection.useJni)) + tryClient(UnixDomainSocketFactory.newSocket(path, connection.useJni)) prepareSocketfile() - addServerError(new UnixDomainServerSocket(path, connection.useJni)) + addServerError(UnixDomainSocketFactory.newServerSocket(path, connection.useJni)) case ConnectionType.Tcp => tryClient(new Socket(InetAddress.getByName(host), port)) addServerError(new ServerSocket(port, 50, InetAddress.getByName(host))) diff --git a/protocol/src/main/java/sbt/internal/UnixDomainSocketFactory.java b/protocol/src/main/java/sbt/internal/UnixDomainSocketFactory.java new file mode 100644 index 000000000..5ef3abd70 --- /dev/null +++ b/protocol/src/main/java/sbt/internal/UnixDomainSocketFactory.java @@ -0,0 +1,242 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.Method; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.nio.channels.Channels; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * Factory for creating Unix domain sockets. + * + *

On JDK 17+, uses native java.net.UnixDomainSocketAddress (no JNI required). On older JDKs, + * falls back to ipcsocket library (requires JNI). + * + *

This enables musl static linking on JDK 17+ by avoiding JNI dependencies. The ipcsocket + * classes are loaded via reflection only when needed, so they won't be loaded on JDK 17+. + */ +public final class UnixDomainSocketFactory { + + private static final boolean JDK17_AVAILABLE; + private static final Method UNIX_ADDRESS_OF_METHOD; + private static final Object UNIX_PROTOCOL_FAMILY; + + static { + boolean available = false; + Method ofMethod = null; + Object unixFamily = null; + + try { + Class unixAddressClass = Class.forName("java.net.UnixDomainSocketAddress"); + ofMethod = unixAddressClass.getMethod("of", Path.class); + + @SuppressWarnings("unchecked") + Class> protocolFamilyClass = + (Class>) Class.forName("java.net.StandardProtocolFamily"); + for (Object constant : protocolFamilyClass.getEnumConstants()) { + if ("UNIX".equals(((Enum) constant).name())) { + unixFamily = constant; + break; + } + } + if (unixFamily != null) { + available = true; + } + } catch (ClassNotFoundException | NoSuchMethodException | SecurityException e) { + available = false; + } + + JDK17_AVAILABLE = available; + UNIX_ADDRESS_OF_METHOD = ofMethod; + UNIX_PROTOCOL_FAMILY = unixFamily; + } + + public static boolean isJdk17Available() { + return JDK17_AVAILABLE; + } + + public static Socket newSocket(String path, boolean useJni) throws IOException { + if (JDK17_AVAILABLE && !useJni) { + return newJdk17Socket(path); + } else { + return newLegacySocket(path, useJni); + } + } + + public static ServerSocket newServerSocket(String path, boolean useJni) throws IOException { + if (JDK17_AVAILABLE && !useJni) { + return newJdk17ServerSocket(path); + } else { + return newLegacyServerSocket(path, useJni); + } + } + + private static Socket newLegacySocket(String path, boolean useJni) throws IOException { + try { + Class clazz = Class.forName("org.scalasbt.ipcsocket.UnixDomainSocket"); + return (Socket) clazz.getConstructor(String.class, boolean.class).newInstance(path, useJni); + } catch (ReflectiveOperationException e) { + throw new IOException("Failed to create ipcsocket UnixDomainSocket", e); + } + } + + private static ServerSocket newLegacyServerSocket(String path, boolean useJni) + throws IOException { + try { + Class clazz = Class.forName("org.scalasbt.ipcsocket.UnixDomainServerSocket"); + return (ServerSocket) + clazz.getConstructor(String.class, boolean.class).newInstance(path, useJni); + } catch (ReflectiveOperationException e) { + throw new IOException("Failed to create ipcsocket UnixDomainServerSocket", e); + } + } + + private static Socket newJdk17Socket(String path) throws IOException { + try { + SocketAddress address = (SocketAddress) UNIX_ADDRESS_OF_METHOD.invoke(null, Paths.get(path)); + SocketChannel channel = + (SocketChannel) + SocketChannel.class + .getMethod("open", java.net.ProtocolFamily.class) + .invoke(null, UNIX_PROTOCOL_FAMILY); + channel.connect(address); + return new ChannelSocket(channel); + } catch (ReflectiveOperationException e) { + throw new IOException("Failed to create JDK 17 Unix domain socket", e); + } + } + + private static ServerSocket newJdk17ServerSocket(String path) throws IOException { + try { + SocketAddress address = (SocketAddress) UNIX_ADDRESS_OF_METHOD.invoke(null, Paths.get(path)); + ServerSocketChannel channel = + (ServerSocketChannel) + ServerSocketChannel.class + .getMethod("open", java.net.ProtocolFamily.class) + .invoke(null, UNIX_PROTOCOL_FAMILY); + channel.bind(address); + return new ChannelServerSocket(channel); + } catch (ReflectiveOperationException e) { + throw new IOException("Failed to create JDK 17 Unix domain server socket", e); + } + } + + private UnixDomainSocketFactory() {} + + public static class ChannelSocket extends Socket { + private final SocketChannel channel; + private final InputStream inputStream; + private final OutputStream outputStream; + + public ChannelSocket(SocketChannel channel) { + this.channel = channel; + this.inputStream = Channels.newInputStream(channel); + this.outputStream = Channels.newOutputStream(channel); + } + + @Override + public InputStream getInputStream() { + return inputStream; + } + + @Override + public OutputStream getOutputStream() { + return outputStream; + } + + @Override + public void close() throws IOException { + channel.close(); + } + + @Override + public boolean isClosed() { + return !channel.isOpen(); + } + + @Override + public boolean isConnected() { + return channel.isConnected(); + } + + @Override + public SocketChannel getChannel() { + return channel; + } + } + + public static class ChannelServerSocket extends ServerSocket { + private final ServerSocketChannel channel; + private int soTimeout = 0; + + public ChannelServerSocket(ServerSocketChannel channel) throws IOException { + this.channel = channel; + channel.configureBlocking(true); + } + + @Override + public Socket accept() throws IOException { + if (soTimeout > 0) { + channel.configureBlocking(false); + long deadline = System.currentTimeMillis() + soTimeout; + while (System.currentTimeMillis() < deadline) { + SocketChannel clientChannel = channel.accept(); + if (clientChannel != null) { + return new ChannelSocket(clientChannel); + } + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new java.net.SocketTimeoutException("Accept interrupted"); + } + } + throw new java.net.SocketTimeoutException("Accept timed out"); + } else { + channel.configureBlocking(true); + SocketChannel clientChannel = channel.accept(); + return new ChannelSocket(clientChannel); + } + } + + @Override + public void close() throws IOException { + channel.close(); + } + + @Override + public boolean isClosed() { + return !channel.isOpen(); + } + + @Override + public ServerSocketChannel getChannel() { + return channel; + } + + @Override + public void setSoTimeout(int timeout) throws java.net.SocketException { + this.soTimeout = timeout; + } + + @Override + public int getSoTimeout() throws java.net.SocketException { + return soTimeout; + } + } +} diff --git a/protocol/src/main/scala/sbt/protocol/ClientSocket.scala b/protocol/src/main/scala/sbt/protocol/ClientSocket.scala index 79f333cc5..16f9d3f9b 100644 --- a/protocol/src/main/scala/sbt/protocol/ClientSocket.scala +++ b/protocol/src/main/scala/sbt/protocol/ClientSocket.scala @@ -17,7 +17,8 @@ import sjsonnew.shaded.scalajson.ast.unsafe.JValue import sbt.internal.protocol.{ PortFile, TokenFile } import sbt.internal.protocol.codec.{ PortFileFormats, TokenFileFormats } import sbt.internal.util.Util.isWindows -import org.scalasbt.ipcsocket.* +import sbt.internal.UnixDomainSocketFactory +import org.scalasbt.ipcsocket.Win32NamedPipeSocket object ClientSocket { private lazy val fileFormats = new BasicJsonProtocol with PortFileFormats with TokenFileFormats {} @@ -44,5 +45,5 @@ object ClientSocket { } def localSocket(name: String, useJNI: Boolean): Socket = if (isWindows) new Win32NamedPipeSocket(s"\\\\.\\pipe\\$name", useJNI) - else new UnixDomainSocket(name, useJNI) + else UnixDomainSocketFactory.newSocket(name, useJNI) }