From faea2d904aba193259a32e331e1b78be2141d596 Mon Sep 17 00:00:00 2001 From: Akash Levy Date: Tue, 9 Jun 2026 10:55:10 -0700 Subject: [PATCH] Make opt_compact_prefix match more --- passes/silimate/opt_compact_prefix.cc | 231 ++++++++++++++++++++++- tests/silimate/opt_compact_prefix.ys | 114 +++++++++++ tests/silimate/opt_compact_prefix_mod.sv | 86 +++++++++ 3 files changed, 424 insertions(+), 7 deletions(-) create mode 100644 tests/silimate/opt_compact_prefix_mod.sv diff --git a/passes/silimate/opt_compact_prefix.cc b/passes/silimate/opt_compact_prefix.cc index e367e0ee3..59b8a62c4 100644 --- a/passes/silimate/opt_compact_prefix.cc +++ b/passes/silimate/opt_compact_prefix.cc @@ -19,6 +19,7 @@ #include "kernel/yosys.h" #include "kernel/sigtools.h" +#include "kernel/consteval.h" USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN @@ -44,6 +45,7 @@ struct OptCompactPrefixWorker int forward_rewrites = 0; int reverse_rewrites = 0; + int modulo_rewrites = 0; int old_cells_removed = 0; int new_cells_emitted = 0; @@ -329,6 +331,25 @@ struct OptCompactPrefixWorker return SigBit(out); } + SigBit emit_bmux(SigSpec table, SigSpec sel) + { + Cell *cell = ref_cell; + log_assert(cell != nullptr); + Wire *out = module->addWire(NEW_ID2_SUFFIX("compact_div"), 1); + module->addBmux(NEW_ID2_SUFFIX("compact_bmux"), table, sel, out); + new_cells_emitted++; + return SigBit(out); + } + + static Const const_u64(uint64_t value, int width) + { + vector bits(width, State::S0); + for (int i = 0; i < width && i < 64; i++) + if ((value >> i) & 1ULL) + bits[i] = State::S1; + return Const(bits); + } + void remove_old_cells(const vector &old_cells) { for (auto cell : old_cells) { @@ -468,13 +489,200 @@ struct OptCompactPrefixWorker return true; } + // Reference function for the modulo-n decimation loop: scanning the enable + // vector from one end, mark every n-th enabled bit. Equivalent closed form: + // mask[I] = en[I] && n>0 && (inclusive_popcount(I) % n == 0) + // where the popcount runs from the scan's leading end down to (incl.) I. + uint64_t expected_modulo_mask(uint64_t en, uint64_t n, int width, bool msb_first) + { + if (n == 0) + return 0; + uint64_t mask = 0; + for (int i = 0; i < width; i++) { + if (!((en >> i) & 1ULL)) + continue; + uint64_t v = 0; + if (msb_first) + for (int k = i; k < width; k++) + v += (en >> k) & 1ULL; + else + for (int k = 0; k <= i; k++) + v += (en >> k) & 1ULL; + if (v % n == 0) + mask |= (1ULL << i); + } + return mask; + } + + // Confirm the combinational function from (en, n) to mask matches the modulo + // decimation reference for the given scan direction, using ConstEval over a + // structured + pseudo-random vector set (cf. opt_argmax's fingerprint). + bool fingerprint_modulo(Wire *en, Wire *n, Wire *mask, int width, bool msb_first) + { + if (width <= 0 || width > 62) + return false; + ConstEval ce(module); + SigSpec en_sig = sigmap(SigSpec(en)); + SigSpec n_sig = sigmap(SigSpec(n)); + SigSpec mask_sig = sigmap(SigSpec(mask)); + int nw = GetSize(n); + + vector nvals; + uint64_t nmax = (nw >= 64) ? ~0ULL : ((1ULL << nw) - 1); + for (uint64_t v = 0; v <= (uint64_t)width + 1 && v <= nmax; v++) + nvals.push_back(v); + if (nmax > (uint64_t)width + 1) + nvals.push_back(nmax); + + uint64_t full = (width >= 64) ? ~0ULL : ((1ULL << width) - 1); + vector envals; + envals.push_back(0); + envals.push_back(full); + envals.push_back(full & 0x5555555555555555ULL); + envals.push_back(full & 0xAAAAAAAAAAAAAAAAULL); + for (int i = 0; i < width; i++) + envals.push_back(1ULL << i); + uint64_t lfsr = 0x1234567089abcdefULL; + for (int r = 0; r < 64; r++) { + lfsr ^= lfsr << 13; + lfsr ^= lfsr >> 7; + lfsr ^= lfsr << 17; + envals.push_back(lfsr & full); + } + + for (uint64_t nv : nvals) + for (uint64_t ev : envals) { + ce.push(); + ce.set(en_sig, const_u64(ev, width)); + ce.set(n_sig, const_u64(nv, nw)); + SigSpec out = mask_sig; + SigSpec undef; + bool ok = ce.eval(out, undef); + ce.pop(); + if (!ok || !out.is_fully_const()) + return false; + Const cv = out.as_const(); + uint64_t actual = 0; + for (int i = 0; i < width && i < 64; i++) + if (cv[i] == State::S1) + actual |= (1ULL << i); + if (actual != expected_modulo_mask(ev, nv, width, msb_first)) + return false; + } + return true; + } + + bool rewrite_modulo_decimation() + { + vector inputs = input_ports(); + vector outputs = output_ports(); + if (GetSize(inputs) != 2 || GetSize(outputs) != 1) + return false; + if (GetSize(module->ports) != 3) + return false; + + Wire *mask = outputs[0]; + int width = GetSize(mask); + if (width < 4 || width > max_width) + return false; + + // en matches the mask width; n (the modulus) is a distinct, narrower bus. + // Requiring different widths also disambiguates from the reverse suffix + // read form, whose two inputs share the mask width. + Wire *en = nullptr, *n = nullptr; + for (auto in : inputs) { + if (GetSize(in) == width && en == nullptr) + en = in; + else + n = in; + } + if (en == nullptr || n == nullptr || en == n || GetSize(n) == width) + return false; + + bool msb_first; + if (fingerprint_modulo(en, n, mask, width, true)) + msb_first = true; + else if (fingerprint_modulo(en, n, mask, width, false)) + msb_first = false; + else + return false; + + vector old_cells(module->cells().begin(), module->cells().end()); + if (old_cells.empty()) + return false; + ref_cell = old_cells.front(); + + int cnt_width = ceil_log2_int(width + 1); + int table_size = 1 << cnt_width; + int cmp_width = std::max(GetSize(n), cnt_width); + + auto en_bit = [&](int i) { return sigmap(SigBit(en, i)); }; + + // 1) Inclusive prefix popcount as a naive linear $add cascade. Each + // running sum is consumed below, so the downstream opt_parallel_prefix + // pass rebuilds the cascade into a shared log-depth prefix network. + vector popcount(width); + Cell *cell = ref_cell; + int start = msb_first ? width - 1 : 0; + int step = msb_first ? -1 : 1; + int last = msb_first ? 0 : width - 1; + SigSpec acc = SigSpec(en_bit(start)); + popcount[start] = zext(acc, cnt_width); + for (int i = start + step; i != last + step; i += step) { + Wire *sum = module->addWire(NEW_ID2_SUFFIX("compact_pop"), cnt_width); + module->addAdd(NEW_ID2_SUFFIX("compact_pop_add"), acc, SigSpec(en_bit(i)), sum); + new_cells_emitted++; + acc = SigSpec(sum); + popcount[i] = SigSpec(sum); + } + + // 2) Decode the modulus once: eq_d = (n == d) for d in [1..width]. + vector eqd(width + 1, State::S0); + for (int d = 1; d <= width; d++) + eqd[d] = emit_eq(SigSpec(n), d, cmp_width); + + // 3) Divisibility per popcount value: div_k = OR_{d | k} eq_d. n>0 and + // n>width fall out for free (no divisor in range matches). + vector divisible(table_size, State::S0); + divisible[0] = State::S1; // gated away by en[]; defined to size the table + for (int k = 1; k <= width; k++) { + SigSpec terms; + for (int d = 1; d <= k; d++) + if (k % d == 0) + terms.append(SigSpec(eqd[d])); + divisible[k] = emit_reduce_or(terms); + } + + // 4) Shared divisibility table; select per bit by its popcount value. + SigSpec table; + for (int k = 0; k < table_size; k++) + table.append(SigSpec(divisible[k])); + + SigSpec out_bits; + for (int i = 0; i < width; i++) { + SigBit sel_divisible = emit_bmux(table, popcount[i]); + out_bits.append(emit_and(en_bit(i), sel_divisible)); + } + + module->connect(SigSpec(mask), out_bits); + remove_old_cells(old_cells); + + log(" Modulo decimation: en=%s, n=%s -> %s, width=%d, %s scan.\n", + log_id(en->name), log_id(n->name), log_id(mask->name), width, + msb_first ? "MSB-first" : "LSB-first"); + modulo_rewrites++; + return true; + } + void run() { if (module->has_processes_warn()) return; if (rewrite_forward_dense_pack()) return; - rewrite_reverse_suffix_read(); + if (rewrite_reverse_suffix_read()) + return; + rewrite_modulo_decimation(); } }; @@ -492,9 +700,16 @@ struct OptCompactPrefixPass : public Pass log("lowering of SystemVerilog loops and replace their long loop-carried\n"); log("index/update cones with balanced prefix-count and routing logic.\n"); log("\n"); - log("Currently this pass handles the dense bit-pack and reverse suffix-read\n"); - log("forms used by the qor_spi_ra_add_chain and qor_spi_ra_sub_chain\n"); - log("regressions. Non-matching modules are left unchanged.\n"); + log("Currently this pass handles the dense bit-pack, reverse suffix-read,\n"); + log("and modulo-n decimation forms used by the qor_spi_ra_add_chain,\n"); + log("qor_spi_ra_sub_chain, and qor_spi_ra_add_chain2 regressions.\n"); + log("Non-matching modules are left unchanged.\n"); + log("\n"); + log("The modulo decimation form (mark every n-th enabled bit while scanning\n"); + log("the enable vector) is lowered to a prefix-popcount plus divisor-decode\n"); + log("divisibility check. The popcount is emitted as a plain linear $add\n"); + log("cascade so a subsequent opt_parallel_prefix pass can rebuild it into a\n"); + log("shared log-depth network.\n"); log("\n"); log(" -max_width \n"); log(" Maximum compaction width to rewrite. Default: 64.\n"); @@ -518,6 +733,7 @@ struct OptCompactPrefixPass : public Pass int total_forward = 0; int total_reverse = 0; + int total_modulo = 0; int total_removed = 0; int total_emitted = 0; @@ -526,15 +742,16 @@ struct OptCompactPrefixPass : public Pass worker.run(); total_forward += worker.forward_rewrites; total_reverse += worker.reverse_rewrites; + total_modulo += worker.modulo_rewrites; total_removed += worker.old_cells_removed; total_emitted += worker.new_cells_emitted; } - log("Rewrote %d forward pack(s), %d reverse suffix read(s); " + log("Rewrote %d forward pack(s), %d reverse suffix read(s), %d modulo decimation(s); " "removed %d old cell(s), emitted %d new cell(s).\n", - total_forward, total_reverse, total_removed, total_emitted); + total_forward, total_reverse, total_modulo, total_removed, total_emitted); - if (total_forward || total_reverse) + if (total_forward || total_reverse || total_modulo) Yosys::run_pass("clean -purge"); } } OptCompactPrefixPass; diff --git a/tests/silimate/opt_compact_prefix.ys b/tests/silimate/opt_compact_prefix.ys index 4ed9f0044..c93c6342d 100644 --- a/tests/silimate/opt_compact_prefix.ys +++ b/tests/silimate/opt_compact_prefix.ys @@ -350,3 +350,117 @@ select -assert-none opt_compact_prefix_multi_match/t:$mux select -assert-count 1 opt_compact_prefix_multi_keep/t:$mux design -reset log -pop + +log -header "Modulo decimation self-equivalence (MSB-first)" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_compact_prefix_mod.sv +verific -import opt_compact_prefix_mod8 +proc; opt_clean +rename opt_compact_prefix_mod8 gold + +read -sv opt_compact_prefix_mod.sv +verific -import opt_compact_prefix_mod8 +proc; opt_clean +opt_compact_prefix +opt_clean +bmuxmap +rename opt_compact_prefix_mod8 gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Modulo decimation self-equivalence (LSB-first)" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_compact_prefix_mod.sv +verific -import opt_compact_prefix_mod_lsb8 +proc; opt_clean +rename opt_compact_prefix_mod_lsb8 gold + +read -sv opt_compact_prefix_mod.sv +verific -import opt_compact_prefix_mod_lsb8 +proc; opt_clean +opt_compact_prefix +opt_clean +bmuxmap +rename opt_compact_prefix_mod_lsb8 gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Modulo decimation structural rewrite" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_compact_prefix_mod.sv +verific -import opt_compact_prefix_mod16 +proc; opt_clean +opt_compact_prefix +opt_clean +select -assert-none t:$mux +select -assert-none t:$sub +select -assert-min 1 t:$add +select -assert-min 1 t:$eq +select -assert-min 1 t:$bmux +design -reset +log -pop + +log -header "Modulo decimation: opt_parallel_prefix collapses the popcount cascade" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_compact_prefix_mod.sv +verific -import opt_compact_prefix_mod16 +proc; opt_clean +opt_compact_prefix +opt_clean +opt_parallel_prefix -arith +opt_clean +select -assert-none t:$mux +select -assert-none t:$sub +select -assert-min 1 t:$add +design -reset +log -pop + +log -header "Negative: modulo off-by-one near miss unchanged" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_compact_prefix_mod.sv +verific -import opt_compact_prefix_mod_offbyone +proc; opt_clean +opt_compact_prefix +select -assert-none w:*compact* +design -reset +log -pop + +log -header "Max width: modulo decimation left unchanged" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_compact_prefix_mod.sv +verific -import opt_compact_prefix_mod16 +proc; opt_clean +select -assert-min 1 t:$mux +opt_compact_prefix -max_width 8 +select -assert-min 1 t:$mux +select -assert-none w:*compact* +design -reset +log -pop diff --git a/tests/silimate/opt_compact_prefix_mod.sv b/tests/silimate/opt_compact_prefix_mod.sv new file mode 100644 index 000000000..9fbc0001e --- /dev/null +++ b/tests/silimate/opt_compact_prefix_mod.sv @@ -0,0 +1,86 @@ +// Modulo-n decimation loops: scanning the enable vector, mark every n-th +// enabled bit. Exercised by opt_compact_prefix's modulo decimation rewrite +// (cf. the qor_spi_ra_add_chain2 regression). + +module opt_compact_prefix_mod8 ( + input logic [7:0] en, + input logic [3:0] n, + output logic [7:0] mask +); + always_comb begin + mask = '0; + for (int I = 7, cnt = 0; I >= 0; I--) begin + if (en[I] && (n > 0)) begin + if (cnt == (n - 1)) begin + mask[I] = 1'b1; + cnt = 0; + end else begin + cnt++; + end + end + end + end +endmodule + +module opt_compact_prefix_mod16 ( + input logic [15:0] en, + input logic [4:0] n, + output logic [15:0] mask +); + always_comb begin + mask = '0; + for (int I = 15, cnt = 0; I >= 0; I--) begin + if (en[I] && (n > 0)) begin + if (cnt == (n - 1)) begin + mask[I] = 1'b1; + cnt = 0; + end else begin + cnt++; + end + end + end + end +endmodule + +// Same function, but scanned LSB-first (exercises the mirrored direction). +module opt_compact_prefix_mod_lsb8 ( + input logic [7:0] en, + input logic [3:0] n, + output logic [7:0] mask +); + always_comb begin + mask = '0; + for (int I = 0, cnt = 0; I < 8; I++) begin + if (en[I] && (n > 0)) begin + if (cnt == (n - 1)) begin + mask[I] = 1'b1; + cnt = 0; + end else begin + cnt++; + end + end + end + end +endmodule + +// Negative near-miss: marks every (n+1)-th enabled bit (reset on cnt == n), +// a different function that must NOT be rewritten. +module opt_compact_prefix_mod_offbyone ( + input logic [7:0] en, + input logic [3:0] n, + output logic [7:0] mask +); + always_comb begin + mask = '0; + for (int I = 7, cnt = 0; I >= 0; I--) begin + if (en[I] && (n > 0)) begin + if (cnt == n) begin + mask[I] = 1'b1; + cnt = 0; + end else begin + cnt++; + end + end + end + end +endmodule