Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for RAPIDS in TPC-H benchmarks #1218

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ci/environment-rapids.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# This is an addition to ci/environment.yml.
# Add cudf and downgrade some pinned dependencies.
channels:
- rapidsai-nightly
- conda-forge
- nvidia
dependencies:
- dask-cudf =24.02
- dask-cuda =24.02
- pandas ==1.5.3 # pinned by cudf
- pynvml ==11.4.1 # pinned by dask-cuda
3 changes: 1 addition & 2 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ dependencies:
# - AB_environments/AB_sample.conda.yaml
########################################################

- python >=3.9
- python >=3.9,<3.11
- pip
- coiled >=0.2.54
- numpy ==1.26.2
- pandas ==2.1.4
- dask ==2023.12.0
- distributed ==2023.12.0
- dask-expr ==0.2.8
- dask-labextension ==7.0.0
- dask-ml ==2023.3.24
- fsspec ==2023.12.1
Expand Down
50 changes: 38 additions & 12 deletions tests/tpch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def pytest_addoption(parser):
parser.addoption("--local", action="store_true", default=False, help="")
parser.addoption("--rapids", action="store_true", default=False, help="")
parser.addoption("--cloud", action="store_false", dest="local", help="")
parser.addoption("--restart", action="store_true", default=True, help="")
parser.addoption("--no-restart", action="store_false", dest="restart", help="")
Expand Down Expand Up @@ -48,6 +49,11 @@ def local(request):
return request.config.getoption("local")


@pytest.fixture(scope="session")
def rapids(request):
return request.config.getoption("rapids")


@pytest.fixture(scope="session")
def restart(request):
return request.config.getoption("restart")
Expand Down Expand Up @@ -186,6 +192,7 @@ def cluster_spec(scale):
@pytest.fixture(scope="module")
def cluster(
local,
rapids,
scale,
module,
dask_env_variables,
Expand All @@ -195,19 +202,38 @@ def cluster(
make_chart,
):
if local:
with LocalCluster() as cluster:
yield cluster
else:
kwargs = dict(
name=f"tpch-{module}-{scale}-{name}",
environ=dask_env_variables,
tags=github_cluster_tags,
region="us-east-2",
**cluster_spec,
)
with dask.config.set({"distributed.scheduler.worker-saturation": "inf"}):
with coiled.Cluster(**kwargs) as cluster:
if not rapids:
with LocalCluster() as cluster:
yield cluster
else:
from dask_cuda import LocalCUDACluster

with dask.config.set(
{"dataframe.backend": "cudf", "dataframe.shuffle.method": "tasks"}
):
with LocalCUDACluster(rmm_pool_size="24GB") as cluster:
yield cluster
else:
if not rapids:
kwargs = dict(
name=f"tpch-{module}-{scale}-{name}",
environ=dask_env_variables,
tags=github_cluster_tags,
region="us-east-2",
**cluster_spec,
)
with dask.config.set({"distributed.scheduler.worker-saturation": "inf"}):
with coiled.Cluster(**kwargs) as cluster:
yield cluster
else:
# should be using Coiled for this
from dask_cuda import LocalCUDACluster

with dask.config.set(
{"dataframe.backend": "cudf", "dataframe.shuffle.method": "tasks"}
):
with LocalCUDACluster(rmm_pool_size="24GB") as cluster:
yield cluster


@pytest.fixture
Expand Down
48 changes: 25 additions & 23 deletions tests/tpch/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

dd = pytest.importorskip("dask_expr")

BLOCKSIZE = "default"


def test_query_1(client, dataset_path, fs):
VAR1 = datetime(1998, 9, 2)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE)

lineitem_filtered = lineitem_ds[lineitem_ds.l_shipdate <= VAR1]
lineitem_filtered["sum_qty"] = lineitem_filtered.l_quantity
Expand Down Expand Up @@ -50,11 +52,11 @@ def test_query_2(client, dataset_path, fs):
var2 = "BRASS"
var3 = "EUROPE"

