diff --git a/kraken/kraken.py b/kraken/kraken.py index a41c41f5..ed01e9a8 100644 --- a/kraken/kraken.py +++ b/kraken/kraken.py @@ -694,8 +694,8 @@ def show(ctx, metadata_version, model_id): """ Retrieves model metadata from the repository. """ - from htrmopo import get_description from htrmopo.util import iso15924_to_name, iso639_3_to_name + from kraken.repo import get_description from kraken.lib.util import is_printable, make_printable def _render_creators(creators): @@ -716,15 +716,13 @@ def _render_metrics(metrics): metadata_version = None try: - desc = get_description(model_id, version=metadata_version) + desc = get_description(model_id, + version=metadata_version, + filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords) except ValueError as e: logger.error(e) ctx.exit(1) - if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords: - logger.error('Record exists but is not a kraken-compatible model') - ctx.exit(1) - if desc.version == 'v0': chars = [] combining = [] @@ -777,19 +775,13 @@ def list_models(ctx): """ Lists models in the repository. """ - from htrmopo import get_listing - from collections import defaultdict + from kraken.repo import get_listing from kraken.lib.progress import KrakenProgressBar with KrakenProgressBar() as progress: download_task = progress.add_task('Retrieving model list', total=0, visible=True if not ctx.meta['verbose'] else False) - repository = get_listing(lambda total, advance: progress.update(download_task, total=total, advance=advance)) - # aggregate models under their concept DOI - concepts = defaultdict(list) - for item in repository.values(): - # both got the same DOI information - record = item['v0'] if item['v0'] else item['v1'] - concepts[record.concept_doi].append(record.doi) + repository = get_listing(callback=lambda total, advance: progress.update(download_task, total=total, advance=advance), + filter_fn=lambda record: getattr(record, 'software_name', None) == 'kraken' or 'kraken_pytorch' in record.keywords) table = Table(show_header=True) table.add_column('DOI', justify="left", no_wrap=True) @@ -797,13 +789,7 @@ def list_models(ctx): table.add_column('model type', justify="left", no_wrap=False) table.add_column('keywords', justify="left", no_wrap=False) - for k, v in concepts.items(): - records = [repository[x]['v1'] if 'v1' in repository[x] else repository[x]['v0'] for x in v] - records = filter(lambda record: getattr(record, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in record.keywords, records) - records = sorted(records, key=lambda x: x.publication_date, reverse=True) - if not len(records): - continue - + for k, records in repository.items(): t = Tree(k) [t.add(x.doi) for x in records] table.add_row(t, @@ -812,7 +798,6 @@ def list_models(ctx): Group(*[''] + ['; '.join(x.keywords) for x in records])) print(table) - ctx.exit(0) @cli.command('get') @@ -822,20 +807,29 @@ def get(ctx, model_id): """ Retrieves a model from the repository. """ - from kraken import repo + import glob + + from htrmopo import get_model, get_description + from kraken.lib.progress import KrakenDownloadProgressBar try: - os.makedirs(click.get_app_dir(APP_NAME)) - except OSError: - pass + desc = get_description(model_id) + except ValueError as e: + logger.error(e) + ctx.exit(1) + + print(desc) + if getattr(desc, 'software_name', None) != 'kraken' or 'kraken_pytorch' not in desc.keywords: + logger.error('Record exists but is not a kraken-compatible model') + ctx.exit(1) with KrakenDownloadProgressBar() as progress: download_task = progress.add_task('Processing', total=0, visible=True if not ctx.meta['verbose'] else False) - filename = repo.get_model(model_id, click.get_app_dir(APP_NAME), - lambda total, advance: progress.update(download_task, total=total, advance=advance)) - message(f'Model name: {filename}') - ctx.exit(0) + model_dir = get_model(model_id, + lambda total, advance: progress.update(download_task, total=total, advance=advance)) + model_candidates = list(filter(lambda x: x.suffix == '.mlmodel', model_dir.iter_dir())) + message(f'Model dir: {model_dir} (model files: {model_candidates})') if __name__ == '__main__': diff --git a/kraken/repo.py b/kraken/repo.py new file mode 100644 index 00000000..f283deb7 --- /dev/null +++ b/kraken/repo.py @@ -0,0 +1,87 @@ +# +# Copyright 2015 Benjamin Kiessling +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" +kraken.repo +~~~~~~~~~~~ + +Wrappers around the htrmopo reference implementation implementing +kraken-specific filtering. +""" +import logging +import warnings +from pathlib import Path +from collections import defaultdict +from typing import IO, Any, Dict, List, Union, cast, Optional, TypeVar, Iterable, Literal + +from collections.abc import Callable + +from htrmopo import get_description as mopo_get_description +from htrmopo import get_listing as mopo_get_listing +from htrmopo.record import v0RepositoryRecord, v1RepositoryRecord + + +_v0_or_v1_Record = TypeVar('_v0_or_v1_Record', v0RepositoryRecord, v1RepositoryRecord) + + +def get_description(model_id: str, + callback: Callable[..., Any] = lambda: None, + version: Optional[Literal['v0', 'v1']] = None, + filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> _v0_or_v1_Record: + """ + Filters the output of htrmopo.get_description with a custom function. + + Args: + model_id: model DOI + callback: Progress callback + version: + filter_fn: Function called to filter the retrieved record. + """ + desc = mopo_get_description(model_id, callback, version) + if not filter_fn(desc): + raise ValueError(f'Record {model_id} exists but is not a valid kraken record') + return desc + + +def get_listing(callback: Callable[[int, int], Any] = lambda total, advance: None, + from_date: Optional[str] = None, + filter_fn: Optional[Callable[[_v0_or_v1_Record], bool]] = lambda x: True) -> Dict[str, Dict[str, _v0_or_v1_Record]]: + """ + Returns a filtered representation of the model repository grouped by + concept DOI. + + Args: + callback: Progress callback + from_data: + filter_fn: Function called for each record object + + Returns: + A dictionary mapping group DOIs to one record object per deposit. The + record of the highest available schema version is retained. + """ + repository = mopo_get_listing(callback, from_date) + # aggregate models under their concept DOI + concepts = defaultdict(list) + for item in repository.values(): + # filter records here + item = {k: v for k, v in item.items() if filter_fn(v)} + # both got the same DOI information + record = item.get('v1', item.get('v0', None)) + if record is not None: + concepts[record.concept_doi].append(record) + + for k, v in concepts.items(): + concepts[k] = sorted(v, key=lambda x: x.publication_date, reverse=True) + + return concepts