Skip to content

Commit

Permalink
Project import generated by Copybara. (#17)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 6f059e3232ea228cbf906c59ab4389d10a00e7c1

Co-authored-by: Snowflake Authors <[email protected]>
  • Loading branch information
snowflake-provisioner and Snowflake Authors authored May 26, 2023
1 parent cd38c89 commit 86d00b9
Show file tree
Hide file tree
Showing 46 changed files with 1,344 additions and 1,105 deletions.
19 changes: 11 additions & 8 deletions ci/conda_recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,28 @@ requirements:
- anyio>=3.5.0,<4
- cloudpickle
- fsspec>=2022.11,<=2023.1
- numpy>=1.23,<1.24
- numpy>=1.23,<2
- packaging>=23.0,<24
- pandas>=1.0.0,<2 # Limit since 2.x is not available in Snowflake Anaconda Channel yet.
- pyyaml>=6.0,<7
- scikit-learn>=1.2.1,<2
- scipy>=1.9,<2
- snowflake-connector-python
- snowflake-snowpark-python>=1.4.0,<=2
- sqlparse>=0.4,<1
- typing-extensions>=4.1.0,<5
- xgboost>=1.7.3,<2

# conda-libmamba-solver is conda-specific requirement, and should not appear in wheel's dependency.
- conda-libmamba-solver>=23.1.0,<24
run_constrained:
# Any dependencies required by extra should be specified here so that conda could consider the constraints when
# installing them simultaneously. This part should sync with the extra_requirements in snowml_wheel in
# snowflake/ml/BUILD.bazel file.
- tensorflow>=2.9,<3
- torchdata>=0.4,<1
- lightgbm==3.3.5

# TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and
# snowpark sandbox. Generic definitions like scikit-learn>=1.1.0,<2 wont work because snowflake conda channel
# only has a few allowlisted versions of scikit-learn available, so we must force users to use scikit-learn
# versions that are available in the snowflake conda channel. Since there is no way to specify allow list of
# versions in the requirements file, we are pinning the versions here.
- scikit-learn>=1.2.1,<2
- xgboost==1.7.3
about:
home: https://github.com/snowflakedb/snowflake-ml-python
license: Apache-2.0
Expand Down
2 changes: 1 addition & 1 deletion codegen/codegen_rules.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def autogen_tests_for_estimators(module, module_root_dir, estimator_info_list):
"//{}:{}".format(module_root_dir, e.normalized_class_name),
"//snowflake/ml/utils:connection_params",
],
timeout = "long",
legacy_create_init = 0,
shard_count = 5,
timeout = "moderate",
tags = ["autogen", "skip_mypy_check"],
)
13 changes: 4 additions & 9 deletions snowflake/ml/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,17 @@ snowml_wheel(
"anyio>=3.5.0,<4",
"cloudpickle", # Version range is specified by snowpark. We are implicitly depending on it.
"fsspec[http]>=2022.11,<=2023.1",
"numpy>=1.23,<1.24",
"numpy>=1.23,<2",
"packaging>=23.0,<24",
"pandas>=1.0.0,<2", # Limit since 2.x is not available in Snowflake Anaconda Channel yet.
"pyyaml>=6.0,<7",
"scikit-learn>=1.2.1,<2",
"scipy>=1.9,<2",
"snowflake-connector-python[pandas]",
"snowflake-snowpark-python>=1.4.0,<2",
"sqlparse>=0.4,<1",
"typing-extensions>=4.1.0,<5",

# TODO(snandamuri): Versions of these packages must be exactly same between user's workspace and
# snowpark sandbox. Generic definitions like scikit-learn>=1.1.0,<2 wont work because snowflake conda channel
# only has a few allowlisted versions of scikit-learn available, so we must force users to use scikit-learn
# versions that are available in the snowflake conda channel. Since there is no way to specify allow list of
# versions in the requirements file, we are pinning the versions here.
"scikit-learn>=1.2.1,<2",
"xgboost==1.7.3",
"xgboost>=1.7.3,<2",
],
version = VERSION,
deps = [
Expand Down
43 changes: 42 additions & 1 deletion snowflake/ml/_internal/utils/identifier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union, overload

# Snowflake Identifier Regex. See https://docs.snowflake.com/en/sql-reference/identifiers-syntax.html.
_SF_UNQUOTED_IDENTIFIER = "[A-Za-z_][A-Za-z0-9_$]*"
Expand All @@ -9,6 +9,7 @@
_SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)

