From 900cefea011125af09d499e14bc1ddaa20610da5 Mon Sep 17 00:00:00 2001 From: FlorianBracq <97248273+FlorianBracq@users.noreply.github.com> Date: Fri, 29 Sep 2023 20:03:06 +0200 Subject: [PATCH 1/4] Intsights api update (#710) * Updating to v3 * Update parsing logic * Update tests --------- Co-authored-by: Ian Hellen --- msticpy/context/tiproviders/intsights.py | 41 ++++++++++++------------ tests/context/test_tiproviders.py | 35 +++++++++----------- 2 files changed, 36 insertions(+), 40 deletions(-) diff --git a/msticpy/context/tiproviders/intsights.py b/msticpy/context/tiproviders/intsights.py index b1c094e7c..b5bc1f620 100644 --- a/msticpy/context/tiproviders/intsights.py +++ b/msticpy/context/tiproviders/intsights.py @@ -46,42 +46,42 @@ class IntSights(HttpTIProvider): _QUERIES = { "ipv4": _IntSightsParams( - path="/public/v2/iocs/ioc-by-value", + path="/public/v3/iocs/ioc-by-value", params={"iocValue": "{observable}"}, headers=_DEF_HEADERS, ), "ipv6": _IntSightsParams( - path="/public/v2/iocs/ioc-by-value", + path="/public/v3/iocs/ioc-by-value", params={"iocValue": "{observable}"}, headers=_DEF_HEADERS, ), "dns": _IntSightsParams( - path="/public/v2/iocs/ioc-by-value", + path="/public/v3/iocs/ioc-by-value", params={"iocValue": "{observable}"}, headers=_DEF_HEADERS, ), "url": _IntSightsParams( - path="/public/v2/iocs/ioc-by-value", + path="/public/v3/iocs/ioc-by-value", params={"iocValue": "{observable}"}, headers=_DEF_HEADERS, ), "md5_hash": _IntSightsParams( - path="/public/v2/iocs/ioc-by-value", + path="/public/v3/iocs/ioc-by-value", params={"iocValue": "{observable}"}, headers=_DEF_HEADERS, ), "sha1_hash": _IntSightsParams( - path="/public/v2/iocs/ioc-by-value", + path="/public/v3/iocs/ioc-by-value", params={"iocValue": "{observable}"}, headers=_DEF_HEADERS, ), "sha256_hash": _IntSightsParams( - path="/public/v2/iocs/ioc-by-value", + path="/public/v3/iocs/ioc-by-value", params={"iocValue": "{observable}"}, headers=_DEF_HEADERS, ), "email": _IntSightsParams( - path="/public/v2/iocs/ioc-by-value", + path="/public/v3/iocs/ioc-by-value", params={"iocValue": "{observable}"}, headers=_DEF_HEADERS, ), @@ -111,27 +111,28 @@ def parse_results(self, response: Dict) -> Tuple[bool, ResultSeverity, Any]: ): return False, ResultSeverity.information, "Not found." - if response["RawResult"]["Whitelist"] == "True": + if response["RawResult"].get("whitelisted", False): return False, ResultSeverity.information, "Whitelisted." - sev = response["RawResult"]["Severity"] + sev = response["RawResult"].get("severity", "Low") result_dict = { - "threat_actors": response["RawResult"]["RelatedThreatActors"], - "geolocation": response["RawResult"].get("Geolocation", ""), + "threat_actors": response["RawResult"].get("relatedThreatActors", ""), + "geolocation": response["RawResult"].get("geolocation", None), "response_code": response["Status"], - "tags": response["RawResult"]["Tags"] + response["RawResult"]["SystemTags"], - "malware": response["RawResult"]["RelatedMalware"], - "campaigns": response["RawResult"]["RelatedCampaigns"], - "sources": response["RawResult"]["Sources"], - "score": response["RawResult"]["Score"], + "tags": response["RawResult"].get("tags", []) + + response["RawResult"].get("SystemTags", []), + "malware": response["RawResult"].get("relatedMalware", []), + "campaigns": response["RawResult"].get("relatedCampaigns", []), + "score": response["RawResult"].get("score", 0), "first_seen": dt.datetime.strptime( - response["RawResult"]["FirstSeen"], "%Y-%m-%dT%H:%M:%S.%fZ" + response["RawResult"].get("firstSeen", None), "%Y-%m-%dT%H:%M:%S.%fZ" ), "last_seen": dt.datetime.strptime( - response["RawResult"]["LastSeen"], "%Y-%m-%dT%H:%M:%S.%fZ" + response["RawResult"].get("lastSeen", None), "%Y-%m-%dT%H:%M:%S.%fZ" ), "last_update": dt.datetime.strptime( - response["RawResult"]["LastUpdate"], "%Y-%m-%dT%H:%M:%S.%fZ" + response["RawResult"].get("lastUpdateDate", None), + "%Y-%m-%dT%H:%M:%S.%fZ", ), } diff --git a/tests/context/test_tiproviders.py b/tests/context/test_tiproviders.py index 2b9d722b8..f2166cd4a 100644 --- a/tests/context/test_tiproviders.py +++ b/tests/context/test_tiproviders.py @@ -885,32 +885,27 @@ def _get_riskiq_classification(): "https://api.ti.insight.rapid7.com": { "ioc_param": "params", "response": { - "Value": "124.5.6.7", - "Type": "IpAddresses", - "Score": 42, - "Severity": "Medium", - "Whitelist": False, - "FirstSeen": dt.datetime.strftime( + "value": "124.5.6.7", + "type": "IpAddresses", + "score": 42, + "severity": "Medium", + "whitelist": False, + "firstSeen": dt.datetime.strftime( dt.datetime.now(), "%Y-%m-%dT%H:%M:%S.%fZ" ), - "LastSeen": dt.datetime.strftime( + "lastSeen": dt.datetime.strftime( dt.datetime.now(), "%Y-%m-%dT%H:%M:%S.%fZ" ), - "LastUpdate": dt.datetime.strftime( + "lastUpdateDate": dt.datetime.strftime( dt.datetime.now(), "%Y-%m-%dT%H:%M:%S.%fZ" ), - "Sources": [ - {"ConfidenceLevel": 2, "Name": "Source A"}, - {"ConfidenceLevel": 1, "Name": "Source B"}, - {"ConfidenceLevel": 1, "Name": "Source C"}, - {"ConfidenceLevel": 3, "Name": "Source D"}, - ], - "SystemTags": ["bot", "malware related"], - "Geolocation": "FR", - "RelatedMalware": ["malware1"], - "RelatedCampaigns": ["Campaign A"], - "RelatedThreatActors": ["Threat Actor 00"], - "Tags": ["tag"], + "systemTags": ["bot", "malware related"], + "geolocation": "FR", + "relatedMalware": ["malware1"], + "relatedCampaigns": ["Campaign A"], + "relatedThreatActors": ["Threat Actor 00"], + "tags": ["tag"], + "whitelisted": False, }, }, "https://cti.api.crowdsec.net": { From aef43f2c897dd79cde3c2a853b3245728965d5c4 Mon Sep 17 00:00:00 2001 From: hackeT <40039738+Tatsuya-hasegawa@users.noreply.github.com> Date: Sat, 30 Sep 2023 03:43:27 +0900 Subject: [PATCH 2/4] Fix m365d/mde hunting query options (#702) * fix unpassed time_column param in m365d_hunting * comment out unused start and end param in m365d/mde hunting * delete, not comment out --------- Co-authored-by: Ian Hellen --- msticpy/data/queries/m365d/kql_m365_hunting.yaml | 8 +------- msticpy/data/queries/mde/kql_mdatp_hunting.yaml | 6 ------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/msticpy/data/queries/m365d/kql_m365_hunting.yaml b/msticpy/data/queries/m365d/kql_m365_hunting.yaml index 69dfb3f36..54e27c7c3 100644 --- a/msticpy/data/queries/m365d/kql_m365_hunting.yaml +++ b/msticpy/data/queries/m365d/kql_m365_hunting.yaml @@ -8,12 +8,6 @@ defaults: metadata: data_source: 'hunting_queries' parameters: - start: - description: Query start time - type: datetime - end: - description: Query end time - type: datetime add_query_items: description: Additional query clauses type: str @@ -413,7 +407,7 @@ sources: makeset(Command), count(), min({time_column}) by AccountName, DeviceName, DeviceId | order by AccountName asc - | where min_Timestamp > ago(1d) + | where min_{time_column} > ago(1d) {add_query_items}' uri: "https://github.com/microsoft/WindowsDefenderATP-Hunting-Queries/blob/master/Lateral%20Movement/ServiceAccountsPerformingRemotePS.txt" accessibility_persistence: diff --git a/msticpy/data/queries/mde/kql_mdatp_hunting.yaml b/msticpy/data/queries/mde/kql_mdatp_hunting.yaml index 0e785ef8d..46f3fc04b 100644 --- a/msticpy/data/queries/mde/kql_mdatp_hunting.yaml +++ b/msticpy/data/queries/mde/kql_mdatp_hunting.yaml @@ -8,12 +8,6 @@ defaults: metadata: data_source: 'hunting_queries' parameters: - start: - description: Query start time - type: datetime - end: - description: Query end time - type: datetime add_query_items: description: Additional query clauses type: str From fdc2a5a819093f2c5d4fd2c83d014532881219a7 Mon Sep 17 00:00:00 2001 From: FlorianBracq <97248273+FlorianBracq@users.noreply.github.com> Date: Fri, 29 Sep 2023 21:21:26 +0200 Subject: [PATCH 3/4] Cybereason pagination support + multi-threading (#707) * Add flags for multi-threading * Remove code for function query_with_results * Add function to format CR API result to a DF * Adding methods for paginated queries * Updating tests * Escape domain name for regular expressions * Sanitizing domain names in regular expressions * Fix mypy error --------- Co-authored-by: Ian Hellen --- msticpy/data/drivers/cybereason_driver.py | 309 +++++++++++++++---- tests/data/drivers/test_cybereason_driver.py | 127 +++++++- 2 files changed, 362 insertions(+), 74 deletions(-) diff --git a/msticpy/data/drivers/cybereason_driver.py b/msticpy/data/drivers/cybereason_driver.py index ad8115fb3..e414e9d52 100644 --- a/msticpy/data/drivers/cybereason_driver.py +++ b/msticpy/data/drivers/cybereason_driver.py @@ -6,23 +6,30 @@ """Cybereason Driver class.""" import datetime as dt import json +import logging import re -from functools import singledispatch +from asyncio import as_completed, Future +from concurrent.futures import ThreadPoolExecutor +from functools import partial, singledispatch from typing import Any, Dict, List, Optional, Tuple, Union import httpx import pandas as pd +from tqdm.auto import tqdm from ..._version import VERSION from ...common.exceptions import MsticpyUserConfigError from ...common.provider_settings import ProviderArgs, get_provider_settings from ...common.utility import mp_ua_header from ..core.query_defns import Formatters +from ..core.query_provider_connections_mixin import _get_event_loop from .driver_base import DriverBase, DriverProps, QuerySource __version__ = VERSION __author__ = "Florian Bracq" +logger = logging.getLogger(__name__) + _HELP_URI = ( "https://msticpy.readthedocs.io/en/latest/data_acquisition/DataProviders.html" ) @@ -66,8 +73,9 @@ def __init__(self, **kwargs): """ super().__init__(**kwargs) timeout = kwargs.get("timeout", 120) # 2 minutes in milliseconds - max_results = min(kwargs.get("max_results", 1000), 10000) - page_size = min(kwargs.get("page_size", 100), 100) + logger.debug("Set timeout to %d", timeout) + max_results = min(kwargs.get("max_results", 100000), 100000) + logger.debug("Set maximum results to %d", max_results) self.base_url: str = "https://{tenant_id}.cybereason.net" self.auth_endpoint: str = "/login.html" self.req_body: Dict[str, Any] = { @@ -77,7 +85,6 @@ def __init__(self, **kwargs): "perFeatureLimit": 100, "templateContext": "SPECIFIC", "queryTimeout": timeout * 1000, - "pagination": {"pageSize": page_size}, "customFields": [], } self.search_endpoint: str = "/rest/visualsearch/query/simple" @@ -96,6 +103,10 @@ def __init__(self, **kwargs): }, ) + self.set_driver_property(DriverProps.SUPPORTS_THREADING, value=True) + self.set_driver_property( + DriverProps.MAX_PARALLEL, value=kwargs.get("max_threads", 4) + ) self._debug = kwargs.get("debug", False) def query( @@ -118,10 +129,91 @@ def query( the underlying provider result if an error. """ - data, response = self.query_with_results(query) - if isinstance(data, pd.DataFrame): - return data - return response + del query_source + if not self._connected: + raise self._create_not_connected_err(self.__class__.__name__) + + page_size = min(kwargs.get("page_size", 2000), 4000) + logger.debug("Set page size to %d", page_size) + json_query = json.loads(query) + body = {**self.req_body, **json_query} + + # The query must be executed at least once to retrieve the number + # of results and the pagination token. + response = self.__execute_query(body, page_size=page_size) + + total_results = response["data"]["totalResults"] + pagination_token = response["data"]["paginationToken"] + results: Dict[str, Any] = response["data"]["resultIdToElementDataMap"] + + logger.debug("Retrieved %d/%d results", len(results), total_results) + + df_result: pd.DataFrame = None + + if len(results) < total_results: + df_result = self._exec_paginated_queries( + body=body, + page_size=page_size, + pagination_token=pagination_token, + total_results=total_results, + ) + else: + df_result = self._format_result_to_dataframe(result=response) + + return df_result + + def _exec_paginated_queries( + self, + body: Dict[str, Any], + page_size: int, + pagination_token: str, + total_results: int, + **kwargs, + ) -> pd.DataFrame: + """ + Return results of paginated queries. + + Parameters + ---------- + body : Dict[str, Any] + The body of the query to execute. + + Additional Parameters + ---------- + progress: bool, optional + Show progress bar, by default True + retry_on_error: bool, optional + Retry failed queries, by default False + **kwargs : Dict[str, Any] + Additional keyword arguments to pass to the query method. + + Returns + ------- + pd.DataFrame + The concatenated results of all the paginated queries. + + Notes + ----- + This method executes the specified query multiple times to retrieve + all the data from paginated results. + The queries are executed asynchronously. + + """ + progress = kwargs.pop("progress", True) + retry = kwargs.pop("retry_on_error", False) + + query_tasks = self._create_paginated_query_tasks( + body=body, + page_size=page_size, + pagination_token=pagination_token, + total_results=total_results, + ) + + logger.info("Running %s paginated queries.", len(query_tasks)) + event_loop = _get_event_loop() + return event_loop.run_until_complete( + self.__run_threaded_queries(query_tasks, progress, retry) + ) def connect( self, @@ -277,6 +369,136 @@ def _flatten_element_values( result[f"{key}.{subkey}"] = subvalues return result + def _create_paginated_query_tasks( + self, + body: Dict[str, Any], + page_size: int, + pagination_token: str, + total_results: int, + ) -> Dict[str, partial]: + """Return dictionary of partials to execute queries.""" + # Compute the number of queries to execute + total_pages = total_results // page_size + 1 + # The first query (page 0) as to be re-run due to a bug in + # Cybereason API. The first query returns less results than the page size + # when executed without a pagination token. + return { + f"{page}": partial( + self.__execute_query, + body=body, + page_size=page_size, + pagination_token=pagination_token, + page=page, + ) + for page in range(0, total_pages) + } + + def __execute_query( + self, + body: Dict[str, Any], + page: int = 0, + page_size: int = 2000, + pagination_token: str = None, + ) -> Dict[str, Any]: + """ + Run query with pagination enabled. + + Parameters + ---------- + body: Dict[str, Any] + Body of the HTTP Request + page_size: int + Size of the page for results + page: int + Page number to query + pagination_token: str + Token of the current search + + Returns + ------- + Dict[str, Any] + + """ + if pagination_token: + pagination = { + "pagination": { + "pageSize": page_size, + "page": page + 1, + "paginationToken": pagination_token, + "skip": page * page_size, + } + } + headers = {"Pagination-Token": pagination_token} + else: + pagination = {"pagination": {"pageSize": page_size}} + headers = {} + params = {"page": page, "itemsPerPage": page_size} + status = None + while status != "SUCCESS": + response = self.client.post( + self.search_endpoint, + json={**body, **pagination}, + headers=headers, + params=params, + ) + response.raise_for_status() + json_result = response.json() + status = json_result["status"] + return json_result + + async def __run_threaded_queries( + self, + query_tasks: Dict[str, partial], + progress: bool = True, + retry: bool = False, + ) -> pd.DataFrame: + logger.info("Running %d threaded queries.", len(query_tasks)) + event_loop = _get_event_loop() + with ThreadPoolExecutor(max_workers=4) as executor: + results: List[pd.DataFrame] = [] + failed_tasks: Dict[str, Future] = {} + thread_tasks = { + query_id: event_loop.run_in_executor(executor, query_func) + for query_id, query_func in query_tasks.items() + } + if progress: + task_iter = tqdm( + as_completed(thread_tasks.values()), + unit="paginated-queries", + desc="Running", + ) + else: + task_iter = as_completed(thread_tasks.values()) + ids_and_tasks = dict(zip(thread_tasks, task_iter)) + for query_id, thread_task in ids_and_tasks.items(): + try: + result = await thread_task + df_result = self._format_result_to_dataframe(result) + logger.info("Query task '%s' completed successfully.", query_id) + results.append(df_result) + except Exception: # pylint: disable=broad-except + logger.warning( + "Query task '%s' failed with exception", query_id, exc_info=True + ) + failed_tasks[query_id] = thread_task + + if retry and failed_tasks: + for query_id, thread_task in failed_tasks.items(): + try: + logger.info("Retrying query task '%s'", query_id) + result = await thread_task + df_result = self._format_result_to_dataframe(result) + results.append(df_result) + except Exception: # pylint: disable=broad-except + logger.warning( + "Retried query task '%s' failed with exception", + query_id, + exc_info=True, + ) + # Sort the results by the order of the tasks + results = [result for _, result in sorted(zip(thread_tasks, results))] + return pd.concat(results, ignore_index=True) + # pylint: disable=too-many-branches def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: """ @@ -294,64 +516,7 @@ def query_with_results(self, query: str, **kwargs) -> Tuple[pd.DataFrame, Any]: Kql ResultSet. """ - if not self.connected: - self.connect(self.current_connection) - if not self.connected: - raise ConnectionError( - "Source is not connected. ", "Please call connect() and retry." - ) - - if self._debug: - print(query) - - json_query = json.loads(query) - body = self.req_body - body.update(json_query) - response = self.client.post(self.search_endpoint, json=body) - - self._check_response_errors(response) - - json_response = response.json() - if json_response["status"] != "SUCCESS": - print( - "Warning - query did not complete successfully.", - f"Status: {json_response['status']}.", - json_response["message"], - ) - return pd.DataFrame(), json_response - - data = json_response.get("data", json_response) - results = data.get("resultIdToElementDataMap", data) - total_results = data.get("totalResults", len(results)) - guessed_results = data.get("guessedPossibleResults", len(results)) - if guessed_results > len(results): - print( - f"Warning - query returned {total_results} out of {guessed_results}.", - "Check returned response.", - ) - results = [ - dict(CybereasonDriver._flatten_result(values), **{"resultId": result_id}) - for result_id, values in results.items() - ] - - return pd.json_normalize(results), json_response - - # pylint: enable=too-many-branches - - @staticmethod - def _check_response_errors(response): - """Check the response for possible errors.""" - if response.status_code == httpx.codes.OK: - return - print(response.json()["error"]["message"]) - if response.status_code == 401: - raise ConnectionRefusedError( - "Authentication failed - possible ", "timeout. Please re-connect." - ) - # Raise an exception to handle hitting API limits - if response.status_code == 429: - raise ConnectionRefusedError("You have likely hit the API limit. ") - response.raise_for_status() + raise NotImplementedError(f"Not supported for {self.__class__.__name__}") # Parameter Formatting method @staticmethod @@ -373,6 +538,18 @@ def _format_to_datetime(timestamp: int) -> Union[dt.datetime, int]: except TypeError: return timestamp + @staticmethod + def _format_result_to_dataframe(result: Dict[str, Any]) -> pd.DataFrame: + """Return a dataframe from a cybereason result object.""" + df_result = [ + dict( + CybereasonDriver._flatten_result(values), + **{"resultId": result_id}, + ) + for result_id, values in result["data"]["resultIdToElementDataMap"].items() + ] + return pd.json_normalize(df_result) + # Retrieve configuration parameters with aliases @staticmethod def _map_config_dict_name(config_dict: Dict[str, str]): diff --git a/tests/data/drivers/test_cybereason_driver.py b/tests/data/drivers/test_cybereason_driver.py index c8684a00f..8e1e4a687 100644 --- a/tests/data/drivers/test_cybereason_driver.py +++ b/tests/data/drivers/test_cybereason_driver.py @@ -50,7 +50,9 @@ } }, } - } + }, + "paginationToken": None, + "totalResults": 1, }, "status": "SUCCESS", "message": "", @@ -58,6 +60,85 @@ "failures": 0, } +_CR_PAGINATED_RESULT = [ + { + "data": { + "resultIdToElementDataMap": { + "id1": { + "simpleValues": { + "osType": {"totalValues": 1, "values": ["WINDOWS"]}, + "totalMemory": { + "totalValues": 1, + "values": ["8589463552"], + }, + "group": { + "totalValues": 1, + "values": ["00000000-0000-0000-0000-000000000000"], + }, + "osVersionType": { + "totalValues": 1, + "values": ["Windows_10"], + }, + }, + "elementValues": { + "users": { + "totalValues": 5, + "elementValues": [], + "totalSuspicious": 0, + "totalMalicious": 0, + "guessedTotal": 0, + } + }, + } + }, + "paginationToken": None, + "totalResults": 2, + }, + "status": "SUCCESS", + "message": "", + "expectedResults": 0, + "failures": 0, + }, + { + "data": { + "resultIdToElementDataMap": { + "id2": { + "simpleValues": { + "osType": {"totalValues": 1, "values": ["WINDOWS"]}, + "totalMemory": { + "totalValues": 1, + "values": ["8589463552"], + }, + "group": { + "totalValues": 1, + "values": ["00000000-0000-0000-0000-000000000000"], + }, + "osVersionType": { + "totalValues": 1, + "values": ["Windows_10"], + }, + }, + "elementValues": { + "users": { + "totalValues": 5, + "elementValues": [], + "totalSuspicious": 0, + "totalMalicious": 0, + "guessedTotal": 0, + } + }, + } + }, + "paginationToken": None, + "totalResults": 2, + }, + "status": "SUCCESS", + "message": "", + "expectedResults": 0, + "failures": 0, + }, +] + _CR_QUERY = { "query": """ { @@ -132,9 +213,9 @@ def _cr_pre_checks(driver: CybereasonDriver): @respx.mock def test_connect(driver): """Test connect.""" - connect = respx.post(re.compile(r"https://.*.cybereason.net/login.html")).respond( - 200 - ) + connect = respx.post( + re.compile(r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/login\.html") + ).respond(200) with custom_mp_config(MP_PATH): driver.connect() check.is_true(connect.called) @@ -144,19 +225,49 @@ def test_connect(driver): @respx.mock def test_query(driver): """Test query calling returns data in expected format.""" - connect = respx.post(re.compile(r"https://.*.cybereason.net/login.html")).respond( - 200 - ) + connect = respx.post( + re.compile(r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/login\.html") + ).respond(200) query = respx.post( - re.compile(r"https://.*.cybereason.net/rest/visualsearch/query/simple") + re.compile( + r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/rest/visualsearch/query/simple" + ) ).respond(200, json=_CR_RESULT) with custom_mp_config(MP_PATH): + driver.connect() data = driver.query('{"test": "test"}') check.is_true(connect.called or driver.connected) check.is_true(query.called) check.is_instance(data, pd.DataFrame) +@respx.mock +def test_paginated_query(driver): + """Test query calling returns data in expected format.""" + connect = respx.post( + re.compile(r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/login.html") + ).respond(200) + query1 = respx.post( + re.compile( + r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/rest/visualsearch/query/simple" + ), + params={"page": 0}, + ).respond(200, json=_CR_PAGINATED_RESULT[0]) + query2 = respx.post( + re.compile( + r"^https://[a-zA-Z0-9\-]+\.cybereason\.net/rest/visualsearch/query/simple" + ), + params={"page": 1}, + ).respond(200, json=_CR_PAGINATED_RESULT[1]) + with custom_mp_config(MP_PATH): + driver.connect() + data = driver.query('{"test": "test"}', page_size=1) + check.is_true(connect.called or driver.connected) + check.is_true(query1.called) + check.is_true(query2.called) + check.is_instance(data, pd.DataFrame) + + def test_custom_param_handler(driver): """Test query formatter returns data in expected format.""" query = _CR_QUERY.get("query", "") From 15fb44b900577fe67c4bd9e501ab93291ce8e686 Mon Sep 17 00:00:00 2001 From: hackeT <40039738+Tatsuya-hasegawa@users.noreply.github.com> Date: Sat, 30 Sep 2023 04:45:39 +0900 Subject: [PATCH 4/4] Add bearer token auth to splunk driver (#708) * add token auth to splunk driver and fix splunk port value type * fix flask8 error * fix flask8 error * fix flask8 error * Fixing some linting errors in splunk_driver.py --------- Co-authored-by: Ian Hellen --- .../data_acquisition/SplunkProvider.rst | 34 +++++++++++++++---- msticpy/data/drivers/splunk_driver.py | 23 +++++++++---- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/docs/source/data_acquisition/SplunkProvider.rst b/docs/source/data_acquisition/SplunkProvider.rst index 39819663b..7f439dc75 100644 --- a/docs/source/data_acquisition/SplunkProvider.rst +++ b/docs/source/data_acquisition/SplunkProvider.rst @@ -38,7 +38,7 @@ The settings in the file should look like the following: Splunk: Args: host: splunk_host - port: 8089 + port: '8089' username: splunk_user password: [PLACEHOLDER] @@ -54,7 +54,7 @@ to a Key Vault secret using the MSTICPy configuration editor. Splunk: Args: host: splunk_host - port: 8089 + port: '8089' username: splunk_user password: KeyVault: @@ -67,8 +67,13 @@ Parameter Description host (string) The host name (the default is 'localhost'). username (string) The Splunk account username, which is used to authenticate the Splunk instance. password (string) The password for the Splunk account. +splunkToken (string) The Authorization Bearer Token created in the Splunk. =========== =========================================================================================================================== +The username and password are needed for user account authentication. +On the other hand, splunkToken is needed for Token authentication. +The user auth method has a priority to token auth method if both username and splunkToken are set. + Optional configuration parameters: @@ -106,11 +111,11 @@ in msticpy config file. For more information on how to create new user with appropriate roles and permissions, follow the Splunk documents: -`Securing the Spunk platform `__ +`Securing the Spunk platform `__ and -`About users and roles `__. +`About users and roles `__ The user should have permission to at least run its own searches or more depending upon the actions to be performed by user. @@ -120,10 +125,20 @@ require the following details to specify while connecting: - host = "localhost" (Splunk server FQDN hostname to connect, for locally installed splunk, you can specify localhost) -- port = 8089 (Splunk REST API ) +- port = "8089" (Splunk REST API) - username = "admin" (username to connect to Splunk instance) - password = "yourpassword" (password of the user specified in username) +On the other hand, you can use the authentification token to connect. + +`Create authentication token `__ + +- host = "localhost" (Splunk server FQDN hostname to connect, for locally + installed splunk, you can specify localhost) +- port = "8089" (Splunk REST API) +- splunkToken = "" (token can be used instead of username/password) + + Once you have details, you can specify it in ``msticpyconfig.yaml`` as described earlier. @@ -146,6 +161,11 @@ as parameters to connect. qry_prov.connect(host=, username=, password=) +OR + +.. code:: ipython3 + + qry_prov.connect(host=, splunkToken=) Listing available queries @@ -217,7 +237,7 @@ For more information, see (default value is: | head 100) end: datetime (optional) Query end time - (default value is: 08/26/2017:00:00:00) + (default value is: current time + 1 day) index: str (optional) Splunk index name (default value is: \*) @@ -229,7 +249,7 @@ For more information, see (default value is: \*) start: datetime (optional) Query start time - (default value is: 08/25/2017:00:00:00) + (default value is: current time - 1 day) timeformat: str (optional) Datetime format to use in Splunk query (default value is: "%Y-%m-%d %H:%M:%S.%6N") diff --git a/msticpy/data/drivers/splunk_driver.py b/msticpy/data/drivers/splunk_driver.py index 3b4c4a6e9..0754027ad 100644 --- a/msticpy/data/drivers/splunk_driver.py +++ b/msticpy/data/drivers/splunk_driver.py @@ -35,14 +35,14 @@ ) from imp_err __version__ = VERSION -__author__ = "Ashwin Patil" +__author__ = "Ashwin Patil, Tatsuya Hasegawa" logger = logging.getLogger(__name__) SPLUNK_CONNECT_ARGS = { "host": "(string) The host name (the default is 'localhost').", - "port": "(integer) The port number (the default is 8089).", + "port": "(string) The port number (the default is '8089').", "http_scheme": "('https' or 'http') The scheme for accessing the service " + "(the default is 'https').", "verify": "(Boolean) Enable (True) or disable (False) SSL verrification for " @@ -60,6 +60,7 @@ "username": "(string) The Splunk account username, which is used to " + "authenticate the Splunk instance.", "password": "(string) The password for the Splunk account.", + "splunkToken": "(string) The Authorization Bearer Token created in the Splunk.", } @@ -67,8 +68,8 @@ class SplunkDriver(DriverBase): """Driver to connect and query from Splunk.""" - _SPLUNK_REQD_ARGS = ["host", "username", "password"] - _CONNECT_DEFAULTS: Dict[str, Any] = {"port": 8089} + _SPLUNK_REQD_ARGS = ["host"] + _CONNECT_DEFAULTS: Dict[str, Any] = {"port": "8089"} _TIME_FORMAT = '"%Y-%m-%d %H:%M:%S.%6N"' def __init__(self, **kwargs): @@ -79,6 +80,7 @@ def __init__(self, **kwargs): self._connected = False if kwargs.get("debug", False): logger.setLevel(logging.DEBUG) + self._required_params = self._SPLUNK_REQD_ARGS self.set_driver_property( DriverProps.PUBLIC_ATTRS, @@ -142,7 +144,7 @@ def connect(self, connection_str: Optional[str] = None, **kwargs): help_uri="https://msticpy.readthedocs.io/en/latest/DataProviders.html", ) from err self._connected = True - print("connected") + print("Connected.") def _get_connect_args( self, connection_str: Optional[str], **kwargs @@ -172,12 +174,19 @@ def _get_connect_args( elif isinstance(verify_opt, bool): cs_dict["verify"] = verify_opt - missing_args = set(self._SPLUNK_REQD_ARGS) - cs_dict.keys() + # Different required parameters for the REST API authentication method + # between user/pass and authorization bearer token + if "username" in cs_dict: + self._required_params = ["host", "username", "password"] + else: + self._required_params = ["host", "splunkToken"] + + missing_args = set(self._required_params) - cs_dict.keys() if missing_args: raise MsticpyUserConfigError( "One or more connection parameters missing for Splunk connector", ", ".join(missing_args), - f"Required parameters are {', '.join(self._SPLUNK_REQD_ARGS)}", + f"Required parameters are {', '.join(self._required_params)}", "All parameters:", *[f"{arg}: {desc}" for arg, desc in SPLUNK_CONNECT_ARGS.items()], title="no Splunk connection parameters",