Add merge, partition and groupBy methods to Relation.

Also add equals/hashCode to Relation.

Also add a basic test for groupBy.
This commit is contained in:
Benjy 2013-10-10 21:04:23 -07:00
parent 8e96a7e145
commit 754334e31b
2 changed files with 40 additions and 4 deletions

View File

@ -19,9 +19,11 @@ object Relation
{
val reversePairs = for( (a,bs) <- forward.view; b <- bs.view) yield (b, a)
val reverse = (Map.empty[B,Set[A]] /: reversePairs) { case (m, (b, a)) => add(m, b, a :: Nil) }
make(forward, reverse)
make(forward filter { case (a, bs) => bs.nonEmpty }, reverse)
}
def merge[A,B](rels: Traversable[Relation[A,B]]): Relation[A,B] = (Relation.empty[A, B] /: rels)(_ ++ _)
private[sbt] def remove[X,Y](map: M[X,Y], from: X, to: Y): M[X,Y] =
map.get(from) match {
case Some(tos) =>
@ -62,6 +64,8 @@ trait Relation[A,B]
def --(_1s: Traversable[A]): Relation[A,B]
/** Removes all `pairs` from this relation. */
def --(pairs: TraversableOnce[(A,B)]): Relation[A,B]
/** Removes all `relations` from this relation. */
def --(relations: Relation[A,B]): Relation[A,B]
/** Removes all pairs `(_1, _2)` from this relation. */
def -(_1: A): Relation[A,B]
/** Removes `pair` from this relation. */
@ -75,11 +79,16 @@ trait Relation[A,B]
/** Returns true iff `(a,b)` is in this relation*/
def contains(a: A, b: B): Boolean
/** Returns a relation with only pairs `(a,b)` for which `f(a,b)` is true.*/
def filter(f: (A,B) => Boolean): Relation[A,B]
/** Partitions this relation into a map of relations according to some discriminator function `f`. */
def groupBy[K](f: ((A,B)) => K): Map[K, Relation[A,B]]
/** Returns a pair of relations: the first contains only pairs `(a,b)` for which `f(a,b)` is true and
* the other only pairs `(a,b)` for which `f(a,b)` is false. */
def partition(f: (A,B) => Boolean): (Relation[A,B], Relation[A,B])
/** Partitions this relation into a map of relations according to some discriminator function. */
def groupBy[K](discriminator: ((A,B)) => K): Map[K, Relation[A,B]]
/** Returns all pairs in this relation.*/
def all: Traversable[(A,B)]
@ -96,6 +105,8 @@ trait Relation[A,B]
* The value associated with a given `_2` is the set of all `_1`s such that `(_1, _2)` is in this relation.*/
def reverseMap: Map[B, Set[A]]
}
// Note that we assume without checking that fwd and rev are consistent.
private final class MRelation[A,B](fwd: Map[A, Set[B]], rev: Map[B, Set[A]]) extends Relation[A,B]
{
def forwardMap = fwd
@ -121,6 +132,7 @@ private final class MRelation[A,B](fwd: Map[A, Set[B]], rev: Map[B, Set[A]]) ext
def --(ts: Traversable[A]): Relation[A,B] = ((this: Relation[A,B]) /: ts) { _ - _ }
def --(pairs: TraversableOnce[(A,B)]): Relation[A,B] = ((this: Relation[A,B]) /: pairs) { _ - _ }
def --(relations: Relation[A,B]): Relation[A,B] = --(relations.all)
def -(pair: (A,B)): Relation[A,B] =
new MRelation( remove(fwd, pair._1, pair._2), remove(rev, pair._2, pair._1) )
def -(t: A): Relation[A,B] =
@ -133,9 +145,23 @@ private final class MRelation[A,B](fwd: Map[A, Set[B]], rev: Map[B, Set[A]]) ext
def filter(f: (A,B) => Boolean): Relation[A,B] = Relation.empty[A,B] ++ all.filter(f.tupled)
def groupBy[K](f: ((A,B)) => K): Map[K, Relation[A,B]] = all.groupBy(f) mapValues { Relation.empty[A,B] ++ _ }
def partition(f: (A,B) => Boolean): (Relation[A,B], Relation[A,B]) = {
val (y, n) = all.partition(f.tupled)
(Relation.empty[A,B] ++ y, Relation.empty[A,B] ++ n)
}
def groupBy[K](discriminator: ((A,B)) => K): Map[K, Relation[A,B]] = all.groupBy(discriminator) mapValues { Relation.empty[A,B] ++ _ }
def contains(a: A, b: B): Boolean = forward(a)(b)
override def equals(other: Any) = other match {
// We assume that the forward and reverse maps are consistent, so we only use the forward map
// for equality. Note that key -> Empty is semantically the same as key not existing.
case o: MRelation[A,B] => forwardMap.filterNot(_._2.isEmpty) == o.forwardMap.filterNot(_._2.isEmpty)
case _ => false
}
override def hashCode = fwd.filterNot(_._2.isEmpty).hashCode()
override def toString = all.map { case (a,b) => a + " -> " + b }.mkString("Relation [", ", ", "]")
}

View File

@ -50,6 +50,16 @@ object RelationTest extends Properties("Relation")
("Reverse map does not contain removed" |: ( notIn(r.reverseMap, b, a) ) )
}
}
property("Groups correctly") = forAll { (entries: List[(Int, Double)], randomInt: Int) =>
val splitInto = randomInt % 10 + 1 // Split into 1-10 groups.
val rel = Relation.empty[Int, Double] ++ entries
val grouped = rel groupBy (_._1 % splitInto)
all(grouped.toSeq) {
case (k, rel_k) => rel_k._1s forall { _ % splitInto == k }
}
}
def all[T](s: Seq[T])(p: T => Prop): Prop =
if(s.isEmpty) true else s.map(p).reduceLeft(_ && _)
}