opt_argmax fixes

This commit is contained in:
Akash Levy 2026-06-09 01:57:11 -07:00
parent b3ea5770cd
commit 4a35c0ab87
3 changed files with 123 additions and 12 deletions

View File

@ -75,6 +75,7 @@ struct OptArgmaxWorker
SigSpec values_sig;
std::string index_name;
std::string values_name;
bool identity_index = false;
int width = 0;
int index_width = 0;
int value_width = 0;
@ -240,9 +241,10 @@ struct OptArgmaxWorker
for (auto bit : sigmap(cand.valid_sig))
if (bit.wire)
allowed.insert(bit);
for (auto bit : sigmap(cand.index_sig))
if (bit.wire)
allowed.insert(bit);
if (!cand.identity_index)
for (auto bit : sigmap(cand.index_sig))
if (bit.wire)
allowed.insert(bit);
for (auto bit : sigmap(cand.values_sig))
if (bit.wire)
allowed.insert(bit);
@ -344,16 +346,16 @@ struct OptArgmaxWorker
return vectors;
}
int expected_argmax(const TestVector &tv, int width, int value_width)
int expected_argmax(const TestVector &tv, int width, int value_width, bool identity_index)
{
uint64_t mask = value_mask(value_width);
int best_idx = 0;
bool best_valid = tv.valid[0] != 0;
uint64_t best_value = tv.values[tv.index[0]] & mask;
uint64_t best_value = tv.values[identity_index ? 0 : tv.index[0]] & mask;
for (int k = 1; k < width; k++) {
bool cand_valid = tv.valid[k] != 0;
uint64_t cand_value = tv.values[tv.index[k]] & mask;
uint64_t cand_value = tv.values[identity_index ? k : tv.index[k]] & mask;
if (!best_valid && cand_valid) {
best_idx = k;
best_valid = true;
@ -372,14 +374,15 @@ struct OptArgmaxWorker
ConstEval ce(module);
SigSpec out_sig = sigmap(SigSpec(cand.out_wire));
SigSpec valid_sig = sigmap(cand.valid_sig);
SigSpec index_sig = sigmap(cand.index_sig);
SigSpec index_sig = cand.identity_index ? SigSpec() : sigmap(cand.index_sig);
SigSpec values_sig = sigmap(cand.values_sig);
vector<TestVector> vectors = make_test_vectors(cand.width, cand.value_width);
for (auto &tv : vectors) {
ce.push();
ce.set(valid_sig, packed_valid_const(tv.valid));
ce.set(index_sig, packed_table_const(tv.index, cand.index_width));
if (!cand.identity_index)
ce.set(index_sig, packed_table_const(tv.index, cand.index_width));
ce.set(values_sig, packed_table_const(tv.values, cand.value_width));
SigSpec out = out_sig;
@ -390,7 +393,7 @@ struct OptArgmaxWorker
return false;
int actual = out.as_const().as_int();
int expected = expected_argmax(tv, cand.width, cand.value_width);
int expected = expected_argmax(tv, cand.width, cand.value_width, cand.identity_index);
if (actual != expected)
return false;
}
@ -481,12 +484,17 @@ struct OptArgmaxWorker
{
vector<Record> leaves;
SigSpec valid = sigmap(cand.valid_sig);
SigSpec index_map = sigmap(cand.index_sig);
SigSpec index_map = cand.identity_index ? SigSpec() : sigmap(cand.index_sig);
SigSpec values = sigmap(cand.values_sig);
for (int k = 0; k < cand.width; k++) {
SigSpec index = index_map.extract(k * cand.index_width, cand.index_width);
SigSpec value = emit_bmux(cand.anchor, values, index);
SigSpec value;
if (cand.identity_index)
value = values.extract(k * cand.value_width, cand.value_width);
else {
SigSpec index = index_map.extract(k * cand.index_width, cand.index_width);
value = emit_bmux(cand.anchor, values, index);
}
leaves.push_back({valid[k], value, SigSpec(Const(k, cand.index_width))});
}
@ -666,6 +674,29 @@ struct OptArgmaxWorker
values_buses.push_back(bus);
}
for (auto &values : values_buses) {
Candidate cand;
cand.out_wire = out;
cand.valid_wire = valid;
cand.valid_sig = SigSpec(valid);
cand.values_sig = values.sig;
cand.index_name = "<identity>";
cand.values_name = values.name;
cand.identity_index = true;
cand.width = width;
cand.index_width = out_width;
cand.value_width = values.elem_width;
if (!check_candidate(cand, cone))
continue;
rewrites.push_back(cand);
claimed_outputs.insert(out);
log(" %s: %s <- argmax(valid=%s, index=<identity>, values=%s) [N=%d, IW=%d, VW=%d]\n",
log_id(module), log_id(out), log_id(valid), values.name.c_str(),
cand.width, cand.index_width, cand.value_width);
goto next_output;
}
for (auto &index : index_buses) {
for (auto &values : values_buses) {
if (index.sig == values.sig)
@ -678,6 +709,7 @@ struct OptArgmaxWorker
cand.values_sig = values.sig;
cand.index_name = index.name;
cand.values_name = values.name;
cand.identity_index = false;
cand.width = width;
cand.index_width = out_width;
cand.value_width = values.elem_width;

View File

@ -55,6 +55,42 @@ module opt_argmax_w32 (
end
endmodule
module opt_argmax_identity_w8 (
input wire [7:0] valid_in,
input wire [7:0][4:0] val_in,
output reg [2:0] best_idx
);
always_comb begin
best_idx = '0;
for (int k = 1; k < 8; k++) begin
if (!valid_in[best_idx] && valid_in[k]) begin
best_idx = k;
end else if (valid_in[best_idx] && valid_in[k] &&
(val_in[best_idx] < val_in[k])) begin
best_idx = k;
end
end
end
endmodule
module opt_argmax_identity_w16 (
input wire [15:0] valid_in,
input wire [15:0][7:0] val_in,
output reg [3:0] best_idx
);
always_comb begin
best_idx = '0;
for (int k = 1; k < 16; k++) begin
if (!valid_in[best_idx] && valid_in[k]) begin
best_idx = k;
end else if (valid_in[best_idx] && valid_in[k] &&
(val_in[best_idx] < val_in[k])) begin
best_idx = k;
end
end
end
endmodule
module opt_argmax_flat (
input wire [7:0] sig,
input wire [23:0] sig3,

View File

@ -74,6 +74,49 @@ sat -prove-asserts -verify
design -reset
log -pop
log -header "Identity-index masked argmax self-equivalence"
log -push
design -reset
verific -cfg veri_optimize_wide_selector 1
verific -cfg db_infer_wide_muxes_post_elaboration 0
read -sv opt_argmax.sv
verific -import opt_argmax_identity_w8
proc; opt_clean
rename opt_argmax_identity_w8 gold
read -sv opt_argmax.sv
verific -import opt_argmax_identity_w8
proc; opt_clean
select -module opt_argmax_identity_w8
opt_argmax
select -clear
opt_clean
select -assert-min 1 w:*argmax*
rename opt_argmax_identity_w8 gate
miter -equiv -flatten -make_assert gold gate miter
hierarchy -top miter
proc; opt; memory; opt
sat -prove-asserts -verify
design -reset
log -pop
log -header "Identity-index masked argmax structural rewrite"
log -push
design -reset
verific -cfg veri_optimize_wide_selector 1
verific -cfg db_infer_wide_muxes_post_elaboration 0
read -sv opt_argmax.sv
verific -import opt_argmax_identity_w16
proc; opt_clean
opt_argmax
opt_clean
select -assert-min 1 w:*argmax*
select -assert-none c:*argmax_val*
select -assert-none c:LessThan_*
design -reset
log -pop
log -header "Scaled masked argmax: 8 entries structural"
log -push
design -reset