diff --git a/util/datatype/src/main/scala/xsbt/datatype/DatatypeParser.scala b/util/datatype/src/main/scala/xsbt/datatype/DatatypeParser.scala index 060527af0..07ed819a2 100644 --- a/util/datatype/src/main/scala/xsbt/datatype/DatatypeParser.scala +++ b/util/datatype/src/main/scala/xsbt/datatype/DatatypeParser.scala @@ -55,17 +55,19 @@ class DatatypeParser extends NotNull def processLine(open: Array[ClassDef], definitions: List[Definition], line: Line): (Array[ClassDef], List[Definition]) = { + def makeAbstract(x: ClassDef) = new ClassDef(x.name, x.parent, x.members, true) + line match { case w: WhitespaceLine => (open, definitions) case e: EnumLine => (Array(), new EnumDef(e.name, e.members) :: open.toList ::: definitions) case m: MemberLine => if(m.level == 0 || m.level > open.length) error(m, "Member must be declared in a class definition") - else withCurrent(open, definitions, m.level) { c => List( c + m) } + else withCurrent(open, definitions, m.level) { c => List(c + m) } case c: ClassLine => - if(c.level == 0) (Array( new ClassDef(c.name, None, Nil) ), open.toList ::: definitions) + if(c.level == 0) (Array( new ClassDef(c.name, None, Nil, false) ), open.toList ::: definitions) else if(c.level > open.length) error(c, "Class must be declared as top level or as a subclass") - else withCurrent(open, definitions, c.level) { p => p :: new ClassDef(c.name, Some(p), Nil) :: Nil} + else withCurrent(open, definitions, c.level) { p => val p1 = makeAbstract(p); p1 :: new ClassDef(c.name, Some(p1), Nil, false) :: Nil} } } private def withCurrent(open: Array[ClassDef], definitions: List[Definition], level: Int)(onCurrent: ClassDef => Seq[ClassDef]): (Array[ClassDef], List[Definition]) = diff --git a/util/datatype/src/main/scala/xsbt/datatype/Definition.scala b/util/datatype/src/main/scala/xsbt/datatype/Definition.scala index ed787ae19..d53b4bad7 100644 --- a/util/datatype/src/main/scala/xsbt/datatype/Definition.scala +++ b/util/datatype/src/main/scala/xsbt/datatype/Definition.scala @@ -7,11 +7,12 @@ sealed trait Definition extends NotNull { val name: String } -final class ClassDef(val name: String, val parent: Option[ClassDef], val members: Seq[MemberDef]) extends Definition +final class ClassDef(val name: String, val parent: Option[ClassDef], val members: Seq[MemberDef], val isAbstract: Boolean) extends Definition { def allMembers = members ++ inheritedMembers def inheritedMembers: Seq[MemberDef] = parent.toList.flatMap(_.allMembers) - def + (m: MemberLine) = new ClassDef(name, parent, members ++ Seq(new MemberDef(m.name, m.tpe.stripPrefix("~"), m.single, m.tpe.startsWith("~"))) ) + def hasLazyMembers = members exists (_.lzy) + def + (m: MemberLine) = new ClassDef(name, parent, members ++ Seq(new MemberDef(m.name, m.tpe.stripPrefix("~"), m.single, m.tpe.startsWith("~"))), isAbstract) } final class EnumDef(val name: String, val members: Seq[String]) extends Definition diff --git a/util/datatype/src/main/scala/xsbt/datatype/Generator.scala b/util/datatype/src/main/scala/xsbt/datatype/Generator.scala index 551850215..2b78228c1 100644 --- a/util/datatype/src/main/scala/xsbt/datatype/Generator.scala +++ b/util/datatype/src/main/scala/xsbt/datatype/Generator.scala @@ -5,7 +5,6 @@ package xsbt package datatype import java.io.File -import sbt.Path import sbt.IO.write import Generator._ @@ -51,7 +50,7 @@ class ImmutableGenerator(pkgName: String, baseDir: File) extends GeneratorBase(p def writeClass(c: ClassDef): Unit = writeSource(c.name, basePkgName, classContent(c)) def classContent(c: ClassDef): String = { - val hasParent = c.parent.isDefined + val abstractStr = if (c.isAbstract) "abstract " else "final " val allMembers = c.allMembers.map(normalize) val normalizedMembers = c.members.map(normalize) val fields = normalizedMembers.map(m => "private final " + m.asJavaDeclaration(false) + ";") @@ -71,33 +70,90 @@ class ImmutableGenerator(pkgName: String, baseDir: File) extends GeneratorBase(p "}" "import java.util.Arrays;\n" + "import java.util.List;\n" + - "public class " + c.name + c.parent.map(" extends " + _.name + " ").getOrElse(" implements java.io.Serializable") + "\n" + + "public " + abstractStr + "class " + c.name + c.parent.map(" extends " + _.name + " ").getOrElse(" implements java.io.Serializable") + "\n" + "{\n\t" + constructor + "\n\t" + - (fields ++ accessors).mkString("\n\t") + "\n\t" + - toStringMethod(c) + "\n" + + (fields ++ accessors).mkString("\n\t") + "\n" + + (if (!c.isAbstract) "\t" + equalsMethod(c) + "\n\t" + hashCodeMethod(c) + "\n\t" + toStringMethod(c) + "\n" else "") + "}\n" } - } object Generator { def methodSignature(modifiers: String, returnType: String, name: String, parameters: String) = modifiers + " " + returnType + " " + name + "(" + parameters + ")" - def method(modifiers: String, returnType: String, name: String, parameters: String, content: String) = - methodSignature(modifiers, returnType, name, parameters) + "\n\t{\n\t\treturn " + content + ";\n\t}" + def method(modifiers: String, returnType: String, name: String, parameters: String, body: String) = + methodSignature(modifiers, returnType, name, parameters) + "\n\t{\n\t\t " + body + "\n\t}" def fieldToString(name: String, single: Boolean) = "\"" + name + ": \" + " + fieldString(name + "()", single) def fieldString(arg: String, single: Boolean) = if(single) arg else "Arrays.toString(" + arg + ")" + def fieldEquals(arg: String, single: Boolean, primitive: Boolean) = { + if(single) { + if (primitive) arg + " == o." + arg else arg + ".equals(o." + arg + ")" + } else { + "Arrays." + (if (primitive) "equals" else "deepEquals") + "(" + arg + ", o." + arg + ")" + } + } def normalize(m: MemberDef): MemberDef = - m.mapType(tpe => if(primitives(tpe.toLowerCase(Locale.ENGLISH))) tpe.toLowerCase(Locale.ENGLISH) else tpe) + m.mapType(tpe => if(isPrimitive(tpe)) tpe.toLowerCase(Locale.ENGLISH) else tpe) + def isPrimitive(tpe: String) = primitives(tpe.toLowerCase(Locale.ENGLISH)) private val primitives = Set("int", "boolean", "float", "long", "short", "byte", "char", "double") + def equalsMethod(c: ClassDef): String = + { + val content = if (c.hasLazyMembers) { + "return this == obj; // We have lazy members, so use object identity to avoid circularity." + } else { + val allMembers = c.allMembers.map(normalize) + val memberComparisons = if (allMembers.isEmpty) "true" else allMembers.map(m => fieldEquals(m.name + "()", m.single, isPrimitive(m.tpe))).mkString(" && ") + "if (this == obj) {\n\t\t\t return true;\n\t\t} else if (!(obj instanceof " + c.name + ")) {\n\t\t\t return false;\n\t\t} else {\n\t\t\t" + c.name + " o = (" + c.name + ")obj;\n\t\t\treturn " + memberComparisons + ";\n\t\t}" + } + method("public", "boolean", "equals", "Object obj", content) + } + + def hashCodeMethod(c: ClassDef): String = + { + def hashCodeExprForMember(m: MemberDef) = + { + val primitive = isPrimitive(m.tpe) + val f = m.name + "()" // Assumes m has already been normalized. + if (m.single) { + if (primitive) { + m.tpe.toLowerCase match { + case "boolean" => "(" + f + " ? 0 : 1)" + case "long" => "(int)(" + f + " ^ (" + f + " >>> 32))" + case "float" => "Float.floatToIntBits(" + f + ")" + case "double" => "(int)(Double.doubleToLongBits(" + f + ") ^ (Double.doubleToLongBits(" + f + ") >>> 32))" + case "int" => f + case _ => "(int)" + f + } + } else { + f + ".hashCode()" + } + } else { + "Arrays." + (if (primitive) "hashCode" else "deepHashCode") + "(" + f + ")" + } + } + val hashCodeExpr = if (c.hasLazyMembers) { + "super.hashCode()" + } else { + val allMembers = c.allMembers.map(normalize) + val memberHashCodes = allMembers.map(hashCodeExprForMember) + ("17" /: memberHashCodes){ "37 * (" + _ + ") + " + _ } + } + method("public", "int", "hashCode", "", "return " + hashCodeExpr + ";") + } + def toStringMethod(c: ClassDef): String = { - val allMembers = c.allMembers.map(normalize) - val parametersString = if(allMembers.isEmpty) "\"\"" else allMembers.map(m => fieldToString(m.name, m.single)).mkString(" + \", \" + ") - method("public", "String", "toString", "", "\"" + c.name + "(\" + " + parametersString + "+ \")\"") + val content = if (c.hasLazyMembers) { + "return super.toString();" + } else { + val allMembers = c.allMembers.map(normalize) + val parametersString = if(allMembers.isEmpty) "\"\"" else allMembers.map(m => fieldToString(m.name, m.single)).mkString(" + \", \" + ") + "return \"" + c.name + "(\" + " + parametersString + " + \")\";" + } + method("public", "String", "toString", "", content) } def writeDefinitions(ds: Iterable[Definition])(writeDefinition: Definition => Unit)