diff --git a/doc/changelog.d/5050.fixed.md b/doc/changelog.d/5050.fixed.md new file mode 100644 index 00000000000..302054a5d53 --- /dev/null +++ b/doc/changelog.d/5050.fixed.md @@ -0,0 +1 @@ +Invalid state and TypeError bug in \`_AllowedDomainNames.valid_name\` diff --git a/src/ansys/fluent/core/services/solution_variables.py b/src/ansys/fluent/core/services/solution_variables.py index 655bef2f5b8..68c544376e9 100644 --- a/src/ansys/fluent/core/services/solution_variables.py +++ b/src/ansys/fluent/core/services/solution_variables.py @@ -327,6 +327,19 @@ def __init__(self, zone_name: str, allowed_values: List[str]): ) +class DomainError(ValueError): + """Exception class for errors in domain name.""" + + def __init__(self, domain_name: str, allowed_values: List[str]): + """Initialize DomainError.""" + self.domain_name = domain_name + super().__init__( + allowed_name_error_message( + context="domain", trial_name=domain_name, allowed_values=allowed_values + ) + ) + + class _AllowedNames: def is_valid(self, name): """Check whether a given name is valid or not.""" @@ -422,15 +435,16 @@ def valid_name(self, domain_name): Raises ------ - ZoneError + DomainError If the given domain name is invalid. """ - if not self.is_valid(domain_name): - raise ZoneError( - domain_name=domain_name, - allowed_values=self(), - ) - return self._zones_info.domain_id(domain_name) + domain_id = self._zones_info.domain_id(domain_name) + if domain_id is not None: + return domain_id + raise DomainError( + domain_name=domain_name, + allowed_values=self(), + ) class _SvarMethod: diff --git a/tests/test_solution_variables.py b/tests/test_solution_variables.py index 37ca9c588c4..7a87952c7f0 100644 --- a/tests/test_solution_variables.py +++ b/tests/test_solution_variables.py @@ -25,9 +25,64 @@ from ansys.fluent.core import examples from ansys.fluent.core.examples.downloads import download_file +from ansys.fluent.core.services.solution_variables import ( + DomainError, + _AllowedDomainNames, +) from ansys.units.variable_descriptor import VariableCatalog +class _DummyZonesInfo: + """Mock ZonesInfo for unit testing _AllowedDomainNames.""" + + def __init__(self, domains, domain_id_map): + self._domains = domains + self._domain_id_map = domain_id_map + + @property + def domains(self): + return self._domains + + def domain_id(self, name): + return self._domain_id_map.get(name, None) + + +class _DummySVInfo: + """Mock SolutionVariableInfo for unit testing _AllowedDomainNames.""" + + def __init__(self, zones_info): + self._zones_info = zones_info + + def get_zones_info(self): + return self._zones_info + + +def test_allowed_domain_names_valid_name_returns_domain_id(): + zones_info = _DummyZonesInfo(domains=["mixture"], domain_id_map={"mixture": 1}) + allowed = _AllowedDomainNames(_DummySVInfo(zones_info)) + assert allowed.valid_name("mixture") == 1 + + +def test_allowed_domain_names_valid_name_domain_id_zero(): + zones_info = _DummyZonesInfo(domains=["mixture"], domain_id_map={"mixture": 0}) + allowed = _AllowedDomainNames(_DummySVInfo(zones_info)) + assert allowed.valid_name("mixture") == 0 + + +def test_allowed_domain_names_valid_name_raises_on_missing_domain_id(): + zones_info = _DummyZonesInfo(domains=["mixture"], domain_id_map={}) + allowed = _AllowedDomainNames(_DummySVInfo(zones_info)) + with pytest.raises(DomainError): + allowed.valid_name("mixture") + + +def test_allowed_domain_names_valid_name_raises_on_invalid_domain(): + zones_info = _DummyZonesInfo(domains=["mixture"], domain_id_map={"mixture": 1}) + allowed = _AllowedDomainNames(_DummySVInfo(zones_info)) + with pytest.raises(DomainError): + allowed.valid_name("nonexistent") + + @pytest.mark.fluent_version(">=23.2") def test_solution_variables(new_solver_session): solver = new_solver_session