utils: refactor TopoSort

This commit is contained in:
Emil J. Tywoniak 2026-04-06 15:03:09 +02:00
parent 37ca545b65
commit 2033df5958
1 changed files with 31 additions and 26 deletions

View File

@ -125,20 +125,20 @@ public:
};
// ------------------------------------------------
// A simple class for topological sorting
// ------------------------------------------------
// ---------------------------------------------------
// Best-effort topological sorting with loop detection
// ---------------------------------------------------
template <typename T, typename C = std::less<T>> class TopoSort
{
public:
public:
// We use this ordering of the edges in the adjacency matrix for
// exact compatibility with an older implementation.
struct IndirectCmp {
IndirectCmp(const std::vector<T> &nodes) : node_cmp_(), nodes_(nodes) {}
IndirectCmp(const std::vector<T> &nodes) : node_cmp_(), nodes_(nodes) {}
bool operator()(int a, int b) const
{
log_assert(static_cast<size_t>(a) < nodes_.size());
log_assert(static_cast<size_t>(a) < nodes_.size());
log_assert(static_cast<size_t>(b) < nodes_.size());
return node_cmp_(nodes_[a], nodes_[b]);
}
@ -147,7 +147,9 @@ template <typename T, typename C = std::less<T>> class TopoSort
};
bool analyze_loops;
// The stability doesn't rely on std::less of T, so pointers are safe
std::map<T, int, C> node_to_index;
// edges[i] is the set of nodes with an edge into node i
std::vector<std::set<int, IndirectCmp>> edges;
std::vector<T> sorted;
std::set<std::vector<T>> loops;
@ -160,10 +162,10 @@ template <typename T, typename C = std::less<T>> class TopoSort
int node(T n)
{
auto rv = node_to_index.emplace(n, static_cast<int>(nodes.size()));
if (rv.second) {
nodes.push_back(n);
edges.push_back(std::set<int, IndirectCmp>(indirect_cmp));
auto rv = node_to_index.emplace(n, static_cast<int>(nodes.size()));
if (rv.second) {
nodes.push_back(n);
edges.push_back(std::set<int, IndirectCmp>(indirect_cmp));
}
return rv.first->second;
}
@ -183,13 +185,14 @@ template <typename T, typename C = std::less<T>> class TopoSort
sorted.clear();
found_loops = false;
std::vector<bool> marked_cells(edges.size(), false);
std::vector<bool> active_cells(edges.size(), false);
std::vector<int> active_stack;
std::vector<bool> node_is_sorted(edges.size(), false);
std::vector<bool> node_is_on_stack(edges.size(), false);
// Only used with analyze_loops
std::vector<int> stack;
sorted.reserve(edges.size());
for (const auto &it : node_to_index)
sort_worker(it.second, marked_cells, active_cells, active_stack);
sort_worker(it.second, node_is_sorted, node_is_on_stack, stack);
log_assert(GetSize(sorted) == GetSize(nodes));
@ -211,19 +214,20 @@ template <typename T, typename C = std::less<T>> class TopoSort
return database;
}
private:
private:
bool found_loops;
std::vector<T> nodes;
const IndirectCmp indirect_cmp;
void sort_worker(const int root_index, std::vector<bool> &marked_cells, std::vector<bool> &active_cells, std::vector<int> &active_stack)
void sort_worker(const int root_index, std::vector<bool> &node_is_sorted, std::vector<bool> &node_is_on_stack, std::vector<int> &stack)
{
if (active_cells[root_index]) {
if (node_is_on_stack[root_index]) {
// We've been here before, meaning we have a loop
found_loops = true;
if (analyze_loops) {
std::vector<T> loop;
for (int i = GetSize(active_stack) - 1; i >= 0; i--) {
const int index = active_stack[i];
for (int i = GetSize(stack) - 1; i >= 0; i--) {
const int index = stack[i];
loop.push_back(nodes[index]);
if (index == root_index)
break;
@ -233,23 +237,24 @@ template <typename T, typename C = std::less<T>> class TopoSort
return;
}
if (marked_cells[root_index])
// We're done if we've already sorted this subgraph
if (node_is_sorted[root_index])
return;
if (!edges[root_index].empty()) {
if (analyze_loops)
active_stack.push_back(root_index);
active_cells[root_index] = true;
stack.push_back(root_index);
node_is_on_stack[root_index] = true;
for (int left_n : edges[root_index])
sort_worker(left_n, marked_cells, active_cells, active_stack);
sort_worker(left_n, node_is_sorted, node_is_on_stack, stack);
if (analyze_loops)
active_stack.pop_back();
active_cells[root_index] = false;
stack.pop_back();
node_is_on_stack[root_index] = false;
}
marked_cells[root_index] = true;
node_is_sorted[root_index] = true;
sorted.push_back(nodes[root_index]);
}
};