Pure Java launcher (size ~5 MB -> 8 kB)

This commit is contained in:
Alexandre Archambault 2015-11-28 17:14:18 +01:00
parent a391b8c36e
commit 4657991531
5 changed files with 243 additions and 167 deletions

View File

@ -0,0 +1,223 @@
package coursier;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLConnection;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
public class Bootstrap {
static void exit(String message) {
System.err.println(message);
System.exit(255);
}
static byte[] readFullySync(InputStream is) throws IOException {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
byte[] data = new byte[16384];
int nRead = is.read(data, 0, data.length);
while (nRead != -1) {
buffer.write(data, 0, nRead);
nRead = is.read(data, 0, data.length);
}
buffer.flush();
return buffer.toByteArray();
}
final static String usage = "Usage: bootstrap main-class JAR-directory JAR-URLs...";
final static int concurrentDownloadCount = 6;
public static void main(String[] args) throws Throwable {
ThreadFactory threadFactory = new ThreadFactory() {
// from scalaz Strategy.DefaultDaemonThreadFactory
ThreadFactory defaultThreadFactory = Executors.defaultThreadFactory();
public Thread newThread(Runnable r) {
Thread t = defaultThreadFactory.newThread(r);
t.setDaemon(true);
return t;
}
};
ExecutorService pool = Executors.newFixedThreadPool(concurrentDownloadCount, threadFactory);
boolean prependClasspath = false;
if (args.length > 0 && args[0].equals("-B"))
prependClasspath = true;
if (args.length < 2 || (prependClasspath && args.length < 3)) {
exit(usage);
}
int offset = 0;
if (prependClasspath)
offset += 1;
String mainClass0 = args[offset];
String jarDir0 = args[offset + 1];
List<String> remainingArgs = new ArrayList<>();
for (int i = offset + 2; i < args.length; i++)
remainingArgs.add(args[i]);
File jarDir = new File(jarDir0);
if (jarDir.exists()) {
if (!jarDir.isDirectory())
exit("Error: " + jarDir0 + " is not a directory");
} else if (!jarDir.mkdirs())
System.err.println("Warning: cannot create " + jarDir0 + ", continuing anyway.");
int splitIdx = remainingArgs.indexOf("--");
List<String> jarStrUrls;
List<String> userArgs;
if (splitIdx < 0) {
jarStrUrls = remainingArgs;
userArgs = new ArrayList<>();
} else {
jarStrUrls = remainingArgs.subList(0, splitIdx);
userArgs = remainingArgs.subList(splitIdx + 1, remainingArgs.size());
}
List<String> errors = new ArrayList<>();
List<URL> urls = new ArrayList<>();
for (String urlStr : jarStrUrls) {
try {
URL url = URI.create(urlStr).toURL();
urls.add(url);
} catch (Exception ex) {
String message = urlStr + ": " + ex.getMessage();
errors.add(message);
}
}
if (!errors.isEmpty()) {
StringBuilder builder = new StringBuilder("Error parsing " + errors.size() + " URL(s):");
for (String error: errors) {
builder.append('\n');
builder.append(error);
}
exit(builder.toString());
}
CompletionService<URL> completionService =
new ExecutorCompletionService<>(pool);
List<URL> localURLs = new ArrayList<>();
for (URL url : urls) {
if (!url.getProtocol().equals("file")) {
completionService.submit(new Callable<URL>() {
@Override
public URL call() throws Exception {
String path = url.getPath();
int idx = path.lastIndexOf('/');
// FIXME Add other components in path to prevent conflicts?
String fileName = path.substring(idx + 1);
File dest = new File(jarDir, fileName);
if (!dest.exists()) {
System.err.println("Downloading " + url);
try {
URLConnection conn = url.openConnection();
long lastModified = conn.getLastModified();
InputStream s = conn.getInputStream();
byte[] b = readFullySync(s);
Files.write(dest.toPath(), b);
dest.setLastModified(lastModified);
} catch (Exception e) {
System.err.println("Error while downloading " + url + ": " + e.getMessage() + ", ignoring it");
throw e;
}
}
return dest.toURI().toURL();
}
});
} else {
localURLs.add(url);
}
}
try {
while (localURLs.size() < urls.size()) {
Future<URL> future = completionService.take();
try {
URL url = future.get();
localURLs.add(url);
} catch (ExecutionException ex) {
// Error message already printed from the Callable above
System.exit(255);
}
}
} catch (InterruptedException ex) {
exit("Interrupted");
}
Thread thread = Thread.currentThread();
ClassLoader parentClassLoader = thread.getContextClassLoader();
URLClassLoader classLoader = new URLClassLoader(localURLs.toArray(new URL[localURLs.size()]), parentClassLoader);
Class<?> mainClass = null;
Method mainMethod = null;
try {
mainClass = classLoader.loadClass(mainClass0);
} catch (ClassNotFoundException ex) {
exit("Error: class " + mainClass0 + " not found");
}
try {
Class params[] = { String[].class };
mainMethod = mainClass.getMethod("main", params);
}
catch (NoSuchMethodException ex) {
exit("Error: main method not found in class " + mainClass0);
}
List<String> userArgs0 = new ArrayList<>();
if (prependClasspath) {
for (URL url : localURLs) {
assert url.getProtocol().equals("file");
userArgs0.add("-B");
userArgs0.add(url.getPath());
}
}
userArgs0.addAll(userArgs);
thread.setContextClassLoader(classLoader);
try {
Object mainArgs[] = { userArgs0.toArray(new String[userArgs0.size()]) };
mainMethod.invoke(null, mainArgs);
}
catch (IllegalAccessException ex) {
exit(ex.getMessage());
}
catch (InvocationTargetException ex) {
throw ex.getCause();
}
finally {
thread.setContextClassLoader(parentClassLoader);
}
}
}

