diff --git a/caveclient/materializationengine.py b/caveclient/materializationengine.py index a74f5f92..25e83fc1 100644 --- a/caveclient/materializationengine.py +++ b/caveclient/materializationengine.py @@ -53,7 +53,7 @@ def deserialize_query_response(response): ) else: raise ValueError( - f'Unknown response type: {response.headers.get("Content-Type")}' + f"Unknown response type: {response.headers.get('Content-Type')}" ) @@ -228,6 +228,42 @@ def __init__( self._tables = None self._views = None + @cached(cache=TTLCache(maxsize=1, ttl=60 * 60 * 1)) + def available_versions(self, datastack_name=None) -> list[int]: + """Get the available versions for this materialization client.""" + return sorted(self.get_versions(expired=False, datastack_name=datastack_name)) + + def _materialization_available( + self, materialization_version, datastack_name + ) -> bool: + "Check if a materialization version is available to query." + return materialization_version in self.available_versions( + datastack_name=datastack_name + ) + + def _assign_datastack(self, datastack_name: Optional[str]) -> str: + """Assign the datastack name to the client.""" + if datastack_name is None: + return self.datastack_name + else: + return datastack_name + + def _assign_version(self, version: Optional[int]) -> int: + """Assign the version to the client.""" + if version is None: + return self.version + else: + return int(version) + + def _assign_desired_resolution( + self, desired_resolution: Optional[Iterable[float]] + ) -> Iterable[float]: + """Assign the desired resolution to the client.""" + if desired_resolution is None: + return self.desired_resolution + else: + return desired_resolution + @property def datastack_name(self): """The name of the datastack.""" @@ -315,8 +351,7 @@ def get_versions(self, datastack_name=None, expired=False) -> dict: dict Dictionary of versions available """ - if datastack_name is None: - datastack_name = self.datastack_name + datastack_name = self._assign_datastack(datastack_name) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name url = self._endpoints["versions"].format_map(endpoint_mapping) @@ -343,10 +378,8 @@ def get_tables(self, datastack_name=None, version: Optional[int] = None) -> list list List of table names """ - if datastack_name is None: - datastack_name = self.datastack_name - if version is None: - version = self.version + datastack_name = self._assign_datastack(datastack_name) + version = self._assign_version(version) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name endpoint_mapping["version"] = version @@ -357,10 +390,31 @@ def get_tables(self, datastack_name=None, version: Optional[int] = None) -> list return response.json() def get_annotation_count(self, table_name: str, datastack_name=None, version=None): - if datastack_name is None: - datastack_name = self.datastack_name - if version is None: - version = self.version + """ + Get the count of annotations in a table for a given datastack and version + + Parameters + ---------- + table_name : str + Name of the table to get the count for + datastack_name : str or None, optional + Name of the datastack, by default None. + If None, uses the one specified in the client. + version : int or None, optional + The version of the datastack to query. If None, will query the client + `version`, which defaults to the most recent version. + + Returns + ------- + int + Count of annotations in the table + """ + datastack_name = self._assign_datastack(datastack_name) + version = self._assign_version(version) + if not self._materialization_available(version, datastack_name=datastack_name): + raise ValueError( + f"Annotation count must use a materialized version ({self.available_versions(datastack_name)})." + ) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name endpoint_mapping["table_name"] = table_name @@ -390,10 +444,8 @@ def get_version_metadata( dict Dictionary of metadata about the version """ - if datastack_name is None: - datastack_name = self.datastack_name - if version is None: - version = self.version + datastack_name = self._assign_datastack(datastack_name) + version = self._assign_version(version) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name @@ -406,6 +458,7 @@ def get_version_metadata( d["expires_on"] = convert_timestamp(d["expires_on"]) return d + @cached(cache=TTLCache(maxsize=50, ttl=60 * 60 * 24)) def get_timestamp( self, version: Optional[int] = None, datastack_name: str = None ) -> datetime: @@ -424,6 +477,8 @@ def get_timestamp( datetime.datetime Datetime when the materialization version was frozen. """ + datastack_name = self._assign_datastack(datastack_name) + version = self._assign_version(version) meta = self.get_version_metadata(version=version, datastack_name=datastack_name) return convert_timestamp(meta["time_stamp"]) @@ -444,8 +499,7 @@ def get_versions_metadata(self, datastack_name=None, expired=False) -> list[dict List of metadata dictionaries """ - if datastack_name is None: - datastack_name = self.datastack_name + datastack_name = self._assign_datastack(datastack_name) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name url = self._endpoints["versions_metadata"].format_map(endpoint_mapping) @@ -485,10 +539,8 @@ def get_table_metadata( Metadata dictionary for table """ - if datastack_name is None: - datastack_name = self.datastack_name - if version is None: - version = self.version + datastack_name = self._assign_datastack(datastack_name) + version = self._assign_version(version) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name endpoint_mapping["table_name"] = table_name @@ -739,8 +791,18 @@ def query_table( A pandas dataframe of results of query """ - if desired_resolution is None: - desired_resolution = self.desired_resolution + desired_resolution = self._assign_desired_resolution(desired_resolution) + datastack_name = self._assign_datastack(datastack_name) + materialization_version = self._assign_version(materialization_version) + if not self._materialization_available( + materialization_version=materialization_version, + datastack_name=datastack_name, + ): + # Treat as timestamp query on the materialization version's timestamp + timestamp = self.get_timestamp( + version=materialization_version, datastack_name=datastack_name + ) + materialization_version = None if timestamp is not None: if materialization_version is not None: raise ValueError("cannot specify timestamp and materialization version") @@ -769,102 +831,107 @@ def query_table( random_sample=random_sample, log_warning=log_warning, ) - if materialization_version is None: - materialization_version = self.version - if datastack_name is None: - datastack_name = self.datastack_name - - tables, suffix_map = self._resolve_merge_reference( - merge_reference, table, datastack_name, materialization_version - ) + else: + tables, suffix_map = self._resolve_merge_reference( + merge_reference, table, datastack_name, materialization_version + ) - url, data, query_args, encoding = self._format_query_components( - datastack_name, - materialization_version, - tables, - select_columns, - suffix_map, - {table: filter_in_dict} if filter_in_dict is not None else None, - {table: filter_out_dict} if filter_out_dict is not None else None, - {table: filter_equal_dict} if filter_equal_dict is not None else None, - {table: filter_greater_dict} if filter_greater_dict is not None else None, - {table: filter_less_dict} if filter_less_dict is not None else None, - {table: filter_greater_equal_dict} - if filter_greater_equal_dict is not None - else None, - {table: filter_less_equal_dict} - if filter_less_equal_dict is not None - else None, - {table: filter_spatial_dict} if filter_spatial_dict is not None else None, - {table: filter_regex_dict} if filter_regex_dict is not None else None, - return_df, - True, - offset, - limit, - desired_resolution, - random_sample=random_sample, - ) - if get_counts: - query_args["count"] = True + url, data, query_args, encoding = self._format_query_components( + datastack_name, + materialization_version, + tables, + select_columns, + suffix_map, + {table: filter_in_dict} if filter_in_dict is not None else None, + {table: filter_out_dict} if filter_out_dict is not None else None, + {table: filter_equal_dict} if filter_equal_dict is not None else None, + {table: filter_greater_dict} + if filter_greater_dict is not None + else None, + {table: filter_less_dict} if filter_less_dict is not None else None, + {table: filter_greater_equal_dict} + if filter_greater_equal_dict is not None + else None, + {table: filter_less_equal_dict} + if filter_less_equal_dict is not None + else None, + {table: filter_spatial_dict} + if filter_spatial_dict is not None + else None, + {table: filter_regex_dict} if filter_regex_dict is not None else None, + return_df, + True, + offset, + limit, + desired_resolution, + random_sample=random_sample, + ) + if get_counts: + query_args["count"] = True - response = self.session.post( - url, - data=json.dumps(data, cls=BaseEncoder), - headers={"Content-Type": "application/json", "Accept-Encoding": encoding}, - params=query_args, - stream=~return_df, - ) - self.raise_for_status(response, log_warning=log_warning) - if return_df: - with warnings.catch_warnings(): - warnings.simplefilter(action="ignore", category=FutureWarning) - warnings.simplefilter(action="ignore", category=DeprecationWarning) - df = deserialize_query_response(response) - if desired_resolution is not None: - if not response.headers.get("dataframe_resolution", None): - if len(desired_resolution) != 3: - raise ValueError( - "desired resolution needs to be of length 3, for xyz" + response = self.session.post( + url, + data=json.dumps(data, cls=BaseEncoder), + headers={ + "Content-Type": "application/json", + "Accept-Encoding": encoding, + }, + params=query_args, + stream=~return_df, + ) + self.raise_for_status(response, log_warning=log_warning) + if return_df: + with warnings.catch_warnings(): + warnings.simplefilter(action="ignore", category=FutureWarning) + warnings.simplefilter(action="ignore", category=DeprecationWarning) + df = deserialize_query_response(response) + if desired_resolution is not None: + if not response.headers.get("dataframe_resolution", None): + if len(desired_resolution) != 3: + raise ValueError( + "desired resolution needs to be of length 3, for xyz" + ) + vox_res = self.get_table_metadata( + table, + datastack_name, + materialization_version, + log_warning=False, + )["voxel_resolution"] + df = convert_position_columns( + df, vox_res, desired_resolution ) - vox_res = self.get_table_metadata( - table, - datastack_name, - materialization_version, - log_warning=False, - )["voxel_resolution"] - df = convert_position_columns(df, vox_res, desired_resolution) - if metadata: - attrs = self._assemble_attributes( - tables, - filters={ - "inclusive": filter_in_dict, - "exclusive": filter_out_dict, - "equal": filter_equal_dict, - "greater": filter_greater_dict, - "less": filter_less_dict, - "greater_equal": filter_greater_equal_dict, - "less_equal": filter_less_equal_dict, - "spatial": filter_spatial_dict, - "regex": filter_regex_dict, - }, - select_columns=select_columns, - offset=offset, - limit=limit, - live_query=timestamp is not None, - timestamp=string_format_timestamp(timestamp), - materialization_version=materialization_version, - desired_resolution=response.headers.get( - "dataframe_resolution", desired_resolution - ), - column_names=response.headers.get("column_names", None), - ) - df.attrs.update(attrs) - if split_positions: - return df + if metadata: + attrs = self._assemble_attributes( + tables, + filters={ + "inclusive": filter_in_dict, + "exclusive": filter_out_dict, + "equal": filter_equal_dict, + "greater": filter_greater_dict, + "less": filter_less_dict, + "greater_equal": filter_greater_equal_dict, + "less_equal": filter_less_equal_dict, + "spatial": filter_spatial_dict, + "regex": filter_regex_dict, + }, + select_columns=select_columns, + offset=offset, + limit=limit, + live_query=timestamp is not None, + timestamp=string_format_timestamp(timestamp), + materialization_version=materialization_version, + desired_resolution=response.headers.get( + "dataframe_resolution", desired_resolution + ), + column_names=response.headers.get("column_names", None), + ) + df.attrs.update(attrs) + if split_positions: + return df + else: + return concatenate_position_columns(df, inplace=True) else: - return concatenate_position_columns(df, inplace=True) - else: - return response.json() + return response.json() @_check_version_compatibility( kwarg_use_constraints={ @@ -973,12 +1040,15 @@ def join_query( a pandas dataframe of results of query """ - if materialization_version is None: - materialization_version = self.version - if datastack_name is None: - datastack_name = self.datastack_name - if desired_resolution is None: - desired_resolution = self.desired_resolution + materialization_version = self._assign_version(materialization_version) + datastack_name = self._assign_datastack(datastack_name) + desired_resolution = self._assign_desired_resolution(desired_resolution) + if not self._materialization_available( + materialization_version, datastack_name=datastack_name + ): + raise ValueError( + f"Cannot use `join_query` for a non-materialized version. Please use a materialized version ({self.available_versions(datastack_name)}) or live_live_query." + ) url, data, query_args, encoding = self._format_query_components( datastack_name, materialization_version, @@ -1178,7 +1248,7 @@ def _update_rootids(self, df: pd.DataFrame, timestamp: datetime, future_map: dic all_svid_lengths.append(n_svids) logger.info(f"{sv_col} has {n_svids} to update") all_svids = np.append(all_svids, svids[~is_latest_root]) - logger.info(f"num zero svids: {np.sum(all_svids==0)}") + logger.info(f"num zero svids: {np.sum(all_svids == 0)}") logger.info(f"all_svids dtype {all_svids.dtype}") logger.info(f"all_svid_lengths {all_svid_lengths}") with MyTimeIt("get_roots"): @@ -1224,8 +1294,7 @@ def ingest_annotation_table( dict Status code of response from server """ - if datastack_name is None: - datastack_name = self.datastack_name + datastack_name = self._assign_datastack(datastack_name) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name @@ -1258,8 +1327,7 @@ def lookup_supervoxel_ids( dict Status code of response from server """ - if datastack_name is None: - datastack_name = self.datastack_name + datastack_name = self._assign_datastack(datastack_name) if annotation_ids is not None: data = {"annotation_ids": annotation_ids} @@ -1386,10 +1454,8 @@ def live_query( if self.cg_client is None: raise ValueError("You must have a cg_client to run live_query") - if datastack_name is None: - datastack_name = self.datastack_name - if desired_resolution is None: - desired_resolution = self.desired_resolution + datastack_name = self._assign_datastack(datastack_name) + desired_resolution = self._assign_desired_resolution(desired_resolution) with MyTimeIt("find_mat_version"): # we want to find the most recent materialization # in which the timestamp given is in the future @@ -1882,10 +1948,9 @@ def get_tables_metadata( : Metadata dictionary for table """ - if datastack_name is None: - datastack_name = self.datastack_name - if version is None: - version = self.version + datastack_name = self._assign_datastack(datastack_name) + version = self._assign_version(version) + endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name endpoint_mapping["version"] = version @@ -2070,10 +2135,11 @@ def live_live_query( "Deprecation: this method is to facilitate beta testing of this feature, \ it will likely get removed in future versions. " ) + datastack_name = self._assign_datastack(datastack_name) + desired_resolution = self._assign_desired_resolution(desired_resolution) + timestamp = convert_timestamp(timestamp) return_df = True - if datastack_name is None: - datastack_name = self.datastack_name endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name @@ -2119,8 +2185,6 @@ def live_live_query( data["limit"] = limit if suffixes is not None: data["suffixes"] = suffixes - if desired_resolution is None: - desired_resolution = self.desired_resolution if Version(str(self.api_version)) >= Version("3"): if desired_resolution is not None: data["desired_resolution"] = desired_resolution @@ -2221,10 +2285,13 @@ def get_views( list List of views """ - if datastack_name is None: - datastack_name = self.datastack_name - if version is None: - version = self.version + datastack_name = self._assign_datastack(datastack_name) + version = self._assign_version(version) + if not self._materialization_available(version, datastack_name=datastack_name): + raise ValueError( + f"Materialization version must not be expired for views. " + f"Available versions: {self.available_versions(datastack_name)}" + ) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name endpoint_mapping["version"] = version @@ -2258,10 +2325,16 @@ def get_view_metadata( dict Metadata of view """ - if datastack_name is None: - datastack_name = self.datastack_name - if materialization_version is None: - materialization_version = self.version + datastack_name = self._assign_datastack(datastack_name) + materialization_version = self._assign_version(materialization_version) + + if not self._materialization_available( + materialization_version, datastack_name=datastack_name + ): + raise ValueError( + f"Materialization version must not be expired for view metadata query. " + f"Available versions: {self.available_versions(datastack_name)}" + ) endpoint_mapping = self.default_url_mapping endpoint_mapping["view_name"] = view_name @@ -2298,10 +2371,15 @@ def get_view_schema( dict Schema of view. """ - if datastack_name is None: - datastack_name = self.datastack_name - if materialization_version is None: - materialization_version = self.version + datastack_name = self._assign_datastack(datastack_name) + materialization_version = self._assign_version(materialization_version) + if not self._materialization_available( + materialization_version, datastack_name=datastack_name + ): + raise ValueError( + f"Materialization version must not be expired for view schema query. " + f"Available versions: {self.available_versions(datastack_name)}" + ) endpoint_mapping = self.default_url_mapping endpoint_mapping["view_name"] = view_name @@ -2334,10 +2412,16 @@ def get_view_schemas( dict Schema of view. """ - if datastack_name is None: - datastack_name = self.datastack_name - if materialization_version is None: - materialization_version = self.version + datastack_name = self._assign_datastack(datastack_name) + materialization_version = self._assign_version(materialization_version) + + if not self._materialization_available( + materialization_version, datastack_name=datastack_name + ): + raise ValueError( + f"Materialization version must not be expired for view schema query. " + f"Available versions: {self.available_versions(datastack_name)}" + ) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name @@ -2446,12 +2530,16 @@ def query_view( A pandas dataframe of results of query """ - if desired_resolution is None: - desired_resolution = self.desired_resolution - if materialization_version is None: - materialization_version = self.version - if datastack_name is None: - datastack_name = self.datastack_name + datastack_name = self._assign_datastack(datastack_name) + materialization_version = self._assign_version(materialization_version) + desired_resolution = self._assign_desired_resolution(desired_resolution) + if not self._materialization_available( + materialization_version, datastack_name=datastack_name + ): + raise ValueError( + "Materialization version must not be expired for view query. " + f"Available versions: {self.available_versions(datastack_name=datastack_name)}" + ) url, data, query_args, encoding = self._format_query_components( datastack_name, @@ -2552,8 +2640,7 @@ def get_unique_string_values( dict[str] A dictionary of column names and their unique values """ - if datastack_name is None: - datastack_name = self.datastack_name + datastack_name = self._assign_datastack(datastack_name) endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name diff --git a/caveclient/tools/testing.py b/caveclient/tools/testing.py index b4759922..1b9e5923 100644 --- a/caveclient/tools/testing.py +++ b/caveclient/tools/testing.py @@ -27,6 +27,7 @@ TEST_LOCAL_SERVER = os.environ.get("TEST_LOCAL_SERVER", "https://local.cave.com") TEST_DATASTACK = os.environ.get("TEST_DATASTACK", "test_stack") DEFAULT_MATERIALIZATION_VERSONS = [1, 2] +DEFAULT_EXPIRED_MATERIALIZATION_VERSONS = [3, 4] DEFAULT_MATERIALIZATION_VERSION_METADATA = { "time_stamp": "2024-06-05T10:10:01.203215", @@ -254,6 +255,7 @@ def CAVEclientMock( materialization_server_version: str = DEFAULT_MATERIALIZATION_SERVER_VERSON, materialization_api_versions: Optional[list] = None, available_materialization_versions: Optional[list] = None, + expired_materialization_versions: Optional[list] = None, set_version: Optional[int] = None, set_version_metadata: Optional[dict] = None, json_service: bool = False, @@ -300,6 +302,10 @@ def CAVEclientMock( available_materialization_versions : list, optional List of materialization database versions that the materialization client thinks exists, by default None. If None, returns the value in DEFAULT_MATERIALIZATION_VERSONS. + expired_materialization_versions : list, optional + List of materialization database versions that the materialization client thinks are expired, by default None. + If None, returns the value in DEFAULT_EXPIRED_MATERIALIZATION_VERSONS. + If the same value is in both available and expired, it is treated as available. materialization_api_versions : list, optional List of materialization API versions that the materialization client thinks exists, by default None. If None, returns the value in MATERIALIZATION_API_VERSIONS. @@ -396,6 +402,8 @@ def fancier_test_client(): info_file = default_info(local_server) if available_materialization_versions is None: available_materialization_versions = DEFAULT_MATERIALIZATION_VERSONS + if expired_materialization_versions is None: + expired_materialization_versions = DEFAULT_EXPIRED_MATERIALIZATION_VERSONS if set_version_metadata is None: set_version_metadata = DEFAULT_MATERIALIZATION_VERSION_METADATA if chunkedgraph_api_versions is None: @@ -404,6 +412,14 @@ def fancier_test_client(): materialization_api_versions = MATERIALIZATION_API_VERSIONS if schema_api_versions is None: schema_api_versions = SCHEMA_API_VERSIONS + all_materialization_versions = sorted( + list( + set( + available_materialization_versions + + expired_materialization_versions + ) + ) + ) @responses.activate() def mockedCAVEclient(): @@ -479,11 +495,17 @@ def mockedCAVEclient(): responses.add( responses.GET, mat_version_list_url, - json=available_materialization_versions, + json=all_materialization_versions, status=200, match=[query_param_matcher({"expired": True})], ) - + responses.add( + responses.GET, + mat_version_list_url, + json=available_materialization_versions, + status=200, + match=[query_param_matcher({"expired": False})], + ) if set_version is not None: mat_mapping["version"] = set_version version_metadata_url = mat_endpoints["version_metadata"].format_map( @@ -533,6 +555,15 @@ def mockedCAVEclient(): client.chunkedgraph if materialization: client.materialize + responses.add( + responses.GET, + mat_version_list_url, + json=available_materialization_versions, + status=200, + match=[query_param_matcher({"expired": False})], + ) + client.materialize.available_versions(datastack_name=datastack_name) + client.materialize.version if json_service: client.state if skeleton_service: diff --git a/tests/conftest.py b/tests/conftest.py index a5367c34..3889c0a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,8 @@ def myclient(): json_service=True, skeleton_service=True, l2cache=True, + available_materialization_versions=[1], + expired_materialization_versions=[2], **datastack_dict, ) @@ -78,7 +80,6 @@ def version_specified_client(): json_service=True, skeleton_service=True, l2cache=True, - available_materialization_versions=[1, 2, 3], set_version=3, **server_versions, ) diff --git a/tests/test_materialization.py b/tests/test_materialization.py index 86211d3e..fcdf9b0f 100644 --- a/tests/test_materialization.py +++ b/tests/test_materialization.py @@ -448,14 +448,6 @@ def test_matclient_v3_tableinterface(self, myclient, mocker): @responses.activate def test_matclient(self, myclient, mocker): endpoint_mapping = self.default_mapping - api_versions_url = chunkedgraph_endpoints_common["get_api_versions"].format_map( - endpoint_mapping - ) - responses.add(responses.GET, url=api_versions_url, json=[0, 1], status=200) - - versionurl = self.endpoints["versions"].format_map(endpoint_mapping) - - responses.add(responses.GET, url=versionurl, json=[1], status=200) url = self.endpoints["simple_query"].format_map(endpoint_mapping) syn_md_url = self.endpoints["metadata"].format_map(endpoint_mapping) @@ -870,9 +862,9 @@ def mock_get_root_timestamps(self, root_ids): @responses.activate def test_get_view_metadata(myclient): - datastack_name = "test_datastack" + datastack_name = datastack_dict["datastack_name"] view_name = "test_view" - materialization_version = 1 + materialization_version = myclient.materialize.version url = f"{datastack_dict['local_server']}/materialize/api/v3/datastack/{datastack_name}/version/{materialization_version}/views/{view_name}/metadata" # Mock the response @@ -892,7 +884,7 @@ def test_get_view_metadata(myclient): @responses.activate def test_get_unique_string_values(myclient): - datastack_name = "test_datastack" + datastack_name = datastack_dict["datastack_name"] table_name = "test_table" url = f"{datastack_dict['local_server']}/materialize/api/v3/datastack/{datastack_name}/table/{table_name}/unique_string_values"