From 6361013601cdfea5ad624b0c18522fd3b7eee552 Mon Sep 17 00:00:00 2001 From: Benjy Date: Fri, 11 Oct 2013 15:52:23 -0700 Subject: [PATCH] Add equals/hashCode to generated API datatype classes. Equality is reference equality for classes that have lazy members (currently Structure and ClassLike) and member equality for everything else. This avoids the circularity issue due to lazy members. Forces each class to be either abstract or final, to ensure that the equals implementation is always correct. Fixes toString to avoid infinite recursion. --- .../scala/xsbt/datatype/DatatypeParser.scala | 8 +- .../main/scala/xsbt/datatype/Definition.scala | 5 +- .../main/scala/xsbt/datatype/Generator.scala | 80 ++++++++++++++++--- 3 files changed, 76 insertions(+), 17 deletions(-) diff --git a/util/datatype/src/main/scala/xsbt/datatype/DatatypeParser.scala b/util/datatype/src/main/scala/xsbt/datatype/DatatypeParser.scala index 060527af0..07ed819a2 100644 --- a/util/datatype/src/main/scala/xsbt/datatype/DatatypeParser.scala +++ b/util/datatype/src/main/scala/xsbt/datatype/DatatypeParser.scala @@ -55,17 +55,19 @@ class DatatypeParser extends NotNull def processLine(open: Array[ClassDef], definitions: List[Definition], line: Line): (Array[ClassDef], List[Definition]) = { + def makeAbstract(x: ClassDef) = new ClassDef(x.name, x.parent, x.members, true) + line match { case w: WhitespaceLine => (open, definitions) case e: EnumLine => (Array(), new EnumDef(e.name, e.members) :: open.toList ::: definitions) case m: MemberLine => if(m.level == 0 || m.level > open.length) error(m, "Member must be declared in a class definition") - else withCurrent(open, definitions, m.level) { c => List( c + m) } + else withCurrent(open, definitions, m.level) { c => List(c + m) } case c: ClassLine => - if(c.level == 0) (Array( new ClassDef(c.name, None, Nil) ), open.toList ::: definitions) + if(c.level == 0) (Array( new ClassDef(c.name, None, Nil, false) ), open.toList ::: definitions) else if(c.level > open.length) error(c, "Class must be declared as top level or as a subclass") - else withCurrent(open, definitions, c.level) { p => p :: new ClassDef(c.name, Some(p), Nil) :: Nil} + else withCurrent(open, definitions, c.level) { p => val p1 = makeAbstract(p); p1 :: new ClassDef(c.name, Some(p1), Nil, false) :: Nil} } } private def withCurrent(open: Array[ClassDef], definitions: List[Definition], level: Int)(onCurrent: ClassDef => Seq[ClassDef]): (Array[ClassDef], List[Definition]) = diff --git a/util/datatype/src/main/scala/xsbt/datatype/Definition.scala b/util/datatype/src/main/scala/xsbt/datatype/Definition.scala index ed787ae19..d53b4bad7 100644 --- a/util/datatype/src/main/scala/xsbt/datatype/Definition.scala +++ b/util/datatype/src/main/scala/xsbt/datatype/Definition.scala @@ -7,11 +7,12 @@ sealed trait Definition extends NotNull { val name: String } -final class ClassDef(val name: String, val parent: Option[ClassDef], val members: Seq[MemberDef]) extends Definition +final class ClassDef(val name: String, val parent: Option[ClassDef], val members: Seq[MemberDef], val isAbstract: Boolean) extends Definition { def allMembers = members ++ inheritedMembers def inheritedMembers: Seq[MemberDef] = parent.toList.flatMap(_.allMembers) - def + (m: MemberLine) = new ClassDef(name, parent, members ++ Seq(new MemberDef(m.name, m.tpe.stripPrefix("~"), m.single, m.tpe.startsWith("~"))) ) + def hasLazyMembers = members exists (_.lzy) + def + (m: MemberLine) = new ClassDef(name, parent, members ++ Seq(new MemberDef(m.name, m.tpe.stripPrefix("~"), m.single, m.tpe.startsWith("~"))), isAbstract) } final class EnumDef(val name: String, val members: Seq[String]) extends Definition diff --git a/util/datatype/src/main/scala/xsbt/datatype/Generator.scala b/util/datatype/src/main/scala/xsbt/datatype/Generator.scala index 551850215..2b78228c1 100644 --- a/util/datatype/src/main/scala/xsbt/datatype/Generator.scala +++ b/util/datatype/src/main/scala/xsbt/datatype/Generator.scala @@ -5,7 +5,6 @@ package xsbt package datatype import java.io.File -import sbt.Path import sbt.IO.write import Generator._ @@ -51,7 +50,7 @@ class ImmutableGenerator(pkgName: String, baseDir: File) extends GeneratorBase(p def writeClass(c: ClassDef): Unit = writeSource(c.name, basePkgName, classContent(c)) def classContent(c: ClassDef): String = { - val hasParent = c.parent.isDefined + val abstractStr = if (c.isAbstract) "abstract " else "final " val allMembers = c.allMembers.map(normalize) val normalizedMembers = c.members.map(normalize) val fields = normalizedMembers.map(m => "private final " + m.asJavaDeclaration(false) + ";") @@ -71,33 +70,90 @@ class ImmutableGenerator(pkgName: String, baseDir: File) extends GeneratorBase(p "}" "import java.util.Arrays;\n" + "import java.util.List;\n" + - "public class " + c.name + c.parent.map(" extends " + _.name + " ").getOrElse(" implements java.io.Serializable") + "\n" + + "public " + abstractStr + "class " + c.name + c.parent.map(" extends " + _.name + " ").getOrElse(" implements java.io.Serializable") + "\n" + "{\n\t" + constructor + "\n\t" + - (fields ++ accessors).mkString("\n\t") + "\n\t" + - toStringMethod(c) + "\n" + + (fields ++ accessors).mkString("\n\t") + "\n" + + (if (!c.isAbstract) "\t" + equalsMethod(c) + "\n\t" + hashCodeMethod(c) + "\n\t" + toStringMethod(c) + "\n" else "") + "}\n" } - } object Generator { def methodSignature(modifiers: String, returnType: String, name: String, parameters: String) = modifiers + " " + returnType + " " + name + "(" + parameters + ")" - def method(modifiers: String, returnType: String, name: String, parameters: String, content: String) = - methodSignature(modifiers, returnType, name, parameters) + "\n\t{\n\t\treturn " + content + ";\n\t}" + def method(modifiers: String, returnType: String, name: String, parameters: String, body: String) = + methodSignature(modifiers, returnType, name, parameters) + "\n\t{\n\t\t " + body + "\n\t}" def fieldToString(name: String, single: Boolean) = "\"" + name + ": \" + " + fieldString(name + "()", single) def fieldString(arg: String, single: Boolean) = if(single) arg else "Arrays.toString(" + arg + ")" + def fieldEquals(arg: String, single: Boolean, primitive: Boolean) = { + if(single) { + if (primitive) arg + " == o." + arg else arg + ".equals(o." + arg + ")" + } else { + "Arrays." + (if (primitive) "equals" else "deepEquals") + "(" + arg + ", o." + arg + ")" + } + } def normalize(m: MemberDef): MemberDef = - m.mapType(tpe => if(primitives(tpe.toLowerCase(Locale.ENGLISH))) tpe.toLowerCase(Locale.ENGLISH) else tpe) + m.mapType(tpe => if(isPrimitive(tpe)) tpe.toLowerCase(Locale.ENGLISH) else tpe) + def isPrimitive(tpe: String) = primitives(tpe.toLowerCase(Locale.ENGLISH)) private val primitives = Set("int", "boolean", "float", "long", "short", "byte", "char", "double") + def equalsMethod(c: ClassDef): String = + { + val content = if (c.hasLazyMembers) { + "return this == obj; // We have lazy members, so use object identity to avoid circularity." + } else { + val allMembers = c.allMembers.map(normalize) + val memberComparisons = if (allMembers.isEmpty) "true" else allMembers.map(m => fieldEquals(m.name + "()", m.single, isPrimitive(m.tpe))).mkString(" && ") + "if (this == obj) {\n\t\t\t return true;\n\t\t} else if (!(obj instanceof " + c.name + ")) {\n\t\t\t return false;\n\t\t} else {\n\t\t\t" + c.name + " o = (" + c.name + ")obj;\n\t\t\treturn " + memberComparisons + ";\n\t\t}" + } + method("public", "boolean", "equals", "Object obj", content) + } + + def hashCodeMethod(c: ClassDef): String = + { + def hashCodeExprForMember(m: MemberDef) = + { + val primitive = isPrimitive(m.tpe) + val f = m.name + "()" // Assumes m has already been normalized. + if (m.single) { + if (primitive) { + m.tpe.toLowerCase match { + case "boolean" => "(" + f + " ? 0 : 1)" + case "long" => "(int)(" + f + " ^ (" + f + " >>> 32))" + case "float" => "Float.floatToIntBits(" + f + ")" + case "double" => "(int)(Double.doubleToLongBits(" + f + ") ^ (Double.doubleToLongBits(" + f + ") >>> 32))" + case "int" => f + case _ => "(int)" + f + } + } else { + f + ".hashCode()" + } + } else { + "Arrays." + (if (primitive) "hashCode" else "deepHashCode") + "(" + f + ")" + } + } + val hashCodeExpr = if (c.hasLazyMembers) { + "super.hashCode()" + } else { + val allMembers = c.allMembers.map(normalize) + val memberHashCodes = allMembers.map(hashCodeExprForMember) + ("17" /: memberHashCodes){ "37 * (" + _ + ") + " + _ } + } + method("public", "int", "hashCode", "", "return " + hashCodeExpr + ";") + } + def toStringMethod(c: ClassDef): String = { - val allMembers = c.allMembers.map(normalize) - val parametersString = if(allMembers.isEmpty) "\"\"" else allMembers.map(m => fieldToString(m.name, m.single)).mkString(" + \", \" + ") - method("public", "String", "toString", "", "\"" + c.name + "(\" + " + parametersString + "+ \")\"") + val content = if (c.hasLazyMembers) { + "return super.toString();" + } else { + val allMembers = c.allMembers.map(normalize) + val parametersString = if(allMembers.isEmpty) "\"\"" else allMembers.map(m => fieldToString(m.name, m.single)).mkString(" + \", \" + ") + "return \"" + c.name + "(\" + " + parametersString + " + \")\";" + } + method("public", "String", "toString", "", content) } def writeDefinitions(ds: Iterable[Definition])(writeDefinition: Definition => Unit)