Skip to content

Commit 463478f

Browse files
authored
fix: sub module conflict error (#295)
* fix: sub module name conflict error
1 parent b83fbae commit 463478f

File tree

10 files changed

+66
-13
lines changed

10 files changed

+66
-13
lines changed

tests/parser-cases/foo.bar.thrift

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include "foo/bar.thrift"

tests/parser-cases/foo/bar.thrift

Whitespace-only changes.

tests/parser-cases/include.thrift

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
include "included.thrift"
2+
include "include/included_1.thrift"
23

34
const included.Timestamp datetime = 1422009523
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include "included_2.thrift"

tests/parser-cases/include/included_2.thrift

Whitespace-only changes.

tests/test_loader.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ def test_load_struct():
5151
def test_load_union():
5252
assert storm_tt.JavaObjectArg.__base__ == TPayload
5353
assert storm.JavaObjectArg.thrift_spec == \
54-
storm_tt.JavaObjectArg.thrift_spec
54+
storm_tt.JavaObjectArg.thrift_spec
5555

5656

5757
def test_load_exc():
5858
assert ab_tt.PersonNotExistsError.__base__ == TException
5959
assert ab.PersonNotExistsError.thrift_spec == \
60-
ab_tt.PersonNotExistsError.thrift_spec
60+
ab_tt.PersonNotExistsError.thrift_spec
6161

6262

6363
def test_load_service():
@@ -70,4 +70,6 @@ def test_load_include():
7070
g = load("parent.thrift")
7171

7272
ts = g.Greet.thrift_spec
73-
assert ts[1][2] == b.Hello and ts[2][0] == TType.I64 and ts[3][2] == b.Code
73+
assert (ts[1][2].thrift_spec == b.Hello.thrift_spec and
74+
ts[2][0] == TType.I64 and
75+
ts[3][2]._NAMES_TO_VALUES == b.Code._NAMES_TO_VALUES)

tests/test_parser.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
2+
import sys
33
import threading
44

55
import pytest
@@ -36,8 +36,26 @@ def test_constants():
3636

3737
def test_include():
3838
thrift = load('parser-cases/include.thrift', include_dirs=[
39-
'./parser-cases'])
39+
'./parser-cases'], module_name='include_thrift')
4040
assert thrift.datetime == 1422009523
41+
assert sys.modules['include_thrift'] is not None
42+
assert sys.modules['included_thrift'] is not None
43+
assert sys.modules['include.included_1_thrift'] is not None
44+
assert sys.modules['include.included_2_thrift'] is not None
45+
46+
47+
def test_include_with_module_name_prefix():
48+
load('parser-cases/include.thrift', module_name='parser_cases.include_thrift')
49+
assert sys.modules['parser_cases.include_thrift'] is not None
50+
assert sys.modules['parser_cases.included_thrift'] is not None
51+
assert sys.modules['parser_cases.include.included_1_thrift'] is not None
52+
assert sys.modules['parser_cases.include.included_2_thrift'] is not None
53+
54+
55+
def test_include_conflict():
56+
with pytest.raises(ThriftParserError) as excinfo:
57+
load('parser-cases/foo.bar.thrift', module_name='foo.bar_thrift')
58+
assert 'Module name conflict between' in str(excinfo.value)
4159

4260

4361
def test_cpp_include():
@@ -295,6 +313,9 @@ def test_thrift_meta():
295313

296314

297315
def test_load_fp():
316+
from thriftpy2.parser import threadlocal
317+
threadlocal.__dict__.clear()
318+
298319
thrift = None
299320
with open('parser-cases/shared.thrift') as thrift_fp:
300321
thrift = load_fp(thrift_fp, 'shared_thrift')

thriftpy2/parser/__init__.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import types
1515

1616
from .parser import parse, parse_fp, threadlocal, _cast
17-
from .exc import ThriftParserError
17+
from .exc import ThriftParserError, ThriftModuleNameConflict
1818
from ..thrift import TPayloadMeta
1919

2020

@@ -41,12 +41,21 @@ def load(path,
4141
# add sub modules to sys.modules recursively
4242
if real_module:
4343
sys.modules[module_name] = thrift
44-
sub_modules = thrift.__thrift_meta__["includes"][:]
45-
while sub_modules:
46-
module = sub_modules.pop()
47-
if module not in sys.modules:
48-
sys.modules[module.__name__] = module
49-
sub_modules.extend(module.__thrift_meta__["includes"])
44+
include_thrifts = thrift.__thrift_meta__["includes"][:]
45+
while include_thrifts:
46+
include_thrift = include_thrifts.pop()
47+
registered_thrift = sys.modules.get(include_thrift.__thrift_module_name__)
48+
if registered_thrift is None:
49+
sys.modules[include_thrift.__thrift_module_name__] = include_thrift
50+
if hasattr(include_thrift, "__thrift_meta__"):
51+
include_thrifts.extend(
52+
include_thrift.__thrift_meta__["includes"][:])
53+
else:
54+
if registered_thrift.__thrift_file__ != include_thrift.__thrift_file__:
55+
raise ThriftModuleNameConflict(
56+
'Module name conflict between "%s" and "%s"' %
57+
(registered_thrift.__thrift_file__, include_thrift.__thrift_file__)
58+
)
5059
return thrift
5160

5261

thriftpy2/parser/exc.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ class ThriftParserError(Exception):
1010
pass
1111

1212

13+
class ThriftModuleNameConflict(ThriftParserError):
14+
pass
15+
16+
1317
class ThriftLexerError(ThriftParserError):
1418
pass
1519

thriftpy2/parser/parser.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,21 @@ def p_include(p):
6262
for include_dir in replace_include_dirs:
6363
path = os.path.join(include_dir, p[2])
6464
if os.path.exists(path):
65-
child = parse(path)
65+
thrift_file_name_module = os.path.basename(thrift.__thrift_file__)
66+
if thrift_file_name_module.endswith(".thrift"):
67+
thrift_file_name_module = thrift_file_name_module[:-7] + "_thrift"
68+
module_prefix = str(thrift.__name__).rstrip(thrift_file_name_module)
69+
70+
child_rel_path = os.path.relpath(str(path), os.path.dirname(thrift.__thrift_file__))
71+
child_module_name = str(child_rel_path).replace(os.sep, ".").replace(".thrift", "_thrift")
72+
child_module_name = module_prefix + child_module_name
73+
74+
child = parse(path, module_name=child_module_name)
75+
child_include_module_name = os.path.basename(path)
76+
if child_include_module_name.endswith(".thrift"):
77+
child_include_module_name = child_include_module_name[:-7]
78+
setattr(child, '__name__', child_include_module_name)
79+
setattr(child, '__thrift_module_name__', child_module_name)
6680
setattr(thrift, child.__name__, child)
6781
_add_thrift_meta('includes', child)
6882
return

0 commit comments

Comments
 (0)