Skip to content

Commit

Permalink
Add extra tests and fixes to QueryProvider, DriverBase and (as)sync q…
Browse files Browse the repository at this point in the history
…uery handling (#777)

* Updated conda reqs files for new packages (#758)

* Updated conda reqs files for new packages

* Updating version to 2.11

* Fix for missing label in RDAP data

* Fix for missing 'label' item in RDAP data

* Modifying method _check_environment to be a class method

* Fix deprecation of datetime.utcnow method

* Expand test coverage for base driver

* Remove implementation from abstract method connect

* Expand test coverage for data provider

* Remove method exec_query from QueryProviderProtocol as it cannot be instanciated

* Set method _get_query_options as abstract from QueryProviderProtocol as it cannot be instanciated
+ remove implementation

* Update _exec_queries_threaded to be a static method, to align with _exec_synchronous_queries

* Handle corner cases when query_tasks are all failing

* Fixing logic flaw in _exec_queries_threaded retry logic

* Fix logic issue in _calc_split_ranges

* Extend coverage for query_provider_connections_mixin

* Adding abstract method _get_query_options to QueryProviderConnectionsMixin

* Use typing types to be compatible with python 3.8

* Move implementation of exec_query to QueryProviderConnectionsMixin
  • Loading branch information
FlorianBracq authored May 28, 2024
1 parent 07a2f0d commit c1df176
Show file tree
Hide file tree
Showing 6 changed files with 530 additions and 101 deletions.
41 changes: 3 additions & 38 deletions msticpy/data/core/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__( # noqa: MC0001
# pylint: enable=import-outside-toplevel
setattr(self.__class__, "_add_pivots", add_data_queries_to_entities)

data_environment, self.environment_name = self._check_environment(
data_environment, self.environment_name = QueryProvider._check_environment(
data_environment
)

Expand Down Expand Up @@ -139,8 +139,9 @@ def __init__( # noqa: MC0001
self._query_time = QueryTime(units="day")
logger.info("Initialization complete.")

@classmethod
def _check_environment(
self, data_environment
cls, data_environment
) -> Tuple[Union[str, DataEnvironment], str]:
"""Check environment against known names."""
if isinstance(data_environment, str):
Expand Down Expand Up @@ -212,42 +213,6 @@ def connect(self, connection_str: Optional[str] = None, **kwargs):
logger.info("Adding query pivot functions")
self._add_pivots(lambda: self._query_time.timespan)

def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]:
"""
Execute simple query string.
Parameters
----------
query : str
[description]
use_connections : Union[str, List[str]]
Other Parameters
----------------
query_options : Dict[str, Any]
Additional options passed to query driver.
kwargs : Dict[str, Any]
Additional options passed to query driver.
Returns
-------
Union[pd.DataFrame, Any]
Query results - a DataFrame if successful
or a KqlResult if unsuccessful.
"""
query_options = kwargs.pop("query_options", {}) or kwargs
query_source = kwargs.pop("query_source", None)

logger.info("Executing query '%s...'", query[:40])
logger.debug("Full query: %s", query)
logger.debug("Query options: %s", query_options)
if not self._additional_connections:
return self._query_provider.query(
query, query_source=query_source, **query_options
)
return self._exec_additional_connections(query, **kwargs)

@property
def query_time(self):
"""Return the default QueryTime control for queries."""
Expand Down
135 changes: 100 additions & 35 deletions msticpy/data/core/query_provider_connections_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Query Provider additional connection methods."""
import asyncio
import logging
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial
Expand Down Expand Up @@ -37,23 +38,61 @@ class QueryProviderProtocol(Protocol):
_additional_connections: Dict[str, Any]
_query_provider: DriverBase

def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]:
"""Execute a query against the provider."""
...

# fmt: off
@staticmethod
@abstractmethod
def _get_query_options(
params: Dict[str, Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
...
# fmt: on
"""Return any kwargs not already in params."""


# pylint: disable=super-init-not-called
class QueryProviderConnectionsMixin(QueryProviderProtocol):
"""Mixin additional connection handling QueryProvider class."""

@staticmethod
@abstractmethod
def _get_query_options(
params: Dict[str, Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Return any kwargs not already in params."""

def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]:
"""
Execute simple query string.
Parameters
----------
query : str
[description]
use_connections : Union[str, List[str]]
Other Parameters
----------------
query_options : Dict[str, Any]
Additional options passed to query driver.
kwargs : Dict[str, Any]
Additional options passed to query driver.
Returns
-------
Union[pd.DataFrame, Any]
Query results - a DataFrame if successful
or a KqlResult if unsuccessful.
"""
query_options = kwargs.pop("query_options", {}) or kwargs
query_source = kwargs.pop("query_source", None)

logger.info("Executing query '%s...'", query[:40])
logger.debug("Full query: %s", query)
logger.debug("Query options: %s", query_options)
if not self._additional_connections:
return self._query_provider.query(
query, query_source=query_source, **query_options
)
return self._exec_additional_connections(query, **kwargs)

def add_connection(
self,
connection_str: Optional[str] = None,
Expand Down Expand Up @@ -159,8 +198,16 @@ def _exec_additional_connections(self, query, **kwargs) -> pd.DataFrame:
if self._query_provider.get_driver_property(DriverProps.SUPPORTS_THREADING):
logger.info("Running threaded queries.")
event_loop = _get_event_loop()
max_workers: int = self._query_provider.get_driver_property(
DriverProps.MAX_PARALLEL
)
return event_loop.run_until_complete(
self._exec_queries_threaded(query_tasks, progress, retry)
self._exec_queries_threaded(
query_tasks,
progress,
retry,
max_workers,
)
)

# standard synchronous execution
Expand Down Expand Up @@ -238,8 +285,16 @@ def _exec_split_query(
if self._query_provider.get_driver_property(DriverProps.SUPPORTS_THREADING):
logger.info("Running threaded queries.")
event_loop = _get_event_loop()
max_workers: int = self._query_provider.get_driver_property(
DriverProps.MAX_PARALLEL
)
return event_loop.run_until_complete(
self._exec_queries_threaded(query_tasks, progress, retry)
self._exec_queries_threaded(
query_tasks,
progress,
retry,
max_workers,
)
)

# or revert to standard synchronous execution
Expand Down Expand Up @@ -285,7 +340,11 @@ def _exec_synchronous_queries(
results.append(query_task())
except MsticpyDataQueryError:
print(f"Query {con_name} failed.")
return pd.concat(results)
if results:
return pd.concat(results)

logger.warning("All queries failed.")
return pd.DataFrame()

def _create_split_queries(
self,
Expand Down Expand Up @@ -318,29 +377,26 @@ def _create_split_queries(
logger.info("Split query into %s chunks", len(split_queries))
return split_queries

@staticmethod
async def _exec_queries_threaded(
self,
query_tasks: Dict[str, partial],
progress: bool = True,
retry: bool = False,
max_workers: int = 4,
) -> pd.DataFrame:
"""Return results of multiple queries run as threaded tasks."""
logger.info("Running threaded queries for %d connections.", len(query_tasks))

event_loop = _get_event_loop()

with ThreadPoolExecutor(
max_workers=self._query_provider.get_driver_property(
DriverProps.MAX_PARALLEL
)
) as executor:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# add the additional connections
thread_tasks = {
query_id: event_loop.run_in_executor(executor, query_func)
for query_id, query_func in query_tasks.items()
}
results: List[pd.DataFrame] = []
failed_tasks: Dict[str, asyncio.Future] = {}
failed_tasks_ids: List[str] = []
if progress:
task_iter = tqdm(
asyncio.as_completed(thread_tasks.values()),
Expand All @@ -360,24 +416,33 @@ async def _exec_queries_threaded(
"Query task '%s' failed with exception",
query_id,
)
failed_tasks[query_id] = thread_task

if retry and failed_tasks:
for query_id, thread_task in failed_tasks.items():
try:
logger.info("Retrying query task '%s'", query_id)
result = await thread_task
results.append(result)
except Exception: # pylint: disable=broad-except
logger.warning(
"Retried query task '%s' failed with exception",
query_id,
exc_info=True,
)
# Sort the results by the order of the tasks
results = [result for _, result in sorted(zip(thread_tasks, results))]

return pd.concat(results, ignore_index=True)
# Reusing thread task would result in:
# RuntimeError: cannot reuse already awaited coroutine
# A new task should be queued
failed_tasks_ids.append(query_id)

# Sort the results by the order of the tasks
results = [result for _, result in sorted(zip(thread_tasks, results))]

if retry and failed_tasks_ids:
failed_results: pd.DataFrame = (
await QueryProviderConnectionsMixin._exec_queries_threaded(
{
failed_tasks_id: query_tasks[failed_tasks_id]
for failed_tasks_id in failed_tasks_ids
},
progress=progress,
retry=False,
max_workers=max_workers,
)
)
if not failed_results.empty:
results.append(failed_results)
if results:
return pd.concat(results, ignore_index=True)

logger.warning("All queries failed.")
return pd.DataFrame()


def _get_event_loop() -> asyncio.AbstractEventLoop:
Expand Down
1 change: 0 additions & 1 deletion msticpy/data/drivers/driver_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def connect(self, connection_str: Optional[str] = None, **kwargs):
Connect to a data source
"""
return None

@abc.abstractmethod
def query(
Expand Down
Loading

0 comments on commit c1df176

Please sign in to comment.