Merge pull request #173 from Silimate/opt_andor_pmux

Fixes for filtering small cases and catching more larger ones with tr…
This commit is contained in:
Akash Levy 2026-05-27 04:37:33 -07:00 committed by GitHub
commit f585f79d31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 262 additions and 94 deletions

View File

@ -64,15 +64,28 @@ struct OptAndOrPmuxWorker
std::vector<std::vector<DataExpr>> bits;
};
struct BitContribs {
int bit_idx = -1;
SigSpec select;
std::vector<Contribution> contribs;
};
struct SelectGroup {
SigSpec select;
std::vector<BitContribs> bits;
};
Module *module;
SigMap sigmap;
dict<SigBit, DriverBit> bit_drivers;
dict<SigBit, std::vector<ConsumerBit>> bit_consumers;
pool<SigBit> observable_bits;
pool<Cell*> removed_cells;
int converted_count = 0;
static const int max_cone_bits = 100000;
static const int min_pmux_bits = 8;
OptAndOrPmuxWorker(Module *module) : module(module), sigmap(module)
{
@ -83,6 +96,7 @@ struct OptAndOrPmuxWorker
{
bit_drivers.clear();
bit_consumers.clear();
observable_bits.clear();
for (auto cell : module->cells())
{
@ -112,6 +126,22 @@ struct OptAndOrPmuxWorker
}
}
}
for (auto &conn : module->connections())
{
SigSpec lhs = conn.first;
SigSpec rhs = sigmap(conn.second);
for (int i = 0; i < std::min(GetSize(lhs), GetSize(rhs)); i++) {
SigBit lhs_bit = lhs[i];
SigBit rhs_bit = rhs[i];
if (lhs_bit.wire == nullptr)
continue;
if (lhs_bit.wire->port_output || lhs_bit.wire->get_bool_attribute(ID::keep)) {
observable_bits.insert(lhs_bit);
observable_bits.insert(rhs_bit);
}
}
}
}
bool get_driver(SigBit bit, DriverBit &driver) const
@ -135,36 +165,20 @@ struct OptAndOrPmuxWorker
return true;
}
bool feeds_or(Cell *cell) const
bool bit_has_observable_output(SigBit bit) const
{
SigSpec y = sigmap(cell->getPort(ID::Y));
for (auto bit : y) {
if (bit.wire == nullptr)
continue;
auto it = bit_consumers.find(bit);
if (it == bit_consumers.end())
continue;
bit = sigmap(bit);
if (bit.wire == nullptr)
return false;
if (bit.wire->port_output || bit.wire->get_bool_attribute(ID::keep))
return true;
if (observable_bits.count(bit))
return true;
auto it = bit_consumers.find(bit);
if (it != bit_consumers.end())
for (auto &consumer : it->second)
if (!removed_cells.count(consumer.cell) && consumer.cell != cell && consumer.cell->type == ID($or))
if (!removed_cells.count(consumer.cell))
return true;
}
return false;
}
bool has_observable_output(Cell *cell) const
{
SigSpec y = sigmap(cell->getPort(ID::Y));
for (auto bit : y) {
if (bit.wire == nullptr)
continue;
if (bit.wire->port_output || bit.wire->get_bool_attribute(ID::keep))
return true;
auto it = bit_consumers.find(bit);
if (it != bit_consumers.end())
for (auto &consumer : it->second)
if (!removed_cells.count(consumer.cell))
return true;
}
return false;
}
@ -358,46 +372,56 @@ struct OptAndOrPmuxWorker
return make_or_tree(cell, terms, src);
}
bool try_convert(Cell *cell)
bool collect_bit_contribs(Cell *cell, int bit_idx, BitContribs &bit_contribs) const
{
if (removed_cells.count(cell))
return false;
if (cell->type != ID($or) || cell->get_bool_attribute(ID::keep))
return false;
if (feeds_or(cell) || !has_observable_output(cell))
SigSpec y = sigmap(cell->getPort(ID::Y));
SigBit y_bit = y[bit_idx];
if (!bit_has_observable_output(y_bit))
return false;
std::vector<SigBit> terms;
pool<SigBit> seen;
int budget = max_cone_bits;
if (!collect_or_terms(y_bit, terms, seen, budget))
return false;
bit_contribs.bit_idx = bit_idx;
for (auto term : terms)
{
Contribution contrib;
TermResult result = parse_term(term, contrib);
if (result == TERM_ZERO)
continue;
if (result == TERM_FAIL)
return false;
if (bit_contribs.select.empty())
bit_contribs.select = contrib.eq.select;
else if (bit_contribs.select != contrib.eq.select)
return false;
bit_contribs.contribs.push_back(contrib);
}
return !bit_contribs.contribs.empty();
}
bool convert_group(Cell *cell, const SelectGroup &group, pool<int> &converted_bits)
{
SigSpec y = sigmap(cell->getPort(ID::Y));
int width = GetSize(y);
int width = GetSize(group.bits);
if (width == 0)
return false;
SigSpec select;
std::vector<Arm> arms;
dict<Const, int> arm_index;
for (int bit_idx = 0; bit_idx < width; bit_idx++)
for (int group_bit_idx = 0; group_bit_idx < width; group_bit_idx++)
{
std::vector<SigBit> terms;
pool<SigBit> seen;
int budget = max_cone_bits;
if (!collect_or_terms(y[bit_idx], terms, seen, budget))
return false;
for (auto term : terms)
const BitContribs &bit_contribs = group.bits[group_bit_idx];
for (auto &contrib : bit_contribs.contribs)
{
Contribution contrib;
TermResult result = parse_term(term, contrib);
if (result == TERM_ZERO)
continue;
if (result == TERM_FAIL)
return false;
if (select.empty())
select = contrib.eq.select;
else if (select != contrib.eq.select)
return false;
int arm_idx;
auto it = arm_index.find(contrib.eq.value);
if (it == arm_index.end()) {
@ -408,16 +432,19 @@ struct OptAndOrPmuxWorker
arm_idx = it->second;
}
arms[arm_idx].bits[bit_idx].push_back(contrib.data);
arms[arm_idx].bits[group_bit_idx].push_back(contrib.data);
}
}
if (GetSize(arms) < 2)
if (GetSize(arms) < 2 || GetSize(arms) * width < min_pmux_bits)
return false;
SigSpec pmux_s, pmux_b;
SigSpec pmux_y, pmux_s, pmux_b;
std::string src = cell->get_src_attribute();
for (auto &bit_contribs : group.bits)
pmux_y.append(y[bit_contribs.bit_idx]);
for (auto &arm : arms)
{
SigSpec data;
@ -431,11 +458,69 @@ struct OptAndOrPmuxWorker
log("Converting AND/OR mux %s.%s to a $pmux with %d cases and width %d.\n",
log_id(module), log_id(cell), GetSize(arms), width);
module->addPmux(NEW_ID2_SUFFIX("andor_pmux"), Const(State::S0, width), pmux_b, pmux_s, cell->getPort(ID::Y), src);
module->addPmux(NEW_ID2_SUFFIX("andor_pmux"), Const(State::S0, width), pmux_b, pmux_s, pmux_y, src);
for (auto &bit_contribs : group.bits)
converted_bits.insert(bit_contribs.bit_idx);
converted_count++;
return true;
}
bool try_convert(Cell *cell)
{
if (removed_cells.count(cell))
return false;
if (cell->type != ID($or) || cell->get_bool_attribute(ID::keep))
return false;
SigSpec y = sigmap(cell->getPort(ID::Y));
SigSpec a = sigmap(cell->getPort(ID::A));
SigSpec b = sigmap(cell->getPort(ID::B));
int width = GetSize(y);
if (width == 0)
return false;
std::vector<SelectGroup> groups;
for (int bit_idx = 0; bit_idx < width; bit_idx++)
{
BitContribs bit_contribs;
if (!collect_bit_contribs(cell, bit_idx, bit_contribs))
continue;
bool found = false;
for (auto &group : groups) {
if (group.select == bit_contribs.select) {
group.bits.push_back(bit_contribs);
found = true;
break;
}
}
if (!found)
groups.push_back({bit_contribs.select, {bit_contribs}});
}
pool<int> converted_bits;
for (auto &group : groups)
convert_group(cell, group, converted_bits);
if (converted_bits.empty())
return false;
SigSpec keep_a, keep_b, keep_y;
for (int bit_idx = 0; bit_idx < width; bit_idx++) {
if (converted_bits.count(bit_idx))
continue;
keep_a.append(a[bit_idx]);
keep_b.append(b[bit_idx]);
keep_y.append(y[bit_idx]);
}
std::string src = cell->get_src_attribute();
if (!keep_y.empty())
module->addOr(NEW_ID2_SUFFIX("andor_pmux_keep_or"), keep_a, keep_b, keep_y, false, src);
removed_cells.insert(cell);
module->remove(cell);
converted_count++;
return true;
}
@ -448,7 +533,8 @@ struct OptAndOrPmuxWorker
cells.push_back(cell);
for (auto cell : cells)
try_convert(cell);
if (try_convert(cell))
build_maps();
}
};
@ -466,6 +552,7 @@ struct OptAndOrPmuxPass : public Pass {
log("\n");
log("into $pmux cells. It only rewrites terms whose select conditions are\n");
log("equality comparisons against distinct constants of the same select signal.\n");
log("Very small conversions are ignored to avoid replacing tiny boolean cones.\n");
log("\n");
}

