Skip to content

Commit

Permalink
Raise an error when a variable is not found in the reference sample (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yias authored Dec 9, 2024
1 parent 9d80677 commit b2fb5a6
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 38 deletions.
106 changes: 80 additions & 26 deletions src/ansys/simai/core/data/model_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __validate_value(self, val: float):
if val < 0 and self.position != "absolute":
raise InvalidArguments(
f"{self.__class__.__name__}: 'value' must be a positive number when the position is not 'absolute'.",
) from None
)

def __set_length(self, lgth: float):
self.__validate_length(lgth)
Expand All @@ -105,7 +105,7 @@ def __validate_length(self, lgth: float):
if not lgth > 0:
raise InvalidArguments(
f"{self.__class__.__name__}: 'length' must be a positive number.",
) from None
)

position: Literal["relative_to_min", "relative_to_max", "relative_to_center", "absolute"]
value: float = property(__get_value, __set_value)
Expand Down Expand Up @@ -289,10 +289,7 @@ class ModelConfiguration:

project: "Optional[Project]" = None
build_on_top: bool = False
input: ModelInput = field(default_factory=lambda: ModelInput())
output: ModelOutput = field(default_factory=lambda: ModelOutput())
domain_of_analysis: DomainOfAnalysis = field(default_factory=lambda: DomainOfAnalysis())
pp_input: PostProcessInput = field(default_factory=lambda: PostProcessInput())

def __set_gc(self, gcs: list[GlobalCoefficientDefinition]):
verified_gcs = []
Expand All @@ -302,7 +299,7 @@ def __set_gc(self, gcs: list[GlobalCoefficientDefinition]):
if self.project is None:
raise ProcessingError(
f"{self.__class__.__name__}: a project must be defined for setting global coefficients."
) from None
)

