From 754334e31b519603eeb7f6e6527240fc47241e56 Mon Sep 17 00:00:00 2001 From: Benjy Date: Thu, 10 Oct 2013 21:04:23 -0700 Subject: [PATCH] Add merge, partition and groupBy methods to Relation. Also add equals/hashCode to Relation. Also add a basic test for groupBy. --- .../src/main/scala/sbt/Relation.scala | 34 ++++++++++++++++--- .../src/test/scala/RelationTest.scala | 10 ++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/util/relation/src/main/scala/sbt/Relation.scala b/util/relation/src/main/scala/sbt/Relation.scala index acf19d6b7..d97ee2321 100644 --- a/util/relation/src/main/scala/sbt/Relation.scala +++ b/util/relation/src/main/scala/sbt/Relation.scala @@ -19,9 +19,11 @@ object Relation { val reversePairs = for( (a,bs) <- forward.view; b <- bs.view) yield (b, a) val reverse = (Map.empty[B,Set[A]] /: reversePairs) { case (m, (b, a)) => add(m, b, a :: Nil) } - make(forward, reverse) + make(forward filter { case (a, bs) => bs.nonEmpty }, reverse) } + def merge[A,B](rels: Traversable[Relation[A,B]]): Relation[A,B] = (Relation.empty[A, B] /: rels)(_ ++ _) + private[sbt] def remove[X,Y](map: M[X,Y], from: X, to: Y): M[X,Y] = map.get(from) match { case Some(tos) => @@ -62,6 +64,8 @@ trait Relation[A,B] def --(_1s: Traversable[A]): Relation[A,B] /** Removes all `pairs` from this relation. */ def --(pairs: TraversableOnce[(A,B)]): Relation[A,B] + /** Removes all `relations` from this relation. */ + def --(relations: Relation[A,B]): Relation[A,B] /** Removes all pairs `(_1, _2)` from this relation. */ def -(_1: A): Relation[A,B] /** Removes `pair` from this relation. */ @@ -75,11 +79,16 @@ trait Relation[A,B] /** Returns true iff `(a,b)` is in this relation*/ def contains(a: A, b: B): Boolean + /** Returns a relation with only pairs `(a,b)` for which `f(a,b)` is true.*/ def filter(f: (A,B) => Boolean): Relation[A,B] - /** Partitions this relation into a map of relations according to some discriminator function `f`. */ - def groupBy[K](f: ((A,B)) => K): Map[K, Relation[A,B]] + /** Returns a pair of relations: the first contains only pairs `(a,b)` for which `f(a,b)` is true and + * the other only pairs `(a,b)` for which `f(a,b)` is false. */ + def partition(f: (A,B) => Boolean): (Relation[A,B], Relation[A,B]) + + /** Partitions this relation into a map of relations according to some discriminator function. */ + def groupBy[K](discriminator: ((A,B)) => K): Map[K, Relation[A,B]] /** Returns all pairs in this relation.*/ def all: Traversable[(A,B)] @@ -96,6 +105,8 @@ trait Relation[A,B] * The value associated with a given `_2` is the set of all `_1`s such that `(_1, _2)` is in this relation.*/ def reverseMap: Map[B, Set[A]] } + +// Note that we assume without checking that fwd and rev are consistent. private final class MRelation[A,B](fwd: Map[A, Set[B]], rev: Map[B, Set[A]]) extends Relation[A,B] { def forwardMap = fwd @@ -121,6 +132,7 @@ private final class MRelation[A,B](fwd: Map[A, Set[B]], rev: Map[B, Set[A]]) ext def --(ts: Traversable[A]): Relation[A,B] = ((this: Relation[A,B]) /: ts) { _ - _ } def --(pairs: TraversableOnce[(A,B)]): Relation[A,B] = ((this: Relation[A,B]) /: pairs) { _ - _ } + def --(relations: Relation[A,B]): Relation[A,B] = --(relations.all) def -(pair: (A,B)): Relation[A,B] = new MRelation( remove(fwd, pair._1, pair._2), remove(rev, pair._2, pair._1) ) def -(t: A): Relation[A,B] = @@ -133,9 +145,23 @@ private final class MRelation[A,B](fwd: Map[A, Set[B]], rev: Map[B, Set[A]]) ext def filter(f: (A,B) => Boolean): Relation[A,B] = Relation.empty[A,B] ++ all.filter(f.tupled) - def groupBy[K](f: ((A,B)) => K): Map[K, Relation[A,B]] = all.groupBy(f) mapValues { Relation.empty[A,B] ++ _ } + def partition(f: (A,B) => Boolean): (Relation[A,B], Relation[A,B]) = { + val (y, n) = all.partition(f.tupled) + (Relation.empty[A,B] ++ y, Relation.empty[A,B] ++ n) + } + + def groupBy[K](discriminator: ((A,B)) => K): Map[K, Relation[A,B]] = all.groupBy(discriminator) mapValues { Relation.empty[A,B] ++ _ } def contains(a: A, b: B): Boolean = forward(a)(b) + override def equals(other: Any) = other match { + // We assume that the forward and reverse maps are consistent, so we only use the forward map + // for equality. Note that key -> Empty is semantically the same as key not existing. + case o: MRelation[A,B] => forwardMap.filterNot(_._2.isEmpty) == o.forwardMap.filterNot(_._2.isEmpty) + case _ => false + } + + override def hashCode = fwd.filterNot(_._2.isEmpty).hashCode() + override def toString = all.map { case (a,b) => a + " -> " + b }.mkString("Relation [", ", ", "]") } diff --git a/util/relation/src/test/scala/RelationTest.scala b/util/relation/src/test/scala/RelationTest.scala index e82bd861d..de63fe893 100644 --- a/util/relation/src/test/scala/RelationTest.scala +++ b/util/relation/src/test/scala/RelationTest.scala @@ -50,6 +50,16 @@ object RelationTest extends Properties("Relation") ("Reverse map does not contain removed" |: ( notIn(r.reverseMap, b, a) ) ) } } + + property("Groups correctly") = forAll { (entries: List[(Int, Double)], randomInt: Int) => + val splitInto = randomInt % 10 + 1 // Split into 1-10 groups. + val rel = Relation.empty[Int, Double] ++ entries + val grouped = rel groupBy (_._1 % splitInto) + all(grouped.toSeq) { + case (k, rel_k) => rel_k._1s forall { _ % splitInto == k } + } + } + def all[T](s: Seq[T])(p: T => Prop): Prop = if(s.isEmpty) true else s.map(p).reduceLeft(_ && _) }