diff --git a/main-command/src/main/scala/xsbt/IPC.scala b/main-command/src/main/scala/xsbt/IPC.scala index 2b9750438..c964bca7b 100644 --- a/main-command/src/main/scala/xsbt/IPC.scala +++ b/main-command/src/main/scala/xsbt/IPC.scala @@ -10,47 +10,57 @@ package xsbt import java.io.{ BufferedReader, BufferedWriter, InputStreamReader, OutputStreamWriter } import java.net.{ InetAddress, ServerSocket, Socket } +import scala.annotation.tailrec import scala.util.control.NonFatal object IPC { private val portMin = 1025 private val portMax = 65536 - private val loopback = InetAddress.getByName(null) // loopback + private val loopback = InetAddress.getByName(null) - def client[T](port: Int)(f: IPC => T): T = - ipc(new Socket(loopback, port))(f) + def client[T](port: Int)(f: IPC => T): T = ipc(new Socket(loopback, port))(f) def pullServer[T](f: Server => T): T = { val server = makeServer - try { f(new Server(server)) } finally { server.close() } + try f(new Server(server)) + finally server.close() } + def unmanagedServer: Server = new Server(makeServer) + def makeServer: ServerSocket = { val random = new java.util.Random def nextPort = random.nextInt(portMax - portMin + 1) + portMin + def createServer(attempts: Int): ServerSocket = - if (attempts > 0) - try { new ServerSocket(nextPort, 1, loopback) } catch { - case NonFatal(_) => createServer(attempts - 1) - } else - sys.error("Could not connect to socket: maximum attempts exceeded") + if (attempts > 0) { + try new ServerSocket(nextPort, 1, loopback) + catch { case NonFatal(_) => createServer(attempts - 1) } + } else sys.error("Could not connect to socket: maximum attempts exceeded") + createServer(10) } + def server[T](f: IPC => Option[T]): T = serverImpl(makeServer, f) + def server[T](port: Int)(f: IPC => Option[T]): T = serverImpl(new ServerSocket(port, 1, loopback), f) + private def serverImpl[T](server: ServerSocket, f: IPC => Option[T]): T = { - def listen(): T = { + @tailrec def listen(): T = { ipc(server.accept())(f) match { case Some(done) => done case None => listen() } } - try { listen() } finally { server.close() } + try listen() + finally server.close() } + private def ipc[T](s: Socket)(f: IPC => T): T = - try { f(new IPC(s)) } finally { s.close() } + try f(new IPC(s)) + finally s.close() final class Server private[IPC] (s: ServerSocket) { def port = s.getLocalPort @@ -59,6 +69,7 @@ object IPC { def connection[T](f: IPC => T): T = IPC.ipc(s.accept())(f) } } + final class IPC private (s: Socket) { def port = s.getLocalPort private val in = new BufferedReader(new InputStreamReader(s.getInputStream))