#!/usr/bin/env python3
# pylint: disable=C0103,C0114,C0115,C0116,C0123,C0301,R0902,R0913,R0914,R0912,R0915,W0621
######################################################################

import argparse
import glob
import os
import re
import sys
import textwrap
# from pprint import pprint, pformat


class Node:
    def __init__(self, name, superClass, file, lineno):
        self._name = name
        self._superClass = superClass
        self._subClasses = []  # Initially list, but tuple after completion
        self._allSuperClasses = None  # Computed on demand after completion
        self._allSubClasses = None  # Computed on demand after completion
        self._typeId = None  # Concrete type identifier number for leaf classes
        self._typeIdMin = None  # Lowest type identifier number for class
        self._typeIdMax = None  # Highest type identifier number for class
        self._file = file  # File this class is defined in
        self._lineno = lineno  # Line this class is defined on
        self._ordIdx = None  # Ordering index of this class
        self._arity = -1  # Arity of node
        self._ops = {}  # Operands of node

    @property
    def name(self):
        return self._name

    @property
    def superClass(self):
        return self._superClass

    @property
    def isRoot(self):
        return self.superClass is None

    @property
    def isCompleted(self):
        return isinstance(self._subClasses, tuple)

    @property
    def file(self):
        return self._file

    @property
    def lineno(self):
        return self._lineno

    # Pre completion methods
    def addSubClass(self, subClass):
        assert not self.isCompleted
        self._subClasses.append(subClass)

    def addOp(self, n, name, monad, kind):
        assert 1 <= n <= 4
        self._ops[n] = (name, monad, kind)
        self._arity = max(self._arity, n)

    def getOp(self, n):
        assert 1 <= n <= 4
        op = self._ops.get(n, None)
        if op is not None:
            return op
        if not self.isRoot:
            return self.superClass.getOp(n)
        return None

    # Computes derived properties over entire class hierarchy.
    # No more changes to the hierarchy are allowed once this was called
    def complete(self, typeId=0, ordIdx=0):
        assert not self.isCompleted
        # Sort sub-classes and convert to tuple, which marks completion
        self._subClasses = tuple(
            sorted(self._subClasses,
                   key=lambda _: (bool(_._subClasses), _.name)))

        self._ordIdx = ordIdx
        ordIdx = ordIdx + 1

        if self.isRoot:
            self._arity = 0
        else:
            self._arity = max(self._arity, self._superClass.arity)

        # Leaves
        if self.isLeaf:
            self._typeId = typeId
            return typeId + 1, ordIdx

        # Non-leaves
        for subClass in self._subClasses:
            typeId, ordIdx = subClass.complete(typeId, ordIdx)
        return typeId, ordIdx

    # Post completion methods
    @property
    def subClasses(self):
        assert self.isCompleted
        return self._subClasses

    @property
    def isLeaf(self):
        assert self.isCompleted
        return not self.subClasses

    @property
    def allSuperClasses(self):
        assert self.isCompleted
        if self._allSuperClasses is None:
            if self.superClass is None:
                self._allSuperClasses = ()
            else:
                self._allSuperClasses = self.superClass.allSuperClasses + (
                    self.superClass, )
        return self._allSuperClasses

    @property
    def allSubClasses(self):
        assert self.isCompleted
        if self._allSubClasses is None:
            if self.isLeaf:
                self._allSubClasses = ()
            else:
                self._allSubClasses = self.subClasses + tuple(
                    _ for subClass in self.subClasses
                    for _ in subClass.allSubClasses)
        return self._allSubClasses

    @property
    def typeId(self):
        assert self.isCompleted
        assert self.isLeaf
        return self._typeId

    @property
    def typeIdMin(self):
        assert self.isCompleted
        if self.isLeaf:
            return self.typeId
        if self._typeIdMin is None:
            self._typeIdMin = min(_.typeIdMin for _ in self.allSubClasses)
        return self._typeIdMin

    @property
    def typeIdMax(self):
        assert self.isCompleted
        if self.isLeaf:
            return self.typeId
        if self._typeIdMax is None:
            self._typeIdMax = max(_.typeIdMax for _ in self.allSubClasses)
        return self._typeIdMax

    @property
    def ordIdx(self):
        assert self.isCompleted
        return self._ordIdx

    @property
    def arity(self):
        assert self.isCompleted
        return self._arity

    def isSubClassOf(self, other):
        assert self.isCompleted
        if self is other:
            return True
        return self in other.allSubClasses


Nodes = {}
SortedNodes = None
DfgVertices = None

ClassRefs = {}
Stages = {}


