Improve AstNode::foreach (also exists and forall)

Speed improvements:
- Use a direct, recursion-free implementation
- Improve pre-fetching

Functionality:
- Support remove/replace of currently iterated node
This commit is contained in:
Geza Lore 2022-07-31 18:13:55 +01:00
parent f91793e931
commit 152a6cd886
2 changed files with 230 additions and 122 deletions

View File

@ -1334,25 +1334,27 @@ inline std::ostream& operator<<(std::ostream& os, const VNRelinker& rhs) {
return os;
}
//######################################################################
// Callback base class to determine if node matches some formula
// ######################################################################
// Callback base class to determine if node matches some formula
class VNodeMatcher VL_NOT_FINAL {
public:
virtual bool nodeMatch(const AstNode* nodep) const { return true; }
};
//######################################################################
// AstNode -- Base type of all Ast types
// ######################################################################
// AstNode -- Base type of all Ast types
// Prefetch a node.
#define ASTNODE_PREFETCH_NON_NULL(nodep) \
do { \
VL_PREFETCH_RD(&((nodep)->m_nextp)); \
VL_PREFETCH_RD(&((nodep)->m_type)); \
} while (false)
// The if() makes it faster, even though prefetch won't fault on null pointers
#define ASTNODE_PREFETCH(nodep) \
do { \
if (nodep) { \
VL_PREFETCH_RD(&((nodep)->m_nextp)); \
VL_PREFETCH_RD(&((nodep)->m_type)); \
} \
if (nodep) ASTNODE_PREFETCH_NON_NULL(nodep); \
} while (false)
class AstNode VL_NOT_FINAL {
@ -1859,12 +1861,6 @@ private:
// Note: specializations for particular node types are provided by 'astgen'
template <typename T> inline static bool privateTypeTest(const AstNode* nodep);
// For internal use only.
// Note: specializations for particular node types are provided below
template <typename T_Node> inline static bool privateMayBeUnder(const AstNode* nodep) {
return true;
}
// For internal use only.
template <typename TargetType, typename DeclType> constexpr static bool uselessCast() {
using NonRef = typename std::remove_reference<DeclType>::type;
@ -1923,102 +1919,39 @@ public:
// Predicate that returns true if the given 'nodep' might have a descendant of type 'T_Node'.
// This is conservative and is used to speed up traversals.
template <typename T_Node> inline static bool mayBeUnder(const AstNode* nodep) {
// Note: specializations for particular node types are provided below
template <typename T_Node> static bool mayBeUnder(const AstNode* nodep) {
static_assert(!std::is_const<T_Node>::value,
"Type parameter 'T_Node' should not be const qualified");
static_assert(std::is_base_of<AstNode, T_Node>::value,
"Type parameter 'T_Node' must be a subtype of AstNode");
return privateMayBeUnder<T_Node>(nodep);
return true;
}
// Predicate that is true for node subtypes 'T_Node' that do not have any children
// This is conservative and is used to speed up traversals.
// Note: specializations for particular node types are provided below
template <typename T_Node> static constexpr bool isLeaf() {
static_assert(!std::is_const<T_Node>::value,
"Type parameter 'T_Node' should not be const qualified");
static_assert(std::is_base_of<AstNode, T_Node>::value,
"Type parameter 'T_Node' must be a subtype of AstNode");
return false;
}
private:
template <typename T_Arg, bool VisitNext>
static void foreachImpl(
// Using std::conditional for const correctness in the public 'foreach' functions
typename std::conditional<std::is_const<T_Arg>::value, const AstNode*, AstNode*>::type
nodep,
std::function<void(T_Arg*)> f) {
// Using std::conditional for const correctness in the public 'foreach' functions
template <typename T_Arg>
using ConstCorrectAstNode =
typename std::conditional<std::is_const<T_Arg>::value, const AstNode, AstNode>::type;
// Note: Using a loop to iterate the nextp() chain, instead of tail recursion, because
// debug builds don't eliminate tail calls, causing stack overflow on long lists of nodes.
do {
// Prefetch children and next
ASTNODE_PREFETCH(nodep->op1p());
ASTNODE_PREFETCH(nodep->op2p());
ASTNODE_PREFETCH(nodep->op3p());
ASTNODE_PREFETCH(nodep->op4p());
if VL_CONSTEXPR_CXX17 (VisitNext) ASTNODE_PREFETCH(nodep->nextp());
template <typename T_Arg>
inline static void foreachImpl(ConstCorrectAstNode<T_Arg>* nodep,
const std::function<void(T_Arg*)>& f, bool visitNext);
// Apply function in pre-order
if (privateTypeTest<typename std::remove_const<T_Arg>::type>(nodep)) {
f(static_cast<T_Arg*>(nodep));
}
// Traverse children (including their 'nextp()' chains), unless futile
if (mayBeUnder<typename std::remove_const<T_Arg>::type>(nodep)) {
if (AstNode* const op1p = nodep->op1p()) foreachImpl<T_Arg, true>(op1p, f);
if (AstNode* const op2p = nodep->op2p()) foreachImpl<T_Arg, true>(op2p, f);
if (AstNode* const op3p = nodep->op3p()) foreachImpl<T_Arg, true>(op3p, f);
if (AstNode* const op4p = nodep->op4p()) foreachImpl<T_Arg, true>(op4p, f);
}
// Traverse 'nextp()' chain if requested
if VL_CONSTEXPR_CXX17 (VisitNext) {
nodep = nodep->nextp();
} else {
break;
}
} while (nodep);
}
template <typename T_Arg, bool Default, bool VisitNext>
static bool predicateImpl(
// Using std::conditional for const correctness in the public 'foreach' functions
typename std::conditional<std::is_const<T_Arg>::value, const AstNode*, AstNode*>::type
nodep,
std::function<bool(T_Arg*)> p) {
// Note: Using a loop to iterate the nextp() chain, instead of tail recursion, because
// debug builds don't eliminate tail calls, causing stack overflow on long lists of nodes.
do {
// Prefetch children and next
ASTNODE_PREFETCH(nodep->op1p());
ASTNODE_PREFETCH(nodep->op2p());
ASTNODE_PREFETCH(nodep->op3p());
ASTNODE_PREFETCH(nodep->op4p());
if VL_CONSTEXPR_CXX17 (VisitNext) ASTNODE_PREFETCH(nodep->nextp());
// Apply function in pre-order
if (privateTypeTest<typename std::remove_const<T_Arg>::type>(nodep)) {
if (p(static_cast<T_Arg*>(nodep)) != Default) return !Default;
}
// Traverse children (including their 'nextp()' chains), unless futile
if (mayBeUnder<typename std::remove_const<T_Arg>::type>(nodep)) {
if (AstNode* const op1p = nodep->op1p()) {
if (predicateImpl<T_Arg, Default, true>(op1p, p) != Default) return !Default;
}
if (AstNode* const op2p = nodep->op2p()) {
if (predicateImpl<T_Arg, Default, true>(op2p, p) != Default) return !Default;
}
if (AstNode* const op3p = nodep->op3p()) {
if (predicateImpl<T_Arg, Default, true>(op3p, p) != Default) return !Default;
}
if (AstNode* const op4p = nodep->op4p()) {
if (predicateImpl<T_Arg, Default, true>(op4p, p) != Default) return !Default;
}
}
// Traverse 'nextp()' chain if requested
if VL_CONSTEXPR_CXX17 (VisitNext) {
nodep = nodep->nextp();
} else {
break;
}
} while (nodep);
return Default;
}
template <typename T_Arg, bool Default>
inline static bool predicateImpl(ConstCorrectAstNode<T_Arg>* nodep,
const std::function<bool(T_Arg*)>& p);
template <typename T_Node> constexpr static bool checkTypeParameter() {
static_assert(!std::is_const<T_Node>::value,
@ -2030,31 +1963,32 @@ private:
public:
// Traverse subtree and call given function 'f' in pre-order on each node that has type
// 'T_Node'. Prefer 'foreach' over simple VNVisitor that only needs to handle a single (or a
// few) node types, as it's easier to write, but more importantly, the dispatch to the
// operation function in 'foreach' should be completely predictable by branch target caches in
// modern CPUs, while it is basically unpredictable for VNVisitor.
// 'T_Node'. The node passd to the function 'f' can be removed or replaced, but other editing
// of the iterated tree is not safe. Prefer 'foreach' over simple VNVisitor that only needs to
// handle a single (or a few) node types, as it's easier to write, but more importantly, the
// dispatch to the operation function in 'foreach' should be completely predictable by branch
// target caches in modern CPUs, while it is basically unpredictable for VNVisitor.
template <typename T_Node> void foreach (std::function<void(T_Node*)> f) {
static_assert(checkTypeParameter<T_Node>(), "Invalid type parameter 'T_Node'");
foreachImpl<T_Node, /* VisitNext: */ false>(this, f);
foreachImpl<T_Node>(this, f, /* visitNext: */ false);
}
// Same as above, but for 'const' nodes
template <typename T_Node> void foreach (std::function<void(const T_Node*)> f) const {
static_assert(checkTypeParameter<T_Node>(), "Invalid type parameter 'T_Node'");
foreachImpl<const T_Node, /* VisitNext: */ false>(this, f);
foreachImpl<const T_Node>(this, f, /* visitNext: */ false);
}
// Same as 'foreach' but also follows 'this->nextp()'
template <typename T_Node> void foreachAndNext(std::function<void(T_Node*)> f) {
static_assert(checkTypeParameter<T_Node>(), "Invalid type parameter 'T_Node'");
foreachImpl<T_Node, /* VisitNext: */ true>(this, f);
foreachImpl<T_Node>(this, f, /* visitNext: */ true);
}
// Same as 'foreach' but also follows 'this->nextp()'
template <typename T_Node> void foreachAndNext(std::function<void(const T_Node*)> f) const {
static_assert(checkTypeParameter<T_Node>(), "Invalid type parameter 'T_Node'");
foreachImpl<const T_Node, /* VisitNext: */ true>(this, f);
foreachImpl<const T_Node>(this, f, /* visitNext: */ true);
}
// Given a predicate function 'p' return true if and only if there exists a node of type
@ -2063,13 +1997,13 @@ public:
// result can be determined.
template <typename T_Node> bool exists(std::function<bool(T_Node*)> p) {
static_assert(checkTypeParameter<T_Node>(), "Invalid type parameter 'T_Node'");
return predicateImpl<T_Node, /* Default: */ false, /* VisitNext: */ false>(this, p);
return predicateImpl<T_Node, /* Default: */ false>(this, p);
}
// Same as above, but for 'const' nodes
template <typename T_Node> void exists(std::function<bool(const T_Node*)> p) const {
static_assert(checkTypeParameter<T_Node>(), "Invalid type parameter 'T_Node'");
return predicateImpl<const T_Node, /* Default: */ false, /* VisitNext: */ false>(this, p);
return predicateImpl<const T_Node, /* Default: */ false>(this, p);
}
// Given a predicate function 'p' return true if and only if all nodes of type
@ -2078,13 +2012,13 @@ public:
// result can be determined.
template <typename T_Node> bool forall(std::function<bool(T_Node*)> p) {
static_assert(checkTypeParameter<T_Node>(), "Invalid type parameter 'T_Node'");
return predicateImpl<T_Node, /* Default: */ true, /* VisitNext: */ false>(this, p);
return predicateImpl<T_Node, /* Default: */ true>(this, p);
}
// Same as above, but for 'const' nodes
template <typename T_Node> void forall(std::function<bool(const T_Node*)> p) const {
static_assert(checkTypeParameter<T_Node>(), "Invalid type parameter 'T_Node'");
return predicateImpl<const T_Node, /* Default: */ true, /* VisitNext: */ false>(this, p);
return predicateImpl<const T_Node, /* Default: */ true>(this, p);
}
int nodeCount() const {
@ -2098,22 +2032,196 @@ public:
// Specialisations of privateTypeTest
#include "V3Ast__gen_impl.h" // From ./astgen
// Specializations of privateMayBeUnder
template <> inline bool AstNode::privateMayBeUnder<AstCell>(const AstNode* nodep) {
// Specializations of AstNode::mayBeUnder
template <> inline bool AstNode::mayBeUnder<AstCell>(const AstNode* nodep) {
return !VN_IS(nodep, NodeStmt) && !VN_IS(nodep, NodeMath);
}
template <> inline bool AstNode::privateMayBeUnder<AstNodeAssign>(const AstNode* nodep) {
template <> inline bool AstNode::mayBeUnder<AstNodeAssign>(const AstNode* nodep) {
return !VN_IS(nodep, NodeMath);
}
template <> inline bool AstNode::privateMayBeUnder<AstVarScope>(const AstNode* nodep) {
template <> inline bool AstNode::mayBeUnder<AstVarScope>(const AstNode* nodep) {
return !VN_IS(nodep, NodeStmt) && !VN_IS(nodep, NodeMath);
}
template <> inline bool AstNode::privateMayBeUnder<AstExecGraph>(const AstNode* nodep) {
template <> inline bool AstNode::mayBeUnder<AstExecGraph>(const AstNode* nodep) {
if (VN_IS(nodep, ExecGraph)) return false; // Should not nest
if (VN_IS(nodep, NodeStmt)) return false; // Should be directly under CFunc
return true;
}
// Specializations of AstNode::isLeaf
template <> constexpr bool AstNode::isLeaf<AstNodeVarRef>() { return true; }
template <> constexpr bool AstNode::isLeaf<AstVarRef>() { return true; }
template <> constexpr bool AstNode::isLeaf<AstVarXRef>() { return true; }
// foreach implementation
template <typename T_Arg>
void AstNode::foreachImpl(ConstCorrectAstNode<T_Arg>* nodep, const std::function<void(T_Arg*)>& f,
bool visitNext) {
// Checking the function is bound up front eliminates this check from the loop at invocation
if (!f) {
nodep->v3fatal("AstNode::foreach called with unbound function"); // LCOV_EXCL_LINE
} else {
// Pre-order traversal implemented directly (without recursion) for speed reasons. The very
// first iteration (the one that operates on the input nodep) is special, as we might or
// might not need to enqueue nodep->nextp() depending on VisitNext, while in all other
// iterations, we do want to enqueue nodep->nextp(). Duplicating code (via
// 'foreachImplVisit') for the initial iteration here to avoid an extra branch in the loop
using T_Arg_NonConst = typename std::remove_const<T_Arg>::type;
using Node = ConstCorrectAstNode<T_Arg>;
// Traversal stack
std::vector<Node*> stack; // Kept as a vector for easy resizing
Node** basep = nullptr; // Pointer to base of stack
Node** topp = nullptr; // Pointer to top of stack
Node** limp = nullptr; // Pointer to stack limit (when need growing)
// We prefetch this far into the stack
constexpr int prefetchDistance = 2;
// Grow stack to given size
const auto grow = [&](size_t size) VL_ATTR_ALWINLINE {
const ptrdiff_t occupancy = topp - basep;
stack.resize(size);
basep = stack.data() + prefetchDistance;
topp = basep + occupancy;
limp = basep + size - 5; // We push max 5 items per iteration
};
// Initial stack size
grow(32);
// We want some non-null pointers at the beginning. These will be prefetched, but not
// visited, so the root node will suffice. This eliminates needing branches in the loop.
for (int i = -prefetchDistance; i; ++i) basep[i] = nodep;
// Visit given node, enqueue children for traversal
const auto visit = [&](Node* currp) VL_ATTR_ALWINLINE {
// Type test this node
if (AstNode::privateTypeTest<T_Arg_NonConst>(currp)) {
// Call the client function
f(static_cast<T_Arg*>(currp));
// Short circuit if iterating leaf nodes
if VL_CONSTEXPR_CXX17 (isLeaf<T_Arg_NonConst>()) return;
}
// Enqueue children for traversal, unless futile
if (mayBeUnder<T_Arg_NonConst>(currp)) {
if (AstNode* const op4p = currp->op4p()) *topp++ = op4p;
if (AstNode* const op3p = currp->op3p()) *topp++ = op3p;
if (AstNode* const op2p = currp->op2p()) *topp++ = op2p;
if (AstNode* const op1p = currp->op1p()) *topp++ = op1p;
}
};
// Enqueue the next of the root node, if required
if (visitNext && nodep->nextp()) *topp++ = nodep->nextp();
// Visit the root node
visit(nodep);
// Visit the rest of the tree
while (VL_LIKELY(topp > basep)) {
// Pop next node in the traversal
Node* const headp = *--topp;
// Prefetch in case we are ascending the tree
ASTNODE_PREFETCH_NON_NULL(topp[-prefetchDistance]);
// Ensure we have stack space for nextp and the 4 children
if (VL_UNLIKELY(topp >= limp)) grow(stack.size() * 2);
// Enqueue the next node
if (headp->nextp()) *topp++ = headp->nextp();
// Visit the head node
visit(headp);
}
}
}
// predicate implementation
template <typename T_Arg, bool Default>
bool AstNode::predicateImpl(ConstCorrectAstNode<T_Arg>* nodep,
const std::function<bool(T_Arg*)>& p) {
// Implementation similar to foreach, but abort traversal as soon as result is determined.
if (!p) {
nodep->v3fatal("AstNode::foreach called with unbound function"); // LCOV_EXCL_LINE
} else {
using T_Arg_NonConst = typename std::remove_const<T_Arg>::type;
using Node = ConstCorrectAstNode<T_Arg>;
// Traversal stack
std::vector<Node*> stack; // Kept as a vector for easy resizing
Node** basep = nullptr; // Pointer to base of stack
Node** topp = nullptr; // Pointer to top of stack
Node** limp = nullptr; // Pointer to stack limit (when need growing)
// We prefetch this far into the stack
constexpr int prefetchDistance = 2;
// Grow stack to given size
const auto grow = [&](size_t size) VL_ATTR_ALWINLINE {
const ptrdiff_t occupancy = topp - basep;
stack.resize(size);
basep = stack.data() + prefetchDistance;
topp = basep + occupancy;
limp = basep + size - 5; // We push max 5 items per iteration
};
// Initial stack size
grow(32);
// We want some non-null pointers at the beginning. These will be prefetched, but not
// visited, so the root node will suffice. This eliminates needing branches in the loop.
for (int i = -prefetchDistance; i; ++i) basep[i] = nodep;
// Visit given node, enqueue children for traversal, return true if result determined.
const auto visit = [&](Node* currp) VL_ATTR_ALWINLINE {
// Type test this node
if (AstNode::privateTypeTest<T_Arg_NonConst>(currp)) {
// Call the client function
if (p(static_cast<T_Arg*>(currp)) != Default) return true;
// Short circuit if iterating leaf nodes
if VL_CONSTEXPR_CXX17 (isLeaf<T_Arg_NonConst>()) return false;
}
// Enqueue children for traversal, unless futile
if (mayBeUnder<T_Arg_NonConst>(currp)) {
if (AstNode* const op4p = currp->op4p()) *topp++ = op4p;
if (AstNode* const op3p = currp->op3p()) *topp++ = op3p;
if (AstNode* const op2p = currp->op2p()) *topp++ = op2p;
if (AstNode* const op1p = currp->op1p()) *topp++ = op1p;
}
return false;
};
// Visit the root node
if (visit(nodep)) return !Default;
// Visit the rest of the tree
while (VL_LIKELY(topp > basep)) {
// Pop next node in the traversal
Node* const headp = *--topp;
// Prefetch in case we are ascending the tree
ASTNODE_PREFETCH_NON_NULL(topp[-prefetchDistance]);
// Ensure we have stack space for nextp and the 4 children
if (VL_UNLIKELY(topp >= limp)) grow(stack.size() * 2);
// Enqueue the next node
if (headp->nextp()) *topp++ = headp->nextp();
// Visit the head node
if (visit(headp)) return !Default;
}
return Default;
}
}
inline std::ostream& operator<<(std::ostream& os, const AstNode* rhs) {
if (!rhs) {
os << "nullptr";

View File

@ -151,14 +151,14 @@ class ForceConvertVisitor final : public VNVisitor {
// referenced AstVarScope with the given function.
void transformWritenVarScopes(AstNode* nodep, std::function<AstVarScope*(AstVarScope*)> f) {
UASSERT_OBJ(nodep->backp(), nodep, "Must have backp, otherwise will be lost if replaced");
nodep->foreach<AstNodeVarRef>([this, &f](AstNodeVarRef* refp) {
nodep->foreach<AstNodeVarRef>([&f](AstNodeVarRef* refp) {
if (refp->access() != VAccess::WRITE) return;
// TODO: this is not strictly speaking safe for some complicated lvalues, eg.:
// 'force foo[a(cnt)] = 1;', where 'cnt' is an out parameter, but it will
// do for now...
refp->replaceWith(
new AstVarRef{refp->fileline(), f(refp->varScopep()), VAccess::WRITE});
pushDeletep(refp);
VL_DO_DANGLING(refp->deleteTree(), refp);
});
}
@ -238,7 +238,7 @@ class ForceConvertVisitor final : public VNVisitor {
flp->warnOff(V3ErrorCode::BLKANDNBLK, true);
AstVarRef* const newpRefp = new AstVarRef{flp, newVscp, VAccess::WRITE};
refp->replaceWith(newpRefp);
pushDeletep(refp);
VL_DO_DANGLING(refp->deleteTree(), refp);
});
// Replace write refs on RHS
resetRdp->rhsp()->foreach<AstNodeVarRef>([this](AstNodeVarRef* refp) {
@ -249,7 +249,7 @@ class ForceConvertVisitor final : public VNVisitor {
AstVarRef* const newpRefp = new AstVarRef{refp->fileline(), newVscp, VAccess::READ};
newpRefp->user2(1); // Don't replace this read ref with the read signal
refp->replaceWith(newpRefp);
pushDeletep(refp);
VL_DO_DANGLING(refp->deleteTree(), refp);
});
resetEnp->addNext(resetRdp);