From 13d2e813a794b95100572b532c8deb1d810ce8c3 Mon Sep 17 00:00:00 2001 From: Labanya Mukhopadhyay Date: Fri, 6 Dec 2024 16:37:04 -0800 Subject: [PATCH] SNOW-1690711: support for cortex sentiment, classify_text with apply Signed-off-by: Labanya Mukhopadhyay --- src/snowflake/snowpark/functions.py | 32 +++++++++++++++++++ .../modin/plugin/_internal/apply_utils.py | 4 +++ .../test_apply_snowpark_python_functions.py | 10 ++++++ 3 files changed, 46 insertions(+) diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 9977b452b7..3f40ab36bb 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -10149,3 +10149,35 @@ def snowflake_cortex_summarize(text: ColumnOrLiteralStr): sql_func_name = "snowflake.cortex.summarize" text_col = _to_col_if_lit(text, sql_func_name) return builtin(sql_func_name)(text_col) + + +def snowflake_cortex_classify_text(input: ColumnOrLiteralStr, list_of_categories): + """ + Classifies free-form text into categories that you provide. + Args: + input: A string containing the English text from which a summary should be generated. + list_of_categories: Array that represents the categories. Must contain at least two and at most 100 unique + categories. Categories are case sensitive. If these requirements are not met, the function returns an error. + Returns: + Returns a string that contains a JSON object. The JSON object contains the category that the input prompt was + classified as. If invalid arguments are given, an error is returned. + """ + sql_func_name = "snowflake.cortex.classify_text" + input_col = _to_col_if_lit(input, sql_func_name) + return builtin(sql_func_name)(input_col, list_of_categories) + + +def snowflake_cortex_sentiment(text: ColumnOrLiteralStr): + """ + A string containing the text for which a sentiment score should be calculated. + + Args: + text: A string containing the English text from which a summary should be generated. + + Returns: + A floating-point number from -1 to 1 (inclusive) indicating the level of negative or positive sentiment in the + text. Values around 0 indicate neutral sentiment. + """ + sql_func_name = "snowflake.cortex.sentiment" + text_col = _to_col_if_lit(text, sql_func_name) + return builtin(sql_func_name)(text_col) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index 187f9d26c5..0935dc5c37 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -35,7 +35,9 @@ _log2, _log10, sin, + snowflake_cortex_sentiment, snowflake_cortex_summarize, + snowflake_cortex_classify_text, udf, to_variant, when, @@ -110,7 +112,9 @@ floor, trunc, sqrt, + snowflake_cortex_sentiment, snowflake_cortex_summarize, + snowflake_cortex_classify_text, } diff --git a/tests/integ/modin/test_apply_snowpark_python_functions.py b/tests/integ/modin/test_apply_snowpark_python_functions.py index 5e5911a92c..300e62904f 100644 --- a/tests/integ/modin/test_apply_snowpark_python_functions.py +++ b/tests/integ/modin/test_apply_snowpark_python_functions.py @@ -89,3 +89,13 @@ def test_apply_snowflake_cortex_summarize(): summary = s.apply(snowflake_cortex_summarize).iloc[0] # this length check is to get around the fact that this function may not be deterministic assert 0 < len(summary) < len(content) + + +@sql_count_checker(query_count=1) +def test_apply_snowflake_cortex_sentiment(): + from snowflake.snowpark.functions import snowflake_cortex_sentiment + + content = "A very very bad review!" + s = pd.Series([content]) + sentiment = s.apply(snowflake_cortex_sentiment).iloc[0] + assert -1 <= sentiment <= 1