Skip to content

Commit fec2b31

Browse files
feat(optimizer)!: Annotate type for snowflake SEARCH function (#5985)
* feat(optimizer)!: Annotate type for snowflake SEARCH function * fix: Fix parsing and AST generation * fix: Address review comments * fix: modified tests * fix(tests)!: Fixed failing tests
1 parent 74a13f2 commit fec2b31

File tree

6 files changed

+125
-1
lines changed

6 files changed

+125
-1
lines changed

sqlglot/dialects/bigquery.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,8 @@ class Parser(parser.Parser):
867867
"FROM_HEX": exp.Unhex.from_arg_list,
868868
"WEEK": lambda args: exp.WeekStart(this=exp.var(seq_get(args, 0))),
869869
}
870+
# Remove SEARCH to avoid parameter routing issues - let it fall back to Anonymous function
871+
FUNCTIONS.pop("SEARCH")
870872

871873
FUNCTION_PARSERS = {
872874
**parser.Parser.FUNCTION_PARSERS,

sqlglot/dialects/snowflake.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,33 @@ def _build_if_from_zeroifnull(args: t.List) -> exp.If:
168168
return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
169169

170170

171+
def _build_search(args: t.List) -> exp.Search:
172+
arg2 = seq_get(args, 2)
173+
arg3 = seq_get(args, 3)
174+
175+
analyzer_val = None
176+
search_mode_val = None
177+
178+
if arg2 and isinstance(arg2, exp.Kwarg):
179+
if arg2.this.name.lower() == "analyzer":
180+
analyzer_val = arg2
181+
elif arg2.this.name.lower() == "search_mode":
182+
search_mode_val = arg2
183+
184+
if arg3 and isinstance(arg3, exp.Kwarg):
185+
if arg3.this.name.lower() == "analyzer":
186+
analyzer_val = arg3
187+
elif arg3.this.name.lower() == "search_mode":
188+
search_mode_val = arg3
189+
190+
return exp.Search(
191+
this=seq_get(args, 0),
192+
expression=seq_get(args, 1),
193+
analyzer=analyzer_val,
194+
search_mode=search_mode_val,
195+
)
196+
197+
171198
# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
172199
def _build_if_from_nullifzero(args: t.List) -> exp.If:
173200
cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0))
@@ -622,6 +649,13 @@ class Snowflake(Dialect):
622649
exp.ParseUrl,
623650
exp.ParseIp,
624651
},
652+
exp.DataType.Type.DECIMAL: {
653+
exp.RegexpCount,
654+
},
655+
exp.DataType.Type.BOOLEAN: {
656+
*Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.BOOLEAN],
657+
exp.Search,
658+
},
625659
}
626660

627661
ANNOTATORS = {
@@ -834,6 +868,7 @@ class Parser(parser.Parser):
834868
"ZEROIFNULL": _build_if_from_zeroifnull,
835869
"LIKE": _build_like(exp.Like),
836870
"ILIKE": _build_like(exp.ILike),
871+
"SEARCH": _build_search,
837872
}
838873
FUNCTIONS.pop("PREDICT")
839874

sqlglot/expressions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7443,6 +7443,19 @@ class StrPosition(Func):
74437443
}
74447444

74457445

7446+
# Snowflake: https://docs.snowflake.com/en/sql-reference/functions/search
7447+
# BigQuery: https://cloud.google.com/bigquery/docs/reference/standard-sql/search_functions#search
7448+
class Search(Func):
7449+
arg_types = {
7450+
"this": True, # data_to_search / search_data
7451+
"expression": True, # search_query / search_string
7452+
"json_scope": False, # BigQuery: JSON_VALUES | JSON_KEYS | JSON_KEYS_AND_VALUES
7453+
"analyzer": False, # Both: analyzer / ANALYZER
7454+
"analyzer_options": False, # BigQuery: analyzer_options_values
7455+
"search_mode": False, # Snowflake: OR | AND
7456+
}
7457+
7458+
74467459
class StrToDate(Func):
74477460
arg_types = {"this": True, "format": False, "safe": False}
74487461

tests/dialects/test_bigquery.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,24 @@ def test_bigquery(self):
7979
self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')")
8080
self.validate_identity("FOO(values)")
8181
self.validate_identity("STRUCT(values AS value)")
82+
83+
self.validate_identity("SELECT SEARCH(data_to_search, 'search_query')")
84+
self.validate_identity(
85+
"SELECT SEARCH(data_to_search, 'search_query', json_scope => 'JSON_KEYS_AND_VALUES')"
86+
)
87+
self.validate_identity(
88+
"SELECT SEARCH(data_to_search, 'search_query', analyzer => 'PATTERN_ANALYZER')"
89+
)
90+
self.validate_identity(
91+
"SELECT SEARCH(data_to_search, 'search_query', analyzer_options => 'analyzer_options_values')"
92+
)
93+
self.validate_identity(
94+
"SELECT SEARCH(data_to_search, 'search_query', json_scope => 'JSON_VALUES', analyzer => 'LOG_ANALYZER')"
95+
)
96+
self.validate_identity(
97+
"SELECT SEARCH(data_to_search, 'search_query', analyzer => 'PATTERN_ANALYZER', analyzer_options => 'options')"
98+
)
99+
82100
self.validate_identity("ARRAY_AGG(x IGNORE NULLS LIMIT 1)")
83101
self.validate_identity("ARRAY_AGG(x IGNORE NULLS ORDER BY x LIMIT 1)")
84102
self.validate_identity("ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY x LIMIT 1)")

