diff --git a/kauldron/xm/_src/experiment.py b/kauldron/xm/_src/experiment.py index 2bccc4d8..9578b5cb 100644 --- a/kauldron/xm/_src/experiment.py +++ b/kauldron/xm/_src/experiment.py @@ -129,9 +129,11 @@ def __post_init__(self): new_sweep = new_sweep.replace_with_jobs_provider(self.jobs_provider) object.__setattr__(self, "sweep_info", new_sweep) - def launch(self) -> xm_abc.XManagerExperiment: + def launch( + self, existing_xp: xm_abc.XManagerExperiment | None = None + ) -> xm_abc.XManagerExperiment: """Launch the experiment.""" - with self.create_experiment() as xp: + with self._maybe_create_experiment(existing_xp) as xp: xp.context.add_config_file( file_content=epy.pretty_repr(self.resolved_jobs), description="Jobs", @@ -170,10 +172,20 @@ def launch(self) -> xm_abc.XManagerExperiment: workdir=dir_builder.xp_dir, executor=xm_abc.Borg(requirements=requirements), ) - # TODO(epot): Support Custom auxiliaries return xp + @contextlib.contextmanager + def _maybe_create_experiment( + self, existing_xp: xm_abc.XManagerExperiment | None + ) -> Iterator[xm_abc.XManagerExperiment]: + """Returns either the existing experiment or create a new one.""" + if existing_xp is not None: + yield existing_xp + else: + with self.create_experiment() as xp: + yield xp + @contextlib.contextmanager def create_experiment(self) -> Iterator[xm_abc.XManagerExperiment]: """Wrapper around `xm_abc.create_experiment`.""" diff --git a/kauldron/xm/_src/kauldron_utils.py b/kauldron/xm/_src/kauldron_utils.py index 4a5029c8..15f43ca7 100644 --- a/kauldron/xm/_src/kauldron_utils.py +++ b/kauldron/xm/_src/kauldron_utils.py @@ -91,6 +91,7 @@ def from_module( module: str | types.ModuleType, *, overrides: dict[str, Any] | None = None, + config_parameter: str | None = None, ) -> KauldronJobs: """Create a `KauldronJobs` from a config module.""" if isinstance(module, str): @@ -98,9 +99,14 @@ def from_module( elif not isinstance(module, types.ModuleType): raise TypeError(f"Expected module. Got: {type(module)}") + if config_parameter is None: + config = module.get_config() + else: + config = module.get_config(config_parameter) return cls( module=module, - config=module.get_config(), + config=config, + config_parameter=config_parameter, overrides=overrides or {}, )