View File

@ -1,155 +0,0 @@
package coursier
import java.io.{ ByteArrayOutputStream, InputStream, File }
import java.net.{ URI, URLClassLoader }
import java.nio.file.Files
import java.util.concurrent.{ Executors, ThreadFactory }
import scala.concurrent.duration.Duration
import scala.concurrent.{ ExecutionContext, Future, Await }
import scala.util.{ Try, Success, Failure }
object Bootstrap extends App {
val concurrentDownloadCount = 6
val threadFactory = new ThreadFactory {
// from scalaz Strategy.DefaultDaemonThreadFactory
val defaultThreadFactory = Executors.defaultThreadFactory()
def newThread(r: Runnable) = {
val t = defaultThreadFactory.newThread(r)
t.setDaemon(true)
t
}
}
val defaultPool = Executors.newFixedThreadPool(concurrentDownloadCount, threadFactory)
implicit val ec = ExecutionContext.fromExecutorService(defaultPool)
private def readFullySync(is: InputStream) = {
val buffer = new ByteArrayOutputStream()
val data = Array.ofDim[Byte](16384)
var nRead = is.read(data, 0, data.length)
while (nRead != -1) {
buffer.write(data, 0, nRead)
nRead = is.read(data, 0, data.length)
}
buffer.flush()
buffer.toByteArray
}
private def errPrintln(s: String): Unit =
Console.err.println(s)
private def exit(msg: String = ""): Nothing = {
if (msg.nonEmpty)
errPrintln(msg)
sys.exit(255)
}
val (prependClasspath, mainClass0, jarDir0, remainingArgs) = args match {
case Array("-B", mainClass0, jarDir0, remainingArgs @ _*) =>
(true, mainClass0, jarDir0, remainingArgs)
case Array(mainClass0, jarDir0, remainingArgs @ _*) =>
(false, mainClass0, jarDir0, remainingArgs)
case _ =>
exit("Usage: bootstrap main-class JAR-directory JAR-URLs...")
}
val jarDir = new File(jarDir0)
if (jarDir.exists()) {
if (!jarDir.isDirectory)
exit(s"Error: $jarDir0 is not a directory")
} else if (!jarDir.mkdirs())
errPrintln(s"Warning: cannot create $jarDir0, continuing anyway.")
val splitIdx = remainingArgs.indexOf("--")
val (jarStrUrls, userArgs) =
if (splitIdx < 0)
(remainingArgs, Nil)
else
(remainingArgs.take(splitIdx), remainingArgs.drop(splitIdx + 1))
val tryUrls = jarStrUrls.map(urlStr => urlStr -> Try(URI.create(urlStr).toURL))
val failedUrls = tryUrls.collect {
case (strUrl, Failure(t)) => strUrl -> t
}
if (failedUrls.nonEmpty)
exit(
s"Error parsing ${failedUrls.length} URL(s):\n" +
failedUrls.map { case (s, t) => s"$s: ${t.getMessage}" }.mkString("\n")
)
val jarUrls = tryUrls.collect {
case (_, Success(url)) => url
}
val jarLocalUrlFutures = jarUrls.map { url =>
if (url.getProtocol == "file")
Future.successful(url)
else
Future {
val path = url.getPath
val idx = path.lastIndexOf('/')
// FIXME Add other components in path to prevent conflicts?
val fileName = path.drop(idx + 1)
val dest = new File(jarDir, fileName)
// FIXME If dest exists, do a HEAD request and check that its size or last modified time is OK?
if (!dest.exists()) {
Console.err.println(s"Downloading $url")
try {
val conn = url.openConnection()
val lastModified = conn.getLastModified
val s = conn.getInputStream
val b = readFullySync(s)
Files.write(dest.toPath, b)
dest.setLastModified(lastModified)
} catch { case e: Exception =>
Console.err.println(s"Error while downloading $url: ${e.getMessage}, ignoring it")
}
}
dest.toURI.toURL
}
}
val jarLocalUrls = Await.result(Future.sequence(jarLocalUrlFutures), Duration.Inf)
val thread = Thread.currentThread()
val parentClassLoader = thread.getContextClassLoader
val classLoader = new URLClassLoader(jarLocalUrls.toArray, parentClassLoader)
val mainClass =
try classLoader.loadClass(mainClass0)
catch { case e: ClassNotFoundException =>
exit(s"Error: class $mainClass0 not found")
}
val mainMethod =
try mainClass.getMethod("main", classOf[Array[String]])
catch { case e: NoSuchMethodException =>
exit(s"Error: main method not found in class $mainClass0")
}
val userArgs0 =
if (prependClasspath)
jarLocalUrls.flatMap { url =>
assert(url.getProtocol == "file")
Seq("-B", url.getPath)
} ++ userArgs
else
userArgs
thread.setContextClassLoader(classLoader)
try mainMethod.invoke(null, userArgs0.toArray)
finally {
thread.setContextClassLoader(parentClassLoader)
}
}

