Fix BitOpTree optimization to consider polarity of frozen node (#3445) (#3459)

* Tests: add a test to another failing case of #3445

* Consider polarity as lsb in BitOpTree optimization.
This commit is contained in:
Yutetsu TAKATSUKASA 2022-06-01 09:26:16 +09:00 committed by GitHub
parent 0c53d19113
commit d64f979f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 28 deletions

View File

@ -111,6 +111,15 @@ class ConstBitOpTreeVisitor final : public VNVisitor {
BitPolarityEntry() = default;
};
struct FrozenNodeInfo final { // Context when a frozen node is found
bool m_polarity;
int m_lsb;
bool operator<(const FrozenNodeInfo& other) const {
if (m_lsb != other.m_lsb) return m_lsb < other.m_lsb;
return m_polarity < other.m_polarity;
}
};
class Restorer final { // Restore the original state unless disableRestore() is called
ConstBitOpTreeVisitor& m_visitor;
const size_t m_polaritiesSize;
@ -299,8 +308,8 @@ class ConstBitOpTreeVisitor final : public VNVisitor {
LeafInfo* m_leafp = nullptr; // AstConst or AstVarRef that currently looking for
const AstNode* const m_rootp; // Root of this AST subtree
std::vector<std::pair<AstNode*, int>>
m_frozenNodes; // Nodes that cannot be optimized, int is lsb
std::vector<std::pair<AstNode*, FrozenNodeInfo>>
m_frozenNodes; // Nodes that cannot be optimized
std::vector<BitPolarityEntry> m_bitPolarities; // Polarity of bits found during iterate()
std::vector<std::unique_ptr<VarInfo>> m_varInfos; // VarInfo for each variable, [0] is nullptr
@ -488,7 +497,7 @@ class ConstBitOpTreeVisitor final : public VNVisitor {
restorer.restoreNow();
// Reach past a cast then add to frozen nodes to be added to final reduction
if (const AstCCast* const castp = VN_CAST(opp, CCast)) opp = castp->lhsp();
m_frozenNodes.emplace_back(opp, m_lsb);
m_frozenNodes.emplace_back(opp, FrozenNodeInfo{m_polarity, m_lsb});
m_failed = origFailed;
continue;
}
@ -653,21 +662,21 @@ public:
}
}
std::map<int, std::vector<AstNode*>> frozenNodes; // Group by LSB
std::map<FrozenNodeInfo, std::vector<AstNode*>> frozenNodes; // Group by FrozenNodeInfo
// Check if frozen terms are clean or not
for (const std::pair<AstNode*, int>& termAndLsb : visitor.m_frozenNodes) {
AstNode* const termp = termAndLsb.first;
for (const auto& frozenInfo : visitor.m_frozenNodes) {
AstNode* const termp = frozenInfo.first;
// Comparison operators are clean
if ((VN_IS(termp, Eq) || VN_IS(termp, Neq) || VN_IS(termp, Lt) || VN_IS(termp, Lte)
|| VN_IS(termp, Gt) || VN_IS(termp, Gte))
&& termAndLsb.second == 0) {
&& frozenInfo.second.m_lsb == 0) {
hasCleanTerm = true;
} else {
// Otherwise, conservatively assume the frozen term is dirty
hasDirtyTerm = true;
UINFO(9, "Dirty frozen term: " << termp << endl);
}
frozenNodes[termAndLsb.second].push_back(termp);
frozenNodes[frozenInfo.second].push_back(termp);
}
// Figure out if a final negation is required
@ -679,7 +688,8 @@ public:
// Add size of reduction tree to op count
resultOps += termps.size() - 1;
for (const auto& lsbAndNodes : frozenNodes) {
if (lsbAndNodes.first > 0) ++resultOps; // Needs AstShiftR
if (lsbAndNodes.first.m_lsb > 0) ++resultOps; // Needs AstShiftR
if (!lsbAndNodes.first.m_polarity) ++resultOps; // Needs AstNot
resultOps += lsbAndNodes.second.size();
}
// Add final polarity flip in Xor tree
@ -690,8 +700,9 @@ public:
if (debug() >= 9) { // LCOV_EXCL_START
cout << "Bitop tree considered: " << endl;
for (AstNode* const termp : termps) termp->dumpTree("Reduced term: ");
for (const std::pair<AstNode*, int>& termp : visitor.m_frozenNodes)
termp.first->dumpTree("Frozen term with lsb " + std::to_string(termp.second)
for (const std::pair<AstNode*, FrozenNodeInfo>& termp : visitor.m_frozenNodes)
termp.first->dumpTree("Frozen term with lsb " + std::to_string(termp.second.m_lsb)
+ " polarity " + std::to_string(termp.second.m_polarity)
+ ": ");
cout << "Needs flipping: " << needsFlip << endl;
cout << "Needs cleaning: " << needsCleaning << endl;
@ -735,16 +746,22 @@ public:
resultp = reduce(resultp, termp);
}
// Add any frozen terms to the reduction
for (auto&& lsbAndNodes : frozenNodes) {
for (auto&& nodes : frozenNodes) {
// nodes.second has same lsb and polarity
AstNode* termp = nullptr;
for (AstNode* const itemp : lsbAndNodes.second) {
for (AstNode* const itemp : nodes.second) {
termp = reduce(termp, itemp->unlinkFrBack());
}
if (lsbAndNodes.first > 0) { // LSB is not 0, so shiftR
if (nodes.first.m_lsb > 0) { // LSB is not 0, so shiftR
AstNodeDType* const dtypep = termp->dtypep();
termp = new AstShiftR{termp->fileline(), termp,
new AstConst(termp->fileline(), AstConst::WidthedValue{},
termp->width(), lsbAndNodes.first)};
termp->width(), nodes.first.m_lsb)};
termp->dtypep(dtypep);
}
if (!nodes.first.m_polarity) { // Polarity is inverted, so append Not
AstNodeDType* const dtypep = termp->dtypep();
termp = new AstNot{termp->fileline(), termp};
termp->dtypep(dtypep);
}
resultp = reduce(resultp, termp);

View File

@ -19,7 +19,7 @@ execute(
);
if ($Self->{vlt}) {
file_grep($Self->{stats}, qr/Optimizations, Const bit op reduction\s+(\d+)/i, 10);
file_grep($Self->{stats}, qr/Optimizations, Const bit op reduction\s+(\d+)/i, 11);
}
ok(1);
1;

View File

@ -4,6 +4,11 @@
// any use, without warranty, 2021 Yutetsu TAKATSUKASA.
// SPDX-License-Identifier: CC0-1.0
// This function always returns 0, so safe to take bitwise OR with any value.
// Calling this function stops constant folding as Verialtor does not know
// what this function returns.
import "DPI-C" context function int fake_dependency();
module t(/*AUTOARG*/
// Inputs
clk
@ -57,7 +62,7 @@ module t(/*AUTOARG*/
$write("[%0t] cyc==%0d crc=%x sum=%x\n", $time, cyc, crc, sum);
if (crc !== 64'hc77bb9b3784ea091) $stop;
// What checksum will we end up with (above print should match)
`define EXPECTED_SUM 64'h194081987b76c71c
`define EXPECTED_SUM 64'hdccb9e7b8b638233
if (sum !== `EXPECTED_SUM) $stop;
$write("*-* All Finished *-*\n");
@ -120,11 +125,6 @@ module bug3182(in, out);
input wire [4:0] in;
output wire out;
// This function always returns 0, so safe to take bitwise OR with any value.
// Calling this function stops constant folding as Verialtor does not know
// what this function returns.
import "DPI-C" context function int fake_dependency();
logic [4:0] bit_source;
/* verilator lint_off WIDTH */
@ -147,16 +147,18 @@ endmodule
// Bug #3445
// An unoptimized node is kept as frozen node, but its LSB were not saved.
// An unoptimized node is kept as frozen node, but its LSB and polarity were not saved.
// AST of RHS of result0 looks as below:
// AND(SHIFTR(AND(WORDSEL(ARRAYSEL(VARREF)), WORDSEL(ARRAYSEL(VARREF)))), 32'd11)
// ~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~
// Two of WORDSELs are frozen nodes. They are under SHIFTR of 11 bits.
//
// Fixing #3445 needs to
// 1. Take AstShiftR into op count when diciding optimizable or not
// (result0 in the test)
// 1. Take AstShiftR and AstNot into op count when diciding optimizable or not
// (result0 and result2 in the test)
// 2. Insert AstShiftR if LSB of the frozen node is not 0 (result1 in the test)
// 3. Insert AstNot if polarity of the frozen node is false (resutl3 in the
// test)
module bug3445(input wire clk, input wire [31:0] in, output wire out);
logic [127:0] d;
always_ff @(posedge clk)
@ -174,20 +176,30 @@ module bug3445(input wire clk, input wire [31:0] in, output wire out);
logic i;
logic [41:0] j;
} packed_struct;
packed_struct st[2];
packed_struct st[4];
// This is always 1'b0, but Verilator cannot notice it.
// This signal helps to reveal wrong optimization of result2 and result3.
logic zero;
always_ff @(posedge clk) begin
st[0] <= d;
st[1] <= st[0];
st[2] <= st[1];
st[3] <= st[2];
zero <= fake_dependency() > 0;
end
logic result0, result1;
logic result0, result1, result2, result3;
always_ff @(posedge clk) begin
// Cannot optimize further.
result0 <= (st[0].g[0] & st[0].h[0]) & (in[0] == 1'b0);
// There are redundant !in[0] terms. They should be simplified.
result1 <= (!in[0] & (st[1].g[0] & st[1].h[0])) & ((in[0] == 1'b0) & !in[0]);
// Cannot optimize further.
result2 <= !(st[2].g[0] & st[2].h[0]) & (zero == 1'b0);
// There are redundant zero terms. They should be simplified.
result3 <= (!zero & !(st[3].g[0] & st[3].h[0])) & ((zero == 1'b0) & !zero);
end
assign out = result0 ^ result1;
assign out = result0 ^ result1 ^ (result2 | result3);
endmodule