class Cpt:
    def __init__(self):
        self.did_out_tree = False
        self.in_filename = ""
        self.in_linenum = 1
        self.out_filename = ""
        self.out_linenum = 1
        self.out_lines = []
        self.tree_skip_visit = {}
        self.treeop = {}
        self._exec_nsyms = 0
        self._exec_syms = {}

    def error(self, txt):
        sys.exit("%%Error: %s:%d: %s" %
                 (self.in_filename, self.in_linenum, txt))

    def print(self, txt):
        self.out_lines.append(txt)

    def output_func(self, func):
        self.out_lines.append(func)

    def _output_line(self):
        self.print("#line " + str(self.out_linenum + 2) + " \"" +
                   self.out_filename + "\"\n")

    def process(self, in_filename, out_filename):
        self.in_filename = in_filename
        self.out_filename = out_filename
        ln = 0
        didln = False

        # Read the file and parse into list of functions that generate output
        with open(self.in_filename) as fhi:
            for line in fhi:
                ln += 1
                if not didln:
                    self.print("#line " + str(ln) + " \"" + self.in_filename +
                               "\"\n")
                    didln = True
                match = re.match(r'^\s+(TREE.*)$', line)
                if match:
                    func = match.group(1)
                    self.in_linenum = ln
                    self.print("//" + line)
                    self.output_func(lambda self: self._output_line())
                    self.tree_line(func)
                    didln = False
                elif not re.match(r'^\s*/[/\*]\s*TREE', line) and re.search(
                        r'\s+TREE', line):
                    self.error("Unknown astgen line: " + line)
                else:
                    self.print(line)

        # Put out the resultant file, if the list has a reference to a
        # function, then call that func to generate output
        with open_file(self.out_filename) as fho:
            togen = self.out_lines
            for line in togen:
                if type(line) is str:
                    self.out_lines = [line]
                else:
                    self.out_lines = []
                    line(self)  # lambda call
                for out in self.out_lines:
                    for _ in re.findall(r'\n', out):
                        self.out_linenum += 1
                    fho.write(out)

    def tree_line(self, func):
        func = re.sub(r'\s*//.*$', '', func)
        func = re.sub(r'\s*;\s*$', '', func)

        # doflag "S" indicates an op specifying short-circuiting for a type.
        match = re.search(
            #       1   2                 3                  4
            r'TREEOP(1?)([ACSV]?)\s*\(\s*\"([^\"]*)\"\s*,\s*\"([^\"]*)\"\s*\)',
            func)
        match_skip = re.search(r'TREE_SKIP_VISIT\s*\(\s*\"([^\"]*)\"\s*\)',
                               func)

        if match:
            order = match.group(1)
            doflag = match.group(2)
            fromn = match.group(3)
            to = match.group(4)
            # self.print("// $fromn $to\n")
            if not self.did_out_tree:
                self.did_out_tree = True
                self.output_func(lambda self: self.tree_match_base())
            match = re.search(r'Ast([a-zA-Z0-9]+)\s*\{(.*)\}\s*$', fromn)
            if not match:
                self.error("Can't parse from function: " + func)
            typen = match.group(1)
            subnodes = match.group(2)
            if Nodes[typen].isRoot:
                self.error("Unknown AstNode typen: " + typen + ": in " + func)

            mif = ""
            if doflag == '':
                mif = "m_doNConst"
            elif doflag == 'A':
                mif = ""
            elif doflag == 'C':
                mif = "m_doCpp"
            elif doflag == 'S':
                mif = "m_doNConst"  # Not just for m_doGenerate
            elif doflag == 'V':
                mif = "m_doV"
            else:
                self.error("Unknown flag: " + doflag)
            subnodes = re.sub(r',,', '__ESCAPEDCOMMA__', subnodes)
            for subnode in re.split(r'\s*,\s*', subnodes):
                subnode = re.sub(r'__ESCAPEDCOMMA__', ',', subnode)
                if re.match(r'^\$([a-zA-Z0-9]+)$', subnode):
                    continue  # "$lhs" is just a comment that this op has a lhs
                subnodeif = subnode
                subnodeif = re.sub(
                    r'\$([a-zA-Z0-9]+)\.cast([A-Z][A-Za-z0-9]+)$',
                    r'VN_IS(nodep->\1(),\2)', subnodeif)
                subnodeif = re.sub(r'\$([a-zA-Z0-9]+)\.([a-zA-Z0-9]+)$',
                                   r'nodep->\1()->\2()', subnodeif)
                subnodeif = self.add_nodep(subnodeif)
                if mif != "" and subnodeif != "":
                    mif += " && "
                mif += subnodeif

            exec_func = self.treeop_exec_func(to)
            exec_func = re.sub(
                r'([-()a-zA-Z0-9_>]+)->cast([A-Z][A-Za-z0-9]+)\(\)',
                r'VN_CAST(\1,\2)', exec_func)

            if typen not in self.treeop:
                self.treeop[typen] = []
            n = len(self.treeop[typen])
            typefunc = {
                'order': order,
                'comment': func,
                'match_func': "match_" + typen + "_" + str(n),
                'match_if': mif,
                'exec_func': exec_func,
                'uinfo': re.sub(r'[ \t\"\{\}]+', ' ', func),
                'uinfo_level': (0 if re.match(r'^!', to) else 7),
                'short_circuit': (doflag == 'S'),
            }
            self.treeop[typen].append(typefunc)

        elif match_skip:
            typen = match_skip.group(1)
            self.tree_skip_visit[typen] = 1
            if typen not in Nodes:
                self.error("Unknown node type: " + typen)

        else:
            self.error("Unknown astgen op: " + func)

    @staticmethod
    def add_nodep(strg):
        strg = re.sub(r'\$([a-zA-Z0-9]+)', r'nodep->\1()', strg)
        return strg

    def _exec_syms_recurse(self, aref):
        for sym in aref:
            if type(sym) is list:
                self._exec_syms_recurse(sym)
            elif re.search(r'^\$.*', sym):
                if sym not in self._exec_syms:
                    self._exec_nsyms += 1
                    self._exec_syms[sym] = "arg" + str(self._exec_nsyms) + "p"

    def _exec_new_recurse(self, aref):
        out = "new " + aref[0] + "(nodep->fileline()"
        first = True
        for sym in aref:
            if first:
                first = False
                continue
            out += ", "
            if type(sym) is list:
                out += self._exec_new_recurse(sym)
            elif re.match(r'^\$.*', sym):
                out += self._exec_syms[sym]
            else:
                out += sym
        return out + ")"

    def treeop_exec_func(self, func):
        out = ""
        func = re.sub(r'^!', '', func)

        if re.match(r'^\s*[a-zA-Z0-9]+\s*\(', func):  # Function call
            outl = re.sub(r'\$([a-zA-Z0-9]+)', r'nodep->\1()', func)
            out += outl + ";"
        elif re.match(r'^\s*Ast([a-zA-Z0-9]+)\s*\{\s*(.*)\s*\}$', func):
            aref = None
            # Recursive array with structure to form
            astack = []
            forming = ""
            argtext = func + "\000"  # EOF character
            for tok in argtext:
                if tok == "\000":
                    pass
                elif re.match(r'\s+', tok):
                    pass
                elif tok == "{":
                    newref = [forming]
                    if not aref:
                        aref = []
                    aref.append(newref)
                    astack.append(aref)
                    aref = newref
                    forming = ""
                elif tok == "}":
                    if forming:
                        aref.append(forming)
                    if len(astack) == 0:
                        self.error("Too many } in execution function: " + func)
                    aref = astack.pop()
                    forming = ""
                elif tok == ",":
                    if forming:
                        aref.append(forming)
                    forming = ""
                else:
                    forming += tok
            if not (aref and len(aref) == 1):
                self.error("Badly formed execution function: " + func)
            aref = aref[0]

            # Assign numbers to each $ symbol
            self._exec_syms = {}
            self._exec_nsyms = 0
            self._exec_syms_recurse(aref)

            for sym in sorted(self._exec_syms.keys(),
                              key=lambda val: self._exec_syms[val]):
                argnp = self._exec_syms[sym]
                arg = self.add_nodep(sym)
                out += "AstNode* " + argnp + " = " + arg + "->unlinkFrBack();\n"

            out += "AstNode* newp = " + self._exec_new_recurse(aref) + ";\n"
            out += "nodep->replaceWith(newp);"
            out += "VL_DO_DANGLING(nodep->deleteTree(), nodep);"
        elif func == "NEVER":
            out += "nodep->v3fatalSrc(\"Executing transform that was NEVERed\");"
        elif func == "DONE":
            pass
        else:
            self.error("Unknown execution function format: " + func + "\n")
        return out

    def tree_match_base(self):
        self.tree_match()
        self.tree_base()

    def tree_match(self):
        self.print(
            "    // TREEOP functions, each return true if they matched & transformed\n"
        )
        for base in sorted(self.treeop.keys()):
            for typefunc in self.treeop[base]:
                self.print("    // Generated by astgen\n")
                self.print("    bool " + typefunc['match_func'] + "(Ast" +
                           base + "* nodep) {\n")
                self.print("\t// " + typefunc['comment'] + "\n")
                self.print("\tif (" + typefunc['match_if'] + ") {\n")
                self.print("\t    UINFO(" + str(typefunc['uinfo_level']) +
                           ", cvtToHex(nodep)" + " << \" " +
                           typefunc['uinfo'] + "\\n\");\n")
                self.print("\t    " + typefunc['exec_func'] + "\n")
                self.print("\t    return true;\n")
                self.print("\t}\n")
                self.print("\treturn false;\n")
                self.print("    }\n", )

    def tree_base(self):
        self.print("    // TREEOP visitors, call each base type's match\n")
        self.print(
            "    // Bottom class up, as more simple transforms are generally better\n"
        )
        for node in SortedNodes:
            out_for_type_sc = []
            out_for_type = []
            classes = list(node.allSuperClasses)
            classes.append(node)
            for base in classes:
                base = base.name
                if base not in self.treeop:
                    continue
                for typefunc in self.treeop[base]:
                    lines = [
                        "        if (" + typefunc['match_func'] +
                        "(nodep)) return;\n"
                    ]
                    if typefunc['short_circuit']:  # short-circuit match fn
                        out_for_type_sc.extend(lines)
                    else:  # Standard match fn
                        if typefunc[
                                'order']:  # TREEOP1's go in front of others
                            out_for_type = lines + out_for_type
                        else:
                            out_for_type.extend(lines)

            # We need to deal with two cases. For short circuited functions we
            # evaluate the LHS, then apply the short-circuit matches, then
            # evaluate the RHS and possibly THS (ternary operators may
            # short-circuit) and apply all the other matches.

            # For types without short-circuits, we just use iterateChildren, which
            # saves one comparison.
            if len(out_for_type_sc) > 0:  # Short-circuited types
                self.print(
                    "    // Generated by astgen with short-circuiting\n" +
                    "    void visit(Ast" + node.name +
                    "* nodep) override {\n" +
                    "      iterateAndNextNull(nodep->lhsp());\n" +
                    "".join(out_for_type_sc))
                if out_for_type[0]:
                    self.print("      iterateAndNextNull(nodep->rhsp());\n")
                    if node.isSubClassOf(Nodes["NodeTriop"]):
                        self.print(
                            "      iterateAndNextNull(nodep->thsp());\n")
                    self.print("".join(out_for_type) + "    }\n")
            elif len(out_for_type) > 0:  # Other types with something to print
                skip = node.name in self.tree_skip_visit
                gen = "Gen" if skip else ""
                virtual = "virtual " if skip else ""
                override = "" if skip else " override"
                self.print(
                    "    // Generated by astgen\n" + "    " + virtual +
                    "void visit" + gen + "(Ast" + node.name + "* nodep)" +
                    override + " {\n" +
                    ("" if skip else "        iterateChildren(nodep);\n") +
                    ''.join(out_for_type) + "    }\n")


