Skip to content

Commit e4289dd

Browse files
committed
Encode additional parameters in Component.data and store in 1 table
1 parent a830a79 commit e4289dd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+572
-589
lines changed

.github/workflows/ci_code.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ jobs:
103103
104104
- name: Unit Testing
105105
run: |
106+
sqlite3 test.db "create table t(f int); drop table t;"
106107
make unit_testing pytest_arguments="--cov=superduper --cov-report=xml" SUPERDUPER_CONFIG=test/configs/${{ matrix.config }}
107108
108109
- name: Usecase Testing

plugins/sql/superduper_sql/data_backend.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,10 @@ class SQLDatabackend(IbisDataBackend):
359359

360360
def __init__(self, uri, plugin, flavour=None):
361361
super().__init__(uri, plugin, flavour)
362-
self._create_sqlalchemy_engine()
362+
if 'sqlite://./' in uri:
363+
self._create_sqlalchemy_engine(uri.replace('./', '//'))
364+
else:
365+
self._create_sqlalchemy_engine(uri)
363366
self.sm = sessionmaker(bind=self.alchemy_engine)
364367

365368
@property
@@ -374,6 +377,8 @@ def update(self, table, condition, key, value):
374377
with self.sm() as session:
375378
metadata = MetaData()
376379

380+
assert table in self.list_tables()
381+
377382
metadata.reflect(bind=session.bind)
378383
table = Table(table, metadata, autoload_with=session.bind)
379384

@@ -422,16 +427,16 @@ def delete(self, table, condition):
422427
except NoSuchTableError:
423428
raise exceptions.NotFound("Table", table)
424429

425-
def _create_sqlalchemy_engine(self):
430+
def _create_sqlalchemy_engine(self, uri):
426431
with self.connection_manager.get_connection() as conn:
427-
self.alchemy_engine = create_engine(self.uri, creator=lambda: conn.con)
432+
self.alchemy_engine = create_engine(uri, creator=lambda: conn.con)
428433
if not self._test_engine():
429434
logging.warn(
430435
"Unable to reuse the ibis connection "
431436
"to create the SQLAlchemy engine. "
432437
"Creating a new connection with the URI."
433438
)
434-
self.alchemy_engine = create_engine(self.uri)
439+
self.alchemy_engine = create_engine(uri)
435440

436441
def _test_engine(self):
437442
"""Test the engine."""

superduper/backends/base/data_backend.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -882,9 +882,9 @@ def insert(self, table, documents):
882882
except exceptions.NotFound:
883883
pid = None
884884

885-
if ('uuid' == pid or not pid) and "uuid" in documents[0]:
885+
if table in {'Component', 'Deployment'}:
886886
for r in documents:
887-
self[table, r['identifier'], r['uuid']] = r
887+
self[table, r['component'], r['identifier'], r['uuid']] = r
888888
ids.append(r['uuid'])
889889
elif pid:
890890
pid = self.primary_id(table)
@@ -944,13 +944,7 @@ def do_test(r):
944944
return False
945945
return True
946946

947-
tables = self.get_many('Table', query.table, '*')
948-
if not tables:
949-
raise exceptions.NotFound("Table", query.table)
950-
951-
is_component = max(tables, key=lambda x: x['version'])['is_component']
952-
953-
if not is_component:
947+
if query.table not in {'Component', 'Deployment'}:
954948
pid = self.primary_id(query.table)
955949

956950
if pid in filter_kwargs:
@@ -966,31 +960,33 @@ def do_test(r):
966960
else:
967961

968962
if not filter_kwargs:
969-
keys = self.keys(query.table, '*', '*')
963+
keys = self.keys(query.table, '*', '*', '*')
970964
docs = [self[k] for k in keys]
971965
elif set(filter_kwargs.keys()) == {'uuid'}:
972-
keys = self.keys(query.table, '*', filter_kwargs['uuid']['value'])
966+
keys = self.keys(query.table, '*', '*', filter_kwargs['uuid']['value'])
973967
docs = [self[k] for k in keys]
974968
elif set(filter_kwargs.keys()) == {'identifier'}:
975969
assert filter_kwargs['identifier']['op'] == '=='
976-
977970
keys = self.keys(query.table, filter_kwargs['identifier']['value'], '*')
978971
docs = [self[k] for k in keys]
979-
elif set(filter_kwargs.keys()) == {'identifier', 'uuid'}:
972+
elif set(filter_kwargs.keys()) == {'component', 'identifier', 'uuid'}:
980973
assert filter_kwargs['identifier']['op'] == '=='
981974
assert filter_kwargs['uuid']['op'] == '=='
975+
assert filter_kwargs['component']['op'] == '=='
982976

983977
r = self[
984978
query.table,
979+
filter_kwargs['component']['value'],
985980
filter_kwargs['identifier']['value'],
986981
filter_kwargs['uuid']['value'],
987982
]
988983
if r is None:
989984
docs = []
990985
else:
991986
docs = [r]
992-
elif set(filter_kwargs.keys()) == {'identifier', 'version'}:
987+
elif set(filter_kwargs.keys()) == {'component', 'identifier', 'version'}:
993988
assert filter_kwargs['identifier']['op'] == '=='
989+
assert filter_kwargs['component']['op'] == '=='
994990
assert filter_kwargs['version']['op'] == '=='
995991

996992
keys = self.keys(query.table, filter_kwargs['identifier']['value'], '*')
@@ -999,8 +995,7 @@ def do_test(r):
999995
r for r in docs if r['version'] == filter_kwargs['version']['value']
1000996
]
1001997
else:
1002-
1003-
keys = self.keys(query.table, '*', '*')
998+
keys = self.keys(query.table, '*', '*', '*')
1004999
docs = [self[k] for k in keys]
10051000
docs = [r for r in docs if do_test(r)]
10061001

superduper/backends/local/vector_search.py

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -42,38 +42,31 @@ def build_tool(self, component):
4242

4343
def initialize(self):
4444
"""Initialize the vector search."""
45-
components = []
46-
from superduper import VectorIndex
47-
48-
for cls in self.db.show('Table'):
49-
t = self.db.load('Table', identifier=cls)
50-
if t.is_component and t.cls is not None:
51-
if issubclass(t.cls, VectorIndex):
52-
components.append(t.identifier)
45+
t = self.db['Deployment']
46+
components = (
47+
t.filter(t['parent'] == 'VectorIndex')
48+
.select('component', 'uuid', 'identifier')
49+
.execute()
50+
)
51+
5352
for component in components:
5453
try:
55-
for identifier in self.db.show(component):
56-
try:
57-
vector_index = self.db.load(component, identifier=identifier)
58-
self.put_component(component, vector_index.uuid)
59-
vectors = vector_index.get_vectors()
60-
vectors = [VectorItem(**vector) for vector in vectors]
61-
self.get_tool(vector_index.uuid).add(vectors)
62-
63-
except FileNotFoundError:
64-
logging.error(
65-
f'Could not load vector index: {identifier} '
66-
'Is the artifact store correctly configured?'
67-
)
68-
continue
69-
except TypeError as e:
70-
import traceback
71-
72-
logging.error(
73-
f'Could not load vector index: {identifier} ' f'{e}'
74-
)
75-
logging.error(traceback.format_exc())
76-
continue
54+
self.put_component(component['component'], uuid=component['uuid'])
55+
56+
except FileNotFoundError:
57+
logging.error(
58+
f'Could not load vector index: {component["identifier"]} '
59+
'Is the artifact store correctly configured?'
60+
)
61+
continue
62+
except TypeError as e:
63+
import traceback
64+
65+
logging.error(
66+
f'Could not load vector index: {component["identifier"]} ' f'{e}'
67+
)
68+
logging.error(traceback.format_exc())
69+
continue
7770
except exceptions.NotFound:
7871
pass
7972

