Skip to content

Commit cbe365d

Browse files
committed
Unify struct and pointer to struct handling, abstract null check in ir_ops
1 parent fed6af1 commit cbe365d

File tree

5 files changed

+144
-129
lines changed

5 files changed

+144
-129
lines changed

pythonbpf/expr/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .expr_pass import eval_expr, handle_expr, get_operand_value
22
from .type_normalization import convert_to_bool, get_base_type_and_depth
3-
from .ir_ops import deref_to_depth
3+
from .ir_ops import deref_to_depth, access_struct_field
44
from .call_registry import CallHandlerRegistry
55
from .vmlinux_registry import VmlinuxHandlerRegistry
66

@@ -10,6 +10,7 @@
1010
"convert_to_bool",
1111
"get_base_type_and_depth",
1212
"deref_to_depth",
13+
"access_struct_field",
1314
"get_operand_value",
1415
"CallHandlerRegistry",
1516
"VmlinuxHandlerRegistry",

pythonbpf/expr/expr_pass.py

Lines changed: 18 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from pythonbpf.type_deducer import ctypes_to_ir, is_ctypes
88
from .call_registry import CallHandlerRegistry
9+
from .ir_ops import deref_to_depth, access_struct_field
910
from .type_normalization import (
1011
convert_to_bool,
1112
handle_comparator,
1213
get_base_type_and_depth,
13-
deref_to_depth,
1414
)
1515
from .vmlinux_registry import VmlinuxHandlerRegistry
1616
from ..vmlinux_parser.dependency_node import Field
@@ -77,89 +77,6 @@ def _handle_attribute_expr(
7777
logger.info(
7878
f"Variable type: {var_type}, Variable ptr: {var_ptr}, Variable Metadata: {var_metadata}"
7979
)
80-
# Check if this is a pointer to a struct (from map lookup)
81-
if (
82-
isinstance(var_type, ir.PointerType)
83-
and var_metadata
84-
and isinstance(var_metadata, str)
85-
):
86-
if var_metadata in structs_sym_tab:
87-
logger.info(
88-
f"Handling pointer to struct {var_metadata} from map lookup"
89-
)
90-
91-
if func is None:
92-
raise ValueError(
93-
f"func parameter required for null-safe pointer access to {var_name}.{attr_name}"
94-
)
95-
96-
# Load the pointer value (ptr<struct>)
97-
struct_ptr = builder.load(var_ptr)
98-
99-
# Create blocks for null check
100-
null_check_block = builder.block
101-
not_null_block = func.append_basic_block(
102-
name=f"{var_name}_not_null"
103-
)
104-
merge_block = func.append_basic_block(name=f"{var_name}_merge")
105-
106-
# Check if pointer is null
107-
null_ptr = ir.Constant(struct_ptr.type, None)
108-
is_not_null = builder.icmp_signed("!=", struct_ptr, null_ptr)
109-
logger.info(f"Inserted null check for pointer {var_name}")
110-
111-
builder.cbranch(is_not_null, not_null_block, merge_block)
112-
113-
# Not-null block: Access the field
114-
builder.position_at_end(not_null_block)
115-
116-
# Get struct metadata
117-
metadata = structs_sym_tab[var_metadata]
118-
struct_ptr = builder.bitcast(
119-
struct_ptr, metadata.ir_type.as_pointer()
120-
)
121-
122-
if attr_name not in metadata.fields:
123-
raise ValueError(
124-
f"Field '{attr_name}' not found in struct '{var_metadata}'"
125-
)
126-
127-
# GEP to field
128-
field_gep = metadata.gep(builder, struct_ptr, attr_name)
129-
130-
# Load field value
131-
field_val = builder.load(field_gep)
132-
field_type = metadata.field_type(attr_name)
133-
134-
logger.info(
135-
f"Loaded field {attr_name} from struct pointer, type: {field_type}"
136-
)
137-
138-
# Branch to merge
139-
not_null_after_load = builder.block
140-
builder.branch(merge_block)
141-
142-
# Merge block: PHI node for the result
143-
builder.position_at_end(merge_block)
144-
phi = builder.phi(field_type, name=f"{var_name}_{attr_name}")
145-
146-
# If null, return zero/default value
147-
if isinstance(field_type, ir.IntType):
148-
zero_value = ir.Constant(field_type, 0)
149-
elif isinstance(field_type, ir.PointerType):
150-
zero_value = ir.Constant(field_type, None)
151-
elif isinstance(field_type, ir.ArrayType):
152-
# For arrays, we can't easily create a zero constant
153-
# This case is tricky - for now, just use undef
154-
zero_value = ir.Constant(field_type, ir.Undefined)
155-
else:
156-
zero_value = ir.Constant(field_type, ir.Undefined)
157-
158-
phi.add_incoming(zero_value, null_check_block)
159-
phi.add_incoming(field_val, not_null_after_load)
160-
161-
logger.info(f"Created PHI node for {var_name}.{attr_name}")
162-
return phi, field_type
16380
if (
16481
hasattr(var_metadata, "__module__")
16582
and var_metadata.__module__ == "vmlinux"
@@ -180,13 +97,23 @@ def _handle_attribute_expr(
18097
)
18198
return None
18299

183-
# Regular user-defined struct
184-
metadata = structs_sym_tab.get(var_metadata)
185-
if metadata and attr_name in metadata.fields:
186-
gep = metadata.gep(builder, var_ptr, attr_name)
187-
val = builder.load(gep)
188-
field_type = metadata.field_type(attr_name)
189-
return val, field_type
100+
if var_metadata in structs_sym_tab:
101+
return access_struct_field(
102+
builder,
103+
var_ptr,
104+
var_type,
105+
var_metadata,
106+
expr.attr,
107+
structs_sym_tab,
108+
func,
109+
)
110+
else:
111+
logger.error(f"Struct metadata for '{var_name}' not found")
112+
else:
113+
logger.error(f"Undefined variable '{var_name}' for attribute access")
114+
else:
115+
logger.error("Unsupported attribute base expression type")
116+
190117
return None
191118

192119

pythonbpf/expr/ir_ops.py

Lines changed: 95 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,108 @@ def deref_to_depth(func, builder, val, target_depth):
1717

1818
# dereference with null check
1919
pointee_type = cur_type.pointee
20-
null_check_block = builder.block
21-
not_null_block = func.append_basic_block(name=f"deref_not_null_{depth}")
22-
merge_block = func.append_basic_block(name=f"deref_merge_{depth}")
2320

24-
null_ptr = ir.Constant(cur_type, None)
25-
is_not_null = builder.icmp_signed("!=", cur_val, null_ptr)
26-
logger.debug(f"Inserted null check for pointer at depth {depth}")
21+
def load_op(builder, ptr):
22+
return builder.load(ptr)
2723

28-
builder.cbranch(is_not_null, not_null_block, merge_block)
24+
cur_val = _null_checked_operation(
25+
func, builder, cur_val, load_op, pointee_type, f"deref_{depth}"
26+
)
27+
cur_type = pointee_type
28+
logger.debug(f"Dereferenced to depth {depth}, type: {pointee_type}")
29+
return cur_val
2930

30-
builder.position_at_end(not_null_block)
31-
dereferenced_val = builder.load(cur_val)
32-
logger.debug(f"Dereferenced to depth {depth - 1}, type: {pointee_type}")
33-
builder.branch(merge_block)
3431

35-
builder.position_at_end(merge_block)
36-
phi = builder.phi(pointee_type, name=f"deref_result_{depth}")
32+
def _null_checked_operation(func, builder, ptr, operation, result_type, name_prefix):
33+
"""
34+
Generic null-checked operation on a pointer.
35+
"""
36+
curr_block = builder.block
37+
not_null_block = func.append_basic_block(name=f"{name_prefix}_not_null")
38+
merge_block = func.append_basic_block(name=f"{name_prefix}_merge")
3739

38-
zero_value = (
39-
ir.Constant(pointee_type, 0)
40-
if isinstance(pointee_type, ir.IntType)
41-
else ir.Constant(pointee_type, None)
42-
)
43-
phi.add_incoming(zero_value, null_check_block)
40+
# Null check
41+
null_ptr = ir.Constant(ptr.type, None)
42+
is_not_null = builder.icmp_signed("!=", ptr, null_ptr)
43+
builder.cbranch(is_not_null, not_null_block, merge_block)
4444

45-
phi.add_incoming(dereferenced_val, not_null_block)
45+
# Not-null path: execute operation
46+
builder.position_at_end(not_null_block)
47+
result = operation(builder, ptr)
48+
not_null_after = builder.block
49+
builder.branch(merge_block)
4650

47-
# Continue with phi result
48-
cur_val = phi
49-
cur_type = pointee_type
50-
return cur_val
51+
# Merge with PHI
52+
builder.position_at_end(merge_block)
53+
phi = builder.phi(result_type, name=f"{name_prefix}_result")
54+
55+
# Null fallback value
56+
if isinstance(result_type, ir.IntType):
57+
null_val = ir.Constant(result_type, 0)
58+
elif isinstance(result_type, ir.PointerType):
59+
null_val = ir.Constant(result_type, None)
60+
else:
61+
null_val = ir.Constant(result_type, ir.Undefined)
62+
63+
phi.add_incoming(null_val, curr_block)
64+
phi.add_incoming(result, not_null_after)
65+
66+
return phi
5167

5268

53-
def deref_struct_ptr(
54-
func, builder, struct_ptr, struct_metadata, field_name, structs_sym_tab
69+
def access_struct_field(
70+
builder, var_ptr, var_type, var_metadata, field_name, structs_sym_tab, func=None
5571
):
56-
"""Dereference a pointer to a struct type."""
57-
return deref_to_depth(func, builder, struct_ptr, 1)
72+
"""
73+
Access a struct field - automatically returns value or pointer based on field type.
74+
"""
75+
# Get struct metadata
76+
metadata = (
77+
structs_sym_tab.get(var_metadata)
78+
if isinstance(var_metadata, str)
79+
else var_metadata
80+
)
81+
if not metadata or field_name not in metadata.fields:
82+
raise ValueError(f"Field '{field_name}' not found in struct")
83+
84+
field_type = metadata.field_type(field_name)
85+
is_ptr_to_struct = isinstance(var_type, ir.PointerType) and isinstance(
86+
var_metadata, str
87+
)
88+
89+
# Get struct pointer
90+
struct_ptr = builder.load(var_ptr) if is_ptr_to_struct else var_ptr
91+
92+
# Decide: load value or return pointer?
93+
should_load = not isinstance(field_type, ir.ArrayType)
94+
95+
# Define the field access operation
96+
def field_access_op(builder, ptr):
97+
typed_ptr = builder.bitcast(ptr, metadata.ir_type.as_pointer())
98+
field_ptr = metadata.gep(builder, typed_ptr, field_name)
99+
return builder.load(field_ptr) if should_load else field_ptr
100+
101+
# Handle null check for pointer-to-struct
102+
if is_ptr_to_struct:
103+
if func is None:
104+
raise ValueError("func required for null-safe struct pointer access")
105+
106+
if should_load:
107+
result_type = field_type
108+
else:
109+
result_type = field_type.as_pointer()
110+
111+
result = _null_checked_operation(
112+
func,
113+
builder,
114+
struct_ptr,
115+
field_access_op,
116+
result_type,
117+
f"field_{field_name}",
118+
)
119+
return result, field_type
120+
121+
# No null check needed
122+
field_ptr = metadata.gep(builder, struct_ptr, field_name)
123+
result = builder.load(field_ptr) if should_load else field_ptr
124+
return result, field_type

pythonbpf/helper/helper_utils.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pythonbpf.expr import (
66
get_operand_value,
77
eval_expr,
8+
access_struct_field,
89
)
910

1011
logger = logging.getLogger(__name__)
@@ -135,7 +136,7 @@ def get_or_create_ptr_from_arg(
135136
and field_type.element.width == 8
136137
):
137138
ptr, sz = get_char_array_ptr_and_size(
138-
arg, builder, local_sym_tab, struct_sym_tab
139+
arg, builder, local_sym_tab, struct_sym_tab, func
139140
)
140141
if not ptr:
141142
raise ValueError("Failed to get char array pointer from struct field")
@@ -266,7 +267,9 @@ def get_buffer_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
266267
)
267268

268269

269-
def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab):
270+
def get_char_array_ptr_and_size(
271+
buf_arg, builder, local_sym_tab, struct_sym_tab, func=None
272+
):
270273
"""Get pointer to char array and its size."""
271274

272275
# Struct field: obj.field
@@ -277,11 +280,11 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab)
277280
if not (local_sym_tab and var_name in local_sym_tab):
278281
raise ValueError(f"Variable '{var_name}' not found")
279282