UNQUOTED_CASE_INSENSITIVE_RE = re.compile(f"^({_SF_UNQUOTED_IDENTIFIER})$")
QUOTED_IDENTIFIER_RE = re.compile(f"^({SF_QUOTED_IDENTIFIER})$")


def _is_quoted(id: str) -> bool:
Expand Down Expand Up @@ -55,6 +56,31 @@ def remove_quote_if_quoted(id: str) -> str:
return id


def remove_and_unescape_quote_if_quoted(id: str) -> str:
"""Remove double quotes and escape quotes between them from id if quoted.
NOTE: See note in :meth:`_is_quoted`.
Args:
id: The string to be checked & treated.
Raises:
ValueError: If the identifier is unquoted, it does not match the syntax.
ValueError: There is a continuous odd number of quotes, thus cannot unescape. Example '""a""' is invalid.
Returns:
String with quotes removed if quoted; original string otherwise.
"""
if not _is_quoted(id):
if not UNQUOTED_CASE_INSENSITIVE_RE.match(id):
raise ValueError("Invalid id passed.")
return id
if not QUOTED_IDENTIFIER_RE.match(id):
raise ValueError("Invalid id passed.")
unquoted_id = id[1:-1]
return unquoted_id.replace('""', '"')


def concat_names(ids: List[str]) -> str:
"""Concatenates `ids` to form one valid id.
Expand Down Expand Up @@ -106,6 +132,21 @@ def parse_schema_level_object_identifier(
return identifiers[0], identifiers[1], identifiers[2], identifiers[3]


@overload
def get_equivalent_identifier_in_the_response_pandas_dataframe(ids: None) -> None:
...


@overload
def get_equivalent_identifier_in_the_response_pandas_dataframe(ids: str) -> str:
...


@overload
def get_equivalent_identifier_in_the_response_pandas_dataframe(ids: List[str]) -> List[str]:
...


def get_equivalent_identifier_in_the_response_pandas_dataframe(
ids: Optional[Union[str, List[str]]]
) -> Optional[Union[str, List[str]]]:
Expand Down
15 changes: 15 additions & 0 deletions snowflake/ml/_internal/utils/identifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ def test_quote_not_removed(self) -> None:
self.assertEqual('foo"', identifier.remove_quote_if_quoted('foo"'))
self.assertEqual('foo"bar', identifier.remove_quote_if_quoted('foo"bar'))

def test_remove_and_unescape_quote_if_quoted(self) -> None:
self.assertEqual("foo", identifier.remove_and_unescape_quote_if_quoted('"foo"'))
self.assertEqual('"foo"', identifier.remove_and_unescape_quote_if_quoted('"""foo"""'))
self.assertEqual('foo"bar', identifier.remove_and_unescape_quote_if_quoted('"foo""bar"'))
with self.assertRaises(ValueError):
identifier.remove_and_unescape_quote_if_quoted('foo"')
with self.assertRaises(ValueError):
identifier.remove_and_unescape_quote_if_quoted('"bar')
with self.assertRaises(ValueError):
identifier.remove_and_unescape_quote_if_quoted('foo"bar')
with self.assertRaises(ValueError):
identifier.remove_and_unescape_quote_if_quoted('""foo""')
with self.assertRaises(ValueError):
identifier.remove_and_unescape_quote_if_quoted('"foo"""bar"')

