diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a794c8889..e957ab5c3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.2.0 + rev: v5.0.0 hooks: - id: check-yaml - id: check-json @@ -8,19 +8,19 @@ repos: - id: trailing-whitespace args: [--markdown-linebreak-ext=md] - repo: https://github.com/ambv/black - rev: 24.4.2 + rev: 24.10.0 hooks: - id: black language: python - repo: https://github.com/PyCQA/pylint - rev: v3.2.2 + rev: v3.3.1 hooks: - id: pylint args: - --disable=duplicate-code,import-error - --ignore-patterns=test_ - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.1.1 hooks: - id: flake8 args: @@ -28,7 +28,7 @@ repos: - --max-line-length=90 - --exclude=tests,test*.py - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: isort (python) @@ -43,7 +43,7 @@ repos: - --convention=numpy - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.7.0 + rev: v0.8.0 hooks: # Run the linter. - id: ruff diff --git a/msticpy/context/azure/azure_data.py b/msticpy/context/azure/azure_data.py index 8c048aef9..2eb3cbd68 100644 --- a/msticpy/context/azure/azure_data.py +++ b/msticpy/context/azure/azure_data.py @@ -818,21 +818,22 @@ def get_network_details( details.network_security_group.id.split("/")[8], ) nsg_rules = [] - for nsg in nsg_details.default_security_rules: - rules = asdict( - NsgItems( - rule_name=nsg.name, - description=nsg.description, - protocol=str(nsg.protocol), - direction=str(nsg.direction), - src_ports=nsg.source_port_range, - dst_ports=nsg.destination_port_range, - src_addrs=nsg.source_address_prefix, - dst_addrs=nsg.destination_address_prefix, - action=str(nsg.access), - ), - ) - nsg_rules.append(rules) + if nsg_details is not None: + for nsg in nsg_details.default_security_rules: # type: ignore + rules = asdict( + NsgItems( + rule_name=nsg.name, + description=nsg.description, + protocol=str(nsg.protocol), + direction=str(nsg.direction), + src_ports=nsg.source_port_range, + dst_ports=nsg.destination_port_range, + src_addrs=nsg.source_address_prefix, + dst_addrs=nsg.destination_address_prefix, + action=str(nsg.access), + ), + ) + nsg_rules.append(rules) nsg_df = pd.DataFrame(nsg_rules) diff --git a/msticpy/context/azure/sentinel_utils.py b/msticpy/context/azure/sentinel_utils.py index cc1c6530f..179b659c2 100644 --- a/msticpy/context/azure/sentinel_utils.py +++ b/msticpy/context/azure/sentinel_utils.py @@ -15,7 +15,7 @@ import pandas as pd from azure.common.exceptions import CloudError from azure.mgmt.core import tools as az_tools -from typing_extensions import Self +from typing_extensions import Dict, Self, cast from ..._version import VERSION from ...auth.azure_auth_core import AzureCloudConfig @@ -339,7 +339,9 @@ def parse_resource_id(res_id: str) -> dict[str, Any]: """Extract components from workspace resource ID.""" if not res_id.startswith("/"): res_id = f"/{res_id}" - res_id_parts: dict[str, Any] = az_tools.parse_resource_id(res_id) + res_id_parts: Dict[str, str] = cast( + Dict[str, str], az_tools.parse_resource_id(res_id) + ) workspace_name: str | None = None if ( res_id_parts.get("namespace") == "Microsoft.OperationalInsights" diff --git a/msticpy/context/azure/sentinel_watchlists.py b/msticpy/context/azure/sentinel_watchlists.py index df8af7e80..a730366e3 100644 --- a/msticpy/context/azure/sentinel_watchlists.py +++ b/msticpy/context/azure/sentinel_watchlists.py @@ -224,12 +224,14 @@ def add_watchlist_item( axis=1, copy=False, ) - if (current_df == item_series).all(axis=1).any() and overwrite: + if (current_df == item_series).all( + axis=1 + ).any() and overwrite: # type: ignore[attr-defined] watchlist_id: str = current_items[ current_items.isin(list(new_item.values())).any(axis=1) ]["properties.watchlistItemId"].iloc[0] # If not in watchlist already generate new ID - elif not (current_df == item_series).all(axis=1).any(): + elif not (current_df == item_series).all(axis=1).any(): # type: ignore[attr-defined] watchlist_id = str(uuid4()) else: err_msg = "Item already exists in the watchlist. Set overwrite = True to replace." diff --git a/msticpy/context/contextlookup.py b/msticpy/context/contextlookup.py index f197fbbfa..631f2f352 100644 --- a/msticpy/context/contextlookup.py +++ b/msticpy/context/contextlookup.py @@ -171,7 +171,7 @@ async def _lookup_observables_async( # pylint:disable=too-many-arguments # noqa ) -> pd.DataFrame: """Lookup items async.""" return await self._lookup_items_async( - data, + data, # type: ignore[arg-type] item_col=obs_col, item_type_col=obs_type_col, query_type=query_type, diff --git a/msticpy/context/geoip.py b/msticpy/context/geoip.py index 1243db783..ed1b0dd87 100644 --- a/msticpy/context/geoip.py +++ b/msticpy/context/geoip.py @@ -588,7 +588,7 @@ def lookup_ip( geo_match = self._get_geomatch_non_public(ip_type) elif self._reader: try: - geo_match = self._reader.city(ip_input).raw + geo_match = self._reader.city(ip_input).raw # type: ignore except (AddressNotFoundError, AttributeError, ValueError): continue if geo_match: diff --git a/msticpy/context/ip_utils.py b/msticpy/context/ip_utils.py index 0f8796a12..3744530eb 100644 --- a/msticpy/context/ip_utils.py +++ b/msticpy/context/ip_utils.py @@ -201,8 +201,11 @@ def create_ip_record( ip_entity.SubscriptionId = ip_hb["SubscriptionId"] geoloc_entity: GeoLocation = GeoLocation() geoloc_entity.CountryOrRegionName = ip_hb["RemoteIPCountry"] - geoloc_entity.Longitude = ip_hb["RemoteIPLongitude"] - geoloc_entity.Latitude = ip_hb["RemoteIPLatitude"] + try: + geoloc_entity.Longitude = float(ip_hb["RemoteIPLongitude"]) + geoloc_entity.Latitude = float(ip_hb["RemoteIPLatitude"]) + except TypeError: + pass ip_entity.Location = geoloc_entity # If Azure network data present add this to host record @@ -493,7 +496,7 @@ def ip_whois( for ip_addr in ip: if rate_limit: sleep(query_rate) - whois_results[ip_addr] = _whois_lookup( + whois_results[ip_addr] = _whois_lookup( # type: ignore[index] ip_addr, raw=raw, retry_count=retry_count, diff --git a/msticpy/context/provider_base.py b/msticpy/context/provider_base.py index b7348849e..1f5bed6c5 100644 --- a/msticpy/context/provider_base.py +++ b/msticpy/context/provider_base.py @@ -20,7 +20,7 @@ from asyncio import get_event_loop from collections.abc import Iterable as C_Iterable from functools import lru_cache, partial, singledispatch -from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generator, Iterable +from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generator, Iterable, cast import pandas as pd from typing_extensions import Self @@ -368,7 +368,7 @@ def resolve_item_type(item: str) -> str: async def _lookup_items_async_wrapper( # pylint: disable=too-many-arguments # noqa: PLR0913 self: Self, - data: pd.DataFrame | dict[str, str] | list[str], + data: pd.DataFrame | dict[str, str] | Iterable[str], item_col: str | None = None, item_type_col: str | None = None, query_type: str | None = None, @@ -395,7 +395,7 @@ async def _lookup_items_async_wrapper( # pylint: disable=too-many-arguments # n If not specified the default record type for the IoitemC type will be returned. prog_counter: ProgressCounter; Optional - Progress Counter to display progess of IOC searches. + Progress Counter to display progress of IOC searches. Returns ------- @@ -413,7 +413,7 @@ async def _lookup_items_async_wrapper( # pylint: disable=too-many-arguments # n ) result: pd.DataFrame = await event_loop.run_in_executor(None, get_items) if prog_counter: - await prog_counter.decrement(len(data)) + await prog_counter.decrement(len(data)) # type: ignore[arg-type] return result @@ -466,7 +466,7 @@ def generate_items( if isinstance(data, C_Iterable): for item in data: - yield item, Provider.resolve_item_type(item) + yield cast(str, item), Provider.resolve_item_type(item) else: yield None, None diff --git a/msticpy/context/tilookup.py b/msticpy/context/tilookup.py index 7d02fa0be..869975ea0 100644 --- a/msticpy/context/tilookup.py +++ b/msticpy/context/tilookup.py @@ -233,7 +233,7 @@ async def _lookup_iocs_async( # pylint: disable=too-many-arguments #noqa:PLR091 ) -> pd.DataFrame: """Lookup IoCs async.""" return await self._lookup_items_async( - data, + data, # type: ignore[arg-type] item_col=ioc_col, item_type_col=ioc_type_col, query_type=ioc_query_type, diff --git a/msticpy/context/vtlookupv3/vtfile_behavior.py b/msticpy/context/vtlookupv3/vtfile_behavior.py index 41a548af8..4d9694839 100644 --- a/msticpy/context/vtlookupv3/vtfile_behavior.py +++ b/msticpy/context/vtlookupv3/vtfile_behavior.py @@ -146,7 +146,9 @@ def __init__( file_summary = file_summary.iloc[0] if isinstance(file_summary, pd.Series): file_summary = file_summary.to_dict() - self.file_summary: dict[str, Any] = file_summary or {} + self.file_summary: pd.Series | dict[str, pd.Timestamp] | dict[str, Any] = ( + file_summary or {} + ) self.file_id: str | Any | None = file_id or self.file_summary.get("id") self._file_behavior: dict[str, Any] = {} @@ -406,7 +408,7 @@ def _try_match_commandlines( and np.isnan(row["cmd_line"]) and row["name"] in cmd ): - procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore[reportCallIssue] + procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore break for cmd in command_executions: for idx, row in procs_cmd.iterrows(): @@ -416,7 +418,7 @@ def _try_match_commandlines( and Path(row["name"]).stem.lower() in cmd.lower() ): weak_matches += 1 - procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore[reportCallIssue] + procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore break if weak_matches: diff --git a/msticpy/context/vtlookupv3/vtlookup.py b/msticpy/context/vtlookupv3/vtlookup.py index 4d1dd9848..5762db6e1 100644 --- a/msticpy/context/vtlookupv3/vtlookup.py +++ b/msticpy/context/vtlookupv3/vtlookup.py @@ -823,7 +823,7 @@ def _add_invalid_input_result( new_row["Status"] = status new_row["SourceIndex"] = source_idx new_results: pd.DataFrame = self.results.append( - new_row.to_dict(), + new_row.to_dict(), # type: ignore[operator] ignore_index=True, ) diff --git a/msticpy/context/vtlookupv3/vtlookupv3.py b/msticpy/context/vtlookupv3/vtlookupv3.py index 352cecca1..58c4a168b 100644 --- a/msticpy/context/vtlookupv3/vtlookupv3.py +++ b/msticpy/context/vtlookupv3/vtlookupv3.py @@ -1222,7 +1222,7 @@ def _get_vt_api_key() -> str | None: def timestamps_to_utcdate(data: pd.DataFrame) -> pd.DataFrame: """Replace Unix timestamps in VT data with Py/pandas Timestamp.""" - columns: pd.Index[str] = data.columns + columns: pd.Index = data.columns for date_col in ( col for col in columns if isinstance(col, str) and col.endswith("_date") ): diff --git a/msticpy/data/drivers/cybereason_driver.py b/msticpy/data/drivers/cybereason_driver.py index 2750b010d..7334e1c94 100644 --- a/msticpy/data/drivers/cybereason_driver.py +++ b/msticpy/data/drivers/cybereason_driver.py @@ -398,6 +398,7 @@ def _create_paginated_query_tasks( def __execute_query( self, body: Dict[str, Any], + *, page: int = 0, page_size: int = 2000, pagination_token: str = None, @@ -406,6 +407,8 @@ def __execute_query( """ Run query with pagination enabled. + :raises httpx.HTTPStatusError: if max_retry reached + Parameters ---------- body: Dict[str, Any] @@ -416,6 +419,8 @@ def __execute_query( Page number to query pagination_token: str Token of the current search + max_retry: int + Maximum retries in case of API no cuccess response Returns ------- @@ -449,6 +454,14 @@ def __execute_query( json_result = response.json() status = json_result["status"] cur_try += 1 + + if cur_try >= max_retry: + raise httpx.HTTPStatusError( + f"{status}: {json_result['message']}", + request=response.request, + response=response, + ) + return json_result async def __run_threaded_queries( diff --git a/mypy.ini b/mypy.ini index e31cc92db..021d7e8a0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -143,3 +143,12 @@ ignore_missing_imports = True [mypy-requests.*] ignore_missing_imports = True + +[mypy-autogen.*] +ignore_missing_imports = True + +[mypy-notebookutils.*] +ignore_missing_imports = True + +[mypy-mo_sql_parsing.*] +ignore_missing_imports = True diff --git a/tests/data/drivers/test_cybereason_driver.py b/tests/data/drivers/test_cybereason_driver.py index 0dc8f8d66..317b91aad 100644 --- a/tests/data/drivers/test_cybereason_driver.py +++ b/tests/data/drivers/test_cybereason_driver.py @@ -12,6 +12,7 @@ import pytest import pytest_check as check import respx +import httpx from msticpy.data.core.query_defns import Formatters from msticpy.data.drivers.cybereason_driver import CybereasonDriver @@ -140,6 +141,46 @@ }, ] +_CR_PARTIAL_SUCCESS_RESULT = { + "data": { + "resultIdToElementDataMap": {}, + "suspicionsMap": {}, + "evidenceMap": {}, + "totalResults": 0, + "totalPossibleResults": 0, + "guessedPossibleResults": 0, + "queryLimits": { + "totalResultLimit": 1000, + "perGroupLimit": 100, + "perFeatureLimit": 100, + "groupingFeature": { + "elementInstanceType": "Process", + "featureName": "imageFileHash", + }, + "sortInGroupFeature": None, + }, + "queryTerminated": False, + "pathResultCounts": None, + "guids": [], + "paginationToken": None, + "executionUUID": None, + "quapiMeasurementData": { + "timeToGetGuids": [], + "timeToGetData": [], + "timeToGetAdditionalData": [], + "totalQuapiQueryTime": [], + "startTime": [], + "endTime": [], + }, + }, + "status": "PARTIAL_SUCCESS", + "hidePartialSuccess": False, + "message": "Received Non-OK status code HTTP/1.1 500 Internal Server Error", + "expectedResults": 0, + "failures": 0, + "failedServersInfo": None, +} + _CR_QUERY = { "query": """ { @@ -241,6 +282,31 @@ def test_query(driver): check.is_true(query.called) check.is_instance(data, pd.DataFrame) +@respx.mock +def test_partial_success_query(driver): + """Test query calling returns data in expected format.""" + connect = respx.post( + re.compile(r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/login\.html") + ).respond(200) + query = respx.post( + re.compile( + r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/rest/visualsearch/query/simple" + ) + ) + query.side_effect = [ + httpx.Response(200, json=_CR_PARTIAL_SUCCESS_RESULT), + httpx.Response(200, json=_CR_PARTIAL_SUCCESS_RESULT), + httpx.Response(200, json=_CR_PARTIAL_SUCCESS_RESULT), + httpx.Response(200, json=_CR_PARTIAL_SUCCESS_RESULT), + ] + with custom_mp_config(MP_PATH): + with pytest.raises(httpx.HTTPStatusError, match=r"PARTIAL_SUCCESS:.*"): + driver.connect() + driver.query('{"test": "test"}') + + check.is_true(connect.called or driver.connected) + check.is_true(query.called) + check.equal(query.call_count, 3) @respx.mock def test_paginated_query(driver):