Skip to content

Commit

Permalink
Simplified create_classes
Browse files Browse the repository at this point in the history
  • Loading branch information
rlb131 committed Jul 19, 2023
1 parent 0b2eba4 commit 2cca9f6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 30 deletions.
10 changes: 5 additions & 5 deletions sirepo_bluesky/sirepo_ophyd.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ def read(self):
return propagation_read


def create_classes(sirepo_data, connection, create_objects=True, extra_model_fields=[]):
def create_classes(connection, create_objects=True, extra_model_fields=[]):
classes = {}
objects = {}
data = copy.deepcopy(sirepo_data)
data = copy.deepcopy(connection.data)

sim_type = connection.sim_type

Expand Down Expand Up @@ -485,11 +485,11 @@ def create_classes(sirepo_data, connection, create_objects=True, extra_model_fie
cpt_class = SirepoSignal

if "type" in el and el["type"] not in ["undulator", "intensityReport"]:
sirepo_dict = sirepo_data["models"][model_field][i]
sirepo_dict = connection.data["models"][model_field][i]
elif sim_type == "madx" and model_field in ["rpnVariables", "commands"]:
sirepo_dict = sirepo_data["models"][model_field][i]
sirepo_dict = connection.data["models"][model_field][i]
else:
sirepo_dict = sirepo_data["models"][model_field]
sirepo_dict = connection.data["models"][model_field]

components[k] = Cpt(
cpt_class,
Expand Down
34 changes: 14 additions & 20 deletions sirepo_bluesky/tests/test_bl_elements_as_ophyd_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def test_beamline_elements_as_ophyd_objects(srw_tes_simulation):
classes, objects = create_classes(srw_tes_simulation.data, connection=srw_tes_simulation)
classes, objects = create_classes(connection=srw_tes_simulation)

for name, obj in objects.items():
pprint.pprint(obj.read())
Expand All @@ -29,7 +29,7 @@ def test_beamline_elements_as_ophyd_objects(srw_tes_simulation):


def test_empty_simulation(srw_empty_simulation):
classes, objects = create_classes(srw_empty_simulation.data, connection=srw_empty_simulation)
classes, objects = create_classes(connection=srw_empty_simulation)
globals().update(**objects)

assert not srw_empty_simulation.data["models"]["beamline"]
Expand All @@ -39,7 +39,7 @@ def test_empty_simulation(srw_empty_simulation):

@pytest.mark.parametrize("method", ["set", "put"])
def test_beamline_elements_set_put(srw_tes_simulation, method):
classes, objects = create_classes(srw_tes_simulation.data, connection=srw_tes_simulation)
classes, objects = create_classes(connection=srw_tes_simulation)
globals().update(**objects)

i = 0
Expand All @@ -66,7 +66,7 @@ def test_beamline_elements_set_put(srw_tes_simulation, method):

@pytest.mark.parametrize("method", ["set", "put"])
def test_crl_calculation(srw_chx_simulation, method):
classes, objects = create_classes(srw_chx_simulation.data, connection=srw_chx_simulation)
classes, objects = create_classes(connection=srw_chx_simulation)
globals().update(**objects)

params_before = copy.deepcopy(crl1.tipRadius._sirepo_dict) # noqa F821
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_crl_calculation(srw_chx_simulation, method):

@pytest.mark.parametrize("method", ["set", "put"])
def test_crystal_calculation(srw_tes_simulation, method):
classes, objects = create_classes(srw_tes_simulation.data, connection=srw_tes_simulation)
classes, objects = create_classes(connection=srw_tes_simulation)
globals().update(**objects)

params_before = copy.deepcopy(mono_crystal1.energy._sirepo_dict) # noqa F821
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_crystal_calculation(srw_tes_simulation, method):

@pytest.mark.parametrize("method", ["set", "put"])
def test_grazing_angle_calculation(srw_tes_simulation, method):
classes, objects = create_classes(srw_tes_simulation.data, connection=srw_tes_simulation)
classes, objects = create_classes(connection=srw_tes_simulation)
globals().update(**objects)

params_before = copy.deepcopy(toroid.grazingAngle._sirepo_dict) # noqa F821
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_grazing_angle_calculation(srw_tes_simulation, method):


def test_beamline_elements_simple_connection(srw_basic_simulation):
classes, objects = create_classes(srw_basic_simulation.data, connection=srw_basic_simulation)
classes, objects = create_classes(connection=srw_basic_simulation)

for name, obj in objects.items():
pprint.pprint(obj.read())
Expand All @@ -203,7 +203,6 @@ def test_beamline_elements_simple_connection(srw_basic_simulation):

def test_srw_source_with_run_engine(RE, db, srw_ari_simulation, num_steps=5):
classes, objects = create_classes(
srw_ari_simulation.data,
connection=srw_ari_simulation,
extra_model_fields=["undulator", "intensityReport"],
)
Expand Down Expand Up @@ -259,7 +258,7 @@ def test_srw_source_with_run_engine(RE, db, srw_ari_simulation, num_steps=5):


def test_srw_propagation_with_run_engine(RE, db, srw_chx_simulation, num_steps=5):
classes, objects = create_classes(srw_chx_simulation.data, connection=srw_chx_simulation)
classes, objects = create_classes(connection=srw_chx_simulation)
globals().update(**objects)

postPropagation.hrange_mod.kind = "hinted" # noqa F821
Expand All @@ -284,7 +283,7 @@ def test_srw_propagation_with_run_engine(RE, db, srw_chx_simulation, num_steps=5


def test_shadow_with_run_engine(RE, db, shadow_tes_simulation, num_steps=5):
classes, objects = create_classes(shadow_tes_simulation.data, connection=shadow_tes_simulation)
classes, objects = create_classes(connection=shadow_tes_simulation)
globals().update(**objects)

aperture.horizontalSize.kind = "hinted" # noqa F821
Expand Down Expand Up @@ -329,7 +328,7 @@ def test_shadow_with_run_engine(RE, db, shadow_tes_simulation, num_steps=5):


def test_beam_statistics_report_only(RE, db, shadow_tes_simulation):
classes, objects = create_classes(shadow_tes_simulation.data, connection=shadow_tes_simulation)
classes, objects = create_classes(connection=shadow_tes_simulation)
globals().update(**objects)

bsr = BeamStatisticsReport(name="bsr", connection=shadow_tes_simulation)
Expand Down Expand Up @@ -369,7 +368,7 @@ def test_beam_statistics_report_only(RE, db, shadow_tes_simulation):


def test_beam_statistics_report_and_watchpoint(RE, db, shadow_tes_simulation):
classes, objects = create_classes(shadow_tes_simulation.data, connection=shadow_tes_simulation)
classes, objects = create_classes(connection=shadow_tes_simulation)
globals().update(**objects)

bsr = BeamStatisticsReport(name="bsr", connection=shadow_tes_simulation)
Expand Down Expand Up @@ -404,7 +403,7 @@ def test_beam_statistics_report_and_watchpoint(RE, db, shadow_tes_simulation):
def test_mad_x_elements_set_put(madx_resr_storage_ring_simulation, method):
connection = madx_resr_storage_ring_simulation
data = connection.data
classes, objects = create_classes(data, connection=connection)
classes, objects = create_classes(connection=connection)
globals().update(**objects)

for i, (k, v) in enumerate(objects.items()):
Expand All @@ -426,8 +425,7 @@ def test_mad_x_elements_set_put(madx_resr_storage_ring_simulation, method):

def test_mad_x_elements_simple_connection(madx_bl2_triplet_tdc_simulation):
connection = madx_bl2_triplet_tdc_simulation
data = connection.data
classes, objects = create_classes(data, connection=connection)
classes, objects = create_classes(connection=connection)
for name, obj in objects.items():
pprint.pprint(obj.read())

Expand All @@ -439,8 +437,7 @@ def test_mad_x_elements_simple_connection(madx_bl2_triplet_tdc_simulation):

def test_madx_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation):
connection = madx_bl2_triplet_tdc_simulation
data = connection.data
classes, objects = create_classes(data, connection=connection)
classes, objects = create_classes(connection=connection)
globals().update(**objects)

madx_flyer = MADXFlyer(
Expand Down Expand Up @@ -477,7 +474,6 @@ def test_madx_variables_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation)
connection = madx_bl2_triplet_tdc_simulation
data = connection.data
classes, objects = create_classes(
data,
connection=connection,
extra_model_fields=["rpnVariables"],
)
Expand Down Expand Up @@ -512,7 +508,6 @@ def test_madx_commands_with_run_engine(RE, db, madx_bl2_triplet_tdc_simulation):
connection = madx_bl2_triplet_tdc_simulation
data = connection.data
classes, objects = create_classes(
data,
connection=connection,
extra_model_fields=["commands"],
)
Expand Down Expand Up @@ -548,7 +543,6 @@ def test_madx_variables_and_commands_with_run_engine(RE, db, madx_bl2_triplet_td
connection = madx_bl2_triplet_tdc_simulation
data = connection.data
classes, objects = create_classes(
data,
connection=connection,
extra_model_fields=["rpnVariables", "commands"],
)
Expand Down
10 changes: 5 additions & 5 deletions sirepo_bluesky/tests/test_stateless_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_stateless_compute_crl_characteristics_basic(srw_chx_simulation, RE):
assert diff, "The browser request and expected response match, but are expected to be different."
pprint.pprint(diff)

classes, objects = create_classes(srw_chx_simulation.data, connection=srw_chx_simulation)
classes, objects = create_classes(connection=srw_chx_simulation)

crl1 = objects["crl1"]

Expand All @@ -86,7 +86,7 @@ def test_stateless_compute_crl_characteristics_basic(srw_chx_simulation, RE):


def test_stateless_compute_crystal_orientation_basic(srw_tes_simulation, RE):
classes, objects = create_classes(srw_tes_simulation.data, connection=srw_tes_simulation)
classes, objects = create_classes(connection=srw_tes_simulation)
globals().update(**objects)
mc1 = mono_crystal1 # noqa

Expand Down Expand Up @@ -184,7 +184,7 @@ def test_stateless_compute_crystal_orientation_basic(srw_tes_simulation, RE):

@vcr.use_cassette(f"{cassette_location}/test_crl_characteristics.yml")
def test_stateless_compute_crl_characteristics_advanced(srw_chx_simulation, tmp_path):
classes, objects = create_classes(srw_chx_simulation.data, connection=srw_chx_simulation)
classes, objects = create_classes(connection=srw_chx_simulation)

_generate_test_crl_file(tmp_path / "test_crl_characteristics.json", objects["crl1"], srw_chx_simulation)

Expand All @@ -202,7 +202,7 @@ def test_stateless_compute_crl_characteristics_advanced(srw_chx_simulation, tmp_

@vcr.use_cassette(f"{cassette_location}/test_crystal_characteristics.yml")
def test_stateless_compute_crystal_advanced(srw_tes_simulation, tmp_path):
classes, objects = create_classes(srw_tes_simulation.data, connection=srw_tes_simulation)
classes, objects = create_classes(connection=srw_tes_simulation)

_generate_test_crystal_file(
tmp_path / "test_compute_crystal.json", objects["mono_crystal1"], srw_tes_simulation
Expand All @@ -226,7 +226,7 @@ def test_stateless_compute_crystal_advanced(srw_tes_simulation, tmp_path):


def test_stateless_compute_with_RE(RE, srw_chx_simulation, db):
classes, objects = create_classes(srw_chx_simulation.data, connection=srw_chx_simulation)
classes, objects = create_classes(connection=srw_chx_simulation)
globals().update(**objects)
crl1.tipRadius.kind = "hinted" # noqa
sample.duration.kind = "hinted" # noqa
Expand Down

0 comments on commit 2cca9f6

Please sign in to comment.