diff --git a/msticpy/data/drivers/cybereason_driver.py b/msticpy/data/drivers/cybereason_driver.py index ad8115fb3..e414e9d52 100644 --- a/msticpy/data/drivers/cybereason_driver.py +++ b/msticpy/data/drivers/cybereason_driver.py @@ -6,23 +6,30 @@ """Cybereason Driver class.""" import datetime as dt import json +import logging import re -from functools import singledispatch +from asyncio import as_completed, Future +from concurrent.futures import ThreadPoolExecutor +from functools import partial, singledispatch from typing import Any, Dict, List, Optional, Tuple, Union import httpx import pandas as pd +from tqdm.auto import tqdm from ..._version import VERSION from ...common.exceptions import MsticpyUserConfigError from ...common.provider_settings import ProviderArgs, get_provider_settings from ...common.utility import mp_ua_header from ..core.query_defns import Formatters +from ..core.query_provider_connections_mixin import _get_event_loop from .driver_base import DriverBase, DriverProps, QuerySource __version__ = VERSION __author__ = "Florian Bracq" +logger = logging.getLogger(__name__) + _HELP_URI = ( "https://msticpy.readthedocs.io/en/latest/data_acquisition/DataProviders.html" ) @@ -66,8 +73,9 @@ def __init__(self, **kwargs): """ super().__init__(**kwargs) timeout = kwargs.get("timeout", 120) # 2 minutes in milliseconds - max_results = min(kwargs.get("max_results", 1000), 10000) - page_size = min(kwargs.get("page_size", 100), 100) + logger.debug("Set timeout to %d", timeout) + max_results = min(kwargs.get("max_results", 100000), 100000) + logger.debug("Set maximum results to %d", max_results) self.base_url: str = "https://{tenant_id}.cybereason.net" self.auth_endpoint: str = "/login.html" self.req_body: Dict[str, Any] = { @@ -77,7 +85,6 @@ def __init__(self, **kwargs): "perFeatureLimit": 100, "templateContext": "SPECIFIC", "queryTimeout": timeout * 1000, - "pagination": {"pageSize": page_size}, "customFields": [], } self.search_endpoint: str = "/rest/visualsearch/query/simple" @@ -96,6 +103,10 @@ def __init__(self, **kwargs): }, ) + self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True) + self.set_driver_property( + DriverProps.MAX_PARALLEL, value=kwargs.get("max_threads", 4) + ) self._debug = kwargs.get("debug", False) def query( @@ -118,10 +129,91 @@ def query( the underlying provider result if an error. """ - data, response = self.query_with_results(query) - if isinstance(data, pd.DataFrame): - return data - return response + del query_source + if not self._connected: + raise self._create_not_connected_err(self.__class__.__name__) + + page_size = min(kwargs.get("page_size", 2000), 4000) + logger.debug("Set page size to %d", page_size) + json_query = json.loads(query) + body = {**self.req_body, **json_query} + + # The query must be executed at least once to retrieve the number + # of results and the pagination token. + response = self.__execute_query(body, page_size=page_size) + + total_results = response["data"]["totalResults"] + pagination_token = response["data"]["paginationToken"] + results: Dict[str, Any] = response["data"]["resultIdToElementDataMap"] + + logger.debug("Retrieved %d/%d results", len(results), total_results) + + df_result: pd.DataFrame = None + + if len(results) < total_results: + df_result = self._exec_paginated_queries( + body=body, + page_size=page_size, + pagination_token=pagination_token, + total_results=total_results, + ) + else: + df_result = self._format_result_to_dataframe(result=response) + + return df_result + + def _exec_paginated_queries( + self, + body: Dict[str, Any], + page_size: int, + pagination_token: str, + total_results: int, + **kwargs, + ) -> pd.DataFrame: + """ + Return results of paginated queries. + + Parameters + ---------- + body : Dict[str, Any] + The body of the query to execute. + + Additional Parameters + ---------- + progress: bool, optional + Show progress bar, by default True + retry_on_error: bool, optional + Retry failed queries, by default False + **kwargs : Dict[str, Any] + Additional keyword arguments to pass to the query method. + + Returns + ------- + pd.DataFrame + The concatenated results of all the paginated queries. + + Notes + ----- + This method executes the specified query multiple times to retrieve + all the data from paginated results. + The queries are executed asynchronously. + + """ + progress = kwargs.pop("progress", True) + retry = kwargs.pop("retry_on_error", False) + + query_tasks = self._create_paginated_query_tasks( + body=body, + page_size=page_size, + pagination_token=pagination_token, + total_results=total_results, + ) + + logger.info("Running %s paginated queries.", len(query_tasks)) + event_loop = _get_event_loop() + return event_loop.run_until_complete( + self.__run_threaded_queries(query_tasks, progress, retry) + ) def connect( self, @@ -277,6 +369,136 @@ def _flatten_element_values( result[f"{key}.{subkey}"] = subvalues return result + def _create_paginated_query_tasks( + self, + body: Dict[str, Any], + page_size: int, + pagination_token: str, + total_results: int, + ) -> Dict[str, partial]: + """Return dictionary of partials to execute queries.""" + # Compute the number of queries to execute + total_pages = total_results // page_size + 1 + # The first query (page 0) as to be re-run due to a bug in + # Cybereason API. The first query returns less results than the page size + # when executed without a pagination token. + return { + f"{page}": partial( + self.__execute_query, + body=body, + page_size=page_size, + pagination_token=pagination_token, + page=page, + ) + for page in range(0, total_pages) + } + + def __execute_query( + self, + body: Dict[str, Any], + page: int = 0, + page_size: int = 2000, + pagination_token: str = None, + ) -> Dict[str, Any]: + """ + Run query with pagination enabled. + + Parameters + ---------- + body: Dict[str, Any] + Body of the HTTP Request + page_size: int + Size of the page for results + page: int + Page number to query + pagination_token: str + Token of the current search + + Returns + ------- + Dict[str, Any] + + """ + if pagination_token: + pagination = { + "pagination": { + "pageSize": page_size, + "page": page + 1, + "paginationToken": pagination_token, + "skip": page * page_size, + } + } + headers = {"Pagination-Token": pagination_token} + else: + pagination = {"pagination": {"pageSize": page_size}} + headers = {} + params = {"page": page, "itemsPerPage": page_size} + status = None + while status != "SUCCESS": + response = self.client.post( + self.search_endpoint, + json={**body, **pagination}, + headers=headers, + params=params, + ) + response.raise_for_status() + json_result = response.json() + status = json_result["status"] + return json_result + + async def __run_threaded_queries( + self, + query_tasks: Dict[str, partial], + progress: bool = True, + retry: bool = False, + ) -> pd.DataFrame: + logger.info("Running %d threaded queries.", len(query_tasks)) + event_loop = _get_event_loop() + with ThreadPoolExecutor(max_workers=4) as executor: + results: List[pd.DataFrame] = [] + failed_tasks: Dict[str, Future] = {} + thread_tasks = { + query_id: event_loop.run_in_executor(executor, query_func) + for query_id, query_func in query_tasks.items() + } + if progress: + task_iter = tqdm( + as_completed(thread_tasks.values()), + unit="paginated-queries", + desc="Running", + ) + else: + task_iter = as_completed(thread_tasks.values()) + ids_and_tasks = dict(zip(thread_tasks, task_iter)) + for query_id, thread_task in ids_and_tasks.items(): + try: + result = await thread_task + df_result = self._format_result_to_dataframe(result) + logger.info("Query task '%s' completed successfully.", query_id) + results.append(df_result) + except Exception: # pylint: disable=broad-except + logger.warning( + "Query task '%s' failed with exception", query_id, exc_info=True + ) + failed_tasks[query_id] = thread_task + + if retry and failed_tasks: + for query_id, thread_task in failed_tasks.items(): + try: + logger.info("Retrying query task '%s'", query_id) + result = await thread_task + df_result = self._format_result_to_dataframe(result) + results.append(df_result) + except Exception: # pylint: disable=broad-except + logger.warning( + "Retried query task '%s' failed with exception", + query_id, + exc_info=True, + ) + # Sort the results by the order of the tasks + results = [result for _, result in sorted(zip(thread_tasks, results))] + return pd.concat(results, ignore_index=True) + # pylint: disable=too-many-branches def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: """ @@ -294,64 +516,7 @@ def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: Kql ResultSet. """ - if not self.connected: - self.connect(self.current_connection) - if not self.connected: - raise ConnectionError( - "Source is not connected. ", "Please call connect() and retry." - ) - - if self._debug: - print(query) - - json_query = json.loads(query) - body = self.req_body - body.update(json_query) - response = self.client.post(self.search_endpoint, json=body) - - self._check_response_errors(response) - - json_response = response.json() - if json_response["status"] != "SUCCESS": - print( - "Warning - query did not complete successfully.", - f"Status: {json_response['status']}.", - json_response["message"], - ) - return pd.DataFrame(), json_response - - data = json_response.get("data", json_response) - results = data.get("resultIdToElementDataMap", data) - total_results = data.get("totalResults", len(results)) - guessed_results = data.get("guessedPossibleResults", len(results)) - if guessed_results > len(results): - print( - f"Warning - query returned {total_results} out of {guessed_results}.", - "Check returned response.", - ) - results = [ - dict(CybereasonDriver._flatten_result(values), **{"resultId": result_id}) - for result_id, values in results.items() - ] - - return pd.json_normalize(results), json_response - - # pylint: enable=too-many-branches - - @staticmethod - def _check_response_errors(response): - """Check the response for possible errors.""" - if response.status_code == httpx.codes.OK: - return - print(response.json()["error"]["message"]) - if response.status_code == 401: - raise ConnectionRefusedError( - "Authentication failed - possible ", "timeout. Please re-connect." - ) - # Raise an exception to handle hitting API limits - if response.status_code == 429: - raise ConnectionRefusedError("You have likely hit the API limit. ") - response.raise_for_status() + raise NotImplementedError(f"Not supported for {self.__class__.__name__}") # Parameter Formatting method @staticmethod @@ -373,6 +538,18 @@ def _format_to_datetime(timestamp: int) -> Union[dt.datetime, int]: except TypeError: return timestamp + @staticmethod + def _format_result_to_dataframe(result: Dict[str, Any]) -> pd.DataFrame: + """Return a dataframe from a cybereason result object.""" + df_result = [ + dict( + CybereasonDriver._flatten_result(values), + **{"resultId": result_id}, + ) + for result_id, values in result["data"]["resultIdToElementDataMap"].items() + ] + return pd.json_normalize(df_result) + # Retrieve configuration parameters with aliases @staticmethod def _map_config_dict_name(config_dict: Dict[str, str]): diff --git a/tests/data/drivers/test_cybereason_driver.py b/tests/data/drivers/test_cybereason_driver.py index c8684a00f..8e1e4a687 100644 --- a/tests/data/drivers/test_cybereason_driver.py +++ b/tests/data/drivers/test_cybereason_driver.py @@ -50,7 +50,9 @@ } }, } - } + }, + "paginationToken": None, + "totalResults": 1, }, "status": "SUCCESS", "message": "", @@ -58,6 +60,85 @@ "failures": 0, } +_CR_PAGINATED_RESULT = [ + { + "data": { + "resultIdToElementDataMap": { + "id1": { + "simpleValues": { + "osType": {"totalValues": 1, "values": ["WINDOWS"]}, + "totalMemory": { + "totalValues": 1, + "values": ["8589463552"], + }, + "group": { + "totalValues": 1, + "values": ["00000000-0000-0000-0000-000000000000"], + }, + "osVersionType": { + "totalValues": 1, + "values": ["Windows_10"], + }, + }, + "elementValues": { + "users": { + "totalValues": 5, + "elementValues": [], + "totalSuspicious": 0, + "totalMalicious": 0, + "guessedTotal": 0, + } + }, + } + }, + "paginationToken": None, + "totalResults": 2, + }, + "status": "SUCCESS", + "message": "", + "expectedResults": 0, + "failures": 0, + }, + { + "data": { + "resultIdToElementDataMap": { + "id2": { + "simpleValues": { + "osType": {"totalValues": 1, "values": ["WINDOWS"]}, + "totalMemory": { + "totalValues": 1, + "values": ["8589463552"], + }, + "group": { + "totalValues": 1, + "values": ["00000000-0000-0000-0000-000000000000"], + }, + "osVersionType": { + "totalValues": 1, + "values": ["Windows_10"], + }, + }, + "elementValues": { + "users": { + "totalValues": 5, + "elementValues": [], + "totalSuspicious": 0, + "totalMalicious": 0, + "guessedTotal": 0, + } + }, + } + }, + "paginationToken": None, + "totalResults": 2, + }, + "status": "SUCCESS", + "message": "", + "expectedResults": 0, + "failures": 0, + }, +] + _CR_QUERY = { "query": """ { @@ -132,9 +213,9 @@ def _cr_pre_checks(driver: CybereasonDriver): @respx.mock def test_connect(driver): """Test connect.""" - connect = respx.post(re.compile(r"https://.*.cybereason.net/login.html")).respond( - 200 - ) + connect = respx.post( + re.compile(r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/login\.html") + ).respond(200) with custom_mp_config(MP_PATH): driver.connect() check.is_true(connect.called) @@ -144,19 +225,49 @@ def test_connect(driver): @respx.mock def test_query(driver): """Test query calling returns data in expected format.""" - connect = respx.post(re.compile(r"https://.*.cybereason.net/login.html")).respond( - 200 - ) + connect = respx.post( + re.compile(r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/login\.html") + ).respond(200) query = respx.post( - re.compile(r"https://.*.cybereason.net/rest/visualsearch/query/simple") + re.compile( + r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/rest/visualsearch/query/simple" + ) ).respond(200, json=_CR_RESULT) with custom_mp_config(MP_PATH): + driver.connect() data = driver.query('{"test": "test"}') check.is_true(connect.called or driver.connected) check.is_true(query.called) check.is_instance(data, pd.DataFrame) +@respx.mock +def test_paginated_query(driver): + """Test query calling returns data in expected format.""" + connect = respx.post( + re.compile(r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/login.html") + ).respond(200) + query1 = respx.post( + re.compile( + r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/rest/visualsearch/query/simple" + ), + params={"page": 0}, + ).respond(200, json=_CR_PAGINATED_RESULT[0]) + query2 = respx.post( + re.compile( + r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/rest/visualsearch/query/simple" + ), + params={"page": 1}, + ).respond(200, json=_CR_PAGINATED_RESULT[1]) + with custom_mp_config(MP_PATH): + driver.connect() + data = driver.query('{"test": "test"}', page_size=1) + check.is_true(connect.called or driver.connected) + check.is_true(query1.called) + check.is_true(query2.called) + check.is_instance(data, pd.DataFrame) + + def test_custom_param_handler(driver): """Test query formatter returns data in expected format.""" query = _CR_QUERY.get("query", "")