Skip to content

Commit

Permalink
Cosmos Diagnostics Handler for Python SDK (Azure#38581)
Browse files Browse the repository at this point in the history
* Adding prototype of diagnostics handler

Prototype implementation of diagnostics handler

* Add additional diagnostics info

Adds additional diagnsotics to add to logs.

* allow customizable logging conditions

This commit adds functinality that allows a user to plug in their own should_log method as a class method for the cosmos http logging policy. This allows for deep customizing options including using their own choice of library to log things such as cpu usage.

* update samples and readme

adds to the readme on how to use diagnsotics handler

* updates to diagnostic handler

added more things to filter and disclaimer tht it is preview feature.

* beta version update

* Update diagnostics_handler_sample.py

* Update _cosmos_http_logging_policy.py

* Update _cosmos_http_logging_policy.py

* update version

* Update CHANGELOG.md

* Update _cosmos_http_logging_policy.py

* Update _cosmos_http_logging_policy.py

* Update test_cosmos_http_logging_policy.py

* Update _cosmos_http_logging_policy.py

* Update _cosmos_http_logging_policy.py

* update tests with custom logger
  • Loading branch information
bambriz authored Dec 13, 2024
1 parent 5cf20b1 commit f9e8d67
Show file tree
Hide file tree
Showing 16 changed files with 474 additions and 52 deletions.
9 changes: 2 additions & 7 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
## Release History

### 4.9.1 (Unreleased)
### 4.9.1b1 (2024-12-13)

#### Features Added
* Added change feed mode support in `query_items_change_feed`. See [PR 38105](https://github.com/Azure/azure-sdk-for-python/pull/38105)

#### Breaking Changes

#### Bugs Fixed

#### Other Changes
* Added a **Preview Feature** for adding Diagnostics Handler to filter what diagnostics get logged. This feature is subject to change significantly. See [PR 38105](https://github.com/Azure/azure-sdk-for-python/pull/38581)

### 4.9.0 (2024-11-18)

Expand Down
52 changes: 49 additions & 3 deletions sdk/cosmos/azure-cosmos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ This library uses the standard
[logging](https://docs.python.org/3.5/library/logging.html) library for logging diagnostics.
Basic information about HTTP sessions (URLs, headers, etc.) is logged at INFO
level.

**Note: You must use 'azure.cosmos' for the logger**
Detailed DEBUG level logging, including request/response bodies and unredacted
headers, can be enabled on a client with the `logging_enable` argument:
```python
Expand All @@ -886,7 +886,7 @@ import logging
from azure.cosmos import CosmosClient

# Create a logger for the 'azure' SDK
logger = logging.getLogger('azure')
logger = logging.getLogger('azure.cosmos')
logger.setLevel(logging.DEBUG)

# Configure a console output
Expand All @@ -910,7 +910,7 @@ import logging
from azure.cosmos import CosmosClient

#Create a logger for the 'azure' SDK
logger = logging.getLogger('azure')
logger = logging.getLogger('azure.cosmos')
logger.setLevel(logging.DEBUG)

# Configure a file output
Expand All @@ -928,6 +928,52 @@ However, if you desire to use the CosmosHttpLoggingPolicy to obtain additional i
client = CosmosClient(URL, credential=KEY, enable_diagnostics_logging=True)
database = client.create_database(DATABASE_NAME, logger=logger)
```
**NOTICE: The Following is a Preview Feature that is subject to significant change.**
To further customize what gets logged, you can use a **PREVIEW** diagnostics handler to filter out the logs you don't want to see.
There are several ways to use the diagnostics handler, those include the following:
- Using the "CosmosDiagnosticsHandler" class, which has default behaviour that can be modified.
**NOTE: The diagnostics handler will only be used if the `enable_diagnostics_logging` argument is passed in at the client constructor.
The CosmosDiagnosticsHandler is also a special type of dictionary that is callable and that has preset keys. The values it expects are functions related to it's relevant diagnostic data. (e.g. ```diagnostics_handler["duration"]``` expects a function that takes in an int and returns a boolean as it relates to the duration of an operation to complete).**
```python
from azure.cosmos import CosmosClient, CosmosDiagnosticsHandler
import logging
# Initialize the logger
logger = logging.getLogger('azure.cosmos')
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler('diagnostics1.output')
logger.addHandler(file_handler)
diagnostics_handler = cosmos_diagnostics_handler.CosmosDiagnosticsHandler()
diagnostics_handler["duration"] = lambda x: x > 2000
client = CosmosClient(URL, credential=KEY,logger=logger, diagnostics_handler=diagnostics_handler, enable_diagnostics_logging=True)

```
- Using a dictionary with the relevant functions to filter out the logs you don't want to see.
```python
# Initialize the logger
logger = logging.getLogger('azure.cosmos')
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler('diagnostics2.output')
logger.addHandler(file_handler)
diagnostics_handler = {
"duration": lambda x: x > 2000
}
client = CosmosClient(URL, credential=KEY,logger=logger, diagnostics_handler=diagnostics_handler, enable_diagnostics_logging=True)
```
- Using a function that will replace the should_log function in the CosmosHttpLoggingPolicy which expects certain paramameters and returns a boolean. **Note: the parameters of the custom should_log must match the parameters of the original should_log function as shown in the sample.**
```python
# Custom should_log method
def should_log(self, **kwargs):
return kwargs.get('duration') and kwargs['duration'] > 2000

# Initialize the logger
logger = logging.getLogger('azure.cosmos')
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler('diagnostics3.output')
logger.addHandler(file_handler)

# Initialize the Cosmos client with custom diagnostics handler
client = CosmosClient(endpoint, key,logger=logger, diagnostics_handler=should_log, enable_diagnostics_logging=True)
```

### Telemetry
Azure Core provides the ability for our Python SDKs to use OpenTelemetry with them. The only packages that need to be installed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ def __init__(
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
CosmosHttpLoggingPolicy(
logger=kwargs.pop("logger", None),
enable_diagnostics_logging=kwargs.pop("enable_diagnostics_logging", False),
global_endpoint_manager=self._global_endpoint_manager,
diagnostics_handler=kwargs.pop("diagnostics_handler", None),
**kwargs
),
]
Expand Down
191 changes: 175 additions & 16 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_http_logging_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@
import json
import logging
import time
from typing import Optional, Union, TYPE_CHECKING
from typing import Optional, Union, Dict, Any, TYPE_CHECKING, Callable, Mapping
import types

from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import HttpLoggingPolicy

from .http_constants import HttpHeaders
from ._global_endpoint_manager import _GlobalEndpointManager
from .documents import DatabaseAccount

if TYPE_CHECKING:
from azure.core.rest import HttpRequest, HttpResponse, AsyncHttpResponse
Expand All @@ -41,6 +44,7 @@
AsyncHttpResponse as LegacyAsyncHttpResponse
)


HTTPRequestType = Union["LegacyHttpRequest", "HttpRequest"]
HTTPResponseType = Union["LegacyHttpResponse", "HttpResponse", "LegacyAsyncHttpResponse", "AsyncHttpResponse"]

Expand All @@ -55,11 +59,34 @@ class CosmosHttpLoggingPolicy(HttpLoggingPolicy):
def __init__(
self,
logger: Optional[logging.Logger] = None,
global_endpoint_manager: Optional[_GlobalEndpointManager] = None,
database_account: Optional[DatabaseAccount] = None,
*,
enable_diagnostics_logging: bool = False,
diagnostics_handler: Optional[Union[Callable, Mapping]] = None,
**kwargs
):
self._enable_diagnostics_logging = enable_diagnostics_logging
self.diagnostics_handler = diagnostics_handler
self.__request_already_logged = False
if self.diagnostics_handler and callable(self.diagnostics_handler):
if hasattr(self.diagnostics_handler, '__get__'):
self._should_log = types.MethodType(diagnostics_handler, self) # type: ignore
else:
self._should_log = self.diagnostics_handler # type: ignore
elif isinstance(self.diagnostics_handler, Mapping):
self._should_log = self._dict_should_log # type: ignore
else:
self._should_log = self._default_should_log # type: ignore
self.__global_endpoint_manager = global_endpoint_manager
self.__client_settings = self.__get_client_settings()
self.__database_account_settings: Optional[DatabaseAccount] = (database_account or
self.__get_database_account_settings())
self._resource_map = {
'docs': 'document',
'colls': 'container',
'dbs': 'database'
}
super().__init__(logger, **kwargs)
if self._enable_diagnostics_logging:
cosmos_disallow_list = ["Authorization", "ProxyAuthorization"]
Expand All @@ -69,26 +96,158 @@ def __init__(
self.allowed_header_names = set(cosmos_allow_list)

def on_request(self, request: PipelineRequest[HTTPRequestType]) -> None:
super().on_request(request)
verb = request.http_request.method
if self._enable_diagnostics_logging:
request.context["start_time"] = time.time()

url = None
if self.diagnostics_handler:
url = request.http_request.url
database_name = None
collection_name = None
resource_type = None
if url:
url_parts = url.split('/')
if 'dbs' in url_parts:
dbs_index = url_parts.index('dbs')
if dbs_index + 1 < len(url_parts):
database_name = url_parts[url_parts.index('dbs') + 1]
resource_type = self._resource_map['dbs']
if 'colls' in url_parts:
colls_index = url_parts.index('colls')
if colls_index + 1 < len(url_parts):
collection_name = url_parts[url_parts.index('colls') + 1]
resource_type = self._resource_map['colls']
if 'docs' in url_parts:
resource_type = self._resource_map['docs']
if self._should_log(verb=verb,database_name=database_name,collection_name=collection_name,
resource_type=resource_type, is_request=True):
self._log_client_settings()
self._log_database_account_settings()
super().on_request(request)
self.__request_already_logged = True

# pylint: disable=too-many-statements
def on_response(
self,
request: PipelineRequest[HTTPRequestType],
response: PipelineResponse[HTTPRequestType, HTTPResponseType], # type: ignore[override]
) -> None:
super().on_response(request, response)
if self._enable_diagnostics_logging:
http_response = response.http_response
options = response.context.options
logger = request.context.setdefault("logger", options.pop("logger", self.logger))
duration = time.time() - request.context["start_time"] if "start_time" in request.context else None
status_code = response.http_response.status_code
sub_status_str = response.http_response.headers.get("x-ms-substatus")
sub_status_code = int(sub_status_str) if sub_status_str else None
verb = request.http_request.method
http_version_obj = None
url = None
if self.diagnostics_handler:
try:
major = response.http_response.internal_response.version.major # type: ignore[attr-defined, union-attr]
minor = response.http_response.internal_response.version.minor # type: ignore[attr-defined, union-attr]
http_version_obj = f"{major}."
http_version_obj += f"{minor}"
except (AttributeError, TypeError):
http_version_obj = None
try:
if "start_time" in request.context:
logger.info("Elapsed time in seconds: {}".format(time.time() - request.context["start_time"]))
else:
logger.info("Elapsed time in seconds: unknown")
if http_response.status_code >= 400:
logger.info("Response error message: %r", _format_error(http_response.text()))
except Exception as err: # pylint: disable=broad-except
logger.warning("Failed to log request: %s", repr(err))
url = response.http_response.internal_response.url.geturl() # type: ignore[attr-defined, union-attr]
except AttributeError:
url = str(response.http_response.internal_response.url) # type: ignore[attr-defined, union-attr]
database_name = None
collection_name = None
resource_type = None
if url:
url_parts = url.split('/')
if 'dbs' in url_parts:
dbs_index = url_parts.index('dbs')
if dbs_index + 1 < len(url_parts):
database_name = url_parts[url_parts.index('dbs') + 1]
resource_type = self._resource_map['dbs']
if 'colls' in url_parts:
colls_index = url_parts.index('colls')
if colls_index + 1 < len(url_parts):
collection_name = url_parts[url_parts.index('colls') + 1]
resource_type = self._resource_map['colls']
if 'docs' in url_parts:
resource_type = self._resource_map['docs']


if self._should_log(duration=duration, status_code=status_code, sub_status_code=sub_status_code,
verb=verb, http_version=http_version_obj, database_name=database_name,
collection_name=collection_name, resource_type=resource_type, is_request=False):
if not self.__request_already_logged:
self._log_client_settings()
self._log_database_account_settings()
super().on_request(request)
else:
self.__request_already_logged = False
super().on_response(request, response)
if self._enable_diagnostics_logging:
http_response = response.http_response
options = response.context.options
logger = request.context.setdefault("logger", options.pop("logger", self.logger))
try:
if "start_time" in request.context:
logger.info("Elapsed time in seconds: {}".format(duration))
else:
logger.info("Elapsed time in seconds: unknown")
if http_response.status_code >= 400:
logger.info("Response error message: %r", _format_error(http_response.text()))
except Exception as err: # pylint: disable=broad-except
logger.warning("Failed to log request: %s", repr(err))

# pylint: disable=unused-argument
def _default_should_log(
self,
**kwargs
) -> bool:
return True

def _dict_should_log(self, **kwargs) -> bool:
params = {
'duration': kwargs.get('duration', None),
'status code': kwargs.get('status_code', None),
'verb': kwargs.get('verb', None),
'http version': kwargs.get('http_version', None),
'database name': kwargs.get('database_name', None),
'collection name': kwargs.get('collection_name', None),
'resource type': kwargs.get('resource_type', None)
}
for key, param in params.items():
if (param and isinstance(self.diagnostics_handler, Mapping) and key in self.diagnostics_handler
and self.diagnostics_handler[key] is not None):
if self.diagnostics_handler[key](param):
return True
return False

def __get_client_settings(self) -> Optional[Dict[str, Any]]:
# Place any client settings we want to log here
if self.__global_endpoint_manager:
if hasattr(self.__global_endpoint_manager, 'PreferredLocations'):
return {"Client Preferred Regions": self.__global_endpoint_manager.PreferredLocations}
return {"Client Preferred Regions": []}
return None

def __get_database_account_settings(self) -> Optional[DatabaseAccount]:
if self.__global_endpoint_manager and hasattr(self.__global_endpoint_manager, '_database_account_cache'):
return self.__global_endpoint_manager._database_account_cache # pylint: disable=protected-access
return None

def _log_client_settings(self) -> None:
self.logger.info("Client Settings:", exc_info=False)
if self.__client_settings and isinstance(self.__client_settings, dict):
self.logger.info("\tClient Preferred Regions: %s", self.__client_settings["Client Preferred Regions"],
exc_info=False)

# pylint: disable=protected-access
def _log_database_account_settings(self) -> None:
self.logger.info("Database Account Settings:", exc_info=False)
self.__database_account_settings = self.__get_database_account_settings()
if self.__database_account_settings and self.__database_account_settings.ConsistencyPolicy:
self.logger.info("\tConsistency Level: %s",
self.__database_account_settings.ConsistencyPolicy.get("defaultConsistencyLevel"),
exc_info=False)
self.logger.info("\tWritable Locations: %s", self.__database_account_settings.WritableLocations,
exc_info=False)
self.logger.info("\tReadable Locations: %s", self.__database_account_settings.ReadableLocations,
exc_info=False)
self.logger.info("\tMulti-Region Writes: %s",
self.__database_account_settings._EnableMultipleWritableLocations, exc_info=False)
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, client):
self.refresh_needed = False
self.refresh_lock = threading.RLock()
self.last_refresh_time = 0
self._database_account_cache = None

def get_refresh_time_interval_in_ms_stub(self):
return constants._Constants.DefaultUnavailableLocationExpirationTime
Expand Down Expand Up @@ -125,6 +126,7 @@ def _GetDatabaseAccount(self, **kwargs):
"""
try:
database_account = self._GetDatabaseAccountStub(self.DefaultEndpoint, **kwargs)
self._database_account_cache = database_account
return database_account
# If for any reason(non-globaldb related), we are not able to get the database
# account from the above call to GetDatabaseAccount, we would try to get this
Expand All @@ -137,6 +139,7 @@ def _GetDatabaseAccount(self, **kwargs):
locational_endpoint = _GlobalEndpointManager.GetLocationalEndpoint(self.DefaultEndpoint, location_name)
try:
database_account = self._GetDatabaseAccountStub(locational_endpoint, **kwargs)
self._database_account_cache = database_account
return database_account
except exceptions.CosmosHttpResponseError:
pass
Expand Down
2 changes: 1 addition & 1 deletion sdk/cosmos/azure-cosmos/azure/cosmos/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

VERSION = "4.9.1"
VERSION = "4.9.1b1"
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,13 @@ def __init__(
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
CosmosHttpLoggingPolicy(enable_diagnostics_logging=kwargs.pop("enable_diagnostics_logging", False),
**kwargs),
CosmosHttpLoggingPolicy(
logger=kwargs.pop("logger", None),
enable_diagnostics_logging=kwargs.pop("enable_diagnostics_logging", False),
global_endpoint_manager=self._global_endpoint_manager,
diagnostics_handler=kwargs.pop("diagnostics_handler", None),
**kwargs
),
]

transport = kwargs.pop("transport", None)
Expand Down
Loading

0 comments on commit f9e8d67

Please sign in to comment.