mirror of https://github.com/YosysHQ/yosys.git
opt_argmax pass
This commit is contained in:
parent
c7b2c16405
commit
b3ea5770cd
|
|
@ -22,11 +22,13 @@ OBJS += passes/opt/opt_lut_ins.o
|
|||
OBJS += passes/opt/opt_ffinv.o
|
||||
OBJS += passes/opt/pmux2shiftx.o
|
||||
OBJS += passes/opt/muxpack.o
|
||||
|
||||
OBJS += passes/opt/opt_addcin.o
|
||||
OBJS += passes/opt/opt_andor_pmux.o
|
||||
OBJS += passes/opt/opt_argmax.o
|
||||
OBJS += passes/opt/opt_balance_tree.o
|
||||
OBJS += passes/opt/opt_parallel_prefix.o
|
||||
OBJS += passes/opt/opt_prienc.o
|
||||
OBJS += passes/opt/opt_addcin.o
|
||||
|
||||
OBJS += passes/opt/peepopt.o
|
||||
GENFILES += passes/opt/peepopt_pm.h
|
||||
|
|
|
|||
|
|
@ -0,0 +1,775 @@
|
|||
/*
|
||||
* yosys -- Yosys Open SYnthesis Suite
|
||||
*
|
||||
* Copyright (C) 2026 Akash Levy <akash@silimate.com>
|
||||
*
|
||||
* Permission to use, copy, modify, and/or distribute this software for any
|
||||
* purpose with or without fee is hereby granted, provided that the above
|
||||
* copyright notice and this permission notice appear in all copies.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||
*
|
||||
*/
|
||||
|
||||
#include "kernel/yosys.h"
|
||||
#include "kernel/sigtools.h"
|
||||
#include "kernel/consteval.h"
|
||||
#include <cctype>
|
||||
#include <map>
|
||||
#include <queue>
|
||||
|
||||
USING_YOSYS_NAMESPACE
|
||||
PRIVATE_NAMESPACE_BEGIN
|
||||
|
||||
static int clog2_int(int x)
|
||||
{
|
||||
int r = 0;
|
||||
while ((1 << r) < x)
|
||||
r++;
|
||||
return r;
|
||||
}
|
||||
|
||||
static bool is_power_of_two(int x)
|
||||
{
|
||||
return x > 0 && (x & (x - 1)) == 0;
|
||||
}
|
||||
|
||||
static Const packed_table_const(const vector<uint64_t> &values, int elem_width)
|
||||
{
|
||||
vector<State> bits(values.size() * elem_width, State::S0);
|
||||
for (int i = 0; i < GetSize(values); i++)
|
||||
for (int b = 0; b < elem_width && b < 64; b++)
|
||||
if ((values[i] >> b) & 1ULL)
|
||||
bits[i * elem_width + b] = State::S1;
|
||||
return Const(bits);
|
||||
}
|
||||
|
||||
static Const packed_valid_const(const vector<int> &valid)
|
||||
{
|
||||
vector<State> bits(valid.size(), State::S0);
|
||||
for (int i = 0; i < GetSize(valid); i++)
|
||||
if (valid[i])
|
||||
bits[i] = State::S1;
|
||||
return Const(bits);
|
||||
}
|
||||
|
||||
struct OptArgmaxWorker
|
||||
{
|
||||
struct TestVector {
|
||||
vector<int> valid;
|
||||
vector<uint64_t> index;
|
||||
vector<uint64_t> values;
|
||||
};
|
||||
|
||||
struct Candidate {
|
||||
Wire *out_wire = nullptr;
|
||||
Wire *valid_wire = nullptr;
|
||||
SigSpec valid_sig;
|
||||
SigSpec index_sig;
|
||||
SigSpec values_sig;
|
||||
std::string index_name;
|
||||
std::string values_name;
|
||||
int width = 0;
|
||||
int index_width = 0;
|
||||
int value_width = 0;
|
||||
Cell *anchor = nullptr;
|
||||
IdString anchor_port;
|
||||
};
|
||||
|
||||
struct OutputCone {
|
||||
pool<Cell *> cells;
|
||||
pool<SigBit> leaves;
|
||||
bool saw_bmux = false;
|
||||
bool saw_lt = false;
|
||||
};
|
||||
|
||||
struct InputBus {
|
||||
SigSpec sig;
|
||||
std::string name;
|
||||
int entries = 0;
|
||||
int elem_width = 0;
|
||||
};
|
||||
|
||||
struct Record {
|
||||
SigBit valid;
|
||||
SigSpec value;
|
||||
SigSpec index;
|
||||
};
|
||||
|
||||
Module *module;
|
||||
SigMap sigmap;
|
||||
dict<SigBit, Cell *> bit_to_driver;
|
||||
pool<SigBit> input_port_bits;
|
||||
Cell *cell = nullptr;
|
||||
|
||||
int min_width = 4;
|
||||
int max_width = 64;
|
||||
int regions_rewritten = 0;
|
||||
int cells_added = 0;
|
||||
|
||||
OptArgmaxWorker(Module *module) : module(module), sigmap(module)
|
||||
{
|
||||
build_indexes();
|
||||
}
|
||||
|
||||
bool is_sequential(Cell *c)
|
||||
{
|
||||
return c->type.in(
|
||||
ID($ff), ID($dff), ID($dffe), ID($adff), ID($adffe),
|
||||
ID($sdff), ID($sdffe), ID($sdffce), ID($dffsr), ID($dffsre),
|
||||
ID($_DFF_P_), ID($_DFF_N_),
|
||||
ID($_DFFE_PP_), ID($_DFFE_PN_), ID($_DFFE_NP_), ID($_DFFE_NN_),
|
||||
ID($_DFF_PP0_), ID($_DFF_PP1_), ID($_DFF_PN0_), ID($_DFF_PN1_),
|
||||
ID($_DFF_NP0_), ID($_DFF_NP1_), ID($_DFF_NN0_), ID($_DFF_NN1_),
|
||||
ID($dlatch), ID($adlatch), ID($dlatchsr),
|
||||
ID($mem), ID($mem_v2), ID($meminit), ID($meminit_v2),
|
||||
ID($memrd), ID($memrd_v2), ID($memwr), ID($memwr_v2),
|
||||
ID($fsm),
|
||||
ID($assert), ID($assume), ID($cover), ID($live), ID($fair),
|
||||
ID($print), ID($check),
|
||||
ID($anyconst), ID($anyseq), ID($allconst), ID($allseq),
|
||||
ID($initstate));
|
||||
}
|
||||
|
||||
void build_indexes()
|
||||
{
|
||||
for (auto c : module->cells()) {
|
||||
if (is_sequential(c))
|
||||
continue;
|
||||
for (auto &conn : c->connections()) {
|
||||
if (!c->output(conn.first))
|
||||
continue;
|
||||
for (auto bit : sigmap(conn.second)) {
|
||||
if (!bit.wire)
|
||||
continue;
|
||||
auto it = bit_to_driver.find(bit);
|
||||
if (it == bit_to_driver.end())
|
||||
bit_to_driver[bit] = c;
|
||||
else if (it->second != c)
|
||||
it->second = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto w : module->wires()) {
|
||||
if (!w->port_input)
|
||||
continue;
|
||||
for (auto bit : sigmap(SigSpec(w)))
|
||||
if (bit.wire)
|
||||
input_port_bits.insert(bit);
|
||||
}
|
||||
}
|
||||
|
||||
bool get_cone(SigSpec from, pool<Cell *> &cone_cells, pool<SigBit> &leaf_bits,
|
||||
int max_cone_cells, int max_leaf_bits)
|
||||
{
|
||||
pool<SigBit> visited;
|
||||
std::queue<SigBit> worklist;
|
||||
for (auto bit : sigmap(from)) {
|
||||
if (!bit.wire)
|
||||
continue;
|
||||
if (visited.insert(bit).second)
|
||||
worklist.push(bit);
|
||||
}
|
||||
|
||||
while (!worklist.empty()) {
|
||||
SigBit bit = worklist.front();
|
||||
worklist.pop();
|
||||
|
||||
if (input_port_bits.count(bit)) {
|
||||
leaf_bits.insert(bit);
|
||||
if (GetSize(leaf_bits) > max_leaf_bits)
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
|
||||
Cell *drv = bit_to_driver.at(bit, nullptr);
|
||||
if (drv == nullptr) {
|
||||
leaf_bits.insert(bit);
|
||||
if (GetSize(leaf_bits) > max_leaf_bits)
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!cone_cells.insert(drv).second)
|
||||
continue;
|
||||
if (GetSize(cone_cells) > max_cone_cells)
|
||||
return false;
|
||||
|
||||
for (auto &conn : drv->connections()) {
|
||||
if (!drv->input(conn.first))
|
||||
continue;
|
||||
for (auto in_bit : sigmap(conn.second)) {
|
||||
if (!in_bit.wire)
|
||||
continue;
|
||||
if (visited.insert(in_bit).second)
|
||||
worklist.push(in_bit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
OutputCone summarize_output_cone(const pool<Cell *> &cone_cells, pool<SigBit> leaf_bits)
|
||||
{
|
||||
OutputCone cone;
|
||||
cone.cells = cone_cells;
|
||||
cone.leaves = std::move(leaf_bits);
|
||||
for (auto c : cone_cells) {
|
||||
cone.saw_bmux = cone.saw_bmux || c->type == ID($bmux);
|
||||
cone.saw_lt = cone.saw_lt || c->type == ID($lt);
|
||||
}
|
||||
return cone;
|
||||
}
|
||||
|
||||
bool cone_has_required_shape(const OutputCone &cone, int value_width)
|
||||
{
|
||||
return cone.saw_bmux && (cone.saw_lt || value_width == 1);
|
||||
}
|
||||
|
||||
bool leaves_are_candidate_inputs(const pool<SigBit> &leaf_bits, const Candidate &cand)
|
||||
{
|
||||
pool<SigBit> allowed;
|
||||
for (auto bit : sigmap(cand.valid_sig))
|
||||
if (bit.wire)
|
||||
allowed.insert(bit);
|
||||
for (auto bit : sigmap(cand.index_sig))
|
||||
if (bit.wire)
|
||||
allowed.insert(bit);
|
||||
for (auto bit : sigmap(cand.values_sig))
|
||||
if (bit.wire)
|
||||
allowed.insert(bit);
|
||||
|
||||
for (auto bit : leaf_bits)
|
||||
if (!allowed.count(bit))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool find_anchor_driver(Wire *out_wire, Cell *&anchor, IdString &anchor_port)
|
||||
{
|
||||
for (auto bit : sigmap(SigSpec(out_wire))) {
|
||||
Cell *drv = bit_to_driver.at(bit, nullptr);
|
||||
if (drv == nullptr)
|
||||
continue;
|
||||
for (auto &conn : drv->connections()) {
|
||||
if (!drv->output(conn.first))
|
||||
continue;
|
||||
for (auto out_bit : sigmap(conn.second)) {
|
||||
if (out_bit == bit) {
|
||||
anchor = drv;
|
||||
anchor_port = conn.first;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
uint64_t value_mask(int width)
|
||||
{
|
||||
if (width >= 64)
|
||||
return ~0ULL;
|
||||
return (1ULL << width) - 1;
|
||||
}
|
||||
|
||||
void add_vector(vector<TestVector> &vectors, const vector<int> &valid,
|
||||
const vector<uint64_t> &index, const vector<uint64_t> &values)
|
||||
{
|
||||
vectors.push_back({valid, index, values});
|
||||
}
|
||||
|
||||
vector<TestVector> make_test_vectors(int width, int value_width)
|
||||
{
|
||||
vector<TestVector> vectors;
|
||||
vector<uint64_t> identity(width), reverse(width), inc(width), dec(width), equal(width, 7);
|
||||
uint64_t mask = value_mask(value_width);
|
||||
for (int i = 0; i < width; i++) {
|
||||
identity[i] = i;
|
||||
reverse[i] = width - 1 - i;
|
||||
inc[i] = uint64_t(i + 1) & mask;
|
||||
dec[i] = uint64_t(width - i) & mask;
|
||||
}
|
||||
|
||||
vector<int> valid(width, 0);
|
||||
add_vector(vectors, valid, identity, inc);
|
||||
|
||||
for (int i = 0; i < width; i++) {
|
||||
valid.assign(width, 0);
|
||||
valid[i] = 1;
|
||||
add_vector(vectors, valid, identity, inc);
|
||||
}
|
||||
|
||||
valid.assign(width, 1);
|
||||
add_vector(vectors, valid, identity, inc);
|
||||
add_vector(vectors, valid, identity, dec);
|
||||
add_vector(vectors, valid, identity, equal);
|
||||
add_vector(vectors, valid, reverse, inc);
|
||||
add_vector(vectors, valid, reverse, dec);
|
||||
|
||||
for (int i = 0; i + 1 < width; i++) {
|
||||
vector<uint64_t> vals(width, 3);
|
||||
valid.assign(width, 0);
|
||||
valid[i] = 1;
|
||||
valid[i + 1] = 1;
|
||||
vals[i] = 1;
|
||||
vals[i + 1] = 9;
|
||||
add_vector(vectors, valid, identity, vals);
|
||||
vals[i] = 5;
|
||||
vals[i + 1] = 5;
|
||||
add_vector(vectors, valid, identity, vals);
|
||||
}
|
||||
|
||||
if (width > 2) {
|
||||
vector<uint64_t> vals(width, 0);
|
||||
valid.assign(width, 0);
|
||||
valid[0] = 1;
|
||||
valid[width - 1] = 1;
|
||||
vals[0] = 2;
|
||||
vals[width - 1] = 11;
|
||||
add_vector(vectors, valid, identity, vals);
|
||||
vals[0] = 13;
|
||||
vals[width - 1] = 13;
|
||||
add_vector(vectors, valid, identity, vals);
|
||||
}
|
||||
|
||||
return vectors;
|
||||
}
|
||||
|
||||
int expected_argmax(const TestVector &tv, int width, int value_width)
|
||||
{
|
||||
uint64_t mask = value_mask(value_width);
|
||||
int best_idx = 0;
|
||||
bool best_valid = tv.valid[0] != 0;
|
||||
uint64_t best_value = tv.values[tv.index[0]] & mask;
|
||||
|
||||
for (int k = 1; k < width; k++) {
|
||||
bool cand_valid = tv.valid[k] != 0;
|
||||
uint64_t cand_value = tv.values[tv.index[k]] & mask;
|
||||
if (!best_valid && cand_valid) {
|
||||
best_idx = k;
|
||||
best_valid = true;
|
||||
best_value = cand_value;
|
||||
} else if (best_valid && cand_valid && best_value < cand_value) {
|
||||
best_idx = k;
|
||||
best_value = cand_value;
|
||||
}
|
||||
}
|
||||
|
||||
return best_idx;
|
||||
}
|
||||
|
||||
bool fingerprint(const Candidate &cand)
|
||||
{
|
||||
ConstEval ce(module);
|
||||
SigSpec out_sig = sigmap(SigSpec(cand.out_wire));
|
||||
SigSpec valid_sig = sigmap(cand.valid_sig);
|
||||
SigSpec index_sig = sigmap(cand.index_sig);
|
||||
SigSpec values_sig = sigmap(cand.values_sig);
|
||||
|
||||
vector<TestVector> vectors = make_test_vectors(cand.width, cand.value_width);
|
||||
for (auto &tv : vectors) {
|
||||
ce.push();
|
||||
ce.set(valid_sig, packed_valid_const(tv.valid));
|
||||
ce.set(index_sig, packed_table_const(tv.index, cand.index_width));
|
||||
ce.set(values_sig, packed_table_const(tv.values, cand.value_width));
|
||||
|
||||
SigSpec out = out_sig;
|
||||
SigSpec undef;
|
||||
bool ok = ce.eval(out, undef);
|
||||
ce.pop();
|
||||
if (!ok || !out.is_fully_const())
|
||||
return false;
|
||||
|
||||
int actual = out.as_const().as_int();
|
||||
int expected = expected_argmax(tv, cand.width, cand.value_width);
|
||||
if (actual != expected)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
SigSpec zext(SigSpec sig, int width)
|
||||
{
|
||||
sig = sigmap(sig);
|
||||
if (GetSize(sig) > width)
|
||||
return sig.extract(0, width);
|
||||
while (GetSize(sig) < width)
|
||||
sig.append(State::S0);
|
||||
return sig;
|
||||
}
|
||||
|
||||
SigSpec emit_not(Cell *anchor, SigSpec a)
|
||||
{
|
||||
Cell *cell = anchor;
|
||||
cells_added++;
|
||||
return module->Not(NEW_ID2_SUFFIX("argmax_not"), a);
|
||||
}
|
||||
|
||||
SigSpec emit_and(Cell *anchor, SigSpec a, SigSpec b)
|
||||
{
|
||||
Cell *cell = anchor;
|
||||
cells_added++;
|
||||
return module->And(NEW_ID2_SUFFIX("argmax_and"), a, b);
|
||||
}
|
||||
|
||||
SigSpec emit_or(Cell *anchor, SigSpec a, SigSpec b)
|
||||
{
|
||||
Cell *cell = anchor;
|
||||
cells_added++;
|
||||
return module->Or(NEW_ID2_SUFFIX("argmax_or"), a, b);
|
||||
}
|
||||
|
||||
SigSpec emit_lt(Cell *anchor, SigSpec a, SigSpec b)
|
||||
{
|
||||
Cell *cell = anchor;
|
||||
cells_added++;
|
||||
return module->Lt(NEW_ID2_SUFFIX("argmax_lt"), a, b);
|
||||
}
|
||||
|
||||
SigSpec emit_mux(Cell *anchor, SigSpec a, SigSpec b, SigSpec s)
|
||||
{
|
||||
Cell *cell = anchor;
|
||||
cells_added++;
|
||||
return module->Mux(NEW_ID2_SUFFIX("argmax_mux"), a, b, s);
|
||||
}
|
||||
|
||||
SigSpec emit_bmux(Cell *anchor, SigSpec a, SigSpec s)
|
||||
{
|
||||
Cell *cell = anchor;
|
||||
cells_added++;
|
||||
return module->Bmux(NEW_ID2_SUFFIX("argmax_val"), a, s);
|
||||
}
|
||||
|
||||
Record combine(Cell *anchor, const Record &lhs, const Record &rhs)
|
||||
{
|
||||
SigSpec lhs_invalid = emit_not(anchor, SigSpec(lhs.valid));
|
||||
SigSpec value_lt = emit_lt(anchor, lhs.value, rhs.value);
|
||||
SigSpec valid_and_lt = emit_and(anchor, SigSpec(lhs.valid), value_lt);
|
||||
SigSpec take_reason = emit_or(anchor, lhs_invalid, valid_and_lt);
|
||||
SigSpec take_rhs = emit_and(anchor, SigSpec(rhs.valid), take_reason);
|
||||
|
||||
Record out;
|
||||
out.valid = emit_or(anchor, SigSpec(lhs.valid), SigSpec(rhs.valid))[0];
|
||||
out.value = emit_mux(anchor, lhs.value, rhs.value, take_rhs);
|
||||
out.index = emit_mux(anchor, lhs.index, rhs.index, take_rhs);
|
||||
return out;
|
||||
}
|
||||
|
||||
Record emit_tree_rec(Cell *anchor, const vector<Record> &leaves, int begin, int end)
|
||||
{
|
||||
log_assert(begin < end);
|
||||
if (begin + 1 == end)
|
||||
return leaves[begin];
|
||||
|
||||
int mid = begin + (end - begin) / 2;
|
||||
Record lhs = emit_tree_rec(anchor, leaves, begin, mid);
|
||||
Record rhs = emit_tree_rec(anchor, leaves, mid, end);
|
||||
return combine(anchor, lhs, rhs);
|
||||
}
|
||||
|
||||
SigSpec emit_argmax(const Candidate &cand)
|
||||
{
|
||||
vector<Record> leaves;
|
||||
SigSpec valid = sigmap(cand.valid_sig);
|
||||
SigSpec index_map = sigmap(cand.index_sig);
|
||||
SigSpec values = sigmap(cand.values_sig);
|
||||
|
||||
for (int k = 0; k < cand.width; k++) {
|
||||
SigSpec index = index_map.extract(k * cand.index_width, cand.index_width);
|
||||
SigSpec value = emit_bmux(cand.anchor, values, index);
|
||||
leaves.push_back({valid[k], value, SigSpec(Const(k, cand.index_width))});
|
||||
}
|
||||
|
||||
Record root = emit_tree_rec(cand.anchor, leaves, 0, GetSize(leaves));
|
||||
return zext(root.index, cand.index_width);
|
||||
}
|
||||
|
||||
void disconnect_old_output(const Candidate &cand)
|
||||
{
|
||||
pool<SigBit> target_bits;
|
||||
for (auto bit : sigmap(SigSpec(cand.out_wire)))
|
||||
if (bit.wire)
|
||||
target_bits.insert(bit);
|
||||
|
||||
pool<Cell *> seen_cells;
|
||||
for (auto target : target_bits) {
|
||||
Cell *drv = bit_to_driver.at(target, nullptr);
|
||||
if (drv == nullptr || seen_cells.count(drv))
|
||||
continue;
|
||||
seen_cells.insert(drv);
|
||||
|
||||
for (auto &conn : drv->connections()) {
|
||||
if (!drv->output(conn.first))
|
||||
continue;
|
||||
|
||||
SigSpec orig = conn.second;
|
||||
SigSpec replacement = orig;
|
||||
bool changed = false;
|
||||
Cell *cell = drv;
|
||||
Wire *dangling = module->addWire(NEW_ID2_SUFFIX("argmax_dangling"), GetSize(orig));
|
||||
for (int i = 0; i < GetSize(orig); i++) {
|
||||
if (target_bits.count(sigmap(orig[i]))) {
|
||||
replacement[i] = SigBit(dangling, i);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed)
|
||||
drv->setPort(conn.first, replacement);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool check_candidate(Candidate &cand, const OutputCone &cone)
|
||||
{
|
||||
if (cand.width < min_width || cand.width > max_width)
|
||||
return false;
|
||||
if (!is_power_of_two(cand.width))
|
||||
return false;
|
||||
if (cand.index_width != clog2_int(cand.width))
|
||||
return false;
|
||||
if (cand.value_width <= 0 || cand.value_width > 62)
|
||||
return false;
|
||||
|
||||
if (!cone_has_required_shape(cone, cand.value_width))
|
||||
return false;
|
||||
if (!leaves_are_candidate_inputs(cone.leaves, cand))
|
||||
return false;
|
||||
if (!find_anchor_driver(cand.out_wire, cand.anchor, cand.anchor_port))
|
||||
return false;
|
||||
|
||||
return fingerprint(cand);
|
||||
}
|
||||
|
||||
bool parse_indexed_port_name(Wire *wire, std::string &base, int &index)
|
||||
{
|
||||
std::string name = wire->name.str();
|
||||
size_t rbrack = name.size();
|
||||
if (rbrack == 0 || name[rbrack - 1] != ']')
|
||||
return false;
|
||||
size_t lbrack = name.rfind('[');
|
||||
if (lbrack == std::string::npos || lbrack + 1 >= rbrack - 1)
|
||||
return false;
|
||||
for (size_t i = lbrack + 1; i < rbrack - 1; i++)
|
||||
if (!isdigit(name[i]))
|
||||
return false;
|
||||
base = name.substr(0, lbrack);
|
||||
index = atoi(name.substr(lbrack + 1, rbrack - lbrack - 2).c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
vector<InputBus> collect_split_input_buses(const vector<Wire *> &inputs)
|
||||
{
|
||||
std::map<std::string, vector<std::pair<int, Wire *>>> groups;
|
||||
for (auto w : inputs) {
|
||||
std::string base;
|
||||
int index = -1;
|
||||
if (parse_indexed_port_name(w, base, index))
|
||||
groups[base].push_back({index, w});
|
||||
}
|
||||
|
||||
vector<InputBus> buses;
|
||||
for (auto &it : groups) {
|
||||
auto entries = it.second;
|
||||
std::sort(entries.begin(), entries.end(),
|
||||
[](const std::pair<int, Wire *> &a, const std::pair<int, Wire *> &b) {
|
||||
return a.first < b.first;
|
||||
});
|
||||
if (entries.empty() || entries.front().first != 0)
|
||||
continue;
|
||||
bool contiguous = true;
|
||||
int elem_width = GetSize(entries.front().second);
|
||||
for (int i = 0; i < GetSize(entries); i++) {
|
||||
if (entries[i].first != i || GetSize(entries[i].second) != elem_width) {
|
||||
contiguous = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!contiguous)
|
||||
continue;
|
||||
|
||||
SigSpec sig;
|
||||
for (auto &entry : entries)
|
||||
sig.append(SigSpec(entry.second));
|
||||
buses.push_back({sig, it.first, GetSize(entries), elem_width});
|
||||
}
|
||||
|
||||
return buses;
|
||||
}
|
||||
|
||||
void run()
|
||||
{
|
||||
if (module->has_processes_warn())
|
||||
return;
|
||||
|
||||
vector<Wire *> inputs;
|
||||
vector<Wire *> outputs;
|
||||
for (auto w : module->wires()) {
|
||||
if (w->port_input)
|
||||
inputs.push_back(w);
|
||||
if (w->port_output && !w->port_input)
|
||||
outputs.push_back(w);
|
||||
}
|
||||
|
||||
vector<Candidate> rewrites;
|
||||
pool<Wire *> claimed_outputs;
|
||||
for (auto out : outputs) {
|
||||
if (claimed_outputs.count(out))
|
||||
continue;
|
||||
int out_width = GetSize(out);
|
||||
if (out_width < 2)
|
||||
continue;
|
||||
|
||||
pool<Cell *> cone_cells;
|
||||
pool<SigBit> leaf_bits;
|
||||
int max_cone_cells = std::max(256, max_width * 96);
|
||||
int max_leaf_bits = max_width * (out_width + max_width) + max_width;
|
||||
if (!get_cone(SigSpec(out), cone_cells, leaf_bits,
|
||||
max_cone_cells, max_leaf_bits))
|
||||
continue;
|
||||
OutputCone cone = summarize_output_cone(cone_cells, std::move(leaf_bits));
|
||||
if (!cone.saw_bmux)
|
||||
continue;
|
||||
|
||||
for (auto valid : inputs) {
|
||||
int width = GetSize(valid);
|
||||
if (width < min_width || width > max_width)
|
||||
continue;
|
||||
if (clog2_int(width) != out_width)
|
||||
continue;
|
||||
|
||||
vector<InputBus> index_buses;
|
||||
vector<InputBus> values_buses;
|
||||
for (auto input : inputs) {
|
||||
if (input == valid)
|
||||
continue;
|
||||
if (GetSize(input) == width * out_width)
|
||||
index_buses.push_back({SigSpec(input), input->name.str(), width, out_width});
|
||||
if (GetSize(input) % width == 0)
|
||||
values_buses.push_back({SigSpec(input), input->name.str(), width, GetSize(input) / width});
|
||||
}
|
||||
|
||||
vector<InputBus> split_buses = collect_split_input_buses(inputs);
|
||||
for (auto bus : split_buses) {
|
||||
if (bus.entries == width && bus.elem_width == out_width)
|
||||
index_buses.push_back(bus);
|
||||
if (bus.entries == width)
|
||||
values_buses.push_back(bus);
|
||||
}
|
||||
|
||||
for (auto &index : index_buses) {
|
||||
for (auto &values : values_buses) {
|
||||
if (index.sig == values.sig)
|
||||
continue;
|
||||
Candidate cand;
|
||||
cand.out_wire = out;
|
||||
cand.valid_wire = valid;
|
||||
cand.valid_sig = SigSpec(valid);
|
||||
cand.index_sig = index.sig;
|
||||
cand.values_sig = values.sig;
|
||||
cand.index_name = index.name;
|
||||
cand.values_name = values.name;
|
||||
cand.width = width;
|
||||
cand.index_width = out_width;
|
||||
cand.value_width = values.elem_width;
|
||||
if (!check_candidate(cand, cone))
|
||||
continue;
|
||||
|
||||
rewrites.push_back(cand);
|
||||
claimed_outputs.insert(out);
|
||||
log(" %s: %s <- argmax(valid=%s, index=%s, values=%s) [N=%d, IW=%d, VW=%d]\n",
|
||||
log_id(module), log_id(out), log_id(valid), index.name.c_str(),
|
||||
values.name.c_str(), cand.width, cand.index_width, cand.value_width);
|
||||
goto next_output;
|
||||
}
|
||||
}
|
||||
}
|
||||
next_output:
|
||||
;
|
||||
}
|
||||
|
||||
for (auto &cand : rewrites) {
|
||||
cell = cand.anchor;
|
||||
SigSpec new_out = emit_argmax(cand);
|
||||
disconnect_old_output(cand);
|
||||
module->connect(SigSpec(cand.out_wire), new_out);
|
||||
regions_rewritten++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct OptArgmaxPass : public Pass
|
||||
{
|
||||
OptArgmaxPass() : Pass("opt_argmax",
|
||||
"detect and rewrite masked argmax loops into balanced compare trees") {}
|
||||
|
||||
void help() override
|
||||
{
|
||||
log("\n");
|
||||
log(" opt_argmax [options] [selection]\n");
|
||||
log("\n");
|
||||
log("Detect combinational masked argmax loops of the form used by\n");
|
||||
log("read-after dependency logic and replace the serial loop-carried\n");
|
||||
log("index/update cone with a balanced tree of {valid,value,index}\n");
|
||||
log("comparators. Ties preserve the lower candidate index, matching a\n");
|
||||
log("strict '<' update condition; all-invalid inputs return index zero.\n");
|
||||
log("\n");
|
||||
log(" -max-width N, -max_width N\n");
|
||||
log(" maximum candidate count to consider (default 64).\n");
|
||||
log("\n");
|
||||
log(" -min-width N, -min_width N\n");
|
||||
log(" minimum candidate count to consider (default 4).\n");
|
||||
log("\n");
|
||||
}
|
||||
|
||||
void execute(std::vector<std::string> args, RTLIL::Design *design) override
|
||||
{
|
||||
log_header(design, "Executing OPT_ARGMAX pass (masked argmax rewrite).\n");
|
||||
|
||||
int max_width = 64;
|
||||
int min_width = 4;
|
||||
size_t argidx;
|
||||
for (argidx = 1; argidx < args.size(); argidx++) {
|
||||
if ((args[argidx] == "-max-width" || args[argidx] == "-max_width") &&
|
||||
argidx + 1 < args.size()) {
|
||||
max_width = std::stoi(args[++argidx]);
|
||||
continue;
|
||||
}
|
||||
if ((args[argidx] == "-min-width" || args[argidx] == "-min_width") &&
|
||||
argidx + 1 < args.size()) {
|
||||
min_width = std::stoi(args[++argidx]);
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
extra_args(args, argidx, design);
|
||||
|
||||
int total_regions = 0;
|
||||
int total_cells_added = 0;
|
||||
for (auto module : design->selected_modules()) {
|
||||
OptArgmaxWorker worker(module);
|
||||
worker.max_width = max_width;
|
||||
worker.min_width = min_width;
|
||||
worker.run();
|
||||
total_regions += worker.regions_rewritten;
|
||||
total_cells_added += worker.cells_added;
|
||||
}
|
||||
|
||||
log("Rewrote %d argmax region(s); emitted %d new cell(s).\n",
|
||||
total_regions, total_cells_added);
|
||||
|
||||
if (total_regions)
|
||||
Yosys::run_pass("clean -purge");
|
||||
}
|
||||
} OptArgmaxPass;
|
||||
|
||||
PRIVATE_NAMESPACE_END
|
||||
|
|
@ -0,0 +1,324 @@
|
|||
module opt_argmax_basic (
|
||||
input wire [15:0] sig,
|
||||
input wire [15:0][3:0] sig3,
|
||||
input wire [15:0][7:0] sig2,
|
||||
output reg [3:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 16; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_w8 (
|
||||
input wire [7:0] sig,
|
||||
input wire [7:0][2:0] sig3,
|
||||
input wire [7:0][4:0] sig2,
|
||||
output reg [2:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 8; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_w32 (
|
||||
input wire [31:0] sig,
|
||||
input wire [31:0][4:0] sig3,
|
||||
input wire [31:0][5:0] sig2,
|
||||
output reg [4:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 32; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_flat (
|
||||
input wire [7:0] sig,
|
||||
input wire [23:0] sig3,
|
||||
input wire [39:0] sig2,
|
||||
output reg [2:0] se_target_idx
|
||||
);
|
||||
function automatic [2:0] idx_at(input [2:0] pos);
|
||||
idx_at = sig3[pos * 3 +: 3];
|
||||
endfunction
|
||||
|
||||
function automatic [4:0] val_at(input [2:0] pos);
|
||||
val_at = sig2[idx_at(pos) * 5 +: 5];
|
||||
endfunction
|
||||
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 8; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(val_at(se_target_idx) < val_at(k[2:0]))) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_value_w1 (
|
||||
input wire [7:0] sig,
|
||||
input wire [7:0][2:0] sig3,
|
||||
input wire [7:0] sig2,
|
||||
output reg [2:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 8; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_value_w16 (
|
||||
input wire [7:0] sig,
|
||||
input wire [7:0][2:0] sig3,
|
||||
input wire [7:0][15:0] sig2,
|
||||
output reg [2:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 8; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_two_regions (
|
||||
input wire [7:0] sig_a,
|
||||
input wire [7:0][2:0] sig3_a,
|
||||
input wire [7:0][7:0] sig2_a,
|
||||
input wire [7:0] sig_b,
|
||||
input wire [7:0][2:0] sig3_b,
|
||||
input wire [7:0][5:0] sig2_b,
|
||||
output reg [2:0] idx_a,
|
||||
output reg [2:0] idx_b
|
||||
);
|
||||
always_comb begin
|
||||
idx_a = '0;
|
||||
for (int k = 1; k < 8; k++) begin
|
||||
if (!sig_a[idx_a] && sig_a[k]) begin
|
||||
idx_a = k;
|
||||
end else if (sig_a[idx_a] && sig_a[k] &&
|
||||
(sig2_a[sig3_a[idx_a]] < sig2_a[sig3_a[k]])) begin
|
||||
idx_a = k;
|
||||
end
|
||||
end
|
||||
|
||||
idx_b = '0;
|
||||
for (int k = 1; k < 8; k++) begin
|
||||
if (!sig_b[idx_b] && sig_b[k]) begin
|
||||
idx_b = k;
|
||||
end else if (sig_b[idx_b] && sig_b[k] &&
|
||||
(sig2_b[sig3_b[idx_b]] < sig2_b[sig3_b[k]])) begin
|
||||
idx_b = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_shared_consumer (
|
||||
input wire [7:0] sig,
|
||||
input wire [7:0][2:0] sig3,
|
||||
input wire [7:0][7:0] sig2,
|
||||
input wire [2:0] salt,
|
||||
output reg [2:0] se_target_idx,
|
||||
output wire [2:0] also_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 8; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
assign also_idx = se_target_idx ^ salt;
|
||||
endmodule
|
||||
|
||||
module opt_argmax_tie_high (
|
||||
input wire [15:0] sig,
|
||||
input wire [15:0][3:0] sig3,
|
||||
input wire [15:0][7:0] sig2,
|
||||
output reg [3:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 16; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] <= sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_nonzero_default (
|
||||
input wire [15:0] sig,
|
||||
input wire [15:0][3:0] sig3,
|
||||
input wire [15:0][7:0] sig2,
|
||||
output reg [3:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = 4'd1;
|
||||
for (int k = 1; k < 16; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_min (
|
||||
input wire [15:0] sig,
|
||||
input wire [15:0][3:0] sig3,
|
||||
input wire [15:0][7:0] sig2,
|
||||
output reg [3:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 16; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] > sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_w12 (
|
||||
input wire [11:0] sig,
|
||||
input wire [11:0][3:0] sig3,
|
||||
input wire [11:0][7:0] sig2,
|
||||
output reg [3:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 12; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_bad_index_width (
|
||||
input wire [15:0] sig,
|
||||
input wire [15:0][4:0] sig3,
|
||||
input wire [15:0][7:0] sig2,
|
||||
output reg [3:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 16; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx][3:0]] < sig2[sig3[k][3:0]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_stress_noop (
|
||||
input wire [63:0] sel,
|
||||
input wire [63:0] a,
|
||||
input wire [63:0] b,
|
||||
output wire [63:0] y
|
||||
);
|
||||
wire [63:0] mux0 = sel[0] ? a : b;
|
||||
wire [63:0] mux1 = sel[1] ? mux0 : {mux0[31:0], mux0[63:32]};
|
||||
wire [63:0] mux2 = sel[2] ? mux1 : (mux1 ^ a);
|
||||
wire [63:0] mux3 = sel[3] ? mux2 : (mux2 & b);
|
||||
wire [63:0] mux4 = sel[4] ? mux3 : (mux3 | a);
|
||||
wire [63:0] mux5 = sel[5] ? mux4 : {mux4[47:0], mux4[63:48]};
|
||||
assign y = sel[6] ? mux5 : ~mux5;
|
||||
endmodule
|
||||
|
||||
module opt_argmax_unrelated (
|
||||
input wire [3:0] a,
|
||||
input wire [3:0] b,
|
||||
input wire sel,
|
||||
output wire [3:0] y
|
||||
);
|
||||
assign y = sel ? a : b;
|
||||
endmodule
|
||||
|
||||
module opt_argmax_multi_match (
|
||||
input wire [15:0] sig,
|
||||
input wire [15:0][3:0] sig3,
|
||||
input wire [15:0][7:0] sig2,
|
||||
output reg [3:0] se_target_idx
|
||||
);
|
||||
always_comb begin
|
||||
se_target_idx = '0;
|
||||
for (int k = 1; k < 16; k++) begin
|
||||
if (!sig[se_target_idx] && sig[k]) begin
|
||||
se_target_idx = k;
|
||||
end else if (sig[se_target_idx] && sig[k] &&
|
||||
(sig2[sig3[se_target_idx]] < sig2[sig3[k]])) begin
|
||||
se_target_idx = k;
|
||||
end
|
||||
end
|
||||
end
|
||||
endmodule
|
||||
|
||||
module opt_argmax_multi_keep (
|
||||
input wire [3:0] a,
|
||||
input wire [3:0] b,
|
||||
input wire sel,
|
||||
output wire [3:0] y
|
||||
);
|
||||
assign y = sel ? a : b;
|
||||
endmodule
|
||||
|
|
@ -0,0 +1,332 @@
|
|||
# Tests for opt_argmax.
|
||||
|
||||
log -header "Small masked argmax self-equivalence"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_w8
|
||||
proc; opt_clean
|
||||
rename opt_argmax_w8 gold
|
||||
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_w8
|
||||
proc; opt_clean
|
||||
select -module opt_argmax_w8
|
||||
opt_argmax
|
||||
select -clear
|
||||
opt_clean
|
||||
rename opt_argmax_w8 gate
|
||||
|
||||
miter -equiv -flatten -make_assert gold gate miter
|
||||
hierarchy -top miter
|
||||
proc; opt; memory; opt
|
||||
sat -prove-asserts -verify
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Basic masked argmax structural rewrite"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_basic
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
select -assert-count 16 t:$bmux
|
||||
select -assert-count 15 t:$lt
|
||||
select -assert-count 29 t:$mux
|
||||
select -assert-count 30 t:$and
|
||||
select -assert-count 29 t:$or
|
||||
select -assert-count 15 t:$not
|
||||
select -assert-none c:LessThan_*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Flat-bus masked argmax self-equivalence"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_flat
|
||||
proc; opt_clean
|
||||
rename opt_argmax_flat gold
|
||||
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_flat
|
||||
proc; opt_clean
|
||||
select -module opt_argmax_flat
|
||||
opt_argmax
|
||||
select -clear
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
rename opt_argmax_flat gate
|
||||
|
||||
miter -equiv -flatten -make_assert gold gate miter
|
||||
hierarchy -top miter
|
||||
proc; opt; memory; opt
|
||||
sat -prove-asserts -verify
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Scaled masked argmax: 8 entries structural"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_w8
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Scaled masked argmax: 32 entries structural"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_w32
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Value width edge: 1-bit values"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_value_w1
|
||||
proc; opt_clean
|
||||
rename opt_argmax_value_w1 gold
|
||||
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_value_w1
|
||||
proc; opt_clean
|
||||
select -module opt_argmax_value_w1
|
||||
opt_argmax
|
||||
select -clear
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
rename opt_argmax_value_w1 gate
|
||||
|
||||
miter -equiv -flatten -make_assert gold gate miter
|
||||
hierarchy -top miter
|
||||
proc; opt; memory; opt
|
||||
sat -prove-asserts -verify
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Value width edge: 16-bit values"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_value_w16
|
||||
proc; opt_clean
|
||||
rename opt_argmax_value_w16 gold
|
||||
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_value_w16
|
||||
proc; opt_clean
|
||||
select -module opt_argmax_value_w16
|
||||
opt_argmax
|
||||
select -clear
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
rename opt_argmax_value_w16 gate
|
||||
|
||||
miter -equiv -flatten -make_assert gold gate miter
|
||||
hierarchy -top miter
|
||||
proc; opt; memory; opt
|
||||
sat -prove-asserts -verify
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Same module: two independent argmax regions"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_two_regions
|
||||
proc; opt_clean
|
||||
rename opt_argmax_two_regions gold
|
||||
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_two_regions
|
||||
proc; opt_clean
|
||||
select -module opt_argmax_two_regions
|
||||
opt_argmax
|
||||
select -clear
|
||||
opt_clean
|
||||
select -assert-min 2 w:*argmax*
|
||||
rename opt_argmax_two_regions gate
|
||||
|
||||
miter -equiv -flatten -make_assert gold gate miter
|
||||
hierarchy -top miter
|
||||
proc; opt; memory; opt
|
||||
sat -prove-asserts -verify
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Shared consumer of argmax output remains equivalent"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_shared_consumer
|
||||
proc; opt_clean
|
||||
rename opt_argmax_shared_consumer gold
|
||||
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_shared_consumer
|
||||
proc; opt_clean
|
||||
select -module opt_argmax_shared_consumer
|
||||
opt_argmax
|
||||
select -clear
|
||||
opt_clean
|
||||
select -assert-min 1 w:*argmax*
|
||||
rename opt_argmax_shared_consumer gate
|
||||
|
||||
miter -equiv -flatten -make_assert gold gate miter
|
||||
hierarchy -top miter
|
||||
proc; opt; memory; opt
|
||||
sat -prove-asserts -verify
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Max width leaves argmax unchanged"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_basic
|
||||
proc; opt_clean
|
||||
opt_argmax -max_width 8
|
||||
select -assert-none w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Negative: non-power-of-two candidate count"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_w12
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
select -assert-none w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Negative: mismatched index-map width"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_bad_index_width
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
select -assert-none w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Negative: strict tie behavior changed"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_tie_high
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
select -assert-none w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Negative: nonzero all-invalid default"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_nonzero_default
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
select -assert-none w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Negative: min-selection comparator"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_min
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
select -assert-none w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Negative: unrelated mux logic unchanged"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_unrelated
|
||||
proc; opt_clean
|
||||
select -assert-count 1 t:$mux
|
||||
opt_argmax
|
||||
select -assert-none w:*argmax*
|
||||
select -assert-count 1 t:$mux
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Negative: bmux-heavy unrelated stress module"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_stress_noop
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
select -assert-none w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
|
||||
log -header "Multi-module: only matching module rewrites"
|
||||
log -push
|
||||
design -reset
|
||||
verific -cfg veri_optimize_wide_selector 1
|
||||
verific -cfg db_infer_wide_muxes_post_elaboration 0
|
||||
read -sv opt_argmax.sv
|
||||
verific -import opt_argmax_multi_match opt_argmax_multi_keep
|
||||
proc; opt_clean
|
||||
opt_argmax
|
||||
opt_clean
|
||||
select -assert-min 1 opt_argmax_multi_match/w:*argmax*
|
||||
select -assert-none opt_argmax_multi_keep/w:*argmax*
|
||||
design -reset
|
||||
log -pop
|
||||
Loading…
Reference in New Issue