diff --git a/main/src/main/scala/sbt/internal/server/BuildServerProtocol.scala b/main/src/main/scala/sbt/internal/server/BuildServerProtocol.scala index d289884bf..0c6655e2d 100644 --- a/main/src/main/scala/sbt/internal/server/BuildServerProtocol.scala +++ b/main/src/main/scala/sbt/internal/server/BuildServerProtocol.scala @@ -39,6 +39,7 @@ import scala.collection.mutable import scala.util.control.NonFatal import scala.util.{ Failure, Success, Try } import scala.annotation.nowarn +import sbt.testing.Framework object BuildServerProtocol { import sbt.internal.bsp.codec.JsonProtocol._ @@ -844,25 +845,20 @@ object BuildServerProtocol { Keys.definedTests.?.value match { case None => Vector.empty case Some(definitions) => - val fingerprints = Keys.loadedTestFrameworks.?.value - .getOrElse(Map.empty) - .values - .flatMap { framework => - framework.fingerprints().map(fingerprint => (fingerprint, framework)) - } - .toMap + val frameworks: Seq[Framework] = Keys.loadedTestFrameworks.?.value + .map(_.values.toSeq) + .getOrElse(Seq.empty) - definitions - .groupBy(defn => fingerprints.get(defn.fingerprint)) - .map { - case (framework, definitions) => - ScalaTestClassesItem( - bspTargetIdentifier.value, - definitions.map(_.name).toVector, - framework.map(_.name()) - ) - } - .toSeq + val grouped = TestFramework.testMap(frameworks, definitions) + + grouped.map { + case (framework, definitions) => + ScalaTestClassesItem( + bspTargetIdentifier.value, + definitions.map(_.name).toVector, + framework.name() + ) + }.toSeq } } diff --git a/server-test/src/test/scala/testpkg/BuildServerTest.scala b/server-test/src/test/scala/testpkg/BuildServerTest.scala index 5aa231752..9cbe43980 100644 --- a/server-test/src/test/scala/testpkg/BuildServerTest.scala +++ b/server-test/src/test/scala/testpkg/BuildServerTest.scala @@ -308,7 +308,8 @@ object BuildServerTest extends AbstractServerTest { assert(svr.waitForString(10.seconds) { s => (s contains """"id":"72"""") && (s contains """"tests.FailingTest"""") && - (s contains """"tests.PassingTest"""") + (s contains """"tests.PassingTest"""") && + (s contains """"framework":"ScalaTest"""") }) } diff --git a/testing/src/main/scala/sbt/TestFramework.scala b/testing/src/main/scala/sbt/TestFramework.scala index fe3f76196..cb89dd5f4 100644 --- a/testing/src/main/scala/sbt/TestFramework.scala +++ b/testing/src/main/scala/sbt/TestFramework.scala @@ -231,21 +231,26 @@ object TestFramework { ): Vector[(String, TestFunction)] = for (d <- inputs; act <- mapped.get(d.name)) yield (d.name, act) - private[this] def testMap( + def testMap( frameworks: Seq[Framework], tests: Seq[TestDefinition] ): Map[Framework, Set[TestDefinition]] = { import scala.collection.mutable.{ HashMap, HashSet, Set } val map = new HashMap[Framework, Set[TestDefinition]] + def assignTest(test: TestDefinition): Unit = { def isTestForFramework(framework: Framework) = getFingerprints(framework).exists { t => matches(t, test.fingerprint) } - for (framework <- frameworks.find(isTestForFramework)) + + frameworks.find(isTestForFramework).foreach { framework => map.getOrElseUpdate(framework, new HashSet[TestDefinition]) += test + } } + if (frameworks.nonEmpty) for (test <- tests) assignTest(test) + map.toMap.mapValues(_.toSet).toMap }