Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cybereason driver fix http429 tests and exception #803

Merged
merged 11 commits into from
Nov 23, 2024
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: check-json
exclude: .*devcontainer.json
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- repo: https://github.com/ambv/black
rev: 24.4.2
rev: 24.10.0
hooks:
- id: black
language: python
- repo: https://github.com/PyCQA/pylint
rev: v3.2.2
rev: v3.3.1
hooks:
- id: pylint
args:
- --disable=duplicate-code,import-error
- --ignore-patterns=test_
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 7.1.1
hooks:
- id: flake8
args:
- --extend-ignore=E401,E501,W503
- --max-line-length=90
- --exclude=tests,test*.py
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
Expand All @@ -43,7 +43,7 @@ repos:
- --convention=numpy
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.0
rev: v0.8.0
hooks:
# Run the linter.
- id: ruff
Expand Down
31 changes: 16 additions & 15 deletions msticpy/context/azure/azure_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,21 +818,22 @@ def get_network_details(
details.network_security_group.id.split("/")[8],
)
nsg_rules = []
for nsg in nsg_details.default_security_rules:
rules = asdict(
NsgItems(
rule_name=nsg.name,
description=nsg.description,
protocol=str(nsg.protocol),
direction=str(nsg.direction),
src_ports=nsg.source_port_range,
dst_ports=nsg.destination_port_range,
src_addrs=nsg.source_address_prefix,
dst_addrs=nsg.destination_address_prefix,
action=str(nsg.access),
),
)
nsg_rules.append(rules)
if nsg_details is not None:
for nsg in nsg_details.default_security_rules: # type: ignore
rules = asdict(
NsgItems(
rule_name=nsg.name,
description=nsg.description,
protocol=str(nsg.protocol),
direction=str(nsg.direction),
src_ports=nsg.source_port_range,
dst_ports=nsg.destination_port_range,
src_addrs=nsg.source_address_prefix,
dst_addrs=nsg.destination_address_prefix,
action=str(nsg.access),
),
)
nsg_rules.append(rules)

nsg_df = pd.DataFrame(nsg_rules)

Expand Down
6 changes: 4 additions & 2 deletions msticpy/context/azure/sentinel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pandas as pd
from azure.common.exceptions import CloudError
from azure.mgmt.core import tools as az_tools
from typing_extensions import Self
from typing_extensions import Dict, Self, cast

