From 0b3c2dada58aae4f7af7d514c689b3bdd35ad917 Mon Sep 17 00:00:00 2001 From: Mark Harrah Date: Sat, 18 Dec 2010 12:44:18 -0500 Subject: [PATCH] expression evaluator --- compile/Eval.scala | 109 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 compile/Eval.scala diff --git a/compile/Eval.scala b/compile/Eval.scala new file mode 100644 index 000000000..e4e7fd0e5 --- /dev/null +++ b/compile/Eval.scala @@ -0,0 +1,109 @@ +package sbt +package compile + +import scala.reflect.Manifest +import scala.tools.nsc.{ast, interpreter, io, reporters, util, CompilerCommand, Global, Phase, Settings} +import interpreter.AbstractFileClassLoader +import io.VirtualDirectory +import ast.parser.Tokens +import reporters.{ConsoleReporter, Reporter} +import util.BatchSourceFile +import Tokens.EOF + +final class EvalException(msg: String) extends RuntimeException(msg) +// not thread safe, since it reuses a Global instance +final class Eval(options: Seq[String], mkReporter: Settings => Reporter, parent: ClassLoader) +{ + def this() = this("-cp" :: IO.classLocationFile[ScalaObject].getAbsolutePath :: Nil, s => new ConsoleReporter(s), getClass.getClassLoader) + + val settings = new Settings(Console.println) + val command = new CompilerCommand(options.toList, settings) + val reporter = mkReporter(settings) + val global: Global = new Global(settings, reporter) + import global._ + import definitions._ + + def eval[T](expression: String)(implicit mf: Manifest[T]): T = eval(expression, mf.toString).asInstanceOf[T] + def eval(expression: String, tpeName: String): Any = + { + reporter.reset + val unit = mkUnit(expression) + val run = new Run { + override def units = (unit :: Nil).iterator + } + def unlinkAll(): Unit = for( (sym, _) <- run.symSource ) unlink(sym) + def unlink(sym: Symbol) = sym.owner.info.decls.unlink(sym) + + try { eval0(expression, tpeName, run, unit) } finally { unlinkAll() } + } + def eval0(expression: String, tpeName: String, run: Run, unit: CompilationUnit): Any = + { + val virtualDirectory = new VirtualDirectory("", None) + settings.outputDirs setSingleOutput virtualDirectory + + val parser = new syntaxAnalyzer.UnitParser(unit) + val tree: Tree = parser.expr() + parser.accept(EOF) + checkError("Error parsing expression.") + + val tpeParser = new syntaxAnalyzer.UnitParser(mkUnit(tpeName)) + val tpt: Tree = tpeParser.typ() + tpeParser.accept(EOF) + checkError("Error parsing type.") + + unit.body = augment(parser, tree, tpt) + + def compile(phase: Phase): Unit = + { + globalPhase = phase + if(phase == null || phase == phase.next || reporter.hasErrors) + () + else + { + reporter.withSource(unit.source) { + atPhase(phase) { phase.run } + } + compile(phase.next) + } + } + + compile(run.namerPhase) + checkError("Type error.") + + val loader = new AbstractFileClassLoader(virtualDirectory, parent) + getValue(loader) + } + val WrapObjectName = "$sbtobj" + val WrapValName = "$sbtdef" + //wrap tree in object WrapObjectName { def WrapValName = } + def augment(parser: global.syntaxAnalyzer.UnitParser, tree: Tree, tpt: Tree): Tree = + { + def emptyPkg = parser.atPos(0, 0, 0) { Ident(nme.EMPTY_PACKAGE_NAME) } + def emptyInit = DefDef( + NoMods, + nme.CONSTRUCTOR, + Nil, + List(Nil), + TypeTree(), + Block(List(Apply(Select(Super("", ""), nme.CONSTRUCTOR), Nil)), Literal(Constant(()))) + ) + + def method = DefDef(NoMods, WrapValName, Nil, Nil, tpt, tree) + def moduleBody = Template(List(gen.scalaScalaObjectConstr), emptyValDef, List(emptyInit, method)) + def moduleDef = ModuleDef(NoMods, WrapObjectName, moduleBody) + parser.makePackaging(0, emptyPkg, List(moduleDef)) + } + + def getValue[T](loader: ClassLoader): T = + { + val clazz = Class.forName(WrapObjectName + "$", true, loader) + val module = clazz.getField("MODULE$").get(null) + val accessor = module.getClass.getMethod(WrapValName) + val value = accessor.invoke(module) + value.asInstanceOf[T] + } + + + def mkUnit(s: String) = new CompilationUnit(new BatchSourceFile("", s)) + def checkError(label: String) = if(reporter.hasErrors) throw new EvalException(label) +}