diff --git a/epynet/network.py b/epynet/network.py index efb139a..9d57120 100644 --- a/epynet/network.py +++ b/epynet/network.py @@ -6,6 +6,8 @@ from .curve import Curve from .pattern import Pattern +import itertools + class Network(object): """ self.epANET Network Simulation Class """ @@ -317,9 +319,22 @@ def solve(self, simtime=0): self.solved = True self.solved_for_simtime = simtime - def run(self): + def _run(self, **kwargs): self.reset() self.time = [] + + # simulation settings + if 'duration' in kwargs: + self.ep.ENsettimeparam(epanet2.EN_DURATION, kwargs['duration']) + if 'hydraulic_step' in kwargs: + self.ep.ENsettimeparam(epanet2.EN_HYDSTEP, kwargs['hydraulic_step']) + if 'report_step' in kwargs: + self.ep.ENsettimeparam(epanet2.EN_REPORTSTEP, kwargs['report_step']) + if 'pattern_start' in kwargs: + self.ep.ENsettimeparam(epanet2.EN_PATTERNSTART, kwargs['pattern_start']) + if 'pattern_step' in kwargs: + self.ep.ENsettimeparam(epanet2.EN_PATTERNSTEP, kwargs['pattern_step']) + # open network self.ep.ENopenH() self.ep.ENinitH(0) @@ -329,20 +344,39 @@ def run(self): self.solved = True + store_results = kwargs.get('store_results', True) + while timestep > 0: self.ep.ENrunH() timestep = self.ep.ENnextH() - self.time.append(simtime) - self.load_attributes(simtime) + + for el in itertools.chain(self.nodes, self.links): + # clear cached values + el._values = {} + + if store_results: + self.time.append(simtime) + self.load_attributes(simtime) + + # pass control to caller, to log results or compute + # custom feedback control laws. + yield simtime + simtime += timestep + self.ep.ENcloseH() + + def run(self, **kwargs): + if kwargs.get('interactive', False): + return self._run(**kwargs) + else: + simtimes = list(self._run(**kwargs)) + def load_attributes(self, simtime): for node in self.nodes: for property_name in node.properties.keys(): if property_name not in node.results.keys(): node.results[property_name] = [] - # clear cached values - node._values = {} node.results[property_name].append(node.get_property(node.properties[property_name])) node.times.append(simtime) @@ -350,8 +384,6 @@ def load_attributes(self, simtime): for property_name in link.properties.keys(): if property_name not in link.results.keys(): link.results[property_name] = [] - # clear cached values - link._values = {} link.results[property_name].append(link.get_property(link.properties[property_name])) link.times.append(simtime) diff --git a/tests/test_network.py b/tests/test_network.py index 2b81c92..965af09 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -182,3 +182,9 @@ def test11_timeseries(self): assert(isinstance(self.network.pipes["1"].velocity, float)) assert(isinstance(self.network.pipes.velocity, pd.Series)) + def test12_interactive(self): + # run network + times = [] + for t in self.network.run(interactive=True, store_results=False): + times.append(t) + assert_equal(times, [0, 3600, 7200, 10800, 14400, 18000, 21600, 25200, 28800, 32400, 36000])