######################################################################
######################################################################


def partitionAndStrip(string, separator):
    return map(lambda _: _.strip(), string.partition(separator))


def parseOpType(string):
    match = re.match(r'^(\w+)\[(\w+)\]$', string)
    if match:
        monad, kind = match.groups()
        if monad not in ("Optional", "List"):
            return None
        kind = parseOpType(kind)
        if not kind or kind[0]:
            return None
        return monad, kind[1]
    if re.match(r'^Ast(\w+)$', string):
        return "", string[3:]
    return None


def read_types(filename):
    hasErrors = False

    def error(lineno, message):
        nonlocal hasErrors
        print(filename + ":" + str(lineno) + ": %Error: " + message,
              file=sys.stderr)
        hasErrors = True

    node = None
    hasAstgenMembers = False

    def checkFinishedNode(node):
        nonlocal hasAstgenMembers
        if not node:
            return
        if not hasAstgenMembers:
            error(
                node.lineno, "'Ast" + node.name +
                "' does not contain 'ASTGEN_MEMBERS_" + node.name + ";'")
        hasAstgenMembers = False

    with open(filename) as fh:
        for (lineno, line) in enumerate(fh, start=1):
            line = line.strip()
            if not line:
                continue

            match = re.search(r'^\s*(class|struct)\s*(\S+)', line)
            if match:
                classn = match.group(2)
                match = re.search(r':\s*public\s+(\S+)', line)
                supern = match.group(1) if match else ""
                if re.search(r'Ast', supern):
                    classn = re.sub(r'^Ast', '', classn)
                    supern = re.sub(r'^Ast', '', supern)
                    if not supern:
                        sys.exit("%Error: 'Ast{}' has no super-class".format(
                            classn))
                    checkFinishedNode(node)
                    superClass = Nodes[supern]
                    node = Node(classn, superClass, filename, lineno)
                    superClass.addSubClass(node)
                    Nodes[classn] = node
            if not node:
                continue

            if re.match(r'^\s*ASTGEN_MEMBERS_' + node.name + ';', line):
                hasAstgenMembers = True
            match = re.match(r'^\s*//\s*@astgen\s+(.*)$', line)
            if match:
                decl = re.sub(r'//.*$', '', match.group(1))
                what, sep, rest = partitionAndStrip(decl, ":=")
                what = re.sub(r'\s+', ' ', what)
                if not sep:
                    error(
                        lineno,
                        "Malformed '@astgen' directive (expecting '<keywords> := <description>'): "
                        + decl)
                elif what in ("op1", "op2", "op3", "op4"):
                    n = int(what[-1])
                    ident, sep, kind = partitionAndStrip(rest, ":")
                    ident = ident.strip()
                    if not sep or not re.match(r'^\w+$', ident):
                        error(
                            lineno, "Malformed '@astgen " + what +
                            "' directive (expecting '" + what +
                            " := <identifier> : <type>': " + decl)
                    else:
                        kind = parseOpType(kind)
                        if not kind:
                            error(
                                lineno, "Bad type for '@astgen " + what +
                                "' (expecting Ast*, Optional[Ast*], or List[Ast*]):"
                                + decl)
                        elif node.getOp(n) is not None:
                            error(
                                lineno, "Already defined " + what + " for " +
                                node.name)
                        else:
                            node.addOp(n, ident, *kind)
                elif what in ("alias op1", "alias op2", "alias op3",
                              "alias op4"):
                    n = int(what[-1])
                    ident = rest.strip()
                    if not re.match(r'^\w+$', ident):
                        error(
                            lineno, "Malformed '@astgen " + what +
                            "' directive (expecting '" + what +
                            " := <identifier>': " + decl)
                    else:
                        op = node.getOp(n)
                        if op is None:
                            error(lineno,
                                  "Alaised op" + str(n) + " is not defined")
                        else:
                            node.addOp(n, ident, *op[1:])
            else:
                line = re.sub(r'//.*$', '', line)
                if re.match(r'.*[Oo]p[1-9].*', line):
                    error(lineno,
                          "Use generated accessors to access op<N> operands")

        checkFinishedNode(node)
    if hasErrors:
        sys.exit("%Error: Stopping due to errors reported above")


