#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Copyright (C) 2017-2020 The Project X-Ray Authors. # # Use of this source code is governed by a ISC-style # license that can be found in the LICENSE file or at # https://opensource.org/licenses/ISC # # SPDX-License-Identifier: ISC ''' This script solves the fuzzing problem through least-mean-square solution of an overdetermined linear equation system. The advantages of this method are: - Ability to detect negative correlations (tags which require clearing bits) - Can detect partial correlation tag <-> bit. This happens if for a small number of specimens a tag is said to be "1" but in fact it is not due to the way Vivado interprets requested features and encodes them into bitstream. - Ease to detect tags with no corresponding bits by evaluating solution error. The solution is computed using the Tikhonov regularization scheme to ensure numerical stability. The parameter -a can be used to vary the regularization factor. By default each tag is solved separately (best results) while they can be solved all at once (not recommended). For each tag a vector of weights is calculated. Each weight corresponds to one bit. Positive values indicate positive correlation and negative values negative correlation. Each weight vector is normalized so that maximum absolute weight is equal to one. The parameter -t is used to set threshold for those weights. Weights with values above the threshold and below the "minus" threshold are output as candidate bits. For each weight vector a solution error is computed. If the error exceeds threshold specified using the -e parameter then the tag is considered to have no bits. The option -m can be used to filter bits found for the specified tag in all other tags. This allows to remove bits from a "IS_BLOCK_IN_USE" type tag from other tags responsible for enabling other features of that block. ''' import sys import os import argparse import itertools import json import numpy as np import numpy.linalg as linalg from prjxray.util import OpenSafeFile # ============================================================================= def load_data(file_name, tagfilter=lambda tag: True, address_map=None): """ Loads data generated by the segmaker. Parameters ---------- file_name: Name of the text file with data. tagfilter: A function for filtering tags. Should reqturn True or False. address_map: A dict indexed by tuples (address, offset) containing a list of tile names. Returns ------- A list of dicts. Each contains: - "seg": Segment name - "bit": A list of bit names - "tag": A list of tuples (tag name, tag value) """ segdata = None all_segdata = [] with OpenSafeFile(file_name, "r") as fp: for line in fp.readlines(): line = line.strip() # Segment tag if line.startswith("seg"): fields = line.split() if segdata is not None: if len(segdata["tag"]): all_segdata.append(segdata) segdata = None segname = fields[1] # Map segment address to tile name if address_map is not None: address = segname.split("_") address = ( int(address[0], base=16), int(address[1]), ) if address in address_map: segname = "_or_".join(address_map[address]) # Append file name segname = file_name + ":" + segname # Append segdata segdata = {"seg": segname, "bit": [], "tag": []} if segdata is None: continue # Bit tag if line.startswith("bit"): fields = line.split() segdata["bit"].append(fields[1]) # Tag tag if line.startswith("tag"): fields = line.split() if not tagfilter(fields[1]): continue segdata["tag"].append(( fields[1], int(fields[2]), )) # Store the last segment if any if segdata is not None: if len(segdata["tag"]): all_segdata.append(segdata) return all_segdata def write_segbits(file_name, all_tags, all_bits, W): """ Writes solution to a raw database file. Parameters ---------- file_name: Name of the .rdb file. all_tags: List of considered tags. all_bits: List of considered bits. W: Matrix with binary solution. """ lines = [] for r in range(W.shape[0]): bits = [] for c in range(W.shape[1]): w = W[r, c] if w < 0: bits.append("!" + all_bits[c]) if w > 0: bits.append(all_bits[c]) if len(bits) == 0: bits = ["<0 candidates>"] lines.append(all_tags[r] + " " + " ".join(bits) + "\n") with OpenSafeFile(file_name, "w") as fp: for line in lines: fp.write(line) def dump_results(fp, all_tags, all_bits, W, X, E, tag_stats=None): """ Dumps solution results to an open file in a nice readable format. Parameters ---------- fp: An open file or stream all_tags: List of considered tags. all_bits: List of considered bits. W: Matrix with binary solution. X: Matrix with raw solution (floats). E: Vector with solution errors. tag_stats: Tag statistics. """ lines = [] pad_len = max([len(tag) for tag in all_tags]) skip_bit = [] for i in range(len(all_bits)): skip_bit.append((W[:, i] == 0).all()) # Bit names bit_len = 6 for i in range(bit_len): line = " " * (pad_len + 2 + 3) for j in range(len(all_bits)): if skip_bit[j]: continue bname = all_bits[j].ljust(bit_len).replace("_", "|") line += bname[i] if i == (bit_len - 1): if tag_stats is not None: line += " #0 #1 " lines.append(line) # Tags and bit values pad = max([len(tag) for tag in all_tags]) for r in range(W.shape[0]): line = all_tags[r].ljust(pad + 1) if (W[r, :] == 0).all(): line += "(!) " else: line += " " for c in range(W.shape[1]): if skip_bit[c]: continue b = W[r, c] if b < 0: line += "0" elif b > 0: line += "1" else: line += "-" if tag_stats is not None: stat = tag_stats[all_tags[r]] line += " %4d|%4d" % stat x_min = np.min(X[r, :]) x_max = np.max(X[r, :]) line += " lo=%+.3f hi=%+.3f e=%.3f" % (x_min, x_max, E[r]) lines.append(line) lines.append("") # Write for line in lines: fp.write(line + "\n") def dump_solution_to_csv(fp, all_tags, all_bits, X): """ Dumps solution data to CSV. Parameters ---------- fp: An open file or stream all_tags: List of considered tags. all_bits: List of considered bits. X: Matrix with raw solution (floats). """ # Bits line = "," for bit in all_bits: line += bit + "," fp.write(line[:-1] + "\n") # Tags + numbers for r, tag in enumerate(all_tags): line = tag + "," for c in range(X.shape[1]): line += "%+e," % X[r, c] fp.write(line[:-1] + "\n") def dump_correlation_report( fp, all_tags, all_bits, W, C, correlation_exceptions): for i, tag in enumerate(all_tags): # No exceptions (100% correlation) if len(correlation_exceptions[tag]) == 0: continue fp.write(tag + "\n") for j, bit in enumerate(all_bits): if bit not in correlation_exceptions[tag]: continue c = C[i, j] w = W[i, j] # Dump bit correlation factor sgn = "+" if w > 0 else "-" fp.write(" bit %s: (%s) %.1f%%\n" % (bit.ljust(6), sgn, c * 100.0)) # Dump counter-factual cases e = correlation_exceptions[tag][bit] for x, y, ex in e: fp.write(" is %d, should be %d - %s\n" % (x, y, ex)) fp.write("\n") # ============================================================================= def build_matrices(all_tags, all_bits, segdata, bias=0.0): """ Builds matrices for the linear equation system to be solved. Parameters ---------- all_tags: List of considered tags. all_bits: List of considered bits. segdata: List of segdata used. bias: T.B.D. """ M = len(segdata) N = len(all_bits) K = len(all_tags) A = np.zeros((M, N), dtype=np.float64) B = np.zeros((M, K), dtype=np.float64) # A matrix for r, c in itertools.product(range(M), range(N)): if all_bits[c] in segdata[r]["bit"]: A[r, c] = +1.0 else: A[r, c] = -1.0 # B matrix for r, c in itertools.product(range(M), range(K)): for t, x in segdata[r]["tag"]: if t == all_tags[c]: v = +1.0 if x > 0 else -1.0 B[r, c] = v + bias return A, B def compute_error(A, B, X): """ Computes solution error. Parameters ---------- A: Matrix A B: Matrix B X: Matrix with computed solution. Returns ------- A vector with errors """ K = B.shape[1] # Compute error Bx = np.matmul(A, X) E = np.empty((K)) for k in range(K): E[k] = np.sqrt(np.sum(np.square(Bx[:, k] - B[:, k]))) return E # ============================================================================= def solve_lms(all_tags, all_bits, segdata, bias=0.0): """ Solves using direct least square solution (NumPy) Parameters ---------- all_tags: List of considered tags. all_bits: List of considered bits. segdata: List of segdata used. bias: T.B.D. """ # Build matrices A, B = build_matrices(all_tags, all_bits, segdata, bias) # Solve X, res, r, s = linalg.lstsq(A, B, rcond=None) return X, compute_error(A, B, X) def solve_tichonov(all_tags, all_bits, segdata, bias=0.0, a=0.0): """ Solves using Tichonov regularization method. Parameters ---------- all_tags: List of considered tags. all_bits: List of considered bits. segdata: List of segdata used. bias: T.B.D. a: Regularization coefficient. Returns ------- Tuple with: - Solution matrix X - Error vector. """ M = len(segdata) N = len(all_bits) K = len(all_tags) # Build matrices A, B = build_matrices(all_tags, all_bits, segdata, bias) # Tikhonov regularization # https://en.wikipedia.org/wiki/Tikhonov_regularization AtA = np.matmul(A.T, A) AtB = np.matmul(A.T, B) X = np.matmul(np.linalg.inv(AtA + a * np.eye(N)), AtB) return X, compute_error(A, B, X) # ============================================================================= def solve_onebyone(all_tags, all_bits, segdata, solver=solve_lms, **kw): """ Solves each tag separately in one-by-one fashion. Parameters ---------- all_tags: List of considered tags. all_bits: List of considered bits. segdata: List of segdata used. solver: Solver function. **kw: Parameters to solver function. Returns ------- Tuple with: - Solution matrix X - Error vector. """ X = np.empty((len(all_bits), len(all_tags))) E = np.empty((len(all_tags))) for i, tag in enumerate(all_tags): tag_segdata = [ data for data in segdata if tag in [t[0] for t in data["tag"]] ] print("%s #%d" % (tag, len(tag_segdata))) X1, E1 = solver([tag], all_bits, tag_segdata, **kw) X[:, i] = X1[:, 0] E[i] = E1[0] return X, E # ============================================================================= def detect_candidates(X, th, norm=None): """ Detects candidate bits. Parameters ---------- X: Matrix with solution th: Threshold norm: Normalization scheme. See code. Returns ------- A tuple with: - Binary solution matrix W - Transposed matrix X """ Xt = np.array(X.T) W = np.zeros_like(Xt, dtype=int) if norm == "max_abs": Nv = np.max(np.abs(Xt), axis=1) Xt /= np.tile(Nv[:, None], (1, Xt.shape[1])) W[Xt < -th] = -1 W[Xt > +th] = +1 return W, X.T # ============================================================================= def compute_bit_correlations(tags_to_solve, bits_to_solve, segdata, W): """ Basing on solution given in the matrix W returns a matrix C with correlation coefficients of each bit. Also returns a dict of dicts indexed by tag names and bit names with correlation exceptions - concrete specimen names where the correlation does not occur. """ C = np.zeros_like(W, dtype=float) exceptions = {} for i, tag in enumerate(tags_to_solve): # Filter data for this tag tag_segdata = [ data for data in segdata if tag in [t[0] for t in data["tag"]] ] exceptions[tag] = {} # Compute bit correlation for j, bit in enumerate(bits_to_solve): w = W[i, j] # No correlation with that bit if w == 0: continue corr_sum = 0 corr_count = 0 # Compute for one bit for k, data in enumerate(tag_segdata): bits = data["bit"] vt = [v for t, v in data["tag"] if t == tag][0] vb = 1 if bit in bits else 0 # Negative correlation if w < 0: vt = int(1 - vt) else: vt = int(vt) # Correlates if vt == vb: corr_sum += 1 # Does not correlate else: if bit not in exceptions[tag]: exceptions[tag][bit] = [] exceptions[tag][bit].append(( vb, vt, data["seg"], )) corr_count += 1 # Store correlation C[i, j] = corr_sum / corr_count return C, exceptions def compute_tag_stats(all_tags, segdata): """ Counts occurrence of all considered tags Parameters ---------- all_tags: Considered tags segdata: List of segdata used Returns ------- A dict indexed by tag name with tuples containing 0 and 1 occurrence count. """ stats = {} for i, tag in enumerate(all_tags): count0 = 0 count1 = 0 for data in segdata: for t, v in data["tag"]: if t == tag: if v > 0: count1 += 1 else: count0 += 1 stats[tag] = ( count0, count1, ) return stats def sort_bits(bit_name): """ Utility function for sorting bits. """ frm, ofs = bit_name.split("_") return ( int(frm), int(ofs), ) def build_address_map(tilegrid_file): """ Loads the tilegrid and generates a map (baseaddr, offset) -> tile name(s). Parameters ---------- tilegrid_file: The tilegrid.json file/ Returns ------- A dict with lists of tile names. """ address_map = {} # Load tilegrid with OpenSafeFile(tilegrid_file, "r") as fp: tilegrid = json.load(fp) # Loop over tiles for tile_name, tile_data in tilegrid.items(): # No bits or bits empty if "bits" not in tile_data: continue if not len(tile_data["bits"]): continue bits = tile_data["bits"] # No bus if "CLB_IO_CLK" not in bits: continue bus = bits["CLB_IO_CLK"] # Make the address as integers baseaddr = int(bus["baseaddr"], 16) offset = int(bus["offset"]) address = ( baseaddr, offset, ) # Add tile to the map if address not in address_map: address_map[address] = [] address_map[address].append(tile_name) return address_map # ============================================================================= class FileOrStream(object): def __init__(self, file_name, stream=sys.stdout): self.file_name = file_name self.stream = stream self.fp = None def __enter__(self): if self.file_name is None: return self.stream if self.file_name == "-": return self.stream self.fp = open(self.file_name, "w") return self.fp def __exit__(self, exc_typ, exc_val, exc_tb): if self.fp is not None: self.fp.close() # ============================================================================= def main(): """ The main. """ # Parse arguments parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument( "files", nargs="*", type=str, help="Input file(s) generated by segmaker") parser.add_argument( "-o", type=str, default="segbits.rdb", help="Output database file (def. segbits.rdb)") parser.add_argument( "-f", type=str, default=None, help="Tag filter. Processes only tags containing the specified text") parser.add_argument( "-t", type=float, default=0.95, help="Candidate threshold (def. 0.95)") parser.add_argument( "-e", type=float, default=0.1, help="RMS error threshold below which a tag is rejected (def. 0.1)") parser.add_argument( "-a", type=float, default=0.01, help="Regularization coefficient (def. 0.01)") parser.add_argument( "--all", action="store_true", help="Solve all tags at once (may give worse results)") parser.add_argument( "-x", type=str, default=None, help="A CSV file name to Write the numerical solution to") parser.add_argument( "-r", type=str, default=None, help= "A text file name to write bit correlation report to. Specify '-' for stdout" ) parser.add_argument( "-m", type=str, default=None, help="Mask bits found for this feature in all other features") parser.add_argument("-b", type=float, default=0.0, help="Bias") parser.add_argument("-no_0", action="store_true", help="Do not output 0s") parser.add_argument("-no_1", action="store_true", help="Do not output 1s") args = parser.parse_args() # Build (baseaddr, offset) -> tile name map database_dir = os.path.join( os.getenv("XRAY_DATABASE_DIR"), os.getenv("XRAY_DATABASE"), os.getenv("XRAY_FABRIC")) tilegrid_file = os.path.join(database_dir, "tilegrid.json") address_map = build_address_map(tilegrid_file) # Compute threshold th = args.t # Load and filter segdata segdata = [] def tagfilter(tag): if args.f is None: return True return args.f in tag for name in args.files: print(name) segdata.extend(load_data(name, tagfilter, address_map)) # Make list of all bits all_bits = set() for seg in segdata: all_bits |= set(seg["bit"]) all_bits = sorted(list(all_bits), key=sort_bits) # Detect bits that are always set const1_bits = set(all_bits) for seg in segdata: const1_bits &= set(seg["bit"]) # Make list of all tags all_tags = set() for seg in segdata: all_tags |= set([tag[0] for tag in seg["tag"]]) all_tags = sorted(list(all_tags)) # Count 0s and 1s for each tag tag_count = {} for seg in segdata: for tag, val in seg["tag"]: if tag not in tag_count: tag_count[tag] = [0, 0] if val > 0: tag_count[tag][1] += 1 else: tag_count[tag][0] += 1 # Identify const0 and const1 tags const_tags = {} for tag in all_tags: if tag_count[tag][0] == 0: const_tags[tag] = 1 if tag_count[tag][1] == 0: const_tags[tag] = 0 const0_tags = [t for t, v in const_tags.items() if v == 0] const1_tags = [t for t, v in const_tags.items() if v == 1] # Print config print("# segs:", len(segdata)) print("# tags:", len(all_tags)) print("# bits:", len(all_bits)) print("threshold: %.2f" % th) if len(segdata) == 0: print("No data!") exit(-1) if len(all_tags) == 0: print("No tags!") exit(-1) if len(all_bits) == 0: print("No bits!") exit(-1) if len(const1_bits): print("const 1 bits: " + ", ".join(const1_bits)) if len(const0_tags): print("const 0 tags: " + ", ".join(const0_tags)) if len(const1_tags): print("const 1 tags: " + ", ".join(const1_tags)) # Data to solve tags_to_solve = list(all_tags) bits_to_solve = list(all_bits) for tag in const_tags.keys(): tags_to_solve.remove(tag) for bit in const1_bits: bits_to_solve.remove(bit) # Statistics tag_stats = compute_tag_stats(tags_to_solve, segdata) # Solve print("Solving...") if args.all: X, E = solve_tichonov( tags_to_solve, bits_to_solve, segdata, bias=args.b, a=args.a) else: X, E = solve_onebyone( tags_to_solve, bits_to_solve, segdata, solver=solve_tichonov, bias=args.b, a=args.a) # Detect candidate bits W, X = detect_candidates(X, th, norm="max_abs") # Mask if args.m is not None: print("Masking out %s" % args.m) tags = [t for t in tags_to_solve if args.m in t] for tag in tags: i = tags_to_solve.index(tag) for r in range(len(tags_to_solve)): if r == i: continue for c in range(len(bits_to_solve)): if W[r, c] == W[i, c]: W[r, c] = 0 # Reject 0s and/or 1s if args.no_0: W[W < 0] = 0 if args.no_1: W[W > 0] = 0 # Reject tags with error greater than threshold for r in range(X.shape[0]): if E[r] > args.e: W[r, :] = 0 # Compute correlation C, correlation_exceptions = compute_bit_correlations( tags_to_solve, bits_to_solve, segdata, W) # Write segbits write_segbits(args.o, tags_to_solve, bits_to_solve, W) # Dump to CSV if args.x is not None: with OpenSafeFile(args.x, "w") as fp: dump_solution_to_csv(fp, tags_to_solve, bits_to_solve, X) # Dump results dump_results(sys.stdout, tags_to_solve, bits_to_solve, W, X, E, tag_stats) # Dump correlation report if args.r is not None: if args.r != "-": print("Dumping bit correlation report to '{}'".format(args.r)) with FileOrStream(args.r, sys.stdout) as fp: dump_correlation_report( fp, tags_to_solve, bits_to_solve, W, C, correlation_exceptions) # ============================================================================= if __name__ == "__main__": main()