tests/dialects/test_snowflake.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2463,7 +2463,43 @@ def test_regexp_substr(self, logger):
24632463
"REGEXP_EXTRACT_ALL(subject, pattern)",
24642464
)
24652465

2466-
self.validate_identity("SELECT REGEXP_COUNT('hello world', 'l')")
2466+
self.validate_identity("SELECT SEARCH((play, line), 'dream')")
2467+
self.validate_identity("SELECT SEARCH(line, 'king', ANALYZER => 'UNICODE_ANALYZER')")
2468+
self.validate_identity("SELECT SEARCH(character, 'king queen', SEARCH_MODE => 'AND')")
2469+
self.validate_identity(
2470+
"SELECT SEARCH(line, 'king', ANALYZER => 'UNICODE_ANALYZER', SEARCH_MODE => 'OR')"
2471+
)
2472+
2473+
# AST validation tests - verify argument mapping
2474+
ast = self.validate_identity("SELECT SEARCH(line, 'king')")
2475+
search_ast = ast.find(exp.Search)
2476+
self.assertEqual(list(search_ast.args), ["this", "expression", "analyzer", "search_mode"])
2477+
self.assertIsNone(search_ast.args.get("analyzer"))
2478+
self.assertIsNone(search_ast.args.get("search_mode"))
2479+
2480+
ast = self.validate_identity("SELECT SEARCH(line, 'king', ANALYZER => 'UNICODE_ANALYZER')")
2481+
search_ast = ast.find(exp.Search)
2482+
self.assertIsNotNone(search_ast.args.get("analyzer"))
2483+
self.assertIsNone(search_ast.args.get("search_mode"))
2484+
2485+
ast = self.validate_identity("SELECT SEARCH(character, 'king queen', SEARCH_MODE => 'AND')")
2486+
search_ast = ast.find(exp.Search)
2487+
self.assertIsNone(search_ast.args.get("analyzer"))
2488+
self.assertIsNotNone(search_ast.args.get("search_mode"))
2489+
2490+
# Test with arguments in different order (search_mode first, then analyzer)
2491+
ast = self.validate_identity(
2492+
"SELECT SEARCH(line, 'king', SEARCH_MODE => 'AND', ANALYZER => 'PATTERN_ANALYZER')",
2493+
"SELECT SEARCH(line, 'king', ANALYZER => 'PATTERN_ANALYZER', SEARCH_MODE => 'AND')",
2494+
)
2495+
search_ast = ast.find(exp.Search)
2496+
self.assertEqual(list(search_ast.args), ["this", "expression", "analyzer", "search_mode"])
2497+
analyzer = search_ast.args.get("analyzer")
2498+
self.assertIsNotNone(analyzer)
2499+
search_mode = search_ast.args.get("search_mode")
2500+
self.assertIsNotNone(search_mode)
2501+
2502+
self.validate_identity("SELECT REGEXP_COUNT('hello world', 'l ')")
24672503
self.validate_identity("SELECT REGEXP_COUNT('hello world', 'l', 1)")
24682504
self.validate_identity("SELECT REGEXP_COUNT('hello world', 'l', 1, 'i')")
24692505

tests/fixtures/optimizer/annotate_functions.sql

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2219,6 +2219,26 @@ BOOLEAN;
22192219
STARTSWITH(tbl.bin_col, NULL);
22202220
BOOLEAN;
22212221

2222+
# dialect: snowflake
2223+
SEARCH(line, 'king');
2224+
BOOLEAN;
2225+
2226+
# dialect: snowflake
2227+
SEARCH((play, line), 'dream');
2228+
BOOLEAN;
2229+
2230+
# dialect: snowflake
2231+
SEARCH(line, 'king', ANALYZER => 'UNICODE_ANALYZER');
2232+
BOOLEAN;
2233+
2234+
# dialect: snowflake
2235+
SEARCH(line, 'king', SEARCH_MODE => 'OR');
2236+
BOOLEAN;
2237+
2238+
# dialect: snowflake
2239+
SEARCH(line, 'king', ANALYZER => 'UNICODE_ANALYZER', SEARCH_MODE => 'AND');
2240+
BOOLEAN;
2241+
22222242
# dialect: snowflake
22232243
STRTOK_TO_ARRAY('a,b,c', ',');
22242244
ARRAY;

0 commit comments

Comments
 (0)