def read_stages(filename):
    with open(filename) as fh:
        n = 100
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            if re.match(r'^\s*$', line):
                continue
            match = re.search(r'\s([A-Za-z0-9]+)::', line)
            if match:
                stage = match.group(1) + ".cpp"
                if stage not in Stages:
                    Stages[stage] = n
                    n += 1


def read_refs(filename):
    basename = re.sub(r'.*/', '', filename)
    with open(filename) as fh:
        for line in fh:
            line = re.sub(r'//.*$', '', line)
            for match in re.finditer(r'\bnew\s*(Ast[A-Za-z0-9_]+)', line):
                ref = match.group(1)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['newed'][basename] = 1
            for match in re.finditer(r'\b(Ast[A-Za-z0-9_]+)', line):
                ref = match.group(1)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['used'][basename] = 1
            for match in re.finditer(
                    r'(VN_IS|VN_AS|VN_CAST)\([^.]+, ([A-Za-z0-9_]+)', line):
                ref = "Ast" + match.group(2)
                if ref not in ClassRefs:
                    ClassRefs[ref] = {'newed': {}, 'used': {}}
                ClassRefs[ref]['used'][basename] = 1


def open_file(filename):
    fh = open(filename, "w")
    if re.search(r'\.txt$', filename):
        fh.write("// Generated by astgen\n")
    else:
        fh.write(
            '// Generated by astgen // -*- mode: C++; c-file-style: "cc-mode" -*-'
            + "\n")
    return fh


