diff --git a/pyrobuf/compile.py b/pyrobuf/compile.py index 2b29dee..8e93ce6 100644 --- a/pyrobuf/compile.py +++ b/pyrobuf/compile.py @@ -5,6 +5,7 @@ from setuptools import setup from Cython.Build import cythonize +from pathlib import Path from jinja2 import Environment, PackageLoader from pyrobuf.parse_proto import Parser, Proto3Parser @@ -24,18 +25,20 @@ class Compiler(object): t_pyx = _env.get_template('proto_pyx.tmpl') def __init__(self, sources, out="out", build="build", install=False, - proto3=False, force=False, package=None, includes=None, + proto3=False, force=False, package=None, module_name=None, includes=None, clean=False): self.sources = sources self.out = out + if module_name: + self.out = os.path.join(out, module_name) self.build = build self.install = install self.force = force self.package = package - self.includes = includes or [] + self.includes = [os.path.normpath(inc) for inc in includes or []] self.clean = clean here = os.path.dirname(os.path.abspath(__file__)) - self.include_path = [os.path.join(here, 'src'), self.out] + self.include_path = [os.path.join(here, 'src'), out] self._generated = set() self._messages = [] self._pyx_files = [] @@ -45,6 +48,9 @@ def __init__(self, sources, out="out", build="build", install=False, else: self.parser = Parser + if module_name: + self.parser.module_name = module_name+'.' + @classmethod def parse_cli_args(cls): parser = argparse.ArgumentParser( @@ -65,6 +71,8 @@ def parse_cli_args(cls): help="force install") parser.add_argument('--package', type=str, default=None, help="name of package to compile to") + parser.add_argument('--module_name', type=str, default=None, + help="name of module to compile to") parser.add_argument('--include', action='append', help="add directory to includes path") parser.add_argument('--clean', action='store_true', @@ -73,7 +81,7 @@ def parse_cli_args(cls): return cls(args.sources, out=args.out_dir, build=args.build_dir, install=args.install, proto3=args.proto3, force=args.force, - package=args.package, includes=args.include, + package=args.package, module_name=args.module_name, includes=args.include, clean=args.clean) def compile(self): @@ -110,6 +118,9 @@ def extend(self, dist): def _compile_spec(self): try: os.makedirs(self.out) + if self.parser.module_name: + filename = Path(self.out).joinpath('__init__.py') + filename.touch(exist_ok=True) except _FileExistsError: pass @@ -150,6 +161,9 @@ def _generate(self, filename): if self.package is None: self._write(name, msg_def) + if (directory) in self.includes: + self._pyx_files.pop() + print("skip building {0}".format(filename)) def _write(self, name, msg_def): name_pxd = "{}_proto.pxd".format(name) diff --git a/pyrobuf/parse_proto.py b/pyrobuf/parse_proto.py index 83bef25..61ca89b 100644 --- a/pyrobuf/parse_proto.py +++ b/pyrobuf/parse_proto.py @@ -20,19 +20,21 @@ class Parser(object): ('EXTENSION', r'extensions\s+(\d+)\s+to\s+(\d+|max)\s*;'), ('ONEOF', r'oneof\s+([A-Za-z_][0-9A-Za-z_]*)'), ('MODIFIER', r'(optional|required|repeated)'), - ('FIELD', r'([A-Za-z][0-9A-Za-z_]*)\s+([A-Za-z][0-9A-Za-z_]*)\s*=\s*(\d+)'), + ('FIELD', r'(google.protobuf.)?([A-Za-z][0-9A-Za-z_]*)\s+([A-Za-z][0-9A-Za-z_]*)\s*=\s*(\d+)'), ('MAP_FIELD', r'map<([A-Za-z][0-9A-Za-z_]+),\s*([A-Za-z][0-9A-Za-z_]+)>\s+([A-Za-z][0-9A-Za-z_]*)\s*=\s*(\d+)'), ('DEFAULT', r'default\s*='), ('PACKED', r'packed\s*=\s*(true|false)'), ('DEPRECATED', r'deprecated\s*=\s*(true|false)'), - ('CUSTOM', r'(\([A-Za-z][0-9A-Za-z_]*\).[A-Za-z][0-9A-Za-z_]*)\s*='), + ('CUSTOM', r'(\([A-Za-z][0-9A-Za-z_]*\)(?:.[A-Za-z][0-9A-Za-z_]*)?)\s*='), ('LBRACKET', r'\['), ('RBRACKET', r'\]\s*;'), ('LBRACE', r'\{'), + ('KEY_VALUE', r'\:'), ('RBRACE', r'\}\s*;{0,1}'), ('COMMA', r','), ('SKIP', r'\s'), ('SEMICOLON', r';'), + ('HEXVALUE', r'(0x[0-9A-Fa-f]+)'), ('NUMERIC', r'(-?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?)'), ('STRING', r'("(?:\\.|[^"\\])*"|\'(?:\\.|[^"\\])*\')'), ('BOOLEAN', r'(true|false)'), @@ -57,6 +59,7 @@ class Parser(object): 'SEMICOLON', 'ENUM', 'LBRACE', + 'KEY_VALUE', 'RBRACE', 'EXTENSION', 'ONEOF', @@ -136,6 +139,7 @@ class Parser(object): token_regex = '|'.join('(?P<%s>%s)' % pair for pair in tokens) get_token = re.compile(token_regex).match token_getter = {key: re.compile(val).match for key, val in tokens} + module_name = '' def __init__(self, string): self.string = string @@ -176,12 +180,13 @@ def tokenize(self, disabled_token_types): def parse(self, cython_info=True, fname='', includes=None, disabled_tokens=()): self.verify_parsable_tokens() tokens = self.tokenize(disabled_tokens) - rep = {'imports': [], 'messages': [], 'enums': []} + rep = {'imports': [], 'messages': [], 'enums': [], 'module_name': self.module_name} enums = {} imported = {'messages': {}, 'enums': {}} messages = {} includes = includes or [] scope = {} + previous = self.LBrace(-1) for token in tokens: if token.token_type == 'OPTION': @@ -459,7 +464,7 @@ def _parse_field(self, field, tokens): elif token.token_type == 'DEPRECATED': field.deprecated = token.value elif token.token_type == 'CUSTOM': - if self._parse_custom(field, tokens): + if self._parse_custom(field, tokens, token.name): return elif token.token_type == 'COMMA': continue @@ -480,8 +485,12 @@ def _parse_default(self, field, tokens): # This will get updated later field.default = token.full_name return + elif token.token_type == 'HEXVALUE': + assert field.type in self.scalars.difference({'bool', 'enum'}), \ + "attempting to set hex value as default for non-numeric field on line {}: '{}'".format( + token.line + 1, self.lines[token.line]) elif token.token_type == 'NUMERIC': - assert field.type in self.scalars, \ + assert field.type in self.scalars.difference({'bool', 'enum'}), \ "attempting to set numeric as default for non-numeric field on line {}: '{}'".format( token.line + 1, self.lines[token.line]) if field.type not in self.floats: @@ -498,39 +507,60 @@ def _parse_default(self, field, tokens): field.default = token.value - def _parse_custom(self, field, tokens): + def _parse_custom(self, field, tokens, custom_name): """Parse a custom option and return whether or not we hit the closing RBRACKET""" + + custom_name= custom_name[1:-1] # remove () + field.options = dict() token = next(tokens) if token.token_type == 'STRING': - field.value = token.value + field.options[custom_name] = token.value for token in tokens: if token.token_type == 'STRING': - field.value += token.value + field.options[custom_name] += token.value continue elif token.token_type == 'COMMA': return False else: assert token.token_type == 'RBRACKET' return True + elif token.token_type == 'ENUM_FIELD': + field.options[custom_name] = token.name else: - assert token.token_type in {'NUMERIC', 'BOOLEAN'}, "unexpected custom option value on line {}: '{}'".format( - token.line + 1, self.lines[token.line]) - field.value = token.value + if token.token_type == 'LBRACE': + field.options[custom_name] = dict() + while token: + token = next(tokens) + if token.token_type == 'RBRACE': + return False + elif token.token_type == 'KEY_VALUE': + continue + elif hasattr(token,'name'): + key = token.name + elif hasattr(token,'value'): + field.options[custom_name][key]=token.value + else: + assert token.token_type in {'NUMERIC', 'BOOLEAN'}, "unexpected custom option value on line {}: '{}'".format( + token.line + 1, self.lines[token.line]) + field.options[custom_name] = token.value return False def _parse_enum(self, current, tokens, scope, current_message=None): token = next(tokens) assert token.token_type == 'LBRACE', "missing opening brace on line {}: '{}'".format( token.line + 1, self.lines[token.line]) + previous = self.LBrace(-1) + setDefault = False - for num, token in enumerate(tokens): + for token in tokens: if token.token_type == 'ENUM_FIELD': - if num == 0: + if (setDefault is False): if self.syntax == 3: assert token.value == 0, "expected zero as first enum element on line {}, got {}: '{}'".format( token.line + 1, token.value, self.lines[token.line]) current.default = token + setDefault = True token.full_name = "{}_{}".format(current.full_name, token.name) @@ -551,6 +581,7 @@ def _parse_enum(self, current, tokens, scope, current_message=None): token.token_type, token.line + 1, self.lines[token.line]) return current + previous = token raise Exception("unexpected EOF on line {}: '{}'".format( token.line + 1, self.lines[token.line])) @@ -560,7 +591,7 @@ def _parse_enum_field(self, field, tokens): if token.token_type == 'LBRACKET': for token in tokens: if token.token_type == 'CUSTOM': - if self._parse_custom(field, tokens): + if self._parse_custom(field, tokens, token.name): return elif token.token_type == 'COMMA': continue @@ -584,9 +615,11 @@ def _parse_extend(self, current, tokens): return current def add_cython_info(self, message): + count = 0 for index, field in message.fields.items(): - field.bitmap_idx = (index - 1) // 64 - field.bitmap_mask = 1 << ((index - 1) % 64) + field.bitmap_idx = count // 64 + field.bitmap_mask = 1 << (count % 64) + count += 1 field.list_type = self.list_type_map.get(field.type, 'TypedList') field.fixed_width = (field.type in { 'float', 'double', 'fixed32', 'sfixed32', 'fixed64', 'sfixed64' @@ -675,10 +708,13 @@ def __init__(self, line, value): class Field(Token): token_type = 'FIELD' - def __init__(self, line, ftype, name, index): + def __init__(self, line, prefix, ftype, name, index): self.line = line self.modifier = None - self.type = ftype + if prefix: # google.protobuf. + self.type = 'GoogleProtobuf' + ftype + else: + self.type = ftype self.name = name self.index = int(index) self.default = None @@ -807,6 +843,13 @@ def __init__(self, line, value): self.line = line self.value = float(value) + class HexValue(Token): + token_type = 'HEXVALUE' + + def __init__(self, line, value): + self.line = line + self.value = int(value, 16) + class String(Token): token_type = 'STRING' @@ -827,6 +870,12 @@ class LBrace(Token): def __init__(self, line): self.line = line + class KEY_VALUE(Token): + token_type = 'KEY_VALUE' + + def __init__(self, line): + self.line = line + class RBrace(Token): token_type = 'RBRACE' diff --git a/pyrobuf/protobuf/templates/proto_pxd.tmpl b/pyrobuf/protobuf/templates/proto_pxd.tmpl index 4712856..288aa3b 100644 --- a/pyrobuf/protobuf/templates/proto_pxd.tmpl +++ b/pyrobuf/protobuf/templates/proto_pxd.tmpl @@ -8,7 +8,7 @@ from pyrobuf_util cimport * import json {%- for import in imports %} -from {{ import }}_proto cimport * +from {{module_name}}{{ import }}_proto cimport * {%- endfor %} {%- macro classdef(message) %} diff --git a/pyrobuf/protobuf/templates/proto_pyx.tmpl b/pyrobuf/protobuf/templates/proto_pyx.tmpl index de238b0..09e4de7 100644 --- a/pyrobuf/protobuf/templates/proto_pyx.tmpl +++ b/pyrobuf/protobuf/templates/proto_pyx.tmpl @@ -7,12 +7,14 @@ from pyrobuf_util cimport * import base64 import json import warnings +import enum {%- for import in imports %} -from {{ import }}_proto cimport * +from {{module_name}}{{ import }}_proto cimport * {%- endfor %} -{%- macro message_enum_fields_def(enum) %} +{%- macro enum_fields_def(enum) %} +class {{enum.full_name}}(enum.IntEnum): {%- for field in enum.fields.values() %} {{ field.name }} = _{{ field.full_name }} {%- endfor %} @@ -997,10 +999,6 @@ cdef class {{ message.full_name }}: yield self.{{ field.name }} {%- endfor %} - {% for message_enum_name, message_enum in message.enums.items() %} - {{ message_enum_fields_def(message_enum) }} - {% endfor %} - def Setters(self): """ Iterator over functions to set the fields in a message. @@ -1014,6 +1012,10 @@ cdef class {{ message.full_name }}: yield setter {%- endfor %} + {% for message_enum_name, message_enum in message.enums.items() %} +{{ enum_fields_def(message_enum) }} + {% endfor %} + {% for message_name, message_message in message.messages.items() %} {{ classdef(message_message) }} {% endfor %} @@ -1026,12 +1028,6 @@ class DecodeError(Exception): {{ classdef(message) }} {%- endfor %} -{%- macro enum_fields_def(enum) %} -{%- for field in enum.fields.values() %} -{{ field.name }} = _{{ field.full_name }} -{%- endfor %} -{%- endmacro %} - {%- for enum in enums %} {{ enum_fields_def(enum) }} {%- endfor %} diff --git a/setup.py b/setup.py index 56aadf6..7b5b1bc 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ import sys -VERSION = "0.9.3" +VERSION = "0.9.3.16" HERE = os.path.dirname(os.path.abspath(__file__)) PYROBUF_DEFS_PXI = "pyrobuf_defs.pxi" PYROBUF_LIST_PXD = "pyrobuf_list.pxd" diff --git a/tests/conftest.py b/tests/conftest.py index 46a33ca..0335937 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,15 @@ def pytest_sessionstart(session): # Insert built messages into path build = os.path.join(here, 'build') lib_path = os.path.join(build, "lib.{0}-{1}".format(get_platform(), - sys.version[0:3])) + sys.version[0:4])) + # for import with full_name (used in templates) if lib_path not in sys.path: sys.path.insert(0, lib_path) + + if compiler.parser.module_name: + # for short import wihtout pyrogen prefix + lib_path = os.path.join(lib_path,compiler.parser.module_name) + if lib_path not in sys.path: + sys.path.insert(0, lib_path) + diff --git a/tests/proto/test_custom_options.proto b/tests/proto/test_custom_options.proto index d7429b5..b0683ae 100644 --- a/tests/proto/test_custom_options.proto +++ b/tests/proto/test_custom_options.proto @@ -9,4 +9,5 @@ message TestCustomOptions { VALUE1 = 1; } optional CustomEnum field3 = 3 [ default = VALUE1, (my_options).custom1 = 4 ]; + optional double field4 = 4 [ (my_option)= 3.4 ]; } \ No newline at end of file diff --git a/tests/proto/test_many_fields.proto b/tests/proto/test_many_fields.proto index 2050e96..3b8d52a 100644 --- a/tests/proto/test_many_fields.proto +++ b/tests/proto/test_many_fields.proto @@ -127,4 +127,9 @@ message TestManyFields { optional int64 field126 = 126; optional int64 field127 = 127; optional int64 field128 = 128; -} \ No newline at end of file +} + +message TestUnusedFieldIndex { + optional int64 field64 = 64; + optional int64 field65 = 65; +} diff --git a/tests/test_has_field_many.py b/tests/test_has_field_many.py index 05b9ee0..3ab6f3e 100644 --- a/tests/test_has_field_many.py +++ b/tests/test_has_field_many.py @@ -19,3 +19,11 @@ def test_has_field(): for field in test.Fields(): assert not test.HasField(field) + +def test_has_field_withUnusedIndex(): + from test_many_fields_proto import TestUnusedFieldIndex + test = TestUnusedFieldIndex() + + # Assert HasField false on clean message + for field in test.Fields(): + assert not test.HasField(field) diff --git a/tests/test_imported_enums.py b/tests/test_imported_enums.py index 6598bf1..6721ca1 100644 --- a/tests/test_imported_enums.py +++ b/tests/test_imported_enums.py @@ -2,27 +2,27 @@ # These can't be imported until the test_imported_enums_proto module has been built. -CLOSE = None -MSG_ONE = None -ExposesInternalEnumConstantsMessage = None +Status = None +MessageID = None +ExposesInternalEnumConstantsMessageinternal_enum = None UsesImportedEnumsMessage = None class ImportedEnumsTest(unittest.TestCase): @classmethod def setUpClass(cls): - global CLOSE, MSG_ONE, ExposesInternalEnumConstantsMessage, UsesImportedEnumsMessage - from test_multi_messages_toplevel_enums_proto import MSG_ONE, CLOSE - from test_imported_enums_proto import UsesImportedEnumsMessage, ExposesInternalEnumConstantsMessage + global Status, MessageID, ExposesInternalEnumConstantsMessageinternal_enum, UsesImportedEnumsMessage + from test_multi_messages_toplevel_enums_proto import Status, MessageID + from test_imported_enums_proto import UsesImportedEnumsMessage, ExposesInternalEnumConstantsMessageinternal_enum def test_message_id_has_default_of_msg_one(self): message = UsesImportedEnumsMessage() - self.assertEqual(message.message_id, MSG_ONE) + self.assertEqual(message.message_id, MessageID.MSG_ONE) def test_status_has_default_of_close(self): message = UsesImportedEnumsMessage() - self.assertEqual(message.status, CLOSE) + self.assertEqual(message.status, Status.CLOSE) def test_internal_enum_constants_exposed(self): - self.assertEqual(ExposesInternalEnumConstantsMessage.INTERNAL, 0) - self.assertEqual(ExposesInternalEnumConstantsMessage.EXTERNAL, 1) + self.assertEqual(ExposesInternalEnumConstantsMessageinternal_enum.INTERNAL, 0) + self.assertEqual(ExposesInternalEnumConstantsMessageinternal_enum.EXTERNAL, 1) diff --git a/tests/test_merge_from.py b/tests/test_merge_from.py index f9720a5..78b9eb2 100644 --- a/tests/test_merge_from.py +++ b/tests/test_merge_from.py @@ -1,6 +1,7 @@ import unittest Test = None +TestEnumField = None TestSs1 = None TestSs3 = None @@ -8,8 +9,8 @@ class MergeFromTest(unittest.TestCase): @classmethod def setUpClass(cls): - global Test, TestSs1, TestSs3 - from test_message_proto import Test, TestSs1, TestSs3 + global Test, TestEnumField, TestSs1, TestSs3 + from test_message_proto import Test, TestEnumField, TestSs1, TestSs3 def test_merge_from_wrong_type_raises_type_error(self): dest = Test() @@ -49,9 +50,9 @@ def test_merge_from_does_set_scalar_field_that_is_set_in_source(self): def test_merge_from_does_set_enum_field_that_is_set_in_source(self): source = Test() dest = Test() - source.enum_field = Test.TEST_ENUM_FIELD_2 + source.enum_field = TestEnumField.TEST_ENUM_FIELD_2 dest.MergeFrom(source) - self.assertEqual(dest.enum_field, Test.TEST_ENUM_FIELD_2) + self.assertEqual(dest.enum_field, TestEnumField.TEST_ENUM_FIELD_2) def test_merge_from_does_merge_message_field_that_is_set_in_source(self): source = TestSs3()