mirror of https://github.com/YosysHQ/yosys.git
commit
b60d2daa41
|
|
@ -75,6 +75,7 @@ struct OptArgmaxWorker
|
|||
SigSpec values_sig;
|
||||
std::string index_name;
|
||||
std::string values_name;
|
||||
bool identity_index = false;
|
||||
int width = 0;
|
||||
int index_width = 0;
|
||||
int value_width = 0;
|
||||
|
|
@ -240,9 +241,10 @@ struct OptArgmaxWorker
|
|||
for (auto bit : sigmap(cand.valid_sig))
|
||||
if (bit.wire)
|
||||
allowed.insert(bit);
|
||||
for (auto bit : sigmap(cand.index_sig))
|
||||
if (bit.wire)
|
||||
allowed.insert(bit);
|
||||
if (!cand.identity_index)
|
||||
for (auto bit : sigmap(cand.index_sig))
|
||||
if (bit.wire)
|
||||
allowed.insert(bit);
|
||||
for (auto bit : sigmap(cand.values_sig))
|
||||
if (bit.wire)
|
||||
allowed.insert(bit);
|
||||
|
|
@ -344,16 +346,16 @@ struct OptArgmaxWorker
|
|||
return vectors;
|
||||
}
|
||||
|
||||
int expected_argmax(const TestVector &tv, int width, int value_width)
|
||||
int expected_argmax(const TestVector &tv, int width, int value_width, bool identity_index)
|
||||
{
|
||||
uint64_t mask = value_mask(value_width);
|
||||
int best_idx = 0;
|
||||
bool best_valid = tv.valid[0] != 0;
|
||||
uint64_t best_value = tv.values[tv.index[0]] & mask;
|
||||
uint64_t best_value = tv.values[identity_index ? 0 : tv.index[0]] & mask;
|
||||
|
||||
for (int k = 1; k < width; k++) {
|
||||
bool cand_valid = tv.valid[k] != 0;
|
||||
uint64_t cand_value = tv.values[tv.index[k]] & mask;
|
||||
uint64_t cand_value = tv.values[identity_index ? k : tv.index[k]] & mask;
|
||||
if (!best_valid && cand_valid) {
|
||||
best_idx = k;
|
||||
best_valid = true;
|
||||
|
|
@ -372,14 +374,15 @@ struct OptArgmaxWorker
|
|||
ConstEval ce(module);
|
||||
SigSpec out_sig = sigmap(SigSpec(cand.out_wire));
|
||||
SigSpec valid_sig = sigmap(cand.valid_sig);
|
||||
SigSpec index_sig = sigmap(cand.index_sig);
|
||||
SigSpec index_sig = cand.identity_index ? SigSpec() : sigmap(cand.index_sig);
|
||||
SigSpec values_sig = sigmap(cand.values_sig);
|
||||
|
||||
vector<TestVector> vectors = make_test_vectors(cand.width, cand.value_width);
|
||||
for (auto &tv : vectors) {
|
||||
ce.push();
|
||||
ce.set(valid_sig, packed_valid_const(tv.valid));
|
||||
ce.set(index_sig, packed_table_const(tv.index, cand.index_width));
|
||||
if (!cand.identity_index)
|
||||
ce.set(index_sig, packed_table_const(tv.index, cand.index_width));
|
||||
ce.set(values_sig, packed_table_const(tv.values, cand.value_width));
|
||||
|
||||
SigSpec out = out_sig;
|
||||
|
|
@ -390,7 +393,7 @@ struct OptArgmaxWorker
|
|||
return false;
|
||||
|
||||
int actual = out.as_const().as_int();
|
||||
int expected = expected_argmax(tv, cand.width, cand.value_width);
|
||||
int expected = expected_argmax(tv, cand.width, cand.value_width, cand.identity_index);
|
||||
if (actual != expected)
|
||||
return false;
|
||||
}
|
||||
|
|
@ -481,12 +484,17 @@ struct OptArgmaxWorker
|
|||
{
|
||||
vector<Record> leaves;
|
||||
SigSpec valid = sigmap(cand.valid_sig);
|
||||
SigSpec index_map = sigmap(cand.index_sig);
|
||||
SigSpec index_map = cand.identity_index ? SigSpec() : sigmap(cand.index_sig);
|
||||
SigSpec values = sigmap(cand.values_sig);
|
||||
|
||||
for (int k = 0; k < cand.width; k++) {
|
||||
SigSpec index = index_map.extract(k * cand.index_width, cand.index_width);
|
||||
SigSpec value = emit_bmux(cand.anchor, values, index);
|
||||
SigSpec value;
|
||||
if (cand.identity_index)
|
||||
value = values.extract(k * cand.value_width, cand.value_width);
|
||||
else {
|
||||
SigSpec index = index_map.extract(k * cand.index_width, cand.index_width);
|
||||
value = emit_bmux(cand.anchor, values, index);
|
||||
}
|
||||
leaves.push_back({valid[k], value, SigSpec(Const(k, cand.index_width))});
|
||||
}
|
||||
|
||||
|
|
@ -666,6 +674,29 @@ struct OptArgmaxWorker
|
|||
values_buses.push_back(bus);
|
||||
}
|
||||
|
||||
for (auto &values : values_buses) {
|
||||
Candidate cand;
|
||||
cand.out_wire = out;
|
||||
cand.valid_wire = valid;
|
||||
cand.valid_sig = SigSpec(valid);
|
||||
cand.values_sig = values.sig;
|
||||
cand.index_name = "<identity>";
|
||||
cand.values_name = values.name;
|
||||
cand.identity_index = true;
|
||||
cand.width = width;
|
||||
cand.index_width = out_width;
|
||||
cand.value_width = values.elem_width;
|
||||
if (!check_candidate(cand, cone))
|
||||
continue;
|
||||
|
||||
rewrites.push_back(cand);
|
||||
claimed_outputs.insert(out);
|
||||
log(" %s: %s <- argmax(valid=%s, index=<identity>, values=%s) [N=%d, IW=%d, VW=%d]\n",
|
||||
log_id(module), log_id(out), log_id(valid), values.name.c_str(),
|
||||
cand.width, cand.index_width, cand.value_width);
|
||||
goto next_output;
|
||||
}
|
||||
|
||||
for (auto &index : index_buses) {
|
||||
for (auto &values : values_buses) {
|
||||
if (index.sig == values.sig)
|
||||
|
|
@ -678,6 +709,7 @@ struct OptArgmaxWorker
|
|||
cand.values_sig = values.sig;
|
||||
cand.index_name = index.name;
|
||||
cand.values_name = values.name;
|
||||
cand.identity_index = false;
|
||||
cand.width = width;
|
||||
cand.index_width = out_width;
|
||||
cand.value_width = values.elem_width;
|
||||
|
|
|
|||
|
|
@ -55,6 +55,42 @@ module opt_argmax_w32 (
|
|||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_identity_w8 (
|
||||
input wire [7:0] valid_in,
|
||||
input wire [7:0][4:0] val_in,
|
||||
output reg [2:0] best_idx
|
||||
);
|
||||
always_comb begin
|
||||
best_idx = '0;
|
||||
for (int k = 1; k < 8; k++) begin
|
||||
if (!valid_in[best_idx] && valid_in[k]) begin
|
||||
best_idx = k;
|
||||
end else if (valid_in[best_idx] && valid_in[k] &&
|
||||
(val_in[best_idx] < val_in[k])) begin
|
||||
best_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_identity_w16 (
|
||||
input wire [15:0] valid_in,
|
||||
input wire [15:0][7:0] val_in,
|
||||
output reg [3:0] best_idx
|
||||
);
|
||||
always_comb begin
|
||||
best_idx = '0;
|
||||
for (int k = 1; k < 16; k++) begin
|
||||
if (!valid_in[best_idx] && valid_in[k]) begin
|
||||
best_idx = k;
|
||||
end else if (valid_in[best_idx] && valid_in[k] &&
|
||||
(val_in[best_idx] < val_in[k])) begin
|
||||
best_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_flat (
|
||||
input wire [7:0] sig,
|
||||
input wire [23:0] sig3,
|
||||
|
|
|
|||
|
|
@ -74,6 +74,49 @@ sat -prove-asserts -verify
|
|||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Identity-index masked argmax self-equivalence"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_identity_w8
|
||||
proc; opt_clean
|
||||
rename opt_argmax_identity_w8 gold
|
||||
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_identity_w8
|
||||
proc; opt_clean
|
||||
select -module opt_argmax_identity_w8
|
||||
opt_argmax
|
||||
select -clear
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
rename opt_argmax_identity_w8 gate
|
||||
|
||||
miter -equiv -flatten -make_assert gold gate miter
|
||||
hierarchy -top miter
|
||||
proc; opt; memory; opt
|
||||
sat -prove-asserts -verify
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Identity-index masked argmax structural rewrite"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_identity_w16
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
select -assert-none c:*argmax_val*
|
||||
select -assert-none c:LessThan_*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Scaled masked argmax: 8 entries structural"
|
||||
log -push
|
||||
design -reset
|
||||
|
|
|
|||
Loading…
Reference in New Issue