From 0bdaf0b1571f8535b799f154fcc60bf8e705ce75 Mon Sep 17 00:00:00 2001 From: Zachary Bloss Date: Mon, 9 Sep 2024 12:22:45 -0400 Subject: [PATCH] Add EmbedText support (#115) Adds support for both SNOWFLAKE.CORTEX.EMBED_TEXT_768 & SNOWFLAKE.CORTEX.EMBED_TEXT_1024 methods into the cortex python sdk. --- snowflake/cortex/BUILD.bazel | 40 +++++++++++++ snowflake/cortex/__init__.py | 4 ++ snowflake/cortex/_embed_text_1024.py | 42 +++++++++++++ snowflake/cortex/_embed_text_768.py | 43 ++++++++++++++ snowflake/cortex/embed_text_1024_test.py | 65 +++++++++++++++++++++ snowflake/cortex/embed_text_768_test.py | 65 +++++++++++++++++++++ snowflake/cortex/package_visibility_test.py | 6 ++ 7 files changed, 265 insertions(+) create mode 100644 snowflake/cortex/_embed_text_1024.py create mode 100644 snowflake/cortex/_embed_text_768.py create mode 100644 snowflake/cortex/embed_text_1024_test.py create mode 100644 snowflake/cortex/embed_text_768_test.py diff --git a/snowflake/cortex/BUILD.bazel b/snowflake/cortex/BUILD.bazel index 9a0c6c9d..1c65b740 100644 --- a/snowflake/cortex/BUILD.bazel +++ b/snowflake/cortex/BUILD.bazel @@ -153,6 +153,44 @@ py_test( ], ) +py_library( + name = "embed_text_768", + srcs = ["_embed_text_768.py"], + deps = [ + ":util", + "//snowflake/ml/_internal:telemetry", + ], +) + +py_test( + name = "embed_text_768_test", + srcs = ["embed_text_768_test.py"], + deps = [ + ":embed_text_768", + ":test_util", + "//snowflake/ml/utils:connection_params", + ], +) + +py_library( + name = "embed_text_1024", + srcs = ["_embed_text_1024.py"], + deps = [ + ":util", + "//snowflake/ml/_internal:telemetry", + ], +) + +py_test( + name = "embed_text_1024_test", + srcs = ["embed_text_1024_test.py"], + deps = [ + ":embed_text_1024", + ":test_util", + "//snowflake/ml/utils:connection_params", + ], +) + py_library( name = "init", srcs = [ @@ -161,6 +199,8 @@ py_library( deps = [ ":classify_text", ":complete", + ":embed_text_768", + ":embed_text_1024", ":extract_answer", ":sentiment", ":summarize", diff --git a/snowflake/cortex/__init__.py b/snowflake/cortex/__init__.py index 1ee01368..947b2d77 100644 --- a/snowflake/cortex/__init__.py +++ b/snowflake/cortex/__init__.py @@ -3,12 +3,16 @@ from snowflake.cortex._extract_answer import ExtractAnswer from snowflake.cortex._sentiment import Sentiment from snowflake.cortex._summarize import Summarize +from snowflake.cortex._embed_text_768 import EmbedText768 +from snowflake.cortex._embed_text_1024 import EmbedText1024 from snowflake.cortex._translate import Translate __all__ = [ "ClassifyText", "Complete", "CompleteOptions", + "EmbedText768", + "EmbedText1024", "ExtractAnswer", "Sentiment", "Summarize", diff --git a/snowflake/cortex/_embed_text_1024.py b/snowflake/cortex/_embed_text_1024.py new file mode 100644 index 00000000..9462a801 --- /dev/null +++ b/snowflake/cortex/_embed_text_1024.py @@ -0,0 +1,42 @@ +from typing import Optional, Union + +from snowflake import snowpark +from snowflake.cortex._util import ( + CORTEX_FUNCTIONS_TELEMETRY_PROJECT, + call_sql_function, +) +from snowflake.ml._internal import telemetry + + +@telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, +) +def EmbedText1024( + model: Union[str, snowpark.Column], + text: Union[str, snowpark.Column], + session: Optional[snowpark.Session] = None, +) -> Union[list[float], snowpark.Column]: + """TextEmbed calls into the LLM inference service to embed the text. + + Args: + model: A Column of strings representing the model to use for embedding. The value + of the strings must be within the SUPPORTED_MODELS list. + text: A Column of strings representing input text. + session: The snowpark session to use. Will be inferred by context if not specified. + + Returns: + A column of vectors containing embeddings. + """ + + return _embed_text_1024_impl( + "snowflake.cortex.embed_text_1024", model, text, session=session + ) + + +def _embed_text_1024_impl( + function: str, + model: Union[str, snowpark.Column], + text: Union[str, snowpark.Column], + session: Optional[snowpark.Session] = None, +) -> Union[list[float], snowpark.Column]: + return call_sql_function(function, session, model, text) diff --git a/snowflake/cortex/_embed_text_768.py b/snowflake/cortex/_embed_text_768.py new file mode 100644 index 00000000..78838ff0 --- /dev/null +++ b/snowflake/cortex/_embed_text_768.py @@ -0,0 +1,43 @@ +from typing import Optional, Union, List + +from snowflake import snowpark +from snowflake.cortex._util import ( + CORTEX_FUNCTIONS_TELEMETRY_PROJECT, + SnowflakeConfigurationException, + call_sql_function, +) +from snowflake.ml._internal import telemetry + + +@telemetry.send_api_usage_telemetry( + project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT, +) +def EmbedText768( + model: Union[str, snowpark.Column], + text: Union[str, snowpark.Column], + session: Optional[snowpark.Session] = None, +) -> Union[list[float], snowpark.Column]: + """TextEmbed calls into the LLM inference service to embed the text. + + Args: + model: A Column of strings representing the model to use for embedding. The value + of the strings must be within the SUPPORTED_MODELS list. + text: A Column of strings representing input text. + session: The snowpark session to use. Will be inferred by context if not specified. + + Returns: + A column of vectors containing embeddings. + """ + + return _embed_text_768_impl( + "snowflake.cortex.embed_text_768", model, text, session=session + ) + + +def _embed_text_768_impl( + function: str, + model: Union[str, snowpark.Column], + text: Union[str, snowpark.Column], + session: Optional[snowpark.Session] = None, +) -> Union[list[float], snowpark.Column]: + return call_sql_function(function, session, model, text) diff --git a/snowflake/cortex/embed_text_1024_test.py b/snowflake/cortex/embed_text_1024_test.py new file mode 100644 index 00000000..d2724d09 --- /dev/null +++ b/snowflake/cortex/embed_text_1024_test.py @@ -0,0 +1,65 @@ +from typing import List + +import _test_util +from absl.testing import absltest + +from snowflake import snowpark +from snowflake.cortex import _embed_text_1024 +from snowflake.snowpark import functions, types + + +class EmbedTest1024Test(absltest.TestCase): + model = "snowflake-arctic-embed-m" + text = "|text|" + + @staticmethod + def embed_text_1024_for_test(model: str, text: str) -> List[float]: + return [0.0] * 1024 + + def setUp(self) -> None: + self._session = _test_util.create_test_session() + functions.udf( + self.embed_text_1024_for_test, + name="embed_text_1024", + session=self._session, + return_type=types.VectorType(float, 1024), + input_types=[types.StringType(), types.StringType()], + is_permanent=False, + ) + + def tearDown(self) -> None: + self._session.sql("drop function embed_text_1024(string,string)").collect() + self._session.close() + + def test_embed_text_1024_str(self) -> None: + res = _embed_text_1024._embed_text_1024_impl( + "embed_text_1024", + self.model, + self.text, + session=self._session, + ) + out = self.embed_text_1024_for_test(self.model, self.text) + self.assertEqual( + out, res + ), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}" + + def test_embed_text_1024_column(self) -> None: + df_in = self._session.create_dataframe( + [snowpark.Row(model=self.model, text=self.text)] + ) + df_out = df_in.select( + _embed_text_1024._embed_text_1024_impl( + "embed_text_1024", + functions.col("model"), + functions.col("text"), + session=self._session, + ) + ) + res = df_out.collect()[0][0] + out = self.embed_text_1024_for_test(self.model, self.text) + + self.assertEqual(out, res) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/cortex/embed_text_768_test.py b/snowflake/cortex/embed_text_768_test.py new file mode 100644 index 00000000..c07249ab --- /dev/null +++ b/snowflake/cortex/embed_text_768_test.py @@ -0,0 +1,65 @@ +from typing import List + +import _test_util +from absl.testing import absltest + +from snowflake import snowpark +from snowflake.cortex import _embed_text_768 +from snowflake.snowpark import functions, types + + +class EmbedTest768Test(absltest.TestCase): + model = "snowflake-arctic-embed-m" + text = "|text|" + + @staticmethod + def embed_text_768_for_test(model: str, text: str) -> List[float]: + return [0.0] * 768 + + def setUp(self) -> None: + self._session = _test_util.create_test_session() + functions.udf( + self.embed_text_768_for_test, + name="embed_text_768", + session=self._session, + return_type=types.VectorType(float, 768), + input_types=[types.StringType(), types.StringType()], + is_permanent=False, + ) + + def tearDown(self) -> None: + self._session.sql("drop function embed_text_768(string,string)").collect() + self._session.close() + + def test_embed_text_768_str(self) -> None: + res = _embed_text_768._embed_text_768_impl( + "embed_text_768", + self.model, + self.text, + session=self._session, + ) + out = self.embed_text_768_for_test(self.model, self.text) + self.assertEqual( + out, res + ), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}" + + def test_embed_text_768_column(self) -> None: + df_in = self._session.create_dataframe( + [snowpark.Row(model=self.model, text=self.text)] + ) + df_out = df_in.select( + _embed_text_768._embed_text_768_impl( + "embed_text_768", + functions.col("model"), + functions.col("text"), + session=self._session, + ) + ) + res = df_out.collect()[0][0] + out = self.embed_text_768_for_test(self.model, self.text) + + self.assertEqual(out, res) + + +if __name__ == "__main__": + absltest.main() diff --git a/snowflake/cortex/package_visibility_test.py b/snowflake/cortex/package_visibility_test.py index 1addaa09..98655da8 100644 --- a/snowflake/cortex/package_visibility_test.py +++ b/snowflake/cortex/package_visibility_test.py @@ -16,6 +16,12 @@ def test_complete_visible(self) -> None: def test_extract_answer_visible(self) -> None: self.assertTrue(callable(cortex.ExtractAnswer)) + def test_embed_text_768_visible(self) -> None: + self.assertTrue(callable(cortex.EmbedText768)) + + def test_embed_text_1024_visible(self) -> None: + self.assertTrue(callable(cortex.EmbedText1024)) + def test_sentiment_visible(self) -> None: self.assertTrue(callable(cortex.Sentiment))