From d9acfed2205145d972d05a8f7c74ed94b68b7248 Mon Sep 17 00:00:00 2001 From: Ethan Atkins Date: Tue, 20 Oct 2020 18:15:47 -0700 Subject: [PATCH] Fix EvaluateTask memory leak EvaluateTask was holding references to SafeState that could be quite large. This was reported as #5992. In that project, I ran the `ci` task and observed the OOM as reported. I took a heap dump prior to OOM and got the retained size graph from visualvm (which took hours to compute). The lastEvaluatedState was holding a reference to SafeState that was 1.7GB. The project max heap size was set to 2GB. Instead of using the lastEvaluatedState, we can just use StandardMain.exchange.withState. The cached instances of state were used for task cancellation and completions. While it is possible that early on in booting StandardMain.exchange.withState could return a null state, in practice this won't happen because it is set early on during the sbt boot commands. After this change, I successfully ran the `ci` task in the #5992 issue project with the same memory parameters as their ci config. --- main-command/src/main/scala/sbt/State.scala | 3 + main/src/main/scala/sbt/EvaluateTask.scala | 9 +- .../sbt/internal/server/NetworkChannel.scala | 104 +++++++++--------- 3 files changed, 58 insertions(+), 58 deletions(-) diff --git a/main-command/src/main/scala/sbt/State.scala b/main-command/src/main/scala/sbt/State.scala index d7db6a51f..00a2d946f 100644 --- a/main-command/src/main/scala/sbt/State.scala +++ b/main-command/src/main/scala/sbt/State.scala @@ -72,12 +72,15 @@ final case class State( * @param currentExecId provide the execId extracted from the original State. * @param combinedParser the parser extracted from the original State. */ +@deprecated("unused", "1.4.2") private[sbt] final case class SafeState( currentExecId: Option[String], combinedParser: Parser[() => sbt.State] ) +@deprecated("unused", "1.4.2") private[sbt] object SafeState { + @deprecated("use StandardMain.exchange.withState", "1.4.2") def apply(s: State) = { new SafeState( currentExecId = s.currentCommand.map(_.execId).flatten, diff --git a/main/src/main/scala/sbt/EvaluateTask.scala b/main/src/main/scala/sbt/EvaluateTask.scala index 4ff043af3..5cbff51db 100644 --- a/main/src/main/scala/sbt/EvaluateTask.scala +++ b/main/src/main/scala/sbt/EvaluateTask.scala @@ -413,9 +413,13 @@ object EvaluateTask { (dummyRoots, roots) :: (Def.dummyStreamsManager, streams) :: (dummyState, state) :: dummies ) + @deprecated("use StandardMain.exchange.withState to obtain an instance of State", "1.4.2") val lastEvaluatedState: AtomicReference[SafeState] = new AtomicReference() + @deprecated("use currentlyRunningTaskEngine", "1.4.2") val currentlyRunningEngine: AtomicReference[(SafeState, RunningTaskEngine)] = new AtomicReference() + private[sbt] val currentlyRunningTaskEngine: AtomicReference[RunningTaskEngine] = + new AtomicReference() /** * The main method for the task engine. @@ -486,7 +490,7 @@ object EvaluateTask { shutdownImpl(true) } } - currentlyRunningEngine.set((SafeState(state), runningEngine)) + currentlyRunningTaskEngine.set(runningEngine) // Register with our cancel handler we're about to start. val strat = config.cancelStrategy val cancelState = strat.onTaskEngineStart(runningEngine) @@ -494,8 +498,7 @@ object EvaluateTask { try run() finally { strat.onTaskEngineFinish(cancelState) - currentlyRunningEngine.set(null) - lastEvaluatedState.set(SafeState(state)) + currentlyRunningTaskEngine.set(null) } } diff --git a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala index 77891846a..2e4cec4c0 100644 --- a/main/src/main/scala/sbt/internal/server/NetworkChannel.scala +++ b/main/src/main/scala/sbt/internal/server/NetworkChannel.scala @@ -422,58 +422,50 @@ final class NetworkChannel( protected def onCompletionRequest(execId: Option[String], cp: CompletionParams) = { if (initialized) { try { - Option(EvaluateTask.lastEvaluatedState.get) match { - case Some(sstate) => - import sbt.protocol.codec.JsonProtocol._ - def completionItems(s: State) = { - Parser - .completions(s.combinedParser, cp.query, cp.level.getOrElse(9)) - .get - .flatMap { c => - if (!c.isEmpty) Some(c.append.replaceAll("\n", " ")) - else None - } - .map(c => cp.query + c) - } - val (items, cachedMainClassNames, cachedTestNames) = StandardMain.exchange.withState { - s => - val scopedKeyParser: Parser[Seq[Def.ScopedKey[_]]] = - Act.aggregatedKeyParser(s) <~ Parsers.any.* - Parser.parse(cp.query, scopedKeyParser) match { - case Right(keys) => - val testKeys = - keys.filter(k => k.key.label == "testOnly" || k.key.label == "testQuick") - val (testState, cachedTestNames) = testKeys.foldLeft((s, true)) { - case ((st, allCached), k) => - SessionVar.loadAndSet(sbt.Keys.definedTestNames in k.scope, st, true) match { - case (nst, d) => (nst, allCached && d.isDefined) - } + StandardMain.exchange.withState { sstate => + import sbt.protocol.codec.JsonProtocol._ + def completionItems(s: State) = { + Parser + .completions(s.combinedParser, cp.query, cp.level.getOrElse(9)) + .get + .flatMap { c => + if (!c.isEmpty) Some(c.append.replaceAll("\n", " ")) + else None + } + .map(c => cp.query + c) + } + val (items, cachedMainClassNames, cachedTestNames) = { + val scopedKeyParser: Parser[Seq[Def.ScopedKey[_]]] = + Act.aggregatedKeyParser(sstate) <~ Parsers.any.* + Parser.parse(cp.query, scopedKeyParser) match { + case Right(keys) => + val testKeys = + keys.filter(k => k.key.label == "testOnly" || k.key.label == "testQuick") + val (testState, cachedTestNames) = testKeys.foldLeft((sstate, true)) { + case ((st, allCached), k) => + SessionVar.loadAndSet(sbt.Keys.definedTestNames in k.scope, st, true) match { + case (nst, d) => (nst, allCached && d.isDefined) } - val runKeys = keys.filter(_.key.label == "runMain") - val (runState, cachedMainClassNames) = runKeys.foldLeft((testState, true)) { - case ((st, allCached), k) => - SessionVar.loadAndSet(sbt.Keys.discoveredMainClasses in k.scope, st, true) match { - case (nst, d) => (nst, allCached && d.isDefined) - } - } - (completionItems(runState), cachedMainClassNames, cachedTestNames) - case _ => (completionItems(s), true, true) } + val runKeys = keys.filter(_.key.label == "runMain") + val (runState, cachedMainClassNames) = runKeys.foldLeft((testState, true)) { + case ((st, allCached), k) => + SessionVar.loadAndSet(sbt.Keys.discoveredMainClasses in k.scope, st, true) match { + case (nst, d) => (nst, allCached && d.isDefined) + } + } + (completionItems(runState), cachedMainClassNames, cachedTestNames) + case _ => (completionItems(sstate), true, true) } - respondResult( - CompletionResponse( - items = items.toVector, - cachedMainClassNames = cachedMainClassNames, - cachedTestNames = cachedTestNames - ), - execId - ) - case _ => - respondError( - ErrorCodes.UnknownError, - "No available sbt state", - execId - ) + } + respondResult( + CompletionResponse( + items = items.toVector, + cachedMainClassNames = cachedMainClassNames, + cachedTestNames = cachedTestNames + ), + execId + ) } } catch { case NonFatal(_) => @@ -498,9 +490,10 @@ final class NetworkChannel( ) try { - Option(EvaluateTask.currentlyRunningEngine.get) match { - case Some((state, runningEngine)) => - val runningExecId = state.currentExecId.getOrElse("") + Option(EvaluateTask.currentlyRunningTaskEngine.get) match { + case Some(runningEngine) => + val runningExecId = + StandardMain.exchange.withState(_.currentCommand.flatMap(_.execId).getOrElse("")) val expected = StandardMain.exchange.withState( _.get(BasicCommands.execMap) .flatMap(s => s.get(crp.id) orElse s.get("\u2668" + crp.id)) @@ -936,9 +929,10 @@ object NetworkChannel { id: String ): Either[String, String] = { - Option(EvaluateTask.currentlyRunningEngine.get) match { - case Some((state, runningEngine)) => - val runningExecId = state.currentExecId.getOrElse("") + Option(EvaluateTask.currentlyRunningTaskEngine.get) match { + case Some(runningEngine) => + val runningExecId = + StandardMain.exchange.withState(_.currentCommand.flatMap(_.execId).getOrElse("")) def checkId(): Boolean = { if (runningExecId.startsWith("\u2668")) {