diff --git a/tasks/standard/TaskExtra.scala b/tasks/standard/TaskExtra.scala index 79f40e301..d53527fe7 100644 --- a/tasks/standard/TaskExtra.scala +++ b/tasks/standard/TaskExtra.scala @@ -185,7 +185,7 @@ object TaskExtra extends TaskExtra def processIO(s: TaskStreams[_]): ProcessIO = { def transfer(id: String) = (in: InputStream) => BasicIO.transferFully(in, s.binary(id)) - new ProcessIO(BasicIO.closeOut, transfer(s.outID), transfer(s.errorID)) + new ProcessIO(BasicIO.closeOut, transfer(s.outID), transfer(s.errorID), inheritInput = {_ => None}) } def reduced[S](i: IndexedSeq[Task[S]], f: (S, S) => S): Task[S] = i match @@ -225,4 +225,4 @@ object TaskExtra extends TaskExtra def failures[A](results: Seq[Result[A]]): Seq[Incomplete] = results.collect { case Inc(i) => i } def incompleteDeps(incs: Seq[Incomplete]): Incomplete = Incomplete(None, causes = incs) -} \ No newline at end of file +} diff --git a/util/process/InheritInput.scala b/util/process/InheritInput.scala index bd45ebe62..5cfe30b79 100755 --- a/util/process/InheritInput.scala +++ b/util/process/InheritInput.scala @@ -7,9 +7,9 @@ import java.lang.{ProcessBuilder => JProcessBuilder} /** On java 7, inherit System.in for a ProcessBuilder. */ private[sbt] object InheritInput { - def apply(p: JProcessBuilder): (Boolean, JProcessBuilder) = (redirectInput, inherit) match { - case (Some(m), Some(f)) => (true, m.invoke(p, f).asInstanceOf[JProcessBuilder]) - case _ => (false, p) + def apply(p: JProcessBuilder): Option[JProcessBuilder] = (redirectInput, inherit) match { + case (Some(m), Some(f)) => Some(m.invoke(p, f).asInstanceOf[JProcessBuilder]) + case _ => None } private[this] val pbClass = Class.forName("java.lang.ProcessBuilder") diff --git a/util/process/Process.scala b/util/process/Process.scala index 5a1f46b4a..bd7e5cbc6 100644 --- a/util/process/Process.scala +++ b/util/process/Process.scala @@ -180,15 +180,15 @@ trait ProcessBuilder extends SourcePartialBuilder with SinkPartialBuilder def canPipeTo: Boolean } /** Each method will be called in a separate thread.*/ -final class ProcessIO(val writeInput: OutputStream => Unit, val processOutput: InputStream => Unit, val processError: InputStream => Unit) extends NotNull +final class ProcessIO(val writeInput: OutputStream => Unit, val processOutput: InputStream => Unit, val processError: InputStream => Unit, val inheritInput: JProcessBuilder => Option[JProcessBuilder]) extends NotNull { - def withOutput(process: InputStream => Unit): ProcessIO = new ProcessIO(writeInput, process, processError) - def withError(process: InputStream => Unit): ProcessIO = new ProcessIO(writeInput, processOutput, process) - def withInput(write: OutputStream => Unit): ProcessIO = new ProcessIO(write, processOutput, processError) + def withOutput(process: InputStream => Unit): ProcessIO = new ProcessIO(writeInput, process, processError, inheritInput) + def withError(process: InputStream => Unit): ProcessIO = new ProcessIO(writeInput, processOutput, process, inheritInput) + def withInput(write: OutputStream => Unit): ProcessIO = new ProcessIO(write, processOutput, processError, inheritInput) } trait ProcessLogger { def info(s: => String): Unit def error(s: => String): Unit def buffer[T](f: => T): T -} \ No newline at end of file +} diff --git a/util/process/ProcessImpl.scala b/util/process/ProcessImpl.scala index cea74272f..e58122efa 100644 --- a/util/process/ProcessImpl.scala +++ b/util/process/ProcessImpl.scala @@ -43,8 +43,8 @@ private object Future object BasicIO { - def apply(buffer: StringBuffer, log: Option[ProcessLogger], withIn: Boolean) = new ProcessIO(input(withIn), processFully(buffer), getErr(log)) - def apply(log: ProcessLogger, withIn: Boolean) = new ProcessIO(input(withIn), processInfoFully(log), processErrFully(log)) + def apply(buffer: StringBuffer, log: Option[ProcessLogger], withIn: Boolean) = new ProcessIO(input(withIn), processFully(buffer), getErr(log), inheritInput(withIn)) + def apply(log: ProcessLogger, withIn: Boolean) = new ProcessIO(input(withIn), processInfoFully(log), processErrFully(log), inheritInput(withIn)) def getErr(log: Option[ProcessLogger]) = log match { case Some(lg) => processErrFully(lg); case None => toStdErr } @@ -78,9 +78,9 @@ object BasicIO readFully() } def connectToIn(o: OutputStream) { transferFully(Uncloseable protect System.in, o) } - def input(connect: Boolean): OutputStream => Unit = if(connect) connectToIn else closeOut - def standard(connectInput: Boolean): ProcessIO = standard(input(connectInput)) - def standard(in: OutputStream => Unit): ProcessIO = new ProcessIO(in, toStdOut, toStdErr) + def input(connect: Boolean): OutputStream => Unit = if(connect) connectToIn else closeOut + def standard(connectInput: Boolean): ProcessIO = standard(input(connectInput), inheritInput(connectInput)) + def standard(in: OutputStream => Unit, inheritIn: JProcessBuilder => Option[JProcessBuilder]): ProcessIO = new ProcessIO(in, toStdOut, toStdErr, inheritIn) def toStdErr = (in: InputStream) => transferFully(in, System.err) def toStdOut = (in: InputStream) => transferFully(in, System.out) @@ -113,6 +113,8 @@ object BasicIO read in.close() } + + def inheritInput(connect: Boolean) = { p: JProcessBuilder => if (connect) InheritInput(p) else None } } @@ -154,7 +156,7 @@ private abstract class AbstractProcessBuilder extends ProcessBuilder with SinkPa private[this] def lines(withInput: Boolean, nonZeroException: Boolean, log: Option[ProcessLogger]): Stream[String] = { val streamed = Streamed[String](nonZeroException) - val process = run(new ProcessIO(BasicIO.input(withInput), BasicIO.processFully(streamed.process), BasicIO.getErr(log))) + val process = run(new ProcessIO(BasicIO.input(withInput), BasicIO.processFully(streamed.process), BasicIO.getErr(log), BasicIO.inheritInput(withInput))) Spawn { streamed.done(process.exitValue()) } streamed.stream() } @@ -379,13 +381,14 @@ private[sbt] class SimpleProcessBuilder(p: JProcessBuilder) extends AbstractProc { override def run(io: ProcessIO): Process = { - val (inherited, pp) = InheritInput(p) - val process = pp.start() // start the external process - import io.{writeInput, processOutput, processError} - // spawn threads that process the input, output, and error streams using the functions defined in `io` - if(!inherited) - Spawn(writeInput(process.getOutputStream), true) + import io._ + val process = inheritInput(p) map (_.start()) getOrElse { + val proc = p.start() + Spawn(writeInput(proc.getOutputStream)) + proc + } + // spawn threads that process the output and error streams. val outThread = Spawn(processOutput(process.getInputStream)) val errorThread = if(!p.redirectErrorStream) @@ -408,7 +411,13 @@ private class SimpleProcess(p: JProcess, outputThreads: List[Thread]) extends Pr override def exitValue() = { def waitDone(): Unit = - try { p.waitFor() } catch { case _: InterruptedException => waitDone() } + try { + p.waitFor() + } catch { + case _: InterruptedException => + // Guard against possible spurious wakeups, check thread interrupted status. + if(Thread.interrupted()) p.destroy() else waitDone() + } waitDone() outputThreads.foreach(_.join()) // this ensures that all output is complete before returning (waitFor does not ensure this) p.exitValue()