From 89678735e18a3a390b464cdc2230c57675b5e146 Mon Sep 17 00:00:00 2001 From: Indrajit Raychaudhuri Date: Tue, 27 Mar 2012 01:33:18 +0530 Subject: [PATCH] Improved implementation for `parents` accumulation for java classes It now considers `ParameterizedType` and includes all interfaces recursively --- compile/api/ClassToAPI.scala | 41 ++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/compile/api/ClassToAPI.scala b/compile/api/ClassToAPI.scala index f7a2bb25d..41f2ce61f 100644 --- a/compile/api/ClassToAPI.scala +++ b/compile/api/ClassToAPI.scala @@ -1,11 +1,12 @@ package sbt - import java.lang.reflect.{Array => _, _} - import java.lang.annotation.Annotation - import xsbti.api - import xsbti.SafeLazy - import SafeLazy.strict - import collection.mutable +import java.lang.reflect.{Array => _, _} +import java.lang.annotation.Annotation +import annotation.tailrec +import xsbti.api +import xsbti.SafeLazy +import SafeLazy.strict +import collection.mutable object ClassToAPI { @@ -76,17 +77,31 @@ object ClassToAPI private val lzyEmptyTpeArray = lzy(emptyTypeArray) private val lzyEmptyDefArray = lzy(new Array[xsbti.api.Definition](0)) - private def allSuperclasses(t: Type): Seq[Type] = + private def allSuperTypes(t: Type): Seq[Type] = { - def accumulate(t: Type, accum: Seq[Type] = Seq.empty): Seq[Type] = t match { - case cl: Class[_] => { val s = cl.getGenericSuperclass ; accumulate(s, accum :+ s) } - case _ => accum + @tailrec def accumulate(t: Type, accum: Seq[Type] = Seq.empty): Seq[Type] = t match { + case c: Class[_] => + val (parent, interfaces) = (c.getGenericSuperclass, c.getGenericInterfaces) + accumulate(parent, (accum :+ parent) ++ flattenAll(interfaces)) + case p: ParameterizedType => + accumulate(p.getRawType, accum) + case _ => + accum } - accumulate(t) + @tailrec def flattenAll(interfaces: Seq[Type], accum: Seq[Type] = Seq.empty): Seq[Type] = + { + if (!interfaces.isEmpty) { + val raw = interfaces map { case p: ParameterizedType => p.getRawType; case i => i } + val children = raw flatMap { case i: Class[_] => i.getGenericInterfaces; case _ => Seq.empty } + flattenAll(children, accum ++ interfaces ++ children) + } + else + accum + } + accumulate(t).filterNot(_ == null).distinct } - def parents(c: Class[_]): Seq[api.Type] = - types(allSuperclasses(c) ++ c.getGenericInterfaces) + def parents(c: Class[_]): Seq[api.Type] = types(allSuperTypes(c)) def types(ts: Seq[Type]): Array[api.Type] = ts filter (_ ne null) map reference toArray; def upperBounds(ts: Array[Type]): api.Type = new api.Structure(lzy(types(ts)), lzyEmptyDefArray, lzyEmptyDefArray)