Skip to content

Commit

Permalink
Reformatted static API and renamed new method
Browse files Browse the repository at this point in the history
  • Loading branch information
mturk24 committed Oct 29, 2024
1 parent 1cdc0dd commit 5015254
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions cleanlab_studio/studio/enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def _response_timestamp_to_datetime(timestamp_string: str) -> datetime:


@lru_cache(maxsize=None)
def _get_online_inference():
from cleanlab_studio.utils.data_enrichment.enrich import online_inference
return online_inference
def _get_run_online():
from cleanlab_studio.utils.data_enrichment.enrich import run_online

return run_online


class EnrichmentProject:
Expand Down Expand Up @@ -349,9 +350,11 @@ def list_all_jobs(self) -> List[EnrichmentJob]:
id=job["id"],
status=job["status"],
created_at=_response_timestamp_to_datetime(job["created_at"]),
updated_at=_response_timestamp_to_datetime(job["updated_at"])
if job["updated_at"]
else None,
updated_at=(
_response_timestamp_to_datetime(job["updated_at"])
if job["updated_at"]
else None
),
enrichment_options=EnrichmentOptions(**enrichment_options_dict), # type: ignore
average_trustworthiness_score=job["average_trustworthiness_score"],
job_type=job["type"],
Expand Down Expand Up @@ -406,7 +409,7 @@ def resume(self) -> JSONDict:
latest_job = self._get_latest_job()
return api.resume_enrichment_job(api_key=self._api_key, job_id=latest_job["id"])

def online_inference(
def run_online(
self,
data: Union[pd.DataFrame, List[dict]],
options: EnrichmentOptions,
Expand All @@ -423,8 +426,8 @@ def online_inference(
Returns:
Dict[str, Any]: A dictionary containing information about the enrichment job and the enriched dataset.
"""
online_inference = _get_online_inference()
job_info = online_inference(data, options, new_column_name, None)
run_online = _get_run_online()
job_info = run_online(data, options, new_column_name, self._api_key)
return job_info


Expand Down Expand Up @@ -639,7 +642,3 @@ def _handle_replacements_and_extraction_pattern(
else:
raise ValueError(REGEX_PARAMETER_ERROR_MESSAGE)
return extraction_pattern, replacements




0 comments on commit 5015254

Please sign in to comment.