diff --git a/Makefile b/Makefile index fc5172cc0..8c5524d42 100644 --- a/Makefile +++ b/Makefile @@ -953,6 +953,7 @@ MK_TEST_DIRS += tests/verilog # Tests that don't generate .mk SH_TEST_DIRS = +SH_TEST_DIRS += tests/arith_tree SH_TEST_DIRS += tests/simple SH_TEST_DIRS += tests/simple_abc9 SH_TEST_DIRS += tests/hana diff --git a/kernel/wallace_tree.h b/kernel/wallace_tree.h new file mode 100644 index 000000000..eb3513803 --- /dev/null +++ b/kernel/wallace_tree.h @@ -0,0 +1,112 @@ +/** + * Wallace tree utilities for multi-operand addition using carry-save adders + * + * Terminology: + * - compressor: $fa viewed as reducing 3 inputs to 2 outputs (sum + shifted carry) (3:2 compressor) + * - level: A stage of parallel compression operations + * - depth: Maximum number of 3:2 compressor levels from any input to a signal + * + * References: + * - "Binary Adder Architectures for Cell-Based VLSI and their Synthesis" (https://iis-people.ee.ethz.ch/~zimmi/publications/adder_arch.pdf) + * - "A Suggestion for a Fast Multiplier" (https://www.ece.ucdavis.edu/~vojin/CLASSES/EEC280/Web-page/papers/Arithmetic/Wallace_mult.pdf) + */ + +#ifndef WALLACE_TREE_H +#define WALLACE_TREE_H + +#include "kernel/sigtools.h" +#include "kernel/yosys.h" + +YOSYS_NAMESPACE_BEGIN + +inline std::pair emit_fa(Module *module, SigSpec a, SigSpec b, SigSpec c, int width) +{ + SigSpec sum = module->addWire(NEW_ID, width); + SigSpec cout = module->addWire(NEW_ID, width); + + module->addFa(NEW_ID, a, b, c, cout, sum); + + SigSpec carry; + carry.append(State::S0); + carry.append(cout.extract(0, width - 1)); + return {sum, carry}; +} + +/** + * wallace_reduce_scheduled() - Reduce multiple operands to two using a Wallace tree + * @module: The Yosys module to which the compressors will be added + * @sigs: Vector of input signals (operands) to be reduced + * @width: Target bit-width to which all operands will be zero-extended + * @compressor_count: Optional pointer to return the number of $fa cells emitted + * + * Return: The final two reduced operands, that are to be fed into an adder + */ +inline std::pair wallace_reduce_scheduled(Module *module, std::vector &sigs, int width, int *compressor_count = nullptr) +{ + struct DepthSig { + SigSpec sig; + int depth; + }; + + for (auto &s : sigs) + s.extend_u0(width); + + std::vector operands; + operands.reserve(sigs.size()); + for (auto &s : sigs) + operands.push_back({s, 0}); + + // Number of $fa's emitted + if (compressor_count) + *compressor_count = 0; + + // Only compress operands ready at current level + for (int level = 0; operands.size() > 2; level++) { + // Partition operands into ready and waiting + std::vector ready, waiting; + for (auto &op : operands) { + if (op.depth <= level) + ready.push_back(op); + else + waiting.push_back(op); + } + + if (ready.size() < 3) + continue; + + // Apply compressors to ready operands + std::vector compressed; + size_t i = 0; + while (i + 2 < ready.size()) { + auto [sum, carry] = emit_fa(module, ready[i].sig, ready[i + 1].sig, ready[i + 2].sig, width); + int new_depth = std::max({ready[i].depth, ready[i + 1].depth, ready[i + 2].depth}) + 1; + compressed.push_back({sum, new_depth}); + compressed.push_back({carry, new_depth}); + if (compressor_count) + (*compressor_count)++; + i += 3; + } + // Uncompressed operands pass through to next level + for (; i < ready.size(); i++) + compressed.push_back(ready[i]); + // Merge compressed with waiting operands + for (auto &op : waiting) + compressed.push_back(op); + + operands = std::move(compressed); + } + + if (operands.size() == 0) + return {SigSpec(State::S0, width), SigSpec(State::S0, width)}; + else if (operands.size() == 1) + return {operands[0].sig, SigSpec(State::S0, width)}; + else { + log_assert(operands.size() == 2); + log(" Wallace tree depth: %d levels of $fa + 1 final $add\n", std::max(operands[0].depth, operands[1].depth)); + return {operands[0].sig, operands[1].sig}; + } +} + +YOSYS_NAMESPACE_END + +#endif diff --git a/passes/techmap/Makefile.inc b/passes/techmap/Makefile.inc index 083778d3c..eccad8998 100644 --- a/passes/techmap/Makefile.inc +++ b/passes/techmap/Makefile.inc @@ -55,6 +55,7 @@ OBJS += passes/techmap/extractinv.o OBJS += passes/techmap/cellmatch.o OBJS += passes/techmap/clockgate.o OBJS += passes/techmap/constmap.o +OBJS += passes/techmap/arith_tree.o endif ifeq ($(DISABLE_SPAWN),0) diff --git a/passes/techmap/arith_tree.cc b/passes/techmap/arith_tree.cc new file mode 100644 index 000000000..9494fa958 --- /dev/null +++ b/passes/techmap/arith_tree.cc @@ -0,0 +1,426 @@ +/** + * Replaces chains of $add/$sub and $macc cells with carry-save adder trees + * + * Terminology: + * - parent: Cells that consume another cell's output + * - chainable: Adds/subs with no carry-out usage + * - chain: Connected path of chainable cells + */ + +#include "kernel/macc.h" +#include "kernel/sigtools.h" +#include "kernel/wallace_tree.h" +#include "kernel/yosys.h" + +#include + +USING_YOSYS_NAMESPACE +PRIVATE_NAMESPACE_BEGIN + +struct Operand { + SigSpec sig; + bool is_signed; + bool negate; +}; + +struct Traversal { + SigMap sigmap; + dict> bit_consumers; + dict fanout; + Traversal(Module *module) : sigmap(module) + { + for (auto cell : module->cells()) + for (auto &conn : cell->connections()) + if (cell->input(conn.first)) + for (auto bit : sigmap(conn.second)) + bit_consumers[bit].insert(cell); + + for (auto &pair : bit_consumers) + fanout[pair.first] = pair.second.size(); + + for (auto wire : module->wires()) + if (wire->port_output) + for (auto bit : sigmap(SigSpec(wire))) + fanout[bit]++; + } +}; + +struct Cells { + pool addsub; + pool alu; + pool macc; + + static bool is_addsub(Cell *cell) { return cell->type == ID($add) || cell->type == ID($sub); } + + static bool is_alu(Cell *cell) { return cell->type == ID($alu); } + + static bool is_macc(Cell *cell) { return cell->type == ID($macc) || cell->type == ID($macc_v2); } + + bool empty() { return addsub.empty() && alu.empty() && macc.empty(); } + + Cells(Module *module) + { + for (auto cell : module->cells()) { + if (is_addsub(cell)) + addsub.insert(cell); + else if (is_alu(cell)) + alu.insert(cell); + else if (is_macc(cell)) + macc.insert(cell); + } + } +}; + +struct AluInfo { + Cells &cells; + Traversal &traversal; + bool is_subtract(Cell *cell) + { + SigSpec bi = traversal.sigmap(cell->getPort(ID::BI)); + SigSpec ci = traversal.sigmap(cell->getPort(ID::CI)); + return GetSize(bi) == 1 && bi[0] == State::S1 && GetSize(ci) == 1 && ci[0] == State::S1; + } + + bool is_add(Cell *cell) + { + SigSpec bi = traversal.sigmap(cell->getPort(ID::BI)); + SigSpec ci = traversal.sigmap(cell->getPort(ID::CI)); + return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0; + } + + bool is_chainable(Cell *cell) + { + if (!(is_add(cell) || is_subtract(cell))) + return false; + + for (auto bit : traversal.sigmap(cell->getPort(ID::X))) + if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0) + return false; + for (auto bit : traversal.sigmap(cell->getPort(ID::CO))) + if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0) + return false; + + return true; + } +}; + +struct Rewriter { + Module *module; + Cells &cells; + Traversal traversal; + AluInfo alu_info; + + Rewriter(Module *module, Cells &cells) : module(module), cells(cells), traversal(module), alu_info{cells, traversal} {} + + Cell *sole_chainable_consumer(SigSpec sig, const pool &candidates) + { + Cell *consumer = nullptr; + for (auto bit : sig) { + if (!traversal.fanout.count(bit) || traversal.fanout[bit] != 1) + return nullptr; + if (!traversal.bit_consumers.count(bit) || traversal.bit_consumers[bit].size() != 1) + return nullptr; + + Cell *c = *traversal.bit_consumers[bit].begin(); + if (!candidates.count(c)) + return nullptr; + + if (consumer == nullptr) + consumer = c; + else if (consumer != c) + return nullptr; + } + return consumer; + } + + dict find_parents(const pool &candidates) + { + dict parent_of; + for (auto cell : candidates) { + Cell *consumer = sole_chainable_consumer(traversal.sigmap(cell->getPort(ID::Y)), candidates); + if (consumer && consumer != cell) + parent_of[cell] = consumer; + } + return parent_of; + } + + std::pair>, pool> invert_parent_map(const dict &parent_of) + { + dict> children_of; + pool has_parent; + for (auto &[child, parent] : parent_of) { + children_of[parent].insert(child); + has_parent.insert(child); + } + return {children_of, has_parent}; + } + + pool collect_chain(Cell *root, const dict> &children_of) + { + pool chain; + std::queue q; + q.push(root); + while (!q.empty()) { + Cell *cur = q.front(); + q.pop(); + if (!chain.insert(cur).second) + continue; + auto it = children_of.find(cur); + if (it != children_of.end()) + for (auto child : it->second) + q.push(child); + } + return chain; + } + + pool internal_bits(const pool &chain) + { + pool bits; + for (auto cell : chain) + for (auto bit : traversal.sigmap(cell->getPort(ID::Y))) + bits.insert(bit); + return bits; + } + + static bool overlaps(SigSpec sig, const pool &bits) + { + for (auto bit : sig) + if (bits.count(bit)) + return true; + return false; + } + + bool feeds_subtracted_port(Cell *child, Cell *parent) + { + bool parent_subtracts; + if (parent->type == ID($sub)) + parent_subtracts = true; + else if (cells.is_alu(parent)) + parent_subtracts = alu_info.is_subtract(parent); + else + return false; + + if (!parent_subtracts) + return false; + + // Check if any bit of child's Y connects to parent's B + SigSpec child_y = traversal.sigmap(child->getPort(ID::Y)); + SigSpec parent_b = traversal.sigmap(parent->getPort(ID::B)); + for (auto bit : child_y) + for (auto pbit : parent_b) + if (bit == pbit) + return true; + return false; + } + + std::vector extract_chain_operands(const pool &chain, Cell *root, const dict &parent_of, int &neg_compensation) + { + pool chain_bits = internal_bits(chain); + + // Propagate negation flags through chain + dict negated; + negated[root] = false; + { + std::queue q; + q.push(root); + while (!q.empty()) { + Cell *cur = q.front(); + q.pop(); + for (auto cell : chain) { + if (!parent_of.count(cell) || parent_of.at(cell) != cur) + continue; + if (negated.count(cell)) + continue; + negated[cell] = negated[cur] ^ feeds_subtracted_port(cell, cur); + q.push(cell); + } + } + } + + // Extract leaf operands + std::vector operands; + neg_compensation = 0; + + for (auto cell : chain) { + bool cell_neg = negated.count(cell) ? negated[cell] : false; + + SigSpec a = traversal.sigmap(cell->getPort(ID::A)); + SigSpec b = traversal.sigmap(cell->getPort(ID::B)); + bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); + bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); + bool b_sub = (cell->type == ID($sub)) || (cells.is_alu(cell) && alu_info.is_subtract(cell)); + + // Only add operands not produced by other chain cells + if (!overlaps(a, chain_bits)) { + operands.push_back({a, a_signed, cell_neg}); + if (cell_neg) + neg_compensation++; + } + if (!overlaps(b, chain_bits)) { + bool neg = cell_neg ^ b_sub; + operands.push_back({b, b_signed, neg}); + if (neg) + neg_compensation++; + } + } + return operands; + } + + bool extract_macc_operands(Cell *cell, std::vector &operands, int &neg_compensation) + { + Macc macc(cell); + neg_compensation = 0; + + for (auto &term : macc.terms) { + // Bail on multiplication + if (GetSize(term.in_b) != 0) + return false; + operands.push_back({term.in_a, term.is_signed, term.do_subtract}); + if (term.do_subtract) + neg_compensation++; + } + return true; + } + + SigSpec extend_operand(SigSpec sig, bool is_signed, int width) + { + if (GetSize(sig) < width) { + SigBit pad; + if (is_signed && GetSize(sig) > 0) + pad = sig[GetSize(sig) - 1]; + else + pad = State::S0; + sig.append(SigSpec(pad, width - GetSize(sig))); + } + if (GetSize(sig) > width) + sig = sig.extract(0, width); + return sig; + } + + void replace_with_carry_save_tree(std::vector &operands, SigSpec result_y, int neg_compensation, const char *desc) + { + int width = GetSize(result_y); + std::vector extended; + extended.reserve(operands.size() + 1); + + for (auto &op : operands) { + SigSpec s = extend_operand(op.sig, op.is_signed, width); + if (op.negate) + s = module->Not(NEW_ID, s); + extended.push_back(s); + } + + // Add correction for negated operands (-x = ~x + 1 so 1 per negation) + if (neg_compensation > 0) + extended.push_back(SigSpec(neg_compensation, width)); + + int compressor_count; + auto [a, b] = wallace_reduce_scheduled(module, extended, width, &compressor_count); + log(" %s -> %d $fa + 1 $add (%d operands, module %s)\n", desc, compressor_count, (int)operands.size(), log_id(module)); + + // Emit final add + module->addAdd(NEW_ID, a, b, result_y, false); + } + + void process_chains() + { + pool candidates; + for (auto cell : cells.addsub) + candidates.insert(cell); + for (auto cell : cells.alu) + if (alu_info.is_chainable(cell)) + candidates.insert(cell); + + if (candidates.empty()) + return; + + auto parent_of = find_parents(candidates); + auto [children_of, has_parent] = invert_parent_map(parent_of); + + pool to_remove; + for (auto root : candidates) { + if (has_parent.count(root) || to_remove.count(root)) + continue; // Not a tree root + + pool chain = collect_chain(root, children_of); + if (chain.size() < 2) + continue; + + int neg_compensation; + auto operands = extract_chain_operands(chain, root, parent_of, neg_compensation); + if (operands.size() < 3) + continue; + + for (auto c : chain) + to_remove.insert(c); + + replace_with_carry_save_tree(operands, root->getPort(ID::Y), neg_compensation, "Replaced add/sub chain"); + } + + for (auto cell : to_remove) + module->remove(cell); + } + + void process_maccs() + { + for (auto cell : cells.macc) { + std::vector operands; + int neg_compensation; + if (!extract_macc_operands(cell, operands, neg_compensation)) + continue; + if (operands.size() < 3) + continue; + + replace_with_carry_save_tree(operands, cell->getPort(ID::Y), neg_compensation, "Replaced $macc"); + module->remove(cell); + } + } +}; + +void run(Module *module) +{ + Cells cells(module); + + if (cells.empty()) + return; + + Rewriter rewriter{module, cells}; + rewriter.process_chains(); + rewriter.process_maccs(); +} + +struct ArithTreePass : public Pass { + ArithTreePass() : Pass("arith_tree", "convert add/sub/macc chains to carry-save adder trees") {} + + void help() override + { + // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| + log("\n"); + log(" arith_tree [selection]\n"); + log("\n"); + log("This pass replaces chains of $add/$sub cells, $alu cells (with constant\n"); + log("BI/CI), and $macc/$macc_v2 cells (without multiplications) with carry-save\n"); + log("adder trees using $fa cells and a single final $add.\n"); + log("\n"); + log("The tree uses Wallace-tree scheduling: at each level, ready operands are\n"); + log("grouped into triplets and compressed via full adders, giving\n"); + log("O(log_{1.5} N) depth for N input operands.\n"); + log("\n"); + } + + void execute(std::vector args, RTLIL::Design *design) override + { + log_header(design, "Executing ARITH_TREE pass.\n"); + + size_t argidx; + for (argidx = 1; argidx < args.size(); argidx++) + break; + extra_args(args, argidx, design); + + for (auto module : design->selected_modules()) { + run(module); + } + } +} ArithTreePass; + +PRIVATE_NAMESPACE_END diff --git a/passes/techmap/booth.cc b/passes/techmap/booth.cc index 11ff71b29..c0bad784a 100644 --- a/passes/techmap/booth.cc +++ b/passes/techmap/booth.cc @@ -58,6 +58,7 @@ synth -top my_design -booth #include "kernel/sigtools.h" #include "kernel/yosys.h" #include "kernel/macc.h" +#include "kernel/wallace_tree.h" USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN @@ -317,36 +318,6 @@ struct BoothPassWorker { } } - SigSig WallaceSum(int width, std::vector summands) - { - for (auto &s : summands) - s.extend_u0(width); - - while (summands.size() > 2) { - std::vector new_summands; - int i; - for (i = 0; i < (int) summands.size() - 2; i += 3) { - SigSpec x = module->addWire(NEW_ID, width); - SigSpec y = module->addWire(NEW_ID, width); - BuildBitwiseFa(module, NEW_ID.str(), summands[i], summands[i + 1], - summands[i + 2], x, y); - new_summands.push_back(y); - new_summands.push_back({x.extract(0, width - 1), State::S0}); - } - - new_summands.insert(new_summands.begin(), summands.begin() + i, summands.end()); - - std::swap(summands, new_summands); - } - - if (!summands.size()) - return SigSig(SigSpec(width, State::S0), SigSpec(width, State::S0)); - else if (summands.size() == 1) - return SigSig(summands[0], SigSpec(width, State::S0)); - else - return SigSig(summands[0], summands[1]); - } - /* Build Multiplier. ------------------------- @@ -415,16 +386,16 @@ struct BoothPassWorker { // Later on yosys will clean up unused constants // DebugDumpAlignPP(aligned_pp); - SigSig wtree_sum = WallaceSum(z_sz, aligned_pp); + auto [wtree_a, wtree_b] = wallace_reduce_scheduled(module, aligned_pp, z_sz); // Debug code: Dump out the csa trees // DumpCSATrees(debug_csa_trees); // Build the CPA to do the final accumulation. - log_assert(wtree_sum.second[0] == State::S0); + log_assert(wtree_b[0] == State::S0); if (mapped_cpa) - BuildCPA(module, wtree_sum.first, {State::S0, wtree_sum.second.extract_end(1)}, Z); + BuildCPA(module, wtree_a, wtree_b, Z); else - module->addAdd(NEW_ID, wtree_sum.first, {wtree_sum.second.extract_end(1), State::S0}, Z); + module->addAdd(NEW_ID, wtree_a, wtree_b, Z); } /* diff --git a/techlibs/common/synth.cc b/techlibs/common/synth.cc index 0623bf43d..daebd789a 100644 --- a/techlibs/common/synth.cc +++ b/techlibs/common/synth.cc @@ -67,6 +67,10 @@ struct SynthPass : public ScriptPass { log(" -booth\n"); log(" run the booth pass to map $mul to Booth encoded multipliers\n"); log("\n"); + log(" -arith_tree\n"); + log(" run the arith_tree pass to convert $add/$sub chains and $macc cells to\n"); + log(" carry-save adder trees.\n"); + log("\n"); log(" -noalumacc\n"); log(" do not run 'alumacc' pass. i.e. keep arithmetic operators in\n"); log(" their direct form ($add, $sub, etc.).\n"); @@ -108,7 +112,7 @@ struct SynthPass : public ScriptPass { } string top_module, fsm_opts, memory_opts, abc; - bool autotop, flatten, noalumacc, nofsm, noabc, noshare, flowmap, booth, hieropt, relative_share; + bool autotop, flatten, noalumacc, nofsm, noabc, noshare, flowmap, booth, arith_tree, hieropt, relative_share; int lut; std::vector techmap_maps; @@ -127,6 +131,7 @@ struct SynthPass : public ScriptPass { noshare = false; flowmap = false; booth = false; + arith_tree = false; hieropt = false; relative_share = false; abc = "abc"; @@ -187,7 +192,10 @@ struct SynthPass : public ScriptPass { booth = true; continue; } - + if (args[argidx] == "-arith_tree") { + arith_tree = true; + continue; + } if (args[argidx] == "-nordff") { memory_opts += " -nordff"; continue; @@ -289,6 +297,8 @@ struct SynthPass : public ScriptPass { run("booth", " (if -booth)"); if (!noalumacc) run("alumacc", " (unless -noalumacc)"); + if (arith_tree || help_mode) + run("arith_tree", " (if -arith_tree)"); if (!noshare) run("share", " (unless -noshare)"); run("opt" + hieropt_flag); @@ -301,7 +311,7 @@ struct SynthPass : public ScriptPass { run("memory_map"); run("opt -full"); if (help_mode) { - run(techmap_cmd, " (unless -extra-map)"); + run(techmap_cmd, " (unless -extra-map)"); run(techmap_cmd + " -map +/techmap.v -map ", " (if -extra-map)"); } else { std::string techmap_opts; diff --git a/tests/arith_tree/arith_tree_add_chains.ys b/tests/arith_tree/arith_tree_add_chains.ys new file mode 100644 index 000000000..f293ed9da --- /dev/null +++ b/tests/arith_tree/arith_tree_add_chains.ys @@ -0,0 +1,197 @@ +read_verilog <