Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GENERAL SUPPORT]: scheduler + batch question #3247

Open
1 task done
VMLC-PV opened this issue Jan 17, 2025 · 6 comments
Open
1 task done

[GENERAL SUPPORT]: scheduler + batch question #3247

VMLC-PV opened this issue Jan 17, 2025 · 6 comments
Assignees
Labels
question Further information is requested

Comments

@VMLC-PV
Copy link

VMLC-PV commented Jan 17, 2025

Question

I am trying to write a scheduler that runs some simulations into batches but still uses AxClient and I cannot make it work properly.
I modified the scheduler tuto to run the job using a multiprocessing pool and write results to a file in a tmp folder where I can later fetch the results from.
I modified my code with the brain function to allow for other people to test (see code below)

Basically when I set the init_seconds_between_polls in
options=SchedulerOptions(run_trials_in_batches=True,init_seconds_between_polls=0.1,trial_type=TrialType.BATCH_TRIAL,batch_size=4),
to a small value (i.e. shorter than the run time of most processes) it fails because of the following:
Scheduler: MetricFetchE INFO: Because branin is an objective, marking trial 19 as TrialStatus.FAILED.

which I guess comes from #L2025 and I don't understand why this happens?
if I set init_seconds_between_polls=4 to ensure that it polls after they are all done then things seem to work. I don't understand what I am missing.
In my real-life case, I don't really know how long I need to wait before polls so I expected that putting a init_seconds_between_polls to a small value would just check often if all jobs in the batch are done and then proceed but it is not what happens...

Please provide any relevant code snippet if applicable.

from ax.utils.measurement.synthetic_functions import branin

import os,sys,json,uuid,time,torch,random
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from functools import partial,reduce
from typing import Any, Dict, NamedTuple, Union, Iterable, Set
import ax
from ax import *
from ax.service.ax_client import AxClient
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax import Models
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.unit_x import UnitX
from ax.modelbridge.transforms.remove_fixed import RemoveFixed
from ax.modelbridge.transforms.log import Log
from ax.runners.synthetic import SyntheticRunner
from ax.core.base_trial import BaseTrial
from ax.core.base_trial import TrialStatus
from ax.core.metric import Metric, MetricFetchResult, MetricFetchE
from ax.core.data import Data
from ax.utils.common.result import Ok, Err
from ax.core.runner import Runner
from ax.core.trial import Trial
from ax.service.scheduler import Scheduler, SchedulerOptions, TrialType
from collections import defaultdict

from torch.multiprocessing import Pool, set_start_method
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from ax.modelbridge.transforms.standardize_y import StandardizeY
from ax.modelbridge.transforms.unit_x import UnitX
from ax.modelbridge.transforms.remove_fixed import RemoveFixed
from ax.modelbridge.transforms.log import Log
from ax.core.base_trial import TrialStatus as T
from ax.core.parameter import RangeParameter, ParameterType

try: # needed for multiprocessing when using pytorch
    set_start_method('spawn')
except RuntimeError:
    pass

class MockJob(NamedTuple):
    """Dummy class to represent a job scheduled on `MockJobQueue`."""

    id: int
    parameters: Dict[str, Union[str, float, int, bool]]

    def run(self, job_id, parameter, tmp_dir = None):

        x1 = parameter['x1']
        x2 = parameter['x2']
        branin_data = branin(x1, x2)
        time.sleep(random.uniform(0.1,3))
        res_dic = {'branin':branin_data}
        print('Job ID:',job_id, 'Parameters:',parameter, 'Results:',res_dic)
        # save the results in tmp folder with the job_id in json format
        if tmp_dir is not None:
            if not os.path.exists(tmp_dir):
                os.makedirs(tmp_dir)
            with open(os.path.join(tmp_dir,str(job_id)+'.json'), 'w') as fp:
                json.dump(res_dic, fp)
            
        

