From e16298521b90b0c3eb041210381bca90b0b12ea6 Mon Sep 17 00:00:00 2001 From: MkDev11 Date: Tue, 13 Jan 2026 23:14:45 -0500 Subject: [PATCH] [2.x] feat: Enable musl static linking for sbtn on JDK 17+ (#8464) ** Problem ** The sbtn (sbt thin client) native image on Linux currently depends on glibc because ipcsocket uses JNI for Unix domain sockets. When building with musl for static linking, the JNI library fails to load since musl doesn't support `dlopen`. ** Solution ** Instead of upgrading to ipcsocket 2.x (which isn't ready for production), I created a `UnixDomainSocketFactory` that detects JDK 17+ at runtime and uses the native `java.net.UnixDomainSocketAddress` API directly via reflection. This completely bypasses JNI on modern JDKs. For older JDKs (8 and 11), the factory falls back to ipcsocket 1.6.3, which is stable and well-tested. ** How It Works ** The factory checks at startup whether `java.net.UnixDomainSocketAddress` is available: - **JDK 17+**: Uses native NIO Unix domain sockets (no JNI, no native libraries) - **JDK 8/11**: Falls back to ipcsocket's JNI-based implementation This approach: - Enables musl static linking on JDK 17+ without any native dependencies - Maintains full backward compatibility with older JDKs - Keeps the stable ipcsocket 1.6.3 instead of the unstable 2.x --- .github/workflows/client-test.yml | 15 +- build.sbt | 2 +- .../META-INF/native-image/reflect-config.json | 17 ++ .../java/sbt/internal/BootServerSocket.java | 4 +- .../sbt/internal/client/NetworkClient.scala | 3 +- .../scala/sbt/internal/server/Server.scala | 7 +- .../sbt/internal/UnixDomainSocketFactory.java | 242 ++++++++++++++++++ .../scala/sbt/protocol/ClientSocket.scala | 5 +- 8 files changed, 284 insertions(+), 11 deletions(-) create mode 100644 protocol/src/main/java/sbt/internal/UnixDomainSocketFactory.java diff --git a/.github/workflows/client-test.yml b/.github/workflows/client-test.yml index 65c286adc..f7a289059 100644 --- a/.github/workflows/client-test.yml +++ b/.github/workflows/client-test.yml @@ -47,11 +47,24 @@ jobs: - name: Setup Windows C++ toolchain uses: ilammy/msvc-dev-cmd@v1 if: ${{ matrix.os == 'windows-latest' }} + - name: Setup musl toolchain (Linux) + if: ${{ matrix.os == 'ubuntu-latest' }} + shell: bash + run: | + sudo apt-get update + sudo apt-get install -y musl-tools musl-dev + curl -fsSL -o zlib.tar.gz https://github.com/madler/zlib/releases/download/v1.3.1/zlib-1.3.1.tar.gz + tar xzf zlib.tar.gz + cd zlib-1.3.1 + CC=musl-gcc ./configure --static --prefix=/usr/local/musl + make + sudo make install + sudo ln -sf /usr/local/musl/lib/libz.a /usr/lib/x86_64-linux-musl/libz.a - name: Client test (Linux) if: ${{ matrix.os == 'ubuntu-latest' }} shell: bash run: | - # test building sbtn on Linux + # test building sbtn on Linux with musl static linking sbt "-Dsbt.io.virtual=false" nativeImage # smoke test native Image ./client/target/bin/sbtn --sbt-script=$(pwd)/sbt about diff --git a/build.sbt b/build.sbt index 753c18884..c7d60616e 100644 --- a/build.sbt +++ b/build.sbt @@ -845,7 +845,7 @@ lazy val sbtClientProj = (project in file("client")) "-H:+ReportExceptionStackTraces", "-H:-ParseRuntimeOptions", s"-H:Name=${target.value / "bin" / "sbtn"}", - ), + ) ++ (if (isLinux) Seq("--static", "--libc=musl") else Nil), buildThinClient := { val isFish = Def.spaceDelimited("").parsed.headOption.fold(false)(_ == "--fish") val ext = if (isWin) ".bat" else if (isFish) ".fish" else ".sh" diff --git a/client/src/main/resources/META-INF/native-image/reflect-config.json b/client/src/main/resources/META-INF/native-image/reflect-config.json index a2b39abca..694e27688 100644 --- a/client/src/main/resources/META-INF/native-image/reflect-config.json +++ b/client/src/main/resources/META-INF/native-image/reflect-config.json @@ -14,5 +14,22 @@ { "name":"jline.UnixTerminal", "methods":[{"name":"","parameterTypes":[] }] + }, + { + "name":"java.net.UnixDomainSocketAddress", + "methods":[{"name":"of","parameterTypes":["java.nio.file.Path"] }] + }, + { + "name":"java.net.StandardProtocolFamily", + "allDeclaredFields":true, + "allPublicFields":true + }, + { + "name":"java.nio.channels.SocketChannel", + "methods":[{"name":"open","parameterTypes":["java.net.ProtocolFamily"] }] + }, + { + "name":"java.nio.channels.ServerSocketChannel", + "methods":[{"name":"open","parameterTypes":["java.net.ProtocolFamily"] }] } ] diff --git a/main-command/src/main/java/sbt/internal/BootServerSocket.java b/main-command/src/main/java/sbt/internal/BootServerSocket.java index 84cd3e69c..0689377a6 100644 --- a/main-command/src/main/java/sbt/internal/BootServerSocket.java +++ b/main-command/src/main/java/sbt/internal/BootServerSocket.java @@ -31,9 +31,7 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import net.openhft.hashing.LongHashFunction; -import org.scalasbt.ipcsocket.UnixDomainServerSocket; import org.scalasbt.ipcsocket.Win32NamedPipeServerSocket; -import org.scalasbt.ipcsocket.Win32NamedPipeSocket; import org.scalasbt.ipcsocket.Win32SecurityLevel; import sbt.internal.util.Terminal; import xsbti.AppConfiguration; @@ -351,7 +349,7 @@ public class BootServerSocket implements AutoCloseable { socket = isWindows ? new Win32NamedPipeServerSocket(name, jni, Win32SecurityLevel.OWNER_DACL) - : new UnixDomainServerSocket(name, jni); + : UnixDomainSocketFactory.newServerSocket(name, jni); return socket; } catch (final IOException e) { throw new ServerAlreadyBootingException(e); diff --git a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala index e267e2ffe..eb693a203 100644 --- a/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala +++ b/main-command/src/main/scala/sbt/internal/client/NetworkClient.scala @@ -32,6 +32,7 @@ import sbt.internal.util.{ } import sbt.io.IO import sbt.io.syntax.* +import sbt.internal.UnixDomainSocketFactory import sbt.protocol.* import sbt.util.Level import sjsonnew.BasicJsonProtocol.* @@ -1355,7 +1356,7 @@ object NetworkClient { } def main(args: Array[String]): Unit = { val (jnaArg, restOfArgs) = args.partition(_ == "--jna") - val useJNI = jnaArg.isEmpty + val useJNI = jnaArg.isEmpty && (Util.isWindows || !UnixDomainSocketFactory.isJdk17Available) val base = new File("").getCanonicalFile if (restOfArgs.exists(_.startsWith(NetworkClient.completions))) System.exit(complete(base, restOfArgs, useJNI, System.in, System.out)) diff --git a/main-command/src/main/scala/sbt/internal/server/Server.scala b/main-command/src/main/scala/sbt/internal/server/Server.scala index 6043ebd42..51e2160db 100644 --- a/main-command/src/main/scala/sbt/internal/server/Server.scala +++ b/main-command/src/main/scala/sbt/internal/server/Server.scala @@ -27,7 +27,8 @@ import sjsonnew.support.scalajson.unsafe.{ CompactPrinter, Converter } import sbt.internal.protocol.codec.* import sbt.internal.util.ErrorHandling import sbt.internal.util.Util.isWindows -import org.scalasbt.ipcsocket.* +import sbt.internal.UnixDomainSocketFactory +import org.scalasbt.ipcsocket.{ UnixDomainSocketLibraryProvider, Win32NamedPipeServerSocket } import sbt.internal.bsp.BuildServerConnection import xsbti.AppConfiguration @@ -82,9 +83,9 @@ private[sbt] object Server { "or define a short \"SBT_GLOBAL_SERVER_DIR\" value. " + s"Current path: ${path}" ) - tryClient(new UnixDomainSocket(path, connection.useJni)) + tryClient(UnixDomainSocketFactory.newSocket(path, connection.useJni)) prepareSocketfile() - addServerError(new UnixDomainServerSocket(path, connection.useJni)) + addServerError(UnixDomainSocketFactory.newServerSocket(path, connection.useJni)) case ConnectionType.Tcp => tryClient(new Socket(InetAddress.getByName(host), port)) addServerError(new ServerSocket(port, 50, InetAddress.getByName(host))) diff --git a/protocol/src/main/java/sbt/internal/UnixDomainSocketFactory.java b/protocol/src/main/java/sbt/internal/UnixDomainSocketFactory.java new file mode 100644 index 000000000..5ef3abd70 --- /dev/null +++ b/protocol/src/main/java/sbt/internal/UnixDomainSocketFactory.java @@ -0,0 +1,242 @@ +/* + * sbt + * Copyright 2023, Scala center + * Copyright 2011 - 2022, Lightbend, Inc. + * Copyright 2008 - 2010, Mark Harrah + * Licensed under Apache License 2.0 (see LICENSE) + */ + +package sbt.internal; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.Method; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.nio.channels.Channels; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * Factory for creating Unix domain sockets. + * + *

On JDK 17+, uses native java.net.UnixDomainSocketAddress (no JNI required). On older JDKs, + * falls back to ipcsocket library (requires JNI). + * + *

This enables musl static linking on JDK 17+ by avoiding JNI dependencies. The ipcsocket + * classes are loaded via reflection only when needed, so they won't be loaded on JDK 17+. + */ +public final class UnixDomainSocketFactory { + + private static final boolean JDK17_AVAILABLE; + private static final Method UNIX_ADDRESS_OF_METHOD; + private static final Object UNIX_PROTOCOL_FAMILY; + + static { + boolean available = false; + Method ofMethod = null; + Object unixFamily = null; + + try { + Class unixAddressClass = Class.forName("java.net.UnixDomainSocketAddress"); + ofMethod = unixAddressClass.getMethod("of", Path.class); + + @SuppressWarnings("unchecked") + Class> protocolFamilyClass = + (Class>) Class.forName("java.net.StandardProtocolFamily"); + for (Object constant : protocolFamilyClass.getEnumConstants()) { + if ("UNIX".equals(((Enum) constant).name())) { + unixFamily = constant; + break; + } + } + if (unixFamily != null) { + available = true; + } + } catch (ClassNotFoundException | NoSuchMethodException | SecurityException e) { + available = false; + } + + JDK17_AVAILABLE = available; + UNIX_ADDRESS_OF_METHOD = ofMethod; + UNIX_PROTOCOL_FAMILY = unixFamily; + } + + public static boolean isJdk17Available() { + return JDK17_AVAILABLE; + } + + public static Socket newSocket(String path, boolean useJni) throws IOException { + if (JDK17_AVAILABLE && !useJni) { + return newJdk17Socket(path); + } else { + return newLegacySocket(path, useJni); + } + } + + public static ServerSocket newServerSocket(String path, boolean useJni) throws IOException { + if (JDK17_AVAILABLE && !useJni) { + return newJdk17ServerSocket(path); + } else { + return newLegacyServerSocket(path, useJni); + } + } + + private static Socket newLegacySocket(String path, boolean useJni) throws IOException { + try { + Class clazz = Class.forName("org.scalasbt.ipcsocket.UnixDomainSocket"); + return (Socket) clazz.getConstructor(String.class, boolean.class).newInstance(path, useJni); + } catch (ReflectiveOperationException e) { + throw new IOException("Failed to create ipcsocket UnixDomainSocket", e); + } + } + + private static ServerSocket newLegacyServerSocket(String path, boolean useJni) + throws IOException { + try { + Class clazz = Class.forName("org.scalasbt.ipcsocket.UnixDomainServerSocket"); + return (ServerSocket) + clazz.getConstructor(String.class, boolean.class).newInstance(path, useJni); + } catch (ReflectiveOperationException e) { + throw new IOException("Failed to create ipcsocket UnixDomainServerSocket", e); + } + } + + private static Socket newJdk17Socket(String path) throws IOException { + try { + SocketAddress address = (SocketAddress) UNIX_ADDRESS_OF_METHOD.invoke(null, Paths.get(path)); + SocketChannel channel = + (SocketChannel) + SocketChannel.class + .getMethod("open", java.net.ProtocolFamily.class) + .invoke(null, UNIX_PROTOCOL_FAMILY); + channel.connect(address); + return new ChannelSocket(channel); + } catch (ReflectiveOperationException e) { + throw new IOException("Failed to create JDK 17 Unix domain socket", e); + } + } + + private static ServerSocket newJdk17ServerSocket(String path) throws IOException { + try { + SocketAddress address = (SocketAddress) UNIX_ADDRESS_OF_METHOD.invoke(null, Paths.get(path)); + ServerSocketChannel channel = + (ServerSocketChannel) + ServerSocketChannel.class + .getMethod("open", java.net.ProtocolFamily.class) + .invoke(null, UNIX_PROTOCOL_FAMILY); + channel.bind(address); + return new ChannelServerSocket(channel); + } catch (ReflectiveOperationException e) { + throw new IOException("Failed to create JDK 17 Unix domain server socket", e); + } + } + + private UnixDomainSocketFactory() {} + + public static class ChannelSocket extends Socket { + private final SocketChannel channel; + private final InputStream inputStream; + private final OutputStream outputStream; + + public ChannelSocket(SocketChannel channel) { + this.channel = channel; + this.inputStream = Channels.newInputStream(channel); + this.outputStream = Channels.newOutputStream(channel); + } + + @Override + public InputStream getInputStream() { + return inputStream; + } + + @Override + public OutputStream getOutputStream() { + return outputStream; + } + + @Override + public void close() throws IOException { + channel.close(); + } + + @Override + public boolean isClosed() { + return !channel.isOpen(); + } + + @Override + public boolean isConnected() { + return channel.isConnected(); + } + + @Override + public SocketChannel getChannel() { + return channel; + } + } + + public static class ChannelServerSocket extends ServerSocket { + private final ServerSocketChannel channel; + private int soTimeout = 0; + + public ChannelServerSocket(ServerSocketChannel channel) throws IOException { + this.channel = channel; + channel.configureBlocking(true); + } + + @Override + public Socket accept() throws IOException { + if (soTimeout > 0) { + channel.configureBlocking(false); + long deadline = System.currentTimeMillis() + soTimeout; + while (System.currentTimeMillis() < deadline) { + SocketChannel clientChannel = channel.accept(); + if (clientChannel != null) { + return new ChannelSocket(clientChannel); + } + try { + Thread.sleep(50); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new java.net.SocketTimeoutException("Accept interrupted"); + } + } + throw new java.net.SocketTimeoutException("Accept timed out"); + } else { + channel.configureBlocking(true); + SocketChannel clientChannel = channel.accept(); + return new ChannelSocket(clientChannel); + } + } + + @Override + public void close() throws IOException { + channel.close(); + } + + @Override + public boolean isClosed() { + return !channel.isOpen(); + } + + @Override + public ServerSocketChannel getChannel() { + return channel; + } + + @Override + public void setSoTimeout(int timeout) throws java.net.SocketException { + this.soTimeout = timeout; + } + + @Override + public int getSoTimeout() throws java.net.SocketException { + return soTimeout; + } + } +} diff --git a/protocol/src/main/scala/sbt/protocol/ClientSocket.scala b/protocol/src/main/scala/sbt/protocol/ClientSocket.scala index 79f333cc5..16f9d3f9b 100644 --- a/protocol/src/main/scala/sbt/protocol/ClientSocket.scala +++ b/protocol/src/main/scala/sbt/protocol/ClientSocket.scala @@ -17,7 +17,8 @@ import sjsonnew.shaded.scalajson.ast.unsafe.JValue import sbt.internal.protocol.{ PortFile, TokenFile } import sbt.internal.protocol.codec.{ PortFileFormats, TokenFileFormats } import sbt.internal.util.Util.isWindows -import org.scalasbt.ipcsocket.* +import sbt.internal.UnixDomainSocketFactory +import org.scalasbt.ipcsocket.Win32NamedPipeSocket object ClientSocket { private lazy val fileFormats = new BasicJsonProtocol with PortFileFormats with TokenFileFormats {} @@ -44,5 +45,5 @@ object ClientSocket { } def localSocket(name: String, useJNI: Boolean): Socket = if (isWindows) new Win32NamedPipeSocket(s"\\\\.\\pipe\\$name", useJNI) - else new UnixDomainSocket(name, useJNI) + else UnixDomainSocketFactory.newSocket(name, useJNI) }