diff --git a/experiments/timfuz/timfuz_checksub.py b/experiments/timfuz/timfuz_checksub.py index cf290d78..a1bbc5d7 100644 --- a/experiments/timfuz/timfuz_checksub.py +++ b/experiments/timfuz/timfuz_checksub.py @@ -8,7 +8,7 @@ import math from collections import OrderedDict # check for issues that may be due to round off error -STRICT = 0 +STRICT = 1 def Adi2matrix_random(A_ubd, b_ub, names): # random assignment @@ -78,7 +78,7 @@ def row_sub_syms(row, sub_json, verbose=False): print('pivot %i %s' % (n, pivot)) for subk, subv in sorted(sub_json['subs'][group].items()): oldn = row.get(subk, 0) - rown = oldn - n * subv + rown = oldn - n * (1.0* subv[0] / subv[1]) if verbose: print(" %s: %d => %d" % (subk, oldn, rown)) if rown == 0: @@ -208,12 +208,7 @@ def load_sub(fn): for name, vals in sorted(j['subs'].items()): pivot = None for k, v in vals.items(): - if STRICT: - vi = int(round(v)) - assert abs(vi - v) < delta - vals[k] = vi - else: - vals[k] = float(v) + vals[k] = float(v) # there may be more than one acceptable pivot # take the first diff --git a/experiments/timfuz/timfuz_rref.py b/experiments/timfuz/timfuz_rref.py index 1fa514b3..0ac87967 100644 --- a/experiments/timfuz/timfuz_rref.py +++ b/experiments/timfuz/timfuz_rref.py @@ -14,8 +14,6 @@ import sympy from collections import OrderedDict from fractions import Fraction -STRICT = 0 - def fracr(r): DELTA = 0.0001 @@ -85,6 +83,12 @@ def Adi2matrix(Adi, cols): return A_ub2 def Anp2matrix(Anp): + ''' + Original idea was to make into a square matrix + but this loses too much information + so now this actually isn't doing anything and should probably be eliminated + ''' + ncols = len(Anp[0]) A_ub2 = [np.zeros(ncols) for _i in range(ncols)] dst_rowi = 0 @@ -102,6 +106,19 @@ def row_np2ds(rownp, names): ret[name] = v return ret +def row_sym2dsf(rowsym, names): + '''Convert a sympy row into a dictionary of keys to (numerator, denominator) tuples''' + from sympy import fraction + + ret = {} + assert len(rowsym) == len(names), (len(rowsym), len(names)) + for namei, name in enumerate(names): + v = rowsym[namei] + if v: + (num, den) = fraction(v) + ret[name] = (int(num), int(den)) + return ret + def comb_corr_sets(state, verbose=False): print('Converting rows to integer keys') names, Anp = A_ds2np(state.Ads) @@ -130,50 +147,41 @@ def comb_corr_sets(state, verbose=False): state.pivots = {} - def row_solved(rownp, row_pivot): - for ci, c in enumerate(rownp): + def row_solved(rowsym, row_pivot): + for ci, c in enumerate(rowsym): if ci == row_pivot: continue if c != 0: return False return True - rrefnp = np.array(rref).astype(np.float64) - print('Computing groups w/ rref %u row x %u col' % (len(rrefnp), len(rrefnp[0]))) + #rrefnp = np.array(rref).astype(np.float64) + #print('Computing groups w/ rref %u row x %u col' % (len(rrefnp), len(rrefnp[0]))) #print(rrefnp) # rows that have a single 1 are okay # anything else requires substitution (unless all 0) # pivots may be fewer than the rows # remaining rows should be 0s - for row_i, (row_pivot, rownp) in enumerate(zip(pivots, rrefnp)): - rowds = row_np2ds(rownp, names) - # boring cases: solved variable, not fully ranked - #if sum(rowds.values()) == 1: - if row_solved(rownp, row_pivot): + for row_i, row_pivot in enumerate(pivots): + rowsym = rref.row(row_i) + # yipee! nothign to report + if row_solved(rowsym, row_pivot): continue # a grouping group_name = "GROUP_%u" % row_i - if STRICT: - delta = 0.001 - rowds_store = {} - for k, v in rowds.items(): - vi = int(round(v)) - error = abs(vi - v) - assert error < delta, (error, delta) - rowds_store[k] = vi - else: - rowds_store = rowds - state.subs[group_name] = rowds_store + rowdsf = row_sym2dsf(rowsym, names) + + state.subs[group_name] = rowdsf # Add the new symbol state.names.add(group_name) # Remove substituted symbols # Note: symbols may appear multiple times - state.names.difference_update(set(rowds.keys())) + state.names.difference_update(set(rowdsf.keys())) pivot_name = names[row_pivot] state.pivots[group_name] = pivot_name if verbose: - print("%s (%s): %s" % (group_name, pivot_name, rowds)) + print("%s (%s): %s" % (group_name, pivot_name, rowdsf)) return state