Skip to content

Commit

Permalink
chore: lint
Browse files Browse the repository at this point in the history
  • Loading branch information
keithZmudzinski committed Sep 12, 2024
1 parent f2971f4 commit e32234e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@
_get_opentelemetry_values,
unwrap,
)
from opentelemetry.semconv._incubating.attributes.db_attributes import (
DB_COLLECTION_NAME,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.semconv._incubating.attributes.db_attributes import DB_COLLECTION_NAME
from opentelemetry.trace import SpanKind, TracerProvider, get_tracer

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -288,11 +290,11 @@ def wrapped_connection(
self.get_connection_attributes(connection=connection, kwargs=kwargs)
return get_traced_connection_proxy(connection, self)

def get_connection_attributes(self, connection, kwargs={}):
def get_connection_attributes(self, connection, kwargs=None):
# Populate span fields using kwargs and connection
for key, value in self.connection_attributes.items():
# First set from kwargs
if value in kwargs:
if kwargs and value in kwargs:
self.connection_props[key] = kwargs.get(value)

# Then override from connection object
Expand Down Expand Up @@ -405,7 +407,9 @@ def _populate_span(
def get_span_name(self, statement):
operation_name = self.get_operation_name(statement)
collection_name = CursorTracer.get_collection_name(statement)
return " ".join(name for name in (operation_name, collection_name) if name)
return " ".join(
name for name in (operation_name, collection_name) if name
)

def get_operation_name(self, statement):
# Strip leading comments so we get the operation name.
Expand All @@ -414,9 +418,11 @@ def get_operation_name(self, statement):
@staticmethod
def get_collection_name(statement):
collection_name = ""
match = re.search(r"\b(?:FROM|JOIN|INTO|UPDATE|TABLE)\s+([\w`']+)", statement)
match = re.search(
r"\b(?:FROM|JOIN|INTO|UPDATE|TABLE)\s+([\w`']+)", statement
)
if match:
collection_name = match.group(1).strip('`\'')
collection_name = match.group(1).strip("`'")

return collection_name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from opentelemetry import trace as trace_api
from opentelemetry.instrumentation import dbapi
from opentelemetry.sdk import resources
from opentelemetry.semconv._incubating.attributes.db_attributes import (
DB_COLLECTION_NAME,
)
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.semconv._incubating.attributes.db_attributes import DB_COLLECTION_NAME
from opentelemetry.test.test_base import TestBase


Expand Down Expand Up @@ -67,9 +69,7 @@ def test_span_succeeded(self):
self.assertEqual(
span.attributes[SpanAttributes.DB_STATEMENT], expected_query
)
self.assertEqual(
span.attributes[DB_COLLECTION_NAME], "test_table"
)
self.assertEqual(span.attributes[DB_COLLECTION_NAME], "test_table")
self.assertFalse("db.statement.parameters" in span.attributes)
self.assertEqual(span.attributes[SpanAttributes.DB_USER], "testuser")
self.assertEqual(
Expand Down

0 comments on commit e32234e

Please sign in to comment.