Skip to content

Commit

Permalink
Fixed model cleanup on cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Dec 10, 2024
1 parent cec41d5 commit bec7a1b
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 9 deletions.
10 changes: 10 additions & 0 deletions superduper/backends/base/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,13 @@ def initialize(self, with_compute: bool = False):
self.vector_search.initialize()
self.crontab.initialize()
self.cdc.initialize()

def drop_component(self, uuid: str):
"""Drop component and its services rom the cluster.
:param uuid: Component uuid.
"""
try:
del self.cache[uuid]
except KeyError:
pass
6 changes: 6 additions & 0 deletions superduper/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,9 @@ def db(self, value: 'Datalayer'):
:param value: ``Datalayer`` instance.
"""
self._db = value

def drop_component(self, uuid: str):
"""Drop the component from compute.
:param uuid: Component uuid.
"""
4 changes: 2 additions & 2 deletions superduper/backends/local/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def put_bytes(

with open(path, 'wb') as f:
f.write(serialized)
os.chmod(path, 0o777)
os.chmod(path, 0x777)

def get_bytes(self, file_id: str) -> bytes:
"""
Expand Down Expand Up @@ -117,7 +117,7 @@ def put_file(self, file_path: str, file_id: str):
shutil.copytree(file_path, save_path)
else:
shutil.copy(file_path, save_path)
os.chmod(save_path, 0o777)
os.chmod(save_path, 0x777)
return file_id

def get_file(self, file_id: str) -> str:
Expand Down
6 changes: 2 additions & 4 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def remove(

for v in sorted(versions_in_use):
self.metadata.hide_component_version(type_id, identifier, v)

else:
logging.warn('aborting.')

Expand Down Expand Up @@ -664,10 +665,7 @@ def _remove_component_version(
type_id, identifier, version=version, allow_hidden=force
)
component.cleanup(self)
try:
del self.cluster.cache[component.uuid]
except KeyError:
pass
self.cluster.drop_component(component.uuid)

self._delete_artifacts(r['uuid'], info)
self.metadata.delete_component_version(type_id, identifier, version=version)
Expand Down
7 changes: 7 additions & 0 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,13 @@ def __post_init__(self, db, example):
if not self.identifier:
raise Exception('_Predictor identifier must be non-empty')

def cleanup(self, db: "Datalayer") -> None:
"""Clean up when the model is deleted.
:param db: Data layer instance to process.
"""
db.cluster.compute.drop_component(self.uuid)

@property
def inputs(self) -> Inputs:
"""Instance of `Inputs` to represent model params."""
Expand Down
4 changes: 2 additions & 2 deletions superduper/ext/llm/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class RetrievalPrompt(QueryModel):
prompt_introduction: str = PROMPT_INTRODUCTION
join: str = "\n---\n"

def __post_init__(self, db):
def __post_init__(self, db, example):
assert 'prompt' in self.select.variables
return super().__post_init__(db)
return super().__post_init__(db, example)

@property
def inputs(self):
Expand Down
3 changes: 2 additions & 1 deletion superduper/rest/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def _add_templates(self, db):

existing = db.show('template')
for t in self.templates:
if t in existing:

if t in existing or t is None:
logging.info(f'Found existing template: {t}')
continue
logging.info(f'Applying template: {t}')
Expand Down
1 change: 1 addition & 0 deletions superduper/templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def ls():

def __getattr__(name: str):
import re
breakpoint()

if not re.match('.*[0-9]+\.[0-9]+\.[0-9]+.*', name):
assert name in TEMPLATES, f'{name} not in supported templates {TEMPLATES}'
Expand Down

0 comments on commit bec7a1b

Please sign in to comment.