def test_plan_concat(self) -> None:
"""Test vanilla concat with no quotes."""
self.assertEqual("demo__task1", identifier.concat_names(["demo__", "task1"]))
Expand Down
3 changes: 2 additions & 1 deletion snowflake/ml/metrics/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ package(default_visibility = ["//visibility:public"])
py_library(
name = "metrics",
srcs = [
"regression.py",
"accuracy_score.py",
"correlation.py",
"covariance.py",
"regression.py",
],
deps = [
":init",
Expand Down
2 changes: 2 additions & 0 deletions snowflake/ml/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .accuracy_score import accuracy_score
from .correlation import correlation
from .covariance import covariance

__all__ = [
"accuracy_score",
"correlation",
"covariance",
]
60 changes: 50 additions & 10 deletions snowflake/ml/metrics/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import numpy as np

from snowflake import snowpark
from snowflake.snowpark import Session, types as snowpark_types
from snowflake.snowpark import Session, functions as F, types as T

_PROJECT = "ModelDevelopment"
_SUBPROJECT = "Metrics"


def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, str]) -> str:
Expand Down Expand Up @@ -47,12 +50,12 @@ def end_partition(self) -> Iterable[Tuple[bytes]]:
dot_and_sum_accumulator = "DotAndSumAccumulator_{}".format(str(uuid4()).replace("-", "_").upper())
session.udtf.register(
DotAndSumAccumulator,
output_schema=snowpark_types.StructType(
output_schema=T.StructType(
[
snowpark_types.StructField("result", snowpark_types.BinaryType()),
T.StructField("result", T.BinaryType()),
]
),
input_types=[snowpark_types.BinaryType()],
input_types=[T.BinaryType()],
packages=["numpy", "cloudpickle"],
name=dot_and_sum_accumulator,
is_permanent=False,
Expand Down Expand Up @@ -159,13 +162,13 @@ def accumulate_batch_sum_and_dot_prod(self) -> None:
sharded_dot_and_sum_computer = "ShardedDotAndSumComputer_{}".format(str(uuid4()).replace("-", "_").upper())
session.udtf.register(
ShardedDotAndSumComputer,
output_schema=snowpark_types.StructType(
output_schema=T.StructType(
[
snowpark_types.StructField("result", snowpark_types.BinaryType()),
snowpark_types.StructField("part", snowpark_types.StringType()),
T.StructField("result", T.BinaryType()),
T.StructField("part", T.StringType()),
]
),
input_types=[snowpark_types.ArrayType(), snowpark_types.StringType(), snowpark_types.StringType()],
input_types=[T.ArrayType(), T.StringType(), T.StringType()],
packages=["numpy", "cloudpickle"],
name=sharded_dot_and_sum_computer,
is_permanent=False,
Expand All @@ -192,12 +195,49 @@ def validate_and_return_dataframe_and_columns(
"""
input_df = df
if columns is None:
columns = [c.name for c in input_df.schema.fields if issubclass(type(c.datatype), snowpark_types._NumericType)]
columns = [c.name for c in input_df.schema.fields if issubclass(type(c.datatype), T._NumericType)]
input_df = input_df.select(columns)
else:
input_df = input_df.select(columns)
for c in input_df.schema.fields:
if not issubclass(type(c.datatype), snowpark_types._NumericType):
if not issubclass(type(c.datatype), T._NumericType):
msg = "Column: {} is not a numeric column"
raise ValueError(msg.format(c.name))
return (input_df, columns)


def weighted_sum(
*,
df: snowpark.DataFrame,
sample_score_column: snowpark.Column,
sample_weight_column: Optional[snowpark.Column] = None,
normalize: bool = False,
statement_params: Dict[str, str],
) -> float:
"""Weighted sum of the sample score column.
Args:
df: Input dataframe.
sample_score_column: Sample score column.
sample_weight_column: Sample weight column.
normalize: If ``False``, return the weighted sum.
Otherwise, return the fraction of weighted sum.
statement_params: Dictionary used for tagging queries for tracking purposes.
Returns:
If ``normalize == True``, return the fraction of weighted sum (float),
else returns the weighted sum (int).
"""
if normalize:
if sample_weight_column is not None:
res = F.sum(sample_score_column * sample_weight_column) / F.sum( # type: ignore[arg-type, operator]
sample_weight_column # type: ignore[arg-type]
)
else:
res = F.avg(sample_score_column) # type: ignore[arg-type]
elif sample_weight_column is not None:
res = F.sum(sample_score_column * sample_weight_column) # type: ignore[arg-type, operator]
else:
res = F.sum(sample_score_column) # type: ignore[arg-type]

return float(df.select(res).collect(statement_params=statement_params)[0][0])
50 changes: 50 additions & 0 deletions snowflake/ml/metrics/accuracy_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional

from snowflake import snowpark
from snowflake.ml._internal import telemetry
from snowflake.ml.metrics import _utils
from snowflake.snowpark import functions as F

_PROJECT = "ModelDevelopment"
_SUBPROJECT = "Metrics"


@telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
def accuracy_score(
*,
df: snowpark.DataFrame,
y_true_col_name: str,
y_pred_col_name: str,
normalize: bool = True,
sample_weight_col_name: Optional[str] = None,
) -> float:
"""
Accuracy classification score.
Note: Currently multilabel classification is not supported.
Args:
df: Input dataframe.
y_true_col_name: Column name representing actual values.
y_pred_col_name: Column name representing predicted values.
normalize: If ``False``, return the number of correctly classified samples.
Otherwise, return the fraction of correctly classified samples.
sample_weight_col_name: Column name representing sample weights.
Returns:
If ``normalize == True``, return the fraction of correctly
classified samples (float), else returns the number of correctly
classified samples (int).
The best performance is 1 with ``normalize == True`` and the number
of samples with ``normalize == False``.
"""
# TODO: Support multilabel classification.
score_column = F.iff(df[y_true_col_name] == df[y_pred_col_name], 1, 0) # type: ignore[arg-type]
return _utils.weighted_sum(
df=df,
sample_score_column=score_column,
sample_weight_column=df[sample_weight_col_name] if sample_weight_col_name else None,
normalize=normalize,
statement_params=telemetry.get_statement_params(_PROJECT, _SUBPROJECT),
)
10 changes: 10 additions & 0 deletions snowflake/ml/model/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@ package(default_visibility = ["//visibility:public"])
py_library(
name = "type_hints",
srcs = ["type_hints.py"],
deps = [
"//snowflake/ml/framework:framework"
]
)

py_library(
name = "model_signature",
srcs = ["model_signature.py"],
deps = [
":type_hints",
"//snowflake/ml/_internal/utils:formatting",
"//snowflake/ml/_internal/utils:identifier",
],
)
Expand All @@ -22,6 +26,7 @@ py_library(
deps = [
"//snowflake/ml/_internal:env",
"//snowflake/ml/_internal:env_utils",
"//snowflake/ml/_internal/utils:formatting",
],
)

Expand All @@ -38,6 +43,7 @@ py_library(
srcs = ["_model_meta.py"],
deps = [
":_env",
":type_hints",
":model_signature",
"//snowflake/ml/_internal:env",
"//snowflake/ml/_internal:env_utils",
Expand All @@ -62,6 +68,8 @@ py_library(
srcs = ["_deployer.py"],
deps = [
":_udf_util",
":type_hints",
":model_signature",
"//snowflake/ml/_internal/utils:identifier",
],
)
Expand All @@ -79,6 +87,7 @@ py_library(
name = "_model",
srcs = ["_model.py"],
deps = [
":_env",
":_model_handler",
":_model_meta",
":custom_model",
Expand All @@ -88,6 +97,7 @@ py_library(
"//snowflake/ml/model/_handlers:sklearn",
"//snowflake/ml/model/_handlers:snowmlmodel",
"//snowflake/ml/model/_handlers:xgboost",
"//snowflake/ml/framework:framework"
],
)

Expand Down
Loading

0 comments on commit 86d00b9

Please sign in to comment.