Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
4d6488c
Initial implementaitons of candidate vs rewrite shuttle
knassre-bodo Oct 9, 2025
5369379
Initial implementation of predicate server integration working on cry…
knassre-bodo Oct 9, 2025
36cab6e
WIP adding to lookup table
knassre-bodo Oct 9, 2025
ed6650c
Rewriting the rest of the filter count queries
knassre-bodo Oct 9, 2025
cc2bbed
Moving server address into mask server info setup
knassre-bodo Oct 9, 2025
a6d4b29
[RUN ALL]
knassre-bodo Oct 9, 2025
beadb15
Adding more tests
knassre-bodo Oct 10, 2025
1b4bcac
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Oct 14, 2025
5ea82f1
Switching up relational shuttle handling for simplification
knassre-bodo Oct 15, 2025
f0f512c
Minor adjustments to file placement
knassre-bodo Oct 15, 2025
54ecef1
Moved some logic from rewrite shuttle to candidate visitor
knassre-bodo Oct 15, 2025
557aaeb
Added more tests
knassre-bodo Oct 15, 2025
6b109d9
Added rewrite shuttle docstrings/comments
knassre-bodo Oct 16, 2025
1377916
Adding remaining documentation
knassre-bodo Oct 16, 2025
891c472
Removing dead rule
knassre-bodo Oct 16, 2025
7d7580b
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Oct 16, 2025
62db4bf
[RUN ALL]
knassre-bodo Oct 16, 2025
c9f6a59
[RUN ALL]
knassre-bodo Oct 16, 2025
7c37110
Adding logging to keep track of the batch requests sent
knassre-bodo Oct 26, 2025
127244f
Ensuring non-predicate sub-expressions are not sent to the server [RU…
knassre-bodo Oct 26, 2025
1f2dc6d
Ensuring non-predicate sub-expressions are not sent to the server [RU…
knassre-bodo Oct 26, 2025
2864e4a
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Oct 26, 2025
b278f9b
Adding date/datetime/timestamp literal handling tests [RUN CI]
knassre-bodo Oct 26, 2025
74b8824
Initial implementation added, as well as early testing
knassre-bodo Oct 27, 2025
7de68bd
Added bubbleprop tests and more warning log tests [RUN CI]
knassre-bodo Oct 28, 2025
d4d2b29
Added docstrings/comments
knassre-bodo Oct 28, 2025
dcbb69c
Added new operators support, need to add new tests for datetime, quar…
knassre-bodo Oct 30, 2025
feabd8a
Added more tests, handled predicate pushdown bug with least/greatest,…
knassre-bodo Oct 30, 2025
940dd16
Added remaining tests [RUN CI]
knassre-bodo Oct 31, 2025
10f40da
Merge branch 'kian/mask_server_rewrite' into kian/unmask_warning_logs
knassre-bodo Oct 31, 2025
a6f6a37
Predicate server revisions with new API
knassre-bodo Nov 5, 2025
af10c5b
JSON request/response reformatting WIP
knassre-bodo Nov 16, 2025
0371ec5
Adding four-phase algorithm, need to implement step #3
knassre-bodo Nov 19, 2025
3996ced
Updating rewrite handling, need to add DP algorithm
knassre-bodo Nov 19, 2025
29e0e3f
Finishing implementation of min cover set
knassre-bodo Nov 21, 2025
f9c05b2
Added edge case tests for selection algorithm
knassre-bodo Nov 21, 2025
4f274fd
Minor test adjustment
knassre-bodo Nov 21, 2025
18379ef
Minor test adjustment
knassre-bodo Nov 21, 2025
f512f8b
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Nov 24, 2025
90f0671
Resolving conflicts [RUN ALL]
knassre-bodo Nov 24, 2025
f6a571b
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Nov 26, 2025
b728348
Added the FQN slash handling
knassre-bodo Nov 26, 2025
8e03b04
Revisions, QUOTE operator handling, docstrings/documentation [RUN ALL]
knassre-bodo Dec 2, 2025
a3c79cf
Fixing mask server tests [RUN ALL]
knassre-bodo Dec 3, 2025
32d7ee2
API-based revisions overhaul WIP
knassre-bodo Dec 10, 2025
0ed7303
Mask server working, need to iron out kinks with 'retail_transactions…
knassre-bodo Dec 18, 2025
7e98a09
More documentation
knassre-bodo Dec 22, 2025
28c7478
[RUN CI]
knassre-bodo Dec 22, 2025
af7089e
More tests after TS fixed, still need to iterate and remove prints [R…
knassre-bodo Dec 23, 2025
f619356
Edge case debugging WIP
knassre-bodo Dec 23, 2025
8a5f82d
Adding PYDOUGH_MASK_SERVER_PATH to CI
knassre-bodo Dec 23, 2025
aaf9af1
Resolving conflicts
knassre-bodo Dec 23, 2025
4e0cf6f
Added more tests after mask server debugging, identified more bugs in…
knassre-bodo Dec 29, 2025
3cdf093
Bugfixes and breaking up batches to be max length of 16
knassre-bodo Jan 6, 2026
a01a4a8
Merge branch 'main' into kian/mask_server_rewrite
knassre-bodo Jan 6, 2026
d821d1a
[RUN ALL]
knassre-bodo Jan 6, 2026
c0c77e0
Testing fixes [RUN CI][RUN SF_MASKED]
knassre-bodo Jan 6, 2026
962a164
Resolving conflicts
knassre-bodo Jan 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions .github/workflows/pr_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ jobs:
SF_NONE_USERNAME: ${{ secrets.SF_NONE_USERNAME }}
SF_NONE_PASSWORD: ${{ secrets.SF_NONE_PASSWORD }}
SF_MASKED_ACCOUNT: ${{ secrets.SF_MASKED_ACCOUNT }}
PYDOUGH_MASK_SERVER_PATH: ${{ secrets.PYDOUGH_MASK_SERVER_PATH }}
with:
python-versions: ${{ github.event_name == 'workflow_dispatch'
&& needs.get-py-ver-matrix.outputs.matrix
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/sf_masked_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ on:
required: true
SF_MASKED_ACCOUNT:
required: true
PYDOUGH_MASK_SERVER_PATH:
required: true

jobs:
sf-tests:
Expand All @@ -39,6 +41,7 @@ jobs:
SF_NONE_USERNAME: ${{ secrets.SF_NONE_USERNAME }}
SF_NONE_PASSWORD: ${{ secrets.SF_NONE_PASSWORD }}
SF_MASKED_ACCOUNT: ${{ secrets.SF_MASKED_ACCOUNT }}
PYDOUGH_MASK_SERVER_PATH: ${{ secrets.PYDOUGH_MASK_SERVER_PATH }}

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 2 additions & 0 deletions documentation/metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ Properties of this type use the type string "masked table column" and include al
- `protect protocol` (required): a Python format string, in the same format as `unprotect protocol`, used to describe how the data was originally masked. This can be used to generate masked values consistent with the encryption scheme, allowing operations such as comparisons between masked data.
- `protected data type` (optional): same as `data type`, except referring to the type of the data when it is protected, whereas `data type` refers to the raw unprotected column. If omitted, it is assumed that the data type is the same between the unprotected vs protected data.
- `server masked` (optional): a boolean flag indicating whether the column was masked on a server that is attached to PyDough. If `true`, PyDough can use it to optimize queries by rewriting predicates and expressions to avoid unmasking the data.
- `server dataset id` (optional): a string that must be provided `server masked` is `true`, indicating the `dataset id` value to be used when looking up this column in a remote server to optimize it by rewriting predicates.


Example of the structure of the metadata for a masked table column property where the string data is masked by moving the first character to the end, and unmasked by moving it back to the beginning:

Expand Down
26 changes: 26 additions & 0 deletions pydough/configs/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
existing state.
"""

from typing import TYPE_CHECKING, Union

from pydough.database_connectors import (
DatabaseContext,
DatabaseDialect,
Expand All @@ -30,6 +32,9 @@

from .pydough_configs import PyDoughConfigs

if TYPE_CHECKING:
from pydough.mask_server import MaskServerInfo


class PyDoughSession:
"""
Expand All @@ -50,6 +55,7 @@ def __init__(self) -> None:
connection=empty_connection, dialect=DatabaseDialect.ANSI
)
self._error_builder: PyDoughErrorBuilder = PyDoughErrorBuilder()
self._mask_server: MaskServerInfo | None = None

@property
def metadata(self) -> GraphMetadata | None:
Expand Down Expand Up @@ -131,6 +137,26 @@ def error_builder(self, builder: PyDoughErrorBuilder) -> None:
"""
self._error_builder = builder

@property
def mask_server(self) -> Union["MaskServerInfo", None]:
"""
Get the active mask server information.

Returns:
The active mask server information.
"""
return self._mask_server

@mask_server.setter
def mask_server(self, server_info: Union["MaskServerInfo", None]) -> None:
"""
Set the active mask server information.

Args:
The mask server information to set.
"""
self._mask_server = server_info

def connect_database(self, database_name: str, **kwargs) -> DatabaseContext:
"""
Create a new DatabaseContext and register it in the session. This returns
Expand Down
290 changes: 290 additions & 0 deletions pydough/conversion/masking_critical_detection_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
"""
Logic for detecting mask/unmask calls within the final relational plan that will
cause a critical logical error if the user does not have permission to make the
call, and logging warnings for those calls.
"""

__all__ = ["MaskingCriticalDetectionVisitor"]

import pydough.pydough_operators as pydop
from pydough.logger import get_logger
from pydough.relational import (
Aggregate,
CallExpression,
ColumnReference,
CorrelatedReference,
EmptySingleton,
Filter,
GeneratedTable,
Join,
Limit,
LiteralExpression,
Project,
RelationalExpression,
RelationalExpressionVisitor,
RelationalNode,
RelationalRoot,
RelationalVisitor,
Scan,
WindowCallExpression,
)
from pydough.relational.rel_util import add_input_name


class MaskingCriticalDetectionExpressionVisitor(RelationalExpressionVisitor):
"""
A visitor to detect mask/unmask calls within expressions based on which
columns from the input relational node depend on mask/unmask calls. After
calling accept with an expression, the stack will contain a singleton list
with a set of (column_name, is_unmask) tuples representing the table
columns that the expression depends on via mask/unmask calls, and whether
they depend on masking or unmasking.
"""

def __init__(self) -> None:
self.input_dependencies: dict[RelationalExpression, set[tuple[str, bool]]] = {}
self.stack: list[set[tuple[str, bool]]] = []

def reset(self) -> None:
self.stack = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it also reset self.input_dependencies?


def visit_call_expression(self, expr: CallExpression) -> None:
# Aggregate the dependencies from all input expressions
dependencies: set[tuple[str, bool]] = set()
for input_expr in expr.inputs:
input_expr.accept(self)
dependencies.update(self.stack.pop())
# If this call expression is a mask/unmask operation, add the relevant
# column dependency.
if isinstance(expr.op, pydop.MaskedExpressionFunctionOperator):
dependencies.add(
(
f"{expr.op.table_path}.{expr.op.masking_metadata.column_name}",
expr.op.is_unmask,
)
)
self.stack.append(dependencies)

def visit_window_expression(self, expr: WindowCallExpression) -> None:
# Aggregate the dependencies from all input, partition, and order
# expressions.
dependencies: set[tuple[str, bool]] = set()
for input_expr in expr.inputs:
input_expr.accept(self)
dependencies.update(self.stack.pop())
for partition_expr in expr.partition_inputs:
partition_expr.accept(self)
dependencies.update(self.stack.pop())
for order_expr in expr.order_inputs:
order_expr.expr.accept(self)
dependencies.update(self.stack.pop())
self.stack.append(dependencies)

def visit_column_reference(self, column_reference: ColumnReference) -> None:
# Retrieve the dependencies for this column from the input dependencies.
self.stack.append(self.input_dependencies.get(column_reference, set()))

def visit_correlated_reference(
self, correlated_reference: CorrelatedReference
) -> None:
# Correlated references have no dependencies on masking/unmasking.
self.stack.append(set())

def visit_literal_expression(self, literal_expression: LiteralExpression) -> None:
# Literal expressions have no dependencies on masking/unmasking.
self.stack.append(set())


class MaskingCriticalDetectionVisitor(RelationalVisitor):
"""
The main visitor which traverses the relational tree, inferring which
columns depending on mask/unmask calls, propagating them upward through
the plan, and logging warnings for any mask/unmask calls that are
critical to the output of the query.
"""

def __init__(self) -> None:
self.critical_mask_columns: set[str] = set()
"""
The set of fully qualified column names where a MASK operation on the
column is critical to the output of the query.
"""

self.critical_unmask_columns: set[str] = set()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

"""
The set of fully qualified column names where an UNMASK operation on the
column is critical to the output of the query.
"""

self.expression_visitor = MaskingCriticalDetectionExpressionVisitor()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need type hint for this?

"""
The expression visitor used to detect mask/unmask dependencies within
expressions.
"""

self.stack: list[dict[RelationalExpression, set[tuple[str, bool]]]] = []
"""
The stack of input dependency mappings for each relational node visited.
Each mapping corresponds to a relational node from one of the inputs to
the current node, and maps each output expression of the node to the set
of (column_name, is_unmask) tuples that the expression depends on.
"""

def reset(self) -> None:
self.critical_mask_columns.clear()
self.critical_unmask_columns.clear()
self.stack.clear()
self.expression_visitor.reset()

def visit_inputs(self, node: RelationalNode) -> None:
"""
Generic logic to visit all input nodes of a relational node, and build
the input dependency mapping for the current node that will be given to
the expression visitor so it knows which column references have
dependencies on mask/unmask calls.
"""
input_dependencies: dict[RelationalExpression, set[tuple[str, bool]]] = {}

# Loop over all of the input nodes and recursively visit them,
# extracting their dependencies from the stack.
for idx, input_node in enumerate(node.inputs):
input_node.accept(self)
dependencies: dict[RelationalExpression, set[tuple[str, bool]]] = (
self.stack.pop()
)
# If the node has 1 input, then its dependencies are what should
# be used.
if len(node.inputs) == 1:
input_dependencies = dependencies

# Otherwise, we need to map the dependencies to the appropriate
# input alias for the node.
else:
alias: str | None = node.default_input_aliases[idx]
for expr, deps in dependencies.items():
input_dependencies[add_input_name(expr, alias)] = deps

# Register the unmask/mask call dependencies from all inputs to this
# node with the expression visitor.
self.expression_visitor.input_dependencies = input_dependencies

def find_critical_dependencies(self, expr: RelationalExpression) -> None:
"""
Takes in an expression used in a critical manner (join condition,
filter condition, aggregate key, or ordering key for a root/limit), and
feeds it to the expression visitor to determine if it has any mask/unmask
call dependencies. If it does, the relevant columns are added to the
critical mask/unmask column sets.

Args:
`expr`: The expression to analyze for any mask/unmask dependencies.
"""
expr.accept(self.expression_visitor)
expr_dependencies: set[tuple[str, bool]] = self.expression_visitor.stack.pop()
for col_name, is_unmask in expr_dependencies:
if is_unmask:
self.critical_unmask_columns.add(col_name)
else:
self.critical_mask_columns.add(col_name)

def add_output_dependencies(self, node: RelationalNode) -> None:
"""
Uses the expression visitor to determine the mask/unmask dependencies
for each output expression of the given relational node, and pushes the
resulting mapping onto the stack.

Args:
`node`: The relational node whose output columns are having their
dependencies determined.
"""
out_dependencies: dict[RelationalExpression, set[tuple[str, bool]]] = {}
for name, expr in node.columns.items():
expr.accept(self.expression_visitor)
out_dependencies[ColumnReference(name, expr.data_type)] = (
self.expression_visitor.stack.pop()
)
self.stack.append(out_dependencies)

def log_critical_calls(self) -> None:
"""
Logs warnings for all critical mask/unmask calls detected during the
traversal of the tree.

This should be called once after the visitor has traversed the entire
relational plan.
"""
logger = get_logger()
for column in self.critical_mask_columns:
logger.warning(
f"Query will not produce a valid output unless user has permission to mask column `{column}`"
)
for column in self.critical_unmask_columns:
logger.warning(
f"Query will not produce a valid output unless user has permission to unmask column `{column}`"
)

# Clean up the visitor afterwards, to avoid accidentally logging a
# duplicate.
self.reset()

def visit_project(self, project: Project) -> None:
# Projects simply propagate dependencies from their inputs.
self.visit_inputs(project)
self.add_output_dependencies(project)

def visit_filter(self, filter: Filter) -> None:
# Filter nodes propagate dependencies from their inputs, but also
# analyze their condition for critical dependencies.
self.visit_inputs(filter)
self.find_critical_dependencies(filter.condition)
self.add_output_dependencies(filter)

def visit_join(self, join: Join) -> None:
# Filter nodes propagate dependencies from their inputs, but also
# analyze their condition for critical dependencies.
self.visit_inputs(join)
self.find_critical_dependencies(join.condition)
self.add_output_dependencies(join)

def visit_aggregate(self, aggregate: Aggregate) -> None:
# Aggregate nodes propagate dependencies from their inputs, but also
# analyze their aggregation keys for critical dependencies.
self.visit_inputs(aggregate)
for agg_key in aggregate.keys.values():
self.find_critical_dependencies(agg_key)
self.add_output_dependencies(aggregate)

def visit_limit(self, limit: Limit) -> None:
# Limit nodes propagate dependencies from their inputs, but also
# analyze their ordering keys for critical dependencies.
self.visit_inputs(limit)
for order_expr in limit.orderings:
self.find_critical_dependencies(order_expr.expr)
self.add_output_dependencies(limit)

def visit_root(self, root: RelationalRoot) -> None:
# Root nodes propagate dependencies from their inputs, but also
# analyze their ordering keys for critical dependencies.
self.visit_inputs(root)
for order_expr in root.orderings:
self.find_critical_dependencies(order_expr.expr)
self.add_output_dependencies(root)

def visit_scan(self, scan: Scan) -> None:
# Scan nodes have no inputs, so they propagate dependencies based on
# their columns relative to an empty input.
self.expression_visitor.input_dependencies = {}
self.add_output_dependencies(scan)

def visit_generated_table(self, generated_table: GeneratedTable) -> None:
# GeneratedTable nodes have no inputs, so they propagate dependencies based on
# their columns relative to an empty input.
self.expression_visitor.input_dependencies = {}
self.add_output_dependencies(generated_table)

def visit_empty_singleton(self, empty_singleton: EmptySingleton) -> None:
# Empty singletons have no inputs, so they propagate dependencies based
# on their columns relative to an empty input.
self.expression_visitor.input_dependencies = {}
self.add_output_dependencies(empty_singleton)
4 changes: 2 additions & 2 deletions pydough/conversion/masking_shuttles.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def rewrite_masked_literal_comparison(
# literal in a call to MASK by toggling is_unmask to False.
masked_literal = CallExpression(
pydop.MaskedExpressionFunctionOperator(
call_arg.op.masking_metadata, False
call_arg.op.masking_metadata, call_arg.op.table_path, False
),
call_arg.data_type,
[literal_arg],
Expand All @@ -83,7 +83,7 @@ def rewrite_masked_literal_comparison(
[
CallExpression(
pydop.MaskedExpressionFunctionOperator(
call_arg.op.masking_metadata, False
call_arg.op.masking_metadata, call_arg.op.table_path, False
),
call_arg.data_type,
[LiteralExpression(v, inner_type)],
Expand Down
Loading