region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
nation_filtered = dd.read_parquet(dataset_path + "nation", filesystem=fs)
supplier_filtered = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
part_filtered = dd.read_parquet(dataset_path + "part", filesystem=fs)
partsupp_filtered = dd.read_parquet(dataset_path + "partsupp", filesystem=fs)
region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs, blocksize=BLOCKSIZE)
nation_filtered = dd.read_parquet(dataset_path + "nation", filesystem=fs, blocksize=BLOCKSIZE)
supplier_filtered = dd.read_parquet(dataset_path + "supplier", filesystem=fs, blocksize=BLOCKSIZE)
part_filtered = dd.read_parquet(dataset_path + "part", filesystem=fs, blocksize=BLOCKSIZE)
partsupp_filtered = dd.read_parquet(dataset_path + "partsupp", filesystem=fs, blocksize=BLOCKSIZE)

region_filtered = region_ds[(region_ds["r_name"] == var3)]
r_n_merged = nation_filtered.merge(
Expand Down Expand Up @@ -118,9 +120,9 @@ def test_query_3(client, dataset_path, fs):
var1 = datetime.strptime("1995-03-15", "%Y-%m-%d")
var2 = "BUILDING"

lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
cutomer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
lineitem_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs, blocksize=BLOCKSIZE)
cutomer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs, blocksize=BLOCKSIZE)

lsel = lineitem_ds.l_shipdate > var1
osel = orders_ds.o_orderdate < var1
Expand All @@ -144,8 +146,8 @@ def test_query_4(client, dataset_path, fs):
date1 = datetime.strptime("1993-10-01", "%Y-%m-%d")
date2 = datetime.strptime("1993-07-01", "%Y-%m-%d")

line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs, blocksize=BLOCKSIZE)

lsel = line_item_ds.l_commitdate < line_item_ds.l_receiptdate
osel = (orders_ds.o_orderdate < date1) & (orders_ds.o_orderdate >= date2)
Expand All @@ -168,12 +170,12 @@ def test_query_5(client, dataset_path, fs):
date1 = datetime.strptime("1994-01-01", "%Y-%m-%d")
date2 = datetime.strptime("1995-01-01", "%Y-%m-%d")

region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs)
nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
region_ds = dd.read_parquet(dataset_path + "region", filesystem=fs, blocksize=BLOCKSIZE)
nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs, blocksize=BLOCKSIZE)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs, blocksize=BLOCKSIZE)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs, blocksize=BLOCKSIZE)
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs, blocksize=BLOCKSIZE)

rsel = region_ds.r_name == "ASIA"
osel = (orders_ds.o_orderdate >= date1) & (orders_ds.o_orderdate < date2)
Expand All @@ -198,7 +200,7 @@ def test_query_6(client, dataset_path, fs):
date2 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var3 = 24

line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE)

sel = (
(line_item_ds.l_shipdate >= date1)
Expand All @@ -217,11 +219,11 @@ def test_query_7(client, dataset_path, fs):
var1 = datetime.strptime("1995-01-01", "%Y-%m-%d")
var2 = datetime.strptime("1997-01-01", "%Y-%m-%d")

nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs)
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs)
nation_ds = dd.read_parquet(dataset_path + "nation", filesystem=fs, blocksize=BLOCKSIZE)
customer_ds = dd.read_parquet(dataset_path + "customer", filesystem=fs, blocksize=BLOCKSIZE)
line_item_ds = dd.read_parquet(dataset_path + "lineitem", filesystem=fs, blocksize=BLOCKSIZE)
orders_ds = dd.read_parquet(dataset_path + "orders", filesystem=fs, blocksize=BLOCKSIZE)
supplier_ds = dd.read_parquet(dataset_path + "supplier", filesystem=fs, blocksize=BLOCKSIZE)

lineitem_filtered = line_item_ds[
(line_item_ds["l_shipdate"] >= var1) & (line_item_ds["l_shipdate"] < var2)
Expand Down
Loading