class MockJobQueueClient:
        """Dummy class to represent a job queue where the Ax `Scheduler` will
        deploy trial evaluation runs during optimization.
        """

        jobs: Dict[str, MockJob] = {}

        def __init__(self, pool = None, tmp_dir = None ):
            self.pool = pool
            self.tmp_dir = tmp_dir

        def schedule_job_with_parameters(
            self, parameters: Dict[str, Union[str, float, int, bool]]
        ) -> int:
            """Schedules an evaluation job with given parameters and returns job ID."""
            # Code to actually schedule the job and produce an ID would go here;
            job_id = str(uuid.uuid4())
            mock = MockJob(job_id, parameters)
            # add mock run to the queue q 
            self.jobs[job_id] = MockJob(job_id, parameters)
            self.pool.apply_async(self.jobs[job_id].run, args=(job_id, parameters, self.tmp_dir ))

            return job_id

        def get_job_status(self, job_id: str) -> TrialStatus:
            """ "Get status of the job by a given ID. For simplicity of the example,
            return an Ax `TrialStatus`.
            """
            job = self.jobs[job_id]
            # check if job_id.json exists in the tmp directory
            if os.path.exists(os.path.join(self.tmp_dir,str(job_id)+'.json')):
                #load the results
                with open(os.path.join(self.tmp_dir,str(job_id)+'.json'), 'r') as fp:
                    res_dic = json.load(fp)

                # check is nan in res_dic
                for key in res_dic.keys():
                    if np.isnan(res_dic[key]):
                        return TrialStatus.FAILED
                    
                return TrialStatus.COMPLETED
            else:
                return TrialStatus.RUNNING

        def get_outcome_value_for_completed_job(self, job_id: int) -> Dict[str, float]:
            """Get evaluation results for a given completed job."""
            job = self.jobs[job_id]
            # In a real external system, this would retrieve real relevant outcomes and
            # not a synthetic function value.
            # check if job_id.json exists in the tmp directory
            if os.path.exists(os.path.join(self.tmp_dir,str(job_id)+'.json')):
                #load the results
                with open(os.path.join(self.tmp_dir,str(job_id)+'.json'), 'r') as fp:
                    res_dic = json.load(fp)
                # delete file
                os.remove(os.path.join(self.tmp_dir,str(job_id)+'.json'))
                # print('WE ARE DELETING THE FILE')
                return res_dic
            else:
                raise ValueError('The job is not completed yet')



def get_mock_job_queue_client(MOCK_JOB_QUEUE_CLIENT) -> MockJobQueueClient:
        """Obtain the singleton job queue instance."""
        return MOCK_JOB_QUEUE_CLIENT


class MockJobRunner(Runner):  # Deploys trials to external system.

    def __init__(self, pool = None, tmp_dir = None):
        self.pool = pool
        self.tmp_dir = tmp_dir
        self.MOCK_JOB_QUEUE_CLIENT = MockJobQueueClient(pool = self.pool, tmp_dir = self.tmp_dir)

    def _get_mock_job_queue_client(self) -> MockJobQueueClient:
        """Obtain the singleton job queue instance."""
        return self.MOCK_JOB_QUEUE_CLIENT
    
    def run(self, trial: BaseTrial) -> Dict[str, Any]:
        """Deploys a trial based on custom runner subclass implementation.

        Args:
            trial: The trial to deploy.

        Returns:
            Dict of run metadata from the deployment process.
        """
        if not isinstance(trial, Trial) and not isinstance(trial, BatchTrial):
            raise ValueError("This runner only handles `Trial`.")

        mock_job_queue = self._get_mock_job_queue_client()

        run_metadata = []
        if isinstance(trial, BatchTrial):
            for arm in trial.arms:
                job_id = mock_job_queue.schedule_job_with_parameters(
                    parameters=arm.parameters
                )
                # This run metadata will be attached to trial as `trial.run_metadata`
                # by the base `Scheduler`.
                arm.run_metadata = {"job_id": job_id}
        else:
            job_id = mock_job_queue.schedule_job_with_parameters(
                parameters=trial.arm.parameters
            )

        # This run metadata will be attached to trial as `trial.run_metadata`
        # by the base `Scheduler`.
        return {"job_id": job_id}

    def poll_trial_status(
        self, trials: Iterable[BaseTrial]
    ) -> Dict[TrialStatus, Set[int]]:
        """Checks the status of any non-terminal trials and returns their
        indices as a mapping from TrialStatus to a list of indices. Required
        for runners used with Ax ``Scheduler``.

        NOTE: Does not need to handle waiting between polling calls while trials
        are running; this function should just perform a single poll.

        Args:
            trials: Trials to poll.

        Returns:
            A dictionary mapping TrialStatus to a list of trial indices that have
            the respective status at the time of the polling. This does not need to
            include trials that at the time of polling already have a terminal
            (ABANDONED, FAILED, COMPLETED) status (but it may).
        """
        status_dict = defaultdict(set)
        for trial in trials:
            mock_job_queue = self._get_mock_job_queue_client()
            status = mock_job_queue.get_job_status(
                job_id=trial.run_metadata.get("job_id")
            )
            status_dict[status].add(trial.index)

        return status_dict
    
