diff --git a/zenml/cli/__init__.py b/zenml/cli/__init__.py index 057871323dc..fdc1e18fbe9 100644 --- a/zenml/cli/__init__.py +++ b/zenml/cli/__init__.py @@ -20,3 +20,5 @@ from .init import * from .version import * from .pipeline import * +from .datasource import * +from .step import * diff --git a/zenml/cli/config.py b/zenml/cli/config.py index 1211b497408..bfb1d938094 100644 --- a/zenml/cli/config.py +++ b/zenml/cli/config.py @@ -19,8 +19,9 @@ from zenml.cli.cli import cli from zenml.cli.cli import pass_config -from zenml.cli.utils import parse_unknown_options +from zenml.cli.utils import parse_unknown_options, error from zenml.core.repo.repo import Repository +from zenml.utils.print_utils import to_pretty_string @cli.group() @@ -54,6 +55,18 @@ def opt_out(config): click.echo('Opted out of analytics.') +@config.command("list") +def list_config(): + """Print the current ZenML config to the command line""" + try: + repo: Repository = Repository.get_instance() + except Exception as e: + error(e) + return + + click.echo(to_pretty_string(repo.zenml_config)) + + # Metadata Store @config.group() def metadata(): diff --git a/zenml/cli/datasource.py b/zenml/cli/datasource.py new file mode 100644 index 00000000000..35ef150d9f1 --- /dev/null +++ b/zenml/cli/datasource.py @@ -0,0 +1,47 @@ +# Copyright (c) maiot GmbH 2021. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import click +from tabulate import tabulate +from zenml.cli.cli import cli +from zenml.cli.utils import pretty_print, pass_repo +from zenml.core.repo.repo import Repository +from typing import Text + + +@cli.group() +def datasource(): + """Data source group""" + pass + + +@datasource.command("list") +@pass_repo +def list_datasources(repo: Repository): + datasources = repo.get_datasources() + + click.echo(tabulate([ds.to_config() for ds in datasources], + headers="keys")) + + +@datasource.command("get") +@click.argument('datasource_name') +@pass_repo +def get_datasource_by_name(repo: Repository, datasource_name: Text): + """ + Gets pipeline from current repository by matching a name identifier + against the data source name. + + """ + pretty_print(repo.get_datasource_by_name(datasource_name)) diff --git a/zenml/cli/pipeline.py b/zenml/cli/pipeline.py index ef010b5d488..573f2a82eaa 100644 --- a/zenml/cli/pipeline.py +++ b/zenml/cli/pipeline.py @@ -14,9 +14,14 @@ """CLI for pipelines.""" import click +from tabulate import tabulate +from typing import Text from zenml.cli.cli import cli +from zenml.cli.utils import error, pretty_print, pass_repo from zenml.core.repo.repo import Repository +from zenml.utils.yaml_utils import read_yaml +from zenml.core.pipelines.training_pipeline import TrainingPipeline @cli.group() @@ -26,8 +31,68 @@ def pipeline(): @pipeline.command('compare') -def set_metadata_store(): +@pass_repo +def compare_pipelines(repo: Repository): """Compares pipelines in repo""" click.echo('Comparing pipelines in repo: Starting app..') - repo: Repository = Repository.get_instance() repo.compare_pipelines() + + +@pipeline.command('list') +@pass_repo +def list_pipelines(repo: Repository): + """Lists pipelines in the current repository.""" + try: + pipelines = repo.get_pipelines() + + names = [p.name for p in pipelines] + types = [p.PIPELINE_TYPE for p in pipelines] + statuses = [p.get_status() for p in pipelines] + cache_enabled = [p.enable_cache for p in pipelines] + filenames = [p.file_name for p in pipelines] + + headers = ["name", "type", "cache enabled", "status", "file name"] + + click.echo(tabulate(zip(names, types, cache_enabled, + statuses, filenames), + headers=headers)) + except Exception as e: + error(e) + + +@pipeline.command('get') +@click.argument('pipeline_name') +@pass_repo +def get_pipeline_by_name(repo: Repository, pipeline_name: Text): + """ + Gets pipeline from current repository by matching a name against a + pipeline name in the repository. + """ + try: + p = repo.get_pipeline_by_name(pipeline_name) + except Exception as e: + error(e) + return + + pretty_print(p) + + +@pipeline.command('run') +@click.argument('path_to_config') +@pass_repo +def run_pipeline(path_to_config: Text): + """ + Runs pipeline specified by the given config YAML object. + + Args: + path_to_config: Path to config of the designated pipeline. + Has to be matching the YAML file name. + """ + # config has metadata store, backends and artifact store, + # so no need to specify them + try: + config = read_yaml(path_to_config) + p: TrainingPipeline = TrainingPipeline.from_config(config) + p.run() + except Exception as e: + error(e) diff --git a/zenml/cli/step.py b/zenml/cli/step.py new file mode 100644 index 00000000000..5dcd232b1b9 --- /dev/null +++ b/zenml/cli/step.py @@ -0,0 +1,39 @@ +# Copyright (c) maiot GmbH 2021. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import click +from tabulate import tabulate +from zenml.cli.cli import cli +from zenml.cli.utils import pass_repo +from zenml.core.repo.repo import Repository + + +@cli.group() +def step(): + """Steps group""" + pass + + +@step.command("list") +@pass_repo +def list_steps(repo: Repository): + step_versions = repo.get_step_versions() + name_version_data = [] + headers = ["step_name", "step_version"] + for name, version_set in step_versions.items(): + names = [name] * len(version_set) + versions = list(version_set) + name_version_data.extend(list(zip(names, versions))) + + click.echo(tabulate(name_version_data, headers=headers)) diff --git a/zenml/cli/utils.py b/zenml/cli/utils.py index d6510f3f3e4..c6225c67a83 100644 --- a/zenml/cli/utils.py +++ b/zenml/cli/utils.py @@ -17,9 +17,12 @@ from dateutil import tz from zenml.core.repo.global_config import GlobalConfig +from zenml.core.repo.repo import Repository pass_config = click.make_pass_decorator(GlobalConfig, ensure=True) +pass_repo = click.make_pass_decorator(Repository, ensure=True) + def title(text): """ @@ -82,6 +85,14 @@ def warning(text): click.echo(click.style(text, fg='yellow', bold=True)) +def pretty_print(obj): + """ + Args: + obj: + """ + click.echo(str(obj)) + + def format_date(dt, format='%Y-%m-%d %H:%M:%S'): """ Args: