Skip to content

Commit

Permalink
Rename model version to a model (#13)
Browse files Browse the repository at this point in the history
* OSSK-342

* bump zenml version
  • Loading branch information
avishniakov authored Jan 19, 2024
1 parent 324a456 commit 6dbf10f
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ jobs:
with:
stack-name: ${{ matrix.stack-name }}
python-version: ${{ matrix.python-version }}
ref-zenml: ${{ inputs.ref-zenml || 'feature/update-quickstart-from-template' }}
ref-zenml: ${{ inputs.ref-zenml || 'feature/OSSK-342-rename-model-version-to-a-model' }}
ref-template: ${{ inputs.ref-template || github.ref }}
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ dmypy.json

*.zen
.vscode
.local
2 changes: 1 addition & 1 deletion template/configs/inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ settings:
- pyarrow

# configuration of the Model Control Plane
model_version:
model:
name: "breast_cancer_classifier"
version: "production"
license: Apache 2.0
Expand Down
2 changes: 1 addition & 1 deletion template/configs/training_rf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ settings:
- pyarrow

# configuration of the Model Control Plane
model_version:
model:
name: breast_cancer_classifier
version: rf
license: Apache 2.0
Expand Down
2 changes: 1 addition & 1 deletion template/configs/training_sgd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ settings:
- pyarrow

# configuration of the Model Control Plane
model_version:
model:
name: breast_cancer_classifier
version: sgd
license: Apache 2.0
Expand Down
4 changes: 2 additions & 2 deletions template/pipelines/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ def inference(random_state: str, target: str):
target: Name of target column in dataset.
"""
# Get the production model artifact
model = get_pipeline_context().model_version.get_artifact("sklearn_classifier")
model = get_pipeline_context().model.get_artifact("sklearn_classifier")

# Get the preprocess pipeline artifact associated with this version
preprocess_pipeline = get_pipeline_context().model_version.get_artifact(
preprocess_pipeline = get_pipeline_context().model.get_artifact(
"preprocess_pipeline"
)

Expand Down
20 changes: 10 additions & 10 deletions template/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
"\n",
"import random\n",
"import pandas as pd\n",
"from zenml import step, ExternalArtifact, pipeline, ModelVersion, get_step_context\n",
"from zenml import step, ExternalArtifact, pipeline, Model, get_step_context\n",
"from zenml.client import Client\n",
"from zenml.logger import get_logger\n",
"from uuid import UUID\n",
Expand Down Expand Up @@ -729,7 +729,7 @@
"all the models produced as you develop your experiments and use-cases. Luckily, ZenML offers a *Model Control Plane*,\n",
"which is a central register of all your ML models.\n",
"\n",
"You can easily create a ZenML `Model` and associate it with your pipelines using the `ModelVersion` object:"
"You can easily create a ZenML Model and associate it with your pipelines using the `Model` object:"
]
},
{
Expand All @@ -742,7 +742,7 @@
"pipeline_settings = {}\n",
"\n",
"# Lets add some metadata to the model to make it identifiable\n",
"pipeline_settings[\"model_version\"] = ModelVersion(\n",
"pipeline_settings[\"model\"] = Model(\n",
" name=\"breast_cancer_classifier\",\n",
" license=\"Apache 2.0\",\n",
" description=\"A breast cancer classifier\",\n",
Expand All @@ -758,7 +758,7 @@
"outputs": [],
"source": [
"# Let's train the SGD model and set the version name to \"sgd\"\n",
"pipeline_settings[\"model_version\"].version = \"sgd\"\n",
"pipeline_settings[\"model\"].version = \"sgd\"\n",
"\n",
"# the `with_options` method allows us to pass in pipeline settings\n",
"# and returns a configured pipeline\n",
Expand All @@ -780,7 +780,7 @@
"outputs": [],
"source": [
"# Let's train the RF model and set the version name to \"rf\"\n",
"pipeline_settings[\"model_version\"].version = \"rf\"\n",
"pipeline_settings[\"model\"].version = \"rf\"\n",
"\n",
"# the `with_options` method allows us to pass in pipeline settings\n",
"# and returns a configured pipeline\n",
Expand Down Expand Up @@ -939,11 +939,11 @@
"@step\n",
"def inference_predict(dataset_inf: pd.DataFrame) -> Annotated[pd.Series, \"predictions\"]:\n",
" \"\"\"Predictions step\"\"\"\n",
" # Get the model_version\n",
" model_version = get_step_context().model_version\n",
" # Get the model\n",
" model = get_step_context().model\n",
"\n",
" # run prediction from memory\n",
" predictor = model_version.load_artifact(\"sklearn_classifier\")\n",
" predictor = model.load_artifact(\"sklearn_classifier\")\n",
" predictions = predictor.predict(dataset_inf)\n",
"\n",
" predictions = pd.Series(predictions, name=\"predicted\")\n",
Expand Down Expand Up @@ -994,7 +994,7 @@
"id": "c7afe7be",
"metadata": {},
"source": [
"The way to load the right model is to pass in the `production` stage into the `ModelVersion` config this time.\n",
"The way to load the right model is to pass in the `production` stage into the `Model` config this time.\n",
"This will ensure to always load the production model, decoupled from all other pipelines:"
]
},
Expand All @@ -1008,7 +1008,7 @@
"pipeline_settings = {\"enable_cache\": False}\n",
"\n",
"# Lets add some metadata to the model to make it identifiable\n",
"pipeline_settings[\"model_version\"] = ModelVersion(\n",
"pipeline_settings[\"model\"] = Model(\n",
" name=\"breast_cancer_classifier\",\n",
" version=\"production\", # We can pass in the stage name here!\n",
" license=\"Apache 2.0\",\n",
Expand Down
2 changes: 1 addition & 1 deletion template/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(
with open(pipeline_args["config_path"], "r") as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
zenml_model = client.get_model_version(
config["model_version"]["name"], config["model_version"]["version"]
config["model"]["name"], config["model"]["version"]
)
preprocess_pipeline_artifact = zenml_model.get_artifact("preprocess_pipeline")

Expand Down
12 changes: 6 additions & 6 deletions template/steps/model_promoter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,26 @@ def model_promoter(accuracy: float, stage: str = "production") -> bool:
is_promoted = True

# Get the model in the current context
current_model_version = get_step_context().model_version
current_model = get_step_context().model

# Get the model that is in the production stage
client = Client()
try:
stage_model_version = client.get_model_version(
current_model_version.name, stage
stage_model = client.get_model_version(
current_model.name, stage
)
# We compare their metrics
prod_accuracy = (
stage_model_version.get_artifact("sklearn_classifier")
stage_model.get_artifact("sklearn_classifier")
.run_metadata["test_accuracy"]
.value
)
if float(accuracy) > float(prod_accuracy):
# If current model has better metrics, we promote it
is_promoted = True
current_model_version.set_stage(stage, force=True)
current_model.set_stage(stage, force=True)
except KeyError:
# If no such model exists, current one is promoted
is_promoted = True
current_model_version.set_stage(stage, force=True)
current_model.set_stage(stage, force=True)
return is_promoted

0 comments on commit 6dbf10f

Please sign in to comment.