diff --git a/openeo/extra/job_management.py b/openeo/extra/job_management.py index 4668a8d9e..3009b5a2f 100644 --- a/openeo/extra/job_management.py +++ b/openeo/extra/job_management.py @@ -1,14 +1,17 @@ import abc import contextlib import datetime +import functools import json import logging +import re import time import warnings from pathlib import Path from threading import Thread from typing import Callable, Dict, List, NamedTuple, Optional, Union +import numpy import pandas as pd import requests import shapely.errors @@ -17,7 +20,7 @@ from openeo import BatchJob, Connection from openeo.rest import OpenEoApiError -from openeo.util import deep_get, rfc3339 +from openeo.util import deep_get, repr_truncate, rfc3339 _log = logging.getLogger(__name__) @@ -822,3 +825,104 @@ def persist(self, df: pd.DataFrame): self._merge_into_df(df) self.path.parent.mkdir(parents=True, exist_ok=True) self.df.to_parquet(self.path, index=False) + + +class UDPJobFactory: + """ + Batch job factory based on a parameterized process definition + (e.g a user-defined process (UDP) or a remote process definitions), + to be used together with :py:class:`MultiBackendJobManager`. + """ + + def __init__( + self, process_id: str, *, namespace: Union[str, None] = None, parameter_defaults: Optional[dict] = None + ): + self._process_id = process_id + self._namespace = namespace + self._parameter_defaults = parameter_defaults or {} + + def _get_process_definition(self, connection: Connection) -> dict: + if isinstance(self._namespace, str) and re.match("https?://", self._namespace): + return self._get_remote_process_definition() + elif self._namespace is None: + return connection.user_defined_process(self._process_id).describe() + else: + raise NotImplementedError( + f"Unsupported process definition source udp_id={self._process_id!r} namespace={self._namespace!r}" + ) + + @functools.lru_cache() + def _get_remote_process_definition(self) -> dict: + """ + Get process definition based on "Remote Process Definition Extension" spec + https://github.com/Open-EO/openeo-api/tree/draft/extensions/remote-process-definition + """ + assert isinstance(self._namespace, str) and re.match("https?://", self._namespace) + resp = requests.get(url=self._namespace) + resp.raise_for_status() + data = resp.json() + if isinstance(data, list): + # Handle process listing: filter out right process + processes = [p for p in data if p.get("id") == self._process_id] + if len(processes) != 1: + raise ValueError(f"Process {self._process_id!r} not found at {self._namespace}") + (data,) = processes + + # Check for required fields of a process definition + if isinstance(data, dict) and "id" in data and "process_graph" in data: + process_definition = data + else: + raise ValueError(f"Invalid process definition at {self._namespace}") + + return process_definition + + def start_job(self, row: pd.Series, connection: Connection, **_) -> BatchJob: + """ + Implementation of the `start_job` callable interface for MultiBackendJobManager: + Create and start a job based on given dataframe row + + :param row: The row in the pandas dataframe that stores the jobs state and other tracked data. + :param connection: The connection to the backend. + + :return: The started job. + """ + + process_definition = self._get_process_definition(connection=connection) + parameters = process_definition.get("parameters", []) + arguments = {} + for parameter in parameters: + name = parameter["name"] + schema = parameter.get("schema", {}) + if name in row.index: + # Higherst priority: value from dataframe row + value = row[name] + elif name in self._parameter_defaults: + # Fallback on default values from constructor + value = self._parameter_defaults[name] + else: + if parameter.get("optional", False): + continue + raise ValueError(f"Missing required parameter {name!r} for process {self._process_id!r}") + + # TODO: validation or normalization based on schema? + # Some pandas/numpy data types need a bit of conversion for JSON encoding + if isinstance(value, numpy.integer): + value = int(value) + elif isinstance(value, numpy.number): + value = float(value) + + arguments[name] = value + + cube = connection.datacube_from_process(process_id=self._process_id, namespace=self._namespace, **arguments) + + title = row.get("title", f"UDP {self._process_id!r} with {repr_truncate(parameters)}") + description = row.get( + "description", f"UDP {self._process_id!r} (namespace {self._namespace}) with {parameters}" + ) + job = connection.create_job(cube, title=title, description=description) + + return job + + def __call__(self, *arg, **kwargs) -> BatchJob: + """Syntactic sugar for calling `start_job` directly.""" + return self.start_job(*arg, **kwargs) diff --git a/openeo/rest/_testing.py b/openeo/rest/_testing.py index b44002687..6764a984a 100644 --- a/openeo/rest/_testing.py +++ b/openeo/rest/_testing.py @@ -5,6 +5,8 @@ from openeo import Connection, DataCube from openeo.rest.vectorcube import VectorCube +OPENEO_BACKEND = "https://openeo.test/" + class OpeneoTestingException(Exception): pass diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 25ba304ac..65126cf9a 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -1732,7 +1732,7 @@ def execute( def create_job( self, - process_graph: Union[dict, str, Path], + process_graph: Union[dict, str, Path, FlatGraphableMixin], *, title: Optional[str] = None, description: Optional[str] = None, diff --git a/tests/extra/test_job_management.py b/tests/extra/test_job_management.py index f1c6eaff6..ce6a31ab7 100644 --- a/tests/extra/test_job_management.py +++ b/tests/extra/test_job_management.py @@ -29,14 +29,26 @@ CsvJobDatabase, MultiBackendJobManager, ParquetJobDatabase, + UDPJobFactory, ) +from openeo.rest._testing import OPENEO_BACKEND, DummyBackend, build_capabilities from openeo.util import rfc3339 +@pytest.fixture +def con120(requests_mock) -> openeo.Connection: + requests_mock.get(OPENEO_BACKEND, json=build_capabilities(api_version="1.2.0")) + con = openeo.Connection(OPENEO_BACKEND) + return con + + class FakeBackend: """ Fake openEO backend with some basic job management functionality for testing job manager logic. """ + + # TODO: replace/merge with openeo.rest._testing.DummyBackend + def __init__(self, *, backend_root_url: str = "http://openeo.test", requests_mock): self.url = backend_root_url.rstrip("/") requests_mock.get(f"{self.url}/", json={"api_version": "1.1.0"}) @@ -753,3 +765,110 @@ def test_persist_and_read(self, tmp_path, orig: pandas.DataFrame): assert loaded.dtypes.to_dict() == orig.dtypes.to_dict() assert loaded.equals(orig) assert type(orig) is type(loaded) + + +class TestUDPJobFactory: + @pytest.fixture + def dummy_backend(self, requests_mock, con120) -> DummyBackend: + return DummyBackend(requests_mock=requests_mock, connection=con120) + + @pytest.fixture(autouse=True) + def remote_process_definitions(self, requests_mock): + requests_mock.get( + "https://remote.test/3plus5.json", + json={ + "id": "3plus5", + "process_graph": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}, + }, + ) + requests_mock.get( + "https://remote.test/increment.json", + json={ + "id": "increment", + "parameters": [ + {"name": "data", "schema": {"type": "number"}}, + {"name": "increment", "schema": {"type": "number"}, "optional": True, "default": 1}, + ], + "process_graph": { + "process_id": "add", + "arguments": {"x": {"from_parameter": "data"}, "y": {"from_parameter": "increment"}}, + "result": True, + }, + }, + ) + + def test_minimal(self, con120, dummy_backend): + """Bare minimum: just start a job, no parameters/arguments""" + job_factory = UDPJobFactory(process_id="3plus5", namespace="https://remote.test/3plus5.json") + + job = job_factory.start_job(row=pd.Series({"foo": 123}), connection=con120) + assert isinstance(job, BatchJob) + assert dummy_backend.batch_jobs == { + "job-000": { + "job_id": "job-000", + "pg": { + "3plus51": { + "process_id": "3plus5", + "namespace": "https://remote.test/3plus5.json", + "arguments": {}, + "result": True, + } + }, + "status": "created", + } + } + + def test_basic(self, con120, dummy_backend): + """Basic parameterized UDP job generation""" + job_factory = UDPJobFactory(process_id="increment", namespace="https://remote.test/increment.json") + + job = job_factory.start_job(row=pd.Series({"data": 123}), connection=con120) + assert isinstance(job, BatchJob) + assert dummy_backend.batch_jobs == { + "job-000": { + "job_id": "job-000", + "pg": { + "increment1": { + "process_id": "increment", + "namespace": "https://remote.test/increment.json", + "arguments": {"data": 123}, + "result": True, + } + }, + "status": "created", + } + } + + @pytest.mark.parametrize( + ["parameter_defaults", "row", "expected_arguments"], + [ + (None, {"data": 123}, {"data": 123}), + (None, {"data": 123, "increment": 5}, {"data": 123, "increment": 5}), + ({"increment": 5}, {"data": 123}, {"data": 123, "increment": 5}), + ({"increment": 5}, {"data": 123, "increment": 1000}, {"data": 123, "increment": 1000}), + ], + ) + def test_basic_parameterization(self, con120, dummy_backend, parameter_defaults, row, expected_arguments): + """Basic parameterized UDP job generation""" + job_factory = UDPJobFactory( + process_id="increment", + namespace="https://remote.test/increment.json", + parameter_defaults=parameter_defaults, + ) + + job = job_factory.start_job(row=pd.Series(row), connection=con120) + assert isinstance(job, BatchJob) + assert dummy_backend.batch_jobs == { + "job-000": { + "job_id": "job-000", + "pg": { + "increment1": { + "process_id": "increment", + "namespace": "https://remote.test/increment.json", + "arguments": expected_arguments, + "result": True, + } + }, + "status": "created", + } + }