View File

@ -0,0 +1,48 @@
module mixed_vector_decode(
input [2:0] sel0,
input [2:0] sel1,
input [3:0] a0,
input [3:0] b0,
input [3:0] c0,
input [3:0] a1,
input [3:0] b1,
input [3:0] c1,
output [3:0] y0,
output [3:0] y1
);
assign {y1, y0} =
({ {4{sel1 == 3'd1}}, {4{sel0 == 3'd1}} } & {a1, a0}) |
({ {4{sel1 == 3'd2}}, {4{sel0 == 3'd2}} } & {b1, b0}) |
({ {4{sel1 == 3'd3}}, {4{sel0 == 3'd3}} } & {c1, c0});
endmodule
module partial_vector_decode(
input [2:0] sel0,
input [2:0] sel1,
input [3:0] a0,
input [3:0] b0,
input [3:0] c0,
input [3:0] a1,
input [3:0] b1,
input [3:0] c1,
input [3:0] passthru,
output [3:0] y0,
output [3:0] y1
);
wire [7:0] stage =
({ {4{sel1 == 3'd1}}, {4{sel0 == 3'd1}} } & {a1, a0}) |
({ {4{sel1 == 3'd2}}, {4{sel0 == 3'd2}} } & {b1, b0}) |
({ {4{sel1 == 3'd3}}, {4{sel0 == 3'd3}} } & {c1, c0});
assign y1 = stage[7:4];
assign y0 = stage[3:0] | passthru;
endmodule
module tiny_decode(
input [1:0] sel,
input a,
input b,
output y
);
assign y = ((sel == 2'd1) & a) | ((sel == 2'd2) & b);
endmodule

