Skip to content

Commit

Permalink
raw_confidence_score property for predictions (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
kliment-slice authored Sep 23, 2024
1 parent 43cfc3d commit 3e999f2
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
16 changes: 14 additions & 2 deletions src/ansys/simai/core/data/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,25 @@ def post(self) -> PredictionPostProcessings:
return self._post_processings

@property
def confidence_score(self) -> str:
def confidence_score(self) -> Optional[str]:
"""Confidence score, which is either ``high`` or ``low``.
This method blocks until the confidence score is computed.
"""
self.wait()
return self.fields["confidence_score"]
confidence_score = self.fields["confidence_score"]
if confidence_score not in ["high", "low", None]:
raise ValueError("Must be None or one of: 'high', 'low', None.")
return confidence_score

@property
def raw_confidence_score(self) -> Optional[float]:
"""Raw confidence score, a float.
This method blocks until the confidence score is computed.
"""
self.wait()
return self.fields["raw_confidence_score"]

def delete(self) -> None:
"""Remove a prediction from the server."""
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def _factory(post_processings=None, geometry=None, **kwargs) -> Prediction:
kwargs["geometry_id"] = geometry.id
kwargs.setdefault("boundary_conditions", {"Vx": 10.01, "Vy": 0.0009})
kwargs.setdefault("state", "successful")
kwargs.setdefault("confidence_score", None)
kwargs.setdefault("raw_confidence_score", None)
prediction = simai_client._prediction_directory._model_from(kwargs)
if post_processings is not None:
# If we passed post-processings as parameter,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ def test_sse_event_prediction_success(sse_mixin, prediction_factory):
assert pred.is_pending
# Mock a SSE success event
updated_record = pred.fields.copy()
updated_record.update({"state": "successful", "confidence_score": "abysmal"})
updated_record.update({"state": "successful", "confidence_score": "high"})
sse_mixin._handle_sse_event(
create_sse_event(
f'{{"target": {{"id": "{pred.id}", "type": "prediction"}}, "type": "job", "record": {json.dumps(updated_record)}}}'
)
)
assert not pred.is_pending
assert pred.is_ready
assert pred.confidence_score == "abysmal"
assert pred.confidence_score == "high"


def test_sse_event_update_prediction_failure(sse_mixin, prediction_factory):
Expand Down
19 changes: 19 additions & 0 deletions tests/test_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import pytest
import responses

from ansys.simai.core.data.geometries import Geometry
Expand Down Expand Up @@ -190,3 +191,21 @@ def test_run_no_bc(simai_client, geometry_factory):
)
geometry = geometry_factory(id="geom-0")
simai_client.predictions.run(geometry.id)


@responses.activate
def test_confidence_score(prediction_factory):
"""WHEN accessing a Prediction's confidence score properties
THEN the corresponding values are returned
"""
prediction = prediction_factory(confidence_score="high", raw_confidence_score=0.94107)
empty_prediction = prediction_factory()
bad_prediction = prediction_factory(confidence_score="abysmal")

assert prediction.confidence_score == "high"
assert prediction.raw_confidence_score == 0.94107
assert empty_prediction.confidence_score is None
assert empty_prediction.raw_confidence_score is None
with pytest.raises(ValueError) as exc:
assert bad_prediction.confidence_score
assert str(exc.value) == "Must be None or one of: 'high', 'low', None."

0 comments on commit 3e999f2

Please sign in to comment.