diff --git a/passes/silimate/opt_balance_tree.cc b/passes/silimate/opt_balance_tree.cc index 5d893658a..da308e582 100644 --- a/passes/silimate/opt_balance_tree.cc +++ b/passes/silimate/opt_balance_tree.cc @@ -19,40 +19,35 @@ * */ -#include "kernel/sigtools.h" #include "kernel/yosys.h" +#include "kernel/sigtools.h" USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN + struct OptBalanceTreeWorker { // Module and signal map - Design *design; Module *module; SigMap sigmap; - bool allow_off_chain; - int limit = -1; + // Counts of each cell type that are getting balanced dict cell_count; - // Driver data - dict> bit_drivers_db; - // Load data - dict>> bit_users_db; + + // Cells to remove + pool remove_cells; // Signal chain data structures - dict sig_chain_next; - dict sig_chain_prev; + dict sig_chain_next; + dict sig_chain_prev; pool sigbit_with_non_chain_users; - pool chain_start_cells; - pool candidate_cells; + pool chain_start_cells; + pool candidate_cells; - // Ignore signals fanout while looking ahead which chains to split. - // Post splitfanout, take that into account. - void make_sig_chain_next_prev(IdString cell_type, bool ignore_split) - { - // Mark all wires with keep attribute as having non-chain users + void make_sig_chain_next_prev(IdString cell_type) { + // Mark all wires with keep attribute or output port as having non-chain users for (auto wire : module->wires()) { - if (wire->get_bool_attribute(ID::keep)) { + if (wire->get_bool_attribute(ID::keep) || wire->port_output) { for (auto bit : sigmap(wire)) sigbit_with_non_chain_users.insert(bit); } @@ -66,40 +61,34 @@ struct OptBalanceTreeWorker { SigSpec a_sig = sigmap(cell->getPort(ID::A)); SigSpec b_sig = sigmap(cell->getPort(ID::B)); SigSpec y_sig = sigmap(cell->getPort(ID::Y)); - - // If a_sig already has a chain user, mark its bits as having non-chain users - if (sig_chain_next.count(a_sig)) { - if (!ignore_split) - for (auto a_bit : a_sig.bits()) - sigbit_with_non_chain_users.insert(a_bit); - // Otherwise, mark cell as the next in the chain relative to a_sig - } else { - if (fanout_in_range(y_sig)) { - sig_chain_next[a_sig] = cell; - } + + // If a_sig already has a chain user, mark its bits as having non-chain users + if (sig_chain_next.count(a_sig)) + for (auto a_bit : a_sig.bits()) + sigbit_with_non_chain_users.insert(a_bit); + // Otherwise, mark cell as the next in the chain relative to a_sig + else { + sig_chain_next[a_sig] = cell; } if (!b_sig.empty()) { // If b_sig already has a chain user, mark its bits as having non-chain users - if (sig_chain_next.count(b_sig)) { - if (!ignore_split) - for (auto b_bit : b_sig.bits()) - sigbit_with_non_chain_users.insert(b_bit); - // Otherwise, mark cell as the next in the chain relative to b_sig - } else { - if (fanout_in_range(y_sig)) { - sig_chain_next[b_sig] = cell; - } + if (sig_chain_next.count(b_sig)) + for (auto b_bit : b_sig.bits()) + sigbit_with_non_chain_users.insert(b_bit); + // Otherwise, mark cell as the next in the chain relative to b_sig + else { + sig_chain_next[b_sig] = cell; } } + + // Add cell as candidate + candidate_cells.insert(cell); - if (fanout_in_range(y_sig)) { - // Add cell as candidate - candidate_cells.insert(cell); - - // Mark cell as the previous in the chain relative to y_sig - sig_chain_prev[y_sig] = cell; - } + // Mark cell as the previous in the chain relative to y_sig + sig_chain_prev[y_sig] = cell; + for (auto bit : y_sig.bits()) + sig_chain_prev[bit] = cell; } // If cell is not matching type, mark all cell input signals as being non-chain users else { @@ -111,8 +100,7 @@ struct OptBalanceTreeWorker { } } - void find_chain_start_cells() - { + void find_chain_start_cells() { for (auto cell : candidate_cells) { // Log candidate cell log_debug("Considering %s (%s)\n", log_id(cell), log_id(cell->type)); @@ -121,7 +109,7 @@ struct OptBalanceTreeWorker { SigSpec a_sig = sigmap(cell->getPort(ID::A)); SigSpec b_sig = sigmap(cell->getPort(ID::B)); SigSpec prev_sig = sig_chain_prev.count(a_sig) ? a_sig : b_sig; - + // This is a start cell if there was no previous cell in the chain for a_sig or b_sig if (sig_chain_prev.count(a_sig) + sig_chain_prev.count(b_sig) != 1) { chain_start_cells.insert(cell); @@ -137,10 +125,9 @@ struct OptBalanceTreeWorker { } } - vector create_chain(Cell *start_cell) - { + vector create_chain(Cell *start_cell) { // Chain of cells - vector chain; + vector chain; // Current cell Cell *c = start_cell; @@ -163,8 +150,7 @@ struct OptBalanceTreeWorker { return chain; } - void wreduce(Cell *cell) - { + void wreduce(Cell *cell) { // If cell is arithmetic, remove leading zeros from inputs, then clean up outputs if (cell->type.in(ID($add), ID($mul))) { // Remove leading zeros from inputs @@ -176,14 +162,13 @@ struct OptBalanceTreeWorker { SigSpec inport_sig = sigmap(cell->getPort(inport)); cell->unsetPort(inport); if (cell->getParam((inport == ID::A) ? ID::A_SIGNED : ID::B_SIGNED).as_bool()) { - while (GetSize(inport_sig) > 1 && inport_sig[GetSize(inport_sig) - 1] == State::S0 && - inport_sig[GetSize(inport_sig) - 2] == State::S0) { - inport_sig.remove(GetSize(inport_sig) - 1, 1); + while (GetSize(inport_sig) > 1 && inport_sig[GetSize(inport_sig)-1] == State::S0 && inport_sig[GetSize(inport_sig)-2] == State::S0) { + inport_sig.remove(GetSize(inport_sig)-1, 1); bits_removed++; } } else { - while (GetSize(inport_sig) > 0 && inport_sig[GetSize(inport_sig) - 1] == State::S0) { - inport_sig.remove(GetSize(inport_sig) - 1, 1); + while (GetSize(inport_sig) > 0 && inport_sig[GetSize(inport_sig)-1] == State::S0) { + inport_sig.remove(GetSize(inport_sig)-1, 1); bits_removed++; } } @@ -203,8 +188,7 @@ struct OptBalanceTreeWorker { width = std::max(cell->getParam(ID::A_WIDTH).as_int(), cell->getParam(ID::B_WIDTH).as_int()) + 1; else if (cell->type == ID($mul)) width = cell->getParam(ID::A_WIDTH).as_int() + cell->getParam(ID::B_WIDTH).as_int(); - else - log_abort(); + else log_abort(); for (int i = GetSize(y_sig) - 1; i >= width; i--) { module->connect(y_sig[i], State::S0); y_sig.remove(i, 1); @@ -218,19 +202,18 @@ struct OptBalanceTreeWorker { cell->fixup_parameters(); } - bool process_chain(vector &chain) - { + void process_chain(vector &chain) { // If chain size is less than 3, no balancing needed if (GetSize(chain) < 3) - return false; + return; // Get mid, midnext (at index mid+1) and end of chain Cell *mid_cell = chain[GetSize(chain) / 2]; Cell *cell = mid_cell; // SILIMATE: Set cell to mid_cell for better naming Cell *midnext_cell = chain[GetSize(chain) / 2 + 1]; Cell *end_cell = chain.back(); - log_debug("Balancing chain of %d cells: mid=%s, midnext=%s, endcell=%s\n", GetSize(chain), log_id(mid_cell), log_id(midnext_cell), - log_id(end_cell)); + log_debug("Balancing chain of %d cells: mid=%s, midnext=%s, endcell=%s\n", + GetSize(chain), log_id(mid_cell), log_id(midnext_cell), log_id(end_cell)); // Get mid signals SigSpec mid_a_sig = sigmap(mid_cell->getPort(ID::A)); @@ -259,17 +242,17 @@ struct OptBalanceTreeWorker { sigmap.set(module); // Get subtrees - vector left_chain(chain.begin(), chain.begin() + GetSize(chain) / 2); - vector right_chain(chain.begin() + GetSize(chain) / 2 + 1, chain.end()); + vector left_chain(chain.begin(), chain.begin() + GetSize(chain) / 2); + vector right_chain(chain.begin() + GetSize(chain) / 2 + 1, chain.end()); // Recurse on subtrees process_chain(left_chain); process_chain(right_chain); - + // Width reduce left subtree for (auto c : left_chain) wreduce(c); - + // Width reduce right subtree for (auto c : right_chain) wreduce(c); @@ -279,15 +262,18 @@ struct OptBalanceTreeWorker { // Width reduce mid cell wreduce(mid_cell); - return true; } - void cleanup() - { + void cleanup() { + // Remove cells + for (auto cell : remove_cells) + module->remove(cell); + // Fix ports module->fixup_ports(); // Clear data structures + remove_cells.clear(); sig_chain_next.clear(); sig_chain_prev.clear(); sigbit_with_non_chain_users.clear(); @@ -295,130 +281,17 @@ struct OptBalanceTreeWorker { candidate_cells.clear(); } - bool fanout_in_range(SigSpec outsig) - { - // Check if output signal is "bit-split", skip if so - // This is a lookahead for the splitfanout pass that has this limitation - auto bit_users = bit_users_db[outsig[0]]; - for (int i = 0; i < GetSize(outsig); i++) { - if (bit_users_db[outsig[i]] != bit_users) { - return false; - } - } - - // Skip if fanout is above limit - if (limit != -1 && GetSize(bit_users) > limit) { - return false; - } - return true; - } - - OptBalanceTreeWorker(Design *design, Module *module, const vector cell_types, bool allow_off_chain, int limit) - : design(design), module(module), sigmap(module), allow_off_chain(allow_off_chain), limit(limit) - { - - if (allow_off_chain) { - - // Build bit_drivers_db - log("Building bit_drivers_db...\n"); - for (auto cell : module->cells()) { - for (auto conn : cell->connections()) { - if (!cell->output(conn.first)) - continue; - for (int i = 0; i < GetSize(conn.second); i++) { - SigBit bit(sigmap(conn.second[i])); - bit_drivers_db[bit] = tuple(cell->name, conn.first, i); - } - } - } - - // Build bit_users_db - log("Building bit_users_db...\n"); - for (auto cell : module->cells()) { - for (auto conn : cell->connections()) { - if (!cell->input(conn.first)) - continue; - for (int i = 0; i < GetSize(conn.second); i++) { - SigBit bit(sigmap(conn.second[i])); - if (!bit_drivers_db.count(bit)) - continue; - bit_users_db[bit].insert( - tuple(cell->name, conn.first, i - std::get<2>(bit_drivers_db[bit]))); - } - } - } - - // Build bit_users_db for output ports - log("Building bit_users_db for output ports...\n"); - for (auto wire : module->wires()) { - if (!wire->port_output) - continue; - SigSpec sig(sigmap(wire)); - for (int i = 0; i < GetSize(sig); i++) { - SigBit bit(sig[i]); - if (!bit_drivers_db.count(bit)) - continue; - bit_users_db[bit].insert( - tuple(wire->name, IdString(), i - std::get<2>(bit_drivers_db[bit]))); - } - } - - // Deselect all cells - Pass::call(design, "select -none"); - // Do for each cell type - bool has_cell_to_split = false; - for (auto cell_type : cell_types) { - // Find chains of ops - make_sig_chain_next_prev(cell_type, true); - find_chain_start_cells(); - - // For each chain, if len >= 3, select all the elements - for (auto c : chain_start_cells) { - vector chain = create_chain(c); - if (GetSize(chain) < 3) - continue; - for (auto cell : chain) { - has_cell_to_split = true; - design->select(module, cell); - } - } - // Clean up - cleanup(); - } - - // Splitfanout of selected cells - if (has_cell_to_split) - Pass::call(design, "splitfanout"); - // Reset selection for other passes - Pass::call(design, "select -clear"); - // Recreate sigmap - sigmap.set(module); - } - + OptBalanceTreeWorker(Module *module, const vector cell_types) : module(module), sigmap(module) { // Do for each cell type for (auto cell_type : cell_types) { // Find chains of ops - make_sig_chain_next_prev(cell_type, false); + make_sig_chain_next_prev(cell_type); find_chain_start_cells(); // For each chain, if len >= 3, convert to tree via "rotation" and recurse on subtrees for (auto c : chain_start_cells) { - vector chain = create_chain(c); - if (process_chain(chain)) { - // Rename cells and wires for formal check to pass as cells signals have changed functionalities post rotation - for (Cell *cell : chain) { - module->rename(cell, NEW_ID2_SUFFIX("rot_cell")); - } - for (Cell *cell : chain) { - SigSpec y_sig = sigmap(cell->getPort(ID::Y)); - if (y_sig.is_wire()) { - Wire *wire = y_sig.as_wire(); - if (wire && !wire->port_input && !wire->port_output) { - module->rename(y_sig.as_wire(), NEW_ID2_SUFFIX("rot_wire")); - } - } - } - } + vector chain = create_chain(c); + process_chain(chain); cell_count[cell_type] += GetSize(chain); } @@ -429,9 +302,8 @@ struct OptBalanceTreeWorker { }; struct OptBalanceTreePass : public Pass { - OptBalanceTreePass() : Pass("opt_balance_tree", "$and/$or/$xor/$xnor/$add/$mul cascades to trees") {} - void help() override - { + OptBalanceTreePass() : Pass("opt_balance_tree", "$and/$or/$xor/$xnor/$add/$mul cascades to trees") { } + void help() override { // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| log("\n"); log(" opt_balance_tree [options] [selection]\n"); @@ -439,47 +311,23 @@ struct OptBalanceTreePass : public Pass { log("This pass converts cascaded chains of $and/$or/$xor/$xnor/$add/$mul cells into\n"); log("trees of cells to improve timing.\n"); log("\n"); - log(" -allow-off-chain\n"); - log(" Allows matching of cells that have loads outside the chain. These cells\n"); - log(" will be replicated and balanced into a tree, but the original\n"); - log(" cell will remain, driving its original loads.\n"); - log(" -fanout_limit n\n"); - log(" Max fanout to split.\n"); - log(" -arith_only\n"); - log(" Only balance arithmetic cells.\n"); - log("\n"); } - void execute(std::vector args, RTLIL::Design *design) override - { + void execute(std::vector args, RTLIL::Design *design) override { log_header(design, "Executing OPT_BALANCE_TREE pass (cell cascades to trees).\n"); - bool allow_off_chain = false; - bool arith_only = false; + // Handle arguments size_t argidx; - int limit = -1; for (argidx = 1; argidx < args.size(); argidx++) { - if (args[argidx] == "-allow-off-chain") { - allow_off_chain = true; - continue; - } - if (args[argidx] == "-fanout_limit" && argidx + 1 < args.size()) { - limit = std::stoi(args[++argidx]); - continue; - } - if (args[argidx] == "-arith_only") { - arith_only = true; - continue; - } + // No arguments yet break; } extra_args(args, argidx, design); // Count of all cells that were packed dict cell_count; - vector cell_types = {ID($and), ID($or), ID($xor), ID($xnor), ID($add), ID($mul)}; - if (arith_only) cell_types = {ID($add), ID($mul)}; + const vector cell_types = {ID($and), ID($or), ID($xor), ID($xnor), ID($add), ID($mul)}; for (auto module : design->selected_modules()) { - OptBalanceTreeWorker worker(design, module, cell_types, allow_off_chain, limit); + OptBalanceTreeWorker worker(module, cell_types); for (auto cell : worker.cell_count) { cell_count[cell.first] += cell.second; }