Added caching of simulation runs for speed

This commit is contained in:
AdvaySingh1 2026-02-17 13:38:32 -08:00
parent dc4ca2c621
commit efcabb270f
1 changed files with 44 additions and 37 deletions

View File

@ -98,7 +98,7 @@ void profileFlipFlops(Module *module, const std::string &filename, const std::st
// Configuration
static const int DEFAULT_MAX_COVER = 100; // Max candidate signals to consider
static const int DEFAULT_MIN_REGS = 10; // Min registers per clock gate
static const int DEFAULT_SIM_ITERATIONS = 10; // Random simulation iterations for pruning
static const int DEFAULT_SIM_ITERATIONS = 1000; // Random simulation iterations for pruning
struct SatClockgateWorker
{
@ -120,8 +120,8 @@ struct SatClockgateWorker
ezSatPtr ez;
SatGen satgen;
// Random number generator for simulation
std::mt19937 rng;
// Cached simulation results: [iteration][SigBit] = evaluated State
std::vector<dict<SigBit, State>> cached_sim_results;
// Statistics
int accepted_count = 0;
@ -131,7 +131,7 @@ struct SatClockgateWorker
SatClockgateWorker(Module *module, int max_cover, int min_regs, int sim_iterations)
: module(module), sigmap(module),
max_cover(max_cover), min_regs(min_regs), sim_iterations(sim_iterations),
ez(), satgen(ez.get(), &sigmap), rng(42)
ez(), satgen(ez.get(), &sigmap)
{
// Build driver and sink maps
for (auto cell : module->cells()) {
@ -156,16 +156,14 @@ struct SatClockgateWorker
ID($dffsre), ID($_DFF_P_), ID($_DFF_N_), ID($_DFFE_PP_),
ID($_DFFE_PN_), ID($_DFFE_NP_), ID($_DFFE_NN_)))
satgen.importCell(cell);
}
// Run simulation with random inputs using ConstEval
// Returns false if counterexample found (candidate is definitely invalid)
bool simulationTest(const std::vector<SigBit> &conds, SigSpec sig_d, SigSpec sig_q, bool as_enable)
{
// Pre-run simulations and cache all signal values
std::mt19937 rng(42);
cached_sim_results.resize(sim_iterations);
for (int iter = 0; iter < sim_iterations; iter++) {
ConstEval ce(module);
// Generate random values for all input ports
// Set random values for input ports
for (auto wire : module->wires()) {
if (wire->port_input) {
Const rand_val(State::S0, wire->width);
@ -174,8 +172,7 @@ struct SatClockgateWorker
ce.set(SigSpec(wire), rand_val);
}
}
// Randomize FF Q outputs (they're inputs to combinational logic)
// Set random values for FF Q outputs
for (auto cell : module->cells()) {
if (cell->type.in(ID($ff), ID($dff), ID($dffe), ID($adff), ID($adffe),
ID($sdff), ID($sdffe), ID($sdffce), ID($dffsr),
@ -189,45 +186,55 @@ struct SatClockgateWorker
}
}
// Evaluate gating condition signals
// Evaluate and cache ALL wire signals
for (auto wire : module->wires()) {
for (int i = 0; i < wire->width; i++) {
SigBit bit(wire, i);
SigSpec sig(bit);
if (ce.eval(sig))
cached_sim_results[iter][sigmap(bit)] = sig[0].data;
}
}
}
}
// Check cached simulation results for counterexamples
// Returns false if counterexample found (candidate is definitely invalid)
bool simulationTest(const std::vector<SigBit> &conds, SigSpec sig_d, SigSpec sig_q, bool as_enable)
{
for (int iter = 0; iter < sim_iterations; iter++) {
auto &cache = cached_sim_results[iter];
// Lookup gating condition from cache
bool combined_cond;
if (as_enable) {
combined_cond = false;
for (auto bit : conds) {
SigSpec sig(bit);
if (ce.eval(sig)) {
if (sig[0] == State::S1)
combined_cond = true;
}
SigBit mapped = sigmap(bit);
if (cache.count(mapped) && cache[mapped] == State::S1)
combined_cond = true;
}
} else {
combined_cond = true;
for (auto bit : conds) {
SigSpec sig(bit);
if (ce.eval(sig)) {
if (sig[0] != State::S1)
combined_cond = false;
} else {
SigBit mapped = sigmap(bit);
if (!cache.count(mapped) || cache[mapped] != State::S1)
combined_cond = false;
}
}
}
bool gating_active = as_enable ? !combined_cond : combined_cond;
// Evaluate D and Q, check if D != Q
SigSpec d_eval = sig_d;
SigSpec q_eval = sig_q;
bool d_ok = ce.eval(d_eval);
bool q_ok = ce.eval(q_eval);
// Lookup D and Q from cache, check if D != Q
bool d_ne_q = false;
if (d_ok && q_ok) {
for (int i = 0; i < sig_d.size(); i++) {
if (d_eval[i] != q_eval[i]) {
d_ne_q = true;
break;
}
for (int i = 0; i < sig_d.size(); i++) {
SigBit d_bit = sigmap(sig_d[i]);
SigBit q_bit = sigmap(sig_q[i]);
State d_val = cache.count(d_bit) ? cache[d_bit] : State::S0;
State q_val = cache.count(q_bit) ? cache[q_bit] : State::S0;
if (d_val != q_val) {
d_ne_q = true;
break;
}
}