Fixes for filtering small cases and catching more larger ones with trickier signatures

This commit is contained in:
Akash Levy 2026-05-27 03:40:44 -07:00
parent 69edb27ab3
commit 6a8d800e63
3 changed files with 234 additions and 61 deletions

View File

@ -64,15 +64,28 @@ struct OptAndOrPmuxWorker
std::vector<std::vector<DataExpr>> bits;
};
struct BitContribs {
int bit_idx = -1;
SigSpec select;
std::vector<Contribution> contribs;
};
struct SelectGroup {
SigSpec select;
std::vector<BitContribs> bits;
};
Module *module;
SigMap sigmap;
dict<SigBit, DriverBit> bit_drivers;
dict<SigBit, std::vector<ConsumerBit>> bit_consumers;
pool<SigBit> observable_bits;
pool<Cell*> removed_cells;
int converted_count = 0;
static const int max_cone_bits = 100000;
static const int min_pmux_bits = 8;
OptAndOrPmuxWorker(Module *module) : module(module), sigmap(module)
{
@ -83,6 +96,7 @@ struct OptAndOrPmuxWorker
{
bit_drivers.clear();
bit_consumers.clear();
observable_bits.clear();
for (auto cell : module->cells())
{
@ -112,6 +126,22 @@ struct OptAndOrPmuxWorker
}
}
}
for (auto &conn : module->connections())
{
SigSpec lhs = conn.first;
SigSpec rhs = sigmap(conn.second);
for (int i = 0; i < std::min(GetSize(lhs), GetSize(rhs)); i++) {
SigBit lhs_bit = lhs[i];
SigBit rhs_bit = rhs[i];
if (lhs_bit.wire == nullptr)
continue;
if (lhs_bit.wire->port_output || lhs_bit.wire->get_bool_attribute(ID::keep)) {
observable_bits.insert(lhs_bit);
observable_bits.insert(rhs_bit);
}
}
}
}
bool get_driver(SigBit bit, DriverBit &driver) const
@ -135,36 +165,20 @@ struct OptAndOrPmuxWorker
return true;
}
bool feeds_or(Cell *cell) const
bool bit_has_observable_output(SigBit bit) const
{
SigSpec y = sigmap(cell->getPort(ID::Y));
for (auto bit : y) {
if (bit.wire == nullptr)
continue;
auto it = bit_consumers.find(bit);
if (it == bit_consumers.end())
continue;
bit = sigmap(bit);
if (bit.wire == nullptr)
return false;
if (bit.wire->port_output || bit.wire->get_bool_attribute(ID::keep))
return true;
if (observable_bits.count(bit))
return true;
auto it = bit_consumers.find(bit);
if (it != bit_consumers.end())
for (auto &consumer : it->second)
if (!removed_cells.count(consumer.cell) && consumer.cell != cell && consumer.cell->type == ID($or))
if (!removed_cells.count(consumer.cell))
return true;
}
return false;
}
bool has_observable_output(Cell *cell) const
{
SigSpec y = sigmap(cell->getPort(ID::Y));
for (auto bit : y) {
if (bit.wire == nullptr)
continue;
if (bit.wire->port_output || bit.wire->get_bool_attribute(ID::keep))
return true;
auto it = bit_consumers.find(bit);
if (it != bit_consumers.end())
for (auto &consumer : it->second)
if (!removed_cells.count(consumer.cell))
return true;
}
return false;
}
@ -358,46 +372,56 @@ struct OptAndOrPmuxWorker
return make_or_tree(cell, terms, src);
}
bool try_convert(Cell *cell)
bool collect_bit_contribs(Cell *cell, int bit_idx, BitContribs &bit_contribs) const
{
if (removed_cells.count(cell))
return false;
if (cell->type != ID($or) || cell->get_bool_attribute(ID::keep))
return false;
if (feeds_or(cell) || !has_observable_output(cell))
SigSpec y = sigmap(cell->getPort(ID::Y));
SigBit y_bit = y[bit_idx];
if (!bit_has_observable_output(y_bit))
return false;
std::vector<SigBit> terms;
pool<SigBit> seen;
int budget = max_cone_bits;
if (!collect_or_terms(y_bit, terms, seen, budget))
return false;
bit_contribs.bit_idx = bit_idx;
for (auto term : terms)
{
Contribution contrib;
TermResult result = parse_term(term, contrib);
if (result == TERM_ZERO)
continue;
if (result == TERM_FAIL)
return false;
if (bit_contribs.select.empty())
bit_contribs.select = contrib.eq.select;
else if (bit_contribs.select != contrib.eq.select)
return false;
bit_contribs.contribs.push_back(contrib);
}
return !bit_contribs.contribs.empty();
}
bool convert_group(Cell *cell, const SelectGroup &group, pool<int> &converted_bits)
{
SigSpec y = sigmap(cell->getPort(ID::Y));
int width = GetSize(y);
int width = GetSize(group.bits);
if (width == 0)
return false;
SigSpec select;
std::vector<Arm> arms;
dict<Const, int> arm_index;
for (int bit_idx = 0; bit_idx < width; bit_idx++)
for (int group_bit_idx = 0; group_bit_idx < width; group_bit_idx++)
{
std::vector<SigBit> terms;
pool<SigBit> seen;
int budget = max_cone_bits;
if (!collect_or_terms(y[bit_idx], terms, seen, budget))
return false;
for (auto term : terms)
const BitContribs &bit_contribs = group.bits[group_bit_idx];
for (auto &contrib : bit_contribs.contribs)
{
Contribution contrib;
TermResult result = parse_term(term, contrib);
if (result == TERM_ZERO)
continue;
if (result == TERM_FAIL)
return false;
if (select.empty())
select = contrib.eq.select;
else if (select != contrib.eq.select)
return false;
int arm_idx;
auto it = arm_index.find(contrib.eq.value);
if (it == arm_index.end()) {
@ -408,16 +432,19 @@ struct OptAndOrPmuxWorker
arm_idx = it->second;
}
arms[arm_idx].bits[bit_idx].push_back(contrib.data);
arms[arm_idx].bits[group_bit_idx].push_back(contrib.data);
}
}
if (GetSize(arms) < 2)
if (GetSize(arms) < 2 || GetSize(arms) * width < min_pmux_bits)
return false;
SigSpec pmux_s, pmux_b;
SigSpec pmux_y, pmux_s, pmux_b;
std::string src = cell->get_src_attribute();
for (auto &bit_contribs : group.bits)
pmux_y.append(y[bit_contribs.bit_idx]);
for (auto &arm : arms)
{
SigSpec data;
@ -431,11 +458,69 @@ struct OptAndOrPmuxWorker
log("Converting AND/OR mux %s.%s to a $pmux with %d cases and width %d.\n",
log_id(module), log_id(cell), GetSize(arms), width);
module->addPmux(NEW_ID2_SUFFIX("andor_pmux"), Const(State::S0, width), pmux_b, pmux_s, cell->getPort(ID::Y), src);
module->addPmux(NEW_ID2_SUFFIX("andor_pmux"), Const(State::S0, width), pmux_b, pmux_s, pmux_y, src);
for (auto &bit_contribs : group.bits)
converted_bits.insert(bit_contribs.bit_idx);
converted_count++;
return true;
}
bool try_convert(Cell *cell)
{
if (removed_cells.count(cell))
return false;
if (cell->type != ID($or) || cell->get_bool_attribute(ID::keep))
return false;
SigSpec y = sigmap(cell->getPort(ID::Y));
SigSpec a = sigmap(cell->getPort(ID::A));
SigSpec b = sigmap(cell->getPort(ID::B));
int width = GetSize(y);
if (width == 0)
return false;
std::vector<SelectGroup> groups;
for (int bit_idx = 0; bit_idx < width; bit_idx++)
{
BitContribs bit_contribs;
if (!collect_bit_contribs(cell, bit_idx, bit_contribs))
continue;
bool found = false;
for (auto &group : groups) {
if (group.select == bit_contribs.select) {
group.bits.push_back(bit_contribs);
found = true;
break;
}
}
if (!found)
groups.push_back({bit_contribs.select, {bit_contribs}});
}
pool<int> converted_bits;
for (auto &group : groups)
convert_group(cell, group, converted_bits);
if (converted_bits.empty())
return false;
SigSpec keep_a, keep_b, keep_y;
for (int bit_idx = 0; bit_idx < width; bit_idx++) {
if (converted_bits.count(bit_idx))
continue;
keep_a.append(a[bit_idx]);
keep_b.append(b[bit_idx]);
keep_y.append(y[bit_idx]);
}
std::string src = cell->get_src_attribute();
if (!keep_y.empty())
module->addOr(NEW_ID2_SUFFIX("andor_pmux_keep_or"), keep_a, keep_b, keep_y, false, src);
removed_cells.insert(cell);
module->remove(cell);
converted_count++;
return true;
}
@ -448,7 +533,8 @@ struct OptAndOrPmuxWorker
cells.push_back(cell);
for (auto cell : cells)
try_convert(cell);
if (try_convert(cell))
build_maps();
}
};
@ -466,6 +552,7 @@ struct OptAndOrPmuxPass : public Pass {
log("\n");
log("into $pmux cells. It only rewrites terms whose select conditions are\n");
log("equality comparisons against distinct constants of the same select signal.\n");
log("Very small conversions are ignored to avoid replacing tiny boolean cones.\n");
log("\n");
}

