Merge pull request #85 from Silimate/fix_opt_balance_tree

Fix opt balance tree and wreduce
This commit is contained in:
Akash Levy 2025-09-09 05:34:27 -07:00 committed by GitHub
commit 36b753285c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1521 additions and 407 deletions

View File

@ -469,28 +469,74 @@ struct WreduceWorker
keep_bits.insert(bit);
}
for (auto c : module->selected_cells())
work_queue_cells.insert(c);
while (!work_queue_cells.empty())
{
work_queue_bits.clear();
for (auto c : work_queue_cells)
run_cell(c);
work_queue_cells.clear();
for (auto bit : work_queue_bits)
for (auto port : mi.query_ports(bit))
if (module->selected(port.cell))
work_queue_cells.insert(port.cell);
}
pool<SigSpec> complete_wires;
for (auto w : module->wires())
complete_wires.insert(mi.sigmap(w));
// Build bit_drivers_db for cells
dict<SigBit, tuple<IdString,IdString>> bit_drivers_db;
for (auto cell : module->cells()) {
for (auto conn : cell->connections()) {
if (!cell->output(conn.first)) continue;
for (int i = 0; i < GetSize(conn.second); i++) {
SigBit bit(mi.sigmap(conn.second[i]));
bit_drivers_db[bit] = tuple<IdString,IdString>(cell->name, conn.first);
}
}
}
// Build wire mapping for dependency tracking
dict<SigBit, Wire*> bit_to_wire_map;
for (auto w : module->wires())
for (auto bit : mi.sigmap(w))
bit_to_wire_map[bit] = w;
// Create unified topological sort for both cells and wires
TopoSort<IdString, RTLIL::sort_by_id_str> unified_toposort;
// Add all cells and processable wires as nodes
for (auto cell : module->selected_cells())
unified_toposort.node(cell->name);
for (auto w : module->selected_wires())
{
unified_toposort.node(w->name);
// Build edges between cells and wires based on signal flow
// In topological sort: edge(A, B) means A should be processed before B
for (auto cell : module->selected_cells()) {
for (auto &conn : cell->connections()) {
bool is_output = cell->output(conn.first);
for (auto bit : mi.sigmap(conn.second)) {
Wire *wire = bit_to_wire_map.count(bit) ? bit_to_wire_map[bit] : nullptr;
if (!wire || !unified_toposort.has_node(wire->name)) continue;
if (is_output) {
// Cell drives wire: process cell before wire
// Cell reduction may affect wire, so cell -> wire
unified_toposort.edge(cell->name, wire->name);
} else {
// Wire drives cell: process wire before cell
// Wire reduction may affect cell, so wire -> cell
unified_toposort.edge(wire->name, cell->name);
}
}
}
}
unified_toposort.analyze_loops = false;
unified_toposort.sort();
// Process cells and wires together in unified topological order
for (auto name : unified_toposort.sorted) {
Cell *c = module->cell(name);
Wire *w = module->wire(name);
if (c && module->selected(c)) {
run_cell(c);
continue;
}
if (!(w && module->selected(w)))
continue;
int unused_top_bits = 0;
if (w->port_id > 0 || count_nontrivial_wire_attrs(w) > 0)
@ -514,6 +560,7 @@ struct WreduceWorker
Wire *nw = module->addWire(module->uniquify(IdString(w->name.str() + "_wreduce")), GetSize(w) - unused_top_bits);
module->connect(nw, SigSpec(w).extract(0, GetSize(nw)));
module->swap_names(w, nw);
mi.reload_module(); // TODO: SILIMATE: CAN WE SPEED THIS UP?
}
}
};

View File