280-
struct_type = local_sym_tab[var_name].metadata
281-
if not (struct_sym_tab and struct_type in struct_sym_tab):
282-
raise ValueError(f"Struct type '{struct_type}' not found")
283+
struct_ptr, struct_type, struct_metadata = local_sym_tab[var_name]
284+
if not (struct_sym_tab and struct_metadata in struct_sym_tab):
285+
raise ValueError(f"Struct type '{struct_metadata}' not found")
283286

284-
struct_info = struct_sym_tab[struct_type]
287+
struct_info = struct_sym_tab[struct_metadata]
285288
if field_name not in struct_info.fields:
286289
raise ValueError(f"Field '{field_name}' not found")
287290

@@ -292,8 +295,25 @@ def get_char_array_ptr_and_size(buf_arg, builder, local_sym_tab, struct_sym_tab)
292295
)
293296
return None, 0
294297

295-
struct_ptr = local_sym_tab[var_name].var
296-
field_ptr = struct_info.gep(builder, struct_ptr, field_name)
298+
# Check if char array
299+
if not (
300+
isinstance(field_type, ir.ArrayType)
301+
and isinstance(field_type.element, ir.IntType)
302+
and field_type.element.width == 8
303+
):
304+
logger.warning("Field is not a char array")
305+
return None, 0
306+
307+
# Get field pointer (automatically handles null checks!)
308+
field_ptr, _ = access_struct_field(
309+
builder,
310+
struct_ptr,
311+
struct_type,
312+
struct_metadata,
313+
field_name,
314+
struct_sym_tab,
315+
func,
316+
)
297317

298318
# GEP to first element: [N x i8]* -> i8*
299319
buf_ptr = builder.gep(

pythonbpf/helper/printk_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def _prepare_expr_args(expr, func, module, builder, local_sym_tab, struct_sym_ta
222222
# Special case: struct field char array needs pointer to first element
223223
if isinstance(expr, ast.Attribute):
224224
char_array_ptr, _ = get_char_array_ptr_and_size(
225-
expr, builder, local_sym_tab, struct_sym_tab
225+
expr, builder, local_sym_tab, struct_sym_tab, func
226226
)
227227
if char_array_ptr:
228228
return char_array_ptr

0 commit comments

Comments
 (0)