Speed fix

This commit is contained in:
Akash Levy 2026-05-20 15:42:26 -07:00
parent 7eff462881
commit a5617f90ac
1 changed files with 92 additions and 12 deletions

View File

@ -139,7 +139,8 @@ struct OptPriEncWorker {
// in the cone (cells whose output is reached by BFS) and the "leaf" bits
// (port-input bits or bits driven by sequential cells / undriven).
// Returns false if the cone touches anything we don't want to drive a PE.
bool get_cone(SigSpec from, pool<Cell*>& cone_cells, pool<SigBit>& leaf_bits) {
bool get_cone(SigSpec from, pool<Cell*>& cone_cells, pool<SigBit>& leaf_bits,
int max_cone_cells, int max_leaf_bits) {
pool<SigBit> visited;
std::queue<SigBit> worklist;
for (auto bit : sigmap(from)) {
@ -149,12 +150,25 @@ struct OptPriEncWorker {
while (!worklist.empty()) {
SigBit bit = worklist.front();
worklist.pop();
if (input_port_bits.count(bit)) { leaf_bits.insert(bit); continue; }
if (input_port_bits.count(bit)) {
leaf_bits.insert(bit);
if (GetSize(leaf_bits) > max_leaf_bits) return false;
continue;
}
auto it = bit_to_driver.find(bit);
if (it == bit_to_driver.end()) { leaf_bits.insert(bit); continue; }
if (it == bit_to_driver.end()) {
leaf_bits.insert(bit);
if (GetSize(leaf_bits) > max_leaf_bits) return false;
continue;
}
Cell* drv = it->second;
if (sequential_cells.count(drv)) { leaf_bits.insert(bit); continue; }
if (sequential_cells.count(drv)) {
leaf_bits.insert(bit);
if (GetSize(leaf_bits) > max_leaf_bits) return false;
continue;
}
if (!cone_cells.insert(drv).second) continue;
if (GetSize(cone_cells) > max_cone_cells) return false;
for (auto& conn : drv->connections()) {
if (!drv->input(conn.first)) continue;
for (auto in_bit : sigmap(conn.second)) {
@ -174,15 +188,17 @@ struct OptPriEncWorker {
// the fingerprint be the final arbiter.
vector<Wire*> find_candidate_Ts(Wire* S_wire,
const pool<SigBit>& cone_bits,
const pool<SigBit>& control_bits,
const vector<Wire*>& possible_Ts) {
vector<Wire*> out;
for (Wire* w : possible_Ts) {
if (w == S_wire) continue;
bool all_in = true;
bool all_in = true, any_control = false;
for (auto bit : sigmap(SigSpec(w))) {
if (!cone_bits.count(bit)) { all_in = false; break; }
if (control_bits.count(bit)) any_control = true;
}
if (all_in) out.push_back(w);
if (all_in && any_control) out.push_back(w);
}
// Try wider candidates first: the more bits the fingerprint constrains,
// the lower the chance of false positives, and longer chains usually
@ -390,6 +406,7 @@ struct OptPriEncWorker {
pool<Cell*> cone_cells;
pool<SigBit> leaf_bits;
pool<SigBit> cone_bits;
pool<SigBit> control_bits;
Cell* sole_driver;
IdString out_port;
};
@ -416,6 +433,53 @@ struct OptPriEncWorker {
return out_sig == S_sig;
}
bool is_control_input(Cell* c, IdString port) {
if (c->type.in(ID($mux), ID($pmux)))
return port == ID::S;
return c->type.in(
ID($eq), ID($ne), ID($eqx), ID($nex), ID($lt), ID($le), ID($gt), ID($ge),
ID($logic_not), ID($logic_and), ID($logic_or),
ID($reduce_bool), ID($reduce_or), ID($reduce_and),
ID($and), ID($or), ID($xor), ID($xnor), ID($not));
}
// Cheap structural prefilter for a candidate S=f(T). ConstEval will only
// assign T, so any other variable leaf in the fanin cone guarantees the
// fingerprint will fail. Stop traversal at T bits to allow T to be an
// internal wire produced by logic outside the PE region.
bool cone_depends_only_on_T(SigSpec S_sig, const pool<SigBit>& T_bits) {
pool<SigBit> visited;
std::queue<SigBit> worklist;
for (auto bit : sigmap(S_sig)) {
if (!bit.wire) continue;
if (visited.insert(bit).second) worklist.push(bit);
}
while (!worklist.empty()) {
SigBit bit = worklist.front();
worklist.pop();
if (T_bits.count(bit)) continue;
if (input_port_bits.count(bit)) return false;
auto it = bit_to_driver.find(bit);
if (it == bit_to_driver.end()) return false;
Cell* drv = it->second;
if (sequential_cells.count(drv)) return false;
for (auto& conn : drv->connections()) {
if (!drv->input(conn.first)) continue;
for (auto in_bit : sigmap(conn.second)) {
if (!in_bit.wire) continue;
if (visited.insert(in_bit).second) worklist.push(in_bit);
}
}
}
return true;
}
void run() {
vector<Wire*> wires_snapshot(module->wires().begin(), module->wires().end());
dict<int, vector<Wire*>> possible_Ts_by_Wbits;
@ -431,6 +495,8 @@ struct OptPriEncWorker {
// Stage 1: build candidate set with cones, filter by driver/width.
vector<Candidate> candidates;
int max_W = clog2_int(max_input_width + 1);
int max_cone_cells = std::max(256, max_input_width * 16);
int max_leaf_bits = max_input_width + max_W + 8;
for (Wire* S_wire : wires_snapshot) {
if (S_wire->port_input) continue;
int Wbits = S_wire->width;
@ -442,19 +508,27 @@ struct OptPriEncWorker {
pool<Cell*> cone_cells;
pool<SigBit> leaf_bits;
if (!get_cone(SigSpec(S_wire), cone_cells, leaf_bits)) continue;
if (!get_cone(SigSpec(S_wire), cone_cells, leaf_bits,
max_cone_cells, max_leaf_bits)) continue;
if (cone_cells.empty()) continue;
pool<SigBit> cone_bits = leaf_bits;
pool<SigBit> control_bits;
for (Cell* c : cone_cells) {
for (auto& conn : c->connections()) {
if (!c->output(conn.first)) continue;
for (auto bit : sigmap(conn.second))
if (bit.wire) cone_bits.insert(bit);
if (c->output(conn.first)) {
for (auto bit : sigmap(conn.second))
if (bit.wire) cone_bits.insert(bit);
}
if (c->input(conn.first) && is_control_input(c, conn.first)) {
for (auto bit : sigmap(conn.second))
if (bit.wire) control_bits.insert(bit);
}
}
}
candidates.push_back({S_wire, std::move(cone_cells), std::move(leaf_bits),
std::move(cone_bits), sole_driver, out_port});
std::move(cone_bits), std::move(control_bits),
sole_driver, out_port});
}
// Stage 2: process candidates in order of cone size (LARGEST first).
@ -485,10 +559,16 @@ struct OptPriEncWorker {
auto possible_Ts_it = possible_Ts_by_Wbits.find(Wbits);
if (possible_Ts_it == possible_Ts_by_Wbits.end()) continue;
vector<Wire*> Ts = find_candidate_Ts(cand.S_wire, cand.cone_bits, possible_Ts_it->second);
vector<Wire*> Ts = find_candidate_Ts(cand.S_wire, cand.cone_bits,
cand.control_bits, possible_Ts_it->second);
for (Wire* T_wire : Ts) {
int N = T_wire->width;
SigSpec T_sig = sigmap(SigSpec(T_wire));
pool<SigBit> T_bits;
for (auto bit : T_sig)
if (bit.wire) T_bits.insert(bit);
if (!cone_depends_only_on_T(S_sig, T_bits)) continue;
PEVariant variant = fingerprint(T_sig, S_sig, N, Wbits);
if (variant == PEVariant::NONE) continue;