Fix EvaluateTask memory leak

EvaluateTask was holding references to SafeState that could be quite
large. This was reported as #5992. In that project, I ran the `ci` task
and observed the OOM as reported. I took a heap dump prior to OOM and
got the retained size graph from visualvm (which took hours to compute).
The lastEvaluatedState was holding a reference to SafeState that was
1.7GB. The project max heap size was set to 2GB. Instead of using the
lastEvaluatedState, we can just use StandardMain.exchange.withState.
The cached instances of state were used for task cancellation and
completions. While it is possible that early on in booting
StandardMain.exchange.withState could return a null state, in practice
this won't happen because it is set early on during the sbt boot
commands.

After this change, I successfully ran the `ci` task in the #5992 issue
project with the same memory parameters as their ci config.
This commit is contained in:
Ethan Atkins 2020-10-20 18:15:47 -07:00
parent d1ff067d54
commit d9acfed220
3 changed files with 58 additions and 58 deletions

View File

@ -72,12 +72,15 @@ final case class State(
* @param currentExecId provide the execId extracted from the original State. * @param currentExecId provide the execId extracted from the original State.
* @param combinedParser the parser extracted from the original State. * @param combinedParser the parser extracted from the original State.
*/ */
@deprecated("unused", "1.4.2")
private[sbt] final case class SafeState( private[sbt] final case class SafeState(
currentExecId: Option[String], currentExecId: Option[String],
combinedParser: Parser[() => sbt.State] combinedParser: Parser[() => sbt.State]
) )
@deprecated("unused", "1.4.2")
private[sbt] object SafeState { private[sbt] object SafeState {
@deprecated("use StandardMain.exchange.withState", "1.4.2")
def apply(s: State) = { def apply(s: State) = {
new SafeState( new SafeState(
currentExecId = s.currentCommand.map(_.execId).flatten, currentExecId = s.currentCommand.map(_.execId).flatten,

View File

@ -413,9 +413,13 @@ object EvaluateTask {
(dummyRoots, roots) :: (Def.dummyStreamsManager, streams) :: (dummyState, state) :: dummies (dummyRoots, roots) :: (Def.dummyStreamsManager, streams) :: (dummyState, state) :: dummies
) )
@deprecated("use StandardMain.exchange.withState to obtain an instance of State", "1.4.2")
val lastEvaluatedState: AtomicReference[SafeState] = new AtomicReference() val lastEvaluatedState: AtomicReference[SafeState] = new AtomicReference()
@deprecated("use currentlyRunningTaskEngine", "1.4.2")
val currentlyRunningEngine: AtomicReference[(SafeState, RunningTaskEngine)] = val currentlyRunningEngine: AtomicReference[(SafeState, RunningTaskEngine)] =
new AtomicReference() new AtomicReference()
private[sbt] val currentlyRunningTaskEngine: AtomicReference[RunningTaskEngine] =
new AtomicReference()
/** /**
* The main method for the task engine. * The main method for the task engine.
@ -486,7 +490,7 @@ object EvaluateTask {
shutdownImpl(true) shutdownImpl(true)
} }
} }
currentlyRunningEngine.set((SafeState(state), runningEngine)) currentlyRunningTaskEngine.set(runningEngine)
// Register with our cancel handler we're about to start. // Register with our cancel handler we're about to start.
val strat = config.cancelStrategy val strat = config.cancelStrategy
val cancelState = strat.onTaskEngineStart(runningEngine) val cancelState = strat.onTaskEngineStart(runningEngine)
@ -494,8 +498,7 @@ object EvaluateTask {
try run() try run()
finally { finally {
strat.onTaskEngineFinish(cancelState) strat.onTaskEngineFinish(cancelState)
currentlyRunningEngine.set(null) currentlyRunningTaskEngine.set(null)
lastEvaluatedState.set(SafeState(state))
} }
} }

View File

@ -422,58 +422,50 @@ final class NetworkChannel(
protected def onCompletionRequest(execId: Option[String], cp: CompletionParams) = { protected def onCompletionRequest(execId: Option[String], cp: CompletionParams) = {
if (initialized) { if (initialized) {
try { try {
Option(EvaluateTask.lastEvaluatedState.get) match { StandardMain.exchange.withState { sstate =>
case Some(sstate) => import sbt.protocol.codec.JsonProtocol._
import sbt.protocol.codec.JsonProtocol._ def completionItems(s: State) = {
def completionItems(s: State) = { Parser
Parser .completions(s.combinedParser, cp.query, cp.level.getOrElse(9))
.completions(s.combinedParser, cp.query, cp.level.getOrElse(9)) .get
.get .flatMap { c =>
.flatMap { c => if (!c.isEmpty) Some(c.append.replaceAll("\n", " "))
if (!c.isEmpty) Some(c.append.replaceAll("\n", " ")) else None
else None }
} .map(c => cp.query + c)
.map(c => cp.query + c) }
} val (items, cachedMainClassNames, cachedTestNames) = {
val (items, cachedMainClassNames, cachedTestNames) = StandardMain.exchange.withState { val scopedKeyParser: Parser[Seq[Def.ScopedKey[_]]] =
s => Act.aggregatedKeyParser(sstate) <~ Parsers.any.*
val scopedKeyParser: Parser[Seq[Def.ScopedKey[_]]] = Parser.parse(cp.query, scopedKeyParser) match {
Act.aggregatedKeyParser(s) <~ Parsers.any.* case Right(keys) =>
Parser.parse(cp.query, scopedKeyParser) match { val testKeys =
case Right(keys) => keys.filter(k => k.key.label == "testOnly" || k.key.label == "testQuick")
val testKeys = val (testState, cachedTestNames) = testKeys.foldLeft((sstate, true)) {
keys.filter(k => k.key.label == "testOnly" || k.key.label == "testQuick") case ((st, allCached), k) =>
val (testState, cachedTestNames) = testKeys.foldLeft((s, true)) { SessionVar.loadAndSet(sbt.Keys.definedTestNames in k.scope, st, true) match {
case ((st, allCached), k) => case (nst, d) => (nst, allCached && d.isDefined)
SessionVar.loadAndSet(sbt.Keys.definedTestNames in k.scope, st, true) match {
case (nst, d) => (nst, allCached && d.isDefined)
}
} }
val runKeys = keys.filter(_.key.label == "runMain")
val (runState, cachedMainClassNames) = runKeys.foldLeft((testState, true)) {
case ((st, allCached), k) =>
SessionVar.loadAndSet(sbt.Keys.discoveredMainClasses in k.scope, st, true) match {
case (nst, d) => (nst, allCached && d.isDefined)
}
}
(completionItems(runState), cachedMainClassNames, cachedTestNames)
case _ => (completionItems(s), true, true)
} }
val runKeys = keys.filter(_.key.label == "runMain")
val (runState, cachedMainClassNames) = runKeys.foldLeft((testState, true)) {
case ((st, allCached), k) =>
SessionVar.loadAndSet(sbt.Keys.discoveredMainClasses in k.scope, st, true) match {
case (nst, d) => (nst, allCached && d.isDefined)
}
}
(completionItems(runState), cachedMainClassNames, cachedTestNames)
case _ => (completionItems(sstate), true, true)
} }
respondResult( }
CompletionResponse( respondResult(
items = items.toVector, CompletionResponse(
cachedMainClassNames = cachedMainClassNames, items = items.toVector,
cachedTestNames = cachedTestNames cachedMainClassNames = cachedMainClassNames,
), cachedTestNames = cachedTestNames
execId ),
) execId
case _ => )
respondError(
ErrorCodes.UnknownError,
"No available sbt state",
execId
)
} }
} catch { } catch {
case NonFatal(_) => case NonFatal(_) =>
@ -498,9 +490,10 @@ final class NetworkChannel(
) )
try { try {
Option(EvaluateTask.currentlyRunningEngine.get) match { Option(EvaluateTask.currentlyRunningTaskEngine.get) match {
case Some((state, runningEngine)) => case Some(runningEngine) =>
val runningExecId = state.currentExecId.getOrElse("") val runningExecId =
StandardMain.exchange.withState(_.currentCommand.flatMap(_.execId).getOrElse(""))
val expected = StandardMain.exchange.withState( val expected = StandardMain.exchange.withState(
_.get(BasicCommands.execMap) _.get(BasicCommands.execMap)
.flatMap(s => s.get(crp.id) orElse s.get("\u2668" + crp.id)) .flatMap(s => s.get(crp.id) orElse s.get("\u2668" + crp.id))
@ -936,9 +929,10 @@ object NetworkChannel {
id: String id: String
): Either[String, String] = { ): Either[String, String] = {
Option(EvaluateTask.currentlyRunningEngine.get) match { Option(EvaluateTask.currentlyRunningTaskEngine.get) match {
case Some((state, runningEngine)) => case Some(runningEngine) =>
val runningExecId = state.currentExecId.getOrElse("") val runningExecId =
StandardMain.exchange.withState(_.currentCommand.flatMap(_.execId).getOrElse(""))
def checkId(): Boolean = { def checkId(): Boolean = {
if (runningExecId.startsWith("\u2668")) { if (runningExecId.startsWith("\u2668")) {