Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference using URL #576

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions docs/machine_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,30 +89,42 @@ Finally execute this whole training flow as a batch job::
training_job = model.create_job()
training_job.start_and_wait()


Inference
----------

When the batch job finishes successfully, the trained model can then be used
with the ``predict_random_forest`` process on the raster data cube
(or another cube with the same band structure) to classify all the pixels.

We inspect the result metadata of the training job to obtain the STAC Item URL of the trained model::


results = training_job.get_results()
links = results.get_metadata()['links']
ml_model_metadata_url = [link for link in links if 'ml_model_metadata.json' in link['href']][0]['href']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why getting the model URL from the links? Isn't this an asset in the batch job results?

Also, this "ml_model_metadata.json" looks like a VITO-specific implementation detail. Can this be generalized. Or at least documented why you are looking for that string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is within the metadata indeed but different than the Assets's link. The link in the assets is as tar.gz . It throws a JSON error when trying to load it using that specific url and same is with the canonical link.
ref: Open-EO/openeo-geopyspark-driver#562

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand save_ml_model correctly (https://github.com/Open-EO/openeo-processes/pull/441/files#diff-0aae18f348c05c0f26c5caee2dcb0ca4af218a55287c9cbf8a286fc35b8faae5) there should be a asset that is a STAC Item that is to be used with load_ml_model.

that 'ml_model_metadata.json' reference in the current docs is highly implementation specific and should be avoided

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we are missing some specification to allow proper extraction of the model's URL, so I started the discussion at https://github.com/Open-EO/openeo-processes/pull/441/files/b162040dea04acd12fad73166274bbff370ed29b#r1650820880

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a better solution for

[link for link in links if 'ml_model_metadata.json' in link['href']][0]['href']

this construct is quite brittle and implementation specific.
Also note that https://github.com/Open-EO/openeo-processes/pull/441/files/b162040dea04acd12fad73166274bbff370ed29b#r1650872242 explicitly says that this the ML Model extension is outdated and going to be replaced, so it is very likely this is going to break

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is apparently the next-gen ML extension

I think we better make sure we make this snippet future-proof

print(ml_model_metadata_url)


Next, load the model from the URL::

model = connection.load_ml_model(id=ml_model_metadata_url)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just do

Suggested change
model = connection.load_ml_model(id=ml_model_metadata_url)
model = connection.load_ml_model(ml_model_metadata_url)

I think the id argument is not a very good name and we might change that in the future, so just avoid using it to keep these docs more robust

Copy link
Member

@soxofaan soxofaan Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Technically, the openEO ``predict_random_forest`` process has to be used as a reducer function
inside a ``reduce_dimension`` call, but the openEO Python client library makes it
a bit easier by providing a :py:meth:`~openeo.rest.datacube.DataCube.predict_random_forest` method
directly on the :py:class:`~openeo.rest.datacube.DataCube` class, so that you can just do::

predicted = cube.predict_random_forest(
model=training_job.job_id,
model=model,
dimension="bands"
)

predicted.download("predicted.GTiff")


We specified the model here by batch job id (string),
We specified the model here by URL corresponding to the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct, you specify the model with model, which is a MlModel instance, not a URL

STAC Item that implements the ml-model extension,
but it can also be specified in other ways:
as :py:class:`~openeo.rest.job.BatchJob` instance,
as URL to the corresponding STAC Item that implements the `ml-model` extension,
as job_id of training job (string),
or as :py:class:`~openeo.rest.mlmodel.MlModel` instance (e.g. loaded through
:py:meth:`~openeo.rest.connection.Connection.load_ml_model`).
soxofaan marked this conversation as resolved.
Show resolved Hide resolved
:py:meth:`~openeo.rest.connection.Connection.load_ml_model`).