Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for serverless operation #297

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
## Change History
All notable changes to the Databricks Labs Data Generator will be documented in this file.

### Version 0.4.0 Hotfix 1

#### Fixed
* Fixed issue with running on serverless environment


### Version 0.4.0

#### Changed
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ details of use and many examples.

Release notes and details of the latest changes for this specific release
can be found in the GitHub repository
[here](https://github.com/databrickslabs/dbldatagen/blob/release/v0.4.0/CHANGELOG.md)
[here](https://github.com/databrickslabs/dbldatagen/blob/release/v0.4.0post1/CHANGELOG.md)

# Installation

Expand Down
2 changes: 1 addition & 1 deletion dbldatagen/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_version(version):
return version_info


__version__ = "0.4.0" # DO NOT EDIT THIS DIRECTLY! It is managed by bumpversion
__version__ = "0.4.0post1" # DO NOT EDIT THIS DIRECTLY! It is managed by bumpversion
__version_info__ = get_version(__version__)


Expand Down
15 changes: 12 additions & 3 deletions dbldatagen/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,21 @@ def _setupPandas(self, pandasBatchSize):
self.logger.info("Spark version: %s", self.sparkSession.version)
if str(self.sparkSession.version).startswith("3"):
self.logger.info("Using spark 3.x")
self.sparkSession.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
try:
self.sparkSession.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
except Exception: # pylint: disable=broad-exception-caught
pass
else:
self.sparkSession.conf.set("spark.sql.execution.arrow.enabled", "true")
try:
self.sparkSession.conf.set("spark.sql.execution.arrow.enabled", "true")
except Exception: # pylint: disable=broad-exception-caught
pass

if self._batchSize is not None:
self.sparkSession.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", self._batchSize)
try:
self.sparkSession.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", self._batchSize)
except Exception: # pylint: disable=broad-exception-caught
pass

def _setupLogger(self):
"""Set up logging
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
author = 'Databricks Inc'

# The full version, including alpha/beta/rc tags
release = "0.4.0" # DO NOT EDIT THIS DIRECTLY! It is managed by bumpversion
release = "0.4.0post1" # DO NOT EDIT THIS DIRECTLY! It is managed by bumpversion

# -- General configuration ---------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion python/.bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.0
current_version = 0.4.0post1
commit = False
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+){0,1}(?P<release>\D*)(?P<build>\d*)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

setuptools.setup(
name="dbldatagen",
version="0.4.0",
version="0.4.0post1",
author="Ronan Stokes, Databricks",
description="Databricks Labs - PySpark Synthetic Data Generator",
long_description=long_description,
Expand Down
75 changes: 75 additions & 0 deletions tests/test_serverless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest

import dbldatagen as dg


class TestSimulatedServerless:
"""Serverless operation and other forms of shared spark cloud operation often have restrictions on what
features may be used.

In this set of tests, we'll simulate some of the common restrictions found in Databricks serverless and shared
environments to ensure that common operations still work.

Serverless operations have some of the following restrictions:

- Spark config settings cannot be written

"""

@pytest.fixture(scope="class")
def serverlessSpark(self):
from unittest.mock import MagicMock

sparkSession = dg.SparkSingleton.getLocalInstance("unit tests")

oldSetMethod = sparkSession.conf.set
oldGetMethod = sparkSession.conf.get
sparkSession.conf.set = MagicMock(
side_effect=ValueError("Setting value prohibited in simulated serverless env."))
sparkSession.conf.get = MagicMock(
side_effect=ValueError("Getting value prohibited in simulated serverless env."))

yield sparkSession

sparkSession.conf.set = oldSetMethod
sparkSession.conf.get = oldGetMethod

def test_basic_data(self, serverlessSpark):
from pyspark.sql.types import FloatType, IntegerType, StringType

row_count = 1000 * 100
column_count = 10
testDataSpec = (
dg.DataGenerator(serverlessSpark, name="test_data_set1", rows=row_count, partitions=4)
.withIdOutput()
.withColumn(
"r",
FloatType(),
expr="floor(rand() * 350) * (86400 + 3600)",
numColumns=column_count,
)
.withColumn("code1", IntegerType(), minValue=100, maxValue=200)
.withColumn("code2", "integer", minValue=0, maxValue=10, random=True)
.withColumn("code3", StringType(), values=["online", "offline", "unknown"])
.withColumn(
"code4", StringType(), values=["a", "b", "c"], random=True, percentNulls=0.05
)
.withColumn(
"code5", "string", values=["a", "b", "c"], random=True, weights=[9, 1, 1]
)
)

dfTestData = testDataSpec.build()

@pytest.mark.parametrize("providerName, providerOptions", [
("basic/user", {"rows": 50, "partitions": 4, "random": False, "dummyValues": 0}),
("basic/user", {"rows": 100, "partitions": -1, "random": True, "dummyValues": 0})
])
def test_basic_user_table_retrieval(self, providerName, providerOptions, serverlessSpark):
ds = dg.Datasets(serverlessSpark, providerName).get(**providerOptions)
assert ds is not None, f"""expected to get dataset specification for provider `{providerName}`
with options: {providerOptions}
"""
df = ds.build()

assert df.count() >= 0
Loading