# ---------------------------------------------------------------------


def write_report(filename):
    with open_file(filename) as fh:

        fh.write(
            "Processing stages (approximate, based on order in Verilator.cpp):\n"
        )
        for classn in sorted(Stages.keys(), key=lambda val: Stages[val]):
            fh.write("  " + classn + "\n")

        fh.write("\nClasses:\n")
        for node in SortedNodes:
            fh.write("  class Ast%-17s\n" % node.name)
            fh.write("    arity:  {}\n".format(node.arity))
            fh.write("    parent: ")
            for superClass in node.allSuperClasses:
                if not superClass.isRoot:
                    fh.write("Ast%-12s " % superClass.name)
            fh.write("\n")
            fh.write("    childs:  ")
            for subClass in node.allSubClasses:
                fh.write("Ast%-12s " % subClass.name)
            fh.write("\n")
            if ("Ast" + node.name) in ClassRefs:  # pylint: disable=superfluous-parens
                refs = ClassRefs["Ast" + node.name]
                fh.write("    newed:  ")
                for stage in sorted(refs['newed'].keys(),
                                    key=lambda val: Stages[val]
                                    if (val in Stages) else -1):
                    fh.write(stage + "  ")
                fh.write("\n")
                fh.write("    used:   ")
                for stage in sorted(refs['used'].keys(),
                                    key=lambda val: Stages[val]
                                    if (val in Stages) else -1):
                    fh.write(stage + "  ")
                fh.write("\n")
            fh.write("\n")


def write_classes(filename):
    with open_file(filename) as fh:
        fh.write("class AstNode;\n")
        for node in SortedNodes:
            fh.write("class Ast%-17s // " % (node.name + ";"))
            for superClass in node.allSuperClasses:
                fh.write("Ast%-12s " % superClass.name)
            fh.write("\n")


def write_visitor_decls(filename):
    with open_file(filename) as fh:
        for node in SortedNodes:
            if not node.isRoot:
                fh.write("virtual void visit(Ast" + node.name + "*);\n")


def write_visitor_defns(filename):
    with open_file(filename) as fh:
        for node in SortedNodes:
            base = node.superClass
            if base is not None:
                fh.write("void VNVisitor::visit(Ast" + node.name +
                         "* nodep) { visit(static_cast<Ast" + base.name +
                         "*>(nodep)); }\n")


def write_impl(filename):
    with open_file(filename) as fh:
        fh.write("\n")
        fh.write("// For internal use. They assume argument is not nullptr.\n")
        for node in SortedNodes:
            fh.write("template<> inline bool AstNode::privateTypeTest<Ast" +
                     node.name + ">(const AstNode* nodep) { ")
            if node.isRoot:
                fh.write("return true; ")
            else:
                fh.write("return ")
                if not node.isLeaf:
                    fh.write(
                        "static_cast<int>(nodep->type()) >= static_cast<int>(VNType::first"
                        + node.name + ") && ")
                    fh.write(
                        "static_cast<int>(nodep->type()) <= static_cast<int>(VNType::last"
                        + node.name + "); ")
                else:
                    fh.write("nodep->type() == VNType::at" + node.name + "; ")
            fh.write("}\n")


def write_types(filename):
    with open_file(filename) as fh:
        fh.write("    enum en : uint16_t {\n")
        for node in sorted(filter(lambda _: _.isLeaf, SortedNodes),
                           key=lambda _: _.typeId):
            fh.write("        at" + node.name + " = " + str(node.typeId) +
                     ",\n")
        fh.write("        _ENUM_END = " + str(Nodes["Node"].typeIdMax + 1) +
                 "\n")
        fh.write("    };\n")

        fh.write("    enum bounds : uint16_t {\n")
        for node in sorted(filter(lambda _: not _.isLeaf, SortedNodes),
                           key=lambda _: _.typeIdMin):
            fh.write("        first" + node.name + " = " +
                     str(node.typeIdMin) + ",\n")
            fh.write("        last" + node.name + " = " + str(node.typeIdMax) +
                     ",\n")
        fh.write("        _BOUNDS_END\n")
        fh.write("    };\n")

        fh.write("    const char* ascii() const {\n")
        fh.write("        static const char* const names[_ENUM_END + 1] = {\n")
        for node in sorted(filter(lambda _: _.isLeaf, SortedNodes),
                           key=lambda _: _.typeId):
            fh.write("            \"" + node.name.upper() + "\",\n")
        fh.write("            \"_ENUM_END\"\n")
        fh.write("        };\n")
        fh.write("        return names[m_e];\n")
        fh.write("    }\n")


def write_yystype(filename):
    with open_file(filename) as fh:
        for node in SortedNodes:
            fh.write("Ast{t}* {m}p;\n".format(t=node.name,
                                              m=node.name[0].lower() +
                                              node.name[1:]))


