Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-802269 Added regexp_extract,signum,substring_index,collect_list #135

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions src/main/scala/com/snowflake/snowpark/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3140,6 +3140,144 @@ object functions {
*/
def listagg(col: Column): Column = listagg(col, "", isDistinct = false)

/**
* Signature - snowflake.snowpark.functions.regexp_extract
* (value: Union[Column, str], regexp: Union[Column, str], idx: int)
* Column
* Extract a specific group matched by a regex, from the specified string
* column. If the regex did not match, or the specified group did not match,
* an empty string is returned.
* Example:
* from snowflake.snowpark.functions import regexp_extract
* df = session.createDataFrame([["id_20_30", 10], ["id_40_50", 30]],
* ["id", "age"])
* df.select(regexp_extract("id", r"(\d+)", 1).alias("RES")).show()
*
* "RES" |
*
* |20 |
* |40 |
*
* Note: non-greedy tokens such as are not supported
sfc-gh-sjayabalan marked this conversation as resolved.
Show resolved Hide resolved
* @since 1.12.1
* @return Column object.
*/
def regexp_extract(
colName: Column,
exp: String,
position: Int,
Occurences: Int,
grpIdx: Int): Column = {
when(colName.is_null, lit(null))
.otherwise(
coalesce(
builtin("REGEXP_SUBSTR")(
colName,
lit(exp),
lit(position),
lit(Occurences),
lit("ce"),
lit(grpIdx)),
lit("")))
}

/**
* Returns the sign of its argument:
*
* - -1 if the argument is negative.
* - 1 if it is positive.
* - 0 if it is 0.
*
* Args:
* col: The column to evaluate its sign
*
* Example::
* >>> df = session.create_dataframe([(-2, 2, 0)], ["a", "b", "c"])
* >>> df.select(sign("a").alias("a_sign"), sign("b").alias("b_sign"),
* sign("c").alias("c_sign")).show()
* ----------------------------------
* |"A_SIGN" |"B_SIGN" |"C_SIGN" |
* ----------------------------------
* |-1 |1 |0 |
* ----------------------------------
* @since 1.12.1
* @param e Column to calculate the sign.
* @return Column object.
*/
def signum(colName: Column): Column = {
builtin("SIGN")(colName)
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Returns the sign of the given column. Returns either 1 for positive,
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
* 0 for 0 or
* NaN, -1 for negative and null for null.
* NOTE: if string values are provided snowflake will attempts to cast.
* If it casts correctly, returns the calculation,
* if not an error will be thrown
* @since 1.12.1
* @param columnName Name of the column to calculate the sign.
* @return Column object.
*/
def signum(columnName: String): Column = {
signum(col(columnName))
}

/**
* Returns the substring from string str before count occurrences
* of the delimiter delim. If count is positive,
* everything the left of the final delimiter (counting from left)
* is returned. If count is negative, every to the right of the
* final delimiter (counting from the right) is returned.
* substring_index performs a case-sensitive match when searching for delim.
* @since 1.12.1
*/
def substring_index(str: Column, delim: String, count: Int): Column = {
when(
lit(count) < lit(0),
callBuiltin(
"substring",
lit(str),
callBuiltin("regexp_instr", sqlExpr(s"reverse(${str}, ${delim}, 1, abs(${count}), 0"))))
.otherwise(
callBuiltin(
"substring",
lit(str),
1,
callBuiltin("regexp_instr", col("str"), lit(delim), 1, lit(count), 1)))
}

/**
* Wrapper for Snowflake built-in collect_list function. Get the values of array column.
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved
* Returns the input values, pivoted into an ARRAY. If the input is empty, an empty
* ARRAY is returned.
*
* Example::
* >>> df = session.create_dataframe([[1], [2], [3], [1]], schema=["a"])
* >>> df.select(array_agg("a", True).alias("result")).show()
* ------------
* |"RESULT" |
* ------------
* |[ |
* | 1, |
* | 2, |
* | 3 |
* |] |
* ------------
* @since 1.10.0
* @param c Column to be collect.
* @return The array.
*/
def collect_list(c: Column): Column = array_agg(c)
sfc-gh-mrojas marked this conversation as resolved.
Show resolved Hide resolved

/**
* Wrapper for Snowflake built-in collect_list function. Get the values of array column.
* @since 1.10.0
* @param s Column name to be collected.
* @return The array.
*/
def collect_list(s: String): Column = array_agg(col(s))

/**
* Invokes a built-in snowflake function with the specified name and arguments.
* Arguments can be of two types
Expand Down
38 changes: 38 additions & 0 deletions src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2177,7 +2177,45 @@ trait FunctionSuite extends TestData {
expected,
sort = false)
}
test("regexp_extract") {
val data = Seq("A MAN A PLAN A CANAL").toDF("a")
var expected = Seq(Row("MAN"))
checkAnswer(
data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 1, 1)),
expected,
sort = false)
expected = Seq(Row("PLAN"))
checkAnswer(
data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 2, 1)),
expected,
sort = false)
expected = Seq(Row("CANAL"))
checkAnswer(
data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 3, 1)),
expected,
sort = false)

expected = Seq(Row(null))
checkAnswer(
data.select(regexp_extract(col("a"), "A\\W+(\\w+)", 1, 4, 1)),
expected,
sort = false)
}
test("signum") {
val df = Seq(1, -2, 0).toDF("a")
checkAnswer(df.select(signum(col("a"))), Seq(Row(1), Row(-1), Row(0)), sort = false)
}

test("collect_list") {
assert(monthlySales.select(collect_list(col("amount"))).collect()(0).get(0).toString ==
"[\n 10000,\n 400,\n 4500,\n 35000,\n 5000,\n 3000,\n 200,\n 90500,\n 6000,\n " +
"5000,\n 2500,\n 9500,\n 8000,\n 10000,\n 800,\n 4500\n]")

}
test("substring_index") {
val df = Seq("It was the best of times, it was the worst of times").toDF("a")
checkAnswer(df.select(substring_index(col("a"), "was", 1)), Seq(Row(7)), sort = false)
}
}

class EagerFunctionSuite extends FunctionSuite with EagerSession
Expand Down