diff --git a/rslp/forest_loss_driver/__init__.py b/rslp/forest_loss_driver/__init__.py index 30e884d8..41967b3a 100644 --- a/rslp/forest_loss_driver/__init__.py +++ b/rslp/forest_loss_driver/__init__.py @@ -1,12 +1,11 @@ """Forest loss driver classification project.""" from .predict_pipeline import ( - PredictPipelineConfig, predict_pipeline, select_best_images_pipeline, ) workflows = { - "predict": (PredictPipelineConfig, predict_pipeline), - "select_best_images": (None, select_best_images_pipeline), + "predict": predict_pipeline, + "select_best_images": select_best_images_pipeline, } diff --git a/rslp/landsat_vessels/__init__.py b/rslp/landsat_vessels/__init__.py index eee14976..f1fa0212 100644 --- a/rslp/landsat_vessels/__init__.py +++ b/rslp/landsat_vessels/__init__.py @@ -3,5 +3,5 @@ from .predict_pipeline import predict_pipeline workflows = { - "predict": (None, predict_pipeline), + "predict": predict_pipeline, } diff --git a/rslp/main.py b/rslp/main.py index 9664c94f..da0cd0e8 100644 --- a/rslp/main.py +++ b/rslp/main.py @@ -21,7 +21,7 @@ def main() -> None: args = parser.parse_args(args=sys.argv[1:3]) module = importlib.import_module(f"rslp.{args.project}") - workflow_fn = module.workflows[args.workflow][1] + workflow_fn = module.workflows[args.workflow] jsonargparse.CLI(workflow_fn, args=sys.argv[3:]) diff --git a/rslp/maldives_ecosystem_mapping/__init__.py b/rslp/maldives_ecosystem_mapping/__init__.py index d8b0260c..013e4108 100644 --- a/rslp/maldives_ecosystem_mapping/__init__.py +++ b/rslp/maldives_ecosystem_mapping/__init__.py @@ -1,15 +1,10 @@ """Maldives ecosystem mapping project.""" -from rslp.config import BaseTrainPipelineConfig - -from .data_pipeline import DataPipelineConfig, data_pipeline +from .data_pipeline import data_pipeline from .predict_pipeline import maxar_predict_pipeline, sentinel2_predict_pipeline -from .train_pipeline import maxar_train_pipeline, sentinel2_train_pipeline workflows = { - "data": (DataPipelineConfig, data_pipeline), - "train_maxar": (BaseTrainPipelineConfig, maxar_train_pipeline), - "train_sentinel2": (BaseTrainPipelineConfig, sentinel2_train_pipeline), - "predict_maxar": (None, maxar_predict_pipeline), - "predict_sentinel2": (None, sentinel2_predict_pipeline), + "data": data_pipeline, + "predict_maxar": maxar_predict_pipeline, + "predict_sentinel2": sentinel2_predict_pipeline, } diff --git a/rslp/maldives_ecosystem_mapping/train_pipeline.py b/rslp/maldives_ecosystem_mapping/train_pipeline.py deleted file mode 100644 index 8040fa3e..00000000 --- a/rslp/maldives_ecosystem_mapping/train_pipeline.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Model training pipeline for Maldives ecosystem mapping project.""" - -from rslp.config import BaseTrainPipelineConfig -from rslp.launch_beaker import launch_job - - -def maxar_train_pipeline(config: BaseTrainPipelineConfig) -> None: - """Run the training pipeline. - - Args: - config: the model training config. - """ - launch_job("data/maldives_ecosystem_mapping/config.yaml", mode="fit") - - -def sentinel2_train_pipeline(config: BaseTrainPipelineConfig) -> None: - """Run the training pipeline. - - Args: - config: the model training config. - """ - launch_job("data/maldives_ecosystem_mapping/config_sentinel2.yaml", mode="fit") diff --git a/rslp/sentinel2_vessels/__init__.py b/rslp/sentinel2_vessels/__init__.py index 96e74130..1af966a0 100644 --- a/rslp/sentinel2_vessels/__init__.py +++ b/rslp/sentinel2_vessels/__init__.py @@ -3,5 +3,5 @@ from .predict_pipeline import predict_pipeline workflows = { - "predict": (None, predict_pipeline), + "predict": predict_pipeline, }