Skip to content

Commit 17651e2

Browse files
committed
Formalising dlt schema sync
1 parent f33f65f commit 17651e2

File tree

21 files changed

+526
-230
lines changed

21 files changed

+526
-230
lines changed

dlt/common/destination/client.py

Lines changed: 132 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,14 @@
3333
from dlt.common.metrics import LoadJobMetrics
3434
from dlt.common.normalizers.naming import NamingConvention
3535

36-
from dlt.common.schema import Schema, TSchemaTables
36+
from dlt.common.schema import Schema, TSchemaTables, TSchemaDrop
3737
from dlt.common.schema.typing import (
38+
C_DLT_ID,
3839
C_DLT_LOAD_ID,
3940
TLoaderReplaceStrategy,
4041
TTableFormat,
42+
TTableSchemaColumns,
43+
TPartialTableSchema,
4144
)
4245
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
4346
from dlt.common.destination.exceptions import (
@@ -539,36 +542,120 @@ def update_stored_schema(
539542
)
540543
return expected_update
541544

542-
def update_stored_schema_destructively(
545+
def update_dlt_schema(
543546
self,
544-
) -> None:
545-
"""
546-
Compare the schema we think we should have (`self.schema`)
547-
with what actually exists in the destination, and drop any
548-
columns that disappeared.
549-
"""
550-
for table in self.schema.data_tables():
551-
table_name = table["name"]
552-
553-
actual_columns = self._get_actual_columns(table_name)
554-
schema_columns = self.schema.get_table_columns(table_name)
555-
dropped_columns = set(schema_columns.keys()) - set(actual_columns)
556-
if dropped_columns:
557-
for dropped_col in dropped_columns:
558-
if schema_columns[dropped_col].get("increment"):
559-
logger.warning(
560-
"An incremental field is being removed from schema."
561-
"You should unset the"
562-
" incremental with `incremental=dlt.sources.incremental.EMPTY`"
563-
)
564-
self.schema.drop_columns(table_name, list(dropped_columns))
565-
566-
def _get_actual_columns(self, table_name: str) -> List[str]: # noqa: B027, optional override
567-
"""
568-
Return a list of column names that currently exist in the
569-
destination for `table_name`.
547+
table_names: Iterable[str] = None,
548+
dry_run: bool = False,
549+
) -> Optional[TSchemaDrop]:
550+
"""Updates schema to the storage.
551+
552+
Compare the schema we think we should have (`self.schema`) with what actually exists in the destination,
553+
and drop any tables and/or columns that disappeared.
554+
555+
Args:
556+
table_names (Iterable[str], optional): Check only listed tables. Defaults to None and checks all tables.
557+
558+
Returns:
559+
Optional[TSchemaTables]: Returns an update that was applied to the schema.
570560
"""
571-
pass
561+
from dlt.destinations.sql_client import WithSqlClient
562+
563+
if not (isinstance(self, WithTableReflection) and isinstance(self, WithSqlClient)):
564+
raise NotImplementedError
565+
566+
def _diff_between_actual_and_dlt_schema(
567+
table_name: str, actual_col_names: set[str], disregard_dlt_columns: bool = True
568+
) -> TPartialTableSchema:
569+
"""Returns a partial table schema containing columns that exist in the dlt schema
570+
but are missing from the actual table. Skips dlt internal columns by default.
571+
"""
572+
col_schemas = self.schema.get_table_columns(table_name)
573+
574+
# Map escaped -> original names (actual_col_names are escaped)
575+
escaped_to_original = {
576+
self.sql_client.escape_column_name(col, quote=False): col
577+
for col in col_schemas.keys()
578+
}
579+
dropped_col_names = set(escaped_to_original.keys()) - actual_col_names
580+
581+
if not dropped_col_names:
582+
return {}
583+
584+
partial_table: TPartialTableSchema = {"name": table_name, "columns": {}}
585+
586+
for esc_name in dropped_col_names:
587+
orig_name = escaped_to_original[esc_name]
588+
589+
# Athena doesn't have dlt columns in actual columns. Don't drop them anyway.
590+
if disregard_dlt_columns and orig_name in [C_DLT_ID, C_DLT_LOAD_ID]:
591+
continue
592+
593+
col_schema = col_schemas[orig_name]
594+
if col_schema.get("increment"):
595+
# We can warn within the for loop,
596+
# since there's only one incremental field per table
597+
logger.warning(
598+
f"An incremental field {orig_name} is being removed from schema."
599+
"You should unset the"
600+
" incremental with `incremental=dlt.sources.incremental.EMPTY`"
601+
)
602+
partial_table["columns"][orig_name] = col_schema
603+
604+
return partial_table if partial_table["columns"] else {}
605+
606+
tables = table_names if table_names else self.schema.data_table_names()
607+
608+
table_drops: TSchemaDrop = {} # includes entire tables to drop
609+
column_drops: TSchemaDrop = {} # includes parts of tables to drop as partial tables
610+
611+
# 1. Detect what needs to be dropped
612+
for table_name in tables:
613+
_, actual_col_schemas = list(self.get_storage_tables([table_name]))[0]
614+
615+
# no actual column schemas ->
616+
# table doesn't exist ->
617+
# we take entire table schema as a schema drop
618+
if not actual_col_schemas:
619+
table = self.schema.get_table(table_name)
620+
table_drops[table_name] = table
621+
continue
622+
623+
# actual column schemas present ->
624+
# we compare actual schemas with dlt ones ->
625+
# we take the difference as a partial table
626+
else:
627+
partial_table = _diff_between_actual_and_dlt_schema(
628+
table_name,
629+
set(actual_col_schemas.keys()),
630+
)
631+
if partial_table:
632+
column_drops[table_name] = partial_table
633+
634+
# 2. For entire table drops, we make sure no orphaned tables remain
635+
for table_name in table_drops.copy():
636+
child_tables = self.schema.get_child_tables(table_name)
637+
orphaned_table_names: List[str] = []
638+
for child_table in child_tables:
639+
if child_table["name"] not in table_drops:
640+
orphaned_table_names.append(child_table["name"])
641+
if orphaned_table_names:
642+
table_drops.pop(table_name)
643+
logger.warning(
644+
f"Removing table '{table_name}' from the dlt schema would leave orphan"
645+
f" table(s): {'.'.join(repr(t) for t in orphaned_table_names)}. Drop these"
646+
" child tables in the destination and sync the dlt schema again."
647+
)
648+
649+
# 3. If it's not a dry run, we actually drop fromt the dlt schema
650+
if not dry_run:
651+
for table_name in table_drops:
652+
self.schema.tables.pop(table_name)
653+
for table_name, partial_table in column_drops.items():
654+
col_schemas = partial_table["columns"]
655+
col_names = [col for col in col_schemas]
656+
self.schema.drop_columns(table_name, col_names)
657+
658+
return {**table_drops, **column_drops}
572659

573660
def prepare_load_table(self, table_name: str) -> PreparedTableSchema:
574661
"""Prepares a table schema to be loaded by filling missing hints and doing other modifications requires by given destination.
@@ -639,6 +726,22 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
639726
pass
640727

641728

729+
class WithTableReflection(ABC):
730+
@abstractmethod
731+
def get_storage_tables(
732+
self, table_names: Iterable[str]
733+
) -> Iterable[Tuple[str, TTableSchemaColumns]]:
734+
"""Uses INFORMATION_SCHEMA to retrieve table and column information for tables in `table_names` iterator.
735+
Table names should be normalized according to naming convention and will be further converted to desired casing
736+
in order to (in most cases) create case-insensitive name suitable for search in information schema.
737+
738+
The column names are returned as in information schema. To match those with columns in existing table, you'll need to use
739+
`schema.get_new_table_columns` method and pass the correct casing. Most of the casing function are irreversible so it is not
740+
possible to convert identifiers into INFORMATION SCHEMA back into case sensitive dlt schema.
741+
"""
742+
pass
743+
744+
642745
class WithStagingDataset(ABC):
643746
"""Adds capability to use staging dataset and request it from the loader"""
644747

dlt/common/schema/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
TColumnHint,
99
TColumnSchema,
1010
TColumnSchemaBase,
11+
TSchemaDrop,
1112
)
1213
from dlt.common.schema.typing import COLUMN_HINTS
1314
from dlt.common.schema.schema import Schema, DEFAULT_SCHEMA_CONTRACT_MODE
1415
from dlt.common.schema.exceptions import DataValidationError
1516
from dlt.common.schema.utils import verify_schema_hash
1617

1718
__all__ = [
19+
"TSchemaDrop",
1820
"TSchemaUpdate",
1921
"TSchemaTables",
2022
"TTableSchema",

dlt/common/schema/schema.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,11 +463,24 @@ def drop_tables(
463463
self.data_item_normalizer.remove_table(table_name)
464464
return result
465465

466-
def drop_columns(self, table_name: str, column_names: Sequence[str]) -> List[TColumnSchema]:
467-
"""Drops columns from the table schema and returns the dropped columns"""
466+
def drop_columns(self, table_name: str, column_names: Sequence[str]) -> TPartialTableSchema:
467+
"""Drops columns from the table schema and returns the table schema with the dropped columns"""
468+
table: TPartialTableSchema = {"name": table_name}
469+
dropped_col_schemas: TTableSchemaColumns = {}
470+
471+
for col in column_names:
472+
col_schema = self._schema_tables[table["name"]]["columns"].pop(col)
473+
dropped_col_schemas[col] = col_schema
474+
475+
table["columns"] = dropped_col_schemas
476+
return table
477+
478+
def get_child_tables(self, table_name: str) -> List[TTableSchema]:
479+
"""Returns child tables"""
468480
result = []
469-
for col_name in column_names:
470-
result.append(self._schema_tables[table_name]["columns"].pop(col_name))
481+
for table in self.data_tables():
482+
if table.get("parent", None) == table_name:
483+
result.append(table)
471484
return result
472485

473486
def filter_row_with_hint(

dlt/common/schema/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ class TPartialTableSchema(TTableSchema):
315315

316316
TSchemaTables = Dict[str, TTableSchema]
317317
TSchemaUpdate = Dict[str, List[TPartialTableSchema]]
318+
TSchemaDrop = Dict[str, TPartialTableSchema]
318319
TColumnDefaultHint = Literal["not_null", TColumnHint]
319320
"""Allows using not_null in default hints setting section"""
320321

dlt/destinations/impl/athena/athena.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
TSortOrder,
3838
)
3939
from dlt.common.destination import DestinationCapabilitiesContext, PreparedTableSchema
40-
from dlt.common.destination.client import FollowupJobRequest, SupportsStagingDestination, LoadJob
40+
from dlt.common.destination.client import (
41+
FollowupJobRequest,
42+
SupportsStagingDestination,
43+
LoadJob,
44+
WithTableReflection,
45+
)
4146
from dlt.destinations.sql_jobs import (
4247
SqlStagingCopyFollowupJob,
4348
SqlStagingReplaceFollowupJob,
@@ -191,7 +196,7 @@ def _parse_and_log_lf_response(
191196
logger.debug(f"Success: {verb} LF tags {lf_tags} to " + resource_msg)
192197

193198

194-
class AthenaClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
199+
class AthenaClient(SqlJobClientWithStagingDataset, SupportsStagingDestination, WithTableReflection):
195200
def __init__(
196201
self,
197202
schema: Schema,

dlt/destinations/impl/bigquery/bigquery.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
RunnableLoadJob,
1717
SupportsStagingDestination,
1818
LoadJob,
19+
WithTableReflection,
1920
)
2021
from dlt.common.json import json
2122
from dlt.common.runtime.signals import sleep
@@ -174,7 +175,9 @@ def gen_key_table_clauses(
174175
return sql
175176

176177

177-
class BigQueryClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
178+
class BigQueryClient(
179+
SqlJobClientWithStagingDataset, SupportsStagingDestination, WithTableReflection
180+
):
178181
def __init__(
179182
self,
180183
schema: Schema,

dlt/destinations/impl/clickhouse/clickhouse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
RunnableLoadJob,
2222
FollowupJobRequest,
2323
LoadJob,
24+
WithTableReflection,
2425
)
2526
from dlt.common.schema import Schema, TColumnSchema
2627
from dlt.common.schema.typing import (
@@ -212,7 +213,9 @@ def requires_temp_table_for_delete(cls) -> bool:
212213
return True
213214

214215

215-
class ClickHouseClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
216+
class ClickHouseClient(
217+
SqlJobClientWithStagingDataset, SupportsStagingDestination, WithTableReflection
218+
):
216219
def __init__(
217220
self,
218221
schema: Schema,

dlt/destinations/impl/databricks/databricks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
RunnableLoadJob,
1414
SupportsStagingDestination,
1515
LoadJob,
16+
WithTableReflection,
1617
)
1718
from dlt.common.configuration.specs import (
1819
AwsCredentialsWithoutDefaults,
@@ -302,7 +303,9 @@ def gen_delete_from_sql(
302303
"""
303304

304305

305-
class DatabricksClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
306+
class DatabricksClient(
307+
SqlJobClientWithStagingDataset, SupportsStagingDestination, WithTableReflection
308+
):
306309
def __init__(
307310
self,
308311
schema: Schema,

dlt/destinations/impl/dremio/dremio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
SupportsStagingDestination,
1111
FollowupJobRequest,
1212
LoadJob,
13+
WithTableReflection,
1314
)
1415
from dlt.common.schema import TColumnSchema, Schema
1516
from dlt.common.schema.typing import TColumnType, TTableFormat
@@ -97,7 +98,7 @@ def run(self) -> None:
9798
""")
9899

99100

100-
class DremioClient(SqlJobClientWithStagingDataset, SupportsStagingDestination):
101+
class DremioClient(SqlJobClientWithStagingDataset, SupportsStagingDestination, WithTableReflection):
101102
def __init__(
102103
self,
103104
schema: Schema,

dlt/destinations/impl/duckdb/duck.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
RunnableLoadJob,
1010
HasFollowupJobs,
1111
LoadJob,
12+
WithTableReflection,
1213
)
1314
from dlt.common.schema.typing import TColumnSchema, TColumnType, TTableFormat
1415
from dlt.common.schema.utils import has_default_column_prop_value
@@ -49,7 +50,7 @@ def run(self) -> None:
4950
)
5051

5152

52-
class DuckDbClient(InsertValuesJobClient):
53+
class DuckDbClient(InsertValuesJobClient, WithTableReflection):
5354
def __init__(
5455
self,
5556
schema: Schema,

0 commit comments

Comments
 (0)