class BraninForMockJobMetric(Metric):  # Pulls data for trial from external system.
    def __init__(self, name = None, pool = None, tmp_dir = None, **kwargs):
        self.pool = pool
        self.tmp_dir = tmp_dir
        self.MOCK_JOB_QUEUE_CLIENT = MockJobQueueClient(pool = self.pool, tmp_dir = self.tmp_dir)
        super().__init__(name=name, **kwargs)

    def _get_mock_job_queue_client(self) -> MockJobQueueClient:
        """Obtain the singleton job queue instance."""
        return self.MOCK_JOB_QUEUE_CLIENT

    def fetch_trial_data(self, trial: BaseTrial) -> MetricFetchResult:
        """Obtains data via fetching it from ` for a given trial."""
        if not isinstance(trial, Trial) and not isinstance(trial, BatchTrial):
            raise ValueError("This metric only handles `Trial`.")

        try:
            mock_job_queue = self._get_mock_job_queue_client()

            # Here we leverage the "job_id" metadata created by `MockJobRunner.run`.
            if isinstance(trial, BatchTrial):
                lst_df_dict = []
                for arm in trial.arms:
                    # branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                    #     job_id=trial.run_metadata.get("job_id")
                    # )
                    # arm.run_metadata.get("job_id")
                    branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                        job_id=arm.run_metadata.get("job_id")
                    )
                    name_ = list(branin_data.keys())[0]
                    df_dict = {
                        "trial_index": trial.index,
                        "metric_name": self.name,
                        "arm_name": arm.name,
                        "mean": branin_data.get(self.name),
                        # Can be set to 0.0 if function is known to be noiseless
                        # or to an actual value when SEM is known. Setting SEM to
                        # `None` results in Ax assuming unknown noise and inferring
                        # noise level from data.
                        "sem": None,
                    }
                    lst_df_dict.append(df_dict)
                return Ok(value=Data(df=pd.DataFrame.from_records(lst_df_dict)))
            else:
                # branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                #         job_id=trial.run_metadata.get("job_id")
                #     )
                # df_dict = {
                #     "trial_index": trial.index,
                #     "metric_name": self.name,
                #     "arm_name": trial.arm.name,
                #     "mean": branin_data.get(self.name),
                #     # Can be set to 0.0 if function is known to be noiseless
                #     # or to an actual value when SEM is known. Setting SEM to
                #     # `None` results in Ax assuming unknown noise and inferring
                #     # noise level from data.
                #     "sem": None,
                # }
                branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                        job_id=arm.run_metadata.get("job_id")
                    )
                name_ = list(branin_data.keys())[0]
                df_dict = {
                    "trial_index": trial.index,
                    "metric_name": self.name,
                    "arm_name": arm.name,
                    "mean": branin_data.get(self.name),
                    # Can be set to 0.0 if function is known to be noiseless
                    # or to an actual value when SEM is known. Setting SEM to
                    # `None` results in Ax assuming unknown noise and inferring
                    # noise level from data.
                    "sem": None,
                }
                return Ok(value=Data(df=pd.DataFrame.from_records([df_dict])))
        except Exception as e:
            return Err(
                MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
            )


