Convert source and target lists to sets for faster contains check.

This commit is contained in:
Matt Guthaus 2019-01-30 11:15:47 -08:00
parent 07f4d639eb
commit 74fbe8fe63
4 changed files with 12 additions and 9 deletions

View File

@ -21,8 +21,8 @@ class grid:
""" Initialize the map and define the costs. """
# list of the source/target grid coordinates
self.source = []
self.target = []
self.source = set()
self.target = set()
self.track_width = track_width
self.track_widths = [self.track_width, self.track_width, 1.0]
@ -80,7 +80,7 @@ class grid:
else:
self.add_map(n)
self.map[n].source=value
self.source.append(n)
self.source.add(n)
def set_target(self,n,value=True):
if isinstance(n, (list,tuple,set,frozenset)):
@ -89,7 +89,7 @@ class grid:
else:
self.add_map(n)
self.map[n].target=value
self.target.append(n)
self.target.add(n)
def add_source(self,track_list,value=True):

View File

@ -62,7 +62,8 @@ class signal_grid(grid):
# We set a cost bound of the HPWL for run-time. This can be
# over-ridden if the route fails due to pruning a feasible solution.
cost_bound = detour_scale*self.cost_to_target(self.source[0])*grid.PREFERRED_COST
any_source_element = next(iter(self.source))
cost_bound = detour_scale*self.cost_to_target(any_source_element)*grid.PREFERRED_COST
# Check if something in the queue is already a source and a target!
for s in self.source:
@ -153,7 +154,8 @@ class signal_grid(grid):
Find the cheapest HPWL distance to any target point ignoring
blockages for A* search.
"""
cost = self.hpwl(source,self.target[0])
any_target_element = next(iter(self.target))
cost = self.hpwl(source,any_target_element)
for t in self.target:
cost = min(self.hpwl(source,t),cost)

View File

@ -20,8 +20,8 @@ class supply_grid(signal_grid):
def reinit(self):
""" Reinitialize everything for a new route. """
self.source = []
self.target = []
self.source = set()
self.target = set()
# Reset all the cells in the map
for p in self.map.values():
p.reset()

View File

@ -2,7 +2,8 @@
import pstats
p = pstats.Stats("profile.dat")
p.strip_dirs()
p.sort_stats("cumulative")
#p.sort_stats("cumulative")
p.sort_stats("tottime")
#p.print_stats(50)
p.print_stats()