Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add query parameter filtering to main entities #626

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
18 changes: 17 additions & 1 deletion cli/medperf/commands/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,29 @@ def list(
False, "--unregistered", help="Get unregistered benchmarks"
),
mine: bool = typer.Option(False, "--mine", help="Get current-user benchmarks"),
name: str = typer.Option(None, "--name", help="Filter by name"),
owner: int = typer.Option(None, "--owner", help="Filter by owner"),
state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"),
is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"),
is_active: bool = typer.Option(None, "--active/--inactive", help="Filter by active status"),
data_prep: int = typer.Option(None, "-d", "--data-preparation-mlcube", help="Filter by Data Preparation MLCube"),
):
"""List benchmarks"""
filters = {
"name": name,
"owner": owner,
"state": state,
"is_valid": is_valid,
"is_active": is_active,
"data_preparation_mlcube": data_prep
}

EntityList.run(
Benchmark,
fields=["UID", "Name", "Description", "State", "Approval Status", "Registered"],
fields=["UID", "Name", "Description", "Data Preparation MLCube", "State", "Approval Status", "Registered"],
unregistered=unregistered,
mine_only=mine,
**filters,
)


Expand Down
8 changes: 8 additions & 0 deletions cli/medperf/commands/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def list(
mlcube: int = typer.Option(
None, "--mlcube", "-m", help="Get datasets for a given data prep mlcube"
),
name: str = typer.Option(None, "--name", help="Filter by name"),
owner: int = typer.Option(None, "--owner", help="Filter by owner"),
state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"),
is_valid: bool = typer.Option(None, "--valid/--invalid", help="Filter by valid status"),
):
"""List datasets"""
EntityList.run(
Expand All @@ -32,6 +36,10 @@ def list(
unregistered=unregistered,
mine_only=mine,
mlcube=mlcube,
name=name,
owner=owner,
state=state,
is_valid=is_valid,
)


Expand Down
3 changes: 2 additions & 1 deletion cli/medperf/commands/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def run(
Args:
unregistered (bool, optional): Display only local unregistered results. Defaults to False.
mine_only (bool, optional): Display all registered current-user results. Defaults to False.
kwargs (dict): Additional parameters for filtering entity lists.
kwargs (dict): Additional parameters for filtering entity lists. Keys with None will be filtered out.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
entity_list = EntityList(
entity_class, fields, unregistered, mine_only, **kwargs
)
Expand Down
8 changes: 8 additions & 0 deletions cli/medperf/commands/mlcube/mlcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ def list(
False, "--unregistered", help="Get unregistered mlcubes"
),
mine: bool = typer.Option(False, "--mine", help="Get current-user mlcubes"),
name: str = typer.Option(None, "--name", "-n", help="Filter out by MLCube Name"),
owner: int = typer.Option(None, "--owner", help="Filter by owner ID"),
state: str = typer.Option(None, "--state", help="Filter by state (DEVELOPMENT/OPERATION)"),
is_active: bool = typer.Option(None, "--active/--inactive", help="Filter by active status"),
):
"""List mlcubes"""
EntityList.run(
Cube,
fields=["UID", "Name", "State", "Registered"],
unregistered=unregistered,
mine_only=mine,
name=name,
owner=owner,
state=state,
is_active=is_active,
)


Expand Down
8 changes: 8 additions & 0 deletions cli/medperf/commands/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def list(
benchmark: int = typer.Option(
None, "--benchmark", "-b", help="Get results for a given benchmark"
),
model: int = typer.Option(
None, "--owner", "-o", help="Get results for a given model"
),
dataset: int = typer.Option(
None, "--dataset", "-d", help="Get reuslts for a given dataset"
),
):
"""List results"""
EntityList.run(
Expand All @@ -77,6 +83,8 @@ def list(
unregistered=unregistered,
mine_only=mine,
benchmark=benchmark,
model=model,
dataset=dataset,
)


Expand Down
70 changes: 43 additions & 27 deletions cli/medperf/comms/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __get_list(
page_size=config.default_page_size,
offset=0,
binary_reduction=False,
filters={},
):
"""Retrieves a list of elements from a URL by iterating over pages until num_elements is obtained.
If num_elements is None, then iterates until all elements have been retrieved.
Expand All @@ -104,7 +105,9 @@ def __get_list(
num_elements = float("inf")

while len(el_list) < num_elements:
paginated_url = f"{url}?limit={page_size}&offset={offset}"
filters.update({"limit": page_size, "offset": offset})
query_str = "&".join([f"{k}={v}" for k, v in filters.items()])
paginated_url = f"{url}?{query_str}"
res = self.__auth_get(paginated_url)
if res.status_code != 200:
if not binary_reduction:
Expand Down Expand Up @@ -152,13 +155,13 @@ def get_current_user(self):
res = self.__auth_get(f"{self.server_url}/me/")
return res.json()

def get_benchmarks(self) -> List[dict]:
def get_benchmarks(self, filters={}) -> List[dict]:
"""Retrieves all benchmarks in the platform.

Returns:
List[dict]: all benchmarks information.
"""
bmks = self.__get_list(f"{self.server_url}/benchmarks/")
bmks = self.__get_list(f"{self.server_url}/benchmarks/", filters=filters)
return bmks

def get_benchmark(self, benchmark_uid: int) -> dict:
Expand All @@ -179,7 +182,7 @@ def get_benchmark(self, benchmark_uid: int) -> dict:
)
return res.json()

def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]:
def get_benchmark_model_associations(self, benchmark_uid: int, filters={}) -> List[int]:
"""Retrieves all the model associations of a benchmark.

Args:
Expand All @@ -188,25 +191,28 @@ def get_benchmark_model_associations(self, benchmark_uid: int) -> List[int]:
Returns:
list[int]: List of benchmark model associations
"""
assocs = self.__get_list(f"{self.server_url}/benchmarks/{benchmark_uid}/models")
assocs = self.__get_list(
f"{self.server_url}/benchmarks/{benchmark_uid}/models",
filters=filters,
)
return filter_latest_associations(assocs, "model_mlcube")

def get_user_benchmarks(self) -> List[dict]:
def get_user_benchmarks(self, filters={}) -> List[dict]:
"""Retrieves all benchmarks created by the user

Returns:
List[dict]: Benchmarks data
"""
bmks = self.__get_list(f"{self.server_url}/me/benchmarks/")
bmks = self.__get_list(f"{self.server_url}/me/benchmarks/", filters=filters)
return bmks

def get_cubes(self) -> List[dict]:
def get_cubes(self, filters={}) -> List[dict]:
"""Retrieves all MLCubes in the platform

Returns:
List[dict]: List containing the data of all MLCubes
"""
cubes = self.__get_list(f"{self.server_url}/mlcubes/")
cubes = self.__get_list(f"{self.server_url}/mlcubes/", filters=filters)
return cubes

def get_cube_metadata(self, cube_uid: int) -> dict:
Expand All @@ -227,13 +233,13 @@ def get_cube_metadata(self, cube_uid: int) -> dict:
)
return res.json()

def get_user_cubes(self) -> List[dict]:
def get_user_cubes(self, filters={}) -> List[dict]:
"""Retrieves metadata from all cubes registered by the user

Returns:
List[dict]: List of dictionaries containing the mlcubes registration information
"""
cubes = self.__get_list(f"{self.server_url}/me/mlcubes/")
cubes = self.__get_list(f"{self.server_url}/me/mlcubes/", filters=filters)
return cubes

def upload_benchmark(self, benchmark_dict: dict) -> int:
Expand Down Expand Up @@ -268,13 +274,13 @@ def upload_mlcube(self, mlcube_body: dict) -> int:
raise CommunicationRetrievalError(f"Could not upload the mlcube: {details}")
return res.json()

def get_datasets(self) -> List[dict]:
def get_datasets(self, filters={}) -> List[dict]:
"""Retrieves all datasets in the platform

Returns:
List[dict]: List of data from all datasets
"""
dsets = self.__get_list(f"{self.server_url}/datasets/")
dsets = self.__get_list(f"{self.server_url}/datasets/", filters=filters)
return dsets

def get_dataset(self, dset_uid: int) -> dict:
Expand All @@ -295,13 +301,13 @@ def get_dataset(self, dset_uid: int) -> dict:
)
return res.json()

def get_user_datasets(self) -> dict:
def get_user_datasets(self, filters={}) -> dict:
"""Retrieves all datasets registered by the user

Returns:
dict: dictionary with the contents of each dataset registration query
"""
dsets = self.__get_list(f"{self.server_url}/me/datasets/")
dsets = self.__get_list(f"{self.server_url}/me/datasets/", filters=filters)
return dsets

def upload_dataset(self, reg_dict: dict) -> int:
Expand All @@ -320,13 +326,13 @@ def upload_dataset(self, reg_dict: dict) -> int:
raise CommunicationRequestError(f"Could not upload the dataset: {details}")
return res.json()

def get_results(self) -> List[dict]:
def get_results(self, filters={}) -> List[dict]:
"""Retrieves all results

Returns:
List[dict]: List of results
"""
res = self.__get_list(f"{self.server_url}/results")
res = self.__get_list(f"{self.server_url}/results", filters=filters)
if res.status_code != 200:
log_response_error(res)
details = format_errors_dict(res.json())
Expand All @@ -351,16 +357,16 @@ def get_result(self, result_uid: int) -> dict:
)
return res.json()

def get_user_results(self) -> dict:
def get_user_results(self, filters={}) -> dict:
"""Retrieves all results registered by the user

Returns:
dict: dictionary with the contents of each result registration query
"""
results = self.__get_list(f"{self.server_url}/me/results/")
results = self.__get_list(f"{self.server_url}/me/results/", filters=filters)
return results

def get_benchmark_results(self, benchmark_id: int) -> dict:
def get_benchmark_results(self, benchmark_id: int, filters={}) -> dict:
"""Retrieves all results for a given benchmark

Args:
Expand All @@ -370,7 +376,8 @@ def get_benchmark_results(self, benchmark_id: int) -> dict:
dict: dictionary with the contents of each result in the specified benchmark
"""
results = self.__get_list(
f"{self.server_url}/benchmarks/{benchmark_id}/results"
f"{self.server_url}/benchmarks/{benchmark_id}/results",
filters=filters,
)
return results

Expand Down Expand Up @@ -472,22 +479,28 @@ def set_mlcube_association_approval(
f"Could not approve association between mlcube {mlcube_uid} and benchmark {benchmark_uid}: {details}"
)

def get_datasets_associations(self) -> List[dict]:
def get_datasets_associations(self, filters={}) -> List[dict]:
"""Get all dataset associations related to the current user

Returns:
List[dict]: List containing all associations information
"""
assocs = self.__get_list(f"{self.server_url}/me/datasets/associations/")
assocs = self.__get_list(
f"{self.server_url}/me/datasets/associations/",
filters=filters,
)
return filter_latest_associations(assocs, "dataset")

def get_cubes_associations(self) -> List[dict]:
def get_cubes_associations(self, filters={}) -> List[dict]:
"""Get all cube associations related to the current user

Returns:
List[dict]: List containing all associations information
"""
assocs = self.__get_list(f"{self.server_url}/me/mlcubes/associations/")
assocs = self.__get_list(
f"{self.server_url}/me/mlcubes/associations/",
filters=filters,
)
return filter_latest_associations(assocs, "model_mlcube")

def set_mlcube_association_priority(
Expand Down Expand Up @@ -519,7 +532,7 @@ def update_dataset(self, dataset_id: int, data: dict):
raise CommunicationRequestError(f"Could not update dataset: {details}")
return res.json()

def get_mlcube_datasets(self, mlcube_id: int) -> dict:
def get_mlcube_datasets(self, mlcube_id: int, filters={}) -> dict:
"""Retrieves all datasets that have the specified mlcube as the prep mlcube

Args:
Expand All @@ -529,7 +542,10 @@ def get_mlcube_datasets(self, mlcube_id: int) -> dict:
dict: dictionary with the contents of each dataset
"""

datasets = self.__get_list(f"{self.server_url}/mlcubes/{mlcube_id}/datasets/")
datasets = self.__get_list(
f"{self.server_url}/mlcubes/{mlcube_id}/datasets/",
filters=filters,
)
return datasets

def get_user(self, user_id: int) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion cli/medperf/entities/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def all(
@classmethod
def __remote_all(cls: Type[EntityType], filters: dict) -> List[EntityType]:
comms_fn = cls.remote_prefilter(filters)
entity_meta = comms_fn()
entity_meta = comms_fn(filters)
entities = [cls(**meta) for meta in entity_meta]
return entities

Expand Down
1 change: 1 addition & 0 deletions cli/medperf/entities/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def remote_prefilter(filters: dict) -> callable:
comms_fn = config.comms.get_results
if "owner" in filters and filters["owner"] == get_medperf_user_data()["id"]:
comms_fn = config.comms.get_user_results
del filters["owner"]
if "benchmark" in filters and filters["benchmark"] is not None:
bmk = filters["benchmark"]

Expand Down
Loading
Loading