From 92cc9b587554877cd7cc49a845fb17f3907cba45 Mon Sep 17 00:00:00 2001 From: Maciej Kurc Date: Mon, 22 Jul 2019 10:55:40 +0200 Subject: [PATCH] Alternative solver for the fuzzing problem. Signed-off-by: Maciej Kurc --- prjxray/lms_solver.py | 769 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 769 insertions(+) create mode 100755 prjxray/lms_solver.py diff --git a/prjxray/lms_solver.py b/prjxray/lms_solver.py new file mode 100755 index 00000000..7f4085e2 --- /dev/null +++ b/prjxray/lms_solver.py @@ -0,0 +1,769 @@ +#!/usr/bin/env python3 +''' +This script solves the fuzzing problem through least-mean-square solution of +an overdetermined linear equation system. + +The advantages of this method are: +- Ability to detect negative correlations (tags whic require clearing bits) +- Can detect partial correlation tag <-> bit. This happens if for a small + number of specimens a tag is said to be "1" but in fact it is not due to + the way Vivado interprets requested features and encodes them into bitstream. +- Ease to detect tags with no corresponding bits by evaluating solution error. + +The solution is computed using the Tikhonov regularization scheme to ensure +numerical stability. The parameter -a can be used to vary the regularization +factor. + +By default each tag is solved separately (best results) while they can be +solved all at once (not recommended). + +For each tag a vector of weights is calculated. Each weight corresponds to one +bit. Positive values indicate positive correlation and negative values negative +correlation. + +Each weight vector is normalized so that maximum absolute weight is equal to +one. + +The parameter -t is used to set threshold for those weights. Weights with +values above the threshold and below the "minus" threshold are output as +candidate bits. + +For each weight vector a solution error is computed. If the error exceeds +threshold specified using the -e parameter then the tag is considered to +have no bits. + +The option -m can be used to filter bits found for the specified tag in all +other tags. This allows to remove bits from a "IS_BLOCK_IN_USE" type tag from +other tags responsible for enabling other features of that block. +''' +import sys +import argparse +import itertools + +import numpy as np +import numpy.linalg as linalg + +# ============================================================================= + + +def load_data(file_name, tagfilter=lambda tag: True): + """ + Loads data generated by the segmaker. + + Parameters + ---------- + + file_name: + Name of the text file with data. + tagfilter: + A function for filtering tags. Should reqturn True or False. + + Returns + ------- + A list of dicts. Each contains: + - "seg": Segment name + - "bit": A list of bit names + - "tag": A list of tuples (tag name, tag value) + """ + + segdata = None + all_segdata = [] + + with open(file_name, "r") as fp: + for line in fp.readlines(): + line = line.strip() + + # Segment tag + if line.startswith("seg"): + fields = line.split() + + if segdata is not None: + if len(segdata["tag"]): + all_segdata.append(segdata) + segdata = None + + segdata = {"seg": fields[1], "bit": [], "tag": []} + + if segdata is None: + continue + + # Bit tag + if line.startswith("bit"): + fields = line.split() + segdata["bit"].append(fields[1]) + + # Tag tag + if line.startswith("tag"): + fields = line.split() + + if not tagfilter(fields[1]): + continue + + segdata["tag"].append(( + fields[1], + int(fields[2]), + )) + + # Store the last segment if any + if segdata is not None: + if len(segdata["tag"]): + all_segdata.append(segdata) + + return all_segdata + + +def write_segbits(file_name, all_tags, all_bits, W): + """ + Writes solution to a raw database file. + + Parameters + ---------- + + file_name: + Name of the .rdb file. + all_tags: + List of considered tags. + all_bits: + List of considered bits. + W: + Matrix with binary solution. + """ + lines = [] + + for r in range(W.shape[0]): + bits = [] + for c in range(W.shape[1]): + w = W[r, c] + if w < 0: + bits.append("!" + all_bits[c]) + if w > 0: + bits.append(all_bits[c]) + + if len(bits) == 0: + bits = ["<0 candidates>"] + + lines.append(all_tags[r] + " " + " ".join(bits) + "\n") + + with open(file_name, "w") as fp: + for line in lines: + fp.write(line) + + +def dump_results(fp, all_tags, all_bits, W, X, E, tag_stats=None): + """ + Dumps solution results to an open file in a nice readable format. + + Parameters + ---------- + + fp: + An open file or stream + all_tags: + List of considered tags. + all_bits: + List of considered bits. + W: + Matrix with binary solution. + X: + Matrix with raw solution (floats). + E: + Vector with solution errors. + tag_stats: + Tag statistics. + """ + lines = [] + + pad_len = max([len(tag) for tag in all_tags]) + + skip_bit = [] + for i in range(len(all_bits)): + skip_bit.append((W[:, i] == 0).all()) + + # Bit names + bit_len = 6 + for i in range(bit_len): + line = " " * (pad_len + 2 + 3) + for j in range(len(all_bits)): + if skip_bit[j]: + continue + bname = all_bits[j].ljust(bit_len).replace("_", "|") + line += bname[i] + + if i == (bit_len - 1): + if tag_stats is not None: + line += " #0 #1 " + + lines.append(line) + + # Tags and bit values + pad = max([len(tag) for tag in all_tags]) + + for r in range(W.shape[0]): + line = all_tags[r].ljust(pad + 1) + + if (W[r, :] == 0).all(): + line += "(!) " + else: + line += " " + + for c in range(W.shape[1]): + if skip_bit[c]: + continue + + b = W[r, c] + if b < 0: + line += "0" + elif b > 0: + line += "1" + else: + line += "-" + + if tag_stats is not None: + stat = tag_stats[all_tags[r]] + line += " %4d|%4d" % stat + + x_min = np.min(X[r, :]) + x_max = np.max(X[r, :]) + line += " lo=%+.3f hi=%+.3f e=%.3f" % (x_min, x_max, E[r]) + + lines.append(line) + + # Write + for line in lines: + fp.write(line + "\n") + + +def dump_solution_to_csv(fp, all_tags, all_bits, X): + """ + Dumps solution data to CSV. + + Parameters + ---------- + + fp: + An open file or stream + all_tags: + List of considered tags. + all_bits: + List of considered bits. + X: + Matrix with raw solution (floats). + """ + + # Bits + line = "," + for bit in all_bits: + line += bit + "," + fp.write(line[:-1] + "\n") + + # Tags + numbers + for r, tag in enumerate(all_tags): + line = tag + "," + for c in range(X.shape[1]): + line += "%+e," % X[r, c] + fp.write(line[:-1] + "\n") + + +# ============================================================================= + + +def build_matrices(all_tags, all_bits, segdata, bias=0.0): + """ + Builds matrices for the linear equation system to be solved. + + Parameters + ---------- + + all_tags: + List of considered tags. + all_bits: + List of considered bits. + segdata: + List of segdata used. + bias: + T.B.D. + """ + + M = len(segdata) + N = len(all_bits) + K = len(all_tags) + + A = np.zeros((M, N), dtype=np.float64) + B = np.zeros((M, K), dtype=np.float64) + + # A matrix + for r, c in itertools.product(range(M), range(N)): + if all_bits[c] in segdata[r]["bit"]: + A[r, c] = +1.0 + else: + A[r, c] = -1.0 + + # B matrix + for r, c in itertools.product(range(M), range(K)): + for t, x in segdata[r]["tag"]: + if t == all_tags[c]: + v = +1.0 if x > 0 else -1.0 + B[r, c] = v + bias + + return A, B + + +def compute_error(A, B, X): + """ + Computes solution error. + + Parameters + ---------- + + A: + Matrix A + B: + Matrix B + X: + Matrix with computed solution. + + Returns + ------- + + A vector with errors + """ + + K = B.shape[1] + + # Compute error + Bx = np.matmul(A, X) + E = np.empty((K)) + for k in range(K): + E[k] = np.sqrt(np.sum(np.square(Bx[:, k] - B[:, k]))) + + return E + + +# ============================================================================= + + +def solve_lms(all_tags, all_bits, segdata, bias=0.0): + """ + Solves using direct least square solution (NumPy) + + Parameters + ---------- + + all_tags: + List of considered tags. + all_bits: + List of considered bits. + segdata: + List of segdata used. + bias: + T.B.D. + """ + + # Build matrices + A, B = build_matrices(all_tags, all_bits, segdata, bias) + + # Solve + X, res, r, s = linalg.lstsq(A, B, rcond=None) + + return X, compute_error(A, B, X) + + +def solve_tichonov(all_tags, all_bits, segdata, bias=0.0, a=0.0): + """ + Solves using Tichonov regularization method. + + Parameters + ---------- + + all_tags: + List of considered tags. + all_bits: + List of considered bits. + segdata: + List of segdata used. + bias: + T.B.D. + a: + Regularization coefficient. + + Returns + ------- + + Tuple with: + - Solution matrix X + - Error vector. + + """ + + M = len(segdata) + N = len(all_bits) + K = len(all_tags) + + # Build matrices + A, B = build_matrices(all_tags, all_bits, segdata, bias) + + # Tikhonov regularization + # https://en.wikipedia.org/wiki/Tikhonov_regularization + AtA = np.matmul(A.T, A) + AtB = np.matmul(A.T, B) + X = np.matmul(np.linalg.inv(AtA + a * np.eye(N)), AtB) + + return X, compute_error(A, B, X) + + +# ============================================================================= + + +def solve_onebyone(all_tags, all_bits, segdata, solver=solve_lms, **kw): + """ + Solves each tag separately in one-by-one fashion. + + Parameters + ---------- + + all_tags: + List of considered tags. + all_bits: + List of considered bits. + segdata: + List of segdata used. + solver: + Solver function. + **kw: + Parameters to solver function. + + Returns + ------- + + Tuple with: + - Solution matrix X + - Error vector. + + """ + + X = np.empty((len(all_bits), len(all_tags))) + E = np.empty((len(all_tags))) + + for i, tag in enumerate(all_tags): + + tag_segdata = [ + data for data in segdata if tag in [t[0] for t in data["tag"]] + ] + print("%s #%d" % (tag, len(tag_segdata))) + + X1, E1 = solver([tag], all_bits, tag_segdata, **kw) + + X[:, i] = X1[:, 0] + E[i] = E1[0] + + return X, E + + +# ============================================================================= + + +def detect_candidates(X, th, norm=None): + """ + Detects candidate bits. + + Parameters + ---------- + + X: + Matrix with solution + th: + Threshold + norm: + Normalization scheme. See code. + + Returns + ------- + + A tuple with: + - Binary solution matrix W + - Transposed matrix X + + """ + + Xt = np.array(X.T) + W = np.zeros_like(Xt, dtype=int) + + if norm == "max_abs": + Nv = np.max(np.abs(Xt), axis=1) + Xt /= np.tile(Nv[:, None], (1, Xt.shape[1])) + + W[Xt < -th] = -1 + W[Xt > +th] = +1 + + return W, X.T + + +# ============================================================================= + + +def compute_tag_stats(all_tags, segdata): + """ + Counts occurrence of all considered tags + + Parameters + ---------- + + all_tags: + Considered tags + segdata: + List of segdata used + + Returns + ------- + + A dict indexed by tag name with tuples containing 0 and 1 occurrence count. + + """ + + stats = {} + + for i, tag in enumerate(all_tags): + count0 = 0 + count1 = 0 + + for data in segdata: + for t, v in data["tag"]: + if t == tag: + if v > 0: + count1 += 1 + else: + count0 += 1 + + stats[tag] = ( + count0, + count1, + ) + return stats + + +def sort_bits(bit_name): + """ + Utility function for sorting bits. + """ + + frm, ofs = bit_name.split("_") + return ( + int(frm), + int(ofs), + ) + + +# ============================================================================= + + +def main(): + """ + The main. + """ + + # Parse arguments + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + + parser.add_argument( + "files", + nargs="*", + type=str, + help="Input file(s) generated by segmaker") + parser.add_argument( + "-o", + type=str, + default="segbits.rdb", + help="Output database file (def. segbits.rdb)") + parser.add_argument( + "-f", + type=str, + default=None, + help="Tag filter. Processes only tags containing the specified text") + parser.add_argument( + "-t", type=float, default=0.95, help="Candidate threshold (def. 0.95)") + parser.add_argument( + "-e", + type=float, + default=0.1, + help="RMS error threshold below which a tag is rejected (def. 0.1)") + parser.add_argument( + "-a", + type=float, + default=0.01, + help="Regularization coefficient (def. 0.01)") + parser.add_argument( + "--all", + action="store_true", + help="Solve all tags at once (may give worse results)") + parser.add_argument( + "-x", + type=str, + default=None, + help="A CSV file name to Write the numerical solution to") + + parser.add_argument( + "-m", + type=str, + default=None, + help="Mask bits found for this feature in all other features") + + parser.add_argument("-b", type=float, default=0.0, help="Bias") + + parser.add_argument("-no_0", action="store_true", help="Do not output 0s") + parser.add_argument("-no_1", action="store_true", help="Do not output 1s") + + args = parser.parse_args() + + # Compute threshold + th = args.t + + # Load and filter segdata + segdata = [] + + def tagfilter(tag): + if args.f is None: + return True + return args.f in tag + + for name in args.files: + print(name) + segdata.extend(load_data(name, tagfilter)) + + # Make list of all bits + all_bits = set() + for seg in segdata: + all_bits |= set(seg["bit"]) + all_bits = sorted(list(all_bits), key=sort_bits) + + # Detect bits that are always set + const1_bits = set(all_bits) + for seg in segdata: + const1_bits &= set(seg["bit"]) + + # Make list of all tags + all_tags = set() + for seg in segdata: + all_tags |= set([tag[0] for tag in seg["tag"]]) + all_tags = sorted(list(all_tags)) + + # Count 0s and 1s for each tag + tag_count = {} + for seg in segdata: + for tag, val in seg["tag"]: + + if tag not in tag_count: + tag_count[tag] = [0, 0] + + if val > 0: + tag_count[tag][1] += 1 + else: + tag_count[tag][0] += 1 + + # Identify const0 and const1 tags + const_tags = {} + for tag in all_tags: + if tag_count[tag][0] == 0: + const_tags[tag] = 1 + if tag_count[tag][1] == 0: + const_tags[tag] = 0 + + const0_tags = [t for t, v in const_tags.items() if v == 0] + const1_tags = [t for t, v in const_tags.items() if v == 1] + + # Print config + print("# segs:", len(segdata)) + print("# tags:", len(all_tags)) + print("# bits:", len(all_bits)) + print("threshold: %.2f" % th) + + if len(segdata) == 0: + print("No data!") + exit(-1) + + if len(all_tags) == 0: + print("No tags!") + exit(-1) + + if len(all_bits) == 0: + print("No bits!") + exit(-1) + + if len(const1_bits): + print("const 1 bits: " + ", ".join(const1_bits)) + + if len(const0_tags): + print("const 0 tags: " + ", ".join(const0_tags)) + if len(const1_tags): + print("const 1 tags: " + ", ".join(const1_tags)) + + # Data to solve + tags_to_solve = list(all_tags) + bits_to_solve = list(all_bits) + + for tag in const_tags.keys(): + tags_to_solve.remove(tag) + for bit in const1_bits: + bits_to_solve.remove(bit) + + # Statistics + tag_stats = compute_tag_stats(tags_to_solve, segdata) + + # Solve + print("Solving...") + if args.all: + X, E = solve_tichonov( + tags_to_solve, bits_to_solve, segdata, bias=args.b, a=args.a) + else: + X, E = solve_onebyone( + tags_to_solve, + bits_to_solve, + segdata, + solver=solve_tichonov, + bias=args.b, + a=args.a) + + # Detect candidate bits + W, X = detect_candidates(X, th, norm="max_abs") + + # Mask + if args.m is not None: + print("Masking out %s" % args.m) + tags = [t for t in tags_to_solve if args.m in t] + for tag in tags: + i = tags_to_solve.index(tag) + for r in range(len(tags_to_solve)): + if r == i: + continue + for c in range(len(bits_to_solve)): + if W[r, c] == W[i, c]: + W[r, c] = 0 + + # Reject 0s and/or 1s + if args.no_0: + W[W < 0] = 0 + if args.no_1: + W[W > 0] = 0 + + # Reject tags with error greater than threshold + for r in range(X.shape[0]): + if E[r] > args.e: + W[r, :] = 0 + + # Write segbits + write_segbits(args.o, tags_to_solve, bits_to_solve, W) + + # Dump to CSV + if args.x is not None: + with open(args.x, "w") as fp: + dump_solution_to_csv(fp, tags_to_solve, bits_to_solve, X) + + # Dump + dump_results(sys.stdout, tags_to_solve, bits_to_solve, W, X, E, tag_stats) + + +# ============================================================================= + +if __name__ == "__main__": + main()