Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cybereason pagination support + multi-threading #707

Merged
merged 11 commits into from
Sep 29, 2023
309 changes: 243 additions & 66 deletions msticpy/data/drivers/cybereason_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,30 @@
"""Cybereason Driver class."""
import datetime as dt
import json
import logging
import re
from functools import singledispatch
from asyncio import as_completed, Future
from concurrent.futures import ThreadPoolExecutor
from functools import partial, singledispatch
from typing import Any, Dict, List, Optional, Tuple, Union

import httpx
import pandas as pd
from tqdm.auto import tqdm

from ..._version import VERSION
from ...common.exceptions import MsticpyUserConfigError
from ...common.provider_settings import ProviderArgs, get_provider_settings
from ...common.utility import mp_ua_header
from ..core.query_defns import Formatters
from ..core.query_provider_connections_mixin import _get_event_loop
from .driver_base import DriverBase, DriverProps, QuerySource

__version__ = VERSION
__author__ = "Florian Bracq"

logger = logging.getLogger(__name__)

_HELP_URI = (
"https://msticpy.readthedocs.io/en/latest/data_acquisition/DataProviders.html"
)
Expand Down Expand Up @@ -66,8 +73,9 @@ def __init__(self, **kwargs):
"""
super().__init__(**kwargs)
timeout = kwargs.get("timeout", 120) # 2 minutes in milliseconds
max_results = min(kwargs.get("max_results", 1000), 10000)
page_size = min(kwargs.get("page_size", 100), 100)
logger.debug("Set timeout to %d", timeout)
max_results = min(kwargs.get("max_results", 100000), 100000)
logger.debug("Set maximum results to %d", max_results)
self.base_url: str = "https://{tenant_id}.cybereason.net"
self.auth_endpoint: str = "/login.html"
self.req_body: Dict[str, Any] = {
Expand All @@ -77,7 +85,6 @@ def __init__(self, **kwargs):
"perFeatureLimit": 100,
"templateContext": "SPECIFIC",
"queryTimeout": timeout * 1000,
"pagination": {"pageSize": page_size},
"customFields": [],
}
self.search_endpoint: str = "/rest/visualsearch/query/simple"
Expand All @@ -96,6 +103,10 @@ def __init__(self, **kwargs):
},
)

self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True)
self.set_driver_property(
DriverProps.MAX_PARALLEL, value=kwargs.get("max_threads", 4)
)
self._debug = kwargs.get("debug", False)

def query(
Expand All @@ -118,10 +129,91 @@ def query(
the underlying provider result if an error.

"""
data, response = self.query_with_results(query)
if isinstance(data, pd.DataFrame):
return data
return response
del query_source
if not self._connected:
raise self._create_not_connected_err(self.__class__.__name__)

page_size = min(kwargs.get("page_size", 2000), 4000)
logger.debug("Set page size to %d", page_size)
json_query = json.loads(query)
body = {**self.req_body, **json_query}

# The query must be executed at least once to retrieve the number
# of results and the pagination token.
response = self.__execute_query(body, page_size=page_size)

total_results = response["data"]["totalResults"]
pagination_token = response["data"]["paginationToken"]
results: Dict[str, Any] = response["data"]["resultIdToElementDataMap"]

logger.debug("Retrieved %d/%d results", len(results), total_results)

df_result: pd.DataFrame = None

if len(results) < total_results:
df_result = self._exec_paginated_queries(
body=body,
page_size=page_size,
pagination_token=pagination_token,
total_results=total_results,
)
else:
df_result = self._format_result_to_dataframe(result=response)

return df_result

def _exec_paginated_queries(
self,
body: Dict[str, Any],
page_size: int,
pagination_token: str,
total_results: int,
**kwargs,
) -> pd.DataFrame:
"""
Return results of paginated queries.

Parameters
----------
body : Dict[str, Any]
The body of the query to execute.

Additional Parameters
----------
progress: bool, optional
Show progress bar, by default True
retry_on_error: bool, optional
Retry failed queries, by default False
**kwargs : Dict[str, Any]
Additional keyword arguments to pass to the query method.

Returns
-------
pd.DataFrame
The concatenated results of all the paginated queries.

Notes
-----
This method executes the specified query multiple times to retrieve
all the data from paginated results.
The queries are executed asynchronously.

"""
progress = kwargs.pop("progress", True)
retry = kwargs.pop("retry_on_error", False)

query_tasks = self._create_paginated_query_tasks(
body=body,
page_size=page_size,
pagination_token=pagination_token,
total_results=total_results,
)

logger.info("Running %s paginated queries.", len(query_tasks))
event_loop = _get_event_loop()
return event_loop.run_until_complete(
self.__run_threaded_queries(query_tasks, progress, retry)
)

