// Replaces chains of $add/$sub and $macc cells with carry-save adder trees, reducing multi-operand // addition to logarithmic depth. ref. paper: Zimmermann, "Architectures for Adders" #include "kernel/yosys.h" #include "kernel/sigtools.h" #include "kernel/macc.h" #include USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN struct Operand { SigSpec sig; bool is_signed; bool negate; }; struct Traversal { SigMap sigmap; dict> bit_consumers; dict fanout; Traversal(Module* module) : sigmap(module) { for (auto cell : module->cells()) for (auto& conn : cell->connections()) if (cell->input(conn.first)) for (auto bit : sigmap(conn.second)) bit_consumers[bit].insert(cell); for (auto& pair : bit_consumers) fanout[pair.first] = pair.second.size(); for (auto wire : module->wires()) if (wire->port_output) for (auto bit : sigmap(SigSpec(wire))) fanout[bit]++; } }; struct Cells { pool addsub; pool alu; pool macc; static bool is_addsub(Cell* cell) { return cell->type == ID($add) || cell->type == ID($sub); } static bool is_alu(Cell* cell) { return cell->type == ID($alu); } static bool is_macc(Cell* cell) { return cell->type == ID($macc) || cell->type == ID($macc_v2); } bool empty() { return addsub.empty() && alu.empty() && macc.empty(); } Cells(Module* module) { for (auto cell : module->cells()) { if (is_addsub(cell)) addsub.insert(cell); else if (is_alu(cell)) alu.insert(cell); else if (is_macc(cell)) macc.insert(cell); } } }; struct AluInfo { Cells& cells; Traversal& traversal; bool is_subtract(Cell* cell) { SigSpec bi = traversal.sigmap(cell->getPort(ID::BI)); SigSpec ci = traversal.sigmap(cell->getPort(ID::CI)); return GetSize(bi) == 1 && bi[0] == State::S1 && GetSize(ci) == 1 && ci[0] == State::S1; } bool is_add(Cell* cell) { SigSpec bi = traversal.sigmap(cell->getPort(ID::BI)); SigSpec ci = traversal.sigmap(cell->getPort(ID::CI)); return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0; } // Chainable cells are adds/subs with no carry usage, connected chainable // cells form chains that can be replaced with CSA trees. bool is_chainable(Cell* cell) { if (!(is_add(cell) || is_subtract(cell))) return false; for (auto bit : traversal.sigmap(cell->getPort(ID::X))) if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0) return false; for (auto bit : traversal.sigmap(cell->getPort(ID::CO))) if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0) return false; return true; } }; struct Rewriter { Module* module; Cells& cells; Traversal traversal; AluInfo alu_info; Rewriter(Module* module, Cells& cells) : module(module), cells(cells), traversal(module), alu_info{cells, traversal} {} Cell* sole_chainable_consumer(SigSpec sig, const pool& candidates) { Cell* consumer = nullptr; for (auto bit : sig) { if (!traversal.fanout.count(bit) || traversal.fanout[bit] != 1) return nullptr; if (!traversal.bit_consumers.count(bit) || traversal.bit_consumers[bit].size() != 1) return nullptr; Cell* c = *traversal.bit_consumers[bit].begin(); if (!candidates.count(c)) return nullptr; if (consumer == nullptr) consumer = c; else if (consumer != c) return nullptr; } return consumer; } // Find cells that consume another cell's output. dict find_parents(const pool& candidates) { dict parent_of; for (auto cell : candidates) { Cell* consumer = sole_chainable_consumer( traversal.sigmap(cell->getPort(ID::Y)), candidates); if (consumer && consumer != cell) parent_of[cell] = consumer; } return parent_of; } std::pair>, pool> invert_parent_map(const dict& parent_of) { dict> children_of; pool has_parent; for (auto& [child, parent] : parent_of) { children_of[parent].insert(child); has_parent.insert(child); } return {children_of, has_parent}; } pool collect_chain(Cell* root, const dict>& children_of) { pool chain; std::queue q; q.push(root); while (!q.empty()) { Cell* cur = q.front(); q.pop(); if (!chain.insert(cur).second) continue; auto it = children_of.find(cur); if (it != children_of.end()) for (auto child : it->second) q.push(child); } return chain; } pool internal_bits(const pool& chain) { pool bits; for (auto cell : chain) for (auto bit : traversal.sigmap(cell->getPort(ID::Y))) bits.insert(bit); return bits; } static bool overlaps(SigSpec sig, const pool& bits) { for (auto bit : sig) if (bits.count(bit)) return true; return false; } bool feeds_subtracted_port(Cell* child, Cell* parent) { bool parent_subtracts; if (parent->type == ID($sub)) parent_subtracts = true; else if (cells.is_alu(parent)) parent_subtracts = alu_info.is_subtract(parent); else return false; if (!parent_subtracts) return false; SigSpec child_y = traversal.sigmap(child->getPort(ID::Y)); SigSpec parent_b = traversal.sigmap(parent->getPort(ID::B)); for (auto bit : child_y) for (auto pbit : parent_b) if (bit == pbit) return true; return false; } std::vector extract_chain_operands( const pool& chain, Cell* root, const dict& parent_of, int& neg_compensation ) { pool chain_bits = internal_bits(chain); dict negated; negated[root] = false; { std::queue q; q.push(root); while (!q.empty()) { Cell* cur = q.front(); q.pop(); for (auto cell : chain) { if (!parent_of.count(cell) || parent_of.at(cell) != cur) continue; if (negated.count(cell)) continue; negated[cell] = negated[cur] ^ feeds_subtracted_port(cell, cur); q.push(cell); } } } std::vector operands; neg_compensation = 0; for (auto cell : chain) { bool cell_neg; if (negated.count(cell)) cell_neg = negated[cell]; else cell_neg = false; SigSpec a = traversal.sigmap(cell->getPort(ID::A)); SigSpec b = traversal.sigmap(cell->getPort(ID::B)); bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); bool b_sub = (cell->type == ID($sub)) || (cells.is_alu(cell) && alu_info.is_subtract(cell)); if (!overlaps(a, chain_bits)) { bool neg = cell_neg; operands.push_back({a, a_signed, neg}); if (neg) neg_compensation++; } if (!overlaps(b, chain_bits)) { bool neg = cell_neg ^ b_sub; operands.push_back({b, b_signed, neg}); if (neg) neg_compensation++; } } return operands; } bool extract_macc_operands(Cell* cell, std::vector& operands, int& neg_compensation) { Macc macc(cell); neg_compensation = 0; for (auto& term : macc.terms) { // Bail on multiplication if (GetSize(term.in_b) != 0) return false; operands.push_back({term.in_a, term.is_signed, term.do_subtract}); if (term.do_subtract) neg_compensation++; } return true; } SigSpec extend_operand(SigSpec sig, bool is_signed, int width) { if (GetSize(sig) < width) { SigBit pad; if (is_signed && GetSize(sig) > 0) pad = sig[GetSize(sig) - 1]; else pad = State::S0; sig.append(SigSpec(pad, width - GetSize(sig))); } if (GetSize(sig) > width) sig = sig.extract(0, width); return sig; } std::pair emit_fa(SigSpec a, SigSpec b, SigSpec c, int width) { SigSpec sum = module->addWire(NEW_ID, width); SigSpec cout = module->addWire(NEW_ID, width); module->addFa(NEW_ID, a, b, c, cout, sum); SigSpec carry; carry.append(State::S0); carry.append(cout.extract(0, width - 1)); return {sum, carry}; } struct DepthSig { SigSpec sig; int depth; }; // Group ready operands into triplets and compress via full adders until two operands remain. std::pair reduce_wallace(std::vector& sigs, int width, int& fa_count) { std::vector ops; ops.reserve(sigs.size()); for (auto& s : sigs) ops.push_back({s, 0}); fa_count = 0; for (int level = 0; ops.size() > 2; level++) { log_assert(level <= 100); std::vector ready, waiting; for (auto& op : ops) { if (op.depth <= level) ready.push_back(op); else waiting.push_back(op); } if (ready.size() < 3) continue; std::vector next; size_t i = 0; while (i + 2 < ready.size()) { auto [sum, carry] = emit_fa(ready[i].sig, ready[i + 1].sig, ready[i + 2].sig, width); int d = std::max({ready[i].depth, ready[i + 1].depth,ready[i + 2].depth}) + 1; next.push_back({sum, d}); next.push_back({carry, d}); fa_count++; i += 3; } for (; i < ready.size(); i++) next.push_back(ready[i]); for (auto& op : waiting) next.push_back(op); ops = std::move(next); } log_assert(ops.size() == 2); log(" Tree depth: %d FA levels + 1 final add\n", std::max(ops[0].depth, ops[1].depth)); return {ops[0].sig, ops[1].sig}; } void replace_with_csa_tree( std::vector& operands, SigSpec result_y, int neg_compensation, const char* desc ) { int width = GetSize(result_y); std::vector extended; extended.reserve(operands.size() + 1); for (auto& op : operands) { SigSpec s = extend_operand(op.sig, op.is_signed, width); if (op.negate) s = module->Not(NEW_ID, s); extended.push_back(s); } // Add correction for negated operands (-x = ~x + 1 so 1 per negation) if (neg_compensation > 0) extended.push_back(SigSpec(neg_compensation, width)); int fa_count; auto [a, b] = reduce_wallace(extended, width, fa_count); log(" %s -> %d $fa + 1 $add (%d operands, module %s)\n", desc, fa_count, (int)operands.size(), log_id(module)); // Emit final add module->addAdd(NEW_ID, a, b, result_y, false); } void process_chains() { pool candidates; for (auto cell : cells.addsub) candidates.insert(cell); for (auto cell : cells.alu) if (alu_info.is_chainable(cell)) candidates.insert(cell); if (candidates.empty()) return; auto parent_of = find_parents(candidates); auto [children_of, has_parent] = invert_parent_map(parent_of); pool processed; for (auto root : candidates) { if (has_parent.count(root) || processed.count(root)) continue; // Not a tree root pool chain = collect_chain(root, children_of); if (chain.size() < 2) continue; for (auto c : chain) processed.insert(c); int neg_compensation; auto operands = extract_chain_operands(chain, root, parent_of, neg_compensation); if (operands.size() < 3) continue; replace_with_csa_tree(operands, root->getPort(ID::Y), neg_compensation, "Replaced add/sub chain"); for (auto cell : chain) module->remove(cell); } } void process_maccs() { for (auto cell : cells.macc) { std::vector operands; int neg_compensation; if (!extract_macc_operands(cell, operands, neg_compensation)) continue; if (operands.size() < 3) continue; replace_with_csa_tree(operands, cell->getPort(ID::Y), neg_compensation, "Replaced $macc"); module->remove(cell); } } }; void run(Module* module) { Cells cells(module); if (cells.empty()) return; Rewriter rewriter {module, cells}; rewriter.process_chains(); rewriter.process_maccs(); } struct CsaTreePass : public Pass { CsaTreePass() : Pass("csa_tree", "convert add/sub/macc chains to carry-save adder trees") {} void help() override { // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| log("\n"); log(" csa_tree [selection]\n"); log("\n"); log("This pass replaces chains of $add/$sub cells, $alu cells (with constant\n"); log("BI/CI), and $macc/$macc_v2 cells (without multiplications) with carry-save\n"); log("adder trees using $fa cells and a single final $add.\n"); log("\n"); log("The tree uses Wallace-tree scheduling: at each level, ready operands are\n"); log("grouped into triplets and compressed via full adders, giving\n"); log("O(log_{1.5} N) depth for N input operands.\n"); log("\n"); } void execute(std::vector args, RTLIL::Design* design) override { log_header(design, "Executing CSA_TREE pass.\n"); size_t argidx; for (argidx = 1; argidx < args.size(); argidx++) break; extra_args(args, argidx, design); for (auto module : design->selected_modules()) { run(module); } } } CsaTreePass; PRIVATE_NAMESPACE_END