self.project.verify_gc_formula(
gc_unit.formula, self.input.boundary_conditions, self.output.surface
Expand Down Expand Up @@ -330,6 +327,76 @@ def __get_build_preset(self):

build_preset = property(__get_build_preset, __set_build_preset)

def __validate_variables(self, vars_to_validate: list[str], var_type: str):
sample_metadata = self.project.sample.fields.get("extracted_metadata")
var_fields = dict_get(sample_metadata, var_type, "fields", default=[])
var_fields_name = {fd.get("name") for fd in var_fields}
unknown_variables = set(vars_to_validate) - var_fields_name
if unknown_variables:
raise ProcessingError(
f"{self.__class__.__name__}: {var_type} variables {', '.join(unknown_variables)} do not exist in the reference sample."
)

def __validate_surface_variables(self, vars_to_validate: list[str]):
sample_metadata = self.project.sample.fields.get("extracted_metadata")
if not sample_metadata.get("surface"):
raise ProcessingError(
"No surface field is found in the reference sample. A surface field is required to use surface variables."
)
self.__validate_variables(vars_to_validate, "surface")

def __validate_volume_variables(self, vars_to_validate: list[str]):
sample_metadata = self.project.sample.fields.get("extracted_metadata")
if not sample_metadata.get("volume"):
raise ProcessingError(
"No volume field is found in the reference sample. A volume field is required to use volume variables."
)
self.__validate_variables(vars_to_validate, "volume")

def __set_input(self, model_input: ModelInput):
if not model_input:
raise InvalidArguments(
"Invalid value for input; input should be an instance of ModelInput."
)
if model_input.surface:
self.__validate_surface_variables(model_input.surface)
self.__dict__["input"] = model_input

def __get_input(self):
return self.__dict__.get("input")

input = property(__get_input, __set_input)

def __set_output(self, model_output: ModelOutput):
if not model_output:
raise InvalidArguments(
"Invalid value for output; output should be an instance of ModelOutput."
)
if model_output.surface:
self.__validate_surface_variables(model_output.surface)

if model_output.volume:
self.__validate_volume_variables(model_output.volume)

self.__dict__["output"] = model_output

def __get_output(self):
return self.__dict__.get("output")

output = property(__get_output, __set_output)

def __set_pp_input(self, pp_input: PostProcessInput):
if not pp_input:
pp_input = PostProcessInput()
if pp_input.surface:
self.__validate_surface_variables(pp_input.surface)
self.__dict__["pp_input"] = pp_input

def __get_pp_input(self):
return self.__dict__.get("pp_input")

pp_input = property(__get_pp_input, __set_pp_input)

def __init__(
self,
project: "Project",
Expand Down Expand Up @@ -408,44 +475,31 @@ def _to_payload(self):
bcs = {bc_name: {} for bc_name in self.input.boundary_conditions}

sample_metadata = self.project.sample.fields.get("extracted_metadata")
surface_fields = dict_get(sample_metadata, "surface", "fields", default=[])
volume_fields = dict_get(sample_metadata, "volume", "fields", default=[])

surface_input_fld = []
if self.input.surface is not None:
surface_input_fld = [
fd
for fd in sample_metadata.get("surface").get("fields")
if fd.get("name") in self.input.surface
fd for fd in surface_fields if fd.get("name") in self.input.surface
]

surface_fld = []
if self.output.surface is not None:
surface_fld = [
fd
for fd in sample_metadata.get("surface").get("fields")
if fd.get("name") in self.output.surface
]
surface_fld = [fd for fd in surface_fields if fd.get("name") in self.output.surface]

volume_fld = []
if self.output.volume:
if not sample_metadata.get("volume"):
raise ProcessingError(
"No volume file was found in the reference sample. A volume file is required to use volume variables."
) from None
volume_fld = [
fd
for fd in sample_metadata["volume"].get("fields")
if fd.get("name") in self.output.volume
]
volume_fld = [fd for fd in volume_fields if fd.get("name") in self.output.volume]

gcs = []
if self.global_coefficients is not None:
gcs = [asdict(gc) for gc in self.global_coefficients]

surface_pp_input_fld = []
if self.pp_input.surface is not None:
suface_fields = dict_get(sample_metadata, "surface", "fields", default=[])
surface_pp_input_fld = [
fd for fd in suface_fields if fd.get("name") in self.pp_input.surface
fd for fd in surface_fields if fd.get("name") in self.pp_input.surface
]

flds = {
Expand Down Expand Up @@ -487,7 +541,7 @@ def compute_global_coefficient(self) -> List[float]:
if self.project is None:
raise ProcessingError(
f"{self.__class__.__name__}: a project must be a defined for computing the global coefficient formula."
) from None
)

return [
self.project.compute_gc_formula(
Expand Down
38 changes: 38 additions & 0 deletions tests/test_model_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,41 @@ def test_build_preset_error(simai_client):
)

assert f"{list(SupportedBuildPresets)}" in excinfo.value


def test_model_input_not_none(simai_client):
"""WHEN ModelConfiguration.input gets a None value
THEN an InvalidArgument is raised."""

raw_project = {
"id": "xX007Xx",
"name": "fifi",
}

project = simai_client._project_directory._model_from(raw_project)

bld_conf = ModelConfiguration(
project=project,
)

with pytest.raises(InvalidArguments):
bld_conf.input = None


def test_model_output_not_none(simai_client):
"""WHEN ModelConfiguration.input gets a None value
THEN an InvalidArgument is raised."""

raw_project = {
"id": "xX007Xx",
"name": "fifi",
}

project = simai_client._project_directory._model_from(raw_project)

bld_conf = ModelConfiguration(
project=project,
)

with pytest.raises(InvalidArguments):
bld_conf.output = None
58 changes: 46 additions & 12 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,12 @@ def test_throw_error_when_volume_is_missing_from_sample(simai_client):
THEN a ProcessingError is thrown.
"""

sample_raw = deepcopy(SAMPLE_RAW)

raw_project = {
"id": MODEL_RAW["project_id"],
"name": "fifi",
"sample": SAMPLE_RAW,
"sample": sample_raw,
}

raw_project["sample"]["extracted_metadata"].pop("volume")
Expand All @@ -520,17 +522,15 @@ def test_throw_error_when_volume_is_missing_from_sample(simai_client):
model_output = ModelOutput(surface=[], volume=["Velocity_0"])
global_coefficients = []

model_conf = ModelConfiguration._from_payload(
project=project,
build_preset="debug",
continuous=False,
input=model_input,
output=model_output,
global_coefficients=global_coefficients,
)

with pytest.raises(ProcessingError):
model_conf._to_payload()
_ = ModelConfiguration._from_payload(
project=project,
build_preset="debug",
continuous=False,
input=model_input,
output=model_output,
global_coefficients=global_coefficients,
)


@responses.activate
Expand Down Expand Up @@ -561,7 +561,7 @@ def test_post_process_input(simai_client):
project: Project = simai_client._project_directory._model_from(raw_project)
project.verify_gc_formula = Mock()

pp_input = PostProcessInput(surface=["Temperature_1"])
pp_input = PostProcessInput(surface=["TurbulentViscosity"])

model_conf_dict = deepcopy(MODEL_CONF_RAW)
model_conf_dict["fields"].pop("volume")
Expand Down Expand Up @@ -672,3 +672,37 @@ def test_failed_build_with_resolution(simai_client):
with pytest.raises(ApiClientError) as e:
simai_client.models.build(new_conf)
assert "This is a resolution." in str(e.value)


@responses.activate
def test_throw_error_when_unknown_variables(simai_client):
"""WHEN input/output/pp_input variables are not found in the reference sample
THEN a ProcessingError is raised.
"""

raw_project = {
"id": MODEL_RAW["project_id"],
"name": "pp_newnew",
"sample": SAMPLE_RAW,
}

project: Project = simai_client._project_directory._model_from(raw_project)

mdl_config = ModelConfiguration(project=project)

unknown_vars = ["abc1", "abc2"]

with pytest.raises(ProcessingError) as e:
mdl_config.input = ModelInput(surface=unknown_vars)
for ukn_var in unknown_vars:
assert ukn_var in str(e.value)

with pytest.raises(ProcessingError) as e:
mdl_config.output = ModelOutput(surface=unknown_vars)
for ukn_var in unknown_vars:
assert ukn_var in str(e.value)

with pytest.raises(ProcessingError) as e:
mdl_config.output = ModelOutput(volume=unknown_vars)
for ukn_var in unknown_vars:
assert ukn_var in str(e.value)

0 comments on commit b2fb5a6

Please sign in to comment.