diff --git a/passes/opt/opt_andor_pmux.cc b/passes/opt/opt_andor_pmux.cc index d034ed236..a16717287 100644 --- a/passes/opt/opt_andor_pmux.cc +++ b/passes/opt/opt_andor_pmux.cc @@ -64,15 +64,28 @@ struct OptAndOrPmuxWorker std::vector> bits; }; + struct BitContribs { + int bit_idx = -1; + SigSpec select; + std::vector contribs; + }; + + struct SelectGroup { + SigSpec select; + std::vector bits; + }; + Module *module; SigMap sigmap; dict bit_drivers; dict> bit_consumers; + pool observable_bits; pool removed_cells; int converted_count = 0; static const int max_cone_bits = 100000; + static const int min_pmux_bits = 8; OptAndOrPmuxWorker(Module *module) : module(module), sigmap(module) { @@ -83,6 +96,7 @@ struct OptAndOrPmuxWorker { bit_drivers.clear(); bit_consumers.clear(); + observable_bits.clear(); for (auto cell : module->cells()) { @@ -112,6 +126,22 @@ struct OptAndOrPmuxWorker } } } + + for (auto &conn : module->connections()) + { + SigSpec lhs = conn.first; + SigSpec rhs = sigmap(conn.second); + for (int i = 0; i < std::min(GetSize(lhs), GetSize(rhs)); i++) { + SigBit lhs_bit = lhs[i]; + SigBit rhs_bit = rhs[i]; + if (lhs_bit.wire == nullptr) + continue; + if (lhs_bit.wire->port_output || lhs_bit.wire->get_bool_attribute(ID::keep)) { + observable_bits.insert(lhs_bit); + observable_bits.insert(rhs_bit); + } + } + } } bool get_driver(SigBit bit, DriverBit &driver) const @@ -135,36 +165,20 @@ struct OptAndOrPmuxWorker return true; } - bool feeds_or(Cell *cell) const + bool bit_has_observable_output(SigBit bit) const { - SigSpec y = sigmap(cell->getPort(ID::Y)); - for (auto bit : y) { - if (bit.wire == nullptr) - continue; - auto it = bit_consumers.find(bit); - if (it == bit_consumers.end()) - continue; + bit = sigmap(bit); + if (bit.wire == nullptr) + return false; + if (bit.wire->port_output || bit.wire->get_bool_attribute(ID::keep)) + return true; + if (observable_bits.count(bit)) + return true; + auto it = bit_consumers.find(bit); + if (it != bit_consumers.end()) for (auto &consumer : it->second) - if (!removed_cells.count(consumer.cell) && consumer.cell != cell && consumer.cell->type == ID($or)) + if (!removed_cells.count(consumer.cell)) return true; - } - return false; - } - - bool has_observable_output(Cell *cell) const - { - SigSpec y = sigmap(cell->getPort(ID::Y)); - for (auto bit : y) { - if (bit.wire == nullptr) - continue; - if (bit.wire->port_output || bit.wire->get_bool_attribute(ID::keep)) - return true; - auto it = bit_consumers.find(bit); - if (it != bit_consumers.end()) - for (auto &consumer : it->second) - if (!removed_cells.count(consumer.cell)) - return true; - } return false; } @@ -358,46 +372,56 @@ struct OptAndOrPmuxWorker return make_or_tree(cell, terms, src); } - bool try_convert(Cell *cell) + bool collect_bit_contribs(Cell *cell, int bit_idx, BitContribs &bit_contribs) const { - if (removed_cells.count(cell)) - return false; - if (cell->type != ID($or) || cell->get_bool_attribute(ID::keep)) - return false; - if (feeds_or(cell) || !has_observable_output(cell)) + SigSpec y = sigmap(cell->getPort(ID::Y)); + SigBit y_bit = y[bit_idx]; + + if (!bit_has_observable_output(y_bit)) return false; + std::vector terms; + pool seen; + int budget = max_cone_bits; + if (!collect_or_terms(y_bit, terms, seen, budget)) + return false; + + bit_contribs.bit_idx = bit_idx; + for (auto term : terms) + { + Contribution contrib; + TermResult result = parse_term(term, contrib); + if (result == TERM_ZERO) + continue; + if (result == TERM_FAIL) + return false; + + if (bit_contribs.select.empty()) + bit_contribs.select = contrib.eq.select; + else if (bit_contribs.select != contrib.eq.select) + return false; + + bit_contribs.contribs.push_back(contrib); + } + + return !bit_contribs.contribs.empty(); + } + + bool convert_group(Cell *cell, const SelectGroup &group, pool &converted_bits) + { SigSpec y = sigmap(cell->getPort(ID::Y)); - int width = GetSize(y); + int width = GetSize(group.bits); if (width == 0) return false; - SigSpec select; std::vector arms; dict arm_index; - for (int bit_idx = 0; bit_idx < width; bit_idx++) + for (int group_bit_idx = 0; group_bit_idx < width; group_bit_idx++) { - std::vector terms; - pool seen; - int budget = max_cone_bits; - if (!collect_or_terms(y[bit_idx], terms, seen, budget)) - return false; - - for (auto term : terms) + const BitContribs &bit_contribs = group.bits[group_bit_idx]; + for (auto &contrib : bit_contribs.contribs) { - Contribution contrib; - TermResult result = parse_term(term, contrib); - if (result == TERM_ZERO) - continue; - if (result == TERM_FAIL) - return false; - - if (select.empty()) - select = contrib.eq.select; - else if (select != contrib.eq.select) - return false; - int arm_idx; auto it = arm_index.find(contrib.eq.value); if (it == arm_index.end()) { @@ -408,16 +432,19 @@ struct OptAndOrPmuxWorker arm_idx = it->second; } - arms[arm_idx].bits[bit_idx].push_back(contrib.data); + arms[arm_idx].bits[group_bit_idx].push_back(contrib.data); } } - if (GetSize(arms) < 2) + if (GetSize(arms) < 2 || GetSize(arms) * width < min_pmux_bits) return false; - SigSpec pmux_s, pmux_b; + SigSpec pmux_y, pmux_s, pmux_b; std::string src = cell->get_src_attribute(); + for (auto &bit_contribs : group.bits) + pmux_y.append(y[bit_contribs.bit_idx]); + for (auto &arm : arms) { SigSpec data; @@ -431,11 +458,69 @@ struct OptAndOrPmuxWorker log("Converting AND/OR mux %s.%s to a $pmux with %d cases and width %d.\n", log_id(module), log_id(cell), GetSize(arms), width); - module->addPmux(NEW_ID2_SUFFIX("andor_pmux"), Const(State::S0, width), pmux_b, pmux_s, cell->getPort(ID::Y), src); + module->addPmux(NEW_ID2_SUFFIX("andor_pmux"), Const(State::S0, width), pmux_b, pmux_s, pmux_y, src); + for (auto &bit_contribs : group.bits) + converted_bits.insert(bit_contribs.bit_idx); + converted_count++; + return true; + } + + bool try_convert(Cell *cell) + { + if (removed_cells.count(cell)) + return false; + if (cell->type != ID($or) || cell->get_bool_attribute(ID::keep)) + return false; + + SigSpec y = sigmap(cell->getPort(ID::Y)); + SigSpec a = sigmap(cell->getPort(ID::A)); + SigSpec b = sigmap(cell->getPort(ID::B)); + int width = GetSize(y); + if (width == 0) + return false; + + std::vector groups; + + for (int bit_idx = 0; bit_idx < width; bit_idx++) + { + BitContribs bit_contribs; + if (!collect_bit_contribs(cell, bit_idx, bit_contribs)) + continue; + + bool found = false; + for (auto &group : groups) { + if (group.select == bit_contribs.select) { + group.bits.push_back(bit_contribs); + found = true; + break; + } + } + if (!found) + groups.push_back({bit_contribs.select, {bit_contribs}}); + } + + pool converted_bits; + for (auto &group : groups) + convert_group(cell, group, converted_bits); + + if (converted_bits.empty()) + return false; + + SigSpec keep_a, keep_b, keep_y; + for (int bit_idx = 0; bit_idx < width; bit_idx++) { + if (converted_bits.count(bit_idx)) + continue; + keep_a.append(a[bit_idx]); + keep_b.append(b[bit_idx]); + keep_y.append(y[bit_idx]); + } + + std::string src = cell->get_src_attribute(); + if (!keep_y.empty()) + module->addOr(NEW_ID2_SUFFIX("andor_pmux_keep_or"), keep_a, keep_b, keep_y, false, src); removed_cells.insert(cell); module->remove(cell); - converted_count++; return true; } @@ -448,7 +533,8 @@ struct OptAndOrPmuxWorker cells.push_back(cell); for (auto cell : cells) - try_convert(cell); + if (try_convert(cell)) + build_maps(); } }; @@ -466,6 +552,7 @@ struct OptAndOrPmuxPass : public Pass { log("\n"); log("into $pmux cells. It only rewrites terms whose select conditions are\n"); log("equality comparisons against distinct constants of the same select signal.\n"); + log("Very small conversions are ignored to avoid replacing tiny boolean cones.\n"); log("\n"); } diff --git a/tests/opt/opt_andor_pmux.v b/tests/opt/opt_andor_pmux.v new file mode 100644 index 000000000..c0caaf5b3 --- /dev/null +++ b/tests/opt/opt_andor_pmux.v @@ -0,0 +1,48 @@ +module mixed_vector_decode( + input [2:0] sel0, + input [2:0] sel1, + input [3:0] a0, + input [3:0] b0, + input [3:0] c0, + input [3:0] a1, + input [3:0] b1, + input [3:0] c1, + output [3:0] y0, + output [3:0] y1 +); + assign {y1, y0} = + ({ {4{sel1 == 3'd1}}, {4{sel0 == 3'd1}} } & {a1, a0}) | + ({ {4{sel1 == 3'd2}}, {4{sel0 == 3'd2}} } & {b1, b0}) | + ({ {4{sel1 == 3'd3}}, {4{sel0 == 3'd3}} } & {c1, c0}); +endmodule + +module partial_vector_decode( + input [2:0] sel0, + input [2:0] sel1, + input [3:0] a0, + input [3:0] b0, + input [3:0] c0, + input [3:0] a1, + input [3:0] b1, + input [3:0] c1, + input [3:0] passthru, + output [3:0] y0, + output [3:0] y1 +); + wire [7:0] stage = + ({ {4{sel1 == 3'd1}}, {4{sel0 == 3'd1}} } & {a1, a0}) | + ({ {4{sel1 == 3'd2}}, {4{sel0 == 3'd2}} } & {b1, b0}) | + ({ {4{sel1 == 3'd3}}, {4{sel0 == 3'd3}} } & {c1, c0}); + + assign y1 = stage[7:4]; + assign y0 = stage[3:0] | passthru; +endmodule + +module tiny_decode( + input [1:0] sel, + input a, + input b, + output y +); + assign y = ((sel == 2'd1) & a) | ((sel == 2'd2) & b); +endmodule diff --git a/tests/opt/opt_andor_pmux.ys b/tests/opt/opt_andor_pmux.ys new file mode 100644 index 000000000..257b119f4 --- /dev/null +++ b/tests/opt/opt_andor_pmux.ys @@ -0,0 +1,38 @@ +read_verilog opt_andor_pmux.v + +hierarchy -top mixed_vector_decode +proc +design -save mixed_gold + +opt_andor_pmux +select -assert-count 2 t:$pmux + +design -stash mixed_gate +design -copy-from mixed_gold -as gold mixed_vector_decode +design -copy-from mixed_gate -as gate mixed_vector_decode +miter -equiv -flatten -make_assert -make_outputs gold gate miter +sat -verify -prove-asserts -show-inputs -show-outputs miter + +design -reset +read_verilog opt_andor_pmux.v + +hierarchy -top partial_vector_decode +proc +design -save partial_gold + +opt_andor_pmux +select -assert-count 2 t:$pmux + +design -stash partial_gate +design -copy-from partial_gold -as gold partial_vector_decode +design -copy-from partial_gate -as gate partial_vector_decode +miter -equiv -flatten -make_assert -make_outputs gold gate miter +sat -verify -prove-asserts -show-inputs -show-outputs miter + +design -reset +read_verilog opt_andor_pmux.v + +hierarchy -top tiny_decode +proc +opt_andor_pmux +select -assert-count 0 t:$pmux diff --git a/tests/various/opt_andor_pmux.v b/tests/various/opt_andor_pmux.v index 0670036fc..01f961a8f 100644 --- a/tests/various/opt_andor_pmux.v +++ b/tests/various/opt_andor_pmux.v @@ -2,40 +2,35 @@ module andor_pmux_basic ( input [2:0] sel, input [5:0] d, input a, - output [1:0] y + output [2:0] y ); - assign y = ({2{sel == 3'd1}} & d[1:0]) | - ({2{sel == 3'd3}} & {d[2] & a, d[3]}) | - ({2{sel == 3'd6}} & 2'b01); + assign y = ({3{sel == 3'd1}} & d[2:0]) | + ({3{sel == 3'd3}} & {d[3] & a, d[4], d[5]}) | + ({3{sel == 3'd6}} & 3'b001); endmodule module andor_pmux_outer_enable ( input [2:0] sel, - input [3:0] d, + input [5:0] d, input en, - output [1:0] y + output [2:0] y ); - wire [1:0] body; + wire [2:0] body; - assign body = ({2{sel == 3'd2}} & {1'b0, d[0]}) | - ({2{sel == 3'd5}} & {d[1], d[2]}) | - ({2{sel == 3'd7}} & {d[3], 1'b1}); - assign y = {2{en}} & body; + assign body = ({3{sel == 3'd2}} & {1'b0, d[0], d[1]}) | + ({3{sel == 3'd5}} & {d[2], d[3], d[4]}) | + ({3{sel == 3'd7}} & {d[5], 1'b1, d[0]}); + assign y = {3{en}} & body; endmodule module andor_pmux_duplicate ( input [1:0] sel, - input a, - input b, - input c, - input d, - input e, - input f, - output [1:0] y + input [11:0] d, + output [3:0] y ); - assign y = ({2{sel == 2'd1}} & {a, b}) | - ({2{sel == 2'd1}} & {c, d}) | - ({2{sel == 2'd2}} & {e, f}); + assign y = ({4{sel == 2'd1}} & d[3:0]) | + ({4{sel == 2'd1}} & d[7:4]) | + ({4{sel == 2'd2}} & d[11:8]); endmodule module andor_pmux_mixed_select_negative ( @@ -65,15 +60,15 @@ endmodule module andor_pmux_shared_subtree ( input [2:0] sel, - input [3:0] d, - input q, - output y, - output z + input [8:0] d, + input [2:0] q, + output [2:0] y, + output [2:0] z ); - wire sub = ((sel == 3'd1) & d[0]) | - ((sel == 3'd3) & d[1]); + wire [2:0] sub = ({3{sel == 3'd1}} & d[2:0]) | + ({3{sel == 3'd3}} & d[5:3]); - assign y = sub | ((sel == 3'd6) & d[2]); + assign y = sub | ({3{sel == 3'd6}} & d[8:6]); assign z = sub & q; endmodule @@ -106,12 +101,12 @@ endmodule module andor_pmux_duplicate_complex ( input [2:0] sel, - input [8:0] d, + input [11:0] d, input q, input r, - output [2:0] y + output [3:0] y ); - assign y = ({3{sel == 3'd2}} & {d[0] & q, d[1], d[2]}) | - ({3{sel == 3'd2}} & {d[3], d[4] & r, d[5]}) | - ({3{sel == 3'd5}} & {d[6], d[7] & q, d[8] & r}); + assign y = ({4{sel == 3'd2}} & {d[0] & q, d[1], d[2], d[3]}) | + ({4{sel == 3'd2}} & {d[4], d[5] & r, d[6], d[7]}) | + ({4{sel == 3'd5}} & {d[8], d[9] & q, d[10] & r, d[11]}); endmodule