From f209d0093c257dfe2d1de92e6a25bf3682977b47 Mon Sep 17 00:00:00 2001 From: Igal Tabachnik Date: Sat, 28 Aug 2021 17:49:19 +0300 Subject: [PATCH] BSP tasks report progress during compilation --- main/src/main/scala/sbt/Defaults.scala | 49 +++++-- .../sbt/internal/server/BspCompileTask.scala | 124 ++++++++++-------- .../sbt/internal/bsp/TaskProgressParams.scala | 92 +++++++++++++ .../sbt/internal/bsp/codec/JsonProtocol.scala | 1 + .../bsp/codec/TaskProgressParamsFormats.scala | 41 ++++++ protocol/src/main/contraband/bsp.contra | 26 ++++ 6 files changed, 272 insertions(+), 61 deletions(-) create mode 100644 protocol/src/main/contraband-scala/sbt/internal/bsp/TaskProgressParams.scala create mode 100644 protocol/src/main/contraband-scala/sbt/internal/bsp/codec/TaskProgressParamsFormats.scala diff --git a/main/src/main/scala/sbt/Defaults.scala b/main/src/main/scala/sbt/Defaults.scala index 236b28a13..721a08b5d 100644 --- a/main/src/main/scala/sbt/Defaults.scala +++ b/main/src/main/scala/sbt/Defaults.scala @@ -2305,14 +2305,14 @@ object Defaults extends BuildCommon { analysis } def compileIncrementalTask = Def.task { + val s = streams.value + val ci = (compile / compileInputs).value + val ping = earlyOutputPing.value + val reporter = (compile / bspReporter).value BspCompileTask.compute(bspTargetIdentifier.value, thisProjectRef.value, configuration.value) { - // TODO - Should readAnalysis + saveAnalysis be scoped by the compile task too? - compileIncrementalTaskImpl( - streams.value, - (compile / compileInputs).value, - earlyOutputPing.value, - (compile / bspReporter).value - ) + task => + // TODO - Should readAnalysis + saveAnalysis be scoped by the compile task too? + compileIncrementalTaskImpl(task, s, ci, ping, reporter) } } private val incCompiler = ZincUtil.defaultIncrementalCompiler @@ -2337,6 +2337,7 @@ object Defaults extends BuildCommon { } } private[this] def compileIncrementalTaskImpl( + task: BspCompileTask, s: TaskStreams, ci: Inputs, promise: PromiseWrap[Boolean], @@ -2351,8 +2352,40 @@ object Defaults extends BuildCommon { } ) } + def onProgress(s: Setup) = { + val p = s.progress.asScala + s.withProgress { + p.collect { + case c: CompileProgress => + new CompileProgress { + override def startUnit(phase: String, unitPath: String): Unit = + c.startUnit(phase, unitPath) + + override def afterEarlyOutput(success: Boolean): Unit = + c.afterEarlyOutput(success) + + override def advance( + current: Int, + total: Int, + prevPhase: String, + nextPhase: String + ): Boolean = { + val percentage = current * 100 / total + // Report percentages every 5% increments + val shouldReportPercentage = percentage % 5 == 0 + if (shouldReportPercentage) { + task.notifyProgress(percentage, total) + } + + c.advance(current, total, prevPhase, nextPhase) + } + } + }.asJava + } + } val compilers: Compilers = ci.compilers - val i = ci.withCompilers(onArgs(compilers)) + val setup: Setup = ci.setup + val i = ci.withCompilers(onArgs(compilers)).withSetup(onProgress(setup)) try { val result = incCompiler.compile(i, s.log) reporter.sendSuccessReport(result.getAnalysis) diff --git a/main/src/main/scala/sbt/internal/server/BspCompileTask.scala b/main/src/main/scala/sbt/internal/server/BspCompileTask.scala index e8a7ae7a0..4e53263ee 100644 --- a/main/src/main/scala/sbt/internal/server/BspCompileTask.scala +++ b/main/src/main/scala/sbt/internal/server/BspCompileTask.scala @@ -10,6 +10,7 @@ package sbt.internal.server import sbt._ import sbt.internal.bsp._ import sbt.internal.io.Retry +import sbt.internal.server.BspCompileTask.{ compileReport, exchange } import sbt.librarymanagement.Configuration import sjsonnew.support.scalajson.unsafe.Converter import xsbti.compile.CompileResult @@ -18,18 +19,16 @@ import xsbti.{ CompileFailed, Problem, Severity } import scala.util.control.NonFatal object BspCompileTask { - import sbt.internal.bsp.codec.JsonProtocol._ - private lazy val exchange = StandardMain.exchange def compute(targetId: BuildTargetIdentifier, project: ProjectRef, config: Configuration)( - compile: => CompileResult + compile: BspCompileTask => CompileResult ): CompileResult = { val task = BspCompileTask(targetId, project, config) try { - notifyStart(task) - val result = Retry(compile) - notifySuccess(task, result) + task.notifyStart + val result = Retry(compile(task)) + task.notifySuccess(result) result } catch { case NonFatal(cause) => @@ -37,7 +36,7 @@ object BspCompileTask { case failed: CompileFailed => Some(failed) case _ => None } - notifyFailure(task, compileFailed) + task.notifyFailure(compileFailed) throw cause } } @@ -52,51 +51,6 @@ object BspCompileTask { BspCompileTask(targetId, targetName, taskId, System.currentTimeMillis()) } - private def notifyStart(task: BspCompileTask): Unit = { - val message = s"Compiling ${task.targetName}" - val data = Converter.toJsonUnsafe(CompileTask(task.targetId)) - val params = TaskStartParams(task.id, task.startTimeMillis, message, "compile-task", data) - exchange.notifyEvent("build/taskStart", params) - } - - private def notifySuccess(task: BspCompileTask, result: CompileResult): Unit = { - import collection.JavaConverters._ - val endTimeMillis = System.currentTimeMillis() - val elapsedTimeMillis = endTimeMillis - task.startTimeMillis - val problems = result match { - case compileResult: CompileResult => - val sourceInfos = compileResult.analysis().readSourceInfos().getAllSourceInfos.asScala - sourceInfos.values.flatMap(_.getReportedProblems).toSeq - case _ => Seq() - } - val report = compileReport(problems, task.targetId, elapsedTimeMillis) - val params = TaskFinishParams( - task.id, - endTimeMillis, - s"Compiled ${task.targetName}", - StatusCode.Success, - "compile-report", - Converter.toJsonUnsafe(report) - ) - exchange.notifyEvent("build/taskFinish", params) - } - - private def notifyFailure(task: BspCompileTask, cause: Option[CompileFailed]): Unit = { - val endTimeMillis = System.currentTimeMillis() - val elapsedTimeMillis = endTimeMillis - task.startTimeMillis - val problems = cause.map(_.problems().toSeq).getOrElse(Seq.empty[Problem]) - val report = compileReport(problems, task.targetId, elapsedTimeMillis) - val params = TaskFinishParams( - task.id, - endTimeMillis, - s"Compiled ${task.targetName}", - StatusCode.Error, - "compile-report", - Converter.toJsonUnsafe(report) - ) - exchange.notifyEvent("build/taskFinish", params) - } - private def compileReport( problems: Seq[Problem], targetId: BuildTargetIdentifier, @@ -114,4 +68,68 @@ case class BspCompileTask private ( targetName: String, id: TaskId, startTimeMillis: Long -) +) { + import sbt.internal.bsp.codec.JsonProtocol._ + + private[sbt] def notifyStart(): Unit = { + val message = s"Compiling $targetName" + val data = Converter.toJsonUnsafe(CompileTask(targetId)) + val params = TaskStartParams(id, startTimeMillis, message, "compile-task", data) + exchange.notifyEvent("build/taskStart", params) + } + + private[sbt] def notifySuccess(result: CompileResult): Unit = { + import collection.JavaConverters._ + val endTimeMillis = System.currentTimeMillis() + val elapsedTimeMillis = endTimeMillis - startTimeMillis + val problems = result match { + case compileResult: CompileResult => + val sourceInfos = compileResult.analysis().readSourceInfos().getAllSourceInfos.asScala + sourceInfos.values.flatMap(_.getReportedProblems).toSeq + case _ => Seq() + } + val report = compileReport(problems, targetId, elapsedTimeMillis) + val params = TaskFinishParams( + id, + endTimeMillis, + s"Compiled $targetName", + StatusCode.Success, + "compile-report", + Converter.toJsonUnsafe(report) + ) + exchange.notifyEvent("build/taskFinish", params) + } + + private[sbt] def notifyProgress(percentage: Int, total: Int): Unit = { + val data = Converter.toJsonUnsafe(CompileTask(targetId)) + val message = s"Compiling $targetName ($percentage%)" + val currentMillis = System.currentTimeMillis() + val params = TaskProgressParams( + id, + Some(currentMillis), + Some(message), + Some(total.toLong), + Some(percentage.toLong), + None, + Some("compile-progress"), + Some(data) + ) + exchange.notifyEvent("build/taskProgress", params) + } + + private[sbt] def notifyFailure(cause: Option[CompileFailed]): Unit = { + val endTimeMillis = System.currentTimeMillis() + val elapsedTimeMillis = endTimeMillis - startTimeMillis + val problems = cause.map(_.problems().toSeq).getOrElse(Seq.empty[Problem]) + val report = compileReport(problems, targetId, elapsedTimeMillis) + val params = TaskFinishParams( + id, + endTimeMillis, + s"Compiled $targetName", + StatusCode.Error, + "compile-report", + Converter.toJsonUnsafe(report) + ) + exchange.notifyEvent("build/taskFinish", params) + } +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/bsp/TaskProgressParams.scala b/protocol/src/main/contraband-scala/sbt/internal/bsp/TaskProgressParams.scala new file mode 100644 index 000000000..43108df4c --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/bsp/TaskProgressParams.scala @@ -0,0 +1,92 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.bsp +/** + * @param taskId Unique id of the task with optional reference to parent task id. + * @param eventTime Optional timestamp of when the event started in milliseconds since Epoch. + * @param message Message describing the task progress. + * @param total If known, total amount of work units in this task. + * @param progress If known, completed amount of work units in this task. + * @param unit Name of a work unit. For example, "files" or "tests". May be empty. + * @param dataKind Kind of data to expect in the `data` field. + * @param data Optional metadata about the task. + */ +final class TaskProgressParams private ( + val taskId: sbt.internal.bsp.TaskId, + val eventTime: Option[Long], + val message: Option[String], + val total: Option[Long], + val progress: Option[Long], + val unit: Option[String], + val dataKind: Option[String], + val data: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]) extends Serializable { + + + + override def equals(o: Any): Boolean = this.eq(o.asInstanceOf[AnyRef]) || (o match { + case x: TaskProgressParams => (this.taskId == x.taskId) && (this.eventTime == x.eventTime) && (this.message == x.message) && (this.total == x.total) && (this.progress == x.progress) && (this.unit == x.unit) && (this.dataKind == x.dataKind) && (this.data == x.data) + case _ => false + }) + override def hashCode: Int = { + 37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (37 * (17 + "sbt.internal.bsp.TaskProgressParams".##) + taskId.##) + eventTime.##) + message.##) + total.##) + progress.##) + unit.##) + dataKind.##) + data.##) + } + override def toString: String = { + "TaskProgressParams(" + taskId + ", " + eventTime + ", " + message + ", " + total + ", " + progress + ", " + unit + ", " + dataKind + ", " + data + ")" + } + private[this] def copy(taskId: sbt.internal.bsp.TaskId = taskId, eventTime: Option[Long] = eventTime, message: Option[String] = message, total: Option[Long] = total, progress: Option[Long] = progress, unit: Option[String] = unit, dataKind: Option[String] = dataKind, data: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue] = data): TaskProgressParams = { + new TaskProgressParams(taskId, eventTime, message, total, progress, unit, dataKind, data) + } + def withTaskId(taskId: sbt.internal.bsp.TaskId): TaskProgressParams = { + copy(taskId = taskId) + } + def withEventTime(eventTime: Option[Long]): TaskProgressParams = { + copy(eventTime = eventTime) + } + def withEventTime(eventTime: Long): TaskProgressParams = { + copy(eventTime = Option(eventTime)) + } + def withMessage(message: Option[String]): TaskProgressParams = { + copy(message = message) + } + def withMessage(message: String): TaskProgressParams = { + copy(message = Option(message)) + } + def withTotal(total: Option[Long]): TaskProgressParams = { + copy(total = total) + } + def withTotal(total: Long): TaskProgressParams = { + copy(total = Option(total)) + } + def withProgress(progress: Option[Long]): TaskProgressParams = { + copy(progress = progress) + } + def withProgress(progress: Long): TaskProgressParams = { + copy(progress = Option(progress)) + } + def withUnit(unit: Option[String]): TaskProgressParams = { + copy(unit = unit) + } + def withUnit(unit: String): TaskProgressParams = { + copy(unit = Option(unit)) + } + def withDataKind(dataKind: Option[String]): TaskProgressParams = { + copy(dataKind = dataKind) + } + def withDataKind(dataKind: String): TaskProgressParams = { + copy(dataKind = Option(dataKind)) + } + def withData(data: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): TaskProgressParams = { + copy(data = data) + } + def withData(data: sjsonnew.shaded.scalajson.ast.unsafe.JValue): TaskProgressParams = { + copy(data = Option(data)) + } +} +object TaskProgressParams { + + def apply(taskId: sbt.internal.bsp.TaskId, eventTime: Option[Long], message: Option[String], total: Option[Long], progress: Option[Long], unit: Option[String], dataKind: Option[String], data: Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]): TaskProgressParams = new TaskProgressParams(taskId, eventTime, message, total, progress, unit, dataKind, data) + def apply(taskId: sbt.internal.bsp.TaskId, eventTime: Long, message: String, total: Long, progress: Long, unit: String, dataKind: String, data: sjsonnew.shaded.scalajson.ast.unsafe.JValue): TaskProgressParams = new TaskProgressParams(taskId, Option(eventTime), Option(message), Option(total), Option(progress), Option(unit), Option(dataKind), Option(data)) +} diff --git a/protocol/src/main/contraband-scala/sbt/internal/bsp/codec/JsonProtocol.scala b/protocol/src/main/contraband-scala/sbt/internal/bsp/codec/JsonProtocol.scala index 3b7446921..c55580da2 100644 --- a/protocol/src/main/contraband-scala/sbt/internal/bsp/codec/JsonProtocol.scala +++ b/protocol/src/main/contraband-scala/sbt/internal/bsp/codec/JsonProtocol.scala @@ -33,6 +33,7 @@ trait JsonProtocol extends sjsonnew.BasicJsonProtocol with sbt.internal.bsp.codec.DependencySourcesItemFormats with sbt.internal.bsp.codec.DependencySourcesResultFormats with sbt.internal.bsp.codec.TaskStartParamsFormats + with sbt.internal.bsp.codec.TaskProgressParamsFormats with sbt.internal.bsp.codec.TaskFinishParamsFormats with sbt.internal.bsp.codec.CompileParamsFormats with sbt.internal.bsp.codec.BspCompileResultFormats diff --git a/protocol/src/main/contraband-scala/sbt/internal/bsp/codec/TaskProgressParamsFormats.scala b/protocol/src/main/contraband-scala/sbt/internal/bsp/codec/TaskProgressParamsFormats.scala new file mode 100644 index 000000000..4931e0a29 --- /dev/null +++ b/protocol/src/main/contraband-scala/sbt/internal/bsp/codec/TaskProgressParamsFormats.scala @@ -0,0 +1,41 @@ +/** + * This code is generated using [[https://www.scala-sbt.org/contraband/ sbt-contraband]]. + */ + +// DO NOT EDIT MANUALLY +package sbt.internal.bsp.codec +import _root_.sjsonnew.{ Unbuilder, Builder, JsonFormat, deserializationError } +trait TaskProgressParamsFormats { self: sbt.internal.bsp.codec.TaskIdFormats with sbt.internal.util.codec.JValueFormats with sjsonnew.BasicJsonProtocol => +implicit lazy val TaskProgressParamsFormat: JsonFormat[sbt.internal.bsp.TaskProgressParams] = new JsonFormat[sbt.internal.bsp.TaskProgressParams] { + override def read[J](__jsOpt: Option[J], unbuilder: Unbuilder[J]): sbt.internal.bsp.TaskProgressParams = { + __jsOpt match { + case Some(__js) => + unbuilder.beginObject(__js) + val taskId = unbuilder.readField[sbt.internal.bsp.TaskId]("taskId") + val eventTime = unbuilder.readField[Option[Long]]("eventTime") + val message = unbuilder.readField[Option[String]]("message") + val total = unbuilder.readField[Option[Long]]("total") + val progress = unbuilder.readField[Option[Long]]("progress") + val unit = unbuilder.readField[Option[String]]("unit") + val dataKind = unbuilder.readField[Option[String]]("dataKind") + val data = unbuilder.readField[Option[sjsonnew.shaded.scalajson.ast.unsafe.JValue]]("data") + unbuilder.endObject() + sbt.internal.bsp.TaskProgressParams(taskId, eventTime, message, total, progress, unit, dataKind, data) + case None => + deserializationError("Expected JsObject but found None") + } + } + override def write[J](obj: sbt.internal.bsp.TaskProgressParams, builder: Builder[J]): Unit = { + builder.beginObject() + builder.addField("taskId", obj.taskId) + builder.addField("eventTime", obj.eventTime) + builder.addField("message", obj.message) + builder.addField("total", obj.total) + builder.addField("progress", obj.progress) + builder.addField("unit", obj.unit) + builder.addField("dataKind", obj.dataKind) + builder.addField("data", obj.data) + builder.endObject() + } +} +} diff --git a/protocol/src/main/contraband/bsp.contra b/protocol/src/main/contraband/bsp.contra index 5f423babc..01b3a450f 100644 --- a/protocol/src/main/contraband/bsp.contra +++ b/protocol/src/main/contraband/bsp.contra @@ -318,6 +318,32 @@ type TaskStartParams { data: sjsonnew.shaded.scalajson.ast.unsafe.JValue } +type TaskProgressParams { + ## Unique id of the task with optional reference to parent task id. + taskId: sbt.internal.bsp.TaskId! + + ## Optional timestamp of when the event started in milliseconds since Epoch. + eventTime: Long + + ## Message describing the task progress. + message: String + + ## If known, total amount of work units in this task. + total: Long + + ## If known, completed amount of work units in this task. + progress: Long + + ## Name of a work unit. For example, "files" or "tests". May be empty. + unit: String + + ## Kind of data to expect in the `data` field. + dataKind: String + + ## Optional metadata about the task. + data: sjsonnew.shaded.scalajson.ast.unsafe.JValue +} + type TaskFinishParams { ## Unique id of the task with optional reference to parent task id. taskId: sbt.internal.bsp.TaskId!