From b3ea5770cd923fa8337c127f6b56d1d580be1b8b Mon Sep 17 00:00:00 2001 From: Akash Levy Date: Tue, 2 Jun 2026 04:11:17 -0700 Subject: [PATCH] opt_argmax pass --- passes/opt/Makefile.inc | 4 +- passes/opt/opt_argmax.cc | 775 +++++++++++++++++++++++++++++++++++ tests/silimate/opt_argmax.sv | 324 +++++++++++++++ tests/silimate/opt_argmax.ys | 332 +++++++++++++++ 4 files changed, 1434 insertions(+), 1 deletion(-) create mode 100644 passes/opt/opt_argmax.cc create mode 100644 tests/silimate/opt_argmax.sv create mode 100644 tests/silimate/opt_argmax.ys diff --git a/passes/opt/Makefile.inc b/passes/opt/Makefile.inc index 220f0d9f8..62f2c537c 100644 --- a/passes/opt/Makefile.inc +++ b/passes/opt/Makefile.inc @@ -22,11 +22,13 @@ OBJS += passes/opt/opt_lut_ins.o OBJS += passes/opt/opt_ffinv.o OBJS += passes/opt/pmux2shiftx.o OBJS += passes/opt/muxpack.o + +OBJS += passes/opt/opt_addcin.o OBJS += passes/opt/opt_andor_pmux.o +OBJS += passes/opt/opt_argmax.o OBJS += passes/opt/opt_balance_tree.o OBJS += passes/opt/opt_parallel_prefix.o OBJS += passes/opt/opt_prienc.o -OBJS += passes/opt/opt_addcin.o OBJS += passes/opt/peepopt.o GENFILES += passes/opt/peepopt_pm.h diff --git a/passes/opt/opt_argmax.cc b/passes/opt/opt_argmax.cc new file mode 100644 index 000000000..8c685e902 --- /dev/null +++ b/passes/opt/opt_argmax.cc @@ -0,0 +1,775 @@ +/* + * yosys -- Yosys Open SYnthesis Suite + * + * Copyright (C) 2026 Akash Levy + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + */ + +#include "kernel/yosys.h" +#include "kernel/sigtools.h" +#include "kernel/consteval.h" +#include +#include +#include + +USING_YOSYS_NAMESPACE +PRIVATE_NAMESPACE_BEGIN + +static int clog2_int(int x) +{ + int r = 0; + while ((1 << r) < x) + r++; + return r; +} + +static bool is_power_of_two(int x) +{ + return x > 0 && (x & (x - 1)) == 0; +} + +static Const packed_table_const(const vector &values, int elem_width) +{ + vector bits(values.size() * elem_width, State::S0); + for (int i = 0; i < GetSize(values); i++) + for (int b = 0; b < elem_width && b < 64; b++) + if ((values[i] >> b) & 1ULL) + bits[i * elem_width + b] = State::S1; + return Const(bits); +} + +static Const packed_valid_const(const vector &valid) +{ + vector bits(valid.size(), State::S0); + for (int i = 0; i < GetSize(valid); i++) + if (valid[i]) + bits[i] = State::S1; + return Const(bits); +} + +struct OptArgmaxWorker +{ + struct TestVector { + vector valid; + vector index; + vector values; + }; + + struct Candidate { + Wire *out_wire = nullptr; + Wire *valid_wire = nullptr; + SigSpec valid_sig; + SigSpec index_sig; + SigSpec values_sig; + std::string index_name; + std::string values_name; + int width = 0; + int index_width = 0; + int value_width = 0; + Cell *anchor = nullptr; + IdString anchor_port; + }; + + struct OutputCone { + pool cells; + pool leaves; + bool saw_bmux = false; + bool saw_lt = false; + }; + + struct InputBus { + SigSpec sig; + std::string name; + int entries = 0; + int elem_width = 0; + }; + + struct Record { + SigBit valid; + SigSpec value; + SigSpec index; + }; + + Module *module; + SigMap sigmap; + dict bit_to_driver; + pool input_port_bits; + Cell *cell = nullptr; + + int min_width = 4; + int max_width = 64; + int regions_rewritten = 0; + int cells_added = 0; + + OptArgmaxWorker(Module *module) : module(module), sigmap(module) + { + build_indexes(); + } + + bool is_sequential(Cell *c) + { + return c->type.in( + ID($ff), ID($dff), ID($dffe), ID($adff), ID($adffe), + ID($sdff), ID($sdffe), ID($sdffce), ID($dffsr), ID($dffsre), + ID($_DFF_P_), ID($_DFF_N_), + ID($_DFFE_PP_), ID($_DFFE_PN_), ID($_DFFE_NP_), ID($_DFFE_NN_), + ID($_DFF_PP0_), ID($_DFF_PP1_), ID($_DFF_PN0_), ID($_DFF_PN1_), + ID($_DFF_NP0_), ID($_DFF_NP1_), ID($_DFF_NN0_), ID($_DFF_NN1_), + ID($dlatch), ID($adlatch), ID($dlatchsr), + ID($mem), ID($mem_v2), ID($meminit), ID($meminit_v2), + ID($memrd), ID($memrd_v2), ID($memwr), ID($memwr_v2), + ID($fsm), + ID($assert), ID($assume), ID($cover), ID($live), ID($fair), + ID($print), ID($check), + ID($anyconst), ID($anyseq), ID($allconst), ID($allseq), + ID($initstate)); + } + + void build_indexes() + { + for (auto c : module->cells()) { + if (is_sequential(c)) + continue; + for (auto &conn : c->connections()) { + if (!c->output(conn.first)) + continue; + for (auto bit : sigmap(conn.second)) { + if (!bit.wire) + continue; + auto it = bit_to_driver.find(bit); + if (it == bit_to_driver.end()) + bit_to_driver[bit] = c; + else if (it->second != c) + it->second = nullptr; + } + } + } + + for (auto w : module->wires()) { + if (!w->port_input) + continue; + for (auto bit : sigmap(SigSpec(w))) + if (bit.wire) + input_port_bits.insert(bit); + } + } + + bool get_cone(SigSpec from, pool &cone_cells, pool &leaf_bits, + int max_cone_cells, int max_leaf_bits) + { + pool visited; + std::queue worklist; + for (auto bit : sigmap(from)) { + if (!bit.wire) + continue; + if (visited.insert(bit).second) + worklist.push(bit); + } + + while (!worklist.empty()) { + SigBit bit = worklist.front(); + worklist.pop(); + + if (input_port_bits.count(bit)) { + leaf_bits.insert(bit); + if (GetSize(leaf_bits) > max_leaf_bits) + return false; + continue; + } + + Cell *drv = bit_to_driver.at(bit, nullptr); + if (drv == nullptr) { + leaf_bits.insert(bit); + if (GetSize(leaf_bits) > max_leaf_bits) + return false; + continue; + } + + if (!cone_cells.insert(drv).second) + continue; + if (GetSize(cone_cells) > max_cone_cells) + return false; + + for (auto &conn : drv->connections()) { + if (!drv->input(conn.first)) + continue; + for (auto in_bit : sigmap(conn.second)) { + if (!in_bit.wire) + continue; + if (visited.insert(in_bit).second) + worklist.push(in_bit); + } + } + } + + return true; + } + + OutputCone summarize_output_cone(const pool &cone_cells, pool leaf_bits) + { + OutputCone cone; + cone.cells = cone_cells; + cone.leaves = std::move(leaf_bits); + for (auto c : cone_cells) { + cone.saw_bmux = cone.saw_bmux || c->type == ID($bmux); + cone.saw_lt = cone.saw_lt || c->type == ID($lt); + } + return cone; + } + + bool cone_has_required_shape(const OutputCone &cone, int value_width) + { + return cone.saw_bmux && (cone.saw_lt || value_width == 1); + } + + bool leaves_are_candidate_inputs(const pool &leaf_bits, const Candidate &cand) + { + pool allowed; + for (auto bit : sigmap(cand.valid_sig)) + if (bit.wire) + allowed.insert(bit); + for (auto bit : sigmap(cand.index_sig)) + if (bit.wire) + allowed.insert(bit); + for (auto bit : sigmap(cand.values_sig)) + if (bit.wire) + allowed.insert(bit); + + for (auto bit : leaf_bits) + if (!allowed.count(bit)) + return false; + return true; + } + + bool find_anchor_driver(Wire *out_wire, Cell *&anchor, IdString &anchor_port) + { + for (auto bit : sigmap(SigSpec(out_wire))) { + Cell *drv = bit_to_driver.at(bit, nullptr); + if (drv == nullptr) + continue; + for (auto &conn : drv->connections()) { + if (!drv->output(conn.first)) + continue; + for (auto out_bit : sigmap(conn.second)) { + if (out_bit == bit) { + anchor = drv; + anchor_port = conn.first; + return true; + } + } + } + } + return false; + } + + uint64_t value_mask(int width) + { + if (width >= 64) + return ~0ULL; + return (1ULL << width) - 1; + } + + void add_vector(vector &vectors, const vector &valid, + const vector &index, const vector &values) + { + vectors.push_back({valid, index, values}); + } + + vector make_test_vectors(int width, int value_width) + { + vector vectors; + vector identity(width), reverse(width), inc(width), dec(width), equal(width, 7); + uint64_t mask = value_mask(value_width); + for (int i = 0; i < width; i++) { + identity[i] = i; + reverse[i] = width - 1 - i; + inc[i] = uint64_t(i + 1) & mask; + dec[i] = uint64_t(width - i) & mask; + } + + vector valid(width, 0); + add_vector(vectors, valid, identity, inc); + + for (int i = 0; i < width; i++) { + valid.assign(width, 0); + valid[i] = 1; + add_vector(vectors, valid, identity, inc); + } + + valid.assign(width, 1); + add_vector(vectors, valid, identity, inc); + add_vector(vectors, valid, identity, dec); + add_vector(vectors, valid, identity, equal); + add_vector(vectors, valid, reverse, inc); + add_vector(vectors, valid, reverse, dec); + + for (int i = 0; i + 1 < width; i++) { + vector vals(width, 3); + valid.assign(width, 0); + valid[i] = 1; + valid[i + 1] = 1; + vals[i] = 1; + vals[i + 1] = 9; + add_vector(vectors, valid, identity, vals); + vals[i] = 5; + vals[i + 1] = 5; + add_vector(vectors, valid, identity, vals); + } + + if (width > 2) { + vector vals(width, 0); + valid.assign(width, 0); + valid[0] = 1; + valid[width - 1] = 1; + vals[0] = 2; + vals[width - 1] = 11; + add_vector(vectors, valid, identity, vals); + vals[0] = 13; + vals[width - 1] = 13; + add_vector(vectors, valid, identity, vals); + } + + return vectors; + } + + int expected_argmax(const TestVector &tv, int width, int value_width) + { + uint64_t mask = value_mask(value_width); + int best_idx = 0; + bool best_valid = tv.valid[0] != 0; + uint64_t best_value = tv.values[tv.index[0]] & mask; + + for (int k = 1; k < width; k++) { + bool cand_valid = tv.valid[k] != 0; + uint64_t cand_value = tv.values[tv.index[k]] & mask; + if (!best_valid && cand_valid) { + best_idx = k; + best_valid = true; + best_value = cand_value; + } else if (best_valid && cand_valid && best_value < cand_value) { + best_idx = k; + best_value = cand_value; + } + } + + return best_idx; + } + + bool fingerprint(const Candidate &cand) + { + ConstEval ce(module); + SigSpec out_sig = sigmap(SigSpec(cand.out_wire)); + SigSpec valid_sig = sigmap(cand.valid_sig); + SigSpec index_sig = sigmap(cand.index_sig); + SigSpec values_sig = sigmap(cand.values_sig); + + vector vectors = make_test_vectors(cand.width, cand.value_width); + for (auto &tv : vectors) { + ce.push(); + ce.set(valid_sig, packed_valid_const(tv.valid)); + ce.set(index_sig, packed_table_const(tv.index, cand.index_width)); + ce.set(values_sig, packed_table_const(tv.values, cand.value_width)); + + SigSpec out = out_sig; + SigSpec undef; + bool ok = ce.eval(out, undef); + ce.pop(); + if (!ok || !out.is_fully_const()) + return false; + + int actual = out.as_const().as_int(); + int expected = expected_argmax(tv, cand.width, cand.value_width); + if (actual != expected) + return false; + } + + return true; + } + + SigSpec zext(SigSpec sig, int width) + { + sig = sigmap(sig); + if (GetSize(sig) > width) + return sig.extract(0, width); + while (GetSize(sig) < width) + sig.append(State::S0); + return sig; + } + + SigSpec emit_not(Cell *anchor, SigSpec a) + { + Cell *cell = anchor; + cells_added++; + return module->Not(NEW_ID2_SUFFIX("argmax_not"), a); + } + + SigSpec emit_and(Cell *anchor, SigSpec a, SigSpec b) + { + Cell *cell = anchor; + cells_added++; + return module->And(NEW_ID2_SUFFIX("argmax_and"), a, b); + } + + SigSpec emit_or(Cell *anchor, SigSpec a, SigSpec b) + { + Cell *cell = anchor; + cells_added++; + return module->Or(NEW_ID2_SUFFIX("argmax_or"), a, b); + } + + SigSpec emit_lt(Cell *anchor, SigSpec a, SigSpec b) + { + Cell *cell = anchor; + cells_added++; + return module->Lt(NEW_ID2_SUFFIX("argmax_lt"), a, b); + } + + SigSpec emit_mux(Cell *anchor, SigSpec a, SigSpec b, SigSpec s) + { + Cell *cell = anchor; + cells_added++; + return module->Mux(NEW_ID2_SUFFIX("argmax_mux"), a, b, s); + } + + SigSpec emit_bmux(Cell *anchor, SigSpec a, SigSpec s) + { + Cell *cell = anchor; + cells_added++; + return module->Bmux(NEW_ID2_SUFFIX("argmax_val"), a, s); + } + + Record combine(Cell *anchor, const Record &lhs, const Record &rhs) + { + SigSpec lhs_invalid = emit_not(anchor, SigSpec(lhs.valid)); + SigSpec value_lt = emit_lt(anchor, lhs.value, rhs.value); + SigSpec valid_and_lt = emit_and(anchor, SigSpec(lhs.valid), value_lt); + SigSpec take_reason = emit_or(anchor, lhs_invalid, valid_and_lt); + SigSpec take_rhs = emit_and(anchor, SigSpec(rhs.valid), take_reason); + + Record out; + out.valid = emit_or(anchor, SigSpec(lhs.valid), SigSpec(rhs.valid))[0]; + out.value = emit_mux(anchor, lhs.value, rhs.value, take_rhs); + out.index = emit_mux(anchor, lhs.index, rhs.index, take_rhs); + return out; + } + + Record emit_tree_rec(Cell *anchor, const vector &leaves, int begin, int end) + { + log_assert(begin < end); + if (begin + 1 == end) + return leaves[begin]; + + int mid = begin + (end - begin) / 2; + Record lhs = emit_tree_rec(anchor, leaves, begin, mid); + Record rhs = emit_tree_rec(anchor, leaves, mid, end); + return combine(anchor, lhs, rhs); + } + + SigSpec emit_argmax(const Candidate &cand) + { + vector leaves; + SigSpec valid = sigmap(cand.valid_sig); + SigSpec index_map = sigmap(cand.index_sig); + SigSpec values = sigmap(cand.values_sig); + + for (int k = 0; k < cand.width; k++) { + SigSpec index = index_map.extract(k * cand.index_width, cand.index_width); + SigSpec value = emit_bmux(cand.anchor, values, index); + leaves.push_back({valid[k], value, SigSpec(Const(k, cand.index_width))}); + } + + Record root = emit_tree_rec(cand.anchor, leaves, 0, GetSize(leaves)); + return zext(root.index, cand.index_width); + } + + void disconnect_old_output(const Candidate &cand) + { + pool target_bits; + for (auto bit : sigmap(SigSpec(cand.out_wire))) + if (bit.wire) + target_bits.insert(bit); + + pool seen_cells; + for (auto target : target_bits) { + Cell *drv = bit_to_driver.at(target, nullptr); + if (drv == nullptr || seen_cells.count(drv)) + continue; + seen_cells.insert(drv); + + for (auto &conn : drv->connections()) { + if (!drv->output(conn.first)) + continue; + + SigSpec orig = conn.second; + SigSpec replacement = orig; + bool changed = false; + Cell *cell = drv; + Wire *dangling = module->addWire(NEW_ID2_SUFFIX("argmax_dangling"), GetSize(orig)); + for (int i = 0; i < GetSize(orig); i++) { + if (target_bits.count(sigmap(orig[i]))) { + replacement[i] = SigBit(dangling, i); + changed = true; + } + } + if (changed) + drv->setPort(conn.first, replacement); + } + } + } + + bool check_candidate(Candidate &cand, const OutputCone &cone) + { + if (cand.width < min_width || cand.width > max_width) + return false; + if (!is_power_of_two(cand.width)) + return false; + if (cand.index_width != clog2_int(cand.width)) + return false; + if (cand.value_width <= 0 || cand.value_width > 62) + return false; + + if (!cone_has_required_shape(cone, cand.value_width)) + return false; + if (!leaves_are_candidate_inputs(cone.leaves, cand)) + return false; + if (!find_anchor_driver(cand.out_wire, cand.anchor, cand.anchor_port)) + return false; + + return fingerprint(cand); + } + + bool parse_indexed_port_name(Wire *wire, std::string &base, int &index) + { + std::string name = wire->name.str(); + size_t rbrack = name.size(); + if (rbrack == 0 || name[rbrack - 1] != ']') + return false; + size_t lbrack = name.rfind('['); + if (lbrack == std::string::npos || lbrack + 1 >= rbrack - 1) + return false; + for (size_t i = lbrack + 1; i < rbrack - 1; i++) + if (!isdigit(name[i])) + return false; + base = name.substr(0, lbrack); + index = atoi(name.substr(lbrack + 1, rbrack - lbrack - 2).c_str()); + return true; + } + + vector collect_split_input_buses(const vector &inputs) + { + std::map>> groups; + for (auto w : inputs) { + std::string base; + int index = -1; + if (parse_indexed_port_name(w, base, index)) + groups[base].push_back({index, w}); + } + + vector buses; + for (auto &it : groups) { + auto entries = it.second; + std::sort(entries.begin(), entries.end(), + [](const std::pair &a, const std::pair &b) { + return a.first < b.first; + }); + if (entries.empty() || entries.front().first != 0) + continue; + bool contiguous = true; + int elem_width = GetSize(entries.front().second); + for (int i = 0; i < GetSize(entries); i++) { + if (entries[i].first != i || GetSize(entries[i].second) != elem_width) { + contiguous = false; + break; + } + } + if (!contiguous) + continue; + + SigSpec sig; + for (auto &entry : entries) + sig.append(SigSpec(entry.second)); + buses.push_back({sig, it.first, GetSize(entries), elem_width}); + } + + return buses; + } + + void run() + { + if (module->has_processes_warn()) + return; + + vector inputs; + vector outputs; + for (auto w : module->wires()) { + if (w->port_input) + inputs.push_back(w); + if (w->port_output && !w->port_input) + outputs.push_back(w); + } + + vector rewrites; + pool claimed_outputs; + for (auto out : outputs) { + if (claimed_outputs.count(out)) + continue; + int out_width = GetSize(out); + if (out_width < 2) + continue; + + pool cone_cells; + pool leaf_bits; + int max_cone_cells = std::max(256, max_width * 96); + int max_leaf_bits = max_width * (out_width + max_width) + max_width; + if (!get_cone(SigSpec(out), cone_cells, leaf_bits, + max_cone_cells, max_leaf_bits)) + continue; + OutputCone cone = summarize_output_cone(cone_cells, std::move(leaf_bits)); + if (!cone.saw_bmux) + continue; + + for (auto valid : inputs) { + int width = GetSize(valid); + if (width < min_width || width > max_width) + continue; + if (clog2_int(width) != out_width) + continue; + + vector index_buses; + vector values_buses; + for (auto input : inputs) { + if (input == valid) + continue; + if (GetSize(input) == width * out_width) + index_buses.push_back({SigSpec(input), input->name.str(), width, out_width}); + if (GetSize(input) % width == 0) + values_buses.push_back({SigSpec(input), input->name.str(), width, GetSize(input) / width}); + } + + vector split_buses = collect_split_input_buses(inputs); + for (auto bus : split_buses) { + if (bus.entries == width && bus.elem_width == out_width) + index_buses.push_back(bus); + if (bus.entries == width) + values_buses.push_back(bus); + } + + for (auto &index : index_buses) { + for (auto &values : values_buses) { + if (index.sig == values.sig) + continue; + Candidate cand; + cand.out_wire = out; + cand.valid_wire = valid; + cand.valid_sig = SigSpec(valid); + cand.index_sig = index.sig; + cand.values_sig = values.sig; + cand.index_name = index.name; + cand.values_name = values.name; + cand.width = width; + cand.index_width = out_width; + cand.value_width = values.elem_width; + if (!check_candidate(cand, cone)) + continue; + + rewrites.push_back(cand); + claimed_outputs.insert(out); + log(" %s: %s <- argmax(valid=%s, index=%s, values=%s) [N=%d, IW=%d, VW=%d]\n", + log_id(module), log_id(out), log_id(valid), index.name.c_str(), + values.name.c_str(), cand.width, cand.index_width, cand.value_width); + goto next_output; + } + } + } +next_output: + ; + } + + for (auto &cand : rewrites) { + cell = cand.anchor; + SigSpec new_out = emit_argmax(cand); + disconnect_old_output(cand); + module->connect(SigSpec(cand.out_wire), new_out); + regions_rewritten++; + } + } +}; + +struct OptArgmaxPass : public Pass +{ + OptArgmaxPass() : Pass("opt_argmax", + "detect and rewrite masked argmax loops into balanced compare trees") {} + + void help() override + { + log("\n"); + log(" opt_argmax [options] [selection]\n"); + log("\n"); + log("Detect combinational masked argmax loops of the form used by\n"); + log("read-after dependency logic and replace the serial loop-carried\n"); + log("index/update cone with a balanced tree of {valid,value,index}\n"); + log("comparators. Ties preserve the lower candidate index, matching a\n"); + log("strict '<' update condition; all-invalid inputs return index zero.\n"); + log("\n"); + log(" -max-width N, -max_width N\n"); + log(" maximum candidate count to consider (default 64).\n"); + log("\n"); + log(" -min-width N, -min_width N\n"); + log(" minimum candidate count to consider (default 4).\n"); + log("\n"); + } + + void execute(std::vector args, RTLIL::Design *design) override + { + log_header(design, "Executing OPT_ARGMAX pass (masked argmax rewrite).\n"); + + int max_width = 64; + int min_width = 4; + size_t argidx; + for (argidx = 1; argidx < args.size(); argidx++) { + if ((args[argidx] == "-max-width" || args[argidx] == "-max_width") && + argidx + 1 < args.size()) { + max_width = std::stoi(args[++argidx]); + continue; + } + if ((args[argidx] == "-min-width" || args[argidx] == "-min_width") && + argidx + 1 < args.size()) { + min_width = std::stoi(args[++argidx]); + continue; + } + break; + } + extra_args(args, argidx, design); + + int total_regions = 0; + int total_cells_added = 0; + for (auto module : design->selected_modules()) { + OptArgmaxWorker worker(module); + worker.max_width = max_width; + worker.min_width = min_width; + worker.run(); + total_regions += worker.regions_rewritten; + total_cells_added += worker.cells_added; + } + + log("Rewrote %d argmax region(s); emitted %d new cell(s).\n", + total_regions, total_cells_added); + + if (total_regions) + Yosys::run_pass("clean -purge"); + } +} OptArgmaxPass; + +PRIVATE_NAMESPACE_END diff --git a/tests/silimate/opt_argmax.sv b/tests/silimate/opt_argmax.sv new file mode 100644 index 000000000..eacfae99f --- /dev/null +++ b/tests/silimate/opt_argmax.sv @@ -0,0 +1,324 @@ +module opt_argmax_basic ( + input wire [15:0] sig, + input wire [15:0][3:0] sig3, + input wire [15:0][7:0] sig2, + output reg [3:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 16; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_w8 ( + input wire [7:0] sig, + input wire [7:0][2:0] sig3, + input wire [7:0][4:0] sig2, + output reg [2:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 8; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_w32 ( + input wire [31:0] sig, + input wire [31:0][4:0] sig3, + input wire [31:0][5:0] sig2, + output reg [4:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 32; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_flat ( + input wire [7:0] sig, + input wire [23:0] sig3, + input wire [39:0] sig2, + output reg [2:0] se_target_idx +); + function automatic [2:0] idx_at(input [2:0] pos); + idx_at = sig3[pos * 3 +: 3]; + endfunction + + function automatic [4:0] val_at(input [2:0] pos); + val_at = sig2[idx_at(pos) * 5 +: 5]; + endfunction + + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 8; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (val_at(se_target_idx) < val_at(k[2:0]))) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_value_w1 ( + input wire [7:0] sig, + input wire [7:0][2:0] sig3, + input wire [7:0] sig2, + output reg [2:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 8; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_value_w16 ( + input wire [7:0] sig, + input wire [7:0][2:0] sig3, + input wire [7:0][15:0] sig2, + output reg [2:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 8; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_two_regions ( + input wire [7:0] sig_a, + input wire [7:0][2:0] sig3_a, + input wire [7:0][7:0] sig2_a, + input wire [7:0] sig_b, + input wire [7:0][2:0] sig3_b, + input wire [7:0][5:0] sig2_b, + output reg [2:0] idx_a, + output reg [2:0] idx_b +); + always_comb begin + idx_a = '0; + for (int k = 1; k < 8; k++) begin + if (!sig_a[idx_a] && sig_a[k]) begin + idx_a = k; + end else if (sig_a[idx_a] && sig_a[k] && + (sig2_a[sig3_a[idx_a]] < sig2_a[sig3_a[k]])) begin + idx_a = k; + end + end + + idx_b = '0; + for (int k = 1; k < 8; k++) begin + if (!sig_b[idx_b] && sig_b[k]) begin + idx_b = k; + end else if (sig_b[idx_b] && sig_b[k] && + (sig2_b[sig3_b[idx_b]] < sig2_b[sig3_b[k]])) begin + idx_b = k; + end + end + end +endmodule + +module opt_argmax_shared_consumer ( + input wire [7:0] sig, + input wire [7:0][2:0] sig3, + input wire [7:0][7:0] sig2, + input wire [2:0] salt, + output reg [2:0] se_target_idx, + output wire [2:0] also_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 8; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end + + assign also_idx = se_target_idx ^ salt; +endmodule + +module opt_argmax_tie_high ( + input wire [15:0] sig, + input wire [15:0][3:0] sig3, + input wire [15:0][7:0] sig2, + output reg [3:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 16; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] <= sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_nonzero_default ( + input wire [15:0] sig, + input wire [15:0][3:0] sig3, + input wire [15:0][7:0] sig2, + output reg [3:0] se_target_idx +); + always_comb begin + se_target_idx = 4'd1; + for (int k = 1; k < 16; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_min ( + input wire [15:0] sig, + input wire [15:0][3:0] sig3, + input wire [15:0][7:0] sig2, + output reg [3:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 16; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] > sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_w12 ( + input wire [11:0] sig, + input wire [11:0][3:0] sig3, + input wire [11:0][7:0] sig2, + output reg [3:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 12; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_bad_index_width ( + input wire [15:0] sig, + input wire [15:0][4:0] sig3, + input wire [15:0][7:0] sig2, + output reg [3:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 16; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx][3:0]] < sig2[sig3[k][3:0]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_stress_noop ( + input wire [63:0] sel, + input wire [63:0] a, + input wire [63:0] b, + output wire [63:0] y +); + wire [63:0] mux0 = sel[0] ? a : b; + wire [63:0] mux1 = sel[1] ? mux0 : {mux0[31:0], mux0[63:32]}; + wire [63:0] mux2 = sel[2] ? mux1 : (mux1 ^ a); + wire [63:0] mux3 = sel[3] ? mux2 : (mux2 & b); + wire [63:0] mux4 = sel[4] ? mux3 : (mux3 | a); + wire [63:0] mux5 = sel[5] ? mux4 : {mux4[47:0], mux4[63:48]}; + assign y = sel[6] ? mux5 : ~mux5; +endmodule + +module opt_argmax_unrelated ( + input wire [3:0] a, + input wire [3:0] b, + input wire sel, + output wire [3:0] y +); + assign y = sel ? a : b; +endmodule + +module opt_argmax_multi_match ( + input wire [15:0] sig, + input wire [15:0][3:0] sig3, + input wire [15:0][7:0] sig2, + output reg [3:0] se_target_idx +); + always_comb begin + se_target_idx = '0; + for (int k = 1; k < 16; k++) begin + if (!sig[se_target_idx] && sig[k]) begin + se_target_idx = k; + end else if (sig[se_target_idx] && sig[k] && + (sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin + se_target_idx = k; + end + end + end +endmodule + +module opt_argmax_multi_keep ( + input wire [3:0] a, + input wire [3:0] b, + input wire sel, + output wire [3:0] y +); + assign y = sel ? a : b; +endmodule diff --git a/tests/silimate/opt_argmax.ys b/tests/silimate/opt_argmax.ys new file mode 100644 index 000000000..79c61ebc5 --- /dev/null +++ b/tests/silimate/opt_argmax.ys @@ -0,0 +1,332 @@ +# Tests for opt_argmax. + +log -header "Small masked argmax self-equivalence" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_w8 +proc; opt_clean +rename opt_argmax_w8 gold + +read -sv opt_argmax.sv +verific -import opt_argmax_w8 +proc; opt_clean +select -module opt_argmax_w8 +opt_argmax +select -clear +opt_clean +rename opt_argmax_w8 gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Basic masked argmax structural rewrite" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_basic +proc; opt_clean +opt_argmax +opt_clean +select -assert-min 1 w:*argmax* +select -assert-count 16 t:$bmux +select -assert-count 15 t:$lt +select -assert-count 29 t:$mux +select -assert-count 30 t:$and +select -assert-count 29 t:$or +select -assert-count 15 t:$not +select -assert-none c:LessThan_* +design -reset +log -pop + +log -header "Flat-bus masked argmax self-equivalence" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_flat +proc; opt_clean +rename opt_argmax_flat gold + +read -sv opt_argmax.sv +verific -import opt_argmax_flat +proc; opt_clean +select -module opt_argmax_flat +opt_argmax +select -clear +opt_clean +select -assert-min 1 w:*argmax* +rename opt_argmax_flat gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Scaled masked argmax: 8 entries structural" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_w8 +proc; opt_clean +opt_argmax +opt_clean +select -assert-min 1 w:*argmax* +design -reset +log -pop + +log -header "Scaled masked argmax: 32 entries structural" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_w32 +proc; opt_clean +opt_argmax +opt_clean +select -assert-min 1 w:*argmax* +design -reset +log -pop + +log -header "Value width edge: 1-bit values" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_value_w1 +proc; opt_clean +rename opt_argmax_value_w1 gold + +read -sv opt_argmax.sv +verific -import opt_argmax_value_w1 +proc; opt_clean +select -module opt_argmax_value_w1 +opt_argmax +select -clear +opt_clean +select -assert-min 1 w:*argmax* +rename opt_argmax_value_w1 gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Value width edge: 16-bit values" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_value_w16 +proc; opt_clean +rename opt_argmax_value_w16 gold + +read -sv opt_argmax.sv +verific -import opt_argmax_value_w16 +proc; opt_clean +select -module opt_argmax_value_w16 +opt_argmax +select -clear +opt_clean +select -assert-min 1 w:*argmax* +rename opt_argmax_value_w16 gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Same module: two independent argmax regions" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_two_regions +proc; opt_clean +rename opt_argmax_two_regions gold + +read -sv opt_argmax.sv +verific -import opt_argmax_two_regions +proc; opt_clean +select -module opt_argmax_two_regions +opt_argmax +select -clear +opt_clean +select -assert-min 2 w:*argmax* +rename opt_argmax_two_regions gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Shared consumer of argmax output remains equivalent" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_shared_consumer +proc; opt_clean +rename opt_argmax_shared_consumer gold + +read -sv opt_argmax.sv +verific -import opt_argmax_shared_consumer +proc; opt_clean +select -module opt_argmax_shared_consumer +opt_argmax +select -clear +opt_clean +select -assert-min 1 w:*argmax* +rename opt_argmax_shared_consumer gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Max width leaves argmax unchanged" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_basic +proc; opt_clean +opt_argmax -max_width 8 +select -assert-none w:*argmax* +design -reset +log -pop + +log -header "Negative: non-power-of-two candidate count" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_w12 +proc; opt_clean +opt_argmax +select -assert-none w:*argmax* +design -reset +log -pop + +log -header "Negative: mismatched index-map width" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_bad_index_width +proc; opt_clean +opt_argmax +select -assert-none w:*argmax* +design -reset +log -pop + +log -header "Negative: strict tie behavior changed" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_tie_high +proc; opt_clean +opt_argmax +select -assert-none w:*argmax* +design -reset +log -pop + +log -header "Negative: nonzero all-invalid default" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_nonzero_default +proc; opt_clean +opt_argmax +select -assert-none w:*argmax* +design -reset +log -pop + +log -header "Negative: min-selection comparator" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_min +proc; opt_clean +opt_argmax +select -assert-none w:*argmax* +design -reset +log -pop + +log -header "Negative: unrelated mux logic unchanged" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_unrelated +proc; opt_clean +select -assert-count 1 t:$mux +opt_argmax +select -assert-none w:*argmax* +select -assert-count 1 t:$mux +design -reset +log -pop + +log -header "Negative: bmux-heavy unrelated stress module" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_stress_noop +proc; opt_clean +opt_argmax +select -assert-none w:*argmax* +design -reset +log -pop + +log -header "Multi-module: only matching module rewrites" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_multi_match opt_argmax_multi_keep +proc; opt_clean +opt_argmax +opt_clean +select -assert-min 1 opt_argmax_multi_match/w:*argmax* +select -assert-none opt_argmax_multi_keep/w:*argmax* +design -reset +log -pop