View File

@ -0,0 +1,38 @@
read_verilog opt_andor_pmux.v
hierarchy -top mixed_vector_decode
proc
design -save mixed_gold
opt_andor_pmux
select -assert-count 2 t:$pmux
design -stash mixed_gate
design -copy-from mixed_gold -as gold mixed_vector_decode
design -copy-from mixed_gate -as gate mixed_vector_decode
miter -equiv -flatten -make_assert -make_outputs gold gate miter
sat -verify -prove-asserts -show-inputs -show-outputs miter
design -reset
read_verilog opt_andor_pmux.v
hierarchy -top partial_vector_decode
proc
design -save partial_gold
opt_andor_pmux
select -assert-count 2 t:$pmux
design -stash partial_gate
design -copy-from partial_gold -as gold partial_vector_decode
design -copy-from partial_gate -as gate partial_vector_decode
miter -equiv -flatten -make_assert -make_outputs gold gate miter
sat -verify -prove-asserts -show-inputs -show-outputs miter
design -reset
read_verilog opt_andor_pmux.v
hierarchy -top tiny_decode
proc
opt_andor_pmux
select -assert-count 0 t:$pmux

View File

@ -2,40 +2,35 @@ module andor_pmux_basic (
input [2:0] sel,
input [5:0] d,
input a,
output [1:0] y
output [2:0] y
);
assign y = ({2{sel == 3'd1}} & d[1:0]) |
({2{sel == 3'd3}} & {d[2] & a, d[3]}) |
({2{sel == 3'd6}} & 2'b01);
assign y = ({3{sel == 3'd1}} & d[2:0]) |
({3{sel == 3'd3}} & {d[3] & a, d[4], d[5]}) |
({3{sel == 3'd6}} & 3'b001);
endmodule
module andor_pmux_outer_enable (
input [2:0] sel,
input [3:0] d,
input [5:0] d,
input en,
output [1:0] y
output [2:0] y
);
wire [1:0] body;
wire [2:0] body;
assign body = ({2{sel == 3'd2}} & {1'b0, d[0]}) |
({2{sel == 3'd5}} & {d[1], d[2]}) |
({2{sel == 3'd7}} & {d[3], 1'b1});
assign y = {2{en}} & body;
assign body = ({3{sel == 3'd2}} & {1'b0, d[0], d[1]}) |
({3{sel == 3'd5}} & {d[2], d[3], d[4]}) |
({3{sel == 3'd7}} & {d[5], 1'b1, d[0]});
assign y = {3{en}} & body;
endmodule
module andor_pmux_duplicate (
input [1:0] sel,
input a,
input b,
input c,
input d,
input e,
input f,
output [1:0] y
input [11:0] d,
output [3:0] y
);
assign y = ({2{sel == 2'd1}} & {a, b}) |
({2{sel == 2'd1}} & {c, d}) |
({2{sel == 2'd2}} & {e, f});
assign y = ({4{sel == 2'd1}} & d[3:0]) |
({4{sel == 2'd1}} & d[7:4]) |
({4{sel == 2'd2}} & d[11:8]);
endmodule
module andor_pmux_mixed_select_negative (
@ -65,15 +60,15 @@ endmodule
module andor_pmux_shared_subtree (
input [2:0] sel,
input [3:0] d,
input q,
output y,
output z
input [8:0] d,
input [2:0] q,
output [2:0] y,
output [2:0] z
);
wire sub = ((sel == 3'd1) & d[0]) |
((sel == 3'd3) & d[1]);
wire [2:0] sub = ({3{sel == 3'd1}} & d[2:0]) |
({3{sel == 3'd3}} & d[5:3]);
assign y = sub | ((sel == 3'd6) & d[2]);
assign y = sub | ({3{sel == 3'd6}} & d[8:6]);
assign z = sub & q;
endmodule
@ -106,12 +101,12 @@ endmodule
module andor_pmux_duplicate_complex (
input [2:0] sel,
input [8:0] d,
input [11:0] d,
input q,
input r,
output [2:0] y
output [3:0] y
);
assign y = ({3{sel == 3'd2}} & {d[0] & q, d[1], d[2]}) |
({3{sel == 3'd2}} & {d[3], d[4] & r, d[5]}) |
({3{sel == 3'd5}} & {d[6], d[7] & q, d[8] & r});
assign y = ({4{sel == 3'd2}} & {d[0] & q, d[1], d[2], d[3]}) |
({4{sel == 3'd2}} & {d[4], d[5] & r, d[6], d[7]}) |
({4{sel == 3'd5}} & {d[8], d[9] & q, d[10] & r, d[11]});
endmodule