Skip to content

Commit

Permalink
move to conftest
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Nov 19, 2024
1 parent 2ed8a4d commit 902e4e0
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 25 deletions.
8 changes: 0 additions & 8 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@
from tests.common import EventCounterCallback


@pytest.fixture(autouse=True)
def setup_mlflow_tracking(monkeypatch, tmp_path):
mlflow = pytest.importorskip('mlflow')
tracking_uri = str(tmp_path / 'mlruns')
monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri)
os.makedirs(tracking_uri, exist_ok=True)


def test_callbacks_map_to_events():
# callback methods must be 1:1 mapping with events
# exception for private methods
Expand Down
8 changes: 0 additions & 8 deletions tests/callbacks/test_loggers_across_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@
)


@pytest.fixture(autouse=True)
def setup_mlflow_tracking(monkeypatch, tmp_path):
mlflow = pytest.importorskip('mlflow')
tracking_uri = str(tmp_path / 'mlruns')
monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri)
os.makedirs(tracking_uri, exist_ok=True)


@pytest.mark.parametrize('logger_cls', get_cbs_and_marks(loggers=True))
@pytest.mark.parametrize('callback_cls', get_cbs_and_marks(callbacks=True))
@pytest.mark.filterwarnings('ignore::UserWarning')
Expand Down
8 changes: 8 additions & 0 deletions tests/fixtures/autouse_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,11 @@ def remove_run_name_env_var():
os.environ['COMPOSER_RUN_NAME'] = composer_run_name
if run_name is not None:
os.environ['RUN_NAME'] = run_name


@pytest.fixture(autouse=True)
def setup_mlflow_tracking(monkeypatch, tmp_path):
mlflow = pytest.importorskip('mlflow')
tracking_uri = str(tmp_path / 'mlruns')
monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri)
os.makedirs(tracking_uri, exist_ok=True)
8 changes: 0 additions & 8 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@
)


@pytest.fixture(autouse=True)
def setup_mlflow_tracking(monkeypatch, tmp_path):
mlflow = pytest.importorskip('mlflow')
tracking_uri = str(tmp_path / 'mlruns')
monkeypatch.setenv(mlflow.environment_variables.MLFLOW_TRACKING_URI.name, tracking_uri)
os.makedirs(tracking_uri, exist_ok=True)


def _get_latest_mlflow_run(experiment_name, tracking_uri=None):
pytest.importorskip('mlflow')
from mlflow import MlflowClient
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ def test_busy_wait_for_local_rank_zero(tmp_path):
end_time = time.time()
total_time = end_time - start_time
gathered_times = dist.all_gather_object(total_time)
assert os.listdir(gathered_tmp_path) == []
assert os.listdir(gathered_tmp_path) == ['mlruns']
assert len(gathered_times) == 2
assert abs(gathered_times[0] - gathered_times[1]) < 0.1

0 comments on commit 902e4e0

Please sign in to comment.