diff --git a/bootstrap/src/main/java/coursier/Bootstrap.java b/bootstrap/src/main/java/coursier/Bootstrap.java new file mode 100644 index 000000000..8222ea293 --- /dev/null +++ b/bootstrap/src/main/java/coursier/Bootstrap.java @@ -0,0 +1,223 @@ +package coursier; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.URI; +import java.net.URL; +import java.net.URLClassLoader; +import java.net.URLConnection; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.*; + +public class Bootstrap { + + static void exit(String message) { + System.err.println(message); + System.exit(255); + } + + static byte[] readFullySync(InputStream is) throws IOException { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + byte[] data = new byte[16384]; + + int nRead = is.read(data, 0, data.length); + while (nRead != -1) { + buffer.write(data, 0, nRead); + nRead = is.read(data, 0, data.length); + } + + buffer.flush(); + return buffer.toByteArray(); + } + + final static String usage = "Usage: bootstrap main-class JAR-directory JAR-URLs..."; + + final static int concurrentDownloadCount = 6; + + public static void main(String[] args) throws Throwable { + + ThreadFactory threadFactory = new ThreadFactory() { + // from scalaz Strategy.DefaultDaemonThreadFactory + ThreadFactory defaultThreadFactory = Executors.defaultThreadFactory(); + public Thread newThread(Runnable r) { + Thread t = defaultThreadFactory.newThread(r); + t.setDaemon(true); + return t; + } + }; + + ExecutorService pool = Executors.newFixedThreadPool(concurrentDownloadCount, threadFactory); + + boolean prependClasspath = false; + + if (args.length > 0 && args[0].equals("-B")) + prependClasspath = true; + + if (args.length < 2 || (prependClasspath && args.length < 3)) { + exit(usage); + } + + int offset = 0; + if (prependClasspath) + offset += 1; + + String mainClass0 = args[offset]; + String jarDir0 = args[offset + 1]; + + List remainingArgs = new ArrayList<>(); + for (int i = offset + 2; i < args.length; i++) + remainingArgs.add(args[i]); + + File jarDir = new File(jarDir0); + + if (jarDir.exists()) { + if (!jarDir.isDirectory()) + exit("Error: " + jarDir0 + " is not a directory"); + } else if (!jarDir.mkdirs()) + System.err.println("Warning: cannot create " + jarDir0 + ", continuing anyway."); + + int splitIdx = remainingArgs.indexOf("--"); + List jarStrUrls; + List userArgs; + + if (splitIdx < 0) { + jarStrUrls = remainingArgs; + userArgs = new ArrayList<>(); + } else { + jarStrUrls = remainingArgs.subList(0, splitIdx); + userArgs = remainingArgs.subList(splitIdx + 1, remainingArgs.size()); + } + + List errors = new ArrayList<>(); + List urls = new ArrayList<>(); + + for (String urlStr : jarStrUrls) { + try { + URL url = URI.create(urlStr).toURL(); + urls.add(url); + } catch (Exception ex) { + String message = urlStr + ": " + ex.getMessage(); + errors.add(message); + } + } + + if (!errors.isEmpty()) { + StringBuilder builder = new StringBuilder("Error parsing " + errors.size() + " URL(s):"); + for (String error: errors) { + builder.append('\n'); + builder.append(error); + } + exit(builder.toString()); + } + + CompletionService completionService = + new ExecutorCompletionService<>(pool); + + List localURLs = new ArrayList<>(); + + for (URL url : urls) { + if (!url.getProtocol().equals("file")) { + completionService.submit(new Callable() { + @Override + public URL call() throws Exception { + String path = url.getPath(); + int idx = path.lastIndexOf('/'); + // FIXME Add other components in path to prevent conflicts? + String fileName = path.substring(idx + 1); + File dest = new File(jarDir, fileName); + + if (!dest.exists()) { + System.err.println("Downloading " + url); + try { + URLConnection conn = url.openConnection(); + long lastModified = conn.getLastModified(); + InputStream s = conn.getInputStream(); + byte[] b = readFullySync(s); + Files.write(dest.toPath(), b); + dest.setLastModified(lastModified); + } catch (Exception e) { + System.err.println("Error while downloading " + url + ": " + e.getMessage() + ", ignoring it"); + throw e; + } + } + + return dest.toURI().toURL(); + } + }); + } else { + localURLs.add(url); + } + } + + try { + while (localURLs.size() < urls.size()) { + Future future = completionService.take(); + try { + URL url = future.get(); + localURLs.add(url); + } catch (ExecutionException ex) { + // Error message already printed from the Callable above + System.exit(255); + } + } + } catch (InterruptedException ex) { + exit("Interrupted"); + } + + Thread thread = Thread.currentThread(); + ClassLoader parentClassLoader = thread.getContextClassLoader(); + + URLClassLoader classLoader = new URLClassLoader(localURLs.toArray(new URL[localURLs.size()]), parentClassLoader); + + Class mainClass = null; + Method mainMethod = null; + + try { + mainClass = classLoader.loadClass(mainClass0); + } catch (ClassNotFoundException ex) { + exit("Error: class " + mainClass0 + " not found"); + } + + try { + Class params[] = { String[].class }; + mainMethod = mainClass.getMethod("main", params); + } + catch (NoSuchMethodException ex) { + exit("Error: main method not found in class " + mainClass0); + } + + List userArgs0 = new ArrayList<>(); + + if (prependClasspath) { + for (URL url : localURLs) { + assert url.getProtocol().equals("file"); + userArgs0.add("-B"); + userArgs0.add(url.getPath()); + } + } + + userArgs0.addAll(userArgs); + + thread.setContextClassLoader(classLoader); + try { + Object mainArgs[] = { userArgs0.toArray(new String[userArgs0.size()]) }; + mainMethod.invoke(null, mainArgs); + } + catch (IllegalAccessException ex) { + exit(ex.getMessage()); + } + catch (InvocationTargetException ex) { + throw ex.getCause(); + } + finally { + thread.setContextClassLoader(parentClassLoader); + } + } + +} diff --git a/bootstrap/src/main/scala/coursier/Bootstrap.scala b/bootstrap/src/main/scala/coursier/Bootstrap.scala deleted file mode 100644 index a56b248bd..000000000 --- a/bootstrap/src/main/scala/coursier/Bootstrap.scala +++ /dev/null @@ -1,155 +0,0 @@ -package coursier - -import java.io.{ ByteArrayOutputStream, InputStream, File } -import java.net.{ URI, URLClassLoader } -import java.nio.file.Files -import java.util.concurrent.{ Executors, ThreadFactory } - -import scala.concurrent.duration.Duration -import scala.concurrent.{ ExecutionContext, Future, Await } - -import scala.util.{ Try, Success, Failure } - -object Bootstrap extends App { - - val concurrentDownloadCount = 6 - val threadFactory = new ThreadFactory { - // from scalaz Strategy.DefaultDaemonThreadFactory - val defaultThreadFactory = Executors.defaultThreadFactory() - def newThread(r: Runnable) = { - val t = defaultThreadFactory.newThread(r) - t.setDaemon(true) - t - } - } - val defaultPool = Executors.newFixedThreadPool(concurrentDownloadCount, threadFactory) - implicit val ec = ExecutionContext.fromExecutorService(defaultPool) - - private def readFullySync(is: InputStream) = { - val buffer = new ByteArrayOutputStream() - val data = Array.ofDim[Byte](16384) - - var nRead = is.read(data, 0, data.length) - while (nRead != -1) { - buffer.write(data, 0, nRead) - nRead = is.read(data, 0, data.length) - } - - buffer.flush() - buffer.toByteArray - } - - private def errPrintln(s: String): Unit = - Console.err.println(s) - - private def exit(msg: String = ""): Nothing = { - if (msg.nonEmpty) - errPrintln(msg) - sys.exit(255) - } - - val (prependClasspath, mainClass0, jarDir0, remainingArgs) = args match { - case Array("-B", mainClass0, jarDir0, remainingArgs @ _*) => - (true, mainClass0, jarDir0, remainingArgs) - case Array(mainClass0, jarDir0, remainingArgs @ _*) => - (false, mainClass0, jarDir0, remainingArgs) - case _ => - exit("Usage: bootstrap main-class JAR-directory JAR-URLs...") - } - - val jarDir = new File(jarDir0) - - if (jarDir.exists()) { - if (!jarDir.isDirectory) - exit(s"Error: $jarDir0 is not a directory") - } else if (!jarDir.mkdirs()) - errPrintln(s"Warning: cannot create $jarDir0, continuing anyway.") - - val splitIdx = remainingArgs.indexOf("--") - val (jarStrUrls, userArgs) = - if (splitIdx < 0) - (remainingArgs, Nil) - else - (remainingArgs.take(splitIdx), remainingArgs.drop(splitIdx + 1)) - - val tryUrls = jarStrUrls.map(urlStr => urlStr -> Try(URI.create(urlStr).toURL)) - - val failedUrls = tryUrls.collect { - case (strUrl, Failure(t)) => strUrl -> t - } - if (failedUrls.nonEmpty) - exit( - s"Error parsing ${failedUrls.length} URL(s):\n" + - failedUrls.map { case (s, t) => s"$s: ${t.getMessage}" }.mkString("\n") - ) - - val jarUrls = tryUrls.collect { - case (_, Success(url)) => url - } - - val jarLocalUrlFutures = jarUrls.map { url => - if (url.getProtocol == "file") - Future.successful(url) - else - Future { - val path = url.getPath - val idx = path.lastIndexOf('/') - // FIXME Add other components in path to prevent conflicts? - val fileName = path.drop(idx + 1) - val dest = new File(jarDir, fileName) - - // FIXME If dest exists, do a HEAD request and check that its size or last modified time is OK? - - if (!dest.exists()) { - Console.err.println(s"Downloading $url") - try { - val conn = url.openConnection() - val lastModified = conn.getLastModified - val s = conn.getInputStream - val b = readFullySync(s) - Files.write(dest.toPath, b) - dest.setLastModified(lastModified) - } catch { case e: Exception => - Console.err.println(s"Error while downloading $url: ${e.getMessage}, ignoring it") - } - } - - dest.toURI.toURL - } - } - - val jarLocalUrls = Await.result(Future.sequence(jarLocalUrlFutures), Duration.Inf) - - val thread = Thread.currentThread() - val parentClassLoader = thread.getContextClassLoader - - val classLoader = new URLClassLoader(jarLocalUrls.toArray, parentClassLoader) - - val mainClass = - try classLoader.loadClass(mainClass0) - catch { case e: ClassNotFoundException => - exit(s"Error: class $mainClass0 not found") - } - - val mainMethod = - try mainClass.getMethod("main", classOf[Array[String]]) - catch { case e: NoSuchMethodException => - exit(s"Error: main method not found in class $mainClass0") - } - - val userArgs0 = - if (prependClasspath) - jarLocalUrls.flatMap { url => - assert(url.getProtocol == "file") - Seq("-B", url.getPath) - } ++ userArgs - else - userArgs - - thread.setContextClassLoader(classLoader) - try mainMethod.invoke(null, userArgs0.toArray) - finally { - thread.setContextClassLoader(parentClassLoader) - } - -} \ No newline at end of file diff --git a/build.sbt b/build.sbt index 1e1b535ea..fe825db36 100644 --- a/build.sbt +++ b/build.sbt @@ -40,15 +40,18 @@ lazy val noPublishSettings = Seq( publishArtifact := false ) -lazy val commonSettings = Seq( +lazy val baseCommonSettings = Seq( organization := "com.github.alexarchambault", - scalaVersion := "2.11.7", - crossScalaVersions := Seq("2.10.6", "2.11.7"), resolvers ++= Seq( "Scalaz Bintray Repo" at "http://dl.bintray.com/scalaz/releases", Resolver.sonatypeRepo("releases"), Resolver.sonatypeRepo("snapshots") - ), + ) +) + +lazy val commonSettings = baseCommonSettings ++ Seq( + scalaVersion := "2.11.7", + crossScalaVersions := Seq("2.10.6", "2.11.7"), libraryDependencies ++= { if (scalaVersion.value startsWith "2.10.") Seq(compilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full)) @@ -120,7 +123,7 @@ lazy val cli = project "com.github.alexarchambault" %% "case-app" % "1.0.0-SNAPSHOT", "ch.qos.logback" % "logback-classic" % "1.1.3" ), - resourceGenerators in Compile += assembly.in(bootstrap).in(assembly).map { jar => + resourceGenerators in Compile += packageBin.in(bootstrap).in(Compile).map { jar => Seq(jar) }.taskValue ) @@ -157,14 +160,20 @@ lazy val web = project ) lazy val bootstrap = project - .settings(commonSettings) - .settings(noPublishSettings) + .settings(baseCommonSettings) + .settings(publishingSettings) .settings( name := "coursier-bootstrap", - assemblyJarName in assembly := s"bootstrap.jar", - assemblyShadeRules in assembly := Seq( - ShadeRule.rename("scala.**" -> "shadedscala.@1").inAll - ) + artifactName := { + val artifactName0 = artifactName.value + (sv, m, artifact) => + if (artifact.`type` == "jar" && artifact.extension == "jar") + "bootstrap.jar" + else + artifactName0(sv, m, artifact) + }, + crossPaths := false, + autoScalaLibrary := false ) lazy val `coursier` = project.in(file(".")) diff --git a/coursier b/coursier index 278ca87eb..c73f9338e 100755 Binary files a/coursier and b/coursier differ diff --git a/project/plugins.sbt b/project/plugins.sbt index 591cf84ea..d26d5c6bc 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,4 +3,3 @@ addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.5") addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.0.0") addSbtPlugin("com.github.gseitz" % "sbt-release" % "0.8.5") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.1.0") -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.0")