Improved implementation for `parents` accumulation for java classes

It now considers `ParameterizedType` and includes all interfaces recursively
This commit is contained in:
Indrajit Raychaudhuri 2012-03-27 01:33:18 +05:30
parent 41eb26ae1f
commit 89678735e1
1 changed files with 28 additions and 13 deletions

View File

@ -1,11 +1,12 @@
package sbt
import java.lang.reflect.{Array => _, _}
import java.lang.annotation.Annotation
import xsbti.api
import xsbti.SafeLazy
import SafeLazy.strict
import collection.mutable
import java.lang.reflect.{Array => _, _}
import java.lang.annotation.Annotation
import annotation.tailrec
import xsbti.api
import xsbti.SafeLazy
import SafeLazy.strict
import collection.mutable
object ClassToAPI
{
@ -76,17 +77,31 @@ object ClassToAPI
private val lzyEmptyTpeArray = lzy(emptyTypeArray)
private val lzyEmptyDefArray = lzy(new Array[xsbti.api.Definition](0))
private def allSuperclasses(t: Type): Seq[Type] =
private def allSuperTypes(t: Type): Seq[Type] =
{
def accumulate(t: Type, accum: Seq[Type] = Seq.empty): Seq[Type] = t match {
case cl: Class[_] => { val s = cl.getGenericSuperclass ; accumulate(s, accum :+ s) }
case _ => accum
@tailrec def accumulate(t: Type, accum: Seq[Type] = Seq.empty): Seq[Type] = t match {
case c: Class[_] =>
val (parent, interfaces) = (c.getGenericSuperclass, c.getGenericInterfaces)
accumulate(parent, (accum :+ parent) ++ flattenAll(interfaces))
case p: ParameterizedType =>
accumulate(p.getRawType, accum)
case _ =>
accum
}
accumulate(t)
@tailrec def flattenAll(interfaces: Seq[Type], accum: Seq[Type] = Seq.empty): Seq[Type] =
{
if (!interfaces.isEmpty) {
val raw = interfaces map { case p: ParameterizedType => p.getRawType; case i => i }
val children = raw flatMap { case i: Class[_] => i.getGenericInterfaces; case _ => Seq.empty }
flattenAll(children, accum ++ interfaces ++ children)
}
else
accum
}
accumulate(t).filterNot(_ == null).distinct
}
def parents(c: Class[_]): Seq[api.Type] =
types(allSuperclasses(c) ++ c.getGenericInterfaces)
def parents(c: Class[_]): Seq[api.Type] = types(allSuperTypes(c))
def types(ts: Seq[Type]): Array[api.Type] = ts filter (_ ne null) map reference toArray;
def upperBounds(ts: Array[Type]): api.Type =
new api.Structure(lzy(types(ts)), lzyEmptyDefArray, lzyEmptyDefArray)