def create_generation_strategy(models, n_batches, batch_size, max_parallelism, model_kwargs_list, model_gen_kwargs_list):
        """ Create a generation strategy for the optimization process using the models and the number of batches and batch sizes. See ax documentation for more details: https://ax.dev/tutorials/generation_strategy.html

        Returns
        -------
        GenerationStrategy
            The generation strategy for the optimization process

        Raises
        ------
        ValueError
            If the model is not a string or a Models enum
        """        

        steps = []
        for i, model in enumerate(models):
            if type(model) == str:
                model = Models[model]
            elif isinstance(model, Models):
                model = model
            else:
                raise ValueError('Model must be a string or a Models enum')
            steps.append(GenerationStep(
                model=model,
                num_trials=n_batches[i]*batch_size[i],
                max_parallelism=min(max_parallelism,batch_size[i]),
                model_kwargs= model_kwargs_list[i],
                model_gen_kwargs= model_gen_kwargs_list[i],
            ))

        gs = GenerationStrategy(steps=steps, )

        return gs

from ax.service.utils.report_utils import exp_to_df
def main():
    print('----We are Starting the Branin Test----')
    print(branin(0.1,0.1))
    tmp_dir = os.path.join(os.getcwd(),'.tmp_dir')
    print('tmp_dir:',tmp_dir)
    models = ['SOBOL','BOTORCH_MODULAR']
    n_batches = [1,4]
    batch_size = [1,4]
    max_parallelism = 4
    model_gen_kwargs_list =[{},{}]
    model_kwargs_list = [{},{'torch_device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),'torch_dtype': torch.double,'botorch_acqf_class':qLogNoisyExpectedImprovement,'transforms':[RemoveFixed, Log,UnitX, StandardizeY],},] 

    parameter_space = [
        {
            "name": "x1",
            "type": "range",
            "bounds": [-5, 10],
            "value_type": "float",
            "log_scale": False,
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0, 15],
            "value_type": "float",
            "log_scale": False,
        },
    ]
    enforce_sequential_optimization = False
    gs = create_generation_strategy(models, n_batches, batch_size, max_parallelism, model_kwargs_list, model_gen_kwargs_list)

    ax_client = AxClient(generation_strategy=gs, enforce_sequential_optimization=enforce_sequential_optimization)

    objectives_ = {'branin':ObjectiveProperties(minimize=True)}

    ax_client.create_experiment(
            name='test_branin',
            parameters=parameter_space,
            # objectives=objectives_,
            
        )
    q = Pool(4)
    obj = Objective(metric=BraninForMockJobMetric(name='branin', pool = q, tmp_dir = tmp_dir), minimize=True)

    ax_client.experiment.optimization_config=OptimizationConfig(objective=obj)

    # create runner
    runner = MockJobRunner(pool = q, tmp_dir = tmp_dir)
    ax_client.experiment.runner = runner

    n = 0
    total_trials = sum(np.asarray(n_batches)*np.asarray(batch_size))
    n_step_points = np.cumsum(np.asarray(n_batches)*np.asarray(batch_size))
    scheduler = Scheduler(
        experiment=ax_client.experiment,
        generation_strategy=ax_client.generation_strategy,
        options=SchedulerOptions(run_trials_in_batches=True,init_seconds_between_polls=0.1,trial_type=TrialType.BATCH_TRIAL,batch_size=4),
        )
    
    while n < total_trials:
        # check the current batch size
        
        curr_batch_size = batch_size[np.argmax(n_step_points>n)]
        n += curr_batch_size
        if n > total_trials:
            curr_batch_size = curr_batch_size - (n-total_trials)

        scheduler.run_n_trials(max_trials=4)
    
    q.close()
    q.join()

    df = exp_to_df(ax_client.experiment)
    print(df)

if __name__ == '__main__':
    main()

Code of Conduct

  • I agree to follow this Ax's Code of Conduct
@VMLC-PV VMLC-PV added the question Further information is requested label Jan 17, 2025
@Cesar-Cardoso Cesar-Cardoso self-assigned this Jan 17, 2025
@Cesar-Cardoso
Copy link
Contributor

Hello there! I was unable to successfully run your repro, it just hangs waiting for trial completion.

Based on your description of the issue however, it seems that the scheduler is expecting data that doesn't exist yet for your running trial. The message you're seeing along with the trial being failed comes from

Ax/ax/service/scheduler.py

Lines 2043 to 2060 in 3c68bd3

# If the fetch failure was for a metric in the optimization config (an
# objective or constraint) the trial as failed
optimization_config = self.experiment.optimization_config
if (
optimization_config is not None
and metric_name in optimization_config.metrics.keys()
):
status = self._mark_err_trial_status(
trial=self.experiment.trials[trial_index],
metric_name=metric_name,
metric_fetch_e=metric_fetch_e,
)
self.logger.warning(
f"MetricFetchE INFO: Because {metric_name} is an objective, "
f"marking trial {trial_index} as {status}."
)
self._num_trials_bad_due_to_err += 1
continue

If you could provide a smaller repro for the issue I can take a deeper look. Otherwise I'd try one the following:

  1. Override is_available_while_running() for your metric. If it's True and your trial is RUNNING then the scheduler should just try fetching the metric again later.
  2. Use TrialStatus.STAGED instead of TrialStatus.RUNNING in get_job_status(). The scheduler shouldn't attempt data fetching for staged trials, so you shouldn't run into this issue. Once the trial is completed the data fetching should succeed as you see when setting a large value of init_seconds_between_polls.

@VMLC-PV
Copy link
Author

VMLC-PV commented Jan 26, 2025

Hey, thanks for the suggestions.

I tried overwriting the is_available_while_running function to True and also to set the TrialStatus.STAGED instead of TrialStatus.RUNNING but neither seems to work. it still looks like the fetch_trial_data was trying to fetch non-completed trials.

I could get to run by keeping the is_available_while_running set to False and adding a sleep statement in the fetch_trial_data like so:

def fetch_trial_data(self, trial: BaseTrial) -> MetricFetchResult:
        """Obtains data via fetching it from ` for a given trial."""
        if not isinstance(trial, Trial) and not isinstance(trial, BatchTrial):
            raise ValueError("This metric only handles `Trial`.")

        try:
            mock_job_queue = self._get_mock_job_queue_client()

            # Here we leverage the "job_id" metadata created by `MockJobRunner.run`.
            if isinstance(trial, BatchTrial):
                lst_df_dict = []
                for arm in trial.arms:
                    job_id = arm.run_metadata.get("job_id")
                    while not os.path.exists(os.path.join(self.tmp_dir,str(job_id)+'.json')):
                        time.sleep(.1)
                    # branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                    #     job_id=trial.run_metadata.get("job_id")
                    # )
                    # arm.run_metadata.get("job_id")

                    branin_data = mock_job_queue.get_outcome_value_for_completed_job(
                        job_id=arm.run_metadata.get("job_id")
                    )
...

But again I don't get why this is necessary.

As to providing a smaller repro I am not sure what you mean by that?

@Cesar-Cardoso
Copy link
Contributor

I see. We recently introduced the concept of recoverable errors for metrics in #3262, for cases similar to these where we don't want to fail a trial just because a metric failed to fetch once. It just requires adding the exception type you're seeing in the Scheduler to metric. recoverable_exceptions.

This will be available in the next Ax release.

@VMLC-PV
Copy link
Author

VMLC-PV commented Jan 28, 2025

Thanks for the help.
Any idea when the next release is planned for?

@Balandat
Copy link
Contributor

We'll do a maintenance release soon - likely within the next week or so but no promises :)

@VMLC-PV
Copy link
Author

VMLC-PV commented Jan 29, 2025

Great! I will retest all this after the release and close the issue if everything works.
Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants