Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raw_confidence_score property for predictions #87

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/ansys/simai/core/data/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,26 @@ def post(self) -> PredictionPostProcessings:
return self._post_processings

@property
def confidence_score(self) -> str:
def confidence_score(self) -> Optional[str]:
"""Confidence score, which is either ``high`` or ``low``.

This method blocks until the confidence score is computed.
"""
self.wait()
return self.fields["confidence_score"]
confidence_score = self.fields["confidence_score"]
if confidence_score is not None and confidence_score not in ["high", "low"]:
raise ValueError("Must be None or one of: 'high', 'low'.")
kliment-slice marked this conversation as resolved.
Show resolved Hide resolved
return confidence_score

@property
def raw_confidence_score(self) -> Optional[float]:
"""Raw confidence score, a float.

This method blocks until the confidence score is computed.
"""
self.wait()
raw_score = self.fields["raw_confidence_score"]
return round(raw_score, 2) if isinstance(raw_score, float) else raw_score
kliment-slice marked this conversation as resolved.
Show resolved Hide resolved

def delete(self) -> None:
"""Remove a prediction from the server."""
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def _factory(post_processings=None, geometry=None, **kwargs) -> Prediction:
kwargs["geometry_id"] = geometry.id
kwargs.setdefault("boundary_conditions", {"Vx": 10.01, "Vy": 0.0009})
kwargs.setdefault("state", "successful")
kwargs.setdefault("confidence_score", None)
kwargs.setdefault("raw_confidence_score", None)
prediction = simai_client._prediction_directory._model_from(kwargs)
if post_processings is not None:
# If we passed post-processings as parameter,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ def test_sse_event_prediction_success(sse_mixin, prediction_factory):
assert pred.is_pending
# Mock a SSE success event
updated_record = pred.fields.copy()
updated_record.update({"state": "successful", "confidence_score": "abysmal"})
updated_record.update({"state": "successful", "confidence_score": "high"})
sse_mixin._handle_sse_event(
create_sse_event(
f'{{"target": {{"id": "{pred.id}", "type": "prediction"}}, "type": "job", "record": {json.dumps(updated_record)}}}'
)
)
assert not pred.is_pending
assert pred.is_ready
assert pred.confidence_score == "abysmal"
assert pred.confidence_score == "high"


def test_sse_event_update_prediction_failure(sse_mixin, prediction_factory):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import pytest
import responses

from ansys.simai.core.data.geometries import Geometry
Expand Down Expand Up @@ -190,3 +191,20 @@ def test_run_no_bc(simai_client, geometry_factory):
)
geometry = geometry_factory(id="geom-0")
simai_client.predictions.run(geometry.id)


@responses.activate
def test_confidence_score(prediction_factory):
"""WHEN accessing a Prediction's confidence score properties
THEN the corresponding values are returned
"""
prediction = prediction_factory(confidence_score="high", raw_confidence_score=0.94107)
empty_prediction = prediction_factory()
bad_prediction = prediction_factory(confidence_score="abysmal")

assert prediction.confidence_score == "high"
assert prediction.raw_confidence_score == 0.94
assert empty_prediction.confidence_score == empty_prediction.raw_confidence_score is None
kliment-slice marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError) as exc:
assert bad_prediction.confidence_score
assert str(exc.value) == "Must be None or one of: 'high', 'low'."
Loading