Skip to content

Commit

Permalink
Instead of writing result to dynamic table, call a UDF to write stream
Browse files Browse the repository at this point in the history
Signed-off-by: sfc-gh-mvashishtha <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha committed Dec 12, 2024
1 parent 6cb6e16 commit ef5957b
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 25 deletions.
59 changes: 39 additions & 20 deletions snowpark_streaming_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from snowflake.snowpark.session import Session
from snowflake.snowpark.functions import parse_json, col
from snowflake.snowpark.types import StructType, MapType, StructField, StringType
from snowflake.snowpark.types import StructType, MapType, StructField, StringType, IntegerType, FloatType, TimestampType
import logging; logging.getLogger("snowflake.snowpark").setLevel(logging.DEBUG)
import pandas as pd
from snowflake.snowpark.async_job import AsyncJob


# Function to generate random JSON data
Expand Down Expand Up @@ -34,6 +35,16 @@ def generate_json_data():
static_df = session.table("static_df")


kafka_event_schema = StructType(
[
StructField(column_identifier="ID", datatype=IntegerType()),
StructField(column_identifier="NAME", datatype=StringType()),
StructField(column_identifier="PRICE", datatype=FloatType()),
StructField(column_identifier="TIMESTAMP", datatype=TimestampType()),
]
)


# Subscribe to 1 topic
kafka_ingest_df = (
session
Expand All @@ -42,29 +53,37 @@ def generate_json_data():
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
.option("topic", "topic1")
.option("partition_id", 1)
.schema(
StructType(
[
StructField(column_identifier="KEY", datatype=StringType()),
StructField(column_identifier="STREAM_VALUE", datatype=StringType())
]
)
)
.schema(kafka_event_schema)
.load()
)

# Join kafka ingest to static table, and write result to dynamic table.
joined = kafka_ingest_df.join(static_df, on='KEY')
joined.create_or_replace_dynamic_table(
'dynamic_join_result',
warehouse=session.connection.warehouse,
lag='1 hour',

)
RESULT_TABLE_NAME = "dynamic_join_result";

transformed_df = kafka_ingest_df \
.select(col("id"), col("timestamp"), col("name")) \
.filter(col("price") > 100.0)


"""
This query looks like
SELECT write_stream_udf('dynamic_join_result', "id", "timestamp", "name")
FROM (SELECT id,
name,
price,
timestamp
FROM ( TABLE (my_streaming_udtf('host1:port1,host2:port2', 'topic1', 1
:: INT
) )))
WHERE ( "price" > 100.0 )
"""

streaming_query: AsyncJob = transformed_df \
.writeStream \
.toTable(RESULT_TABLE_NAME)

streaming_query.cancel()

# Clean up dynamic table.
drop_result = session.connection.cursor().execute('DROP DYNAMIC TABLE dynamic_join_result;')
assert drop_result is not None

# # Write streaming dataframe to output data sink
# sink_query = (
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,12 +1000,13 @@ def load(self) -> DataFrame:
bootstrap_servers = self._cur_options["kafka.bootstrap.servers".upper()]
topic = self._cur_options["topic".upper()]
partition_id = self._cur_options["partition_id".upper()]

self._session.custom_package_usage_config['force_push'] = True
self._session.custom_package_usage_config['enabled'] = True
self._session.add_import(snowflake.snowpark.kafka_ingest_udtf.__file__, import_path="snowflake.snowpark.kafka_ingest_udtf")
self._session.add_packages(["python-confluent-kafka"])

self._session.sql("create or replace stage mystage").collect()

kafka_udtf = udtf(
KafkaFetch,
output_schema=self._user_schema,
Expand Down
35 changes: 32 additions & 3 deletions src/snowflake/snowpark/dataframe_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

import snowflake.snowpark # for forward references of type hints
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
from snowflake.snowpark.write_stream_to_table import write_stream_to_table
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
CopyIntoLocationNode,
SaveMode,
SnowflakeCreateTable,
TableCreationSource,
)
from snowflake.snowpark.async_job import AsyncJob
from snowflake.snowpark._internal.ast.utils import (
build_expr_from_snowpark_column_or_col_name,
debug_check_missing_ast,
Expand All @@ -40,8 +42,9 @@
warning,
)
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
from snowflake.snowpark.types import StringType
from snowflake.snowpark.column import Column, _to_col_if_str
from snowflake.snowpark.functions import sql_expr
from snowflake.snowpark.functions import sql_expr, udf, lit, col
from snowflake.snowpark.mock._connection import MockServerConnection
from snowflake.snowpark.row import Row

Expand Down Expand Up @@ -917,6 +920,32 @@ def parquet(

saveAsTable = save_as_table


class DataStreamWriter(DataFrameWriter):
def start(self):
raise NotImplementedError("cannot write a data stream yet.")
def toTable(self, table_name: str) -> AsyncJob:
self._dataframe.session.custom_package_usage_config['force_push'] = True
self._dataframe.session.custom_package_usage_config['enabled'] = True
self._dataframe.session.add_import(snowflake.snowpark.write_stream_to_table.__file__, import_path="snowflake.snowpark.write_stream_to_table")
self._dataframe.session.sql("create or replace stage mystage").collect()


write_stream_udf = udf(
write_stream_to_table,
input_types=
[
StringType(),
*(f.datatype for f in self._dataframe.schema.fields)
],
is_permanent=True,
replace=True,
name='write_stream_udf',
stage_location="@mystage"
)

return self._dataframe.select(write_stream_udf(
lit(table_name),
*(
col(f.name)
for f in self._dataframe.schema.fields
)
)).collect_nowait()
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/kafka_ingest_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def process(self, bootstrap_servers: str, topic: str, partition_id: int):
# logging.info(f"Received message: {msg.value().decode('utf-8')}")
# yield (msg.value().decode('utf-8'),)

yield (str(i), str(generate_json_data()))
yield tuple(generate_json_data().values())

except:
logging.error("Consumer Error")
Expand Down

0 comments on commit ef5957b

Please sign in to comment.