support dist weight as a variable case
This commit is contained in:
parent
2acc674a7a
commit
d6946e0100
|
|
@ -3479,7 +3479,8 @@ class RandomizeVisitor final : public VNVisitor {
|
|||
}
|
||||
}
|
||||
|
||||
// Replace AstDist with weighted bucket selection via AstConstraintIf chain
|
||||
// Replace AstDist with weighted bucket selection via AstConstraintIf chain.
|
||||
// Supports both constant and variable weight expressions.
|
||||
void lowerDistConstraints(AstTask* taskp, AstNode* constrItemsp) {
|
||||
for (AstNode *nextip, *itemp = constrItemsp; itemp; itemp = nextip) {
|
||||
nextip = itemp->nextp();
|
||||
|
|
@ -3492,36 +3493,75 @@ class RandomizeVisitor final : public VNVisitor {
|
|||
|
||||
struct BucketInfo final {
|
||||
AstNodeExpr* rangep;
|
||||
uint64_t effectiveWeight;
|
||||
AstNodeExpr* weightExprp; // Effective weight as AST expression
|
||||
};
|
||||
std::vector<BucketInfo> buckets;
|
||||
uint64_t totalWeight = 0;
|
||||
|
||||
for (AstDistItem* ditemp = distp->itemsp(); ditemp;
|
||||
ditemp = VN_AS(ditemp->nextp(), DistItem)) {
|
||||
const AstConst* const weightp = VN_CAST(ditemp->weightp(), Const);
|
||||
if (!weightp) continue;
|
||||
const uint64_t w = weightp->toUQuad();
|
||||
if (w == 0) continue;
|
||||
// Skip compile-time zero weights
|
||||
if (const AstConst* const constp = VN_CAST(ditemp->weightp(), Const)) {
|
||||
if (constp->toUQuad() == 0) continue;
|
||||
}
|
||||
|
||||
uint64_t effectiveW = w;
|
||||
// := is per-value weight, so multiply by range size
|
||||
// Clone and extend weight to 64-bit
|
||||
AstNodeExpr* weightExprp
|
||||
= new AstExtend{fl, ditemp->weightp()->cloneTreePure(false), 64};
|
||||
|
||||
// := is per-value weight; for ranges multiply by range size
|
||||
if (!ditemp->isWhole()) {
|
||||
if (const AstInsideRange* const irp = VN_CAST(ditemp->rangep(), InsideRange)) {
|
||||
const AstConst* const lop = VN_CAST(irp->lhsp(), Const);
|
||||
const AstConst* const hip = VN_CAST(irp->rhsp(), Const);
|
||||
if (lop && hip && hip->toUQuad() >= lop->toUQuad())
|
||||
effectiveW = w * (hip->toUQuad() - lop->toUQuad() + 1);
|
||||
AstNodeExpr* rangeSizep;
|
||||
if (lop && hip) {
|
||||
const uint64_t rangeSize = hip->toUQuad() - lop->toUQuad() + 1;
|
||||
rangeSizep = new AstConst{fl, AstConst::Unsized64{}, rangeSize};
|
||||
} else {
|
||||
// Variable range bounds: (hi - lo + 1) at runtime
|
||||
rangeSizep = new AstAdd{
|
||||
fl, new AstConst{fl, AstConst::Unsized64{}, 1},
|
||||
new AstSub{
|
||||
fl, new AstExtend{fl, irp->rhsp()->cloneTreePure(false), 64},
|
||||
new AstExtend{fl, irp->lhsp()->cloneTreePure(false), 64}}};
|
||||
rangeSizep->dtypeSetUInt64();
|
||||
}
|
||||
weightExprp = new AstMul{fl, weightExprp, rangeSizep};
|
||||
weightExprp->dtypeSetUInt64();
|
||||
}
|
||||
}
|
||||
|
||||
buckets.push_back({ditemp->rangep(), effectiveW});
|
||||
totalWeight += effectiveW;
|
||||
buckets.push_back({ditemp->rangep(), weightExprp});
|
||||
}
|
||||
|
||||
if (buckets.empty() || totalWeight == 0) continue;
|
||||
if (buckets.empty()) continue;
|
||||
|
||||
const std::string bucketName = "__Vdist_bucket" + cvtToStr(m_distNum++);
|
||||
// Build totalWeight expression: w[0] + w[1] + ... + w[N-1]
|
||||
AstNodeExpr* totalWeightExprp = nullptr;
|
||||
for (auto& bucket : buckets) {
|
||||
if (!totalWeightExprp) {
|
||||
totalWeightExprp = bucket.weightExprp->cloneTreePure(false);
|
||||
} else {
|
||||
totalWeightExprp = new AstAdd{fl, totalWeightExprp,
|
||||
bucket.weightExprp->cloneTreePure(false)};
|
||||
totalWeightExprp->dtypeSetUInt64();
|
||||
}
|
||||
}
|
||||
|
||||
// Store totalWeight in temp var (evaluated once, used twice)
|
||||
const int distId = m_distNum++;
|
||||
const std::string totalName = "__Vdist_total" + cvtToStr(distId);
|
||||
AstVar* const totalVarp
|
||||
= new AstVar{fl, VVarType::BLOCKTEMP, totalName, taskp->findUInt64DType()};
|
||||
totalVarp->noSubst(true);
|
||||
totalVarp->lifetime(VLifetime::AUTOMATIC_EXPLICIT);
|
||||
totalVarp->funcLocal(true);
|
||||
taskp->addStmtsp(totalVarp);
|
||||
taskp->addStmtsp(
|
||||
new AstAssign{fl, new AstVarRef{fl, totalVarp, VAccess::WRITE}, totalWeightExprp});
|
||||
|
||||
// bucketVar = (rand64() % totalWeight) + 1
|
||||
const std::string bucketName = "__Vdist_bucket" + cvtToStr(distId);
|
||||
AstVar* const bucketVarp
|
||||
= new AstVar{fl, VVarType::BLOCKTEMP, bucketName, taskp->findUInt64DType()};
|
||||
bucketVarp->noSubst(true);
|
||||
|
|
@ -3529,23 +3569,31 @@ class RandomizeVisitor final : public VNVisitor {
|
|||
bucketVarp->funcLocal(true);
|
||||
taskp->addStmtsp(bucketVarp);
|
||||
|
||||
// bucketVar = (rand64() % totalWeight) + 1
|
||||
AstNodeExpr* randp = new AstRand{fl, nullptr, false};
|
||||
randp->dtypeSetUInt64();
|
||||
taskp->addStmtsp(new AstAssign{
|
||||
fl, new AstVarRef{fl, bucketVarp, VAccess::WRITE},
|
||||
new AstAdd{fl, new AstConst{fl, AstConst::Unsized64{}, 1},
|
||||
new AstModDiv{fl, randp,
|
||||
new AstConst{fl, AstConst::Unsized64{}, totalWeight}}}});
|
||||
new AstAdd{
|
||||
fl, new AstConst{fl, AstConst::Unsized64{}, 1},
|
||||
new AstModDiv{fl, randp, new AstVarRef{fl, totalVarp, VAccess::READ}}}});
|
||||
|
||||
// Build if/else chain keyed on cumulative weights
|
||||
// Build cumulative sum expressions forward: cumSum[i] = w[0]+...+w[i]
|
||||
std::vector<AstNodeExpr*> cumSums;
|
||||
AstNodeExpr* runningSump = nullptr;
|
||||
for (size_t i = 0; i < buckets.size(); ++i) {
|
||||
if (!runningSump) {
|
||||
runningSump = buckets[i].weightExprp->cloneTreePure(false);
|
||||
} else {
|
||||
runningSump = new AstAdd{fl, runningSump,
|
||||
buckets[i].weightExprp->cloneTreePure(false)};
|
||||
runningSump->dtypeSetUInt64();
|
||||
}
|
||||
cumSums.push_back(runningSump->cloneTreePure(true));
|
||||
}
|
||||
|
||||
// Build ConstraintIf chain backward (last bucket is unconditional default)
|
||||
AstNode* chainp = nullptr;
|
||||
uint64_t cumWeight = totalWeight;
|
||||
|
||||
for (int i = static_cast<int>(buckets.size()) - 1; i >= 0; --i) {
|
||||
cumWeight -= buckets[i].effectiveWeight;
|
||||
const uint64_t thisCumWeight = cumWeight + buckets[i].effectiveWeight;
|
||||
|
||||
AstNodeExpr* constraintExprp;
|
||||
if (const AstInsideRange* const irp = VN_CAST(buckets[i].rangep, InsideRange)) {
|
||||
AstNodeExpr* const exprCopy1p = distp->exprp()->cloneTreePure(false);
|
||||
|
|
@ -3574,8 +3622,7 @@ class RandomizeVisitor final : public VNVisitor {
|
|||
chainp = thenp;
|
||||
} else {
|
||||
AstNodeExpr* const condp
|
||||
= new AstLte{fl, new AstVarRef{fl, bucketVarp, VAccess::READ},
|
||||
new AstConst{fl, AstConst::Unsized64{}, thisCumWeight}};
|
||||
= new AstLte{fl, new AstVarRef{fl, bucketVarp, VAccess::READ}, cumSums[i]};
|
||||
chainp = new AstConstraintIf{fl, condp, thenp, chainp};
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3088,22 +3088,10 @@ class WidthVisitor final : public VNVisitor {
|
|||
iterateCheck(nodep, "Dist Item", itemp, CONTEXT_DET, FINAL, subDTypep, EXTEND_EXP);
|
||||
}
|
||||
|
||||
// Keep AstDist for V3Randomize if all items have const weights and simple ranges
|
||||
if (m_constraintp) {
|
||||
bool canLower = true;
|
||||
for (const AstDistItem* itemp = nodep->itemsp(); itemp;
|
||||
itemp = VN_AS(itemp->nextp(), DistItem)) {
|
||||
if (!VN_IS(itemp->weightp(), Const)
|
||||
|| (!VN_IS(itemp->rangep(), Const) && !VN_IS(itemp->rangep(), InsideRange))) {
|
||||
canLower = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (canLower) return;
|
||||
}
|
||||
// Inside a constraint, V3Randomize handles dist lowering with proper weights
|
||||
if (m_constraintp) return;
|
||||
|
||||
// Outside constraint or complex items: lower to inside (ignores weights)
|
||||
nodep->v3warn(CONSTRAINTIGN, "Constraint expression ignored (imperfect distribution)");
|
||||
// Outside constraint: lower to inside (ignores weights)
|
||||
AstNodeExpr* newp = nullptr;
|
||||
for (AstDistItem* itemp = nodep->itemsp(); itemp;
|
||||
itemp = VN_AS(itemp->nextp(), DistItem)) {
|
||||
|
|
|
|||
|
|
@ -25,11 +25,25 @@ class DistZeroWeight;
|
|||
constraint c { x dist { 8'd0 := 0, 8'd1 := 1, 8'd2 := 1 }; }
|
||||
endclass
|
||||
|
||||
class DistVarWeight;
|
||||
rand bit [7:0] x;
|
||||
int w1, w2;
|
||||
constraint c { x dist { 8'd0 := w1, 8'd255 := w2 }; }
|
||||
endclass
|
||||
|
||||
class DistVarWeightRange;
|
||||
rand bit [7:0] x;
|
||||
int w1, w2;
|
||||
constraint c { x dist { [8'd0:8'd9] :/ w1, [8'd10:8'd19] :/ w2 }; }
|
||||
endclass
|
||||
|
||||
module t;
|
||||
initial begin
|
||||
DistScalar sc;
|
||||
DistRange rg;
|
||||
DistZeroWeight zw;
|
||||
DistVarWeight vw;
|
||||
DistVarWeightRange vwr;
|
||||
int count_high;
|
||||
int count_range_high;
|
||||
int total;
|
||||
|
|
@ -70,6 +84,33 @@ module t;
|
|||
`check_range(zw.x, 1, 2);
|
||||
end
|
||||
|
||||
// Variable := scalar weights: w1=1, w2=3 => expect ~75% for value 255
|
||||
vw = new;
|
||||
vw.w1 = 1;
|
||||
vw.w2 = 3;
|
||||
count_high = 0;
|
||||
repeat (total) begin
|
||||
`checkd(vw.randomize(), 1);
|
||||
if (vw.x == 8'd255) count_high++;
|
||||
else `checkd(vw.x, 0);
|
||||
end
|
||||
`check_range(count_high, 1200, 1800);
|
||||
|
||||
// Variable :/ range weights: w1=1, w2=3 => expect ~75% in [10:19]
|
||||
vwr = new;
|
||||
vwr.w1 = 1;
|
||||
vwr.w2 = 3;
|
||||
count_range_high = 0;
|
||||
repeat (total) begin
|
||||
`checkd(vwr.randomize(), 1);
|
||||
if (vwr.x >= 8'd10 && vwr.x <= 8'd19) count_range_high++;
|
||||
else if (vwr.x > 8'd9) begin
|
||||
$write("%%Error: x=%0d outside valid range [0:19]\n", vwr.x);
|
||||
`stop;
|
||||
end
|
||||
end
|
||||
`check_range(count_range_high, 1200, 1800);
|
||||
|
||||
$write("*-* All Finished *-*\n");
|
||||
$finish;
|
||||
end
|
||||
|
|
|
|||
Loading…
Reference in New Issue