From 4a35c0ab87aaf70667e205329effe712c3331da9 Mon Sep 17 00:00:00 2001 From: Akash Levy Date: Tue, 9 Jun 2026 01:57:11 -0700 Subject: [PATCH] opt_argmax fixes --- passes/opt/opt_argmax.cc | 56 ++++++++++++++++++++++++++++-------- tests/silimate/opt_argmax.sv | 36 +++++++++++++++++++++++ tests/silimate/opt_argmax.ys | 43 +++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 12 deletions(-) diff --git a/passes/opt/opt_argmax.cc b/passes/opt/opt_argmax.cc index 8c685e902..59f7bde76 100644 --- a/passes/opt/opt_argmax.cc +++ b/passes/opt/opt_argmax.cc @@ -75,6 +75,7 @@ struct OptArgmaxWorker SigSpec values_sig; std::string index_name; std::string values_name; + bool identity_index = false; int width = 0; int index_width = 0; int value_width = 0; @@ -240,9 +241,10 @@ struct OptArgmaxWorker for (auto bit : sigmap(cand.valid_sig)) if (bit.wire) allowed.insert(bit); - for (auto bit : sigmap(cand.index_sig)) - if (bit.wire) - allowed.insert(bit); + if (!cand.identity_index) + for (auto bit : sigmap(cand.index_sig)) + if (bit.wire) + allowed.insert(bit); for (auto bit : sigmap(cand.values_sig)) if (bit.wire) allowed.insert(bit); @@ -344,16 +346,16 @@ struct OptArgmaxWorker return vectors; } - int expected_argmax(const TestVector &tv, int width, int value_width) + int expected_argmax(const TestVector &tv, int width, int value_width, bool identity_index) { uint64_t mask = value_mask(value_width); int best_idx = 0; bool best_valid = tv.valid[0] != 0; - uint64_t best_value = tv.values[tv.index[0]] & mask; + uint64_t best_value = tv.values[identity_index ? 0 : tv.index[0]] & mask; for (int k = 1; k < width; k++) { bool cand_valid = tv.valid[k] != 0; - uint64_t cand_value = tv.values[tv.index[k]] & mask; + uint64_t cand_value = tv.values[identity_index ? k : tv.index[k]] & mask; if (!best_valid && cand_valid) { best_idx = k; best_valid = true; @@ -372,14 +374,15 @@ struct OptArgmaxWorker ConstEval ce(module); SigSpec out_sig = sigmap(SigSpec(cand.out_wire)); SigSpec valid_sig = sigmap(cand.valid_sig); - SigSpec index_sig = sigmap(cand.index_sig); + SigSpec index_sig = cand.identity_index ? SigSpec() : sigmap(cand.index_sig); SigSpec values_sig = sigmap(cand.values_sig); vector vectors = make_test_vectors(cand.width, cand.value_width); for (auto &tv : vectors) { ce.push(); ce.set(valid_sig, packed_valid_const(tv.valid)); - ce.set(index_sig, packed_table_const(tv.index, cand.index_width)); + if (!cand.identity_index) + ce.set(index_sig, packed_table_const(tv.index, cand.index_width)); ce.set(values_sig, packed_table_const(tv.values, cand.value_width)); SigSpec out = out_sig; @@ -390,7 +393,7 @@ struct OptArgmaxWorker return false; int actual = out.as_const().as_int(); - int expected = expected_argmax(tv, cand.width, cand.value_width); + int expected = expected_argmax(tv, cand.width, cand.value_width, cand.identity_index); if (actual != expected) return false; } @@ -481,12 +484,17 @@ struct OptArgmaxWorker { vector leaves; SigSpec valid = sigmap(cand.valid_sig); - SigSpec index_map = sigmap(cand.index_sig); + SigSpec index_map = cand.identity_index ? SigSpec() : sigmap(cand.index_sig); SigSpec values = sigmap(cand.values_sig); for (int k = 0; k < cand.width; k++) { - SigSpec index = index_map.extract(k * cand.index_width, cand.index_width); - SigSpec value = emit_bmux(cand.anchor, values, index); + SigSpec value; + if (cand.identity_index) + value = values.extract(k * cand.value_width, cand.value_width); + else { + SigSpec index = index_map.extract(k * cand.index_width, cand.index_width); + value = emit_bmux(cand.anchor, values, index); + } leaves.push_back({valid[k], value, SigSpec(Const(k, cand.index_width))}); } @@ -666,6 +674,29 @@ struct OptArgmaxWorker values_buses.push_back(bus); } + for (auto &values : values_buses) { + Candidate cand; + cand.out_wire = out; + cand.valid_wire = valid; + cand.valid_sig = SigSpec(valid); + cand.values_sig = values.sig; + cand.index_name = ""; + cand.values_name = values.name; + cand.identity_index = true; + cand.width = width; + cand.index_width = out_width; + cand.value_width = values.elem_width; + if (!check_candidate(cand, cone)) + continue; + + rewrites.push_back(cand); + claimed_outputs.insert(out); + log(" %s: %s <- argmax(valid=%s, index=, values=%s) [N=%d, IW=%d, VW=%d]\n", + log_id(module), log_id(out), log_id(valid), values.name.c_str(), + cand.width, cand.index_width, cand.value_width); + goto next_output; + } + for (auto &index : index_buses) { for (auto &values : values_buses) { if (index.sig == values.sig) @@ -678,6 +709,7 @@ struct OptArgmaxWorker cand.values_sig = values.sig; cand.index_name = index.name; cand.values_name = values.name; + cand.identity_index = false; cand.width = width; cand.index_width = out_width; cand.value_width = values.elem_width; diff --git a/tests/silimate/opt_argmax.sv b/tests/silimate/opt_argmax.sv index eacfae99f..f02c4c3ca 100644 --- a/tests/silimate/opt_argmax.sv +++ b/tests/silimate/opt_argmax.sv @@ -55,6 +55,42 @@ module opt_argmax_w32 ( end endmodule +module opt_argmax_identity_w8 ( + input wire [7:0] valid_in, + input wire [7:0][4:0] val_in, + output reg [2:0] best_idx +); + always_comb begin + best_idx = '0; + for (int k = 1; k < 8; k++) begin + if (!valid_in[best_idx] && valid_in[k]) begin + best_idx = k; + end else if (valid_in[best_idx] && valid_in[k] && + (val_in[best_idx] < val_in[k])) begin + best_idx = k; + end + end + end +endmodule + +module opt_argmax_identity_w16 ( + input wire [15:0] valid_in, + input wire [15:0][7:0] val_in, + output reg [3:0] best_idx +); + always_comb begin + best_idx = '0; + for (int k = 1; k < 16; k++) begin + if (!valid_in[best_idx] && valid_in[k]) begin + best_idx = k; + end else if (valid_in[best_idx] && valid_in[k] && + (val_in[best_idx] < val_in[k])) begin + best_idx = k; + end + end + end +endmodule + module opt_argmax_flat ( input wire [7:0] sig, input wire [23:0] sig3, diff --git a/tests/silimate/opt_argmax.ys b/tests/silimate/opt_argmax.ys index 79c61ebc5..bf96482f8 100644 --- a/tests/silimate/opt_argmax.ys +++ b/tests/silimate/opt_argmax.ys @@ -74,6 +74,49 @@ sat -prove-asserts -verify design -reset log -pop +log -header "Identity-index masked argmax self-equivalence" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_identity_w8 +proc; opt_clean +rename opt_argmax_identity_w8 gold + +read -sv opt_argmax.sv +verific -import opt_argmax_identity_w8 +proc; opt_clean +select -module opt_argmax_identity_w8 +opt_argmax +select -clear +opt_clean +select -assert-min 1 w:*argmax* +rename opt_argmax_identity_w8 gate + +miter -equiv -flatten -make_assert gold gate miter +hierarchy -top miter +proc; opt; memory; opt +sat -prove-asserts -verify +design -reset +log -pop + +log -header "Identity-index masked argmax structural rewrite" +log -push +design -reset +verific -cfg veri_optimize_wide_selector 1 +verific -cfg db_infer_wide_muxes_post_elaboration 0 +read -sv opt_argmax.sv +verific -import opt_argmax_identity_w16 +proc; opt_clean +opt_argmax +opt_clean +select -assert-min 1 w:*argmax* +select -assert-none c:*argmax_val* +select -assert-none c:LessThan_* +design -reset +log -pop + log -header "Scaled masked argmax: 8 entries structural" log -push design -reset