Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extra tests and fixes to QueryProvider, DriverBase and (as)sync query handling #777

Merged
merged 17 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
1349425
Updated conda reqs files for new packages (#758)
FlorianBracq May 11, 2024
35656a4
Modifying method _check_environment to be a class method
FlorianBracq May 11, 2024
d086091
Fix deprecation of datetime.utcnow method
FlorianBracq May 11, 2024
e542e0a
Expand test coverage for base driver
FlorianBracq May 11, 2024
cfdcf80
Remove implementation from abstract method connect
FlorianBracq May 11, 2024
5f9c227
Expand test coverage for data provider
FlorianBracq May 11, 2024
201ea8b
Remove method exec_query from QueryProviderProtocol as it cannot be i…
FlorianBracq May 12, 2024
9a9d02c
Set method _get_query_options as abstract from QueryProviderProtocol …
FlorianBracq May 12, 2024
3f5b66a
Update _exec_queries_threaded to be a static method, to align with _e…
FlorianBracq May 12, 2024
8e0b00b
Handle corner cases when query_tasks are all failing
FlorianBracq May 12, 2024
36b4299
Fixing logic flaw in _exec_queries_threaded retry logic
FlorianBracq May 12, 2024
b6e11f2
Fix logic issue in _calc_split_ranges
FlorianBracq May 12, 2024
4cc2e80
Extend coverage for query_provider_connections_mixin
FlorianBracq May 12, 2024
6104b93
Merge branch 'main' of https://github.com/microsoft/msticpy into extr…
FlorianBracq May 12, 2024
5993595
Adding abstract method _get_query_options to QueryProviderConnections…
FlorianBracq May 12, 2024
46797e2
Use typing types to be compatible with python 3.8
FlorianBracq May 12, 2024
29acc7f
Move implementation of exec_query to QueryProviderConnectionsMixin
FlorianBracq May 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 3 additions & 38 deletions msticpy/data/core/data_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__( # noqa: MC0001
# pylint: enable=import-outside-toplevel
setattr(self.__class__, "_add_pivots", add_data_queries_to_entities)

data_environment, self.environment_name = self._check_environment(
data_environment, self.environment_name = QueryProvider._check_environment(
data_environment
)

Expand Down Expand Up @@ -139,8 +139,9 @@ def __init__( # noqa: MC0001
self._query_time = QueryTime(units="day")
logger.info("Initialization complete.")

@classmethod
def _check_environment(
self, data_environment
cls, data_environment
) -> Tuple[Union[str, DataEnvironment], str]:
"""Check environment against known names."""
if isinstance(data_environment, str):
Expand Down Expand Up @@ -212,42 +213,6 @@ def connect(self, connection_str: Optional[str] = None, **kwargs):
logger.info("Adding query pivot functions")
self._add_pivots(lambda: self._query_time.timespan)

def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]:
ianhelle marked this conversation as resolved.
Show resolved Hide resolved
"""
Execute simple query string.

Parameters
----------
query : str
[description]
use_connections : Union[str, List[str]]

Other Parameters
----------------
query_options : Dict[str, Any]
Additional options passed to query driver.
kwargs : Dict[str, Any]
Additional options passed to query driver.

Returns
-------
Union[pd.DataFrame, Any]
Query results - a DataFrame if successful
or a KqlResult if unsuccessful.

"""
query_options = kwargs.pop("query_options", {}) or kwargs
query_source = kwargs.pop("query_source", None)

logger.info("Executing query '%s...'", query[:40])
logger.debug("Full query: %s", query)
logger.debug("Query options: %s", query_options)
if not self._additional_connections:
return self._query_provider.query(
query, query_source=query_source, **query_options
)
return self._exec_additional_connections(query, **kwargs)

@property
def query_time(self):
"""Return the default QueryTime control for queries."""
Expand Down
135 changes: 100 additions & 35 deletions msticpy/data/core/query_provider_connections_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Query Provider additional connection methods."""
import asyncio
import logging
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial
Expand Down Expand Up @@ -37,23 +38,61 @@ class QueryProviderProtocol(Protocol):
_additional_connections: Dict[str, Any]
_query_provider: DriverBase

def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]:
"""Execute a query against the provider."""
...

# fmt: off
@staticmethod
@abstractmethod
def _get_query_options(
params: Dict[str, Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
...
# fmt: on
"""Return any kwargs not already in params."""


# pylint: disable=super-init-not-called
class QueryProviderConnectionsMixin(QueryProviderProtocol):
"""Mixin additional connection handling QueryProvider class."""

@staticmethod
@abstractmethod
def _get_query_options(
params: Dict[str, Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
"""Return any kwargs not already in params."""

def exec_query(self, query: str, **kwargs) -> Union[pd.DataFrame, Any]:
"""
Execute simple query string.

Parameters
----------
query : str
[description]
use_connections : Union[str, List[str]]

Other Parameters
----------------
query_options : Dict[str, Any]
Additional options passed to query driver.
kwargs : Dict[str, Any]
Additional options passed to query driver.

Returns
-------
Union[pd.DataFrame, Any]
Query results - a DataFrame if successful
or a KqlResult if unsuccessful.

"""
query_options = kwargs.pop("query_options", {}) or kwargs
query_source = kwargs.pop("query_source", None)

logger.info("Executing query '%s...'", query[:40])
logger.debug("Full query: %s", query)
logger.debug("Query options: %s", query_options)
if not self._additional_connections:
return self._query_provider.query(
query, query_source=query_source, **query_options
)
return self._exec_additional_connections(query, **kwargs)

def add_connection(
self,
connection_str: Optional[str] = None,
Expand Down Expand Up @@ -159,8 +198,16 @@ def _exec_additional_connections(self, query, **kwargs) -> pd.DataFrame:
if self._query_provider.get_driver_property(DriverProps.SUPPORTS_THREADING):
logger.info("Running threaded queries.")
event_loop = _get_event_loop()
max_workers: int = self._query_provider.get_driver_property(
DriverProps.MAX_PARALLEL
)
return event_loop.run_until_complete(
self._exec_queries_threaded(query_tasks, progress, retry)
self._exec_queries_threaded(
query_tasks,
progress,
retry,
max_workers,
)
)

# standard synchronous execution
Expand Down Expand Up @@ -238,8 +285,16 @@ def _exec_split_query(
if self._query_provider.get_driver_property(DriverProps.SUPPORTS_THREADING):
logger.info("Running threaded queries.")
event_loop = _get_event_loop()
max_workers: int = self._query_provider.get_driver_property(
DriverProps.MAX_PARALLEL
)
return event_loop.run_until_complete(
self._exec_queries_threaded(query_tasks, progress, retry)
self._exec_queries_threaded(
query_tasks,
progress,
retry,
max_workers,
)
)

# or revert to standard synchronous execution
Expand Down Expand Up @@ -285,7 +340,11 @@ def _exec_synchronous_queries(
results.append(query_task())
except MsticpyDataQueryError:
print(f"Query {con_name} failed.")
return pd.concat(results)
if results:
return pd.concat(results)

logger.warning("All queries failed.")
return pd.DataFrame()

def _create_split_queries(
self,
Expand Down Expand Up @@ -318,29 +377,26 @@ def _create_split_queries(
logger.info("Split query into %s chunks", len(split_queries))
return split_queries

@staticmethod
async def _exec_queries_threaded(
self,
query_tasks: Dict[str, partial],
progress: bool = True,
retry: bool = False,
max_workers: int = 4,
) -> pd.DataFrame:
"""Return results of multiple queries run as threaded tasks."""
logger.info("Running threaded queries for %d connections.", len(query_tasks))

event_loop = _get_event_loop()

with ThreadPoolExecutor(
max_workers=self._query_provider.get_driver_property(
DriverProps.MAX_PARALLEL
)
) as executor:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# add the additional connections
thread_tasks = {
query_id: event_loop.run_in_executor(executor, query_func)
for query_id, query_func in query_tasks.items()
}
results: List[pd.DataFrame] = []
failed_tasks: Dict[str, asyncio.Future] = {}
failed_tasks_ids: List[str] = []
if progress:
task_iter = tqdm(
asyncio.as_completed(thread_tasks.values()),
Expand All @@ -360,24 +416,33 @@ async def _exec_queries_threaded(
"Query task '%s' failed with exception",
query_id,
)
failed_tasks[query_id] = thread_task

if retry and failed_tasks:
for query_id, thread_task in failed_tasks.items():
try:
logger.info("Retrying query task '%s'", query_id)
result = await thread_task
results.append(result)
except Exception: # pylint: disable=broad-except
logger.warning(
"Retried query task '%s' failed with exception",
query_id,
exc_info=True,
)
# Sort the results by the order of the tasks
results = [result for _, result in sorted(zip(thread_tasks, results))]

return pd.concat(results, ignore_index=True)
# Reusing thread task would result in:
# RuntimeError: cannot reuse already awaited coroutine
# A new task should be queued
failed_tasks_ids.append(query_id)

# Sort the results by the order of the tasks
results = [result for _, result in sorted(zip(thread_tasks, results))]

if retry and failed_tasks_ids:
failed_results: pd.DataFrame = (
await QueryProviderConnectionsMixin._exec_queries_threaded(
{
failed_tasks_id: query_tasks[failed_tasks_id]
for failed_tasks_id in failed_tasks_ids
},
progress=progress,
retry=False,
max_workers=max_workers,
)
)
if not failed_results.empty:
results.append(failed_results)
if results:
return pd.concat(results, ignore_index=True)

logger.warning("All queries failed.")
return pd.DataFrame()


def _get_event_loop() -> asyncio.AbstractEventLoop:
Expand Down
1 change: 0 additions & 1 deletion msticpy/data/drivers/driver_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def connect(self, connection_str: Optional[str] = None, **kwargs):
Connect to a data source
"""
return None

@abc.abstractmethod
def query(
Expand Down
Loading
Loading