Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EmbedText support #115

Merged
merged 5 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions snowflake/cortex/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,44 @@ py_test(
],
)

py_library(
name = "embed_text_768",
srcs = ["_embed_text_768.py"],
deps = [
":util",
"//snowflake/ml/_internal:telemetry",
],
)

py_test(
name = "embed_text_768_test",
srcs = ["embed_text_768_test.py"],
deps = [
":embed_text_768",
":test_util",
"//snowflake/ml/utils:connection_params",
],
)

py_library(
name = "embed_text_1024",
srcs = ["_embed_text_1024.py"],
deps = [
":util",
"//snowflake/ml/_internal:telemetry",
],
)

py_test(
name = "embed_text_1024_test",
srcs = ["embed_text_1024_test.py"],
deps = [
":embed_text_1024",
":test_util",
"//snowflake/ml/utils:connection_params",
],
)

py_library(
name = "init",
srcs = [
Expand All @@ -161,6 +199,7 @@ py_library(
deps = [
":classify_text",
":complete",
":embed_text_768",
":extract_answer",
zbloss marked this conversation as resolved.
Show resolved Hide resolved
":sentiment",
":summarize",
Expand Down
4 changes: 4 additions & 0 deletions snowflake/cortex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from snowflake.cortex._extract_answer import ExtractAnswer
from snowflake.cortex._sentiment import Sentiment
from snowflake.cortex._summarize import Summarize
from snowflake.cortex._embed_text_768 import EmbedText768
from snowflake.cortex._embed_text_1024 import EmbedText1024
from snowflake.cortex._translate import Translate

__all__ = [
"ClassifyText",
"Complete",
"CompleteOptions",
"EmbedText768",
"EmbedText1024",
"ExtractAnswer",
"Sentiment",
"Summarize",
Expand Down
54 changes: 54 additions & 0 deletions snowflake/cortex/_embed_text_1024.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Optional, Union, List

from snowflake import snowpark
from snowflake.cortex._util import (
CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
SnowflakeConfigurationException,
call_sql_function,
)
from snowflake.ml._internal import telemetry


SUPPORTED_MODELS: List[str] = [
"nv-embed-qa-4",
"multilingual-e5-large",
]


@telemetry.send_api_usage_telemetry(
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
)
def EmbedText1024(
model: Union[str, snowpark.Column],
text: Union[str, snowpark.Column],
session: Optional[snowpark.Session] = None,
) -> Union[str, snowpark.Column]:
"""TextEmbed calls into the LLM inference service to embed the text.

Args:
model: A Column of strings representing the model to use for embedding. The value
of the strings must be within the SUPPORTED_MODELS list.
text: A Column of strings representing input text.
session: The snowpark session to use. Will be inferred by context if not specified.

Returns:
A column of vectors containing embeddings.
"""

if model not in SUPPORTED_MODELS:
zbloss marked this conversation as resolved.
Show resolved Hide resolved
raise SnowflakeConfigurationException(
f"model must be one of {SUPPORTED_MODELS}"
)

return _embed_text_1024_impl(
"snowflake.cortex.embed_text_1024", model, text, session=session
)


def _embed_text_1024_impl(
function: str,
model: Union[str, snowpark.Column],
text: Union[str, snowpark.Column],
session: Optional[snowpark.Session] = None,
) -> Union[str, snowpark.Column]:
return call_sql_function(function, session, model, text)
53 changes: 53 additions & 0 deletions snowflake/cortex/_embed_text_768.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Optional, Union, List

from snowflake import snowpark
from snowflake.cortex._util import (
CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
SnowflakeConfigurationException,
call_sql_function,
)
from snowflake.ml._internal import telemetry

SUPPORTED_MODELS: List[str] = [
"snowflake-arctic-embed-m",
"e5-base-v2",
]


@telemetry.send_api_usage_telemetry(
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
)
def EmbedText768(
model: Union[str, snowpark.Column],
text: Union[str, snowpark.Column],
session: Optional[snowpark.Session] = None,
) -> Union[str, snowpark.Column]:
"""TextEmbed calls into the LLM inference service to embed the text.

Args:
model: A Column of strings representing the model to use for embedding. The value
of the strings must be within the SUPPORTED_MODELS list.
text: A Column of strings representing input text.
session: The snowpark session to use. Will be inferred by context if not specified.

Returns:
A column of vectors containing embeddings.
"""

if model not in SUPPORTED_MODELS:
raise SnowflakeConfigurationException(
f"model must be one of {SUPPORTED_MODELS}"
)

return _embed_text_768_impl(
"snowflake.cortex.embed_text_768", model, text, session=session
)


def _embed_text_768_impl(
function: str,
model: Union[str, snowpark.Column],
text: Union[str, snowpark.Column],
session: Optional[snowpark.Session] = None,
) -> Union[str, snowpark.Column]:
return call_sql_function(function, session, model, text)
65 changes: 65 additions & 0 deletions snowflake/cortex/embed_text_1024_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List

import _test_util
from absl.testing import absltest

from snowflake import snowpark
from snowflake.cortex import _embed_text_1024
from snowflake.snowpark import functions, types


class EmbedTest1024Test(absltest.TestCase):
model = "snowflake-arctic-embed-m"
text = "|text|"

@staticmethod
def embed_text_1024_for_test(model: str, text: str) -> List[float]:
return [0.0] * 1024

def setUp(self) -> None:
self._session = _test_util.create_test_session()
functions.udf(
self.embed_text_1024_for_test,
name="embed_text_1024",
session=self._session,
return_type=types.VectorType(float, 1024),
input_types=[types.StringType(), types.StringType()],
is_permanent=False,
)

def tearDown(self) -> None:
self._session.sql("drop function embed_text_1024(string,string)").collect()
self._session.close()

def test_embed_text_1024_str(self) -> None:
res = _embed_text_1024._embed_text_1024_impl(
"embed_text_1024",
self.model,
self.text,
session=self._session,
)
out = self.embed_text_1024_for_test(self.model, self.text)
self.assertEqual(
out, res
), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}"

def test_embed_text_1024_column(self) -> None:
df_in = self._session.create_dataframe(
[snowpark.Row(model=self.model, text=self.text)]
)
df_out = df_in.select(
_embed_text_1024._embed_text_1024_impl(
"embed_text_1024",
functions.col("model"),
functions.col("text"),
session=self._session,
)
)
res = df_out.collect()[0][0]
out = self.embed_text_1024_for_test(self.model, self.text)

self.assertEqual(out, res)


if __name__ == "__main__":
absltest.main()
65 changes: 65 additions & 0 deletions snowflake/cortex/embed_text_768_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import List

import _test_util
from absl.testing import absltest

from snowflake import snowpark
from snowflake.cortex import _embed_text_768
from snowflake.snowpark import functions, types


class EmbedTest768Test(absltest.TestCase):
model = "snowflake-arctic-embed-m"
text = "|text|"

@staticmethod
def embed_text_768_for_test(model: str, text: str) -> List[float]:
return [0.0] * 768

def setUp(self) -> None:
self._session = _test_util.create_test_session()
functions.udf(
self.embed_text_768_for_test,
name="embed_text_768",
session=self._session,
return_type=types.VectorType(float, 768),
input_types=[types.StringType(), types.StringType()],
is_permanent=False,
)

def tearDown(self) -> None:
self._session.sql("drop function embed_text_768(string,string)").collect()
self._session.close()

def test_embed_text_768_str(self) -> None:
res = _embed_text_768._embed_text_768_impl(
"embed_text_768",
self.model,
self.text,
session=self._session,
)
out = self.embed_text_768_for_test(self.model, self.text)
self.assertEqual(
out, res
), f"Expected ({type(out)}) {out}, got ({type(res)}) {res}"

def test_embed_text_768_column(self) -> None:
df_in = self._session.create_dataframe(
[snowpark.Row(model=self.model, text=self.text)]
)
df_out = df_in.select(
_embed_text_768._embed_text_768_impl(
"embed_text_768",
functions.col("model"),
functions.col("text"),
session=self._session,
)
)
res = df_out.collect()[0][0]
out = self.embed_text_768_for_test(self.model, self.text)

self.assertEqual(out, res)


if __name__ == "__main__":
absltest.main()