View File

@ -0,0 +1,48 @@
module mixed_vector_decode(
input [2:0] sel0,
input [2:0] sel1,
input [3:0] a0,
input [3:0] b0,
input [3:0] c0,
input [3:0] a1,
input [3:0] b1,
input [3:0] c1,
output [3:0] y0,
output [3:0] y1
);
assign {y1, y0} =
({ {4{sel1 == 3'd1}}, {4{sel0 == 3'd1}} } & {a1, a0}) |
({ {4{sel1 == 3'd2}}, {4{sel0 == 3'd2}} } & {b1, b0}) |
({ {4{sel1 == 3'd3}}, {4{sel0 == 3'd3}} } & {c1, c0});
endmodule
module partial_vector_decode(
input [2:0] sel0,
input [2:0] sel1,
input [3:0] a0,
input [3:0] b0,
input [3:0] c0,
input [3:0] a1,
input [3:0] b1,
input [3:0] c1,
input [3:0] passthru,
output [3:0] y0,
output [3:0] y1
);
wire [7:0] stage =
({ {4{sel1 == 3'd1}}, {4{sel0 == 3'd1}} } & {a1, a0}) |
({ {4{sel1 == 3'd2}}, {4{sel0 == 3'd2}} } & {b1, b0}) |
({ {4{sel1 == 3'd3}}, {4{sel0 == 3'd3}} } & {c1, c0});
assign y1 = stage[7:4];
assign y0 = stage[3:0] | passthru;
endmodule
module tiny_decode(
input [1:0] sel,
input a,
input b,
output y
);
assign y = ((sel == 2'd1) & a) | ((sel == 2'd2) & b);
endmodule

View File

@ -0,0 +1,38 @@
read_verilog opt_andor_pmux.v
hierarchy -top mixed_vector_decode
proc
design -save mixed_gold
opt_andor_pmux
select -assert-count 2 t:$pmux
design -stash mixed_gate
design -copy-from mixed_gold -as gold mixed_vector_decode
design -copy-from mixed_gate -as gate mixed_vector_decode
miter -equiv -flatten -make_assert -make_outputs gold gate miter
sat -verify -prove-asserts -show-inputs -show-outputs miter
design -reset
read_verilog opt_andor_pmux.v
hierarchy -top partial_vector_decode
proc
design -save partial_gold
opt_andor_pmux
select -assert-count 2 t:$pmux
design -stash partial_gate
design -copy-from partial_gold -as gold partial_vector_decode
design -copy-from partial_gate -as gate partial_vector_decode
miter -equiv -flatten -make_assert -make_outputs gold gate miter
sat -verify -prove-asserts -show-inputs -show-outputs miter
design -reset
read_verilog opt_andor_pmux.v
hierarchy -top tiny_decode
proc
opt_andor_pmux
select -assert-count 0 t:$pmux