Skip to content

Commit 1aecf4d

Browse files
committed
Generate seeds before adding graph components
This means that if you do set the seed for the top-level Network in your model, adding or removing graphs will not affect the seed for the Ensembles in your model. Fixes #855
1 parent 55e730b commit 1aecf4d

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

nengo_gui/page.py

+11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import nengo_gui
1313
import nengo_gui.user_action
1414
import nengo_gui.config
15+
import nengo_gui.seed_generation
1516

1617

1718
class PageSettings(object):
@@ -434,6 +435,12 @@ def build(self):
434435
with self.lock:
435436
self.building = True
436437

438+
# set all the seeds so that creating components doesn't affect
439+
# the neural model itself
440+
seeds = nengo_gui.seed_generation.define_all_seeds(self.model)
441+
for obj, s in seeds.items():
442+
obj.seed = s
443+
437444
# modify the model for the various Components
438445
for c in self.components:
439446
c.add_nengo_objects(self)
@@ -456,6 +463,10 @@ def build(self):
456463
line = nengo_gui.exec_env.determine_line_number()
457464
self.error = dict(trace=traceback.format_exc(), line=line)
458465

466+
# set the defined seeds back to None
467+
for obj in seeds:
468+
obj.seed = None
469+
459470
self.stdout += exec_env.stdout.getvalue()
460471

461472
if self.sim is not None:

nengo_gui/seed_generation.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import nengo
2+
import nengo.utils.numpy as npext
3+
import numpy as np
4+
5+
6+
def define_all_seeds(net, seeds=None):
7+
if seeds is None:
8+
seeds = {}
9+
10+
if net.seed is None:
11+
if net not in seeds:
12+
# this only happens at the very top level
13+
seeds[net] = np.random.randint(npext.maxint)
14+
rng = np.random.RandomState(seed=seeds[net])
15+
else:
16+
rng = np.random.RandomState(seed=net.seed)
17+
18+
# let's use the same algorithm as the builder, just to be consistent
19+
sorted_types = sorted(net.objects, key=lambda t: t.__name__)
20+
for obj_type in sorted_types:
21+
for obj in net.objects[obj_type]:
22+
# generate a seed for each item, so that manually setting a seed
23+
# for a particular item doesn't change the generated seed for
24+
# other items
25+
generated_seed = rng.randint(npext.maxint)
26+
if obj.seed is None:
27+
seeds[obj] = generated_seed
28+
29+
for subnet in net.networks:
30+
define_all_seeds(subnet, seeds)
31+
32+
return seeds

0 commit comments

Comments
 (0)