diff --git a/passes/silimate/sat_clockgate.cc b/passes/silimate/sat_clockgate.cc index e8c3d1d64..4d4e08899 100644 --- a/passes/silimate/sat_clockgate.cc +++ b/passes/silimate/sat_clockgate.cc @@ -98,7 +98,7 @@ void profileFlipFlops(Module *module, const std::string &filename, const std::st // Configuration static const int DEFAULT_MAX_COVER = 100; // Max candidate signals to consider static const int DEFAULT_MIN_REGS = 10; // Min registers per clock gate -static const int DEFAULT_SIM_ITERATIONS = 10; // Random simulation iterations for pruning +static const int DEFAULT_SIM_ITERATIONS = 1000; // Random simulation iterations for pruning struct SatClockgateWorker { @@ -120,8 +120,8 @@ struct SatClockgateWorker ezSatPtr ez; SatGen satgen; - // Random number generator for simulation - std::mt19937 rng; + // Cached simulation results: [iteration][SigBit] = evaluated State + std::vector> cached_sim_results; // Statistics int accepted_count = 0; @@ -131,7 +131,7 @@ struct SatClockgateWorker SatClockgateWorker(Module *module, int max_cover, int min_regs, int sim_iterations) : module(module), sigmap(module), max_cover(max_cover), min_regs(min_regs), sim_iterations(sim_iterations), - ez(), satgen(ez.get(), &sigmap), rng(42) + ez(), satgen(ez.get(), &sigmap) { // Build driver and sink maps for (auto cell : module->cells()) { @@ -156,16 +156,14 @@ struct SatClockgateWorker ID($dffsre), ID($_DFF_P_), ID($_DFF_N_), ID($_DFFE_PP_), ID($_DFFE_PN_), ID($_DFFE_NP_), ID($_DFFE_NN_))) satgen.importCell(cell); - } - - // Run simulation with random inputs using ConstEval - // Returns false if counterexample found (candidate is definitely invalid) - bool simulationTest(const std::vector &conds, SigSpec sig_d, SigSpec sig_q, bool as_enable) - { + + // Pre-run simulations and cache all signal values + std::mt19937 rng(42); + cached_sim_results.resize(sim_iterations); for (int iter = 0; iter < sim_iterations; iter++) { ConstEval ce(module); - // Generate random values for all input ports + // Set random values for input ports for (auto wire : module->wires()) { if (wire->port_input) { Const rand_val(State::S0, wire->width); @@ -174,8 +172,7 @@ struct SatClockgateWorker ce.set(SigSpec(wire), rand_val); } } - - // Randomize FF Q outputs (they're inputs to combinational logic) + // Set random values for FF Q outputs for (auto cell : module->cells()) { if (cell->type.in(ID($ff), ID($dff), ID($dffe), ID($adff), ID($adffe), ID($sdff), ID($sdffe), ID($sdffce), ID($dffsr), @@ -189,45 +186,55 @@ struct SatClockgateWorker } } - // Evaluate gating condition signals + // Evaluate and cache ALL wire signals + for (auto wire : module->wires()) { + for (int i = 0; i < wire->width; i++) { + SigBit bit(wire, i); + SigSpec sig(bit); + if (ce.eval(sig)) + cached_sim_results[iter][sigmap(bit)] = sig[0].data; + } + } + } + } + + // Check cached simulation results for counterexamples + // Returns false if counterexample found (candidate is definitely invalid) + bool simulationTest(const std::vector &conds, SigSpec sig_d, SigSpec sig_q, bool as_enable) + { + for (int iter = 0; iter < sim_iterations; iter++) { + auto &cache = cached_sim_results[iter]; + + // Lookup gating condition from cache bool combined_cond; if (as_enable) { combined_cond = false; for (auto bit : conds) { - SigSpec sig(bit); - if (ce.eval(sig)) { - if (sig[0] == State::S1) - combined_cond = true; - } + SigBit mapped = sigmap(bit); + if (cache.count(mapped) && cache[mapped] == State::S1) + combined_cond = true; } } else { combined_cond = true; for (auto bit : conds) { - SigSpec sig(bit); - if (ce.eval(sig)) { - if (sig[0] != State::S1) - combined_cond = false; - } else { + SigBit mapped = sigmap(bit); + if (!cache.count(mapped) || cache[mapped] != State::S1) combined_cond = false; - } } } bool gating_active = as_enable ? !combined_cond : combined_cond; - // Evaluate D and Q, check if D != Q - SigSpec d_eval = sig_d; - SigSpec q_eval = sig_q; - bool d_ok = ce.eval(d_eval); - bool q_ok = ce.eval(q_eval); - + // Lookup D and Q from cache, check if D != Q bool d_ne_q = false; - if (d_ok && q_ok) { - for (int i = 0; i < sig_d.size(); i++) { - if (d_eval[i] != q_eval[i]) { - d_ne_q = true; - break; - } + for (int i = 0; i < sig_d.size(); i++) { + SigBit d_bit = sigmap(sig_d[i]); + SigBit q_bit = sigmap(sig_q[i]); + State d_val = cache.count(d_bit) ? cache[d_bit] : State::S0; + State q_val = cache.count(q_bit) ? cache[q_bit] : State::S0; + if (d_val != q_val) { + d_ne_q = true; + break; } }