View File

@ -40,15 +40,18 @@ lazy val noPublishSettings = Seq(
publishArtifact := false
)
lazy val commonSettings = Seq(
lazy val baseCommonSettings = Seq(
organization := "com.github.alexarchambault",
scalaVersion := "2.11.7",
crossScalaVersions := Seq("2.10.6", "2.11.7"),
resolvers ++= Seq(
"Scalaz Bintray Repo" at "http://dl.bintray.com/scalaz/releases",
Resolver.sonatypeRepo("releases"),
Resolver.sonatypeRepo("snapshots")
),
)
)
lazy val commonSettings = baseCommonSettings ++ Seq(
scalaVersion := "2.11.7",
crossScalaVersions := Seq("2.10.6", "2.11.7"),
libraryDependencies ++= {
if (scalaVersion.value startsWith "2.10.")
Seq(compilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full))
@ -120,7 +123,7 @@ lazy val cli = project
"com.github.alexarchambault" %% "case-app" % "1.0.0-SNAPSHOT",
"ch.qos.logback" % "logback-classic" % "1.1.3"
),
resourceGenerators in Compile += assembly.in(bootstrap).in(assembly).map { jar =>
resourceGenerators in Compile += packageBin.in(bootstrap).in(Compile).map { jar =>
Seq(jar)
}.taskValue
)
@ -157,14 +160,20 @@ lazy val web = project
)
lazy val bootstrap = project
.settings(commonSettings)
.settings(noPublishSettings)
.settings(baseCommonSettings)
.settings(publishingSettings)
.settings(
name := "coursier-bootstrap",
assemblyJarName in assembly := s"bootstrap.jar",
assemblyShadeRules in assembly := Seq(
ShadeRule.rename("scala.**" -> "shadedscala.@1").inAll
)
artifactName := {
val artifactName0 = artifactName.value
(sv, m, artifact) =>
if (artifact.`type` == "jar" && artifact.extension == "jar")
"bootstrap.jar"
else
artifactName0(sv, m, artifact)
},
crossPaths := false,
autoScalaLibrary := false
)
lazy val `coursier` = project.in(file("."))

BIN
coursier

Binary file not shown.

View File

@ -3,4 +3,3 @@ addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.5")
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0")
addSbtPlugin("com.github.gseitz" % "sbt-release" % "0.8.5")
addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.1.0")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.0")