-
Notifications
You must be signed in to change notification settings - Fork 3
Adding optimization rewrite pass to utilize server with information about masked columns #443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ptbank_filter_count_01
john-sanchez31
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments regarding IN and ISIN operators and a type hint
john-sanchez31
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just fix the type hints missing and TODO docstrings, but overall LGTM! Nice job with the new dry run algorithm impressive!
pydough/mask_server/mask_server.py
Outdated
| """ | ||
|
|
||
| def __init__(self, base_url: str, token: str | None = None): | ||
| def __init__(self, base_url: str, server_address: str, token: str | None = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be the difference between base_url and server_addresss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think base_url is to contact the predicate server and server_address is to write the Fully Qualified Column Name as f"{server_addresss}/{table_path}". Is this correct?
Where will server_address be configured? This is very specific to the database instance we are connecting to. For example, metadata can be re-used for the same database on different servers, even with different engines. However, the server_address is directly associated (1:1) with the database instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think base_url is to contact the predicate server and server_address is to write the Fully Qualified Column Name as f"{server_addresss}/{table_path}". Is this correct?
Yes, that is correct.
Where will server_address be configured?
When you configure/mount the MaskServerInfo class, you pass in the server_address (same place the token gets passed).
pydough/mask_server/mask_server.py
Outdated
| response: dict = item.get("response", None) | ||
| if response is None: | ||
| # In this case, use a dummy value as a default to indicate | ||
| # the dry run was successful |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to indicate the dry run was unsuccessful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I mean successful. I do need to adjust this slightly.
| pydop.MaskedExpressionFunctionOperator( | ||
| hybrid_expr.column.column_property, True | ||
| hybrid_expr.column.column_property, | ||
| node.collection.collection.table_path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the reason why we need to use the full table path in metadata?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EXACTLY (plus its a good idea in general)
| # PYDOUGH_ENABLE_MASK_REWRITES is set to 1. | ||
| # PYDOUGH_ENABLE_MASK_REWRITES is set to 1. If a masking rewrite server has | ||
| # been attached to the session, include the shuttles for that as well. | ||
| if os.getenv("PYDOUGH_ENABLE_MASK_REWRITES") == "1": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reson why PYDOUGH_ENABLE_MASK_REWRITES is not in PyDoughConfigs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we wanted an environment variable as a "switch"
pydough/mask_server/mask_server.py
Outdated
| """ | ||
|
|
||
| def __init__(self, base_url: str, token: str | None = None): | ||
| def __init__(self, base_url: str, server_address: str, token: str | None = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think base_url is to contact the predicate server and server_address is to write the Fully Qualified Column Name as f"{server_addresss}/{table_path}". Is this correct?
Where will server_address be configured? This is very specific to the database instance we are connecting to. For example, metadata can be re-used for the same database on different servers, even with different engines. However, the server_address is directly associated (1:1) with the database instance.
| for idx, item in enumerate(batch): | ||
| pyd_logger.info( | ||
| f"({idx + 1}) {item.table_path}.{item.column_name}: {item.expression}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this log entry be debug level instead of info?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤷 Rn I'm keeping everything the same logging level for simplicity. We can revise down if we think it is appropriate.
| request: ServerRequest = self.generate_request( | ||
| batch, path, method, dry_run, hard_limit | ||
| ) | ||
| response_json = self.connection.send_server_request(request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of a predicate_server failure, users will not be able to query the database at all. Not even with the MASK functions. This could be a critical point of failure for the system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Failure vs error are very different. If there is a legitimate error with connecting to the server, my understanding was that we wanted to abort. If the server responds just fine but indicates it failed to derive an answer, then that's fine and we proceed normally.
pydough/mask_server/mask_server.py
Outdated
| "column_reference": f"{item.table_path}.{item.column_name}", | ||
| "column_ref": { | ||
| "kind": "fqn", | ||
| "value": f"{self.server_address}.{item.table_path}.{item.column_name}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separator should be "/". item.table_path is a composed name with elements separated by ".". Any element could be enclosed with double-quotes or backtick and have "." as part of the element name. Additionally, any character in the name equals to the enclosure char will be escaped using the same character twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could wait to see the real thing implementation before this kind of changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gonna do this. The problem for table path is how to handle varying edge cases of what table_path looks like:
db.schema.col->db/schema/col"a.b"."c.d"."e.f"->a.b/c.d/e.f
pydough/mask_server/mask_server.py
Outdated
|
|
||
| assert batch != [], "Batch cannot be empty." | ||
|
|
||
| path: str = "v1/predicates/batch-evaluate" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
path could be a class variable, so we don't need to pass it as parameter to other class methods like generate_request()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, moved into a class var of MaskServerInfo that gets passed into the ServerRequest by generate_request
pydough/mask_server/mask_server.py
Outdated
| self, | ||
| batch: list[MaskServerInput], | ||
| path: str, | ||
| method: RequestMethod, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Including path and method in parameters looks like an attempt of doing generate_request() more general. However, due to all other specific parameters and actions I think this method is very specific for batch-evaluate. Maybe path and method could be class properties since them will not change for this method. If more request methods are required in future those paths and methods could also be part of the class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think method doesn't need to be a class property, can just get baked into the method's construciton of an ServerRequest instance.
pydough/mask_server/mask_server.py
Outdated
| """ | ||
| Generate a list of server outputs from the server response. | ||
| Generate a list of server outputs from the server response of a | ||
| non-dry-run request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when the request is a dry-run? We are calling generate_result(response_json) in
L174 for all batch-evaluate requests.
I didn't liked the design idea to have the dry-run and the actual call in the same API path because they are different things called on different times. We can't change that but could it make sense to separate them on our side? At least how do we process the response?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment should be rolled-back. The function is the same for both, the difference is that dry-runs have an empty payload for the records.
| - `DATEDIFF` | ||
| """ | ||
|
|
||
| PREDICATE_OPERATORS: set[str] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the criteria for a predicate operator to be included here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are the operators that are actually predicates, e.g. they return a boolean.
E.g. SUBSTRING can be inside the expression, but should not be the expression itself.
E.g. we wouldn't send abs(expr + 2) to the predicate server, but we would send abs(expr + 2) < 13, we wouldn't send LOWER(expr[:5]) but we would send CONTAINS(LOWER(expr[:5]), 'a')
| # from the earlier check. | ||
| for inp in input_exprs: | ||
| assert inp is not None | ||
| result.extend(inp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember that a literal string may require to use QUOTE if it matches an operator name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahhh good point. I'll do that for literal string handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder.
pydough/mask_server/mask_server.py
Outdated
| }, | ||
| ... | ||
| ], | ||
| "expression_format": {"name": "linear", "version": "0.2.0"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "expression_format": {"name": "linear", "version": "0.2.0"} | |
| "expression_format": {"name": "linear", "version": "0.2.0"}, |
| Mask Server and replacing the candidate expressions with the appropriate | ||
| responses from the server. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Mask Server and replacing the candidate expressions with the appropriate | |
| responses from the server. | |
| Mask Server. First send all candidates using the dry run flag, then selects the best candidates to be replaced with the appropriate response from the Mask Server. |
| self.processed_candidates: set[RelationalExpression] = set() | ||
| """ | ||
| The set of all relational expressions that have already been added to | ||
| the candidate pool at lest once. This is used to avoid adding the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder
hadia206
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well done!
I have minor comments but overall great work!
Thanks for the efforts
| } | ||
|
|
||
| def _split_identifier(self, name: str) -> list[str]: | ||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change to static?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it can be used from other files w/o creating an instance of this class
| # of the same expression. The responses will be stored in self.responses | ||
| # for later lookup. | ||
| if expr in self.candidate_visitor.candidate_pool: | ||
| self.process_batch() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment to explain why batching happens here and not at the end of traversal?
| assert mask_op.masking_metadata.server_masked | ||
| assert mask_op.masking_metadata.server_dataset_id is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Should these be asserts or explicit checks and exceptions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm treating these as assertions since if these conditions are not true, then it shouldn't have been placed in the candidate pool in the first place.
| expression. This is used to build the `heritage_tree` mapping. | ||
| """ | ||
|
|
||
| def reset(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring
| input_op: pydop.MaskedExpressionFunctionOperator | ||
| input_expr: RelationalExpression | ||
| combined_exprs: list[str | int | float | None | bool] | None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: following what happens with the stack is hard. Adding an example in the code that shows the flow, would be good.
| self.processed_candidates: set[RelationalExpression] = set() | ||
| """ | ||
| The set of all relational expressions that have already been added to | ||
| the candidate pool at lest once. This is used to avoid adding the same |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder
| # If there are zero unmasking operators in the inputs, or more than | ||
| # one, this expression is not a candidate. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may have missed it but can you explain why more that one is not a candidate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YEAR(x) == 2024is fineMONTH(y) == 6is fine(YEAR(x) == 2024) & (MONTH(y) == 6)is not fine because the predicate is onxandy, so we can't rewrite it asx IN (...)ory IN (...)- However, we could do
(YEAR(x) == 2024) & (MONTH(x) == 6), because that is a predicate on justx
| # from the earlier check. | ||
| for inp in input_exprs: | ||
| assert inp is not None | ||
| result.extend(inp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder.
|
|
||
| def contains_real_aggregate(expression) -> bool: | ||
| """ | ||
| Check if the expression contains a real aggregate function (e.g. SUM, AVG), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we discovered an unfortuante bug where MIN / MAX are treated by SQLGlot as aggregations even when doing MIN(a, b, c) (which is how some dialects do LEAST / GREATEST), which is highly problematic because it would make SQLGlot do buggy stuff during filter pushdown.
| return expression | ||
|
|
||
|
|
||
| def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is new code (not copy/paste from SqlGlot), add docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, the only new thing is the parts with contains_real_aggregate (which is a brand new function).
Augmenting relational optimization to rewrite expressions containing an UNMASK operator when a server is mounted to the PyDough session (and the environment variable is activated):
additional_shuttleslest, before the masking literal comparisons shuttle.MaskServerCandidateShuttleis a no-op shuttle that just traverses the entire tree to find expressions that can potentially be rewritten and adds them to a pool.MaskServerRewriteShuttlelooks for expressions in the candidate shuttle's pool, and once it finds one it sends every candidate in the pool into a batch request to the mask server, processing the output results to create the new relational node. The candidate pool is then emptied so future invocations will not re-do the same batch calculation.