def connect(
self,
Expand Down Expand Up @@ -277,6 +369,136 @@ def _flatten_element_values(
result[f"{key}.{subkey}"] = subvalues
return result

def _create_paginated_query_tasks(
self,
body: Dict[str, Any],
page_size: int,
pagination_token: str,
total_results: int,
) -> Dict[str, partial]:
"""Return dictionary of partials to execute queries."""
# Compute the number of queries to execute
total_pages = total_results // page_size + 1
# The first query (page 0) as to be re-run due to a bug in
# Cybereason API. The first query returns less results than the page size
# when executed without a pagination token.
return {
f"{page}": partial(
self.__execute_query,
body=body,
page_size=page_size,
pagination_token=pagination_token,
page=page,
)
for page in range(0, total_pages)
}

def __execute_query(
self,
body: Dict[str, Any],
page: int = 0,
page_size: int = 2000,
pagination_token: str = None,
) -> Dict[str, Any]:
"""
Run query with pagination enabled.

Parameters
----------
body: Dict[str, Any]
Body of the HTTP Request
page_size: int
Size of the page for results
page: int
Page number to query
pagination_token: str
Token of the current search

Returns
-------
Dict[str, Any]

"""
if pagination_token:
pagination = {
"pagination": {
"pageSize": page_size,
"page": page + 1,
"paginationToken": pagination_token,
"skip": page * page_size,
}
}
headers = {"Pagination-Token": pagination_token}
else:
pagination = {"pagination": {"pageSize": page_size}}
headers = {}
params = {"page": page, "itemsPerPage": page_size}
status = None
while status != "SUCCESS":
response = self.client.post(
self.search_endpoint,
json={**body, **pagination},
headers=headers,
params=params,
)
response.raise_for_status()
json_result = response.json()
status = json_result["status"]
return json_result

async def __run_threaded_queries(
self,
query_tasks: Dict[str, partial],
progress: bool = True,
retry: bool = False,
) -> pd.DataFrame:
logger.info("Running %d threaded queries.", len(query_tasks))
event_loop = _get_event_loop()
with ThreadPoolExecutor(max_workers=4) as executor:
results: List[pd.DataFrame] = []
failed_tasks: Dict[str, Future] = {}
thread_tasks = {
query_id: event_loop.run_in_executor(executor, query_func)
for query_id, query_func in query_tasks.items()
}
if progress:
task_iter = tqdm(
as_completed(thread_tasks.values()),
unit="paginated-queries",
desc="Running",
)
else:
task_iter = as_completed(thread_tasks.values())
ids_and_tasks = dict(zip(thread_tasks, task_iter))
for query_id, thread_task in ids_and_tasks.items():
try:
result = await thread_task
df_result = self._format_result_to_dataframe(result)
logger.info("Query task '%s' completed successfully.", query_id)
results.append(df_result)
except Exception: # pylint: disable=broad-except
logger.warning(
"Query task '%s' failed with exception", query_id, exc_info=True
)
failed_tasks[query_id] = thread_task

if retry and failed_tasks:
for query_id, thread_task in failed_tasks.items():
try:
logger.info("Retrying query task '%s'", query_id)
result = await thread_task
df_result = self._format_result_to_dataframe(result)
results.append(df_result)
except Exception: # pylint: disable=broad-except
logger.warning(
"Retried query task '%s' failed with exception",
query_id,
exc_info=True,
)
# Sort the results by the order of the tasks
results = [result for _, result in sorted(zip(thread_tasks, results))]
return pd.concat(results, ignore_index=True)

# pylint: disable=too-many-branches
def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]:
"""
Expand All @@ -294,64 +516,7 @@ def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]:
Kql ResultSet.

"""
if not self.connected:
self.connect(self.current_connection)
if not self.connected:
raise ConnectionError(
"Source is not connected. ", "Please call connect() and retry."
)

if self._debug:
print(query)

json_query = json.loads(query)
body = self.req_body
body.update(json_query)
response = self.client.post(self.search_endpoint, json=body)

self._check_response_errors(response)

json_response = response.json()
if json_response["status"] != "SUCCESS":
print(
"Warning - query did not complete successfully.",
f"Status: {json_response['status']}.",
json_response["message"],
)
return pd.DataFrame(), json_response

data = json_response.get("data", json_response)
results = data.get("resultIdToElementDataMap", data)
total_results = data.get("totalResults", len(results))
guessed_results = data.get("guessedPossibleResults", len(results))
if guessed_results > len(results):
print(
f"Warning - query returned {total_results} out of {guessed_results}.",
"Check returned response.",
)
results = [
dict(CybereasonDriver._flatten_result(values), **{"resultId": result_id})
for result_id, values in results.items()
]

return pd.json_normalize(results), json_response

# pylint: enable=too-many-branches

@staticmethod
def _check_response_errors(response):
"""Check the response for possible errors."""
if response.status_code == httpx.codes.OK:
return
print(response.json()["error"]["message"])
if response.status_code == 401:
raise ConnectionRefusedError(
"Authentication failed - possible ", "timeout. Please re-connect."
)
# Raise an exception to handle hitting API limits
if response.status_code == 429:
raise ConnectionRefusedError("You have likely hit the API limit. ")
response.raise_for_status()
raise NotImplementedError(f"Not supported for {self.__class__.__name__}")

# Parameter Formatting method
@staticmethod
Expand All @@ -373,6 +538,18 @@ def _format_to_datetime(timestamp: int) -> Union[dt.datetime, int]:
except TypeError:
return timestamp

@staticmethod
def _format_result_to_dataframe(result: Dict[str, Any]) -> pd.DataFrame:
"""Return a dataframe from a cybereason result object."""
df_result = [
dict(
CybereasonDriver._flatten_result(values),
**{"resultId": result_id},
)
for result_id, values in result["data"]["resultIdToElementDataMap"].items()
]
return pd.json_normalize(df_result)

# Retrieve configuration parameters with aliases
@staticmethod
def _map_config_dict_name(config_dict: Dict[str, str]):
Expand Down
Loading