Skip to content

Commit

Permalink
Make dask a required dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
rosswhitfield committed Oct 28, 2022
1 parent 0f94024 commit f864926
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 26 deletions.
1 change: 1 addition & 0 deletions .github/workflows/conda_env/environment_minimal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ dependencies:
- pytest-timeout
- psutil
- coverage!=6.3
- dask<=2022.10.0
8 changes: 4 additions & 4 deletions ipsframework/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2229,15 +2229,15 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppw=None,
"""

if use_dask:
if TaskPool.dask and self.serial_pool:
if TaskPool.dask and TaskPool.distributed and self.serial_pool:
self.dask_pool = True
if use_shifter and not self.shifter:
self.services.error("Requested to run dask within shifter but shifter not available")
raise Exception("shifter not found")
raise RuntimeError("shifter not found")
else:
return self.submit_dask_tasks(block, dask_nodes, dask_ppw, use_shifter, dask_worker_plugin, dask_worker_per_gpu)
elif not TaskPool.dask:
self.services.warning("Requested use_dask but cannot because import dask failed")
elif not TaskPool.dask or not TaskPool.distributed:
raise RuntimeError("Requested use_dask but cannot because import dask or distributed failed")
elif not self.serial_pool:
self.services.warning("Requested use_dask but cannot because multiple processors requested")

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
zip_safe=True,
install_requires=[
'urllib3',
'configobj'
'configobj',
'dask',
'distributed'
]
)
3 changes: 0 additions & 3 deletions tests/helloworld/test_helloworld.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import shutil
import json
import pytest
from ipsframework import Framework, TaskPool


Expand Down Expand Up @@ -211,8 +210,6 @@ def test_helloworld_task_pool(tmpdir, capfd):


def test_helloworld_task_pool_dask(tmpdir, capfd):
pytest.importorskip("dask")
pytest.importorskip("distributed")
assert TaskPool.dask is not None

data_dir = os.path.dirname(__file__)
Expand Down
18 changes: 0 additions & 18 deletions tests/new/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ def write_basic_config_and_platform_files(tmpdir, timeout='', logfile='', errfil


def test_dask(tmpdir):
pytest.importorskip("dask")
pytest.importorskip("distributed")
platform_file, config_file = write_basic_config_and_platform_files(tmpdir, value=1)

framework = Framework(config_file_list=[str(config_file)],
Expand Down Expand Up @@ -128,8 +126,6 @@ def test_dask(tmpdir):
@pytest.mark.skipif(shutil.which('shifter') is not None,
reason="This tests only works if shifter doesn't exist")
def test_dask_shifter_fail(tmpdir):
pytest.importorskip("dask")
pytest.importorskip("distributed")
platform_file, config_file = write_basic_config_and_platform_files(tmpdir, value=1, shifter=True)

framework = Framework(config_file_list=[str(config_file)],
Expand Down Expand Up @@ -164,8 +160,6 @@ def test_dask_shifter_fail(tmpdir):


def test_dask_fake_shifter(tmpdir, monkeypatch):
pytest.importorskip("dask")
pytest.importorskip("distributed")

shifter = tmpdir.join("shifter")
shifter.write("#!/bin/bash\necho Running $@ in shifter >> shifter.log\n$@\n")
Expand Down Expand Up @@ -238,8 +232,6 @@ def test_dask_fake_shifter(tmpdir, monkeypatch):


def test_dask_timeout(tmpdir):
pytest.importorskip("dask")
pytest.importorskip("distributed")
platform_file, config_file = write_basic_config_and_platform_files(tmpdir, timeout=1, value=100)

framework = Framework(config_file_list=[str(config_file)],
Expand Down Expand Up @@ -289,8 +281,6 @@ def test_dask_timeout(tmpdir):


def test_dask_nproc(tmpdir):
pytest.importorskip("dask")
pytest.importorskip("distributed")
platform_file, config_file = write_basic_config_and_platform_files(tmpdir, nproc=2, value=1)

# Running with NPROC=2 should prevent dask from running and revert to normal task pool
Expand Down Expand Up @@ -325,8 +315,6 @@ def test_dask_nproc(tmpdir):


def test_dask_logfile(tmpdir):
pytest.importorskip("dask")
pytest.importorskip("distributed")

exe = tmpdir.join("stdouterr_write.sh")
exe.write("#!/bin/bash\necho Running $1\n>&2 echo ERROR $1\n")
Expand Down Expand Up @@ -371,8 +359,6 @@ def test_dask_logfile(tmpdir):


def test_dask_logfile_errfile(tmpdir):
pytest.importorskip("dask")
pytest.importorskip("distributed")

exe = tmpdir.join("stdouterr_write.sh")
exe.write("#!/bin/bash\necho Running $1\n>&2 echo ERROR $1\n")
Expand Down Expand Up @@ -473,8 +459,6 @@ def test_dask_shifter_on_cori(tmpdir):


def test_dask_with_1_gpu(tmpdir):
pytest.importorskip("dask")
pytest.importorskip("distributed")
platform_file, config_file = write_basic_config_and_platform_files(tmpdir, gpus=1)

framework = Framework(config_file_list=[str(config_file)],
Expand Down Expand Up @@ -514,8 +498,6 @@ def test_dask_with_1_gpu(tmpdir):


def test_dask_with_2_gpus(tmpdir):
pytest.importorskip("dask")
pytest.importorskip("distributed")
platform_file, config_file = write_basic_config_and_platform_files(tmpdir, gpus=2)

framework = Framework(config_file_list=[str(config_file)],
Expand Down

0 comments on commit f864926

Please sign in to comment.