From 6daa8a01edcc8388e729ef878a553a9f8322ff62 Mon Sep 17 00:00:00 2001 From: Akash Levy Date: Tue, 7 Apr 2026 22:07:14 -0700 Subject: [PATCH] opt_vps improvements for VPS read --- passes/silimate/opt_vps.cc | 1166 +++++++++++++++++++++++++++- tests/silimate/opt_vps.ys | 82 ++ tests/silimate/opt_vps_read.sv | 17 + tests/silimate/opt_vps_read_ref.sv | 16 + 4 files changed, 1259 insertions(+), 22 deletions(-) create mode 100644 tests/silimate/opt_vps_read.sv create mode 100644 tests/silimate/opt_vps_read_ref.sv diff --git a/passes/silimate/opt_vps.cc b/passes/silimate/opt_vps.cc index 6569bd8bd..a42f039eb 100644 --- a/passes/silimate/opt_vps.cc +++ b/passes/silimate/opt_vps.cc @@ -1,8 +1,7 @@ /* * yosys -- Yosys Open SYnthesis Suite * - * Copyright (C) 2012 Claire Xenia Wolf - * 2025 Silimate Inc. + * Copyright (C) 2025 Silimate Inc. * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -45,7 +44,9 @@ struct OptVpsWorker int pmux_replaced = 0; int reduce_or_replaced = 0; int feedback_collapsed = 0; + int vps_reads_replaced = 0; int min_stride; + pool vps_shr_cells; OptVpsWorker(Module *module, int min_stride) : module(module), sigmap(module), min_stride(min_stride) @@ -148,6 +149,113 @@ struct OptVpsWorker return -1; } + // Extract the constant addend from a binary_index signal. + // If binary_index = $add(dynamic, C) or $add(C, dynamic), + // return C. Otherwise return 0. Handles chains of + // $add/$sub up to 8 levels deep. + // Evaluate a signal assuming all primary inputs are 0. + // Uses recursive constant propagation through the driver + // graph. Handles $add, $sub, $not, $and, $or, $xor, $shl, + // $shr, $neg and pass-through (no driver → 0). + int eval_at_zero(SigSpec sig) + { + dict cache; + int result = 0; + for (int i = 0; i < GetSize(sig); i++) { + int v = eval_bit_at_zero(sigmap(sig[i]), cache, 0); + result |= (v << i); + } + return result; + } + + int eval_bit_at_zero(SigBit b, dict &cache, int depth) + { + if (b == State::S0) return 0; + if (b == State::S1) return 1; + if (!b.wire) return 0; + + auto it = cache.find(b); + if (it != cache.end()) return it->second; + if (depth > 64) return 0; + + cache[b] = 0; + + Cell *drv = bit_drivers.at(b, nullptr); + if (!drv) return 0; + + int bit_pos = -1; + SigSpec dy = drv->getPort(ID::Y); + for (int j = 0; j < GetSize(dy); j++) + if (sigmap(dy[j]) == b) { bit_pos = j; break; } + if (bit_pos < 0) return 0; + + int64_t av = 0, bv = 0; + SigSpec pa = drv->getPort(ID::A); + SigSpec pb = drv->hasPort(ID::B) ? drv->getPort(ID::B) : SigSpec(); + for (int i = 0; i < GetSize(pa) && i < 62; i++) + av |= ((int64_t)eval_bit_at_zero(sigmap(pa[i]), cache, + depth+1) << i); + for (int i = 0; i < GetSize(pb) && i < 62; i++) + bv |= ((int64_t)eval_bit_at_zero(sigmap(pb[i]), cache, + depth+1) << i); + + int64_t rv = 0; + if (drv->type == ID($add)) rv = av + bv; + else if (drv->type == ID($sub)) rv = av - bv; + else if (drv->type == ID($mul)) rv = av * bv; + else if (drv->type == ID($and) || drv->type == ID($_AND_)) rv = av & bv; + else if (drv->type == ID($or) || drv->type == ID($_OR_)) rv = av | bv; + else if (drv->type == ID($xor) || drv->type == ID($_XOR_)) rv = av ^ bv; + else if (drv->type == ID($not) || drv->type == ID($_NOT_)) rv = ~av; + else if (drv->type == ID($neg)) rv = -av; + else if (drv->type == ID($shl) || drv->type == ID($sshl)) rv = av << bv; + else if (drv->type == ID($shr) || drv->type == ID($sshr)) rv = av >> bv; + else if (drv->type == ID($mux)) { + SigSpec sp = drv->getPort(ID::S); + int sv = eval_bit_at_zero(sigmap(sp[0]), cache, depth+1); + rv = sv ? bv : av; + } else { + cache[b] = 0; + return 0; + } + + int val = (rv >> bit_pos) & 1; + cache[b] = val; + return val; + } + + // Trace a signal back through the driver graph to find the + // set of root bits (primary inputs / FF outputs) that + // influence it. Returns them as a sorted SigSpec. + SigSpec trace_input_roots(SigSpec sig) + { + pool roots, visited; + std::vector worklist; + for (auto bit : sig) + worklist.push_back(sigmap(bit)); + while (!worklist.empty()) { + SigBit b = worklist.back(); + worklist.pop_back(); + if (!visited.insert(b).second) + continue; + Cell *drv = bit_drivers.at(b, nullptr); + if (!drv) { + if (b.wire) + roots.insert(b); + continue; + } + for (auto &conn : drv->connections()) + if (!drv->output(conn.first)) + for (auto bit2 : conn.second) + worklist.push_back(sigmap(bit2)); + } + SigSpec result; + for (auto b : roots) + result.append(b); + result.sort(); + return result; + } + void run() { std::vector decoders; @@ -155,8 +263,1017 @@ struct OptVpsWorker if (is_decoder_shl(cell)) decoders.push_back(cell); + // --- Cross-decoder VPS read merge --- + // Collect stride-1 VPS read candidates across ALL decoders. + // Group by the underlying SOURCE REGISTER (identified by the + // wire of the first reconstructed source bit). When multiple + // reads extract adjacent byte lanes from the same register + // with verified data overlap, merge them into one wider + // barrel shifter. + + struct XReadCandidate { + Cell *decoder; + Cell *pmux; + int W; + int base; + int valid_n; + std::vector s_indices; + Wire *src_wire; // register wire (from first source bit) + int src_offset; // offset of first source bit in that wire + int idx_const; // constant part of decoder binary_index + SigSpec idx_roots; // primary input bits influencing decoder + }; + + std::vector all_reads; + + for (auto decoder : decoders) { + SigSpec decoder_y = decoder->getPort(ID::Y); + + for (auto cell : module->selected_cells()) { + if (cell->type != ID($pmux)) + continue; + int W = cell->getParam(ID::WIDTH).as_int(); + if (W <= 1) continue; + SigSpec sig_a = cell->getPort(ID::A); + if (!sig_a.is_fully_zero()) continue; + + SigSpec sig_s = cell->getPort(ID::S); + SigSpec sig_b = cell->getPort(ID::B); + int N = GetSize(sig_s); + + std::vector dec_positions, s_indices; + for (int i = 0; i < N; i++) { + SigBit sb = sigmap(sig_s[i]); + if (sb == State::S0) continue; + int pos = trace_to_decoder_pos(sig_s[i], decoder_y); + if (pos < 0) break; + dec_positions.push_back(pos); + s_indices.push_back(i); + } + if (GetSize(dec_positions) < 2) continue; + + bool contiguous = true; + for (int i = 1; i < GetSize(dec_positions); i++) + if (dec_positions[i] != dec_positions[0] + i) + { contiguous = false; break; } + if (!contiguous) continue; + + int sliding_n = 1; + for (int k = 0; k < GetSize(s_indices) - 1; k++) { + int si_cur = s_indices[k], si_nxt = s_indices[k + 1]; + bool ok = true; + for (int j = 1; j < W && ok; j++) + if (sigmap(sig_b[si_cur * W + j]) != + sigmap(sig_b[si_nxt * W + (j - 1)])) + ok = false; + if (!ok) break; + sliding_n = k + 2; + } + if (sliding_n < 2) continue; + + int base = dec_positions[0]; + + // Find first source bit with a valid wire (skip + // don't-care bits that arise when base < W-1) + Wire *reg_wire = nullptr; + int reg_offset = -1; + for (int k = 0; k < sliding_n + W - 1 && !reg_wire; k++) { + int idx = std::min(k, sliding_n - 1); + int si = s_indices[idx]; + int j = k - idx; + if (j >= W) break; + SigBit sb = sigmap(sig_b[si * W + j]); + if (sb.wire) { + reg_wire = sb.wire; + reg_offset = sb.offset - k; + } + } + if (!reg_wire) continue; + + SigSpec binary_idx = decoder->getPort(ID::B); + SigSpec roots = trace_input_roots(binary_idx); + int idx_c = eval_at_zero(binary_idx); + + all_reads.push_back({decoder, cell, W, base, sliding_n, + {s_indices.begin(), s_indices.begin() + sliding_n}, + reg_wire, reg_offset, idx_c, roots}); + } + } + + // Group by (source wire, width, index root bits) + struct SrcKey { + Wire *wire; + int W; + SigSpec roots; + bool operator<(const SrcKey &o) const { + if (wire != o.wire) return wire < o.wire; + if (W != o.W) return W < o.W; + return roots < o.roots; + } + }; + std::map> src_groups; + for (int i = 0; i < GetSize(all_reads); i++) + src_groups[{all_reads[i].src_wire, all_reads[i].W, + all_reads[i].idx_roots}].push_back(i); + + // Compute effective register offset: for reads sharing + // the same data port but using different decoders with + // different constant offsets in binary_index, the + // effective offset = src_offset + idx_const. This lets + // reads like source[(n+31)-:32] and source[(n+63)-:32] + // be recognized as accessing adjacent 32-bit windows. + auto eff_offset = [&](const XReadCandidate &r) -> int { + return r.src_offset + r.idx_const; + }; + + for (auto &[key, indices] : src_groups) { + if (GetSize(indices) < 2) continue; + int W0 = key.W; + + // Sort by effective offset (data position in register) + std::sort(indices.begin(), indices.end(), + [&](int a, int b) { + return eff_offset(all_reads[a]) < eff_offset(all_reads[b]); + }); + + + + // Find maximal contiguous runs where effective offset + // values differ by exactly W + int run_start = 0; + while (run_start < GetSize(indices)) { + int run_end = run_start + 1; + while (run_end < GetSize(indices) && + eff_offset(all_reads[indices[run_end]]) == + eff_offset(all_reads[indices[run_end-1]]) + W0) + run_end++; + + int run_len = run_end - run_start; + if (run_len < 2) { + run_start = run_end; + continue; + } + + // Check if reads come from different decoders + bool different_decoders = false; + for (int ri = run_start + 1; ri < run_end; ri++) + if (all_reads[indices[ri]].idx_const != + all_reads[indices[run_start]].idx_const) + { different_decoders = true; break; } + + auto &lowest = all_reads[indices[run_start]]; + int combined_W = W0 * run_len; + + if (different_decoders) { + // Different decoders: build source from + // the register wire directly and compute + // shift from the dynamic part of binary_index. + Wire *reg = lowest.src_wire; + int lowest_eff = eff_offset(lowest); + + // The source covers the register range + // needed by all reads across all dynamic + // index values. valid_n gives the decoder + // range; combined_W is the output width. + int src_start = std::max(0, lowest_eff); + int src_end = std::min(GetSize(reg) - 1, + lowest_eff + lowest.valid_n + + combined_W - 2); + if (src_end < src_start) { + run_start = run_end; + continue; + } + SigSpec source = SigSpec(reg, src_start, + src_end - src_start + 1); + + // Shift amount: the dynamic part. For + // binary_index = n + idx_const and VPS + // [idx -: W], the register position of + // output bit 0 is idx - W + 1 = n + + // idx_const - W + 1. The lowest eff_offset + // is src_offset + idx_const_lowest, and + // register bit at eff_offset is at source + // position eff_offset - src_start. So + // shift = (n + idx_const - W + 1 - + // src_offset) - src_start. Since + // eff_offset = src_offset + idx_const, + // shift = n - W + 1 + eff_offset - + // src_start = n - W + 1 + lowest_eff - + // src_start. + // + // Equivalently, shift = binary_index - + // (W - 1) - src_offset - src_start. Since + // binary_index = decoder B port and + // src_offset = lowest.src_offset: + SigSpec binary_index = lowest.decoder->getPort(ID::B); + int base_sub = (W0 - 1) + src_start; + + SigSpec raw_idx; + if (base_sub != 0) { + Wire *sub_w = module->addWire( + NEW_ID_SUFFIX("vps_merge_idx"), + GetSize(binary_index)); + module->addSub(NEW_ID_SUFFIX("vps_merge_sub"), + binary_index, + Const(base_sub, GetSize(binary_index)), + sub_w); + raw_idx = SigSpec(sub_w); + } else { + raw_idx = binary_index; + } + + // Detect alignment: look for constant + // lower bits in the shift expression, + // tracing through $add/$sub if needed + // (mirrors process_vps_reads). + auto count_const_lower = [&](SigSpec sig) -> std::pair { + int count = 0, value = 0; + for (int i = 0; i < GetSize(sig); i++) { + SigBit b = sigmap(sig[i]); + if (b == State::S0) count++; + else if (b == State::S1) { value |= (1 << i); count++; } + else break; + } + return {count, value}; + }; + + int log2_align = 0; + int fixed_lower = 0; + { + auto [n0, v0] = count_const_lower(binary_index); + if (n0 > 0) { log2_align = n0; fixed_lower = v0; } + if (log2_align == 0) { + Cell *drv = nullptr; + for (int i = 0; i < GetSize(binary_index); i++) { + Cell *d = bit_drivers.at(sigmap(binary_index[i]), nullptr); + if (!d) { drv = nullptr; break; } + if (!drv) drv = d; + else if (drv != d) { drv = nullptr; break; } + } + if (drv && (drv->type == ID($add) || drv->type == ID($sub))) { + SigSpec aa = drv->getPort(ID::A); + SigSpec ab = drv->getPort(ID::B); + SigSpec non_const; + int offset = 0; + bool is_sub = (drv->type == ID($sub)); + if (aa.is_fully_const()) { offset = aa.as_int(); non_const = ab; } + else if (ab.is_fully_const()) { offset = ab.as_int(); non_const = aa; } + if (GetSize(non_const) > 0) { + auto [nc, nv] = count_const_lower(non_const); + if (nc > 0) { + log2_align = nc; + int mask = (1 << nc) - 1; + if (is_sub) { + if (non_const == ab) + fixed_lower = ((offset & mask) - nv) & mask; + else + fixed_lower = (nv - (offset & mask)) & mask; + } else { + fixed_lower = (nv + (offset & mask)) & mask; + } + } + } + } + } + } + + SigSpec shift_amount; + if (log2_align > 0) { + int adj_lower = (fixed_lower - (base_sub & ((1 << log2_align) - 1))) + & ((1 << log2_align) - 1); + for (int i = 0; i < log2_align; i++) + shift_amount.append((adj_lower >> i) & 1 ? State::S1 : State::S0); + shift_amount.append(raw_idx.extract( + log2_align, GetSize(raw_idx) - log2_align)); + } else { + shift_amount = raw_idx; + } + + Wire *merged_y = module->addWire( + NEW_ID_SUFFIX("vps_merge_y"), combined_W); + Cell *shr = module->addShr(NEW_ID_SUFFIX("vps_merge_shr"), + source, shift_amount, SigSpec(merged_y)); + shr->add_strpool_attribute(ID::src, + lowest.pmux->get_strpool_attribute(ID::src)); + vps_shr_cells.insert(shr); + + int lowest_eff_off = eff_offset(lowest); + for (int i = 0; i < run_len; i++) { + auto &r = all_reads[indices[run_start + i]]; + SigSpec pmux_y = r.pmux->getPort(ID::Y); + int byte_offset = eff_offset(r) - lowest_eff_off; + module->connect(pmux_y, + SigSpec(merged_y, byte_offset, r.W)); + + log(" VPS xread merge: pmux %s (W=%d, eff=%d)" + " -> merged $shr [%d:%d]\n", + log_id(r.pmux->name), r.W, eff_offset(r), + byte_offset + r.W - 1, byte_offset); + + module->remove(r.pmux); + pmux_replaced++; + vps_reads_replaced++; + } + + log(" VPS merged %d cross-decoder reads -> $shr" + " WIDTH=%d, src=%d, align=%d\n", + run_len, combined_W, GetSize(source), + log2_align > 0 ? (1 << log2_align) : 1); + groups_optimized++; + run_start = run_end; + continue; + } + + // Same decoder: verify data overlap between + // adjacent reads. + // source[k] = sig_b[s_indices[min(k,vn-1)] * W + k - min(k,vn-1)] + bool overlap_ok = true; + for (int ri = run_start; ri < run_end - 1 && overlap_ok; ri++) { + auto &rA = all_reads[indices[ri]]; + auto &rB = all_reads[indices[ri + 1]]; + SigSpec bA = rA.pmux->getPort(ID::B); + SigSpec bB = rB.pmux->getPort(ID::B); + int check_len = std::min(W0, + std::min(rA.valid_n, rB.valid_n)); + for (int c = 0; c < check_len && overlap_ok; c++) { + int posA = c + W0; + int idxA = std::min(posA, rA.valid_n - 1); + int jA = posA - idxA; + int idxB = std::min(c, rB.valid_n - 1); + int jB = c - idxB; + if (jA >= W0 || jB >= W0) break; + if (sigmap(bA[rA.s_indices[idxA] * W0 + jA]) != + sigmap(bB[rB.s_indices[idxB] * W0 + jB])) + overlap_ok = false; + } + } + + if (!overlap_ok) { + log(" Run [%d..%d]: overlap check failed\n", + run_start, run_end - 1); + run_start = run_end; + continue; + } + + // --- Merge this run into one wider barrel shifter --- + int lowest_base = lowest.base; + + Cell *ref_pmux = lowest.pmux; + SigSpec ref_sig_b = ref_pmux->getPort(ID::B); + int ref_valid_n = lowest.valid_n; + int source_width = ref_valid_n + combined_W - 1; + SigSpec source; + for (int k = 0; k < source_width; k++) { + int idx = std::min(k, ref_valid_n - 1); + int si = lowest.s_indices[idx]; + int j = k - idx; + if (j < W0) + source.append(sigmap(ref_sig_b[si * W0 + j])); + else { + int extra = j - W0; + int next_ri = run_start + 1 + extra / W0; + if (next_ri < run_end) { + auto &rN = all_reads[indices[next_ri]]; + SigSpec bN = rN.pmux->getPort(ID::B); + int siN = rN.s_indices[std::min(idx, rN.valid_n - 1)]; + int jN = extra % W0; + source.append(sigmap(bN[siN * W0 + jN])); + } else { + source.append(State::S0); + } + } + } + + // Truncate to actual register range + Wire *merge_reg = nullptr; + for (int i = 0; i < GetSize(source); i++) { + SigBit b = source[i]; + if (b.wire) { merge_reg = b.wire; break; } + } + if (merge_reg) { + int first_reg = -1, last_reg = -1; + for (int i = 0; i < GetSize(source); i++) { + SigBit b = source[i]; + if (b.wire == merge_reg) { + if (first_reg < 0) first_reg = i; + last_reg = i; + } + } + if (first_reg > 0 || last_reg < GetSize(source) - 1) { + int new_len = last_reg - first_reg + 1; + source = source.extract(first_reg, new_len); + lowest_base += first_reg; + } + } + + SigSpec binary_index = lowest.decoder->getPort(ID::B); + SigSpec shift_amount; + SigSpec raw_idx = binary_index; + if (lowest_base > 0) { + Wire *sub_w = module->addWire( + NEW_ID_SUFFIX("vps_merge_idx"), + GetSize(binary_index)); + module->addSub(NEW_ID_SUFFIX("vps_merge_sub"), + binary_index, + Const(lowest_base, GetSize(binary_index)), + sub_w); + raw_idx = SigSpec(sub_w); + } + + auto count_const_lower = [&](SigSpec sig) -> std::pair { + int count = 0, value = 0; + for (int i = 0; i < GetSize(sig); i++) { + SigBit b = sigmap(sig[i]); + if (b == State::S0) count++; + else if (b == State::S1) { value |= (1 << i); count++; } + else break; + } + return {count, value}; + }; + + int log2_align = 0; + int fixed_lower = 0; + { + auto [n0, v0] = count_const_lower(binary_index); + if (n0 > 0) { log2_align = n0; fixed_lower = v0; } + if (log2_align == 0) { + Cell *drv = nullptr; + for (int i = 0; i < GetSize(binary_index); i++) { + Cell *d = bit_drivers.at(sigmap(binary_index[i]), nullptr); + if (!d) { drv = nullptr; break; } + if (!drv) drv = d; + else if (drv != d) { drv = nullptr; break; } + } + if (drv && (drv->type == ID($add) || drv->type == ID($sub))) { + SigSpec aa = drv->getPort(ID::A); + SigSpec ab = drv->getPort(ID::B); + SigSpec non_const; + int offset = 0; + bool is_sub = (drv->type == ID($sub)); + if (aa.is_fully_const()) { offset = aa.as_int(); non_const = ab; } + else if (ab.is_fully_const()) { offset = ab.as_int(); non_const = aa; } + if (GetSize(non_const) > 0) { + auto [nc, nv] = count_const_lower(non_const); + if (nc > 0) { + log2_align = nc; + int mask = (1 << nc) - 1; + if (is_sub) { + if (non_const == ab) + fixed_lower = ((offset & mask) - nv) & mask; + else + fixed_lower = (nv - (offset & mask)) & mask; + } else { + fixed_lower = (nv + (offset & mask)) & mask; + } + } + } + } + } + } + + if (log2_align > 0) { + int adj_lower = (fixed_lower - (lowest_base & ((1 << log2_align) - 1))) + & ((1 << log2_align) - 1); + for (int i = 0; i < log2_align; i++) + shift_amount.append((adj_lower >> i) & 1 ? State::S1 : State::S0); + shift_amount.append(raw_idx.extract( + log2_align, GetSize(binary_index) - log2_align)); + } else { + shift_amount = raw_idx; + } + + Wire *merged_y = module->addWire( + NEW_ID_SUFFIX("vps_merge_y"), combined_W); + Cell *shr = module->addShr(NEW_ID_SUFFIX("vps_merge_shr"), + source, shift_amount, SigSpec(merged_y)); + shr->add_strpool_attribute(ID::src, + ref_pmux->get_strpool_attribute(ID::src)); + vps_shr_cells.insert(shr); + + int lowest_eff_off = eff_offset(lowest); + for (int i = 0; i < run_len; i++) { + auto &r = all_reads[indices[run_start + i]]; + SigSpec pmux_y = r.pmux->getPort(ID::Y); + int byte_offset = eff_offset(r) - lowest_eff_off; + module->connect(pmux_y, + SigSpec(merged_y, byte_offset, r.W)); + + log(" VPS read merge: pmux %s (WIDTH=%d, base=%d)" + " -> merged $shr [%d:%d]\n", + log_id(r.pmux->name), r.W, r.base, + byte_offset + r.W - 1, byte_offset); + + module->remove(r.pmux); + pmux_replaced++; + vps_reads_replaced++; + } + + log(" VPS merged %d reads (WIDTH=%d each) -> $shr WIDTH=%d%s\n", + run_len, W0, combined_W, + log2_align > 0 ? + stringf(", align=%d", 1 << log2_align).c_str() : ""); + groups_optimized++; + + run_start = run_end; + } + } + + // Process remaining decoders normally (for VPS writes and + // unmerged VPS reads — merged reads' $pmux cells were + // already removed, so they won't be found again) for (auto decoder : decoders) process_decoder(decoder); + + // --- Shared barrel shifter merge --- + // After all VPS reads have been converted to $shr cells, + // find groups that access the same register with byte-aligned + // shifts sharing the same dynamic index variable. Replace + // each group with a single barrel shifter whose output feeds + // all reads via simple wire slices. + merge_shared_barrel_shifters(); + } + + void merge_shared_barrel_shifters() + { + // Rebuild bit_drivers to include cells created during + // process_vps_reads (e.g. $sub cells for index adjustment) + bit_drivers.clear(); + for (auto cell : module->cells()) + for (auto &conn : cell->connections()) + if (cell->output(conn.first)) + for (int i = 0; i < GetSize(conn.second); i++) { + SigBit bit = sigmap(conn.second[i]); + bit_drivers[bit] = cell; + } + + struct ShrInfo { + Cell *shr; + Wire *reg_wire; + int reg_offset; // bit offset of read within register + int output_width; + SigSpec shift_variable; // the variable (non-constant) upper shift bits + int const_shift_lower; // constant value of lower shift bits + int shift_align; // number of constant lower shift bits + }; + + std::vector shr_infos; + for (auto *shr : vps_shr_cells) { + if (!shr->type.in(ID($shr))) + continue; + + SigSpec source = shr->getPort(ID::A); + SigSpec shift = shr->getPort(ID::B); + SigSpec output = shr->getPort(ID::Y); + int out_w = GetSize(output); + + // Find the register wire: all source bits must + // come from the same wire + Wire *reg_wire = nullptr; + bool single_wire = true; + for (auto b : source) { + if (!b.wire) continue; + if (!reg_wire) reg_wire = b.wire; + else if (b.wire != reg_wire) { single_wire = false; break; } + } + if (!reg_wire || !single_wire) + continue; + + // Determine the register bit offset: position of + // output bit 0 relative to the register when the + // variable part of the shift is zero + // For the source, find the offset of source[0] within reg_wire + int src_base = -1; + for (int i = 0; i < GetSize(source); i++) { + SigBit b = source[i]; + if (b.wire == reg_wire) { + src_base = b.offset - i; + break; + } + } + if (src_base < 0) continue; + + // Count constant lower shift bits + int shift_align = 0; + int const_lower = 0; + for (int i = 0; i < GetSize(shift); i++) { + SigBit b = sigmap(shift[i]); + if (b == State::S0) + shift_align++; + else if (b == State::S1) + { const_lower |= (1 << i); shift_align++; } + else + break; + } + + SigSpec shift_var = shift.extract(shift_align, + GetSize(shift) - shift_align); + + // reg_offset: register position of output bit 0 when + // all dynamic inputs are zero. Evaluated by + // constant-propagating the full shift signal. + int shift_at_zero = eval_at_zero(shift); + int reg_offset = src_base + shift_at_zero; + + shr_infos.push_back({shr, reg_wire, reg_offset, + out_w, shift_var, const_lower, shift_align}); + } + + if (shr_infos.empty()) + return; + + // Group by (register wire, input root bits of shift + // variable, alignment). Using trace_input_roots lets + // reads with different carry patterns but the same + // underlying dynamic variable group together. + struct MergeKey { + Wire *wire; + SigSpec roots; + int align; + bool operator<(const MergeKey &o) const { + if (wire != o.wire) return wire < o.wire; + if (align != o.align) return align < o.align; + return roots < o.roots; + } + }; + + std::map> groups; + for (int i = 0; i < GetSize(shr_infos); i++) { + auto &info = shr_infos[i]; + SigSpec roots = trace_input_roots(info.shift_variable); + groups[{info.reg_wire, roots, info.shift_align}].push_back(i); + } + + for (auto &[key, indices] : groups) { + if (GetSize(indices) < 2) + continue; + + Wire *reg = key.wire; + int reg_width = reg->width; + int align = key.align; + + // Find the reference read (lowest reg_offset) + int ref_idx = indices[0]; + for (int idx : indices) + if (shr_infos[idx].reg_offset < shr_infos[ref_idx].reg_offset) + ref_idx = idx; + auto &ref_info = shr_infos[ref_idx]; + int ref_offset = ref_info.reg_offset; + + // Use the reference read's full shift signal + // for the shared barrel shifter + SigSpec ref_shift = ref_info.shr->getPort(ID::B); + SigSpec reg_source(reg); + + Wire *shared_y = module->addWire( + NEW_ID_SUFFIX("vps_shared_y"), reg_width); + Cell *shared_shr = module->addShr( + NEW_ID_SUFFIX("vps_shared_shr"), + reg_source, ref_shift, SigSpec(shared_y)); + shared_shr->add_strpool_attribute(ID::src, + ref_info.shr->get_strpool_attribute(ID::src)); + + log(" VPS shared barrel shifter: %s (reg=%s, width=%d, " + "align=%d, serves %d reads, ref_offset=%d)\n", + log_id(shared_shr->name), log_id(reg->name), + reg_width, 1 << align, GetSize(indices), ref_offset); + + for (int idx : indices) { + auto &info = shr_infos[idx]; + SigSpec orig_y = info.shr->getPort(ID::Y); + int off = info.reg_offset - ref_offset; + if (off < 0) off = 0; + + SigSpec slice; + for (int j = 0; j < info.output_width; j++) { + int pos = off + j; + if (pos >= 0 && pos < reg_width) + slice.append(SigBit(shared_y, pos)); + else + slice.append(State::S0); + } + module->connect(orig_y, slice); + + log(" read %s: WIDTH=%d, reg_offset=%d, " + "slice_offset=%d -> shared[%d:%d]\n", + log_id(info.shr->name), info.output_width, + info.reg_offset, off, + off, off + info.output_width - 1); + + module->remove(info.shr); + } + + groups_optimized++; + } + } + + void process_vps_reads(Cell *decoder) + { + SigSpec decoder_y = decoder->getPort(ID::Y); + SigSpec binary_index = decoder->getPort(ID::B); + + struct ReadCandidate { + Cell *cell; + std::vector dec_positions; + std::vector s_indices; + int valid_n; + bool strided; + }; + std::vector read_candidates; + + for (auto cell : module->selected_cells()) { + if (cell->type != ID($pmux)) + continue; + int W = cell->getParam(ID::WIDTH).as_int(); + if (W <= 1) + continue; + + SigSpec sig_a = cell->getPort(ID::A); + if (!sig_a.is_fully_zero()) + continue; + + SigSpec sig_s = cell->getPort(ID::S); + SigSpec sig_b = cell->getPort(ID::B); + int N = GetSize(sig_s); + + // Trace S bits to decoder positions, skipping constant-zero + // padding bits (Verific may insert zeros between one-hot bits + // and append overflow bits at the MSB). + std::vector dec_positions; + std::vector s_indices; + for (int i = 0; i < N; i++) { + SigBit sb = sigmap(sig_s[i]); + if (sb == State::S0) + continue; + int pos = trace_to_decoder_pos(sig_s[i], decoder_y); + if (pos < 0) + break; + dec_positions.push_back(pos); + s_indices.push_back(i); + } + if (GetSize(dec_positions) < 2) + continue; + + // Check that decoder positions are contiguous + bool contiguous = true; + for (int i = 1; i < GetSize(dec_positions); i++) { + if (dec_positions[i] != dec_positions[0] + i) { + contiguous = false; + break; + } + } + if (!contiguous) + continue; + + // Check for sliding window (stride-1) pattern in B + int sliding_n = 1; + for (int k = 0; k < GetSize(s_indices) - 1; k++) { + int si_cur = s_indices[k]; + int si_nxt = s_indices[k + 1]; + bool ok = true; + for (int j = 1; j < W && ok; j++) + if (sigmap(sig_b[si_cur * W + j]) != sigmap(sig_b[si_nxt * W + (j - 1)])) + ok = false; + if (!ok) + break; + sliding_n = k + 2; + } + + if (sliding_n >= 2) { + read_candidates.push_back({cell, dec_positions, s_indices, sliding_n, false}); + continue; + } + + // No stride-1 overlap; fall back to general window + // packing (stride=W) which works for any W that is a + // power of 2. The packed source is built by + // concatenating B-port windows for each valid select + // line; the $shr indexes it with binary_index << log2(W). + bool strided_ok = (W & (W - 1)) == 0 && GetSize(s_indices) >= 2; + if (strided_ok) + read_candidates.push_back({cell, dec_positions, s_indices, GetSize(s_indices), true}); + } + + for (auto &rc : read_candidates) { + Cell *cell = rc.cell; + int W = cell->getParam(ID::WIDTH).as_int(); + SigSpec sig_b = cell->getPort(ID::B); + SigSpec sig_y = cell->getPort(ID::Y); + int full_s = GetSize(cell->getPort(ID::S)); + int valid_n = rc.valid_n; + int base = rc.dec_positions[0]; + + // Detect if binary_index has constant lower bits. + // Verific encodes VPS reads as `source[(idx + W-1) -: W]` + // where idx = stride * k. The decoder's B port is then + // `stride * k + offset`, driven by a chain of + // $mul(stride, k) → $add(., offset) + // When stride is a power of 2, the lower log2(stride) + // bits of (stride*k + offset) are the constant value + // (offset & (stride-1)). Making those bits structural + // constants lets techmap's constmap skip the + // corresponding barrel-shifter stages. + int log2_align = 0; + int fixed_lower = 0; + { + auto count_const_lower_bits = [&](SigSpec sig) -> std::pair { + int count = 0, value = 0; + for (int i = 0; i < GetSize(sig); i++) { + SigBit b = sigmap(sig[i]); + if (b == State::S0) + count++; + else if (b == State::S1) { + value |= (1 << i); + count++; + } else + break; + } + return {count, value}; + }; + + auto find_sole_driver = [&](SigSpec sig) -> Cell * { + Cell *drv = nullptr; + for (int i = 0; i < GetSize(sig); i++) { + Cell *d = bit_drivers.at(sigmap(sig[i]), nullptr); + if (!d) return nullptr; + if (!drv) drv = d; + else if (drv != d) return nullptr; + } + return drv; + }; + + auto [n0, v0] = count_const_lower_bits(binary_index); + if (n0 > 0) { + log2_align = n0; + fixed_lower = v0; + } + + if (log2_align == 0) { + Cell *drv = find_sole_driver(binary_index); + if (drv && (drv->type == ID($add) || drv->type == ID($sub))) { + SigSpec aa = drv->getPort(ID::A); + SigSpec ab = drv->getPort(ID::B); + SigSpec non_const; + int offset = 0; + bool is_sub = (drv->type == ID($sub)); + if (aa.is_fully_const()) { + offset = aa.as_int(); + non_const = ab; + } else if (ab.is_fully_const()) { + offset = ab.as_int(); + non_const = aa; + } + if (GetSize(non_const) > 0) { + auto [nc, nv] = count_const_lower_bits(non_const); + if (nc > 0) { + log2_align = nc; + int mask = (1 << nc) - 1; + if (is_sub) { + if (non_const == ab) + fixed_lower = ((offset & mask) - nv) & mask; + else + fixed_lower = (nv - (offset & mask)) & mask; + } else { + fixed_lower = (nv + (offset & mask)) & mask; + } + } + } + } + } + } + + int src_bits = 0; + if (!rc.strided) { + // Stride-1: reconstruct source from overlapping windows + int source_width = valid_n + W - 1; + SigSpec source; + for (int k = 0; k < source_width; k++) { + int idx = std::min(k, valid_n - 1); + int si = rc.s_indices[idx]; + int j = k - idx; + source.append(sigmap(sig_b[si * W + j])); + } + + // Identify the register wire from the source bits + // and truncate to the actual register range + Wire *reg_wire = nullptr; + int reg_lo = INT_MAX, reg_hi = INT_MIN; + for (int i = 0; i < GetSize(source); i++) { + SigBit b = source[i]; + if (b.wire) { + if (!reg_wire) reg_wire = b.wire; + if (b.wire == reg_wire) { + reg_lo = std::min(reg_lo, b.offset); + reg_hi = std::max(reg_hi, b.offset); + } + } + } + if (reg_wire) { + int first_reg = -1, last_reg = -1; + for (int i = 0; i < GetSize(source); i++) { + SigBit b = source[i]; + if (b.wire == reg_wire) { + if (first_reg < 0) first_reg = i; + last_reg = i; + } + } + if (first_reg > 0 || last_reg < GetSize(source) - 1) { + int new_len = last_reg - first_reg + 1; + source = source.extract(first_reg, new_len); + base += first_reg; + } + } + + SigSpec shift_amount; + SigSpec raw_idx = binary_index; + if (base > 0) { + Wire *sub_w = module->addWire(NEW_ID_SUFFIX("vps_rd_idx"), GetSize(binary_index)); + module->addSub(NEW_ID_SUFFIX("vps_rd_sub"), + binary_index, Const(base, GetSize(binary_index)), sub_w); + raw_idx = SigSpec(sub_w); + } + if (log2_align > 0) { + int adj_lower = (fixed_lower - (base & ((1 << log2_align) - 1))) + & ((1 << log2_align) - 1); + for (int i = 0; i < log2_align; i++) + shift_amount.append((adj_lower >> i) & 1 ? State::S1 : State::S0); + shift_amount.append(raw_idx.extract( + log2_align, GetSize(binary_index) - log2_align)); + } else { + shift_amount = raw_idx; + } + + src_bits = GetSize(source); + Cell *shr = module->addShr(NEW_ID_SUFFIX("vps_rd_shr"), + source, shift_amount, sig_y); + shr->add_strpool_attribute(ID::src, cell->get_strpool_attribute(ID::src)); + vps_shr_cells.insert(shr); + } else { + // Stride=W: pack windows sequentially, shift by W*binary_index + SigSpec packed; + for (int k = 0; k < valid_n; k++) { + int si = rc.s_indices[k]; + for (int j = 0; j < W; j++) + packed.append(sigmap(sig_b[si * W + j])); + } + + // Identify the register wire from packed source bits + // and trim trailing windows that have no register data + Wire *reg_wire = nullptr; + for (int i = 0; i < GetSize(packed); i++) { + SigBit b = packed[i]; + if (b.wire) { reg_wire = b.wire; break; } + } + if (reg_wire) { + int last_valid_window = -1; + for (int k = 0; k < valid_n; k++) { + for (int j = 0; j < W; j++) { + if (packed[k * W + j].wire == reg_wire) + { last_valid_window = k; break; } + } + } + if (last_valid_window >= 0 && + last_valid_window < valid_n - 1) { + packed = packed.extract(0, + (last_valid_window + 1) * W); + } + } + + int log2w = 0; + for (int v = W; v > 1; v >>= 1) + log2w++; + + SigSpec shifted_idx; + shifted_idx.append(Const(0, log2w)); + if (log2_align > 0) { + for (int i = 0; i < log2_align; i++) + shifted_idx.append((fixed_lower >> i) & 1 ? + State::S1 : State::S0); + shifted_idx.append(binary_index.extract( + log2_align, GetSize(binary_index) - log2_align)); + } else { + shifted_idx.append(binary_index); + } + + src_bits = GetSize(packed); + Cell *shr = module->addShr(NEW_ID_SUFFIX("vps_rd_shr"), + packed, shifted_idx, sig_y); + shr->add_strpool_attribute(ID::src, cell->get_strpool_attribute(ID::src)); + vps_shr_cells.insert(shr); + } + + log(" VPS read: pmux %s (WIDTH=%d, S_WIDTH=%d/%d, base=%d, src=%d%s%s) -> $shr\n", + log_id(cell->name), W, valid_n, full_s, base, src_bits, + rc.strided ? ", strided" : "", + log2_align > 0 ? + stringf(", align=%d", 1 << log2_align).c_str() : ""); + + module->remove(cell); + pmux_replaced++; + vps_reads_replaced++; + } + + if (!read_candidates.empty()) + groups_optimized++; } void process_decoder(Cell *decoder) @@ -203,6 +1320,9 @@ struct OptVpsWorker candidates.push_back({cell, positions[0]}); } + // Detect VPS read patterns (WIDTH > 1) from this decoder + process_vps_reads(decoder); + if (candidates.empty()) return; @@ -224,7 +1344,7 @@ struct OptVpsWorker return a.window_start < b.window_start; }); - // Build position buckets: window_start → list of cells + // Build position buckets: window_start -> list of cells dict> by_pos; for (auto &c : cells) by_pos[c.window_start].push_back(c); @@ -546,26 +1666,25 @@ struct OptVpsPass : public Pass { log("\n"); log(" opt_vps [options] [selection]\n"); log("\n"); - log("Detect variable-part-select (VPS) write patterns generated by Verific\n"); - log("and replace the per-bit sliding-window $pmux cells with per-lane\n"); - log("enable logic and direct data wiring.\n"); + log("Detect variable-part-select (VPS) patterns generated by Verific and\n"); + log("replace them with efficient equivalents.\n"); log("\n"); - log("Verific lowers VPS writes like `reg[idx -: W] <= data` into a\n"); + log("VPS WRITES: Verific lowers `reg[idx -: W] <= data` into a\n"); log("bit-granularity decoder ($shl with A=1) followed by overflow-gated\n"); log("AND gates and N sliding-window one-hot $pmux cells (one per output\n"); - log("bit, each with S_WIDTH=W). This structure has O(N*W) gates after\n"); - log("pmuxtree expansion.\n"); + log("bit, each with S_WIDTH=W). This pass recovers the lane structure\n"); + log("and replaces each W-entry $pmux with a single 2:1 $mux gated by a\n"); + log("shared per-lane enable, reducing gates from O(N*W) to O(N + N/W).\n"); log("\n"); - log("This pass recovers the lane structure and replaces each W-entry\n"); - log("$pmux with a single 2:1 $mux gated by a shared per-lane enable,\n"); - log("reducing the gate count to O(N + N/W).\n"); - log("\n"); - log("The pass also replaces per-bit $reduce_or enable cells with the\n"); - log("shared lane enable signal.\n"); + log("VPS READS: Verific lowers `out = reg[idx +: W]` into a one-hot\n"); + log("decoder plus a wide $pmux (WIDTH=W, S_WIDTH=N) that selects among\n"); + log("all N sliding windows. This pass detects the sliding-window pattern\n"); + log("and replaces the $pmux with a $shr barrel shifter, reducing gates\n"); + log("from O(N*W) to O(log(N)*W).\n"); log("\n"); log(" -min_stride \n"); - log(" Minimum stride (S_WIDTH of the $pmux cells) to consider.\n"); - log(" Default: 4.\n"); + log(" Minimum stride (S_WIDTH of the VPS write $pmux cells) to\n"); + log(" consider. Default: 4.\n"); log("\n"); } void execute(std::vector args, RTLIL::Design *design) override @@ -584,7 +1703,7 @@ struct OptVpsPass : public Pass { } extra_args(args, argidx, design); - int total_groups = 0, total_pmux = 0, total_ror = 0, total_fb = 0; + int total_groups = 0, total_pmux = 0, total_ror = 0, total_fb = 0, total_rd = 0; for (auto module : design->selected_modules()) { if (module->has_processes_warn()) @@ -595,20 +1714,23 @@ struct OptVpsPass : public Pass { if (worker.groups_optimized > 0) log(" Module %s: %d VPS group(s), %d $pmux replaced, " - "%d $reduce_or replaced, %d feedback collapsed.\n", + "%d $reduce_or replaced, %d feedback collapsed, " + "%d VPS reads -> $shr.\n", log_id(module->name), worker.groups_optimized, worker.pmux_replaced, worker.reduce_or_replaced, - worker.feedback_collapsed); + worker.feedback_collapsed, worker.vps_reads_replaced); total_groups += worker.groups_optimized; total_pmux += worker.pmux_replaced; total_ror += worker.reduce_or_replaced; total_fb += worker.feedback_collapsed; + total_rd += worker.vps_reads_replaced; } log("Optimized %d VPS group(s), %d $pmux replaced, " - "%d $reduce_or replaced, %d feedback collapsed.\n", - total_groups, total_pmux, total_ror, total_fb); + "%d $reduce_or replaced, %d feedback collapsed, " + "%d VPS reads -> $shr.\n", + total_groups, total_pmux, total_ror, total_fb, total_rd); } } OptVpsPass; diff --git a/tests/silimate/opt_vps.ys b/tests/silimate/opt_vps.ys index 6f35f8808..87327d9a6 100644 --- a/tests/silimate/opt_vps.ys +++ b/tests/silimate/opt_vps.ys @@ -157,3 +157,85 @@ select -assert-count 1 t:$dff design -reset log -pop +# ============================================================================= +# Test 7: SAT equivalence — VPS read vs right-shift reference +# Proves opt_vps produces a logically equivalent circuit to a hand-written +# right-shift for a 256-bit register with 32-bit read window. +# ============================================================================= +log -header "SAT equivalence: VPS read vs right-shift ref" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 + +read -sv opt_vps_read.sv +verific -import opt_vps_read +proc; opt_clean +opt_vps; opt_clean +rename opt_vps_read gate + +read -sv opt_vps_read_ref.sv +verific -import opt_vps_read +proc; opt_clean +rename opt_vps_read gold + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +clk2fflogic +sat -set-init-zero -tempinduct -prove-asserts -verify +design -reset +log -pop + +# ============================================================================= +# Test 8: SAT self-equivalence — VPS read before vs after opt_vps +# Proves opt_vps does not change the functional behavior for VPS reads. +# ============================================================================= +log -header "SAT self-equivalence: VPS read before vs after opt_vps" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 + +read -sv opt_vps_read.sv +verific -import opt_vps_read +proc; opt_clean +rename opt_vps_read gold + +read -sv opt_vps_read.sv +verific -import opt_vps_read +proc; opt_clean +opt_vps; opt_clean +rename opt_vps_read gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +clk2fflogic +sat -set-init-zero -tempinduct -prove-asserts -verify +design -reset +log -pop + +# ============================================================================= +# Test 9: Cell count verification — VPS read +# After opt_vps, the $pmux should be replaced with a $shr. +# ============================================================================= +log -header "Cell counts: VPS read post-opt_vps" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 + +read -sv opt_vps_read.sv +verific -import opt_vps_read +proc; opt_clean +opt_vps; opt_clean + +select -assert-none t:$pmux +select -assert-count 1 t:$shr +select -assert-count 1 t:$dff +select -assert-count 1 t:$mux +design -reset +log -pop + + diff --git a/tests/silimate/opt_vps_read.sv b/tests/silimate/opt_vps_read.sv new file mode 100644 index 000000000..8daffbea0 --- /dev/null +++ b/tests/silimate/opt_vps_read.sv @@ -0,0 +1,17 @@ +// Minimal variable-part-select (VPS) read: extracts a 32-bit word +// from a 256-bit register at a dynamic byte offset. +module opt_vps_read ( + input logic clk, + input logic wr_en, + input logic [7:0] index, + input logic [255:0] wdata, + output logic [31:0] q +); + logic [255:0] reg_data; + + always_ff @(posedge clk) + if (wr_en) + reg_data <= wdata; + + assign q = reg_data[index +: 32]; +endmodule diff --git a/tests/silimate/opt_vps_read_ref.sv b/tests/silimate/opt_vps_read_ref.sv new file mode 100644 index 000000000..a7c2992af --- /dev/null +++ b/tests/silimate/opt_vps_read_ref.sv @@ -0,0 +1,16 @@ +// Reference for VPS read: uses right-shift instead of variable part-select. +module opt_vps_read ( + input logic clk, + input logic wr_en, + input logic [7:0] index, + input logic [255:0] wdata, + output logic [31:0] q +); + logic [255:0] reg_data; + + always_ff @(posedge clk) + if (wr_en) + reg_data <= wdata; + + assign q = (reg_data >> index); +endmodule