superduper/base/annotations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def decorated(self, *, context: str = '', job: bool = False, **kwargs):
4949

5050
if job:
5151
return Job(
52-
component=self.__class__.__name__,
52+
parent_component=self.__class__.__name__,
5353
identifier=self.identifier,
5454
uuid=self.uuid,
5555
method=f.__name__,

superduper/base/apply.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ def _apply(
251251
deprecated_context: str | None = None,
252252
):
253253

254-
if object.status == STATUS_UNINITIALIZED:
255-
object.status, object.details = pending_status()
254+
# if object.status == STATUS_UNINITIALIZED:
255+
# object.status, object.details = pending_status()
256256

257257
processed_components = processed_components or set()
258258
if context is None:
@@ -311,7 +311,6 @@ def wrapper(child):
311311
if current.hash == object.hash:
312312
apply_status = 'same'
313313
object.version = current.version
314-
object.status, object.details = running_status()
315314
elif current.uuid == object.uuid:
316315
apply_status = 'update'
317316
object.version = current.version
@@ -361,16 +360,23 @@ def wrapper(child):
361360
jobs=list(job_events.values()),
362361
context=context,
363362
)
364-
# TODO - add multiple services (potentially empty)
365-
for service in object.services:
366-
put_events[f'{object.huuid}/{service}'] = PutComponent(
367-
component=object.component,
368-
identifier=object.identifier,
369-
uuid=object.uuid,
370-
context=context,
371-
version=object.version,
372-
service=service,
373-
)
363+
filter = getattr(object, 'filter_deployment', None)
364+
if filter:
365+
filter = getattr(object, filter)
366+
367+
branch, parent = object.get_branch_and_parent()
368+
put_events[object.huuid] = PutComponent(
369+
component=object.component,
370+
identifier=object.identifier,
371+
uuid=object.uuid,
372+
context=context,
373+
version=object.version,
374+
services=object.services,
375+
branch=branch,
376+
parent=parent,
377+
tags={tag: getattr(object, tag) for tag in object.tags},
378+
filter=filter,
379+
)
374380

375381
elif apply_status == 'breaking':
376382
metadata_event = Create(
@@ -387,15 +393,23 @@ def wrapper(child):
387393
jobs=list(job_events.values()),
388394
context=context,
389395
)
390-
for service in object.services:
391-
put_events[f'{object.huuid}/{service}'] = PutComponent(
392-
component=object.component,
393-
identifier=object.identifier,
394-
uuid=object.uuid,
395-
context=context,
396-
version=object.version,
397-
service=service,
398-
)
396+
filter = getattr(object, 'filter_deployment', None)
397+
if filter:
398+
filter = getattr(object, filter)
399+
400+
branch, parent = object.get_branch_and_parent()
401+
put_events[object.huuid] = PutComponent(
402+
component=object.component,
403+
identifier=object.identifier,
404+
uuid=object.uuid,
405+
context=context,
406+
version=object.version,
407+
services=object.services,
408+
branch=branch,
409+
tags={tag: getattr(object, tag) for tag in object.tags},
410+
parent=parent,
411+
filter=filter,
412+
)
399413

400414
d = db['Deployment']
401415
assert deprecated_context is not None

superduper/base/base.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,26 @@
2020
class _UniqueRegistry:
2121
def __init__(self, d):
2222
self.d = d
23+
24+
def __contains__(self, item):
25+
if '.' in item:
26+
return item in self.d
27+
else:
28+
return item in [k.split('.')[-1] for k in self.d]
29+
30+
def drop(self):
31+
for k in list(self.d.keys()):
32+
del self.d[k]
33+
2334
def __getitem__(self, item):
24-
return self.d[item]
35+
if '.' in item:
36+
return self.d[item]
37+
else:
38+
try:
39+
key = next(k for k in self.d if k.split('.')[-1] == item)
40+
except StopIteration:
41+
raise KeyError(item)
42+
return self.d[key]
2543

2644
def __setitem__(self, key, value):
2745
matching = [x for x in self.d.keys() if x.split('.')[-1] == key.split('.')[-1]]
@@ -108,7 +126,7 @@ class Base(metaclass=BaseMeta):
108126
"""Base class for all superduper classes."""
109127

110128
verbosity: t.ClassVar[int] = 0
111-
set_post_init: t.ClassVar[t.Sequence[str]] = ()
129+
primary_id: t.ClassVar[str] = 'uuid'
112130

113131
def __init_subclass__(cls):
114132
full_path = f"{cls.__module__}.{cls.__name__}"
@@ -151,16 +169,14 @@ def class_schema(cls):
151169

152170
@lazy_classproperty
153171
def table(cls):
154-
from superduper import Component
155172
from superduper.components.table import Table
156-
from superduper.misc.importing import isreallyinstance
157173

174+
if cls.__name__ == 'Base':
175+
return None
158176
return Table(
159177
identifier=cls.__name__,
160178
fields=cls.class_schema.fields,
161-
path=cls.__module__ + '.' + cls.__name__,
162-
primary_id='uuid',
163-
is_component=isreallyinstance(cls, Component),
179+
primary_id=cls.primary_id,
164180
)
165181

166182
@staticmethod
@@ -298,10 +314,8 @@ def decode(cls, r, db: t.Optional['Datalayer'] = None):
298314
"""
299315
from superduper.base.document import Document
300316

301-
if '_path' in r:
302-
from superduper.misc.importing import import_object
303-
304-
cls = import_object(r['_path'])
317+
if 'component' in r:
318+
cls = REGISTRY[r['component']]
305319

306320
r = Document.decode(r, schema=cls.class_schema, db=db)
307321
return cls.from_dict(r, db=db)
@@ -334,6 +348,7 @@ def encode(
334348
del r[k]
335349
if r is None:
336350
r = self.dict(metadata=context.metadata)
351+
r['component'] = self.__class__.__name__
337352

338353
if not context.defaults:
339354
for k, v in list(r.items()):
@@ -443,7 +458,6 @@ def dict(self, metadata: bool = True) -> t.Dict[str, t.Any]:
443458
from superduper import Document
444459

445460
r = asdict(self)
446-
r['_path'] = self.__class__.__module__ + '.' + self.__class__.__name__
447461
if metadata:
448462
metadata = getattr(self, 'metadata', {})
449463
for k, v in metadata.items():

superduper/base/datalayer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from superduper.backends.base.data_backend import BaseDataBackend
1212
from superduper.base import apply, exceptions
1313
from superduper.base.artifacts import ArtifactStore
14-
from superduper.base.base import Base
14+
from superduper.base.base import REGISTRY, Base
1515
from superduper.base.config import Config
1616
from superduper.base.datatype import BaseType, ComponentType
1717
from superduper.base.document import Document
@@ -95,8 +95,6 @@ def insert(self, items: t.List[Base]):
9595
"""
9696
table = self.pre_insert(items)
9797
data = [x.dict() for x in items]
98-
for r in data:
99-
del r['_path']
10098
return self[table.identifier].insert(data)
10199

102100
def replace(self, condition: t.Dict, item: Base):
@@ -309,7 +307,9 @@ def pre_insert(
309307
return self.metadata.create(type(items[0]))
310308

311309
def _post_query(self, table: str, ids: t.Sequence[str], type_: str):
312-
if table in metaclasses or self.metadata.is_component(table):
310+
if table in metaclasses or (
311+
table in REGISTRY and issubclass(REGISTRY[table], Component)
312+
):
313313
return
314314
if (
315315
not table.startswith(CFG.output_prefix)

0 commit comments

Comments
 (0)