Alternative solver for the fuzzing problem.

Signed-off-by: Maciej Kurc <mkurc@antmicro.com>
This commit is contained in:
Maciej Kurc 2019-07-22 10:55:40 +02:00
parent 1752fce3b7
commit 92cc9b5875
1 changed files with 769 additions and 0 deletions

769
prjxray/lms_solver.py Executable file
View File

@ -0,0 +1,769 @@
#!/usr/bin/env python3
'''
This script solves the fuzzing problem through least-mean-square solution of
an overdetermined linear equation system.
The advantages of this method are:
- Ability to detect negative correlations (tags whic require clearing bits)
- Can detect partial correlation tag <-> bit. This happens if for a small
number of specimens a tag is said to be "1" but in fact it is not due to
the way Vivado interprets requested features and encodes them into bitstream.
- Ease to detect tags with no corresponding bits by evaluating solution error.
The solution is computed using the Tikhonov regularization scheme to ensure
numerical stability. The parameter -a can be used to vary the regularization
factor.
By default each tag is solved separately (best results) while they can be
solved all at once (not recommended).
For each tag a vector of weights is calculated. Each weight corresponds to one
bit. Positive values indicate positive correlation and negative values negative
correlation.
Each weight vector is normalized so that maximum absolute weight is equal to
one.
The parameter -t is used to set threshold for those weights. Weights with
values above the threshold and below the "minus" threshold are output as
candidate bits.
For each weight vector a solution error is computed. If the error exceeds
threshold specified using the -e parameter then the tag is considered to
have no bits.
The option -m can be used to filter bits found for the specified tag in all
other tags. This allows to remove bits from a "IS_BLOCK_IN_USE" type tag from
other tags responsible for enabling other features of that block.
'''
import sys
import argparse
import itertools
import numpy as np
import numpy.linalg as linalg
# =============================================================================
def load_data(file_name, tagfilter=lambda tag: True):
"""
Loads data generated by the segmaker.
Parameters
----------
file_name:
Name of the text file with data.
tagfilter:
A function for filtering tags. Should reqturn True or False.
Returns
-------
A list of dicts. Each contains:
- "seg": Segment name
- "bit": A list of bit names
- "tag": A list of tuples (tag name, tag value)
"""
segdata = None
all_segdata = []
with open(file_name, "r") as fp:
for line in fp.readlines():
line = line.strip()
# Segment tag
if line.startswith("seg"):
fields = line.split()
if segdata is not None:
if len(segdata["tag"]):
all_segdata.append(segdata)
segdata = None
segdata = {"seg": fields[1], "bit": [], "tag": []}
if segdata is None:
continue
# Bit tag
if line.startswith("bit"):
fields = line.split()
segdata["bit"].append(fields[1])
# Tag tag
if line.startswith("tag"):
fields = line.split()
if not tagfilter(fields[1]):
continue
segdata["tag"].append((
fields[1],
int(fields[2]),
))
# Store the last segment if any
if segdata is not None:
if len(segdata["tag"]):
all_segdata.append(segdata)
return all_segdata
def write_segbits(file_name, all_tags, all_bits, W):
"""
Writes solution to a raw database file.
Parameters
----------
file_name:
Name of the .rdb file.
all_tags:
List of considered tags.
all_bits:
List of considered bits.
W:
Matrix with binary solution.
"""
lines = []
for r in range(W.shape[0]):
bits = []
for c in range(W.shape[1]):
w = W[r, c]
if w < 0:
bits.append("!" + all_bits[c])
if w > 0:
bits.append(all_bits[c])
if len(bits) == 0:
bits = ["<0 candidates>"]
lines.append(all_tags[r] + " " + " ".join(bits) + "\n")
with open(file_name, "w") as fp:
for line in lines:
fp.write(line)
def dump_results(fp, all_tags, all_bits, W, X, E, tag_stats=None):
"""
Dumps solution results to an open file in a nice readable format.
Parameters
----------
fp:
An open file or stream
all_tags:
List of considered tags.
all_bits:
List of considered bits.
W:
Matrix with binary solution.
X:
Matrix with raw solution (floats).
E:
Vector with solution errors.
tag_stats:
Tag statistics.
"""
lines = []
pad_len = max([len(tag) for tag in all_tags])
skip_bit = []
for i in range(len(all_bits)):
skip_bit.append((W[:, i] == 0).all())
# Bit names
bit_len = 6
for i in range(bit_len):
line = " " * (pad_len + 2 + 3)
for j in range(len(all_bits)):
if skip_bit[j]:
continue
bname = all_bits[j].ljust(bit_len).replace("_", "|")
line += bname[i]
if i == (bit_len - 1):
if tag_stats is not None:
line += " #0 #1 "
lines.append(line)
# Tags and bit values
pad = max([len(tag) for tag in all_tags])
for r in range(W.shape[0]):
line = all_tags[r].ljust(pad + 1)
if (W[r, :] == 0).all():
line += "(!) "
else:
line += " "
for c in range(W.shape[1]):
if skip_bit[c]:
continue
b = W[r, c]
if b < 0:
line += "0"
elif b > 0:
line += "1"
else:
line += "-"
if tag_stats is not None:
stat = tag_stats[all_tags[r]]
line += " %4d|%4d" % stat
x_min = np.min(X[r, :])
x_max = np.max(X[r, :])
line += " lo=%+.3f hi=%+.3f e=%.3f" % (x_min, x_max, E[r])
lines.append(line)
# Write
for line in lines:
fp.write(line + "\n")
def dump_solution_to_csv(fp, all_tags, all_bits, X):
"""
Dumps solution data to CSV.
Parameters
----------
fp:
An open file or stream
all_tags:
List of considered tags.
all_bits:
List of considered bits.
X:
Matrix with raw solution (floats).
"""
# Bits
line = ","
for bit in all_bits:
line += bit + ","
fp.write(line[:-1] + "\n")
# Tags + numbers
for r, tag in enumerate(all_tags):
line = tag + ","
for c in range(X.shape[1]):
line += "%+e," % X[r, c]
fp.write(line[:-1] + "\n")
# =============================================================================
def build_matrices(all_tags, all_bits, segdata, bias=0.0):
"""
Builds matrices for the linear equation system to be solved.
Parameters
----------
all_tags:
List of considered tags.
all_bits:
List of considered bits.
segdata:
List of segdata used.
bias:
T.B.D.
"""
M = len(segdata)
N = len(all_bits)
K = len(all_tags)
A = np.zeros((M, N), dtype=np.float64)
B = np.zeros((M, K), dtype=np.float64)
# A matrix
for r, c in itertools.product(range(M), range(N)):
if all_bits[c] in segdata[r]["bit"]:
A[r, c] = +1.0
else:
A[r, c] = -1.0
# B matrix
for r, c in itertools.product(range(M), range(K)):
for t, x in segdata[r]["tag"]:
if t == all_tags[c]:
v = +1.0 if x > 0 else -1.0
B[r, c] = v + bias
return A, B
def compute_error(A, B, X):
"""
Computes solution error.
Parameters
----------
A:
Matrix A
B:
Matrix B
X:
Matrix with computed solution.
Returns
-------
A vector with errors
"""
K = B.shape[1]
# Compute error
Bx = np.matmul(A, X)
E = np.empty((K))
for k in range(K):
E[k] = np.sqrt(np.sum(np.square(Bx[:, k] - B[:, k])))
return E
# =============================================================================
def solve_lms(all_tags, all_bits, segdata, bias=0.0):
"""
Solves using direct least square solution (NumPy)
Parameters
----------
all_tags:
List of considered tags.
all_bits:
List of considered bits.
segdata:
List of segdata used.
bias:
T.B.D.
"""
# Build matrices
A, B = build_matrices(all_tags, all_bits, segdata, bias)
# Solve
X, res, r, s = linalg.lstsq(A, B, rcond=None)
return X, compute_error(A, B, X)
def solve_tichonov(all_tags, all_bits, segdata, bias=0.0, a=0.0):
"""
Solves using Tichonov regularization method.
Parameters
----------
all_tags:
List of considered tags.
all_bits:
List of considered bits.
segdata:
List of segdata used.
bias:
T.B.D.
a:
Regularization coefficient.
Returns
-------
Tuple with:
- Solution matrix X
- Error vector.
"""
M = len(segdata)
N = len(all_bits)
K = len(all_tags)
# Build matrices
A, B = build_matrices(all_tags, all_bits, segdata, bias)
# Tikhonov regularization
# https://en.wikipedia.org/wiki/Tikhonov_regularization
AtA = np.matmul(A.T, A)
AtB = np.matmul(A.T, B)
X = np.matmul(np.linalg.inv(AtA + a * np.eye(N)), AtB)
return X, compute_error(A, B, X)
# =============================================================================
def solve_onebyone(all_tags, all_bits, segdata, solver=solve_lms, **kw):
"""
Solves each tag separately in one-by-one fashion.
Parameters
----------
all_tags:
List of considered tags.
all_bits:
List of considered bits.
segdata:
List of segdata used.
solver:
Solver function.
**kw:
Parameters to solver function.
Returns
-------
Tuple with:
- Solution matrix X
- Error vector.
"""
X = np.empty((len(all_bits), len(all_tags)))
E = np.empty((len(all_tags)))
for i, tag in enumerate(all_tags):
tag_segdata = [
data for data in segdata if tag in [t[0] for t in data["tag"]]
]
print("%s #%d" % (tag, len(tag_segdata)))
X1, E1 = solver([tag], all_bits, tag_segdata, **kw)
X[:, i] = X1[:, 0]
E[i] = E1[0]
return X, E
# =============================================================================
def detect_candidates(X, th, norm=None):
"""
Detects candidate bits.
Parameters
----------
X:
Matrix with solution
th:
Threshold
norm:
Normalization scheme. See code.
Returns
-------
A tuple with:
- Binary solution matrix W
- Transposed matrix X
"""
Xt = np.array(X.T)
W = np.zeros_like(Xt, dtype=int)
if norm == "max_abs":
Nv = np.max(np.abs(Xt), axis=1)
Xt /= np.tile(Nv[:, None], (1, Xt.shape[1]))
W[Xt < -th] = -1
W[Xt > +th] = +1
return W, X.T
# =============================================================================
def compute_tag_stats(all_tags, segdata):
"""
Counts occurrence of all considered tags
Parameters
----------
all_tags:
Considered tags
segdata:
List of segdata used
Returns
-------
A dict indexed by tag name with tuples containing 0 and 1 occurrence count.
"""
stats = {}
for i, tag in enumerate(all_tags):
count0 = 0
count1 = 0
for data in segdata:
for t, v in data["tag"]:
if t == tag:
if v > 0:
count1 += 1
else:
count0 += 1
stats[tag] = (
count0,
count1,
)
return stats
def sort_bits(bit_name):
"""
Utility function for sorting bits.
"""
frm, ofs = bit_name.split("_")
return (
int(frm),
int(ofs),
)
# =============================================================================
def main():
"""
The main.
"""
# Parse arguments
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
"files",
nargs="*",
type=str,
help="Input file(s) generated by segmaker")
parser.add_argument(
"-o",
type=str,
default="segbits.rdb",
help="Output database file (def. segbits.rdb)")
parser.add_argument(
"-f",
type=str,
default=None,
help="Tag filter. Processes only tags containing the specified text")
parser.add_argument(
"-t", type=float, default=0.95, help="Candidate threshold (def. 0.95)")
parser.add_argument(
"-e",
type=float,
default=0.1,
help="RMS error threshold below which a tag is rejected (def. 0.1)")
parser.add_argument(
"-a",
type=float,
default=0.01,
help="Regularization coefficient (def. 0.01)")
parser.add_argument(
"--all",
action="store_true",
help="Solve all tags at once (may give worse results)")
parser.add_argument(
"-x",
type=str,
default=None,
help="A CSV file name to Write the numerical solution to")
parser.add_argument(
"-m",
type=str,
default=None,
help="Mask bits found for this feature in all other features")
parser.add_argument("-b", type=float, default=0.0, help="Bias")
parser.add_argument("-no_0", action="store_true", help="Do not output 0s")
parser.add_argument("-no_1", action="store_true", help="Do not output 1s")
args = parser.parse_args()
# Compute threshold
th = args.t
# Load and filter segdata
segdata = []
def tagfilter(tag):
if args.f is None:
return True
return args.f in tag
for name in args.files:
print(name)
segdata.extend(load_data(name, tagfilter))
# Make list of all bits
all_bits = set()
for seg in segdata:
all_bits |= set(seg["bit"])
all_bits = sorted(list(all_bits), key=sort_bits)
# Detect bits that are always set
const1_bits = set(all_bits)
for seg in segdata:
const1_bits &= set(seg["bit"])
# Make list of all tags
all_tags = set()
for seg in segdata:
all_tags |= set([tag[0] for tag in seg["tag"]])
all_tags = sorted(list(all_tags))
# Count 0s and 1s for each tag
tag_count = {}
for seg in segdata:
for tag, val in seg["tag"]:
if tag not in tag_count:
tag_count[tag] = [0, 0]
if val > 0:
tag_count[tag][1] += 1
else:
tag_count[tag][0] += 1
# Identify const0 and const1 tags
const_tags = {}
for tag in all_tags:
if tag_count[tag][0] == 0:
const_tags[tag] = 1
if tag_count[tag][1] == 0:
const_tags[tag] = 0
const0_tags = [t for t, v in const_tags.items() if v == 0]
const1_tags = [t for t, v in const_tags.items() if v == 1]
# Print config
print("# segs:", len(segdata))
print("# tags:", len(all_tags))
print("# bits:", len(all_bits))
print("threshold: %.2f" % th)
if len(segdata) == 0:
print("No data!")
exit(-1)
if len(all_tags) == 0:
print("No tags!")
exit(-1)
if len(all_bits) == 0:
print("No bits!")
exit(-1)
if len(const1_bits):
print("const 1 bits: " + ", ".join(const1_bits))
if len(const0_tags):
print("const 0 tags: " + ", ".join(const0_tags))
if len(const1_tags):
print("const 1 tags: " + ", ".join(const1_tags))
# Data to solve
tags_to_solve = list(all_tags)
bits_to_solve = list(all_bits)
for tag in const_tags.keys():
tags_to_solve.remove(tag)
for bit in const1_bits:
bits_to_solve.remove(bit)
# Statistics
tag_stats = compute_tag_stats(tags_to_solve, segdata)
# Solve
print("Solving...")
if args.all:
X, E = solve_tichonov(
tags_to_solve, bits_to_solve, segdata, bias=args.b, a=args.a)
else:
X, E = solve_onebyone(
tags_to_solve,
bits_to_solve,
segdata,
solver=solve_tichonov,
bias=args.b,
a=args.a)
# Detect candidate bits
W, X = detect_candidates(X, th, norm="max_abs")
# Mask
if args.m is not None:
print("Masking out %s" % args.m)
tags = [t for t in tags_to_solve if args.m in t]
for tag in tags:
i = tags_to_solve.index(tag)
for r in range(len(tags_to_solve)):
if r == i:
continue
for c in range(len(bits_to_solve)):
if W[r, c] == W[i, c]:
W[r, c] = 0
# Reject 0s and/or 1s
if args.no_0:
W[W < 0] = 0
if args.no_1:
W[W > 0] = 0
# Reject tags with error greater than threshold
for r in range(X.shape[0]):
if E[r] > args.e:
W[r, :] = 0
# Write segbits
write_segbits(args.o, tags_to_solve, bits_to_solve, W)
# Dump to CSV
if args.x is not None:
with open(args.x, "w") as fp:
dump_solution_to_csv(fp, tags_to_solve, bits_to_solve, X)
# Dump
dump_results(sys.stdout, tags_to_solve, bits_to_solve, W, X, E, tag_stats)
# =============================================================================
if __name__ == "__main__":
main()