mirror of https://github.com/openXC7/prjxray.git
994 lines
24 KiB
Python
Executable File
994 lines
24 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
'''
|
|
This script solves the fuzzing problem through least-mean-square solution of
|
|
an overdetermined linear equation system.
|
|
|
|
The advantages of this method are:
|
|
- Ability to detect negative correlations (tags which require clearing bits)
|
|
- Can detect partial correlation tag <-> bit. This happens if for a small
|
|
number of specimens a tag is said to be "1" but in fact it is not due to
|
|
the way Vivado interprets requested features and encodes them into bitstream.
|
|
- Ease to detect tags with no corresponding bits by evaluating solution error.
|
|
|
|
The solution is computed using the Tikhonov regularization scheme to ensure
|
|
numerical stability. The parameter -a can be used to vary the regularization
|
|
factor.
|
|
|
|
By default each tag is solved separately (best results) while they can be
|
|
solved all at once (not recommended).
|
|
|
|
For each tag a vector of weights is calculated. Each weight corresponds to one
|
|
bit. Positive values indicate positive correlation and negative values negative
|
|
correlation.
|
|
|
|
Each weight vector is normalized so that maximum absolute weight is equal to
|
|
one.
|
|
|
|
The parameter -t is used to set threshold for those weights. Weights with
|
|
values above the threshold and below the "minus" threshold are output as
|
|
candidate bits.
|
|
|
|
For each weight vector a solution error is computed. If the error exceeds
|
|
threshold specified using the -e parameter then the tag is considered to
|
|
have no bits.
|
|
|
|
The option -m can be used to filter bits found for the specified tag in all
|
|
other tags. This allows to remove bits from a "IS_BLOCK_IN_USE" type tag from
|
|
other tags responsible for enabling other features of that block.
|
|
'''
|
|
import sys
|
|
import os
|
|
import argparse
|
|
import itertools
|
|
import json
|
|
|
|
import numpy as np
|
|
import numpy.linalg as linalg
|
|
|
|
# =============================================================================
|
|
|
|
|
|
def load_data(file_name, tagfilter=lambda tag: True, address_map=None):
|
|
"""
|
|
Loads data generated by the segmaker.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
file_name:
|
|
Name of the text file with data.
|
|
tagfilter:
|
|
A function for filtering tags. Should reqturn True or False.
|
|
address_map:
|
|
A dict indexed by tuples (address, offset) containing a list
|
|
of tile names.
|
|
|
|
Returns
|
|
-------
|
|
A list of dicts. Each contains:
|
|
- "seg": Segment name
|
|
- "bit": A list of bit names
|
|
- "tag": A list of tuples (tag name, tag value)
|
|
"""
|
|
|
|
segdata = None
|
|
all_segdata = []
|
|
|
|
with open(file_name, "r") as fp:
|
|
for line in fp.readlines():
|
|
line = line.strip()
|
|
|
|
# Segment tag
|
|
if line.startswith("seg"):
|
|
fields = line.split()
|
|
|
|
if segdata is not None:
|
|
if len(segdata["tag"]):
|
|
all_segdata.append(segdata)
|
|
segdata = None
|
|
|
|
segname = fields[1]
|
|
|
|
# Map segment address to tile name
|
|
if address_map is not None:
|
|
address = segname.split("_")
|
|
address = (
|
|
int(address[0], base=16),
|
|
int(address[1]),
|
|
)
|
|
if address in address_map:
|
|
segname = "_or_".join(address_map[address])
|
|
|
|
# Append file name
|
|
segname = file_name + ":" + segname
|
|
|
|
# Append segdata
|
|
segdata = {"seg": segname, "bit": [], "tag": []}
|
|
|
|
if segdata is None:
|
|
continue
|
|
|
|
# Bit tag
|
|
if line.startswith("bit"):
|
|
fields = line.split()
|
|
segdata["bit"].append(fields[1])
|
|
|
|
# Tag tag
|
|
if line.startswith("tag"):
|
|
fields = line.split()
|
|
|
|
if not tagfilter(fields[1]):
|
|
continue
|
|
|
|
segdata["tag"].append((
|
|
fields[1],
|
|
int(fields[2]),
|
|
))
|
|
|
|
# Store the last segment if any
|
|
if segdata is not None:
|
|
if len(segdata["tag"]):
|
|
all_segdata.append(segdata)
|
|
|
|
return all_segdata
|
|
|
|
|
|
def write_segbits(file_name, all_tags, all_bits, W):
|
|
"""
|
|
Writes solution to a raw database file.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
file_name:
|
|
Name of the .rdb file.
|
|
all_tags:
|
|
List of considered tags.
|
|
all_bits:
|
|
List of considered bits.
|
|
W:
|
|
Matrix with binary solution.
|
|
"""
|
|
lines = []
|
|
|
|
for r in range(W.shape[0]):
|
|
bits = []
|
|
for c in range(W.shape[1]):
|
|
w = W[r, c]
|
|
if w < 0:
|
|
bits.append("!" + all_bits[c])
|
|
if w > 0:
|
|
bits.append(all_bits[c])
|
|
|
|
if len(bits) == 0:
|
|
bits = ["<0 candidates>"]
|
|
|
|
lines.append(all_tags[r] + " " + " ".join(bits) + "\n")
|
|
|
|
with open(file_name, "w") as fp:
|
|
for line in lines:
|
|
fp.write(line)
|
|
|
|
|
|
def dump_results(fp, all_tags, all_bits, W, X, E, tag_stats=None):
|
|
"""
|
|
Dumps solution results to an open file in a nice readable format.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
fp:
|
|
An open file or stream
|
|
all_tags:
|
|
List of considered tags.
|
|
all_bits:
|
|
List of considered bits.
|
|
W:
|
|
Matrix with binary solution.
|
|
X:
|
|
Matrix with raw solution (floats).
|
|
E:
|
|
Vector with solution errors.
|
|
tag_stats:
|
|
Tag statistics.
|
|
"""
|
|
lines = []
|
|
|
|
pad_len = max([len(tag) for tag in all_tags])
|
|
|
|
skip_bit = []
|
|
for i in range(len(all_bits)):
|
|
skip_bit.append((W[:, i] == 0).all())
|
|
|
|
# Bit names
|
|
bit_len = 6
|
|
for i in range(bit_len):
|
|
line = " " * (pad_len + 2 + 3)
|
|
for j in range(len(all_bits)):
|
|
if skip_bit[j]:
|
|
continue
|
|
bname = all_bits[j].ljust(bit_len).replace("_", "|")
|
|
line += bname[i]
|
|
|
|
if i == (bit_len - 1):
|
|
if tag_stats is not None:
|
|
line += " #0 #1 "
|
|
|
|
lines.append(line)
|
|
|
|
# Tags and bit values
|
|
pad = max([len(tag) for tag in all_tags])
|
|
|
|
for r in range(W.shape[0]):
|
|
line = all_tags[r].ljust(pad + 1)
|
|
|
|
if (W[r, :] == 0).all():
|
|
line += "(!) "
|
|
else:
|
|
line += " "
|
|
|
|
for c in range(W.shape[1]):
|
|
if skip_bit[c]:
|
|
continue
|
|
|
|
b = W[r, c]
|
|
if b < 0:
|
|
line += "0"
|
|
elif b > 0:
|
|
line += "1"
|
|
else:
|
|
line += "-"
|
|
|
|
if tag_stats is not None:
|
|
stat = tag_stats[all_tags[r]]
|
|
line += " %4d|%4d" % stat
|
|
|
|
x_min = np.min(X[r, :])
|
|
x_max = np.max(X[r, :])
|
|
line += " lo=%+.3f hi=%+.3f e=%.3f" % (x_min, x_max, E[r])
|
|
|
|
lines.append(line)
|
|
|
|
lines.append("")
|
|
|
|
# Write
|
|
for line in lines:
|
|
fp.write(line + "\n")
|
|
|
|
|
|
def dump_solution_to_csv(fp, all_tags, all_bits, X):
|
|
"""
|
|
Dumps solution data to CSV.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
fp:
|
|
An open file or stream
|
|
all_tags:
|
|
List of considered tags.
|
|
all_bits:
|
|
List of considered bits.
|
|
X:
|
|
Matrix with raw solution (floats).
|
|
"""
|
|
|
|
# Bits
|
|
line = ","
|
|
for bit in all_bits:
|
|
line += bit + ","
|
|
fp.write(line[:-1] + "\n")
|
|
|
|
# Tags + numbers
|
|
for r, tag in enumerate(all_tags):
|
|
line = tag + ","
|
|
for c in range(X.shape[1]):
|
|
line += "%+e," % X[r, c]
|
|
fp.write(line[:-1] + "\n")
|
|
|
|
|
|
def dump_correlation_report(
|
|
fp, all_tags, all_bits, W, C, correlation_exceptions):
|
|
|
|
for i, tag in enumerate(all_tags):
|
|
|
|
# No exceptions (100% correlation)
|
|
if len(correlation_exceptions[tag]) == 0:
|
|
continue
|
|
|
|
fp.write(tag + "\n")
|
|
|
|
for j, bit in enumerate(all_bits):
|
|
|
|
if bit not in correlation_exceptions[tag]:
|
|
continue
|
|
|
|
c = C[i, j]
|
|
w = W[i, j]
|
|
|
|
# Dump bit correlation factor
|
|
sgn = "+" if w > 0 else "-"
|
|
fp.write(" bit %s: (%s) %.1f%%\n" % (bit.ljust(6), sgn, c * 100.0))
|
|
|
|
# Dump counter-factual cases
|
|
e = correlation_exceptions[tag][bit]
|
|
for x, y, ex in e:
|
|
fp.write(" is %d, should be %d - %s\n" % (x, y, ex))
|
|
|
|
fp.write("\n")
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
def build_matrices(all_tags, all_bits, segdata, bias=0.0):
|
|
"""
|
|
Builds matrices for the linear equation system to be solved.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
all_tags:
|
|
List of considered tags.
|
|
all_bits:
|
|
List of considered bits.
|
|
segdata:
|
|
List of segdata used.
|
|
bias:
|
|
T.B.D.
|
|
"""
|
|
|
|
M = len(segdata)
|
|
N = len(all_bits)
|
|
K = len(all_tags)
|
|
|
|
A = np.zeros((M, N), dtype=np.float64)
|
|
B = np.zeros((M, K), dtype=np.float64)
|
|
|
|
# A matrix
|
|
for r, c in itertools.product(range(M), range(N)):
|
|
if all_bits[c] in segdata[r]["bit"]:
|
|
A[r, c] = +1.0
|
|
else:
|
|
A[r, c] = -1.0
|
|
|
|
# B matrix
|
|
for r, c in itertools.product(range(M), range(K)):
|
|
for t, x in segdata[r]["tag"]:
|
|
if t == all_tags[c]:
|
|
v = +1.0 if x > 0 else -1.0
|
|
B[r, c] = v + bias
|
|
|
|
return A, B
|
|
|
|
|
|
def compute_error(A, B, X):
|
|
"""
|
|
Computes solution error.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
A:
|
|
Matrix A
|
|
B:
|
|
Matrix B
|
|
X:
|
|
Matrix with computed solution.
|
|
|
|
Returns
|
|
-------
|
|
|
|
A vector with errors
|
|
"""
|
|
|
|
K = B.shape[1]
|
|
|
|
# Compute error
|
|
Bx = np.matmul(A, X)
|
|
E = np.empty((K))
|
|
for k in range(K):
|
|
E[k] = np.sqrt(np.sum(np.square(Bx[:, k] - B[:, k])))
|
|
|
|
return E
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
def solve_lms(all_tags, all_bits, segdata, bias=0.0):
|
|
"""
|
|
Solves using direct least square solution (NumPy)
|
|
|
|
Parameters
|
|
----------
|
|
|
|
all_tags:
|
|
List of considered tags.
|
|
all_bits:
|
|
List of considered bits.
|
|
segdata:
|
|
List of segdata used.
|
|
bias:
|
|
T.B.D.
|
|
"""
|
|
|
|
# Build matrices
|
|
A, B = build_matrices(all_tags, all_bits, segdata, bias)
|
|
|
|
# Solve
|
|
X, res, r, s = linalg.lstsq(A, B, rcond=None)
|
|
|
|
return X, compute_error(A, B, X)
|
|
|
|
|
|
def solve_tichonov(all_tags, all_bits, segdata, bias=0.0, a=0.0):
|
|
"""
|
|
Solves using Tichonov regularization method.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
all_tags:
|
|
List of considered tags.
|
|
all_bits:
|
|
List of considered bits.
|
|
segdata:
|
|
List of segdata used.
|
|
bias:
|
|
T.B.D.
|
|
a:
|
|
Regularization coefficient.
|
|
|
|
Returns
|
|
-------
|
|
|
|
Tuple with:
|
|
- Solution matrix X
|
|
- Error vector.
|
|
|
|
"""
|
|
|
|
M = len(segdata)
|
|
N = len(all_bits)
|
|
K = len(all_tags)
|
|
|
|
# Build matrices
|
|
A, B = build_matrices(all_tags, all_bits, segdata, bias)
|
|
|
|
# Tikhonov regularization
|
|
# https://en.wikipedia.org/wiki/Tikhonov_regularization
|
|
AtA = np.matmul(A.T, A)
|
|
AtB = np.matmul(A.T, B)
|
|
X = np.matmul(np.linalg.inv(AtA + a * np.eye(N)), AtB)
|
|
|
|
return X, compute_error(A, B, X)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
def solve_onebyone(all_tags, all_bits, segdata, solver=solve_lms, **kw):
|
|
"""
|
|
Solves each tag separately in one-by-one fashion.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
all_tags:
|
|
List of considered tags.
|
|
all_bits:
|
|
List of considered bits.
|
|
segdata:
|
|
List of segdata used.
|
|
solver:
|
|
Solver function.
|
|
**kw:
|
|
Parameters to solver function.
|
|
|
|
Returns
|
|
-------
|
|
|
|
Tuple with:
|
|
- Solution matrix X
|
|
- Error vector.
|
|
|
|
"""
|
|
|
|
X = np.empty((len(all_bits), len(all_tags)))
|
|
E = np.empty((len(all_tags)))
|
|
|
|
for i, tag in enumerate(all_tags):
|
|
|
|
tag_segdata = [
|
|
data for data in segdata if tag in [t[0] for t in data["tag"]]
|
|
]
|
|
print("%s #%d" % (tag, len(tag_segdata)))
|
|
|
|
X1, E1 = solver([tag], all_bits, tag_segdata, **kw)
|
|
|
|
X[:, i] = X1[:, 0]
|
|
E[i] = E1[0]
|
|
|
|
return X, E
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
def detect_candidates(X, th, norm=None):
|
|
"""
|
|
Detects candidate bits.
|
|
|
|
Parameters
|
|
----------
|
|
|
|
X:
|
|
Matrix with solution
|
|
th:
|
|
Threshold
|
|
norm:
|
|
Normalization scheme. See code.
|
|
|
|
Returns
|
|
-------
|
|
|
|
A tuple with:
|
|
- Binary solution matrix W
|
|
- Transposed matrix X
|
|
|
|
"""
|
|
|
|
Xt = np.array(X.T)
|
|
W = np.zeros_like(Xt, dtype=int)
|
|
|
|
if norm == "max_abs":
|
|
Nv = np.max(np.abs(Xt), axis=1)
|
|
Xt /= np.tile(Nv[:, None], (1, Xt.shape[1]))
|
|
|
|
W[Xt < -th] = -1
|
|
W[Xt > +th] = +1
|
|
|
|
return W, X.T
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
def compute_bit_correlations(tags_to_solve, bits_to_solve, segdata, W):
|
|
"""
|
|
Basing on solution given in the matrix W returns a matrix C with
|
|
correlation coefficients of each bit.
|
|
|
|
Also returns a dict of dicts indexed by tag names and bit names with
|
|
correlation exceptions - concrete specimen names where the correlation
|
|
does not occur.
|
|
"""
|
|
|
|
C = np.zeros_like(W, dtype=float)
|
|
exceptions = {}
|
|
|
|
for i, tag in enumerate(tags_to_solve):
|
|
|
|
# Filter data for this tag
|
|
tag_segdata = [
|
|
data for data in segdata if tag in [t[0] for t in data["tag"]]
|
|
]
|
|
exceptions[tag] = {}
|
|
|
|
# Compute bit correlation
|
|
for j, bit in enumerate(bits_to_solve):
|
|
w = W[i, j]
|
|
|
|
# No correlation with that bit
|
|
if w == 0:
|
|
continue
|
|
|
|
corr_sum = 0
|
|
corr_count = 0
|
|
|
|
# Compute for one bit
|
|
for k, data in enumerate(tag_segdata):
|
|
bits = data["bit"]
|
|
|
|
vt = [v for t, v in data["tag"] if t == tag][0]
|
|
vb = 1 if bit in bits else 0
|
|
|
|
# Negative correlation
|
|
if w < 0:
|
|
vt = int(1 - vt)
|
|
else:
|
|
vt = int(vt)
|
|
|
|
# Correlates
|
|
if vt == vb:
|
|
corr_sum += 1
|
|
# Does not correlate
|
|
else:
|
|
if bit not in exceptions[tag]:
|
|
exceptions[tag][bit] = []
|
|
exceptions[tag][bit].append((
|
|
vb,
|
|
vt,
|
|
data["seg"],
|
|
))
|
|
|
|
corr_count += 1
|
|
|
|
# Store correlation
|
|
C[i, j] = corr_sum / corr_count
|
|
|
|
return C, exceptions
|
|
|
|
|
|
def compute_tag_stats(all_tags, segdata):
|
|
"""
|
|
Counts occurrence of all considered tags
|
|
|
|
Parameters
|
|
----------
|
|
|
|
all_tags:
|
|
Considered tags
|
|
segdata:
|
|
List of segdata used
|
|
|
|
Returns
|
|
-------
|
|
|
|
A dict indexed by tag name with tuples containing 0 and 1 occurrence count.
|
|
|
|
"""
|
|
|
|
stats = {}
|
|
|
|
for i, tag in enumerate(all_tags):
|
|
count0 = 0
|
|
count1 = 0
|
|
|
|
for data in segdata:
|
|
for t, v in data["tag"]:
|
|
if t == tag:
|
|
if v > 0:
|
|
count1 += 1
|
|
else:
|
|
count0 += 1
|
|
|
|
stats[tag] = (
|
|
count0,
|
|
count1,
|
|
)
|
|
return stats
|
|
|
|
|
|
def sort_bits(bit_name):
|
|
"""
|
|
Utility function for sorting bits.
|
|
"""
|
|
|
|
frm, ofs = bit_name.split("_")
|
|
return (
|
|
int(frm),
|
|
int(ofs),
|
|
)
|
|
|
|
|
|
def build_address_map(tilegrid_file):
|
|
"""
|
|
Loads the tilegrid and generates a map (baseaddr, offset) -> tile name(s).
|
|
|
|
Parameters
|
|
----------
|
|
|
|
tilegrid_file:
|
|
The tilegrid.json file/
|
|
|
|
Returns
|
|
-------
|
|
|
|
A dict with lists of tile names.
|
|
|
|
"""
|
|
|
|
address_map = {}
|
|
|
|
# Load tilegrid
|
|
with open(tilegrid_file, "r") as fp:
|
|
tilegrid = json.load(fp)
|
|
|
|
# Loop over tiles
|
|
for tile_name, tile_data in tilegrid.items():
|
|
|
|
# No bits or bits empty
|
|
if "bits" not in tile_data:
|
|
continue
|
|
if not len(tile_data["bits"]):
|
|
continue
|
|
|
|
bits = tile_data["bits"]
|
|
|
|
# No bus
|
|
if "CLB_IO_CLK" not in bits:
|
|
continue
|
|
|
|
bus = bits["CLB_IO_CLK"]
|
|
|
|
# Make the address as integers
|
|
baseaddr = int(bus["baseaddr"], 16)
|
|
offset = int(bus["offset"])
|
|
address = (
|
|
baseaddr,
|
|
offset,
|
|
)
|
|
|
|
# Add tile to the map
|
|
if address not in address_map:
|
|
address_map[address] = []
|
|
address_map[address].append(tile_name)
|
|
|
|
return address_map
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
class FileOrStream(object):
|
|
def __init__(self, file_name, stream=sys.stdout):
|
|
self.file_name = file_name
|
|
self.stream = stream
|
|
self.fp = None
|
|
|
|
def __enter__(self):
|
|
if self.file_name is None:
|
|
return self.stream
|
|
if self.file_name == "-":
|
|
return self.stream
|
|
|
|
self.fp = open(self.file_name, "w")
|
|
return self.fp
|
|
|
|
def __exit__(self, exc_typ, exc_val, exc_tb):
|
|
if self.fp is not None:
|
|
self.fp.close()
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
def main():
|
|
"""
|
|
The main.
|
|
"""
|
|
|
|
# Parse arguments
|
|
parser = argparse.ArgumentParser(
|
|
description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
|
|
parser.add_argument(
|
|
"files",
|
|
nargs="*",
|
|
type=str,
|
|
help="Input file(s) generated by segmaker")
|
|
parser.add_argument(
|
|
"-o",
|
|
type=str,
|
|
default="segbits.rdb",
|
|
help="Output database file (def. segbits.rdb)")
|
|
parser.add_argument(
|
|
"-f",
|
|
type=str,
|
|
default=None,
|
|
help="Tag filter. Processes only tags containing the specified text")
|
|
parser.add_argument(
|
|
"-t", type=float, default=0.95, help="Candidate threshold (def. 0.95)")
|
|
parser.add_argument(
|
|
"-e",
|
|
type=float,
|
|
default=0.1,
|
|
help="RMS error threshold below which a tag is rejected (def. 0.1)")
|
|
parser.add_argument(
|
|
"-a",
|
|
type=float,
|
|
default=0.01,
|
|
help="Regularization coefficient (def. 0.01)")
|
|
parser.add_argument(
|
|
"--all",
|
|
action="store_true",
|
|
help="Solve all tags at once (may give worse results)")
|
|
parser.add_argument(
|
|
"-x",
|
|
type=str,
|
|
default=None,
|
|
help="A CSV file name to Write the numerical solution to")
|
|
parser.add_argument(
|
|
"-r",
|
|
type=str,
|
|
default=None,
|
|
help=
|
|
"A text file name to write bit correlation report to. Specify '-' for stdout"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-m",
|
|
type=str,
|
|
default=None,
|
|
help="Mask bits found for this feature in all other features")
|
|
|
|
parser.add_argument("-b", type=float, default=0.0, help="Bias")
|
|
|
|
parser.add_argument("-no_0", action="store_true", help="Do not output 0s")
|
|
parser.add_argument("-no_1", action="store_true", help="Do not output 1s")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Build (baseaddr, offset) -> tile name map
|
|
database_dir = os.path.join(
|
|
os.getenv("XRAY_DATABASE_DIR"), os.getenv("XRAY_DATABASE"))
|
|
tilegrid_file = os.path.join(database_dir, "tilegrid.json")
|
|
address_map = build_address_map(tilegrid_file)
|
|
|
|
# Compute threshold
|
|
th = args.t
|
|
|
|
# Load and filter segdata
|
|
segdata = []
|
|
|
|
def tagfilter(tag):
|
|
if args.f is None:
|
|
return True
|
|
return args.f in tag
|
|
|
|
for name in args.files:
|
|
print(name)
|
|
segdata.extend(load_data(name, tagfilter, address_map))
|
|
|
|
# Make list of all bits
|
|
all_bits = set()
|
|
for seg in segdata:
|
|
all_bits |= set(seg["bit"])
|
|
all_bits = sorted(list(all_bits), key=sort_bits)
|
|
|
|
# Detect bits that are always set
|
|
const1_bits = set(all_bits)
|
|
for seg in segdata:
|
|
const1_bits &= set(seg["bit"])
|
|
|
|
# Make list of all tags
|
|
all_tags = set()
|
|
for seg in segdata:
|
|
all_tags |= set([tag[0] for tag in seg["tag"]])
|
|
all_tags = sorted(list(all_tags))
|
|
|
|
# Count 0s and 1s for each tag
|
|
tag_count = {}
|
|
for seg in segdata:
|
|
for tag, val in seg["tag"]:
|
|
|
|
if tag not in tag_count:
|
|
tag_count[tag] = [0, 0]
|
|
|
|
if val > 0:
|
|
tag_count[tag][1] += 1
|
|
else:
|
|
tag_count[tag][0] += 1
|
|
|
|
# Identify const0 and const1 tags
|
|
const_tags = {}
|
|
for tag in all_tags:
|
|
if tag_count[tag][0] == 0:
|
|
const_tags[tag] = 1
|
|
if tag_count[tag][1] == 0:
|
|
const_tags[tag] = 0
|
|
|
|
const0_tags = [t for t, v in const_tags.items() if v == 0]
|
|
const1_tags = [t for t, v in const_tags.items() if v == 1]
|
|
|
|
# Print config
|
|
print("# segs:", len(segdata))
|
|
print("# tags:", len(all_tags))
|
|
print("# bits:", len(all_bits))
|
|
print("threshold: %.2f" % th)
|
|
|
|
if len(segdata) == 0:
|
|
print("No data!")
|
|
exit(-1)
|
|
|
|
if len(all_tags) == 0:
|
|
print("No tags!")
|
|
exit(-1)
|
|
|
|
if len(all_bits) == 0:
|
|
print("No bits!")
|
|
exit(-1)
|
|
|
|
if len(const1_bits):
|
|
print("const 1 bits: " + ", ".join(const1_bits))
|
|
|
|
if len(const0_tags):
|
|
print("const 0 tags: " + ", ".join(const0_tags))
|
|
if len(const1_tags):
|
|
print("const 1 tags: " + ", ".join(const1_tags))
|
|
|
|
# Data to solve
|
|
tags_to_solve = list(all_tags)
|
|
bits_to_solve = list(all_bits)
|
|
|
|
for tag in const_tags.keys():
|
|
tags_to_solve.remove(tag)
|
|
for bit in const1_bits:
|
|
bits_to_solve.remove(bit)
|
|
|
|
# Statistics
|
|
tag_stats = compute_tag_stats(tags_to_solve, segdata)
|
|
|
|
# Solve
|
|
print("Solving...")
|
|
if args.all:
|
|
X, E = solve_tichonov(
|
|
tags_to_solve, bits_to_solve, segdata, bias=args.b, a=args.a)
|
|
else:
|
|
X, E = solve_onebyone(
|
|
tags_to_solve,
|
|
bits_to_solve,
|
|
segdata,
|
|
solver=solve_tichonov,
|
|
bias=args.b,
|
|
a=args.a)
|
|
|
|
# Detect candidate bits
|
|
W, X = detect_candidates(X, th, norm="max_abs")
|
|
|
|
# Mask
|
|
if args.m is not None:
|
|
print("Masking out %s" % args.m)
|
|
tags = [t for t in tags_to_solve if args.m in t]
|
|
for tag in tags:
|
|
i = tags_to_solve.index(tag)
|
|
for r in range(len(tags_to_solve)):
|
|
if r == i:
|
|
continue
|
|
for c in range(len(bits_to_solve)):
|
|
if W[r, c] == W[i, c]:
|
|
W[r, c] = 0
|
|
|
|
# Reject 0s and/or 1s
|
|
if args.no_0:
|
|
W[W < 0] = 0
|
|
if args.no_1:
|
|
W[W > 0] = 0
|
|
|
|
# Reject tags with error greater than threshold
|
|
for r in range(X.shape[0]):
|
|
if E[r] > args.e:
|
|
W[r, :] = 0
|
|
|
|
# Compute correlation
|
|
C, correlation_exceptions = compute_bit_correlations(
|
|
tags_to_solve, bits_to_solve, segdata, W)
|
|
|
|
# Write segbits
|
|
write_segbits(args.o, tags_to_solve, bits_to_solve, W)
|
|
|
|
# Dump to CSV
|
|
if args.x is not None:
|
|
with open(args.x, "w") as fp:
|
|
dump_solution_to_csv(fp, tags_to_solve, bits_to_solve, X)
|
|
|
|
# Dump results
|
|
dump_results(sys.stdout, tags_to_solve, bits_to_solve, W, X, E, tag_stats)
|
|
|
|
# Dump correlation report
|
|
if args.r is not None:
|
|
if args.r != "-":
|
|
print("Dumping bit correlation report to '{}'".format(args.r))
|
|
with FileOrStream(args.r, sys.stdout) as fp:
|
|
dump_correlation_report(
|
|
fp, tags_to_solve, bits_to_solve, W, C, correlation_exceptions)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
main()
|