diff --git a/util/collection/Relation.scala b/util/collection/Relation.scala index b305bb47c..c5195ffb7 100644 --- a/util/collection/Relation.scala +++ b/util/collection/Relation.scala @@ -29,7 +29,7 @@ object Relation private[sbt] def combine[X,Y](a: M[X,Y], b: M[X,Y]): M[X,Y] = (a /: b) { (map, mapping) => add(map, mapping._1, mapping._2) } - private[sbt] def add[X,Y](map: M[X,Y], from: X, to: Iterable[Y]): M[X,Y] = + private[sbt] def add[X,Y](map: M[X,Y], from: X, to: Traversable[Y]): M[X,Y] = map.updated(from, get(map, from) ++ to) private[sbt] def get[X,Y](map: M[X,Y], t: X): Set[Y] = map.getOrElse(t, Set.empty[Y]) @@ -49,15 +49,15 @@ trait Relation[A,B] /** Includes the relation (a, b). */ def +(a: A, b: B): Relation[A,B] /** Includes the relations (a, b) for all b in bs. */ - def +(a: A, bs: Iterable[B]): Relation[A,B] + def +(a: A, bs: Traversable[B]): Relation[A,B] /** Returns the union of the relation r with this relation. */ def ++(r: Relation[A,B]): Relation[A,B] /** Includes the given relations. */ - def ++(rs: Iterable[(A,B)]): Relation[A,B] + def ++(rs: Traversable[(A,B)]): Relation[A,B] /** Removes all relations (_1, _2) for all _1 in _1s. */ - def --(_1s: Iterable[A]): Relation[A,B] + def --(_1s: Traversable[A]): Relation[A,B] /** Removes all `pairs` from this relation. */ - def --(pairs: Traversable[(A,B)]): Relation[A,B] + def --(pairs: TraversableOnce[(A,B)]): Relation[A,B] /** Removes all pairs (_1, _2) from this relation. */ def -(_1: A): Relation[A,B] /** Removes `pair` from this relation. */ @@ -69,6 +69,11 @@ trait Relation[A,B] /** Returns the number of pairs in this relation */ def size: Int + /** Returns true iff (a,b) is in this relation*/ + def contains(a: A, b: B): Boolean + /** Returns a relation with only pairs (a,b) for which f(a,b) is true.*/ + def filter(f: (A,B) => Boolean): Relation[A,B] + /** Returns all pairs in this relation.*/ def all: Traversable[(A,B)] @@ -92,14 +97,14 @@ private final class MRelation[A,B](fwd: Map[A, Set[B]], rev: Map[B, Set[A]]) ext def +(pair: (A,B)) = this + (pair._1, Set(pair._2)) def +(from: A, to: B) = this + (from, to :: Nil) - def +(from: A, to: Iterable[B]) = + def +(from: A, to: Traversable[B]) = new MRelation( add(fwd, from, to), (rev /: to) { (map, t) => add(map, t, from :: Nil) }) - def ++(rs: Iterable[(A,B)]) = ((this: Relation[A,B]) /: rs) { _ + _ } + def ++(rs: Traversable[(A,B)]) = ((this: Relation[A,B]) /: rs) { _ + _ } def ++(other: Relation[A,B]) = new MRelation[A,B]( combine(fwd, other.forwardMap), combine(rev, other.reverseMap) ) - def --(ts: Iterable[A]): Relation[A,B] = ((this: Relation[A,B]) /: ts) { _ - _ } - def --(pairs: Traversable[(A,B)]): Relation[A,B] = ((this: Relation[A,B]) /: pairs) { _ - _ } + def --(ts: Traversable[A]): Relation[A,B] = ((this: Relation[A,B]) /: ts) { _ - _ } + def --(pairs: TraversableOnce[(A,B)]): Relation[A,B] = ((this: Relation[A,B]) /: pairs) { _ - _ } def -(pair: (A,B)): Relation[A,B] = new MRelation( remove(fwd, pair._1, pair._2), remove(rev, pair._2, pair._1) ) def -(t: A): Relation[A,B] = @@ -110,5 +115,9 @@ private final class MRelation[A,B](fwd: Map[A, Set[B]], rev: Map[B, Set[A]]) ext case None => this } + def filter(f: (A,B) => Boolean): Relation[A,B] = Relation.empty[A,B] ++ all.filter(f.tupled) + + def contains(a: A, b: B): Boolean = forward(a)(b) + override def toString = all.map { case (a,b) => a + " -> " + b }.mkString("Relation [", ", ", "]") } \ No newline at end of file