Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast

import mlflow
from mlflow.entities import Experiment, Run
from mlflow.entities import Experiment, LifecycleStage, Run
from mlflow.exceptions import MlflowException
from mlflow.store.db.db_types import DATABASE_ENGINES

import zenml
Expand Down Expand Up @@ -195,6 +196,24 @@ def prepare_step_run(self, info: "StepRunInfo") -> None:
experiment_name=experiment_name, run_name=info.run_name
)

# Validate that the run exists before attempting to resume it
if run_id:
try:
run = mlflow.get_run(run_id)
if run.info.lifecycle_stage == LifecycleStage.DELETED:
logger.warning(
f"Run with id {run_id} is in 'deleted' state. "
"Creating a new run instead."
)
run_id = None
except MlflowException as e:
# Run doesn't exist on the MLflow server, create a new one
logger.warning(
f"Run with id {run_id} not found in MLflow tracking server. "
f"Creating a new run instead. Error: {e}"
)
run_id = None

tags = settings.tags.copy()
tags.update(self._get_internal_tags())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
import os
from contextlib import ExitStack as does_not_raise
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4

import pytest
from mlflow.exceptions import MlflowException
from pydantic import ValidationError

from zenml.enums import StackComponentType
Expand Down Expand Up @@ -245,3 +247,69 @@ def test_mlflow_experiment_tracker_set_config(local_stack: Stack) -> None:
assert os.environ[DATABRICKS_PASSWORD] == "password"
assert os.environ[DATABRICKS_TOKEN] == "token1234"
assert os.environ[DATABRICKS_HOST] == "https://databricks.com"


@patch("mlflow.start_run")
@patch("mlflow.get_run")
@patch("mlflow.get_experiment_by_name")
@patch("mlflow.set_experiment")
def test_mlflow_experiment_tracker_handles_missing_run(
mock_set_experiment: MagicMock,
mock_get_experiment: MagicMock,
mock_get_run: MagicMock,
mock_start_run: MagicMock,
) -> None:
"""Tests that the MLflow experiment tracker handles missing runs gracefully.

This test verifies the fix for issue #4207 where MLflow would crash
when trying to resume a run that doesn't exist on the server.
"""
# Setup mocks
mock_experiment = MagicMock()
mock_experiment.experiment_id = "test_experiment_id"
mock_get_experiment.return_value = mock_experiment

# Simulate a run that doesn't exist on the MLflow server
mock_get_run.side_effect = MlflowException("RESOURCE_DOES_NOT_EXIST")

# Create experiment tracker
tracker = MLFlowExperimentTracker(
name="test_tracker",
id=uuid4(),
config=MLFlowExperimentTrackerConfig(
tracking_uri="file:///tmp/mlflow",
),
flavor="mlflow",
type=StackComponentType.EXPERIMENT_TRACKER,
user=uuid4(),
created=datetime.now(),
updated=datetime.now(),
)

# Create a mock StepRunInfo
mock_step_info = MagicMock()
mock_step_info.pipeline.name = "test_pipeline"
mock_step_info.run_name = "test_run"
mock_step_info.pipeline_step_name = "test_step"

# Mock get_run_id to return a stale run_id
with patch.object(tracker, "get_run_id", return_value="stale_run_id"):
with patch.object(
tracker,
"get_settings",
return_value=MagicMock(
experiment_name=None,
tags={},
nested=False,
),
):
# This should not raise an exception, even though the run doesn't exist
tracker.prepare_step_run(mock_step_info)

# Verify that start_run was called with run_id=None (creating a new run)
mock_start_run.assert_called_once()
call_kwargs = mock_start_run.call_args[1]
assert call_kwargs["run_id"] is None, (
"Expected run_id to be None when run doesn't exist"
)
assert call_kwargs["run_name"] == "test_run"
Loading