Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code generator for otel proto #39

Merged
merged 15 commits into from
Oct 29, 2024
237 changes: 237 additions & 0 deletions compile/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
#!/usr/bin/env python3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bit hard to understand without context. Can you add some explanation about this plug in and the standard protoc generator in the PR description? Is it simply a slimmed down version of protoc or does it do anything special?


import os
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make more sense to have this file (and template) under scripts since it is only used when we generate from proto.
Unless that stops tests from reading this file.

import sys
from dataclasses import dataclass, field
from typing import List, Optional
from enum import IntEnum

from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
FileDescriptorProto,
FieldDescriptorProto,
EnumDescriptorProto,
EnumValueDescriptorProto,
MethodDescriptorProto,
ServiceDescriptorProto,
DescriptorProto,
)
from jinja2 import Environment, FileSystemLoader
import black
import isort.api

class WireType(IntEnum):
VARINT = 0
I64 = 1
LEN = 2
I32 = 5

@dataclass
class ProtoTypeDescriptor:
name: str
wire_type: WireType
python_type: str

proto_type_to_descriptor = {
FieldDescriptorProto.TYPE_BOOL: ProtoTypeDescriptor("bool", WireType.VARINT, "bool"),
FieldDescriptorProto.TYPE_ENUM: ProtoTypeDescriptor("enum", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_INT32: ProtoTypeDescriptor("int32", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_INT64: ProtoTypeDescriptor("int64", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_UINT32: ProtoTypeDescriptor("uint32", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_UINT64: ProtoTypeDescriptor("uint64", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_SINT32: ProtoTypeDescriptor("sint32", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_SINT64: ProtoTypeDescriptor("sint64", WireType.VARINT, "int"),
FieldDescriptorProto.TYPE_FIXED32: ProtoTypeDescriptor("fixed32", WireType.I32, "int"),
FieldDescriptorProto.TYPE_FIXED64: ProtoTypeDescriptor("fixed64", WireType.I64, "int"),
FieldDescriptorProto.TYPE_SFIXED32: ProtoTypeDescriptor("sfixed32", WireType.I32, "int"),
FieldDescriptorProto.TYPE_SFIXED64: ProtoTypeDescriptor("sfixed64", WireType.I64, "int"),
FieldDescriptorProto.TYPE_FLOAT: ProtoTypeDescriptor("float", WireType.I32, "float"),
FieldDescriptorProto.TYPE_DOUBLE: ProtoTypeDescriptor("double", WireType.I64, "float"),
FieldDescriptorProto.TYPE_STRING: ProtoTypeDescriptor("string", WireType.LEN, "str"),
FieldDescriptorProto.TYPE_BYTES: ProtoTypeDescriptor("bytes", WireType.LEN, "bytes"),
FieldDescriptorProto.TYPE_MESSAGE: ProtoTypeDescriptor("message", WireType.LEN, "bytes"),
}

@dataclass
class EnumValueTemplate:
name: str
number: int

@staticmethod
def from_descriptor(descriptor: EnumValueDescriptorProto) -> "EnumValueTemplate":
return EnumValueTemplate(
name=descriptor.name,
number=descriptor.number,
)

@dataclass
class EnumTemplate:
name: str
values: List["EnumValueTemplate"] = field(default_factory=list)

@staticmethod
def from_descriptor(descriptor: EnumDescriptorProto, parent: str = "") -> "EnumTemplate":
return EnumTemplate(
name=parent + "_" + descriptor.name if parent else descriptor.name,
values=[EnumValueTemplate.from_descriptor(value) for value in descriptor.value],
)

def tag_to_repr_varint(tag: int) -> str:
out = bytearray()
while tag >= 128:
out.append((tag & 0x7F) | 0x80)
tag >>= 7
out.append(tag)
return repr(bytes(out))

@dataclass
class FieldTemplate:
name: str
number: int
tag: str
python_type: str
proto_type: str
repeated: bool
group: str
encode_presence: bool

@staticmethod
def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = None) -> "FieldTemplate":
repeated = descriptor.label == FieldDescriptorProto.LABEL_REPEATED
type_descriptor = proto_type_to_descriptor[descriptor.type]

python_type = type_descriptor.python_type
proto_type = type_descriptor.name

if repeated:
python_type = f"List[{python_type}]"
proto_type = f"repeated_{proto_type}"

tag = (descriptor.number << 3) | type_descriptor.wire_type.value
if repeated and type_descriptor.wire_type != WireType.LEN:
# Special case: repeated primitive fields are packed
# So we need to use the length-delimited wire type
tag = (descriptor.number << 3) | WireType.LEN.value
# Convert the tag to a varint representation
# Saves us from having to calculate the tag at runtime
tag = tag_to_repr_varint(tag)

# For group / oneof fields, we need to encode the presence of the field
# For message fields, we need to encode the presence of the field if it is not None
encode_presence = group is not None or proto_type == "message"

return FieldTemplate(
name=descriptor.name,
tag=tag,
number=descriptor.number,
python_type=python_type,
proto_type=proto_type,
repeated=repeated,
group=group,
encode_presence=encode_presence,
)

@dataclass
class MessageTemplate:
name: str
fields: List[FieldTemplate] = field(default_factory=list)
enums: List["EnumTemplate"] = field(default_factory=list)
messages: List["MessageTemplate"] = field(default_factory=list)

@staticmethod
def from_descriptor(descriptor: DescriptorProto, parent: str = "") -> "MessageTemplate":
def get_group(field: FieldDescriptorProto) -> str:
return descriptor.oneof_decl[field.oneof_index].name if field.HasField("oneof_index") else None
fields = [FieldTemplate.from_descriptor(field, get_group(field)) for field in descriptor.field]
fields.sort(key=lambda field: field.number)

name = parent + "_" + descriptor.name if parent else descriptor.name
return MessageTemplate(
name=name,
fields=fields,
enums=[EnumTemplate.from_descriptor(enum, name) for enum in descriptor.enum_type],
messages=[MessageTemplate.from_descriptor(message, name) for message in descriptor.nested_type],
)

@dataclass
class MethodTemplate:
name: str
input_message: MessageTemplate
output_message: MessageTemplate

@staticmethod
def from_descriptor(descriptor: MethodDescriptorProto) -> "MethodTemplate":
return MethodTemplate(
name=descriptor.name,
input_message=MessageTemplate(name=descriptor.input_type),
output_message=MessageTemplate(name=descriptor.output_type),
)

@dataclass
class ServiceTemplate:
name: str
methods: List["MethodTemplate"] = field(default_factory=list)

@staticmethod
def from_descriptor(descriptor: ServiceDescriptorProto) -> "ServiceTemplate":
return ServiceTemplate(
name=descriptor.name,
methods=[MethodTemplate.from_descriptor(method) for method in descriptor.method],
)

@dataclass
class FileTemplate:
messages: List["MessageTemplate"] = field(default_factory=list)
enums: List["EnumTemplate"] = field(default_factory=list)
services: List["ServiceTemplate"] = field(default_factory=list)
name: str = ""

@staticmethod
def from_descriptor(descriptor: FileDescriptorProto) -> "FileTemplate":
return FileTemplate(
messages=[MessageTemplate.from_descriptor(message) for message in descriptor.message_type],
enums=[EnumTemplate.from_descriptor(enum) for enum in descriptor.enum_type],
services=[ServiceTemplate.from_descriptor(service) for service in descriptor.service],
name=descriptor.name,
)

def main():
request = plugin.CodeGeneratorRequest()
request.ParseFromString(sys.stdin.buffer.read())

response = plugin.CodeGeneratorResponse()
# needed since metrics.proto uses proto3 optional fields
response.supported_features = plugin.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL

template_env = Environment(loader=FileSystemLoader(f"{os.path.dirname(os.path.realpath(__file__))}/templates"))
jinja_body_template = template_env.get_template("template.py.jinja2")

for proto_file in request.proto_file:
file_name = proto_file.name.replace('.proto', '.py')
file_descriptor_proto = proto_file

file_template = FileTemplate.from_descriptor(file_descriptor_proto)

code = jinja_body_template.render(file_template=file_template)
code = isort.api.sort_code_string(
code = code,
show_diff=False,
profile="black",
combine_as_imports=True,
lines_after_imports=2,
quiet=True,
force_grid_wrap=2,
)
code = black.format_str(
src_contents=code,
mode=black.Mode(),
)

response_file = response.file.add()
response_file.name = file_name
response_file.content = code

sys.stdout.buffer.write(response.SerializeToString())

if __name__ == '__main__':
main()
44 changes: 44 additions & 0 deletions compile/templates/template.py.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Generated by the protoc compiler with a custom plugin. DO NOT EDIT!
# sources: {{ file_template.name }}

from snowflake.telemetry._internal.serialize import (
Enum,
ProtoSerializer,
)
from typing import List, Optional

{% for enum in file_template.enums %}
class {{ enum.name }}(Enum):
{%- for value in enum.values %}
{{ value.name }} = {{ value.number }}
{%- endfor %}
{% endfor %}

{% macro render_message(message) %}
def {{ message.name }}(
{%- for field in message.fields %}
{{ field.name }}: Optional[{{ field.python_type }}] = None,
{%- endfor %}
) -> bytes:
proto_serializer = ProtoSerializer()
{%- for field in message.fields %}
if {{ field.name }}{% if field.encode_presence %} is not None{% endif %}: {% if field.group %}# oneof group {{ field.group }}{% endif %}
proto_serializer.serialize_{{ field.proto_type }}({{ field.tag }}, {{ field.name }})
{%- endfor %}
return proto_serializer.out

{% for nested_enum in message.enums %}
class {{ nested_enum.name }}(Enum):
{%- for value in nested_enum.values %}
{{ value.name }} = {{ value.number }}
{%- endfor %}
{% endfor %}

{% for nested_message in message.messages %}
{{ render_message(nested_message) }}
{% endfor %}
{% endmacro %}

{% for message in file_template.messages %}
{{ render_message(message) }}
{% endfor %}
63 changes: 63 additions & 0 deletions scripts/proto_codegen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/bin/bash
#
# Regenerate python code from OTLP protos in
# https://github.com/open-telemetry/opentelemetry-proto
#
# To use, update PROTO_REPO_BRANCH_OR_COMMIT variable below to a commit hash or
# tag in opentelemtry-proto repo that you want to build off of. Then, just run
# this script to update the proto files. Commit the changes as well as any
# fixes needed in the OTLP exporter.
#
# Optional envars:
# PROTO_REPO_DIR - the path to an existing checkout of the opentelemetry-proto repo

# Pinned commit/branch/tag for the current version used in opentelemetry-proto python package.
PROTO_REPO_BRANCH_OR_COMMIT="v1.2.0"

set -e

PROTO_REPO_DIR=${PROTO_REPO_DIR:-"/tmp/opentelemetry-proto"}
# root of opentelemetry-python repo
repo_root="$(git rev-parse --show-toplevel)"
venv_dir="/tmp/proto_codegen_venv"

# run on exit even if crash
cleanup() {
echo "Deleting $venv_dir"
rm -rf $venv_dir
}
trap cleanup EXIT

echo "Creating temporary virtualenv at $venv_dir using $(python3 --version)"
python3 -m venv $venv_dir
source $venv_dir/bin/activate
python -m pip install protobuf Jinja2 grpcio-tools black isort
echo 'python -m grpc_tools.protoc --version'
python -m grpc_tools.protoc --version

# Clone the proto repo if it doesn't exist
if [ ! -d "$PROTO_REPO_DIR" ]; then
git clone https://github.com/open-telemetry/opentelemetry-proto.git $PROTO_REPO_DIR
fi

# Pull in changes and switch to requested branch
(
cd $PROTO_REPO_DIR
git fetch --all
git checkout $PROTO_REPO_BRANCH_OR_COMMIT
# pull if PROTO_REPO_BRANCH_OR_COMMIT is not a detached head
git symbolic-ref -q HEAD && git pull --ff-only || true
)

cd $repo_root/src/snowflake/telemetry/_internal

# clean up old generated code
find opentelemetry/proto/ -regex ".*\.py?" -exec rm {} +

# generate proto code for all protos
all_protos=$(find $PROTO_REPO_DIR/ -iname "*.proto")
python -m grpc_tools.protoc \
-I $PROTO_REPO_DIR \
--plugin=protoc-gen-custom-plugin=$repo_root/compile/plugin.py \
--custom-plugin_out=. \
$all_protos
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Generated by the protoc compiler with a custom plugin. DO NOT EDIT!
# sources: opentelemetry/proto/collector/logs/v1/logs_service.proto

from typing import (
List,
Optional,
)

from snowflake.telemetry._internal.serialize import (
Enum,
ProtoSerializer,
)


def ExportLogsServiceRequest(
resource_logs: Optional[List[bytes]] = None,
) -> bytes:
proto_serializer = ProtoSerializer()
if resource_logs:
proto_serializer.serialize_repeated_message(b"\n", resource_logs)
return proto_serializer.out


def ExportLogsServiceResponse(
partial_success: Optional[bytes] = None,
) -> bytes:
proto_serializer = ProtoSerializer()
if partial_success is not None:
proto_serializer.serialize_message(b"\n", partial_success)
return proto_serializer.out


def ExportLogsPartialSuccess(
rejected_log_records: Optional[int] = None,
error_message: Optional[str] = None,
) -> bytes:
proto_serializer = ProtoSerializer()
if rejected_log_records:
proto_serializer.serialize_int64(b"\x08", rejected_log_records)
if error_message:
proto_serializer.serialize_string(b"\x12", error_message)
return proto_serializer.out
Loading
Loading