@ -34,317 +34,303 @@ struct OptBalanceTreeWorker {
// Counts of each cell type that are getting balanced
dict<IdString, int> cell_count;
// Cells to remove
pool<Cell*> remove_cells;
// Signal chain data structures
dict<SigSpec, Cell*> sig_chain_next;
dict<SigSpec, Cell*> sig_chain_prev;
pool<SigBit> sigbit_with_non_chain_users;
pool<Cell*> chain_start_cells;
pool<Cell*> candidate_cells;
void make_sig_chain_next_prev(IdString cell_type) {
// Mark all wires with keep attribute or output port as having non-chain users
for (auto wire : module->wires()) {
if (wire->get_bool_attribute(ID::keep) || wire->port_output) {
for (auto bit : sigmap(wire))
sigbit_with_non_chain_users.insert(bit);
}
}
// Iterate over all cells in module
for (auto cell : module->cells()) {
// If cell matches and not marked as keep
if (cell->type == cell_type && !cell->get_bool_attribute(ID::keep)) {
// Get signals for cell ports
SigSpec a_sig = sigmap(cell->getPort(ID::A));
SigSpec b_sig = sigmap(cell->getPort(ID::B));
SigSpec y_sig = sigmap(cell->getPort(ID::Y));
// If a_sig already has a chain user, mark its bits as having non-chain users
if (sig_chain_next.count(a_sig))
for (auto a_bit : a_sig.bits())
sigbit_with_non_chain_users.insert(a_bit);
// Otherwise, mark cell as the next in the chain relative to a_sig
else {
sig_chain_next[a_sig] = cell;
}
if (!b_sig.empty()) {
// If b_sig already has a chain user, mark its bits as having non-chain users
if (sig_chain_next.count(b_sig))
for (auto b_bit : b_sig.bits())
sigbit_with_non_chain_users.insert(b_bit);
// Otherwise, mark cell as the next in the chain relative to b_sig
else {
sig_chain_next[b_sig] = cell;
}
}
// Add cell as candidate
candidate_cells.insert(cell);
// Mark cell as the previous in the chain relative to y_sig
sig_chain_prev[y_sig] = cell;
for (auto bit : y_sig.bits())
sig_chain_prev[bit] = cell;
}
// If cell is not matching type, mark all cell input signals as being non-chain users
else {
for (auto conn : cell->connections())
if (cell->input(conn.first))
for (auto bit : sigmap(conn.second))
sigbit_with_non_chain_users.insert(bit);
}
}
// Check if cell is of the right type and has matching input/output widths
bool is_right_type(Cell* cell, IdString cell_type) {
return cell->type == cell_type &&
cell->getParam(ID::Y_WIDTH).as_int() >= cell->getParam(ID::A_WIDTH).as_int() &&
cell->getParam(ID::Y_WIDTH).as_int() >= cell->getParam(ID::B_WIDTH).as_int();
}
void find_chain_start_cells() {
for (auto cell : candidate_cells) {
// Log candidate cell
log_debug("Considering %s (%s)\n", log_id(cell), log_id(cell->type));
// Create a balanced binary tree from a vector of source signals
SigSpec create_balanced_tree(vector<SigSpec> &sources, IdString cell_type, Cell* cell) {
// Base case: if we have no sources, return an empty signal
if (sources.size() == 0)
return SigSpec();
// Get signals for cell ports
SigSpec a_sig = sigmap(cell->getPort(ID::A));
SigSpec b_sig = sigmap(cell->getPort(ID::B));
SigSpec prev_sig = sig_chain_prev.count(a_sig) ? a_sig : b_sig;
// Base case: if we have only one source, return it
if (sources.size() == 1)
return sources[0];
// Base case: if we have two sources, create a single cell
if (sources.size() == 2) {
// Create a new cell of the same type
Cell* new_cell = module->addCell(NEW_ID2_SUFFIX("tree"), cell_type);
// This is a start cell if there was no previous cell in the chain for a_sig or b_sig
if (sig_chain_prev.count(a_sig) + sig_chain_prev.count(b_sig) != 1) {
chain_start_cells.insert(cell);
continue;
}
// If any bits in previous cell signal have non-chain users, this is a start cell
for (auto bit : prev_sig.bits())
if (sigbit_with_non_chain_users.count(bit)) {
chain_start_cells.insert(cell);
continue;
}
// Copy attributes from reference cell
new_cell->attributes = cell->attributes;
// Create output wire
int out_width = cell->getParam(ID::Y_WIDTH).as_int();
if (cell_type == ID($add))
out_width = max(sources[0].size(), sources[1].size()) + 1;
else if (cell_type == ID($mul))
out_width = sources[0].size() + sources[1].size();
Wire* out_wire = module->addWire(NEW_ID2_SUFFIX("tree_out"), out_width);
// Connect ports and fix up parameters
new_cell->setPort(ID::A, sources[0]);
new_cell->setPort(ID::B, sources[1]);
new_cell->setPort(ID::Y, out_wire);
new_cell->fixup_parameters();
new_cell->setParam(ID::A_SIGNED, cell->getParam(ID::A_SIGNED));
new_cell->setParam(ID::B_SIGNED, cell->getParam(ID::B_SIGNED));
// Update count and return output wire
cell_count[cell_type]++;
return out_wire;
}
}
vector<Cell*> create_chain(Cell *start_cell) {
// Chain of cells
vector<Cell*> chain;
// Current cell
Cell *c = start_cell;
// Iterate over cells and add to chain
while (c != nullptr) {
chain.push_back(c);
SigSpec y_sig = sigmap(c->getPort(ID::Y));
if (sig_chain_next.count(y_sig) == 0)
break;
c = sig_chain_next.at(y_sig);
if (chain_start_cells.count(c) != 0)
break;
}
// Return chain of cells
return chain;
}
void wreduce(Cell *cell) {
// If cell is arithmetic, remove leading zeros from inputs, then clean up outputs
if (cell->type.in(ID($add), ID($mul))) {
// Remove leading zeros from inputs
for (auto inport : {ID::A, ID::B}) {
// Record number of bits removed
int bits_removed = 0;
IdString inport_signed = (inport == ID::A) ? ID::A_SIGNED : ID::B_SIGNED;
IdString inport_width = (inport == ID::A) ? ID::A_WIDTH : ID::B_WIDTH;
SigSpec inport_sig = sigmap(cell->getPort(inport));
cell->unsetPort(inport);
if (cell->getParam((inport == ID::A) ? ID::A_SIGNED : ID::B_SIGNED).as_bool()) {
while (GetSize(inport_sig) > 1 && inport_sig[GetSize(inport_sig)-1] == State::S0 && inport_sig[GetSize(inport_sig)-2] == State::S0) {
inport_sig.remove(GetSize(inport_sig)-1, 1);
bits_removed++;
}
} else {
while (GetSize(inport_sig) > 0 && inport_sig[GetSize(inport_sig)-1] == State::S0) {
inport_sig.remove(GetSize(inport_sig)-1, 1);
bits_removed++;
}
}
cell->setPort(inport, inport_sig);
cell->setParam(inport_width, GetSize(inport_sig));
log_debug("Width reduced %s/%s by %d bits\n", log_id(cell), log_id(inport), bits_removed);
}
// Record number of bits removed from output
int bits_removed = 0;
// Remove unnecessary bits from output
SigSpec y_sig = sigmap(cell->getPort(ID::Y));
cell->unsetPort(ID::Y);
int width;
if (cell->type == ID($add))
width = std::max(cell->getParam(ID::A_WIDTH).as_int(), cell->getParam(ID::B_WIDTH).as_int()) + 1;
else if (cell->type == ID($mul))
width = cell->getParam(ID::A_WIDTH).as_int() + cell->getParam(ID::B_WIDTH).as_int();
else log_abort();
for (int i = GetSize(y_sig) - 1; i >= width; i--) {
module->connect(y_sig[i], State::S0);
y_sig.remove(i, 1);
bits_removed++;
}
cell->setPort(ID::Y, y_sig);
cell->setParam(ID::Y_WIDTH, GetSize(y_sig));
log_debug("Width reduced %s/Y by %d bits\n", log_id(cell), bits_removed);
}
cell->fixup_parameters();
}
bool process_chain(vector<Cell*> &chain) {
// If chain size is less than 3, no balancing needed
if (GetSize(chain) < 3)
return false;
// Get mid, midnext (at index mid+1) and end of chain
Cell *mid_cell = chain[GetSize(chain) / 2];
Cell *cell = mid_cell; // SILIMATE: Set cell to mid_cell for better naming
Cell *midnext_cell = chain[GetSize(chain) / 2 + 1];
Cell *end_cell = chain.back();
log_debug("Balancing chain of %d cells: mid=%s, midnext=%s, endcell=%s\n",
GetSize(chain), log_id(mid_cell), log_id(midnext_cell), log_id(end_cell));
// Get mid signals
SigSpec mid_a_sig = sigmap(mid_cell->getPort(ID::A));
SigSpec mid_b_sig = sigmap(mid_cell->getPort(ID::B));
SigSpec mid_non_chain_sig = sig_chain_prev.count(mid_a_sig) ? mid_b_sig : mid_a_sig;
IdString mid_non_chain_port = sig_chain_prev.count(mid_a_sig) ? ID::B : ID::A;
// Get midnext signals
SigSpec midnext_a_sig = sigmap(midnext_cell->getPort(ID::A));
SigSpec midnext_b_sig = sigmap(midnext_cell->getPort(ID::B));
IdString midnext_chain_port = sig_chain_prev.count(midnext_a_sig) ? ID::A : ID::B;
// Get output signal
SigSpec end_y_sig = sigmap(end_cell->getPort(ID::Y));
// Create new mid wire
Wire *mid_wire = module->addWire(NEW_ID2_SUFFIX("mid"), GetSize(end_y_sig)); // SILIMATE: Improve the naming
// Perform rotation
mid_cell->setPort(mid_non_chain_port, mid_wire);
mid_cell->setPort(ID::Y, end_y_sig);
midnext_cell->setPort(midnext_chain_port, mid_non_chain_sig);
end_cell->setPort(ID::Y, mid_wire);
// Recreate sigmap
sigmap.set(module);
// Get subtrees
vector<Cell*> left_chain(chain.begin(), chain.begin() + GetSize(chain) / 2);
vector<Cell*> right_chain(chain.begin() + GetSize(chain) / 2 + 1, chain.end());
// Recurse on subtrees
process_chain(left_chain);
process_chain(right_chain);
// Width reduce left subtree
for (auto c : left_chain)
wreduce(c);
// Recursive case: split sources into two groups and create subtrees
int mid = (sources.size() + 1) / 2;
vector<SigSpec> left_sources(sources.begin(), sources.begin() + mid);
vector<SigSpec> right_sources(sources.begin() + mid, sources.end());
// Width reduce right subtree
for (auto c : right_chain)
wreduce(c);
// Recreate sigmap
sigmap.set(module);
// Width reduce mid cell
wreduce(mid_cell);
return true;
}
void cleanup() {
// Remove cells
for (auto cell : remove_cells)
module->remove(cell);
// Fix ports
module->fixup_ports();
// Clear data structures
remove_cells.clear();
sig_chain_next.clear();
sig_chain_prev.clear();
sigbit_with_non_chain_users.clear();
chain_start_cells.clear();
candidate_cells.clear();
SigSpec left_tree = create_balanced_tree(left_sources, cell_type, cell);
SigSpec right_tree = create_balanced_tree(right_sources, cell_type, cell);
// Create a cell to combine the two subtrees
Cell* new_cell = module->addCell(NEW_ID2_SUFFIX("tree"), cell_type);
// Copy attributes from reference cell
new_cell->attributes = cell->attributes;
// Create output wire
int out_width = cell->getParam(ID::Y_WIDTH).as_int();
if (cell_type == ID($add))
out_width = max(left_tree.size(), right_tree.size()) + 1;
else if (cell_type == ID($mul))
out_width = left_tree.size() + right_tree.size();
Wire* out_wire = module->addWire(NEW_ID2_SUFFIX("tree_out"), out_width);
// Connect ports and fix up parameters
new_cell->setPort(ID::A, left_tree);
new_cell->setPort(ID::B, right_tree);
new_cell->setPort(ID::Y, out_wire);
new_cell->fixup_parameters();
new_cell->setParam(ID::A_SIGNED, cell->getParam(ID::A_SIGNED));
new_cell->setParam(ID::B_SIGNED, cell->getParam(ID::B_SIGNED));
// Update count and return output wire
cell_count[cell_type]++;
return out_wire;
}
OptBalanceTreeWorker(Module *module, const vector<IdString> cell_types) : module(module), sigmap(module) {
// Do for each cell type
for (auto cell_type : cell_types) {
// Find chains of ops
make_sig_chain_next_prev(cell_type);
find_chain_start_cells();
// Index all of the nets in the module
dict<SigSpec, Cell*> sig_to_driver;
dict<SigSpec, pool<Cell*>> sig_to_sink;
for (auto cell : module->selected_cells())
{
for (auto &conn : cell->connections())
{
if (cell->output(conn.first))
sig_to_driver[sigmap(conn.second)] = cell;
// For each chain, if len >= 3, convert to tree via "rotation" and recurse on subtrees
for (auto c : chain_start_cells) {
vector<Cell*> chain = create_chain(c);
bool processed = process_chain(chain);
if (processed) {
// Rename cells and wires for formal check to pass as cells signals have changed functionalities post rotation
for (Cell *cell : chain) {
module->rename(cell, NEW_ID2_SUFFIX("rot_cell"));
if (cell->input(conn.first))
{
SigSpec sig = sigmap(conn.second);
if (sig_to_sink.count(sig) == 0)
sig_to_sink[sig] = pool<Cell*>();
sig_to_sink[sig].insert(cell);
}
for (Cell *cell : chain) {
SigSpec y_sig = sigmap(cell->getPort(ID::Y));
if (y_sig.is_wire()) {
Wire *wire = y_sig.as_wire();
if (wire && !wire->port_input && !wire->port_output) {
module->rename(y_sig.as_wire(), NEW_ID2_SUFFIX("rot_wire"));
}
}
}
}
// Need to check if any wires connect to module ports
pool<SigSpec> input_port_sigs;
pool<SigSpec> output_port_sigs;
for (auto wire : module->selected_wires())
if (wire->port_input || wire->port_output) {
SigSpec sig = sigmap(wire);
for (auto bit : sig) {
if (wire->port_input)
input_port_sigs.insert(bit);
if (wire->port_output)
output_port_sigs.insert(bit);
}
}
cell_count[cell_type] += GetSize(chain);
}
// Actual logic starts here
pool<Cell*> consumed_cells;
for (auto cell : module->selected_cells())
{
// If consumed or not the correct type, skip
if (consumed_cells.count(cell) || !is_right_type(cell, cell_type))
continue;
// Clean up
cleanup();
// BFS, following all chains until they hit a cell of a different type
// Pick the longest one
auto y = sigmap(cell->getPort(ID::Y));
pool<Cell*> sinks;
pool<Cell*> current_loads = sig_to_sink[y];
pool<Cell*> next_loads;
while (!current_loads.empty())
{
// Find each sink and see what they are
for (auto x : current_loads)
{
// If not the correct type, don't follow any further
// (but add the originating cell to the list of sinks)
if (!is_right_type(x, cell_type))
{
sinks.insert(cell);
continue;
}
auto xy = sigmap(x->getPort(ID::Y));
// If this signal drives a port, add it to the sinks
// (even though it may not be the end of a chain)
for (auto bit : xy) {
if (output_port_sigs.count(bit) && !consumed_cells.count(x)) {
sinks.insert(x);
break;
}
}
// Search signal's fanout
auto& next = sig_to_sink[xy];
for (auto z : next)
next_loads.insert(z);
}
// If we couldn't find any downstream loads, stop.
// Create a reduction for each of the max-length chains we found
if (next_loads.empty())
{
for (auto s : current_loads)
{
// Not one of our gates? Don't follow any further
if (!is_right_type(s, cell_type))
continue;
sinks.insert(s);
}
break;
}
// Otherwise, continue down the chain
current_loads = next_loads;
next_loads.clear();
}
// We have our list of sinks, now go tree balance the chains
for (auto head_cell : sinks)
{
// Avoid duplication if we already were covered
if (consumed_cells.count(head_cell))
continue;
// Get sources of the chain
dict<SigSpec, int> sources;
dict<SigSpec, bool> signeds;
int inner_cells = 0;
std::deque<Cell*> bfs_queue = {head_cell};
while (bfs_queue.size())
{
Cell* x = bfs_queue.front();
bfs_queue.pop_front();
for (auto port: {ID::A, ID::B}) {
auto sig = sigmap(x->getPort(port));
Cell* drv = sig_to_driver[sig];
bool drv_ok = drv && is_right_type(drv, cell_type);
for (auto bit : sig) {
if (input_port_sigs.count(bit) && !consumed_cells.count(drv)) {
drv_ok = false;
break;
}
}
if (drv_ok) {
inner_cells++;
bfs_queue.push_back(drv);
} else {
sources[sig]++;
signeds[sig] = x->getParam(port == ID::A ? ID::A_SIGNED : ID::B_SIGNED).as_bool();
}
}
}
if (inner_cells)
{
// Create a tree
log(" Creating tree for %d sources and %d inner cells...\n", GetSize(sources), inner_cells);
// Build a vector of all source signals
vector<SigSpec> source_signals;
vector<bool> signed_flags;
for (auto &source : sources) {
for (int i = 0; i < source.second; i++) {
source_signals.push_back(source.first);
signed_flags.push_back(signeds[source.first]);
}
}
// If not all signed flags are the same, do not balance
if (!std::all_of(signed_flags.begin(), signed_flags.end(), [&](bool flag) { return flag == signed_flags[0]; })) {
continue;
}
// Create the balanced tree
SigSpec tree_output = create_balanced_tree(source_signals, cell_type, head_cell);
// Connect the tree output to the head cell's output
SigSpec head_output = sigmap(head_cell->getPort(ID::Y));
int connect_width = std::min(head_output.size(), tree_output.size());
log(" Connecting %s to %s\n", log_signal(head_output), log_signal(tree_output));
log_flush();
module->connect(head_output.extract(0, connect_width), tree_output.extract(0, connect_width));
if (head_output.size() > tree_output.size()) {
SigBit sext_bit = head_cell->getParam(ID::A_SIGNED).as_bool() ? head_output[connect_width - 1] : State::S0;
module->connect(head_output.extract(connect_width, head_output.size() - connect_width), SigSpec(sext_bit, head_output.size() - connect_width));
}
// Mark consumed cell for removal
consumed_cells.insert(head_cell);
}
}
}
// Remove all consumed cells, which now have been replaced by trees
for (auto cell : consumed_cells)
module->remove(cell);
}
}
};
struct OptBalanceTreePass : public Pass {
OptBalanceTreePass() : Pass("opt_balance_tree", "$and/$or/$xor/$xnor/$add/$mul cascades to trees") { }
OptBalanceTreePass() : Pass("opt_balance_tree", "$and/$or/$xor/$add/$mul cascades to trees") { }
void help() override {
// |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|
log("\n");
log(" opt_balance_tree [options] [selection]\n");
log("\n");
log("This pass converts cascaded chains of $and/$or/$xor/$xnor/$add/$mul cells into\n");
log("This pass converts cascaded chains of $and/$or/$xor/$add/$mul cells into\n");
log("trees of cells to improve timing.\n");
log("\n");
log(" -arith\n");
log(" only convert arithmetic cells.\n");
log("\n");
}
void execute(std::vector<std::string> args, RTLIL::Design *design) override {
log_header(design, "Executing OPT_BALANCE_TREE pass (cell cascades to trees).\n");
// Handle arguments
size_t argidx;
vector<IdString> cell_types = {ID($and), ID($or), ID($xor), ID($add), ID($mul)};
for (argidx = 1; argidx < args.size(); argidx++) {
// No arguments yet
if (args[argidx] == "-arith") {
cell_types = {ID($add), ID($mul)};
continue;
}
break;
}
extra_args(args, argidx, design);
// Count of all cells that were packed
dict<IdString, int> cell_count;
const vector<IdString> cell_types = {ID($and), ID($or), ID($xor), ID($xnor), ID($add), ID($mul)};
for (auto module : design->selected_modules()) {
OptBalanceTreeWorker worker(module, cell_types);
for (auto cell : worker.cell_count) {
@ -355,6 +341,9 @@ struct OptBalanceTreePass : public Pass {
// Log stats
for (auto cell_type : cell_types)
log("Converted %d %s cells into trees.\n", cell_count[cell_type], log_id(cell_type));
// Clean up
Yosys::run_pass("clean -purge");
}
} OptBalanceTreePass;

View File

@ -11,7 +11,7 @@ prep
design -save gold
alumacc
opt_clean
select -assert-count 1 t:$macc_v2
select -assert-count 2 t:$macc_v2
maccmap -unmap
design -copy-from gold -as gold gate
equiv_make gold gate equiv

File diff suppressed because it is too large Load Diff