diff --git a/passes/opt/opt_balance_tree.cc b/passes/opt/opt_balance_tree.cc index d6450c735..003dfc89c 100644 --- a/passes/opt/opt_balance_tree.cc +++ b/passes/opt/opt_balance_tree.cc @@ -34,6 +34,14 @@ struct OptBalanceTreeWorker { // Counts of each cell type that are getting balanced dict cell_count; + int sliced_add_count = 0; + + struct SlicedAddContext { + dict bit_to_driver; + dict bit_to_driver_index; + dict> bit_to_sink; + pool output_port_sigs; + }; // Check if cell is of the right type and has matching input/output widths // Only allow cells with "natural" output widths (no truncation) to prevent @@ -66,6 +74,28 @@ struct OptBalanceTreeWorker { return y_width >= natural_width; } + bool is_unsigned_add(Cell *cell) + { + return cell && is_right_type(cell, ID($add)) && + !cell->getParam(ID::A_SIGNED).as_bool() && + !cell->getParam(ID::B_SIGNED).as_bool(); + } + + bool is_nonzero(const SigSpec &sig) + { + for (auto bit : sig) + if (bit != State::S0) + return true; + return false; + } + + SigSpec shift_summand(const SigSpec &sig, int offset) + { + SigSpec shifted(State::S0, offset); + shifted.append(sig); + return shifted; + } + // Create a balanced binary tree from a vector of source signals SigSpec create_balanced_tree(vector &sources, IdString cell_type, Cell* cell) { // Base case: if we have no sources, return an empty signal @@ -140,25 +170,198 @@ struct OptBalanceTreeWorker { return out_wire; } + bool full_child_output_at(const SigSpec &sig, int pos, Cell *&child, int &child_width, + SlicedAddContext &ctx) + { + child = nullptr; + child_width = 0; + if (pos >= GetSize(sig)) + return false; + + SigBit bit = sig[pos]; + auto driver_it = ctx.bit_to_driver.find(bit); + if (driver_it == ctx.bit_to_driver.end()) + return false; + + Cell *candidate = driver_it->second; + if (!is_unsigned_add(candidate)) + return false; + + auto index_it = ctx.bit_to_driver_index.find(bit); + if (index_it == ctx.bit_to_driver_index.end() || index_it->second != 0) + return false; + + SigSpec y = sigmap(candidate->getPort(ID::Y)); + child_width = GetSize(y); + if (pos + child_width > GetSize(sig)) + return false; + + for (int i = 0; i < child_width; i++) + if (sig[pos + i] != y[i]) + return false; + + child = candidate; + return true; + } + + bool bit_is_partial_add_output(SigBit bit, SlicedAddContext &ctx) + { + auto driver_it = ctx.bit_to_driver.find(bit); + if (driver_it == ctx.bit_to_driver.end()) + return false; + return is_unsigned_add(driver_it->second); + } + + bool extract_sliced_operand(const SigSpec &sig, int base_offset, vector &summands, + pool &cluster, pool &visiting, SlicedAddContext &ctx, bool &saw_sliced_edge) + { + for (int i = 0; i < GetSize(sig); ) + { + Cell *child = nullptr; + int child_width = 0; + if (full_child_output_at(sig, i, child, child_width, ctx)) + { + if (i != 0 || child_width != GetSize(sig)) + saw_sliced_edge = true; + if (!extract_sliced_add(child, base_offset + i, summands, cluster, visiting, ctx, saw_sliced_edge)) + return false; + i += child_width; + continue; + } + + if (bit_is_partial_add_output(sig[i], ctx)) + return false; + + SigSpec leaf; + int leaf_start = i; + while (i < GetSize(sig)) + { + Cell *next_child = nullptr; + int next_child_width = 0; + if (full_child_output_at(sig, i, next_child, next_child_width, ctx)) + break; + if (bit_is_partial_add_output(sig[i], ctx)) + return false; + leaf.append(sig[i]); + i++; + } + + if (is_nonzero(leaf)) + summands.push_back(shift_summand(leaf, base_offset + leaf_start)); + } + + return true; + } + + bool extract_sliced_add(Cell *cell, int base_offset, vector &summands, + pool &cluster, pool &visiting, SlicedAddContext &ctx, bool &saw_sliced_edge) + { + if (!is_unsigned_add(cell) || visiting.count(cell)) + return false; + + visiting.insert(cell); + cluster.insert(cell); + + for (IdString port : {ID::A, ID::B}) { + SigSpec sig = sigmap(cell->getPort(port)); + if (!extract_sliced_operand(sig, base_offset, summands, cluster, visiting, ctx, saw_sliced_edge)) + return false; + } + + visiting.erase(cell); + return true; + } + + bool has_downstream_add_sink(Cell *cell, SlicedAddContext &ctx) + { + SigSpec y = sigmap(cell->getPort(ID::Y)); + for (auto bit : y) + for (auto sink : ctx.bit_to_sink[bit]) + if (sink != cell && is_unsigned_add(sink)) + return true; + return false; + } + + bool sliced_cluster_has_external_fanout(Cell *head_cell, pool &cluster, SlicedAddContext &ctx) + { + for (auto cell : cluster) + { + if (cell == head_cell) + continue; + + SigSpec y = sigmap(cell->getPort(ID::Y)); + for (auto bit : y) + { + if (ctx.output_port_sigs.count(bit)) + return true; + for (auto sink : ctx.bit_to_sink[bit]) + if (!cluster.count(sink)) + return true; + } + } + + return false; + } + + bool try_sliced_add_tree(Cell *head_cell, pool &consumed_cells, SlicedAddContext &ctx) + { + if (!is_unsigned_add(head_cell) || consumed_cells.count(head_cell) || has_downstream_add_sink(head_cell, ctx)) + return false; + + vector summands; + pool cluster, visiting; + bool saw_sliced_edge = false; + if (!extract_sliced_add(head_cell, 0, summands, cluster, visiting, ctx, saw_sliced_edge)) + return false; + if (!saw_sliced_edge || GetSize(cluster) <= 1 || GetSize(summands) <= 2) + return false; + if (sliced_cluster_has_external_fanout(head_cell, cluster, ctx)) + return false; + + log_debug(" Creating sliced add tree for %s with %d summands and %d cells...\n", + log_id(head_cell), GetSize(summands), GetSize(cluster)); + + SigSpec tree_output = create_balanced_tree(summands, ID($add), head_cell); + SigSpec head_output = sigmap(head_cell->getPort(ID::Y)); + int connect_width = std::min(head_output.size(), tree_output.size()); + module->connect(head_output.extract(0, connect_width), tree_output.extract(0, connect_width)); + if (head_output.size() > tree_output.size()) + module->connect(head_output.extract(connect_width, head_output.size() - connect_width), + SigSpec(State::S0, head_output.size() - connect_width)); + + for (auto cell : cluster) + consumed_cells.insert(cell); + sliced_add_count++; + return true; + } + OptBalanceTreeWorker(Module *module, const vector cell_types) : module(module), sigmap(module) { // Do for each cell type for (auto cell_type : cell_types) { // Index all of the nets in the module dict sig_to_driver; dict> sig_to_sink; + SlicedAddContext sliced_add_ctx; for (auto cell : module->selected_cells()) { for (auto &conn : cell->connections()) { - if (cell->output(conn.first)) - sig_to_driver[sigmap(conn.second)] = cell; + SigSpec sig = sigmap(conn.second); + if (cell->output(conn.first)) { + sig_to_driver[sig] = cell; + for (int i = 0; i < GetSize(sig); i++) { + sliced_add_ctx.bit_to_driver[sig[i]] = cell; + sliced_add_ctx.bit_to_driver_index[sig[i]] = i; + } + } if (cell->input(conn.first)) { - SigSpec sig = sigmap(conn.second); if (sig_to_sink.count(sig) == 0) sig_to_sink[sig] = pool(); sig_to_sink[sig].insert(cell); + for (auto bit : sig) + sliced_add_ctx.bit_to_sink[bit].insert(cell); } } } @@ -172,13 +375,19 @@ struct OptBalanceTreeWorker { for (auto bit : sig) { if (wire->port_input) input_port_sigs.insert(bit); - if (wire->port_output) + if (wire->port_output) { output_port_sigs.insert(bit); + sliced_add_ctx.output_port_sigs.insert(bit); + } } } // Actual logic starts here pool consumed_cells; + if (cell_type == ID($add)) + for (auto cell : module->selected_cells()) + try_sliced_add_tree(cell, consumed_cells, sliced_add_ctx); + for (auto cell : module->selected_cells()) { // If consumed or not the correct type, skip @@ -362,16 +571,20 @@ struct OptBalanceTreePass : public Pass { // Count of all cells that were packed dict cell_count; + int sliced_add_count = 0; for (auto module : design->selected_modules()) { OptBalanceTreeWorker worker(module, cell_types); for (auto cell : worker.cell_count) { cell_count[cell.first] += cell.second; } + sliced_add_count += worker.sliced_add_count; } // Log stats for (auto cell_type : cell_types) log("Converted %d %s cells into trees.\n", cell_count[cell_type], log_id(cell_type)); + if (std::find(cell_types.begin(), cell_types.end(), ID($add)) != cell_types.end()) + log("Converted %d sliced $add chains into trees.\n", sliced_add_count); // Clean up Yosys::run_pass("clean -purge"); diff --git a/tests/opt/opt_balance_tree.ys b/tests/opt/opt_balance_tree.ys index 6f8b8b711..e82923bb0 100644 --- a/tests/opt/opt_balance_tree.ys +++ b/tests/opt/opt_balance_tree.ys @@ -42,6 +42,144 @@ log -pop +# Test 31 +log -header "Sliced shifted ADD chain" +log -push +design -reset +read_verilog -icells <