timfuz: use symbolic math to generate constants

Signed-off-by: John McMaster <johndmcmaster@gmail.com>
This commit is contained in:
John McMaster 2018-08-24 13:12:33 -07:00
parent 8c5436916c
commit 71dad3d09c
2 changed files with 35 additions and 32 deletions

View File

@ -8,7 +8,7 @@ import math
from collections import OrderedDict
# check for issues that may be due to round off error
STRICT = 0
STRICT = 1
def Adi2matrix_random(A_ubd, b_ub, names):
# random assignment
@ -78,7 +78,7 @@ def row_sub_syms(row, sub_json, verbose=False):
print('pivot %i %s' % (n, pivot))
for subk, subv in sorted(sub_json['subs'][group].items()):
oldn = row.get(subk, 0)
rown = oldn - n * subv
rown = oldn - n * (1.0* subv[0] / subv[1])
if verbose:
print(" %s: %d => %d" % (subk, oldn, rown))
if rown == 0:
@ -208,12 +208,7 @@ def load_sub(fn):
for name, vals in sorted(j['subs'].items()):
pivot = None
for k, v in vals.items():
if STRICT:
vi = int(round(v))
assert abs(vi - v) < delta
vals[k] = vi
else:
vals[k] = float(v)
vals[k] = float(v)
# there may be more than one acceptable pivot
# take the first

View File

@ -14,8 +14,6 @@ import sympy
from collections import OrderedDict
from fractions import Fraction
STRICT = 0
def fracr(r):
DELTA = 0.0001
@ -85,6 +83,12 @@ def Adi2matrix(Adi, cols):
return A_ub2
def Anp2matrix(Anp):
'''
Original idea was to make into a square matrix
but this loses too much information
so now this actually isn't doing anything and should probably be eliminated
'''
ncols = len(Anp[0])
A_ub2 = [np.zeros(ncols) for _i in range(ncols)]
dst_rowi = 0
@ -102,6 +106,19 @@ def row_np2ds(rownp, names):
ret[name] = v
return ret
def row_sym2dsf(rowsym, names):
'''Convert a sympy row into a dictionary of keys to (numerator, denominator) tuples'''
from sympy import fraction
ret = {}
assert len(rowsym) == len(names), (len(rowsym), len(names))
for namei, name in enumerate(names):
v = rowsym[namei]
if v:
(num, den) = fraction(v)
ret[name] = (int(num), int(den))
return ret
def comb_corr_sets(state, verbose=False):
print('Converting rows to integer keys')
names, Anp = A_ds2np(state.Ads)
@ -130,50 +147,41 @@ def comb_corr_sets(state, verbose=False):
state.pivots = {}
def row_solved(rownp, row_pivot):
for ci, c in enumerate(rownp):
def row_solved(rowsym, row_pivot):
for ci, c in enumerate(rowsym):
if ci == row_pivot:
continue
if c != 0:
return False
return True
rrefnp = np.array(rref).astype(np.float64)
print('Computing groups w/ rref %u row x %u col' % (len(rrefnp), len(rrefnp[0])))
#rrefnp = np.array(rref).astype(np.float64)
#print('Computing groups w/ rref %u row x %u col' % (len(rrefnp), len(rrefnp[0])))
#print(rrefnp)
# rows that have a single 1 are okay
# anything else requires substitution (unless all 0)
# pivots may be fewer than the rows
# remaining rows should be 0s
for row_i, (row_pivot, rownp) in enumerate(zip(pivots, rrefnp)):
rowds = row_np2ds(rownp, names)
# boring cases: solved variable, not fully ranked
#if sum(rowds.values()) == 1:
if row_solved(rownp, row_pivot):
for row_i, row_pivot in enumerate(pivots):
rowsym = rref.row(row_i)
# yipee! nothign to report
if row_solved(rowsym, row_pivot):
continue
# a grouping
group_name = "GROUP_%u" % row_i
if STRICT:
delta = 0.001
rowds_store = {}
for k, v in rowds.items():
vi = int(round(v))
error = abs(vi - v)
assert error < delta, (error, delta)
rowds_store[k] = vi
else:
rowds_store = rowds
state.subs[group_name] = rowds_store
rowdsf = row_sym2dsf(rowsym, names)
state.subs[group_name] = rowdsf
# Add the new symbol
state.names.add(group_name)
# Remove substituted symbols
# Note: symbols may appear multiple times
state.names.difference_update(set(rowds.keys()))
state.names.difference_update(set(rowdsf.keys()))
pivot_name = names[row_pivot]
state.pivots[group_name] = pivot_name
if verbose:
print("%s (%s): %s" % (group_name, pivot_name, rowds))
print("%s (%s): %s" % (group_name, pivot_name, rowdsf))
return state