def write_macros(filename):
    with open_file(filename) as fh:

        def emitBlock(pattern, **fmt):
            fh.write(
                textwrap.indent(textwrap.dedent(pattern),
                                "    ").format(**fmt).replace("\n", " \\\n"))

        for node in SortedNodes:
            fh.write("#define ASTGEN_MEMBERS_{t} \\\n".format(t=node.name))
            emitBlock('''\
            static Ast{t}* cloneTreeNull(Ast{t}* nodep, bool cloneNextLink) {{
                return nodep ? nodep->cloneTree(cloneNextLink) : nullptr;
            }}
            Ast{t}* cloneTree(bool cloneNext) {{
                return static_cast<Ast{t}*>(AstNode::cloneTree(cloneNext));
            }}
            Ast{t}* clonep() const {{ return static_cast<Ast{t}*>(AstNode::clonep()); }}
            Ast{t}* addNext(Ast{t}* nodep) {{ return static_cast<Ast{t}*>(AstNode::addNext(this, nodep)); }}
            ''',
                      t=node.name)

            if node.isLeaf:
                emitBlock('''\
                void accept(VNVisitor& v) override {{ v.visit(this); }}
                AstNode* clone() override {{ return new Ast{t}(*this); }}
                ''',
                          t=node.name)

            for n in (1, 2, 3, 4):
                op = node.getOp(n)
                if not op:
                    continue
                name, monad, kind = op
                retrieve = ("VN_AS(op{n}p(), {kind})" if kind != "Node" else
                            "op{n}p()").format(n=n, kind=kind)
                if monad == "List":
                    emitBlock('''\
                    Ast{kind}* {name}() const {{ return {retrieve}; }}
                    void add{Name}(Ast{kind}* nodep) {{ addNOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              Name=name[0].upper() + name[1:],
                              n=n,
                              retrieve=retrieve)
                elif monad == "Optional":
                    emitBlock('''\
                    Ast{kind}* {name}() const {{ return {retrieve}; }}
                    void {name}(Ast{kind}* nodep) {{ setNOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              n=n,
                              retrieve=retrieve)
                else:
                    emitBlock('''\
                    Ast{kind}* {name}() const {{ return {retrieve}; }}
                    void {name}(Ast{kind}* nodep) {{ setOp{n}p(reinterpret_cast<AstNode*>(nodep)); }}
                    ''',
                              kind=kind,
                              name=name,
                              n=n,
                              retrieve=retrieve)

            fh.write(
                "    static_assert(true, \"\")\n")  # Swallowing the semicolon

            # Only care about leaf classes for the rest
            if node.isLeaf:
                fh.write(
                    "#define ASTGEN_SUPER_{t}(...) Ast{b}(VNType::at{t}, __VA_ARGS__)\n"
                    .format(t=node.name, b=node.superClass.name))
            fh.write("\n")


def write_op_checks(filename):
    with open_file(filename) as fh:

        indent = ""

        def emitBlock(pattern, **fmt):
            fh.write(
                textwrap.indent(textwrap.dedent(pattern),
                                indent).format(**fmt))

        for node in SortedNodes:
            if not node.isLeaf:
                continue

            emitBlock('''\
                case VNType::at{nodeName}: {{
                    const Ast{nodeName}* const currp = static_cast<const Ast{nodeName}*>(this);
                ''',
                      nodeName=node.name)
            indent = "    "
            for n in range(1, 5):
                op = node.getOp(n)
                emitBlock("// Checking op{n}p\n", n=n)
                if op:
                    name, monad, kind = op
                    if not monad:
                        emitBlock('''\
                            UASSERT_OBJ(currp->{opName}(), currp, "Ast{nodeName} must have non nullptr {opName}()");
                            UASSERT_OBJ(!currp->{opName}()->nextp(), currp, "Ast{nodeName}::{opName}() cannot have a non nullptr nextp()");
                            currp->{opName}()->checkTreeIter(currp);
                            ''',
                                  n=n,
                                  nodeName=node.name,
                                  opName=name)
                    elif monad == "Optional":
                        emitBlock('''\
                            if (Ast{kind}* const opp = currp->{opName}()) {{
                                UASSERT_OBJ(!currp->{opName}()->nextp(), currp, "Ast{nodeName}::{opName}() cannot have a non nullptr nextp()");
                                opp->checkTreeIter(currp);
                            }}
                            ''',
                                  n=n,
                                  nodeName=node.name,
                                  opName=name,
                                  kind=kind)
                    elif monad == "List":
                        emitBlock('''\
                            if (const Ast{kind}* const headp = currp->{opName}()) {{
                                const AstNode* backp = currp;
                                const Ast{kind}* tailp = headp;
                                const Ast{kind}* opp = headp;
                                do {{
                                    opp->checkTreeIter(backp);
                                    UASSERT_OBJ(opp == headp || !opp->nextp() || !opp->m_headtailp, opp, "Headtailp should be null in middle of lists");
                                    backp = tailp = opp;
                                    opp = {next};
                                }} while (opp);
                                UASSERT_OBJ(headp->m_headtailp == tailp, headp, "Tail in headtailp is inconsistent");
                                UASSERT_OBJ(tailp->m_headtailp == headp, tailp, "Head in headtailp is inconsistent");
                            }}
                            ''',
                                  n=n,
                                  nodeName=node.name,
                                  opName=name,
                                  kind=kind,
                                  next="VN_AS(opp->nextp(), {kind})".format(
                                      kind=kind)
                                  if kind != "Node" else "opp->nextp()")
                    else:
                        sys.exit("Unknown operand type")
                else:
                    emitBlock('''\
                        UASSERT_OBJ(!currp->op{n}p(), currp, "Ast{nodeName} does not use op{n}p()");
                        ''',
                              n=n,
                              nodeName=node.name)
            indent = ""
            emitBlock('''\
                    break;
                }}
                ''')


def write_dfg_vertex_classes(filename):
    with open_file(filename) as fh:
        fh.write("\n")
        for node in DfgVertices:
            fh.write("class Dfg{} final : public DfgVertexWithArity<{}> {{\n".
                     format(node.name, node.arity))
            fh.write("    friend class DfgVertex;\n")
            fh.write("    friend class DfgVisitor;\n")
            fh.write("    void accept(DfgVisitor& visitor) override;\n")
            fh.write(
                "    static constexpr DfgType dfgType() {{ return DfgType::at{t}; }};\n"
                .format(t=node.name))
            fh.write("public:\n")
            fh.write(
                "    Dfg{t}(DfgGraph& dfg, FileLine* flp, AstNodeDType* dtypep) : DfgVertexWithArity<{a}>{{dfg, flp, dtypep, dfgType()}} {{}}\n"
                .format(t=node.name, a=node.arity))
            # Accessors
            operandNames = tuple(
                node.getOp(n)[0] for n in range(1, node.arity + 1))
            assert not operandNames or len(operandNames) == node.arity
            for i, n in enumerate(operandNames):
                fh.write(
                    "    DfgVertex* {n}() const {{ return source<{i}>(); }}\n".
                    format(n=n, i=i))
            for i, n in enumerate(operandNames):
                fh.write(
                    "    void {n}(DfgVertex* vtxp) {{ relinkSource<{i}>(vtxp); }}\n"
                    .format(n=n, i=i))
            if operandNames:
                names = ", ".join(map(lambda _: '"' + _ + '"', operandNames))
                fh.write(
                    "    const string srcName(size_t idx) const override {\n")
                fh.write(
                    "        static const char* names[{a}] = {{ {ns} }};\n".
                    format(a=node.arity, ns=names))
                fh.write("        return names[idx];\n")
                fh.write("    }\n")
            fh.write("};\n")
            fh.write("\n")
        fh.write("\n")

        fh.write("\n\ntemplate<typename Node>\n")
        fh.write("struct DfgForAstImpl;\n\n")
        for node in DfgVertices:
            fh.write("template <>\n")
            fh.write(
                "struct DfgForAstImpl<Ast{name}> {{\n".format(name=node.name))
            fh.write("  using type = Dfg{name};\n".format(name=node.name))
            fh.write("};\n")
        fh.write("\ntemplate<typename Node>\n")
        fh.write("using DfgForAst = typename DfgForAstImpl<Node>::type;\n")

        fh.write("\n\ntemplate<typename Vertex>\n")
        fh.write("struct AstForDfgImpl;\n\n")
        for node in DfgVertices:
            fh.write("template <>\n")
            fh.write(
                "struct AstForDfgImpl<Dfg{name}> {{\n".format(name=node.name))
            fh.write("  using type = Ast{name};\n".format(name=node.name))
            fh.write("};\n")
        fh.write("\ntemplate<typename Vertex>\n")
        fh.write("using AstForDfg = typename AstForDfgImpl<Vertex>::type;\n")


def write_dfg_visitor_decls(filename):
    with open_file(filename) as fh:
        fh.write("\n")
        fh.write("virtual void visit(DfgVertex*) = 0;\n")
        for node in DfgVertices:
            fh.write("virtual void visit(Dfg{}*);\n".format(node.name))


def write_dfg_definitions(filename):
    with open_file(filename) as fh:
        fh.write("\n")
        for node in DfgVertices:
            fh.write(
                "void Dfg{}::accept(DfgVisitor& visitor) {{ visitor.visit(this); }}\n"
                .format(node.name))
        fh.write("\n")
        for node in DfgVertices:
            fh.write(
                "void DfgVisitor::visit(Dfg{}* vtxp) {{ visit(static_cast<DfgVertex*>(vtxp)); }}\n"
                .format(node.name))


def write_dfg_ast_to_dfg(filename):
    with open_file(filename) as fh:
        fh.write("\n")
        for node in DfgVertices:
            fh.write(
                "void visit(Ast{t}* nodep) override {{\n".format(t=node.name))
            fh.write(
                '    UASSERT_OBJ(!nodep->user1p(), nodep, "Already has Dfg vertex");\n'
            )
            fh.write("    if (unhandled(nodep)) return;\n")
            fh.write(
                "    Dfg{t}* const vtxp = makeVertex<Dfg{t}>(nodep, *m_dfgp);\n"
                .format(t=node.name))
            fh.write("    if (!vtxp) {\n")
            fh.write("        m_foundUnhandled = true;\n")
            fh.write("        ++m_ctx.m_nonRepNode;\n")
            fh.write("        return;\n")
            fh.write("    }\n\n")
            fh.write("    m_uncommittedVertices.push_back(vtxp);\n")
            for i in range(node.arity):
                fh.write("    iterate(nodep->op{j}p());\n".format(j=i + 1))
                fh.write("    if (m_foundUnhandled) return;\n")
                fh.write(
                    '    UASSERT_OBJ(nodep->op{j}p()->user1p(), nodep, "Child {j} missing Dfg vertex");\n'
                    .format(j=i + 1))
                fh.write(
                    "    vtxp->relinkSource<{i}>(nodep->op{j}p()->user1u().to<DfgVertex*>());\n\n"
                    .format(i=i, j=i + 1))
            fh.write("    nodep->user1p(vtxp);\n")
            fh.write("}\n")


def write_dfg_dfg_to_ast(filename):
    with open_file(filename) as fh:
        fh.write("\n")
        for node in DfgVertices:
            fh.write(
                "void visit(Dfg{t}* vtxp) override {{\n".format(t=node.name))
            for i in range(node.arity):
                fh.write(
                    "    AstNodeMath* const op{j}p = convertSource(vtxp->source<{i}>());\n"
                    .format(i=i, j=i + 1))
            fh.write(
                "    m_resultp = makeNode<Ast{t}>(vtxp".format(t=node.name))
            for i in range(node.arity):
                fh.write(", op{j}p".format(j=i + 1))
            fh.write(");\n")
            fh.write("}\n")


######################################################################
# main

parser = argparse.ArgumentParser(
    allow_abbrev=False,
    formatter_class=argparse.RawDescriptionHelpFormatter,
    description="""Generate V3Ast headers to reduce C++ code duplication.""",
    epilog=
    """Copyright 2002-2022 by Wilson Snyder. This program is free software; you
can redistribute it and/or modify it under the terms of either the GNU
Lesser General Public License Version 3 or the Perl Artistic License
Version 2.0.

SPDX-License-Identifier: LGPL-3.0-only OR Artistic-2.0""")

parser.add_argument('-I', action='store', help='source code include directory')
parser.add_argument('--astdef',
                    action='append',
                    help='add AST definition file (relative to -I)')
parser.add_argument('--classes',
                    action='store_true',
                    help='makes class declaration files')
parser.add_argument('--debug', action='store_true', help='enable debug')

parser.add_argument('infiles', nargs='*', help='list of input .cpp filenames')

Args = parser.parse_args()

# Set up the root AstNode type. It is standalone so we don't need to parse the
# sources for this.
Nodes["Node"] = Node("Node", None, "AstNode", 1)

# Read Ast node definitions
for filename in Args.astdef:
    read_types(os.path.join(Args.I, filename))

# Compute derived properties over the whole AstNode hierarchy
Nodes["Node"].complete()

SortedNodes = tuple(map(lambda _: Nodes[_], sorted(Nodes.keys())))

for node in SortedNodes:
    # Check all leaves are not AstNode* and non-leaves are AstNode*
    if re.match(r'^Node', node.name):
        if node.isLeaf:
            sys.exit(
                "%Error: Final AstNode subclasses must not be named AstNode*: Ast"
                + node.name)
    else:
        if not node.isLeaf:
            sys.exit(
                "%Error: Non-final AstNode subclasses must be named AstNode*: Ast"
                + node.name)

DfgBases = (Nodes["NodeUniop"], Nodes["NodeBiop"], Nodes["NodeTriop"])
DfgVertices = tuple(
    node for node in SortedNodes
    if node.isLeaf and any(node.isSubClassOf(base) for base in DfgBases))

# Check ordering of node definitions
files = tuple(sorted(set(_.file for _ in SortedNodes)))

hasOrderingError = False
for file in files:
    nodes = tuple(filter(lambda _, f=file: _.file == f, SortedNodes))
    expectOrder = tuple(sorted(nodes, key=lambda _: (_.isLeaf, _.ordIdx)))
    actualOrder = tuple(sorted(nodes, key=lambda _: _.lineno))
    expect = {
        node: pred
        for pred, node in zip((None, ) + expectOrder[:-1], expectOrder)
    }
    actual = {
        node: pred
        for pred, node in zip((None, ) + actualOrder[:-1], actualOrder)
    }
    for node in nodes:
        if expect[node] != actual[node]:
            hasOrderingError = True
            pred = expect[node]
            print(file + ":" + str(node.lineno) +
                  ": %Error: Definition of 'Ast" + node.name +
                  "' is out of order. Should be " +
                  ("right after 'Ast" + pred.name +
                   "'" if pred else "first in file") + ".",
                  file=sys.stderr)

if hasOrderingError:
    sys.exit("%Error: Stopping due to out of order definitions listed above")

read_stages(Args.I + "/Verilator.cpp")

source_files = glob.glob(Args.I + "/*.y")
source_files.extend(glob.glob(Args.I + "/*.h"))
source_files.extend(glob.glob(Args.I + "/*.cpp"))
for filename in source_files:
    read_refs(filename)

if Args.classes:
    write_report("V3Ast__gen_report.txt")
    write_classes("V3Ast__gen_classes.h")
    write_visitor_decls("V3Ast__gen_visitor_decls.h")
    write_visitor_defns("V3Ast__gen_visitor_defns.h")
    write_impl("V3Ast__gen_impl.h")
    write_types("V3Ast__gen_types.h")
    write_yystype("V3Ast__gen_yystype.h")
    write_macros("V3Ast__gen_macros.h")
    write_op_checks("V3Ast__gen_op_checks.h")
    write_dfg_vertex_classes("V3Dfg__gen_vertex_classes.h")
    write_dfg_visitor_decls("V3Dfg__gen_visitor_decls.h")
    write_dfg_definitions("V3Dfg__gen_definitions.h")
    write_dfg_ast_to_dfg("V3Dfg__gen_ast_to_dfg.h")
    write_dfg_dfg_to_ast("V3Dfg__gen_dfg_to_ast.h")

for cpt in Args.infiles:
    if not re.search(r'.cpp$', cpt):
        sys.exit("%Error: Expected argument to be .cpp file: " + cpt)
    cpt = re.sub(r'.cpp$', '', cpt)
    Cpt().process(in_filename=Args.I + "/" + cpt + ".cpp",
                  out_filename=cpt + "__gen.cpp")

######################################################################
# Local Variables:
# compile-command: "cd obj_dbg && ../astgen -I.. V3Const.cpp"
# End:
