Skip to content

Commit

Permalink
memcpy approach fastest benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jopel committed Oct 18, 2024
1 parent 5d474c7 commit 077f82b
Show file tree
Hide file tree
Showing 21 changed files with 992 additions and 1,265 deletions.
6 changes: 3 additions & 3 deletions benchmark/benchmark_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def get_traces_data() -> Sequence[_Span]:
def test_bm_serialize_logs_data(state):
logs_data = get_logs_data()
while state:
bytes(encode_logs(logs_data))
encode_logs(logs_data)

@benchmark.register
def test_bm_pb2_serialize_logs_data(state):
Expand All @@ -309,7 +309,7 @@ def test_bm_pb2_serialize_logs_data(state):
def test_bm_serialize_metrics_data(state):
metrics_data = get_metrics_data()
while state:
bytes(encode_metrics(metrics_data))
encode_metrics(metrics_data)

@benchmark.register
def test_bm_pb2_serialize_metrics_data(state):
Expand All @@ -321,7 +321,7 @@ def test_bm_pb2_serialize_metrics_data(state):
def test_bm_serialize_traces_data(state):
traces_data = get_traces_data()
while state:
bytes(encode_spans(traces_data))
encode_spans(traces_data)

@benchmark.register
def test_bm_pb2_serialize_traces_data(state):
Expand Down
113 changes: 39 additions & 74 deletions compile/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@

import os
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List, Union
from typing import List, Optional
from enum import IntEnum

from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
FileDescriptorProto,
FieldDescriptorProto,
OneofDescriptorProto,
EnumDescriptorProto,
EnumValueDescriptorProto,
MethodDescriptorProto,
Expand All @@ -33,26 +31,25 @@ class ProtoTypeDescriptor:
name: str
wire_type: WireType
python_type: str
default: str

proto_type_to_descriptor = {
FieldDescriptorProto.TYPE_BOOL: ProtoTypeDescriptor("bool", WireType.VARINT, "bool", "False"),
FieldDescriptorProto.TYPE_ENUM: ProtoTypeDescriptor("enum", WireType.VARINT, "int", "0"),
FieldDescriptorProto.TYPE_INT32: ProtoTypeDescriptor("int32", WireType.VARINT, "int", "0"),
FieldDescriptorProto.TYPE_INT64: ProtoTypeDescriptor("int64", WireType.VARINT, "int", "0"),
FieldDescriptorProto.TYPE_UINT32: ProtoTypeDescriptor("uint32", WireType.VARINT, "int", "0"),
FieldDescriptorProto.TYPE_UINT64: ProtoTypeDescriptor("uint64", WireType.VARINT, "int", "0"),
FieldDescriptorProto.TYPE_SINT32: ProtoTypeDescriptor("sint32", WireType.VARINT, "int", "0"),
FieldDescriptorProto.TYPE_SINT64: ProtoTypeDescriptor("sint64", WireType.VARINT, "int", "0"),
FieldDescriptorProto.TYPE_FIXED32: ProtoTypeDescriptor("fixed32", WireType.I32, "int", "0"),
FieldDescriptorProto.TYPE_FIXED64: ProtoTypeDescriptor("fixed64", WireType.I64, "int", "0"),
FieldDescriptorProto.TYPE_SFIXED32: ProtoTypeDescriptor("sfixed32", WireType.I32, "int", "0"),
FieldDescriptorProto.TYPE_SFIXED64: ProtoTypeDescriptor("sfixed64", WireType.I64, "int", "0"),
FieldDescriptorProto.TYPE_FLOAT: ProtoTypeDescriptor("float", WireType.I32, "float", "0.0"),
FieldDescriptorProto.TYPE_DOUBLE: ProtoTypeDescriptor("double", WireType.I64, "float", "0.0"),
FieldDescriptorProto.TYPE_STRING: ProtoTypeDescriptor("string", WireType.LEN, "str", '""'),
FieldDescriptorProto.TYPE_BYTES: ProtoTypeDescriptor("bytes", WireType.LEN, "bytes", 'b""'),
FieldDescriptorProto.TYPE_MESSAGE: ProtoTypeDescriptor("message", WireType.LEN, "MessageMarshaler", 'None'),
FieldDescriptorProto.TYPE_BOOL: ProtoTypeDescriptor("bool", WireType.VARINT, "bool"),
FieldDescriptorProto.TYPE_ENUM: ProtoTypeDescriptor("enum", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_INT32: ProtoTypeDescriptor("int32", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_INT64: ProtoTypeDescriptor("int64", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_UINT32: ProtoTypeDescriptor("uint32", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_UINT64: ProtoTypeDescriptor("uint64", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_SINT32: ProtoTypeDescriptor("sint32", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_SINT64: ProtoTypeDescriptor("sint64", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_FIXED32: ProtoTypeDescriptor("fixed32", WireType.I32, "int"),
FieldDescriptorProto.TYPE_FIXED64: ProtoTypeDescriptor("fixed64", WireType.I64, "int"),
FieldDescriptorProto.TYPE_SFIXED32: ProtoTypeDescriptor("sfixed32", WireType.I32, "int"),
FieldDescriptorProto.TYPE_SFIXED64: ProtoTypeDescriptor("sfixed64", WireType.I64, "int"),
FieldDescriptorProto.TYPE_FLOAT: ProtoTypeDescriptor("float", WireType.I32, "float"),
FieldDescriptorProto.TYPE_DOUBLE: ProtoTypeDescriptor("double", WireType.I64, "float"),
FieldDescriptorProto.TYPE_STRING: ProtoTypeDescriptor("string", WireType.LEN, "str"),
FieldDescriptorProto.TYPE_BYTES: ProtoTypeDescriptor("bytes", WireType.LEN, "bytes"),
FieldDescriptorProto.TYPE_MESSAGE: ProtoTypeDescriptor("message", WireType.LEN, "bytes"),
}

@dataclass
Expand All @@ -73,9 +70,9 @@ class EnumTemplate:
values: List["EnumValueTemplate"] = field(default_factory=list)

@staticmethod
def from_descriptor(descriptor: EnumDescriptorProto) -> "EnumTemplate":
def from_descriptor(descriptor: EnumDescriptorProto, parent: str = "") -> "EnumTemplate":
return EnumTemplate(
name=descriptor.name,
name=parent + "_" + descriptor.name if parent else descriptor.name,
values=[EnumValueTemplate.from_descriptor(value) for value in descriptor.value],
)

Expand All @@ -95,21 +92,13 @@ class FieldTemplate:
python_type: str
proto_type: str
repeated: bool
default: str
group: str
encode_presence: bool

@staticmethod
def from_descriptor(descriptor: FieldDescriptorProto) -> "FieldTemplate":
def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = None) -> "FieldTemplate":
repeated = descriptor.label == FieldDescriptorProto.LABEL_REPEATED
type_descriptor = proto_type_to_descriptor[descriptor.type]

if descriptor.HasField("oneof_index"):
default = None
elif repeated:
# In python, default field values are shared across all instances of the class
# So we should not use mutable objects like list() as default values
default = None
else:
default = type_descriptor.default

python_type = type_descriptor.python_type
proto_type = type_descriptor.name
Expand All @@ -127,65 +116,41 @@ def from_descriptor(descriptor: FieldDescriptorProto) -> "FieldTemplate":
# Saves us from having to calculate the tag at runtime
tag = tag_to_repr_varint(tag)

# For group / oneof fields, we need to encode the presence of the field
# For message fields, we need to encode the presence of the field if it is not None
encode_presence = group is not None or proto_type == "message"

return FieldTemplate(
name=descriptor.name,
tag=tag,
number=descriptor.number,
python_type=python_type,
proto_type=proto_type,
repeated=repeated,
default=default,
)

@dataclass
class OneOfTemplate:
name: str
fields: List[FieldTemplate] = field(default_factory=list)

@staticmethod
def from_descriptor(descriptor: OneofDescriptorProto, fields: List[FieldDescriptorProto]) -> "OneOfTemplate":

fields = [FieldTemplate.from_descriptor(field) for field in fields]
# Sort the fields by number in descending order to follow "last one wins" semantics
fields.sort(key=lambda field: field.number, reverse=True)

return OneOfTemplate(
name=descriptor.name,
fields=fields,
group=group,
encode_presence=encode_presence,
)

@dataclass
class MessageTemplate:
name: str
fields: List[Union["FieldTemplate", "OneOfTemplate"]] = field(default_factory=list)
fields: List[FieldTemplate] = field(default_factory=list)
enums: List["EnumTemplate"] = field(default_factory=list)
messages: List["MessageTemplate"] = field(default_factory=list)

@staticmethod
def from_descriptor(descriptor: DescriptorProto) -> "MessageTemplate":
fields = []
oneofs_map = defaultdict(list)
for field in descriptor.field:
if field.HasField("oneof_index"):
oneofs_map[field.oneof_index].append(field)
else:
fields.append(field)

# Sort the fields by number in descending order, since we serialize in reverse order
fields = [FieldTemplate.from_descriptor(field) for field in fields]
oneofs = [OneOfTemplate.from_descriptor(descriptor.oneof_decl[oneof_index], fields) for oneof_index, fields in oneofs_map.items()]
fields += oneofs
def sort_key(field: Union[FieldTemplate, OneOfTemplate]):
if isinstance(field, FieldTemplate):
return field.number
return field.fields[0].number
fields.sort(key=sort_key, reverse=True)
def from_descriptor(descriptor: DescriptorProto, parent: str = "") -> "MessageTemplate":
def get_group(field: FieldDescriptorProto) -> str:
return descriptor.oneof_decl[field.oneof_index].name if field.HasField("oneof_index") else None
fields = [FieldTemplate.from_descriptor(field, get_group(field)) for field in descriptor.field]
fields.sort(key=lambda field: field.number)

name = parent + "_" + descriptor.name if parent else descriptor.name
return MessageTemplate(
name=descriptor.name,
name=name,
fields=fields,
enums=[EnumTemplate.from_descriptor(enum) for enum in descriptor.enum_type],
messages=[MessageTemplate.from_descriptor(message) for message in descriptor.nested_type],
enums=[EnumTemplate.from_descriptor(enum, name) for enum in descriptor.enum_type],
messages=[MessageTemplate.from_descriptor(message, name) for message in descriptor.nested_type],
)

@dataclass
Expand Down
50 changes: 12 additions & 38 deletions compile/templates/template.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from snowflake.telemetry._internal.serialize import (
Enum,
ProtoSerializer,
MessageMarshaler,
)
from typing import List
from typing import List, Optional

{% for enum in file_template.enums %}
class {{ enum.name }}(Enum):
Expand All @@ -17,52 +16,27 @@ class {{ enum.name }}(Enum):
{% endfor %}

{% macro render_message(message) %}
class {{ message.name }}(MessageMarshaler):
def __init__(
self,
{% for field in message.fields|reverse %}
{% if field.fields is defined %}
{% for oneof_field in field.fields|reverse %}
{{ oneof_field.name }}: {{ oneof_field.python_type }} = {{ oneof_field.default }},
{% endfor %}
{% else %}
{{ field.name }}: {{ field.python_type }} = {{ field.default }},
{% endif %}
{% endfor %}
):
{%- for field in message.fields|reverse %}
{%- if field.fields is defined %}
{%- for oneof_field in field.fields|reverse %}
self.{{ oneof_field.name }} = {{ oneof_field.name }}
{%- endfor %}
{%- else %}
self.{{ field.name }} = {{ field.name }}
{%- endif %}
{%- endfor %}

def write_to(self, proto_serializer: ProtoSerializer) -> None:
def {{ message.name }}(
{%- for field in message.fields %}
{%- if field.fields is defined %}
# oneof group {{ field.name }}
{%- for oneof_field in field.fields %}
{% if loop.index != 1 %}el{% endif %}if self.{{ oneof_field.name }} is not None:
proto_serializer.serialize_{{ oneof_field.proto_type }}({{ oneof_field.tag }}, self.{{ oneof_field.name }})
{{ field.name }}: Optional[{{ field.python_type }}] = None,
{%- endfor %}
{%- else %}
if self.{{ field.name }}: proto_serializer.serialize_{{ field.proto_type }}({{ field.tag }}, self.{{ field.name }})
{%- endif %}
) -> bytes:
proto_serializer = ProtoSerializer()
{%- for field in message.fields %}
if {{ field.name }}{% if field.encode_presence %} is not None{% endif %}: {% if field.group %}# oneof group {{ field.group }}{% endif %}
proto_serializer.serialize_{{ field.proto_type }}({{ field.tag }}, {{ field.name }})
{%- endfor %}
return proto_serializer.out

{% for nested_enum in message.enums %}
class {{ nested_enum.name }}(Enum):
class {{ nested_enum.name }}(Enum):
{%- for value in nested_enum.values %}
{{ value.name }} = {{ value.number }}
{{ value.name }} = {{ value.number }}
{%- endfor %}
{% endfor %}

{% for nested_message in message.messages %}
{%- set nested_message_result = render_message(nested_message) -%}
{{ nested_message_result | indent(4) }}
{{ render_message(nested_message) }}
{% endfor %}
{% endmacro %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def export(self, batch: typing.Sequence[_logs.LogData]) -> export.LogExportResul
def _serialize_logs_data(batch: typing.Sequence[_logs.LogData]) -> bytes:
# pylint gets confused by protobuf-generated code, that's why we must
# disable the no-member check below.
return bytes(LogsData(resource_logs=encode_logs(batch).resource_logs))
return encode_logs(batch)

def shutdown(self):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def export(
def _serialize_metrics_data(data: MetricsData) -> bytes:
# pylint gets confused by protobuf-generated code, that's why we must
# disable the no-member check below.
return bytes(PB2MetricsData(resource_metrics=encode_metrics(data).resource_metrics))
return encode_metrics(data)

def force_flush(self, timeout_millis: float = 10_000) -> bool:
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _serialize_traces_data(
) -> bytes:
# pylint gets confused by protobuf-generated code, that's why we must
# disable the no-member check below.
return bytes(TracesData(resource_spans=encode_spans(sdk_spans).resource_spans))
return encode_spans(sdk_spans)

def shutdown(self) -> None:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
from opentelemetry.sdk._logs import LogData


def encode_logs(batch: Sequence[LogData]) -> ExportLogsServiceRequest:
return ExportLogsServiceRequest(resource_logs=_encode_resource_logs(batch))

def encode_logs(batch: Sequence[LogData]) -> bytes:
return bytes(ExportLogsServiceRequest(resource_logs=_encode_resource_logs(batch))
)

def _encode_log(log_data: LogData) -> PB2LogRecord:
span_id = (
Expand Down
Loading

0 comments on commit 077f82b

Please sign in to comment.