Elim equiv bits.

This commit is contained in:
nella 2026-04-29 11:21:26 +02:00
parent d2f7fecef5
commit c58ed410da
1 changed files with 283 additions and 1 deletions

View File

@ -919,6 +919,286 @@ struct OptDffWorker
return did_something;
}
struct EqBit {
Cell *cell;
int idx;
SigBit q;
};
struct SigKey {
enum Flag : uint16_t {
InitOne = 1u << 0,
InitX = 1u << 1,
PolClk = 1u << 2,
PolCe = 1u << 3,
PolSrst = 1u << 4,
PolArst = 1u << 5,
PolAload = 1u << 6,
PolClr = 1u << 7,
PolSet = 1u << 8,
CeOverSrst = 1u << 9,
};
SigBit clk, ce, srst, arst, aload, clr, set;
IdString cell_type; // for SR
uint16_t flags;
bool operator==(const SigKey &o) const {
return flags == o.flags && clk == o.clk && ce == o.ce && srst == o.srst && arst == o.arst
&& aload == o.aload && clr == o.clr && set == o.set && cell_type == o.cell_type;
}
Hasher hash_into(Hasher h) const {
h.eat(flags);
h.eat(clk);
h.eat(ce);
h.eat(srst);
h.eat(arst);
h.eat(aload);
h.eat(clr);
h.eat(set);
h.eat(cell_type);
return h;
}
};
bool is_def(State s) {
return s == State::S0 || s == State::S1;
}
int sat_mux(QuickConeSat &qcsat, int s, int a, int b) {
return qcsat.ez->OR(qcsat.ez->AND(s, a), qcsat.ez->AND(qcsat.ez->NOT(s), b));
}
int sat_const(QuickConeSat &qcsat, State v) {
return v == State::S1 ? qcsat.ez->CONST_TRUE : qcsat.ez->CONST_FALSE;
}
bool run_eqbits()
{
std::vector<EqBit> bits;
std::vector<SigKey> keys;
dict<Cell*, FfData> ff_for_cell;
// Collect FF bits eligible for merging
for (auto cell : module->selected_cells()) {
if (!cell->is_builtin_ff())
continue;
FfData ff(&initvals, cell);
if (!ff.has_clk && !ff.has_gclk)
continue;
ff_for_cell.emplace(cell, ff);
for (int i = 0; i < ff.width; i++) {
// X value
if (ff.has_srst && !is_def(ff.val_srst[i])) continue;
if (ff.has_arst && !is_def(ff.val_arst[i])) continue;
// Missing anchor
bool def_init = is_def(ff.val_init[i]);
if (!def_init && !ff.has_srst && !ff.has_arst)
continue;
SigKey k = {};
// Flags
if (def_init && ff.val_init[i] == State::S1)
k.flags |= SigKey::InitOne;
else if (!def_init)
k.flags |= SigKey::InitX;
if (ff.has_clk) {
k.clk = ff.sig_clk;
if (ff.pol_clk) k.flags |= SigKey::PolClk;
}
if (ff.has_ce) {
k.ce = ff.sig_ce;
if (ff.pol_ce) k.flags |= SigKey::PolCe;
}
if (ff.has_srst) {
k.srst = ff.sig_srst;
if (ff.pol_srst) k.flags |= SigKey::PolSrst;
if (ff.ce_over_srst) k.flags |= SigKey::CeOverSrst;
}
if (ff.has_arst) {
k.arst = ff.sig_arst;
if (ff.pol_arst) k.flags |= SigKey::PolArst;
}
if (ff.has_aload) {
k.aload = ff.sig_aload;
if (ff.pol_aload) k.flags |= SigKey::PolAload;
}
if (ff.has_sr) {
k.clr = ff.sig_clr[i];
k.set = ff.sig_set[i];
k.cell_type = cell->type;
if (ff.pol_clr) k.flags |= SigKey::PolClr;
if (ff.pol_set) k.flags |= SigKey::PolSet;
}
bits.push_back({cell, i, ff.sig_q[i]});
keys.push_back(k);
}
}
if (GetSize(bits) < 2)
return false;
// Group bits by control signature
dict<SigKey, std::vector<int>> buckets;
for (int i = 0; i < GetSize(bits); i++)
buckets[keys[i]].push_back(i);
std::vector<std::vector<int>> classes;
classes.reserve(GetSize(buckets));
for (auto &kv : buckets)
if (GetSize(kv.second) >= 2)
classes.push_back(std::move(kv.second));
if (classes.empty())
return false;
ModWalker modwalker(module->design, module);
QuickConeSat qcsat(modwalker);
std::vector<int> q_lit(bits.size(), -1);
std::vector<int> n_lit(bits.size(), -1);
// Per candidate SAT for its next state, model difference
for (auto &cls : classes) {
for (int idx : cls) {
const EqBit &eb = bits[idx];
const FfData &ff = ff_for_cell.at(eb.cell);
q_lit[idx] = qcsat.importSigBit(eb.q);
int n = qcsat.importSigBit(ff.sig_d[eb.idx]);
if (ff.has_aload) {
int al = qcsat.importSigBit(ff.sig_aload);
if (!ff.pol_aload) al = qcsat.ez->NOT(al);
int ad = qcsat.importSigBit(ff.sig_ad[eb.idx]);
n = sat_mux(qcsat, al, ad, n);
}
if (ff.has_arst) {
int ar = qcsat.importSigBit(ff.sig_arst);
if (!ff.pol_arst) ar = qcsat.ez->NOT(ar);
n = sat_mux(qcsat, ar, sat_const(qcsat, ff.val_arst[eb.idx]), n);
}
if (ff.has_sr) {
int clr = qcsat.importSigBit(ff.sig_clr[eb.idx]);
if (!ff.pol_clr) clr = qcsat.ez->NOT(clr);
int set = qcsat.importSigBit(ff.sig_set[eb.idx]);
if (!ff.pol_set) set = qcsat.ez->NOT(set);
n = qcsat.ez->AND(qcsat.ez->NOT(clr), qcsat.ez->OR(set, n));
}
if (ff.has_srst) {
int srst = qcsat.importSigBit(ff.sig_srst);
if (!ff.pol_srst) srst = qcsat.ez->NOT(srst);
n = sat_mux(qcsat,srst, sat_const(qcsat, ff.val_srst[eb.idx]), n);
}
n_lit[idx] = n;
}
}
qcsat.prepare();
bool any_change = false;
bool changed = true;
// Bit = class rep, split classes whenever two next states differ
while (changed) {
changed = false;
int joint = qcsat.ez->CONST_TRUE;
for (auto &cls : classes) {
int rep = cls[0];
for (int k = 1; k < GetSize(cls); k++)
joint = qcsat.ez->AND(joint, qcsat.ez->IFF(q_lit[rep], q_lit[cls[k]]));
}
std::vector<std::vector<int>> new_classes;
new_classes.reserve(classes.size());
for (auto &cls : classes) {
std::vector<std::vector<int>> subs;
for (int b : cls) {
bool placed = false;
// Identical literal - trivially eq
for (auto &sub : subs) {
if (n_lit[sub[0]] == n_lit[b]) {
sub.push_back(b);
placed = true;
break;
}
}
if (placed) continue;
for (auto &sub : subs) {
int rep = sub[0];
int query = qcsat.ez->NOT(qcsat.ez->IFF(n_lit[rep], n_lit[b]));
if (!qcsat.ez->solve(joint, query)) {
sub.push_back(b);
placed = true;
break;
}
}
if (!placed)
subs.push_back({b});
}
if (GetSize(subs) > 1)
changed = true;
for (auto &sub : subs)
if (GetSize(sub) >= 2)
new_classes.push_back(std::move(sub));
}
classes = std::move(new_classes);
if (changed)
any_change = true;
}
if (classes.empty())
return any_change;
dict<Cell *, std::set<int>> remove_bits;
// Drive every non-rep Q from its class rep, drop merged bits from their FFs
for (auto &cls : classes) {
SigBit rep_q = bits[cls[0]].q;
for (int k = 1; k < GetSize(cls); k++) {
const EqBit &eb = bits[cls[k]];
initvals.remove_init(eb.q);
module->connect(eb.q, rep_q);
remove_bits[eb.cell].insert(eb.idx);
}
}
for (auto &kv : remove_bits) {
Cell *cell = kv.first;
const std::set<int> &drop = kv.second;
FfData &ff = ff_for_cell.at(cell);
std::vector<int> keep;
for (int i = 0; i < ff.width; i++)
if (!drop.count(i))
keep.push_back(i);
if (keep.empty()) {
module->remove(cell);
} else {
FfData new_ff = ff.slice(keep);
new_ff.cell = cell;
new_ff.emit();
}
}
return true;
}
};
struct OptDffPass : public Pass {
@ -946,7 +1226,7 @@ struct OptDffPass : public Pass {
log(" -simple-dffe\n");
log(" only enables clock enable recognition transform for obvious cases\n");
log("\n");
log(" -sat\n");
log(" -sat AAA\n");
log(" additionally invoke SAT solver to detect and remove flip-flops (with\n");
log(" non-constant inputs) that can also be replaced with a constant driver\n");
log("\n");
@ -987,6 +1267,8 @@ struct OptDffPass : public Pass {
did_something = true;
if (worker.run_constbits())
did_something = true;
if (opt.sat && worker.run_eqbits())
did_something = true;
}
if (did_something)