from ..._version import VERSION
from ...auth.azure_auth_core import AzureCloudConfig
Expand Down Expand Up @@ -339,7 +339,9 @@ def parse_resource_id(res_id: str) -> dict[str, Any]:
"""Extract components from workspace resource ID."""
if not res_id.startswith("/"):
res_id = f"/{res_id}"
res_id_parts: dict[str, Any] = az_tools.parse_resource_id(res_id)
res_id_parts: Dict[str, str] = cast(
Dict[str, str], az_tools.parse_resource_id(res_id)
)
workspace_name: str | None = None
if (
res_id_parts.get("namespace") == "Microsoft.OperationalInsights"
Expand Down
6 changes: 4 additions & 2 deletions msticpy/context/azure/sentinel_watchlists.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,14 @@ def add_watchlist_item(
axis=1,
copy=False,
)
if (current_df == item_series).all(axis=1).any() and overwrite:
if (current_df == item_series).all(
axis=1
).any() and overwrite: # type: ignore[attr-defined]
watchlist_id: str = current_items[
current_items.isin(list(new_item.values())).any(axis=1)
]["properties.watchlistItemId"].iloc[0]
# If not in watchlist already generate new ID
elif not (current_df == item_series).all(axis=1).any():
elif not (current_df == item_series).all(axis=1).any(): # type: ignore[attr-defined]
watchlist_id = str(uuid4())
else:
err_msg = "Item already exists in the watchlist. Set overwrite = True to replace."
Expand Down
2 changes: 1 addition & 1 deletion msticpy/context/contextlookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def _lookup_observables_async( # pylint:disable=too-many-arguments # noqa
) -> pd.DataFrame:
"""Lookup items async."""
return await self._lookup_items_async(
data,
data, # type: ignore[arg-type]
item_col=obs_col,
item_type_col=obs_type_col,
query_type=query_type,
Expand Down
2 changes: 1 addition & 1 deletion msticpy/context/geoip.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def lookup_ip(
geo_match = self._get_geomatch_non_public(ip_type)
elif self._reader:
try:
geo_match = self._reader.city(ip_input).raw
geo_match = self._reader.city(ip_input).raw # type: ignore
except (AddressNotFoundError, AttributeError, ValueError):
continue
if geo_match:
Expand Down
9 changes: 6 additions & 3 deletions msticpy/context/ip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,11 @@ def create_ip_record(
ip_entity.SubscriptionId = ip_hb["SubscriptionId"]
geoloc_entity: GeoLocation = GeoLocation()
geoloc_entity.CountryOrRegionName = ip_hb["RemoteIPCountry"]
geoloc_entity.Longitude = ip_hb["RemoteIPLongitude"]
geoloc_entity.Latitude = ip_hb["RemoteIPLatitude"]
try:
geoloc_entity.Longitude = float(ip_hb["RemoteIPLongitude"])
geoloc_entity.Latitude = float(ip_hb["RemoteIPLatitude"])
except TypeError:
pass
ip_entity.Location = geoloc_entity

# If Azure network data present add this to host record
Expand Down Expand Up @@ -493,7 +496,7 @@ def ip_whois(
for ip_addr in ip:
if rate_limit:
sleep(query_rate)
whois_results[ip_addr] = _whois_lookup(
whois_results[ip_addr] = _whois_lookup( # type: ignore[index]
ip_addr,
raw=raw,
retry_count=retry_count,
Expand Down
10 changes: 5 additions & 5 deletions msticpy/context/provider_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from asyncio import get_event_loop
from collections.abc import Iterable as C_Iterable
from functools import lru_cache, partial, singledispatch
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generator, Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generator, Iterable, cast

import pandas as pd
from typing_extensions import Self
Expand Down Expand Up @@ -368,7 +368,7 @@ def resolve_item_type(item: str) -> str:

async def _lookup_items_async_wrapper( # pylint: disable=too-many-arguments # noqa: PLR0913
self: Self,
data: pd.DataFrame | dict[str, str] | list[str],
data: pd.DataFrame | dict[str, str] | Iterable[str],
item_col: str | None = None,
item_type_col: str | None = None,
query_type: str | None = None,
Expand All @@ -395,7 +395,7 @@ async def _lookup_items_async_wrapper( # pylint: disable=too-many-arguments # n
If not specified the default record type for the IoitemC type
will be returned.
prog_counter: ProgressCounter; Optional
Progress Counter to display progess of IOC searches.
Progress Counter to display progress of IOC searches.

Returns
-------
Expand All @@ -413,7 +413,7 @@ async def _lookup_items_async_wrapper( # pylint: disable=too-many-arguments # n
)
result: pd.DataFrame = await event_loop.run_in_executor(None, get_items)
if prog_counter:
await prog_counter.decrement(len(data))
await prog_counter.decrement(len(data)) # type: ignore[arg-type]
return result


Expand Down Expand Up @@ -466,7 +466,7 @@ def generate_items(

if isinstance(data, C_Iterable):
for item in data:
yield item, Provider.resolve_item_type(item)
yield cast(str, item), Provider.resolve_item_type(item)
else:
yield None, None

Expand Down
2 changes: 1 addition & 1 deletion msticpy/context/tilookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ async def _lookup_iocs_async( # pylint: disable=too-many-arguments #noqa:PLR091
) -> pd.DataFrame:
"""Lookup IoCs async."""
return await self._lookup_items_async(
data,
data, # type: ignore[arg-type]
item_col=ioc_col,
item_type_col=ioc_type_col,
query_type=ioc_query_type,
Expand Down
8 changes: 5 additions & 3 deletions msticpy/context/vtlookupv3/vtfile_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def __init__(
file_summary = file_summary.iloc[0]
if isinstance(file_summary, pd.Series):
file_summary = file_summary.to_dict()
self.file_summary: dict[str, Any] = file_summary or {}
self.file_summary: pd.Series | dict[str, pd.Timestamp] | dict[str, Any] = (
file_summary or {}
)
self.file_id: str | Any | None = file_id or self.file_summary.get("id")

self._file_behavior: dict[str, Any] = {}
Expand Down Expand Up @@ -406,7 +408,7 @@ def _try_match_commandlines(
and np.isnan(row["cmd_line"])
and row["name"] in cmd
):
procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore[reportCallIssue]
procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore
break
for cmd in command_executions:
for idx, row in procs_cmd.iterrows():
Expand All @@ -416,7 +418,7 @@ def _try_match_commandlines(
and Path(row["name"]).stem.lower() in cmd.lower()
):
weak_matches += 1
procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore[reportCallIssue]
procs_cmd.loc[idx, "cmd_line"] = cmd # type: ignore
break

if weak_matches:
Expand Down
2 changes: 1 addition & 1 deletion msticpy/context/vtlookupv3/vtlookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def _add_invalid_input_result(
new_row["Status"] = status
new_row["SourceIndex"] = source_idx
new_results: pd.DataFrame = self.results.append(
new_row.to_dict(),
new_row.to_dict(), # type: ignore[operator]
ignore_index=True,
)

Expand Down
2 changes: 1 addition & 1 deletion msticpy/context/vtlookupv3/vtlookupv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,7 @@ def _get_vt_api_key() -> str | None:

def timestamps_to_utcdate(data: pd.DataFrame) -> pd.DataFrame:
"""Replace Unix timestamps in VT data with Py/pandas Timestamp."""
columns: pd.Index[str] = data.columns
columns: pd.Index = data.columns
for date_col in (
col for col in columns if isinstance(col, str) and col.endswith("_date")
):
Expand Down
13 changes: 13 additions & 0 deletions msticpy/data/drivers/cybereason_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def _create_paginated_query_tasks(
def __execute_query(
self,
body: Dict[str, Any],
*,
page: int = 0,
page_size: int = 2000,
pagination_token: str = None,
Expand All @@ -406,6 +407,8 @@ def __execute_query(
"""
Run query with pagination enabled.
:raises httpx.HTTPStatusError: if max_retry reached
Parameters
----------
body: Dict[str, Any]
Expand All @@ -416,6 +419,8 @@ def __execute_query(
Page number to query
pagination_token: str
Token of the current search
max_retry: int
Maximum retries in case of API no cuccess response
Returns
-------
Expand Down Expand Up @@ -449,6 +454,14 @@ def __execute_query(
json_result = response.json()
status = json_result["status"]
cur_try += 1

if cur_try >= max_retry:
raise httpx.HTTPStatusError(
f"{status}: {json_result['message']}",
request=response.request,
response=response,
)

return json_result

async def __run_threaded_queries(
Expand Down
9 changes: 9 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,12 @@ ignore_missing_imports = True

[mypy-requests.*]
ignore_missing_imports = True

[mypy-autogen.*]
ignore_missing_imports = True

[mypy-notebookutils.*]
ignore_missing_imports = True

[mypy-mo_sql_parsing.*]
ignore_missing_imports = True
Loading