Skip to content

Commit

Permalink
Fixed model cleanup on cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Dec 11, 2024
1 parent cec41d5 commit 8114119
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Deprecate vanilla `DataType`
- Remove `_Encodable` from project
- Add model component cleanup

#### New Features & Functionality

Expand Down
10 changes: 10 additions & 0 deletions superduper/backends/base/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,13 @@ def initialize(self, with_compute: bool = False):
self.vector_search.initialize()
self.crontab.initialize()
self.cdc.initialize()

def drop_component(self, uuid: str):
"""Drop component and its services rom the cluster.
:param uuid: Component uuid.
"""
try:
del self.cache[uuid]
except KeyError:
pass
6 changes: 6 additions & 0 deletions superduper/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,9 @@ def db(self, value: 'Datalayer'):
:param value: ``Datalayer`` instance.
"""
self._db = value

def drop_component(self, uuid: str):
"""Drop the component from compute.
:param uuid: Component uuid.
"""
6 changes: 2 additions & 4 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def remove(

for v in sorted(versions_in_use):
self.metadata.hide_component_version(type_id, identifier, v)

else:
logging.warn('aborting.')

Expand Down Expand Up @@ -664,10 +665,7 @@ def _remove_component_version(
type_id, identifier, version=version, allow_hidden=force
)
component.cleanup(self)
try:
del self.cluster.cache[component.uuid]
except KeyError:
pass
self.cluster.drop_component(component.uuid)

self._delete_artifacts(r['uuid'], info)
self.metadata.delete_component_version(type_id, identifier, version=version)
Expand Down
7 changes: 7 additions & 0 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,13 @@ def __post_init__(self, db, example):
if not self.identifier:
raise Exception('_Predictor identifier must be non-empty')

def cleanup(self, db: "Datalayer") -> None:
"""Clean up when the model is deleted.
:param db: Data layer instance to process.
"""
db.cluster.compute.drop_component(self.uuid)

@property
def inputs(self) -> Inputs:
"""Instance of `Inputs` to represent model params."""
Expand Down
4 changes: 2 additions & 2 deletions superduper/ext/llm/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class RetrievalPrompt(QueryModel):
prompt_introduction: str = PROMPT_INTRODUCTION
join: str = "\n---\n"

def __post_init__(self, db):
def __post_init__(self, db, example):
assert 'prompt' in self.select.variables
return super().__post_init__(db)
return super().__post_init__(db, example)

@property
def inputs(self):
Expand Down
4 changes: 3 additions & 1 deletion superduper/rest/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,16 @@ def _add_templates(self, db):

existing = db.show('template')
for t in self.templates:
if t in existing:
if t in existing or t is None:
logging.info(f'Found existing template: {t}')
continue
logging.info(f'Applying template: {t}')

import os

if os.path.exists(t):
from superduper import Template

t = Template.read(t)
else:
t = templates.get(t)
Expand Down
2 changes: 2 additions & 0 deletions superduper/templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def ls():
def __getattr__(name: str):
import re

breakpoint()

if not re.match('.*[0-9]+\.[0-9]+\.[0-9]+.*', name):
assert name in TEMPLATES, f'{name} not in supported templates {TEMPLATES}'
file = TEMPLATES[name].split('/')[-1]
Expand Down

0 comments on commit 8114119

Please sign in to comment.