Improvement to opt_balance_tree

This commit is contained in:
Akash Levy 2026-06-01 17:56:44 -07:00
parent d1ec970f86
commit 9cc69a3c49
2 changed files with 355 additions and 4 deletions

View File

@ -34,6 +34,14 @@ struct OptBalanceTreeWorker {
// Counts of each cell type that are getting balanced
dict<IdString, int> cell_count;
int sliced_add_count = 0;
struct SlicedAddContext {
dict<SigBit, Cell*> bit_to_driver;
dict<SigBit, int> bit_to_driver_index;
dict<SigBit, pool<Cell*>> bit_to_sink;
pool<SigBit> output_port_sigs;
};
// Check if cell is of the right type and has matching input/output widths
// Only allow cells with "natural" output widths (no truncation) to prevent
@ -66,6 +74,28 @@ struct OptBalanceTreeWorker {
return y_width >= natural_width;
}
bool is_unsigned_add(Cell *cell)
{
return cell && is_right_type(cell, ID($add)) &&
!cell->getParam(ID::A_SIGNED).as_bool() &&
!cell->getParam(ID::B_SIGNED).as_bool();
}
bool is_nonzero(const SigSpec &sig)
{
for (auto bit : sig)
if (bit != State::S0)
return true;
return false;
}
SigSpec shift_summand(const SigSpec &sig, int offset)
{
SigSpec shifted(State::S0, offset);
shifted.append(sig);
return shifted;
}
// Create a balanced binary tree from a vector of source signals
SigSpec create_balanced_tree(vector<SigSpec> &sources, IdString cell_type, Cell* cell) {
// Base case: if we have no sources, return an empty signal
@ -140,25 +170,198 @@ struct OptBalanceTreeWorker {
return out_wire;
}
bool full_child_output_at(const SigSpec &sig, int pos, Cell *&child, int &child_width,
SlicedAddContext &ctx)
{
child = nullptr;
child_width = 0;
if (pos >= GetSize(sig))
return false;
SigBit bit = sig[pos];
auto driver_it = ctx.bit_to_driver.find(bit);
if (driver_it == ctx.bit_to_driver.end())
return false;
Cell *candidate = driver_it->second;
if (!is_unsigned_add(candidate))
return false;
auto index_it = ctx.bit_to_driver_index.find(bit);
if (index_it == ctx.bit_to_driver_index.end() || index_it->second != 0)
return false;
SigSpec y = sigmap(candidate->getPort(ID::Y));
child_width = GetSize(y);
if (pos + child_width > GetSize(sig))
return false;
for (int i = 0; i < child_width; i++)
if (sig[pos + i] != y[i])
return false;
child = candidate;
return true;
}
bool bit_is_partial_add_output(SigBit bit, SlicedAddContext &ctx)
{
auto driver_it = ctx.bit_to_driver.find(bit);
if (driver_it == ctx.bit_to_driver.end())
return false;
return is_unsigned_add(driver_it->second);
}
bool extract_sliced_operand(const SigSpec &sig, int base_offset, vector<SigSpec> &summands,
pool<Cell*> &cluster, pool<Cell*> &visiting, SlicedAddContext &ctx, bool &saw_sliced_edge)
{
for (int i = 0; i < GetSize(sig); )
{
Cell *child = nullptr;
int child_width = 0;
if (full_child_output_at(sig, i, child, child_width, ctx))
{
if (i != 0 || child_width != GetSize(sig))
saw_sliced_edge = true;
if (!extract_sliced_add(child, base_offset + i, summands, cluster, visiting, ctx, saw_sliced_edge))
return false;
i += child_width;
continue;
}
if (bit_is_partial_add_output(sig[i], ctx))
return false;
SigSpec leaf;
int leaf_start = i;
while (i < GetSize(sig))
{
Cell *next_child = nullptr;
int next_child_width = 0;
if (full_child_output_at(sig, i, next_child, next_child_width, ctx))
break;
if (bit_is_partial_add_output(sig[i], ctx))
return false;
leaf.append(sig[i]);
i++;
}
if (is_nonzero(leaf))
summands.push_back(shift_summand(leaf, base_offset + leaf_start));
}
return true;
}
bool extract_sliced_add(Cell *cell, int base_offset, vector<SigSpec> &summands,
pool<Cell*> &cluster, pool<Cell*> &visiting, SlicedAddContext &ctx, bool &saw_sliced_edge)
{
if (!is_unsigned_add(cell) || visiting.count(cell))
return false;
visiting.insert(cell);
cluster.insert(cell);
for (IdString port : {ID::A, ID::B}) {
SigSpec sig = sigmap(cell->getPort(port));
if (!extract_sliced_operand(sig, base_offset, summands, cluster, visiting, ctx, saw_sliced_edge))
return false;
}
visiting.erase(cell);
return true;
}
bool has_downstream_add_sink(Cell *cell, SlicedAddContext &ctx)
{
SigSpec y = sigmap(cell->getPort(ID::Y));
for (auto bit : y)
for (auto sink : ctx.bit_to_sink[bit])
if (sink != cell && is_unsigned_add(sink))
return true;
return false;
}
bool sliced_cluster_has_external_fanout(Cell *head_cell, pool<Cell*> &cluster, SlicedAddContext &ctx)
{
for (auto cell : cluster)
{
if (cell == head_cell)
continue;
SigSpec y = sigmap(cell->getPort(ID::Y));
for (auto bit : y)
{
if (ctx.output_port_sigs.count(bit))
return true;
for (auto sink : ctx.bit_to_sink[bit])
if (!cluster.count(sink))
return true;
}
}
return false;
}
bool try_sliced_add_tree(Cell *head_cell, pool<Cell*> &consumed_cells, SlicedAddContext &ctx)
{
if (!is_unsigned_add(head_cell) || consumed_cells.count(head_cell) || has_downstream_add_sink(head_cell, ctx))
return false;
vector<SigSpec> summands;
pool<Cell*> cluster, visiting;
bool saw_sliced_edge = false;
if (!extract_sliced_add(head_cell, 0, summands, cluster, visiting, ctx, saw_sliced_edge))
return false;
if (!saw_sliced_edge || GetSize(cluster) <= 1 || GetSize(summands) <= 2)
return false;
if (sliced_cluster_has_external_fanout(head_cell, cluster, ctx))
return false;
log_debug(" Creating sliced add tree for %s with %d summands and %d cells...\n",
log_id(head_cell), GetSize(summands), GetSize(cluster));
SigSpec tree_output = create_balanced_tree(summands, ID($add), head_cell);
SigSpec head_output = sigmap(head_cell->getPort(ID::Y));
int connect_width = std::min(head_output.size(), tree_output.size());
module->connect(head_output.extract(0, connect_width), tree_output.extract(0, connect_width));
if (head_output.size() > tree_output.size())
module->connect(head_output.extract(connect_width, head_output.size() - connect_width),
SigSpec(State::S0, head_output.size() - connect_width));
for (auto cell : cluster)
consumed_cells.insert(cell);
sliced_add_count++;
return true;
}
OptBalanceTreeWorker(Module *module, const vector<IdString> cell_types) : module(module), sigmap(module) {
// Do for each cell type
for (auto cell_type : cell_types) {
// Index all of the nets in the module
dict<SigSpec, Cell*> sig_to_driver;
dict<SigSpec, pool<Cell*>> sig_to_sink;
SlicedAddContext sliced_add_ctx;
for (auto cell : module->selected_cells())
{
for (auto &conn : cell->connections())
{
if (cell->output(conn.first))
sig_to_driver[sigmap(conn.second)] = cell;
SigSpec sig = sigmap(conn.second);
if (cell->output(conn.first)) {
sig_to_driver[sig] = cell;
for (int i = 0; i < GetSize(sig); i++) {
sliced_add_ctx.bit_to_driver[sig[i]] = cell;
sliced_add_ctx.bit_to_driver_index[sig[i]] = i;
}
}
if (cell->input(conn.first))
{
SigSpec sig = sigmap(conn.second);
if (sig_to_sink.count(sig) == 0)
sig_to_sink[sig] = pool<Cell*>();
sig_to_sink[sig].insert(cell);
for (auto bit : sig)
sliced_add_ctx.bit_to_sink[bit].insert(cell);
}
}
}
@ -172,13 +375,19 @@ struct OptBalanceTreeWorker {
for (auto bit : sig) {
if (wire->port_input)
input_port_sigs.insert(bit);
if (wire->port_output)
if (wire->port_output) {
output_port_sigs.insert(bit);
sliced_add_ctx.output_port_sigs.insert(bit);
}
}
}
// Actual logic starts here
pool<Cell*> consumed_cells;
if (cell_type == ID($add))
for (auto cell : module->selected_cells())
try_sliced_add_tree(cell, consumed_cells, sliced_add_ctx);
for (auto cell : module->selected_cells())
{
// If consumed or not the correct type, skip
@ -362,16 +571,20 @@ struct OptBalanceTreePass : public Pass {
// Count of all cells that were packed
dict<IdString, int> cell_count;
int sliced_add_count = 0;
for (auto module : design->selected_modules()) {
OptBalanceTreeWorker worker(module, cell_types);
for (auto cell : worker.cell_count) {
cell_count[cell.first] += cell.second;
}
sliced_add_count += worker.sliced_add_count;
}
// Log stats
for (auto cell_type : cell_types)
log("Converted %d %s cells into trees.\n", cell_count[cell_type], log_id(cell_type));
if (std::find(cell_types.begin(), cell_types.end(), ID($add)) != cell_types.end())
log("Converted %d sliced $add chains into trees.\n", sliced_add_count);
// Clean up
Yosys::run_pass("clean -purge");

View File

@ -42,6 +42,144 @@ log -pop
# Test 31
log -header "Sliced shifted ADD chain"
log -push
design -reset
read_verilog -icells <<EOF
module top (
input wire [15:0] a,
input wire [15:0] b,
input wire [15:0] c,
input wire [15:0] d,
input wire [15:0] e,
output wire [63:0] y
);
wire [16:0] s0;
wire [31:0] s1;
wire [47:0] s2;
// s0 is later embedded into a wider operand with b[13:0] below it.
// The sliced-add matcher should still flatten and rebalance the whole chain.
\$add #(.A_WIDTH(16), .B_WIDTH(16), .Y_WIDTH(17), .A_SIGNED(0), .B_SIGNED(0))
add0 (.A(a), .B({14'b0, b[15:14]}), .Y(s0));
\$add #(.A_WIDTH(31), .B_WIDTH(31), .Y_WIDTH(32), .A_SIGNED(0), .B_SIGNED(0))
add1 (.A({s0, b[13:0]}), .B({15'b0, c}), .Y(s1));
\$add #(.A_WIDTH(32), .B_WIDTH(32), .Y_WIDTH(48), .A_SIGNED(0), .B_SIGNED(0))
add2 (.A(s1), .B({16'b0, d}), .Y(s2));
\$add #(.A_WIDTH(48), .B_WIDTH(48), .Y_WIDTH(64), .A_SIGNED(0), .B_SIGNED(0))
add3 (.A(s2), .B({32'b0, e}), .Y(y));
endmodule
EOF
check -assert
design -save preopt
equiv_opt -assert opt_balance_tree
design -load preopt
opt_balance_tree
select -assert-count 0 c:add0 c:add1 c:add2 c:add3
select -assert-count 1 c:add3_tree_3
design -reset
log -pop
# Test 32
log -header "Sliced shifted ADD chain with external intermediate fanout"
log -push
design -reset
read_verilog -icells <<EOF
module top (
input wire [15:0] a,
input wire [15:0] b,
input wire [15:0] c,
output wire [31:0] y,
output wire [16:0] tap
);
wire [16:0] s0;
\$add #(.A_WIDTH(16), .B_WIDTH(16), .Y_WIDTH(17), .A_SIGNED(0), .B_SIGNED(0))
add0 (.A(a), .B({14'b0, b[15:14]}), .Y(s0));
\$add #(.A_WIDTH(31), .B_WIDTH(31), .Y_WIDTH(32), .A_SIGNED(0), .B_SIGNED(0))
add1 (.A({s0, b[13:0]}), .B({15'b0, c}), .Y(y));
assign tap = s0;
endmodule
EOF
check -assert
design -save preopt
equiv_opt -assert opt_balance_tree
design -load preopt
opt_balance_tree
select -assert-count 1 c:add0
design -reset
log -pop
# Test 33
log -header "Signed sliced ADD chain is skipped"
log -push
design -reset
read_verilog -icells <<EOF
module top (
input wire signed [15:0] a,
input wire signed [15:0] b,
input wire signed [15:0] c,
output wire signed [31:0] y
);
wire signed [16:0] s0;
\$add #(.A_WIDTH(16), .B_WIDTH(16), .Y_WIDTH(17), .A_SIGNED(1), .B_SIGNED(1))
add0 (.A(a), .B({14'b0, b[15:14]}), .Y(s0));
\$add #(.A_WIDTH(31), .B_WIDTH(31), .Y_WIDTH(32), .A_SIGNED(1), .B_SIGNED(1))
add1 (.A({s0, b[13:0]}), .B({15'b0, c}), .Y(y));
endmodule
EOF
check -assert
design -save preopt
equiv_opt -assert opt_balance_tree
design -load preopt
opt_balance_tree
select -assert-count 1 c:add0
design -reset
log -pop
# Test 34
log -header "High-slice-only ADD dependency is skipped"
log -push
design -reset
read_verilog -icells <<EOF
module top (
input wire [15:0] a,
input wire [15:0] b,
input wire [15:0] c,
output wire [31:0] y
);
wire [16:0] s0;
\$add #(.A_WIDTH(16), .B_WIDTH(16), .Y_WIDTH(17), .A_SIGNED(0), .B_SIGNED(0))
add0 (.A(a), .B({14'b0, b[15:14]}), .Y(s0));
\$add #(.A_WIDTH(30), .B_WIDTH(30), .Y_WIDTH(32), .A_SIGNED(0), .B_SIGNED(0))
add1 (.A({s0[16:1], b[13:0]}), .B({14'b0, c}), .Y(y));
endmodule
EOF
check -assert
design -save preopt
equiv_opt -assert opt_balance_tree
design -load preopt
opt_balance_tree
select -assert-count 1 c:add0
design -reset
log -pop
# Test 2
log -header "AND chain with intermediate outputs"
log -push