Skip to content

Commit

Permalink
allows data type diff and ensures valid migration separately (#2150)
Browse files Browse the repository at this point in the history
* allows data type diff and ensures valid migration separately

* removes dlt init flag to skip core sources, adds flag to eject core source
  • Loading branch information
rudolfix authored Dec 16, 2024
1 parent 4a051b0 commit 1b0d7b2
Show file tree
Hide file tree
Showing 16 changed files with 399 additions and 311 deletions.
4 changes: 2 additions & 2 deletions dlt/cli/command_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def init_command_wrapper(
destination_type: str,
repo_location: str,
branch: str,
omit_core_sources: bool = False,
eject_source: bool = False,
) -> None:
init_command(
source_name,
destination_type,
repo_location,
branch,
omit_core_sources,
eject_source,
)


Expand Down
31 changes: 18 additions & 13 deletions dlt/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _list_core_sources() -> Dict[str, SourceConfiguration]:
sources: Dict[str, SourceConfiguration] = {}
for source_name in files_ops.get_sources_names(core_sources_storage, source_type="core"):
sources[source_name] = files_ops.get_core_source_configuration(
core_sources_storage, source_name
core_sources_storage, source_name, eject_source=False
)
return sources

Expand Down Expand Up @@ -295,7 +295,7 @@ def init_command(
destination_type: str,
repo_location: str,
branch: str = None,
omit_core_sources: bool = False,
eject_source: bool = False,
) -> None:
# try to import the destination and get config spec
destination_reference = Destination.from_reference(destination_type)
Expand All @@ -310,13 +310,9 @@ def init_command(

# discover type of source
source_type: files_ops.TSourceType = "template"
if (
source_name in files_ops.get_sources_names(core_sources_storage, source_type="core")
) and not omit_core_sources:
if source_name in files_ops.get_sources_names(core_sources_storage, source_type="core"):
source_type = "core"
else:
if omit_core_sources:
fmt.echo("Omitting dlt core sources.")
verified_sources_storage = _clone_and_get_verified_sources_storage(repo_location, branch)
if source_name in files_ops.get_sources_names(
verified_sources_storage, source_type="verified"
Expand Down Expand Up @@ -380,7 +376,7 @@ def init_command(
else:
if source_type == "core":
source_configuration = files_ops.get_core_source_configuration(
core_sources_storage, source_name
core_sources_storage, source_name, eject_source
)
from importlib.metadata import Distribution

Expand All @@ -392,6 +388,9 @@ def init_command(

if canonical_source_name in extras:
source_configuration.requirements.update_dlt_extras(canonical_source_name)

# create remote modified index to copy files when ejecting
remote_modified = {file_name: None for file_name in source_configuration.files}
else:
if not is_valid_schema_name(source_name):
raise InvalidSchemaName(source_name)
Expand Down Expand Up @@ -536,11 +535,17 @@ def init_command(
"Creating a new pipeline with the dlt core source %s (%s)"
% (fmt.bold(source_name), source_configuration.doc)
)
fmt.echo(
"NOTE: Beginning with dlt 1.0.0, the source %s will no longer be copied from the"
" verified sources repo but imported from dlt.sources. You can provide the"
" --omit-core-sources flag to revert to the old behavior." % (fmt.bold(source_name))
)
if eject_source:
fmt.echo(
"NOTE: Source code of %s will be ejected. Remember to modify the pipeline "
"example script to import the ejected source." % (fmt.bold(source_name))
)
else:
fmt.echo(
"NOTE: Beginning with dlt 1.0.0, the source %s will no longer be copied from"
" the verified sources repo but imported from dlt.sources. You can provide the"
" --eject flag to revert to the old behavior." % (fmt.bold(source_name))
)
elif source_configuration.source_type == "verified":
fmt.echo(
"Creating and configuring a new pipeline with the verified source %s (%s)"
Expand Down
40 changes: 23 additions & 17 deletions dlt/cli/pipeline_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,39 @@ def get_template_configuration(
)


def _get_source_files(sources_storage: FileStorage, source_name: str) -> List[str]:
"""Get all files that belong to source `source_name`"""
files: List[str] = []
for root, subdirs, _files in os.walk(sources_storage.make_full_path(source_name)):
# filter unwanted files
for subdir in list(subdirs):
if any(fnmatch.fnmatch(subdir, ignore) for ignore in IGNORE_FILES):
subdirs.remove(subdir)
rel_root = sources_storage.to_relative_path(root)
files.extend(
[
os.path.join(rel_root, file)
for file in _files
if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES)
]
)
return files


def get_core_source_configuration(
sources_storage: FileStorage, source_name: str
sources_storage: FileStorage, source_name: str, eject_source: bool
) -> SourceConfiguration:
src_pipeline_file = CORE_SOURCE_TEMPLATE_MODULE_NAME + "/" + source_name + PIPELINE_FILE_SUFFIX
dest_pipeline_file = source_name + PIPELINE_FILE_SUFFIX
files: List[str] = _get_source_files(sources_storage, source_name) if eject_source else []

return SourceConfiguration(
"core",
"dlt.sources." + source_name,
sources_storage,
src_pipeline_file,
dest_pipeline_file,
[".gitignore"],
files,
SourceRequirements([]),
_get_docstring_for_module(sources_storage, source_name),
False,
Expand All @@ -259,21 +279,7 @@ def get_verified_source_configuration(
f"Pipeline example script {example_script} could not be found in the repository",
source_name,
)
# get all files recursively
files: List[str] = []
for root, subdirs, _files in os.walk(sources_storage.make_full_path(source_name)):
# filter unwanted files
for subdir in list(subdirs):
if any(fnmatch.fnmatch(subdir, ignore) for ignore in IGNORE_FILES):
subdirs.remove(subdir)
rel_root = sources_storage.to_relative_path(root)
files.extend(
[
os.path.join(rel_root, file)
for file in _files
if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES)
]
)
files = _get_source_files(sources_storage, source_name)
# read requirements
requirements_path = os.path.join(source_name, utils.REQUIREMENTS_TXT)
if sources_storage.has_file(requirements_path):
Expand Down
10 changes: 3 additions & 7 deletions dlt/cli/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,10 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
)

parser.add_argument(
"--omit-core-sources",
"--eject",
default=False,
action="store_true",
help=(
"When present, will not create the new pipeline with a core source of the given"
" name but will take a source of this name from the default or provided"
" location."
),
help="Ejects the source code of the core source like sql_database",
)

def execute(self, args: argparse.Namespace) -> None:
Expand All @@ -107,7 +103,7 @@ def execute(self, args: argparse.Namespace) -> None:
args.destination,
args.location,
args.branch,
args.omit_core_sources,
args.eject,
)


Expand Down
70 changes: 48 additions & 22 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,35 +457,15 @@ def diff_table(
* when columns with the same name have different data types
* when table links to different parent tables
"""
if tab_a["name"] != tab_b["name"]:
raise TablePropertiesConflictException(
schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"]
)
table_name = tab_a["name"]
# check if table properties can be merged
if tab_a.get("parent") != tab_b.get("parent"):
raise TablePropertiesConflictException(
schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent")
)
# allow for columns to differ
ensure_compatible_tables(schema_name, tab_a, tab_b, ensure_columns=False)

# get new columns, changes in the column data type or other properties are not allowed
tab_a_columns = tab_a["columns"]
new_columns: List[TColumnSchema] = []
for col_b_name, col_b in tab_b["columns"].items():
if col_b_name in tab_a_columns:
col_a = tab_a_columns[col_b_name]
# we do not support changing data types of columns
if is_complete_column(col_a) and is_complete_column(col_b):
if not compare_complete_columns(tab_a_columns[col_b_name], col_b):
# attempt to update to incompatible columns
raise CannotCoerceColumnException(
schema_name,
table_name,
col_b_name,
col_b["data_type"],
tab_a_columns[col_b_name]["data_type"],
None,
)
# all other properties can change
merged_column = merge_column(copy(col_a), col_b)
if merged_column != col_a:
Expand All @@ -494,6 +474,8 @@ def diff_table(
new_columns.append(col_b)

# return partial table containing only name and properties that differ (column, filters etc.)
table_name = tab_a["name"]

partial_table: TPartialTableSchema = {
"name": table_name,
"columns": {} if new_columns is None else {c["name"]: c for c in new_columns},
Expand All @@ -519,6 +501,50 @@ def diff_table(
return partial_table


def ensure_compatible_tables(
schema_name: str, tab_a: TTableSchema, tab_b: TPartialTableSchema, ensure_columns: bool = True
) -> None:
"""Ensures that `tab_a` and `tab_b` can be merged without conflicts. Conflicts are detected when
- tables have different names
- nested tables have different parents
- tables have any column with incompatible types
Note: all the identifiers must be already normalized
"""
if tab_a["name"] != tab_b["name"]:
raise TablePropertiesConflictException(
schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"]
)
table_name = tab_a["name"]
# check if table properties can be merged
if tab_a.get("parent") != tab_b.get("parent"):
raise TablePropertiesConflictException(
schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent")
)

if not ensure_columns:
return

tab_a_columns = tab_a["columns"]
for col_b_name, col_b in tab_b["columns"].items():
if col_b_name in tab_a_columns:
col_a = tab_a_columns[col_b_name]
# we do not support changing data types of columns
if is_complete_column(col_a) and is_complete_column(col_b):
if not compare_complete_columns(tab_a_columns[col_b_name], col_b):
# attempt to update to incompatible columns
raise CannotCoerceColumnException(
schema_name,
table_name,
col_b_name,
col_b["data_type"],
tab_a_columns[col_b_name]["data_type"],
None,
)


# def compare_tables(tab_a: TTableSchema, tab_b: TTableSchema) -> bool:
# try:
# table_name = tab_a["name"]
Expand Down
40 changes: 9 additions & 31 deletions dlt/normalize/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LoadStorage,
ParsedLoadJobFileName,
)
from dlt.common.schema import TSchemaUpdate, Schema
from dlt.common.schema import Schema
from dlt.common.schema.exceptions import CannotCoerceColumnException
from dlt.common.pipeline import (
NormalizeInfo,
Expand All @@ -34,7 +34,7 @@
from dlt.normalize.configuration import NormalizeConfiguration
from dlt.normalize.exceptions import NormalizeJobFailed
from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV
from dlt.normalize.validate import verify_normalized_table
from dlt.normalize.validate import validate_and_update_schema, verify_normalized_table


# normalize worker wrapping function signature
Expand Down Expand Up @@ -80,16 +80,6 @@ def create_storages(self) -> None:
config=self.config._load_storage_config,
)

def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None:
for schema_update in schema_updates:
for table_name, table_updates in schema_update.items():
logger.info(
f"Updating schema for table {table_name} with {len(table_updates)} deltas"
)
for partial_table in table_updates:
# merge columns where we expect identifiers to be normalized
schema.update_table(partial_table, normalize_identifiers=False)

def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV:
workers: int = getattr(self.pool, "_max_workers", 1)
chunk_files = group_worker_files(files, workers)
Expand Down Expand Up @@ -123,7 +113,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW
result: TWorkerRV = pending.result()
try:
# gather schema from all manifests, validate consistency and combine
self.update_schema(schema, result[0])
validate_and_update_schema(schema, result[0])
summary.schema_updates.extend(result.schema_updates)
summary.file_metrics.extend(result.file_metrics)
# update metrics
Expand Down Expand Up @@ -162,7 +152,7 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWor
load_id,
files,
)
self.update_schema(schema, result.schema_updates)
validate_and_update_schema(schema, result.schema_updates)
self.collector.update("Files", len(result.file_metrics))
self.collector.update(
"Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count
Expand Down Expand Up @@ -237,23 +227,11 @@ def spool_schema_files(self, load_id: str, schema: Schema, files: Sequence[str])
self.load_storage.import_extracted_package(
load_id, self.normalize_storage.extracted_packages
)
logger.info(f"Created new load package {load_id} on loading volume")
try:
# process parallel
self.spool_files(
load_id, schema.clone(update_normalizers=True), self.map_parallel, files
)
except CannotCoerceColumnException as exc:
# schema conflicts resulting from parallel executing
logger.warning(
f"Parallel schema update conflict, switching to single thread ({str(exc)}"
)
# start from scratch
self.load_storage.new_packages.delete_package(load_id)
self.load_storage.import_extracted_package(
load_id, self.normalize_storage.extracted_packages
)
self.spool_files(load_id, schema.clone(update_normalizers=True), self.map_single, files)
logger.info(f"Created new load package {load_id} on loading volume with ")
# get number of workers with default == 1 if not set (ie. NullExecutor)
workers: int = getattr(self.pool, "_max_workers", 1)
map_f: TMapFuncType = self.map_parallel if workers > 1 else self.map_single
self.spool_files(load_id, schema.clone(update_normalizers=True), map_f, files)

return load_id

Expand Down
20 changes: 19 additions & 1 deletion dlt/normalize/validate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import List

from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.schema import Schema
from dlt.common.schema.typing import TTableSchema
from dlt.common.schema.typing import TTableSchema, TSchemaUpdate
from dlt.common.schema.utils import (
ensure_compatible_tables,
find_incomplete_columns,
get_first_column_name_with_prop,
is_nested_table,
Expand All @@ -10,6 +13,21 @@
from dlt.common import logger


def validate_and_update_schema(schema: Schema, schema_updates: List[TSchemaUpdate]) -> None:
"""Updates `schema` tables with partial tables in `schema_updates`"""
for schema_update in schema_updates:
for table_name, table_updates in schema_update.items():
logger.info(f"Updating schema for table {table_name} with {len(table_updates)} deltas")
for partial_table in table_updates:
# ensure updates will pass
if existing_table := schema.tables.get(partial_table["name"]):
ensure_compatible_tables(schema.name, existing_table, partial_table)

for partial_table in table_updates:
# merge columns where we expect identifiers to be normalized
schema.update_table(partial_table, normalize_identifiers=False)


def verify_normalized_table(
schema: Schema, table: TTableSchema, capabilities: DestinationCapabilitiesContext
) -> None:
Expand Down
Loading

0 comments on commit 1b0d7b2

Please sign in to comment.