Skip to content

Commit

Permalink
Merge pull request #15 from maiot-io/zenml-cli
Browse files Browse the repository at this point in the history
Zenml cli
  • Loading branch information
htahir1 authored Jan 19, 2021
2 parents 7e3b792 + ef135a0 commit 8d6ae70
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 3 deletions.
2 changes: 2 additions & 0 deletions zenml/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@
from .init import *
from .version import *
from .pipeline import *
from .datasource import *
from .step import *
15 changes: 14 additions & 1 deletion zenml/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@

from zenml.cli.cli import cli
from zenml.cli.cli import pass_config
from zenml.cli.utils import parse_unknown_options
from zenml.cli.utils import parse_unknown_options, error
from zenml.core.repo.repo import Repository
from zenml.utils.print_utils import to_pretty_string


@cli.group()
Expand Down Expand Up @@ -54,6 +55,18 @@ def opt_out(config):
click.echo('Opted out of analytics.')


@config.command("list")
def list_config():
"""Print the current ZenML config to the command line"""
try:
repo: Repository = Repository.get_instance()
except Exception as e:
error(e)
return

click.echo(to_pretty_string(repo.zenml_config))


# Metadata Store
@config.group()
def metadata():
Expand Down
47 changes: 47 additions & 0 deletions zenml/cli/datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) maiot GmbH 2021. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

import click
from tabulate import tabulate
from zenml.cli.cli import cli
from zenml.cli.utils import pretty_print, pass_repo
from zenml.core.repo.repo import Repository
from typing import Text


@cli.group()
def datasource():
"""Data source group"""
pass


@datasource.command("list")
@pass_repo
def list_datasources(repo: Repository):
datasources = repo.get_datasources()

click.echo(tabulate([ds.to_config() for ds in datasources],
headers="keys"))


@datasource.command("get")
@click.argument('datasource_name')
@pass_repo
def get_datasource_by_name(repo: Repository, datasource_name: Text):
"""
Gets pipeline from current repository by matching a name identifier
against the data source name.
"""
pretty_print(repo.get_datasource_by_name(datasource_name))
69 changes: 67 additions & 2 deletions zenml/cli/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
"""CLI for pipelines."""

import click
from tabulate import tabulate
from typing import Text

from zenml.cli.cli import cli
from zenml.cli.utils import error, pretty_print, pass_repo
from zenml.core.repo.repo import Repository
from zenml.utils.yaml_utils import read_yaml
from zenml.core.pipelines.training_pipeline import TrainingPipeline


@cli.group()
Expand All @@ -26,8 +31,68 @@ def pipeline():


@pipeline.command('compare')
def set_metadata_store():
@pass_repo
def compare_pipelines(repo: Repository):
"""Compares pipelines in repo"""
click.echo('Comparing pipelines in repo: Starting app..')
repo: Repository = Repository.get_instance()
repo.compare_pipelines()


@pipeline.command('list')
@pass_repo
def list_pipelines(repo: Repository):
"""Lists pipelines in the current repository."""
try:
pipelines = repo.get_pipelines()

names = [p.name for p in pipelines]
types = [p.PIPELINE_TYPE for p in pipelines]
statuses = [p.get_status() for p in pipelines]
cache_enabled = [p.enable_cache for p in pipelines]
filenames = [p.file_name for p in pipelines]

headers = ["name", "type", "cache enabled", "status", "file name"]

click.echo(tabulate(zip(names, types, cache_enabled,
statuses, filenames),
headers=headers))
except Exception as e:
error(e)


@pipeline.command('get')
@click.argument('pipeline_name')
@pass_repo
def get_pipeline_by_name(repo: Repository, pipeline_name: Text):
"""
Gets pipeline from current repository by matching a name against a
pipeline name in the repository.
"""
try:
p = repo.get_pipeline_by_name(pipeline_name)
except Exception as e:
error(e)
return

pretty_print(p)


@pipeline.command('run')
@click.argument('path_to_config')
@pass_repo
def run_pipeline(path_to_config: Text):
"""
Runs pipeline specified by the given config YAML object.
Args:
path_to_config: Path to config of the designated pipeline.
Has to be matching the YAML file name.
"""
# config has metadata store, backends and artifact store,
# so no need to specify them
try:
config = read_yaml(path_to_config)
p: TrainingPipeline = TrainingPipeline.from_config(config)
p.run()
except Exception as e:
error(e)
39 changes: 39 additions & 0 deletions zenml/cli/step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) maiot GmbH 2021. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

import click
from tabulate import tabulate
from zenml.cli.cli import cli
from zenml.cli.utils import pass_repo
from zenml.core.repo.repo import Repository


@cli.group()
def step():
"""Steps group"""
pass


@step.command("list")
@pass_repo
def list_steps(repo: Repository):
step_versions = repo.get_step_versions()
name_version_data = []
headers = ["step_name", "step_version"]
for name, version_set in step_versions.items():
names = [name] * len(version_set)
versions = list(version_set)
name_version_data.extend(list(zip(names, versions)))

click.echo(tabulate(name_version_data, headers=headers))
11 changes: 11 additions & 0 deletions zenml/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
from dateutil import tz

from zenml.core.repo.global_config import GlobalConfig
from zenml.core.repo.repo import Repository

pass_config = click.make_pass_decorator(GlobalConfig, ensure=True)

pass_repo = click.make_pass_decorator(Repository, ensure=True)


def title(text):
"""
Expand Down Expand Up @@ -82,6 +85,14 @@ def warning(text):
click.echo(click.style(text, fg='yellow', bold=True))


def pretty_print(obj):
"""
Args:
obj:
"""
click.echo(str(obj))


def format_date(dt, format='%Y-%m-%d %H:%M:%S'):
"""
Args:
Expand Down

0 comments on commit 8d6ae70

Please sign in to comment.