diff --git a/src/V3Randomize.cpp b/src/V3Randomize.cpp index fd1165341..25f13a704 100644 --- a/src/V3Randomize.cpp +++ b/src/V3Randomize.cpp @@ -3479,7 +3479,8 @@ class RandomizeVisitor final : public VNVisitor { } } - // Replace AstDist with weighted bucket selection via AstConstraintIf chain + // Replace AstDist with weighted bucket selection via AstConstraintIf chain. + // Supports both constant and variable weight expressions. void lowerDistConstraints(AstTask* taskp, AstNode* constrItemsp) { for (AstNode *nextip, *itemp = constrItemsp; itemp; itemp = nextip) { nextip = itemp->nextp(); @@ -3492,36 +3493,75 @@ class RandomizeVisitor final : public VNVisitor { struct BucketInfo final { AstNodeExpr* rangep; - uint64_t effectiveWeight; + AstNodeExpr* weightExprp; // Effective weight as AST expression }; std::vector buckets; - uint64_t totalWeight = 0; for (AstDistItem* ditemp = distp->itemsp(); ditemp; ditemp = VN_AS(ditemp->nextp(), DistItem)) { - const AstConst* const weightp = VN_CAST(ditemp->weightp(), Const); - if (!weightp) continue; - const uint64_t w = weightp->toUQuad(); - if (w == 0) continue; + // Skip compile-time zero weights + if (const AstConst* const constp = VN_CAST(ditemp->weightp(), Const)) { + if (constp->toUQuad() == 0) continue; + } - uint64_t effectiveW = w; - // := is per-value weight, so multiply by range size + // Clone and extend weight to 64-bit + AstNodeExpr* weightExprp + = new AstExtend{fl, ditemp->weightp()->cloneTreePure(false), 64}; + + // := is per-value weight; for ranges multiply by range size if (!ditemp->isWhole()) { if (const AstInsideRange* const irp = VN_CAST(ditemp->rangep(), InsideRange)) { const AstConst* const lop = VN_CAST(irp->lhsp(), Const); const AstConst* const hip = VN_CAST(irp->rhsp(), Const); - if (lop && hip && hip->toUQuad() >= lop->toUQuad()) - effectiveW = w * (hip->toUQuad() - lop->toUQuad() + 1); + AstNodeExpr* rangeSizep; + if (lop && hip) { + const uint64_t rangeSize = hip->toUQuad() - lop->toUQuad() + 1; + rangeSizep = new AstConst{fl, AstConst::Unsized64{}, rangeSize}; + } else { + // Variable range bounds: (hi - lo + 1) at runtime + rangeSizep = new AstAdd{ + fl, new AstConst{fl, AstConst::Unsized64{}, 1}, + new AstSub{ + fl, new AstExtend{fl, irp->rhsp()->cloneTreePure(false), 64}, + new AstExtend{fl, irp->lhsp()->cloneTreePure(false), 64}}}; + rangeSizep->dtypeSetUInt64(); + } + weightExprp = new AstMul{fl, weightExprp, rangeSizep}; + weightExprp->dtypeSetUInt64(); } } - buckets.push_back({ditemp->rangep(), effectiveW}); - totalWeight += effectiveW; + buckets.push_back({ditemp->rangep(), weightExprp}); } - if (buckets.empty() || totalWeight == 0) continue; + if (buckets.empty()) continue; - const std::string bucketName = "__Vdist_bucket" + cvtToStr(m_distNum++); + // Build totalWeight expression: w[0] + w[1] + ... + w[N-1] + AstNodeExpr* totalWeightExprp = nullptr; + for (auto& bucket : buckets) { + if (!totalWeightExprp) { + totalWeightExprp = bucket.weightExprp->cloneTreePure(false); + } else { + totalWeightExprp = new AstAdd{fl, totalWeightExprp, + bucket.weightExprp->cloneTreePure(false)}; + totalWeightExprp->dtypeSetUInt64(); + } + } + + // Store totalWeight in temp var (evaluated once, used twice) + const int distId = m_distNum++; + const std::string totalName = "__Vdist_total" + cvtToStr(distId); + AstVar* const totalVarp + = new AstVar{fl, VVarType::BLOCKTEMP, totalName, taskp->findUInt64DType()}; + totalVarp->noSubst(true); + totalVarp->lifetime(VLifetime::AUTOMATIC_EXPLICIT); + totalVarp->funcLocal(true); + taskp->addStmtsp(totalVarp); + taskp->addStmtsp( + new AstAssign{fl, new AstVarRef{fl, totalVarp, VAccess::WRITE}, totalWeightExprp}); + + // bucketVar = (rand64() % totalWeight) + 1 + const std::string bucketName = "__Vdist_bucket" + cvtToStr(distId); AstVar* const bucketVarp = new AstVar{fl, VVarType::BLOCKTEMP, bucketName, taskp->findUInt64DType()}; bucketVarp->noSubst(true); @@ -3529,23 +3569,31 @@ class RandomizeVisitor final : public VNVisitor { bucketVarp->funcLocal(true); taskp->addStmtsp(bucketVarp); - // bucketVar = (rand64() % totalWeight) + 1 AstNodeExpr* randp = new AstRand{fl, nullptr, false}; randp->dtypeSetUInt64(); taskp->addStmtsp(new AstAssign{ fl, new AstVarRef{fl, bucketVarp, VAccess::WRITE}, - new AstAdd{fl, new AstConst{fl, AstConst::Unsized64{}, 1}, - new AstModDiv{fl, randp, - new AstConst{fl, AstConst::Unsized64{}, totalWeight}}}}); + new AstAdd{ + fl, new AstConst{fl, AstConst::Unsized64{}, 1}, + new AstModDiv{fl, randp, new AstVarRef{fl, totalVarp, VAccess::READ}}}}); - // Build if/else chain keyed on cumulative weights + // Build cumulative sum expressions forward: cumSum[i] = w[0]+...+w[i] + std::vector cumSums; + AstNodeExpr* runningSump = nullptr; + for (size_t i = 0; i < buckets.size(); ++i) { + if (!runningSump) { + runningSump = buckets[i].weightExprp->cloneTreePure(false); + } else { + runningSump = new AstAdd{fl, runningSump, + buckets[i].weightExprp->cloneTreePure(false)}; + runningSump->dtypeSetUInt64(); + } + cumSums.push_back(runningSump->cloneTreePure(true)); + } + + // Build ConstraintIf chain backward (last bucket is unconditional default) AstNode* chainp = nullptr; - uint64_t cumWeight = totalWeight; - for (int i = static_cast(buckets.size()) - 1; i >= 0; --i) { - cumWeight -= buckets[i].effectiveWeight; - const uint64_t thisCumWeight = cumWeight + buckets[i].effectiveWeight; - AstNodeExpr* constraintExprp; if (const AstInsideRange* const irp = VN_CAST(buckets[i].rangep, InsideRange)) { AstNodeExpr* const exprCopy1p = distp->exprp()->cloneTreePure(false); @@ -3574,8 +3622,7 @@ class RandomizeVisitor final : public VNVisitor { chainp = thenp; } else { AstNodeExpr* const condp - = new AstLte{fl, new AstVarRef{fl, bucketVarp, VAccess::READ}, - new AstConst{fl, AstConst::Unsized64{}, thisCumWeight}}; + = new AstLte{fl, new AstVarRef{fl, bucketVarp, VAccess::READ}, cumSums[i]}; chainp = new AstConstraintIf{fl, condp, thenp, chainp}; } } diff --git a/src/V3Width.cpp b/src/V3Width.cpp index 614b58f48..de290ab6b 100644 --- a/src/V3Width.cpp +++ b/src/V3Width.cpp @@ -3088,22 +3088,10 @@ class WidthVisitor final : public VNVisitor { iterateCheck(nodep, "Dist Item", itemp, CONTEXT_DET, FINAL, subDTypep, EXTEND_EXP); } - // Keep AstDist for V3Randomize if all items have const weights and simple ranges - if (m_constraintp) { - bool canLower = true; - for (const AstDistItem* itemp = nodep->itemsp(); itemp; - itemp = VN_AS(itemp->nextp(), DistItem)) { - if (!VN_IS(itemp->weightp(), Const) - || (!VN_IS(itemp->rangep(), Const) && !VN_IS(itemp->rangep(), InsideRange))) { - canLower = false; - break; - } - } - if (canLower) return; - } + // Inside a constraint, V3Randomize handles dist lowering with proper weights + if (m_constraintp) return; - // Outside constraint or complex items: lower to inside (ignores weights) - nodep->v3warn(CONSTRAINTIGN, "Constraint expression ignored (imperfect distribution)"); + // Outside constraint: lower to inside (ignores weights) AstNodeExpr* newp = nullptr; for (AstDistItem* itemp = nodep->itemsp(); itemp; itemp = VN_AS(itemp->nextp(), DistItem)) { diff --git a/test_regress/t/t_constraint_dist_weight.v b/test_regress/t/t_constraint_dist_weight.v index 60ad8e450..93d0b8214 100644 --- a/test_regress/t/t_constraint_dist_weight.v +++ b/test_regress/t/t_constraint_dist_weight.v @@ -25,11 +25,25 @@ class DistZeroWeight; constraint c { x dist { 8'd0 := 0, 8'd1 := 1, 8'd2 := 1 }; } endclass +class DistVarWeight; + rand bit [7:0] x; + int w1, w2; + constraint c { x dist { 8'd0 := w1, 8'd255 := w2 }; } +endclass + +class DistVarWeightRange; + rand bit [7:0] x; + int w1, w2; + constraint c { x dist { [8'd0:8'd9] :/ w1, [8'd10:8'd19] :/ w2 }; } +endclass + module t; initial begin DistScalar sc; DistRange rg; DistZeroWeight zw; + DistVarWeight vw; + DistVarWeightRange vwr; int count_high; int count_range_high; int total; @@ -70,6 +84,33 @@ module t; `check_range(zw.x, 1, 2); end + // Variable := scalar weights: w1=1, w2=3 => expect ~75% for value 255 + vw = new; + vw.w1 = 1; + vw.w2 = 3; + count_high = 0; + repeat (total) begin + `checkd(vw.randomize(), 1); + if (vw.x == 8'd255) count_high++; + else `checkd(vw.x, 0); + end + `check_range(count_high, 1200, 1800); + + // Variable :/ range weights: w1=1, w2=3 => expect ~75% in [10:19] + vwr = new; + vwr.w1 = 1; + vwr.w2 = 3; + count_range_high = 0; + repeat (total) begin + `checkd(vwr.randomize(), 1); + if (vwr.x >= 8'd10 && vwr.x <= 8'd19) count_range_high++; + else if (vwr.x > 8'd9) begin + $write("%%Error: x=%0d outside valid range [0:19]\n", vwr.x); + `stop; + end + end + `check_range(count_range_high, 1200, 1800); + $write("*-* All Finished *-*\n"); $finish; end