Further CSA cleanup.

This commit is contained in:
nella 2026-04-01 13:30:34 +02:00 committed by nella
parent 847a8941e9
commit 135812ab02
1 changed files with 13 additions and 11 deletions

View File

@ -1,5 +1,10 @@
/** /**
* Replaces chains of $add/$sub and $macc cells with carry-save adder trees * Replaces chains of $add/$sub and $macc cells with carry-save adder trees
*
* Terminology:
* - parent: Cells that consume another cell's output
* - chainable: Adds/subs with no carry-out usage
* - chain: Connected path of chainable cells
*/ */
#include "kernel/macc.h" #include "kernel/macc.h"
@ -83,8 +88,6 @@ struct AluInfo {
return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0; return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0;
} }
// Chainable cells are adds/subs with no carry usage, connected chainable
// cells form chains that can be replaced with CSA trees.
bool is_chainable(Cell *cell) bool is_chainable(Cell *cell)
{ {
if (!(is_add(cell) || is_subtract(cell))) if (!(is_add(cell) || is_subtract(cell)))
@ -130,7 +133,6 @@ struct Rewriter {
return consumer; return consumer;
} }
// Find cells that consume another cell's output.
dict<Cell *, Cell *> find_parents(const pool<Cell *> &candidates) dict<Cell *, Cell *> find_parents(const pool<Cell *> &candidates)
{ {
dict<Cell *, Cell *> parent_of; dict<Cell *, Cell *> parent_of;
@ -201,6 +203,7 @@ struct Rewriter {
if (!parent_subtracts) if (!parent_subtracts)
return false; return false;
// Check if any bit of child's Y connects to parent's B
SigSpec child_y = traversal.sigmap(child->getPort(ID::Y)); SigSpec child_y = traversal.sigmap(child->getPort(ID::Y));
SigSpec parent_b = traversal.sigmap(parent->getPort(ID::B)); SigSpec parent_b = traversal.sigmap(parent->getPort(ID::B));
for (auto bit : child_y) for (auto bit : child_y)
@ -214,6 +217,8 @@ struct Rewriter {
int &neg_compensation) int &neg_compensation)
{ {
pool<SigBit> chain_bits = internal_bits(chain); pool<SigBit> chain_bits = internal_bits(chain);
// Propagate negation flags through chain
dict<Cell *, bool> negated; dict<Cell *, bool> negated;
negated[root] = false; negated[root] = false;
{ {
@ -233,15 +238,12 @@ struct Rewriter {
} }
} }
// Extract leaf operands
std::vector<Operand> operands; std::vector<Operand> operands;
neg_compensation = 0; neg_compensation = 0;
for (auto cell : chain) { for (auto cell : chain) {
bool cell_neg; bool cell_neg = negated.count(cell) ? negated[cell] : false;
if (negated.count(cell))
cell_neg = negated[cell];
else
cell_neg = false;
SigSpec a = traversal.sigmap(cell->getPort(ID::A)); SigSpec a = traversal.sigmap(cell->getPort(ID::A));
SigSpec b = traversal.sigmap(cell->getPort(ID::B)); SigSpec b = traversal.sigmap(cell->getPort(ID::B));
@ -249,10 +251,10 @@ struct Rewriter {
bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); bool b_signed = cell->getParam(ID::B_SIGNED).as_bool();
bool b_sub = (cell->type == ID($sub)) || (cells.is_alu(cell) && alu_info.is_subtract(cell)); bool b_sub = (cell->type == ID($sub)) || (cells.is_alu(cell) && alu_info.is_subtract(cell));
// Only add operands not produced by other chain cells
if (!overlaps(a, chain_bits)) { if (!overlaps(a, chain_bits)) {
bool neg = cell_neg; operands.push_back({a, a_signed, cell_neg});
operands.push_back({a, a_signed, neg}); if (cell_neg)
if (neg)
neg_compensation++; neg_compensation++;
} }
if (!overlaps(b, chain_bits)) { if (!overlaps(b, chain_bits)) {