diff --git a/passes/opt/opt_dff.cc b/passes/opt/opt_dff.cc index 90ace69e5..cefc5d3fe 100644 --- a/passes/opt/opt_dff.cc +++ b/passes/opt/opt_dff.cc @@ -919,6 +919,286 @@ struct OptDffWorker return did_something; } + + struct EqBit { + Cell *cell; + int idx; + SigBit q; + }; + + struct SigKey { + enum Flag : uint16_t { + InitOne = 1u << 0, + InitX = 1u << 1, + PolClk = 1u << 2, + PolCe = 1u << 3, + PolSrst = 1u << 4, + PolArst = 1u << 5, + PolAload = 1u << 6, + PolClr = 1u << 7, + PolSet = 1u << 8, + CeOverSrst = 1u << 9, + }; + + SigBit clk, ce, srst, arst, aload, clr, set; + IdString cell_type; // for SR + uint16_t flags; + + bool operator==(const SigKey &o) const { + return flags == o.flags && clk == o.clk && ce == o.ce && srst == o.srst && arst == o.arst + && aload == o.aload && clr == o.clr && set == o.set && cell_type == o.cell_type; + } + + Hasher hash_into(Hasher h) const { + h.eat(flags); + h.eat(clk); + h.eat(ce); + h.eat(srst); + h.eat(arst); + h.eat(aload); + h.eat(clr); + h.eat(set); + h.eat(cell_type); + return h; + } + }; + + bool is_def(State s) { + return s == State::S0 || s == State::S1; + } + + int sat_mux(QuickConeSat &qcsat, int s, int a, int b) { + return qcsat.ez->OR(qcsat.ez->AND(s, a), qcsat.ez->AND(qcsat.ez->NOT(s), b)); + } + + int sat_const(QuickConeSat &qcsat, State v) { + return v == State::S1 ? qcsat.ez->CONST_TRUE : qcsat.ez->CONST_FALSE; + } + + bool run_eqbits() + { + std::vector bits; + std::vector keys; + dict ff_for_cell; + + // Collect FF bits eligible for merging + for (auto cell : module->selected_cells()) { + if (!cell->is_builtin_ff()) + continue; + + FfData ff(&initvals, cell); + if (!ff.has_clk && !ff.has_gclk) + continue; + + ff_for_cell.emplace(cell, ff); + + for (int i = 0; i < ff.width; i++) { + // X value + if (ff.has_srst && !is_def(ff.val_srst[i])) continue; + if (ff.has_arst && !is_def(ff.val_arst[i])) continue; + + // Missing anchor + bool def_init = is_def(ff.val_init[i]); + if (!def_init && !ff.has_srst && !ff.has_arst) + continue; + + SigKey k = {}; + + // Flags + if (def_init && ff.val_init[i] == State::S1) + k.flags |= SigKey::InitOne; + else if (!def_init) + k.flags |= SigKey::InitX; + + if (ff.has_clk) { + k.clk = ff.sig_clk; + if (ff.pol_clk) k.flags |= SigKey::PolClk; + } + if (ff.has_ce) { + k.ce = ff.sig_ce; + if (ff.pol_ce) k.flags |= SigKey::PolCe; + } + if (ff.has_srst) { + k.srst = ff.sig_srst; + if (ff.pol_srst) k.flags |= SigKey::PolSrst; + if (ff.ce_over_srst) k.flags |= SigKey::CeOverSrst; + } + if (ff.has_arst) { + k.arst = ff.sig_arst; + if (ff.pol_arst) k.flags |= SigKey::PolArst; + } + if (ff.has_aload) { + k.aload = ff.sig_aload; + if (ff.pol_aload) k.flags |= SigKey::PolAload; + } + if (ff.has_sr) { + k.clr = ff.sig_clr[i]; + k.set = ff.sig_set[i]; + k.cell_type = cell->type; + if (ff.pol_clr) k.flags |= SigKey::PolClr; + if (ff.pol_set) k.flags |= SigKey::PolSet; + } + + bits.push_back({cell, i, ff.sig_q[i]}); + keys.push_back(k); + } + } + + if (GetSize(bits) < 2) + return false; + + // Group bits by control signature + dict> buckets; + for (int i = 0; i < GetSize(bits); i++) + buckets[keys[i]].push_back(i); + + std::vector> classes; + classes.reserve(GetSize(buckets)); + for (auto &kv : buckets) + if (GetSize(kv.second) >= 2) + classes.push_back(std::move(kv.second)); + + if (classes.empty()) + return false; + + ModWalker modwalker(module->design, module); + QuickConeSat qcsat(modwalker); + std::vector q_lit(bits.size(), -1); + std::vector n_lit(bits.size(), -1); + + // Per candidate SAT for its next state, model difference + for (auto &cls : classes) { + for (int idx : cls) { + const EqBit &eb = bits[idx]; + const FfData &ff = ff_for_cell.at(eb.cell); + q_lit[idx] = qcsat.importSigBit(eb.q); + int n = qcsat.importSigBit(ff.sig_d[eb.idx]); + + if (ff.has_aload) { + int al = qcsat.importSigBit(ff.sig_aload); + if (!ff.pol_aload) al = qcsat.ez->NOT(al); + int ad = qcsat.importSigBit(ff.sig_ad[eb.idx]); + n = sat_mux(qcsat, al, ad, n); + } + if (ff.has_arst) { + int ar = qcsat.importSigBit(ff.sig_arst); + if (!ff.pol_arst) ar = qcsat.ez->NOT(ar); + n = sat_mux(qcsat, ar, sat_const(qcsat, ff.val_arst[eb.idx]), n); + } + if (ff.has_sr) { + int clr = qcsat.importSigBit(ff.sig_clr[eb.idx]); + if (!ff.pol_clr) clr = qcsat.ez->NOT(clr); + int set = qcsat.importSigBit(ff.sig_set[eb.idx]); + if (!ff.pol_set) set = qcsat.ez->NOT(set); + n = qcsat.ez->AND(qcsat.ez->NOT(clr), qcsat.ez->OR(set, n)); + } + if (ff.has_srst) { + int srst = qcsat.importSigBit(ff.sig_srst); + if (!ff.pol_srst) srst = qcsat.ez->NOT(srst); + n = sat_mux(qcsat,srst, sat_const(qcsat, ff.val_srst[eb.idx]), n); + } + + n_lit[idx] = n; + } + } + + qcsat.prepare(); + bool any_change = false; + bool changed = true; + + // Bit = class rep, split classes whenever two next states differ + while (changed) { + changed = false; + int joint = qcsat.ez->CONST_TRUE; + + for (auto &cls : classes) { + int rep = cls[0]; + for (int k = 1; k < GetSize(cls); k++) + joint = qcsat.ez->AND(joint, qcsat.ez->IFF(q_lit[rep], q_lit[cls[k]])); + } + + std::vector> new_classes; + new_classes.reserve(classes.size()); + + for (auto &cls : classes) { + std::vector> subs; + for (int b : cls) { + bool placed = false; + + // Identical literal - trivially eq + for (auto &sub : subs) { + if (n_lit[sub[0]] == n_lit[b]) { + sub.push_back(b); + placed = true; + break; + } + } + + if (placed) continue; + + for (auto &sub : subs) { + int rep = sub[0]; + int query = qcsat.ez->NOT(qcsat.ez->IFF(n_lit[rep], n_lit[b])); + if (!qcsat.ez->solve(joint, query)) { + sub.push_back(b); + placed = true; + break; + } + } + + if (!placed) + subs.push_back({b}); + } + + if (GetSize(subs) > 1) + changed = true; + for (auto &sub : subs) + if (GetSize(sub) >= 2) + new_classes.push_back(std::move(sub)); + } + + classes = std::move(new_classes); + if (changed) + any_change = true; + } + + if (classes.empty()) + return any_change; + + dict> remove_bits; + + // Drive every non-rep Q from its class rep, drop merged bits from their FFs + for (auto &cls : classes) { + SigBit rep_q = bits[cls[0]].q; + for (int k = 1; k < GetSize(cls); k++) { + const EqBit &eb = bits[cls[k]]; + initvals.remove_init(eb.q); + module->connect(eb.q, rep_q); + remove_bits[eb.cell].insert(eb.idx); + } + } + + for (auto &kv : remove_bits) { + Cell *cell = kv.first; + const std::set &drop = kv.second; + FfData &ff = ff_for_cell.at(cell); + std::vector keep; + + for (int i = 0; i < ff.width; i++) + if (!drop.count(i)) + keep.push_back(i); + + if (keep.empty()) { + module->remove(cell); + } else { + FfData new_ff = ff.slice(keep); + new_ff.cell = cell; + new_ff.emit(); + } + } + + return true; + } }; struct OptDffPass : public Pass { @@ -946,7 +1226,7 @@ struct OptDffPass : public Pass { log(" -simple-dffe\n"); log(" only enables clock enable recognition transform for obvious cases\n"); log("\n"); - log(" -sat\n"); + log(" -sat AAA\n"); log(" additionally invoke SAT solver to detect and remove flip-flops (with\n"); log(" non-constant inputs) that can also be replaced with a constant driver\n"); log("\n"); @@ -987,6 +1267,8 @@ struct OptDffPass : public Pass { did_something = true; if (worker.run_constbits()) did_something = true; + if (opt.sat && worker.run_eqbits()) + did_something = true; } if (did_something)