From 947f8ff6e479d0fb593253242f7353d8f1d4d6cd Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:13:56 +0200 Subject: [PATCH] linting --- template/quickstart.ipynb | 90 ++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 43 deletions(-) diff --git a/template/quickstart.ipynb b/template/quickstart.ipynb index 91737c9..afcf2ca 100644 --- a/template/quickstart.ipynb +++ b/template/quickstart.ipynb @@ -71,7 +71,7 @@ " # Pull required modules from this example\n", " !git clone -b main https://github.com/zenml-io/zenml\n", " !cp -r zenml/examples/quickstart/* .\n", - " !rm -rf zenml\n" + " !rm -rf zenml" ] }, { @@ -84,6 +84,7 @@ "!zenml integration install sklearn -y\n", "\n", "import IPython\n", + "\n", "IPython.Application.instance().kernel.do_shutdown(restart=True)" ] }, @@ -145,28 +146,22 @@ "outputs": [], "source": [ "# Do the imports at the top\n", - "from typing_extensions import Annotated\n", - "from sklearn.datasets import load_breast_cancer\n", - "\n", "import random\n", - "import pandas as pd\n", - "from zenml import step, pipeline, Model, get_step_context\n", - "from zenml.client import Client\n", - "from zenml.logger import get_logger\n", + "from typing import List, Optional\n", "from uuid import UUID\n", "\n", - "from typing import Optional, List\n", - "\n", - "from zenml import pipeline\n", - "\n", + "import pandas as pd\n", + "from sklearn.datasets import load_breast_cancer\n", "from steps import (\n", " data_loader,\n", " data_preprocessor,\n", " data_splitter,\n", + " inference_preprocessor,\n", " model_evaluator,\n", - " inference_preprocessor\n", ")\n", - "\n", + "from typing_extensions import Annotated\n", + "from zenml import Model, get_step_context, pipeline, step\n", + "from zenml.client import Client\n", "from zenml.logger import get_logger\n", "\n", "logger = get_logger(__name__)\n", @@ -205,7 +200,7 @@ "@step\n", "def data_loader_simplified(\n", " random_state: int, is_inference: bool = False, target: str = \"target\"\n", - ") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset \n", + ") -> Annotated[pd.DataFrame, \"dataset\"]: # We name the dataset\n", " \"\"\"Dataset reader step.\"\"\"\n", " dataset = load_breast_cancer(as_frame=True)\n", " inference_size = int(len(dataset.target) * 0.05)\n", @@ -218,7 +213,7 @@ " dataset.drop(inference_subset.index, inplace=True)\n", " dataset.reset_index(drop=True, inplace=True)\n", " logger.info(f\"Dataset with {len(dataset)} records loaded!\")\n", - " return dataset\n" + " return dataset" ] }, { @@ -291,7 +286,7 @@ " normalize: Optional[bool] = None,\n", " drop_columns: Optional[List[str]] = None,\n", " target: Optional[str] = \"target\",\n", - " random_state: int = 17\n", + " random_state: int = 17,\n", "):\n", " \"\"\"Feature engineering pipeline.\"\"\"\n", " # Link all the steps together by calling them and passing the output\n", @@ -402,7 +397,6 @@ "from zenml.environment import Environment\n", "from zenml.zen_stores.rest_zen_store import RestZenStore\n", "\n", - "\n", "if not isinstance(client.zen_store, RestZenStore):\n", " # Only spin up a local Dashboard in case you aren't already connected to a remote server\n", " if Environment.in_google_colab():\n", @@ -479,7 +473,9 @@ "outputs": [], "source": [ "# Get artifact version from our run\n", - "dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\"dataset_trn\"] \n", + "dataset_trn_artifact_version_via_run = run.steps[\"data_preprocessor\"].outputs[\n", + " \"dataset_trn\"\n", + "]\n", "\n", "# Get latest version from client directly\n", "dataset_trn_artifact_version = client.get_artifact_version(\"dataset_trn\")\n", @@ -498,7 +494,9 @@ "source": [ "# Fetch the rest of the artifacts\n", "dataset_tst_artifact_version = client.get_artifact_version(\"dataset_tst\")\n", - "preprocessing_pipeline_artifact_version = client.get_artifact_version(\"preprocess_pipeline\")" + "preprocessing_pipeline_artifact_version = client.get_artifact_version(\n", + " \"preprocess_pipeline\"\n", + ")" ] }, { @@ -576,7 +574,9 @@ "def model_trainer(\n", " dataset_trn: pd.DataFrame,\n", " model_type: str = \"sgd\",\n", - ") -> Annotated[ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)]:\n", + ") -> Annotated[\n", + " ClassifierMixin, ArtifactConfig(name=\"sklearn_classifier\", is_model_artifact=True)\n", + "]:\n", " \"\"\"Configure and train a model on the training dataset.\"\"\"\n", " target = \"target\"\n", " if model_type == \"sgd\":\n", @@ -584,7 +584,7 @@ " elif model_type == \"rf\":\n", " model = RandomForestClassifier()\n", " else:\n", - " raise ValueError(f\"Unknown model type {model_type}\") \n", + " raise ValueError(f\"Unknown model type {model_type}\")\n", "\n", " logger.info(f\"Training model {model}...\")\n", "\n", @@ -592,7 +592,7 @@ " dataset_trn.drop(columns=[target]),\n", " dataset_trn[target],\n", " )\n", - " return model\n" + " return model" ] }, { @@ -630,14 +630,14 @@ " min_train_accuracy: float = 0.0,\n", " min_test_accuracy: float = 0.0,\n", "):\n", - " \"\"\"Model training pipeline.\"\"\" \n", + " \"\"\"Model training pipeline.\"\"\"\n", " if train_dataset_id is None or test_dataset_id is None:\n", - " # If we dont pass the IDs, this will run the feature engineering pipeline \n", + " # If we dont pass the IDs, this will run the feature engineering pipeline\n", " dataset_trn, dataset_tst = feature_engineering()\n", " else:\n", " # Load the datasets from an older pipeline\n", " dataset_trn = client.get_artifact_version(name_id_or_prefix=train_dataset_id)\n", - " dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id) \n", + " dataset_tst = client.get_artifact_version(name_id_or_prefix=test_dataset_id)\n", "\n", " trained_model = model_trainer(\n", " dataset_trn=dataset_trn,\n", @@ -676,7 +676,7 @@ "training(\n", " model_type=\"rf\",\n", " train_dataset_id=dataset_trn_artifact_version.id,\n", - " test_dataset_id=dataset_tst_artifact_version.id\n", + " test_dataset_id=dataset_tst_artifact_version.id,\n", ")\n", "\n", "rf_run = client.get_pipeline(\"training\").last_run" @@ -693,7 +693,7 @@ "sgd_run = training(\n", " model_type=\"sgd\",\n", " train_dataset_id=dataset_trn_artifact_version.id,\n", - " test_dataset_id=dataset_tst_artifact_version.id\n", + " test_dataset_id=dataset_tst_artifact_version.id,\n", ")\n", "\n", "sgd_run = client.get_pipeline(\"training\").last_run" @@ -717,7 +717,9 @@ "outputs": [], "source": [ "# The evaluator returns a float value with the accuracy\n", - "rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\"model_evaluator\"].output.load()" + "rf_run.steps[\"model_evaluator\"].output.load() > sgd_run.steps[\n", + " \"model_evaluator\"\n", + "].output.load()" ] }, { @@ -776,7 +778,7 @@ "training_configured(\n", " model_type=\"sgd\",\n", " train_dataset_id=dataset_trn_artifact_version.id,\n", - " test_dataset_id=dataset_tst_artifact_version.id\n", + " test_dataset_id=dataset_tst_artifact_version.id,\n", ")" ] }, @@ -798,7 +800,7 @@ "training_configured(\n", " model_type=\"rf\",\n", " train_dataset_id=dataset_trn_artifact_version.id,\n", - " test_dataset_id=dataset_tst_artifact_version.id\n", + " test_dataset_id=dataset_tst_artifact_version.id,\n", ")" ] }, @@ -848,7 +850,9 @@ "rf_zenml_model_version = client.get_model_version(\"breast_cancer_classifier\", \"rf\")\n", "\n", "# We can now load our classifier directly as well\n", - "random_forest_classifier = rf_zenml_model_version.get_artifact(\"sklearn_classifier\").load()\n", + "random_forest_classifier = rf_zenml_model_version.get_artifact(\n", + " \"sklearn_classifier\"\n", + ").load()\n", "\n", "random_forest_classifier" ] @@ -956,7 +960,7 @@ "\n", " predictions = pd.Series(predictions, name=\"predicted\")\n", "\n", - " return predictions\n" + " return predictions" ] }, { @@ -983,18 +987,18 @@ " random_state = 42\n", " target = \"target\"\n", "\n", - " df_inference = data_loader(\n", - " random_state=random_state, is_inference=True\n", - " )\n", + " df_inference = data_loader(random_state=random_state, is_inference=True)\n", " df_inference = inference_preprocessor(\n", " dataset_inf=df_inference,\n", " # We use the preprocess pipeline from the feature engineering pipeline\n", - " preprocess_pipeline=client.get_artifact_version(name_id_or_prefix=preprocess_pipeline_id),\n", + " preprocess_pipeline=client.get_artifact_version(\n", + " name_id_or_prefix=preprocess_pipeline_id\n", + " ),\n", " target=target,\n", " )\n", " inference_predict(\n", " dataset_inf=df_inference,\n", - " )\n" + " )" ] }, { @@ -1018,7 +1022,7 @@ "# Lets add some metadata to the model to make it identifiable\n", "pipeline_settings[\"model\"] = Model(\n", " name=\"breast_cancer_classifier\",\n", - " version=\"production\", # We can pass in the stage name here!\n", + " version=\"production\", # We can pass in the stage name here!\n", " license=\"Apache 2.0\",\n", " description=\"A breast cancer classifier\",\n", " tags=[\"breast_cancer\", \"classifier\"],\n", @@ -1039,9 +1043,7 @@ "# Let's run it again to make sure we have two versions\n", "# We need to pass in the ID of the preprocessing done in the feature engineering pipeline\n", "# in order to avoid training-serving skew\n", - "inference_configured(\n", - " preprocess_pipeline_id=preprocessing_pipeline_artifact_version.id\n", - ")" + "inference_configured(preprocess_pipeline_id=preprocessing_pipeline_artifact_version.id)" ] }, { @@ -1061,7 +1063,9 @@ "outputs": [], "source": [ "# Fetch production model\n", - "production_model_version = client.get_model_version(\"breast_cancer_classifier\", \"production\")\n", + "production_model_version = client.get_model_version(\n", + " \"breast_cancer_classifier\", \"production\"\n", + ")\n", "\n", "# Get the predictions artifact\n", "production_model_version.get_artifact(\"predictions\").load()"