Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allows data type diff and ensures valid migration separately #2150

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dlt/cli/command_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def init_command_wrapper(
destination_type: str,
repo_location: str,
branch: str,
omit_core_sources: bool = False,
eject_source: bool = False,
) -> None:
init_command(
source_name,
destination_type,
repo_location,
branch,
omit_core_sources,
eject_source,
)


Expand Down
31 changes: 18 additions & 13 deletions dlt/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _list_core_sources() -> Dict[str, SourceConfiguration]:
sources: Dict[str, SourceConfiguration] = {}
for source_name in files_ops.get_sources_names(core_sources_storage, source_type="core"):
sources[source_name] = files_ops.get_core_source_configuration(
core_sources_storage, source_name
core_sources_storage, source_name, eject_source=False
)
return sources

Expand Down Expand Up @@ -295,7 +295,7 @@ def init_command(
destination_type: str,
repo_location: str,
branch: str = None,
omit_core_sources: bool = False,
eject_source: bool = False,
) -> None:
# try to import the destination and get config spec
destination_reference = Destination.from_reference(destination_type)
Expand All @@ -310,13 +310,9 @@ def init_command(

# discover type of source
source_type: files_ops.TSourceType = "template"
if (
source_name in files_ops.get_sources_names(core_sources_storage, source_type="core")
) and not omit_core_sources:
if source_name in files_ops.get_sources_names(core_sources_storage, source_type="core"):
source_type = "core"
else:
if omit_core_sources:
fmt.echo("Omitting dlt core sources.")
verified_sources_storage = _clone_and_get_verified_sources_storage(repo_location, branch)
if source_name in files_ops.get_sources_names(
verified_sources_storage, source_type="verified"
Expand Down Expand Up @@ -380,7 +376,7 @@ def init_command(
else:
if source_type == "core":
source_configuration = files_ops.get_core_source_configuration(
core_sources_storage, source_name
core_sources_storage, source_name, eject_source
)
from importlib.metadata import Distribution

Expand All @@ -392,6 +388,9 @@ def init_command(

if canonical_source_name in extras:
source_configuration.requirements.update_dlt_extras(canonical_source_name)

# create remote modified index to copy files when ejecting
remote_modified = {file_name: None for file_name in source_configuration.files}
else:
if not is_valid_schema_name(source_name):
raise InvalidSchemaName(source_name)
Expand Down Expand Up @@ -536,11 +535,17 @@ def init_command(
"Creating a new pipeline with the dlt core source %s (%s)"
% (fmt.bold(source_name), source_configuration.doc)
)
fmt.echo(
"NOTE: Beginning with dlt 1.0.0, the source %s will no longer be copied from the"
" verified sources repo but imported from dlt.sources. You can provide the"
" --omit-core-sources flag to revert to the old behavior." % (fmt.bold(source_name))
)
if eject_source:
fmt.echo(
"NOTE: Source code of %s will be ejected. Remember to modify the pipeline "
"example script to import the ejected source." % (fmt.bold(source_name))
)
else:
fmt.echo(
"NOTE: Beginning with dlt 1.0.0, the source %s will no longer be copied from"
" the verified sources repo but imported from dlt.sources. You can provide the"
" --eject flag to revert to the old behavior." % (fmt.bold(source_name))
)
elif source_configuration.source_type == "verified":
fmt.echo(
"Creating and configuring a new pipeline with the verified source %s (%s)"
Expand Down
40 changes: 23 additions & 17 deletions dlt/cli/pipeline_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,39 @@ def get_template_configuration(
)


def _get_source_files(sources_storage: FileStorage, source_name: str) -> List[str]:
"""Get all files that belong to source `source_name`"""
files: List[str] = []
for root, subdirs, _files in os.walk(sources_storage.make_full_path(source_name)):
# filter unwanted files
for subdir in list(subdirs):
if any(fnmatch.fnmatch(subdir, ignore) for ignore in IGNORE_FILES):
subdirs.remove(subdir)
rel_root = sources_storage.to_relative_path(root)
files.extend(
[
os.path.join(rel_root, file)
for file in _files
if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES)
]
)
return files


def get_core_source_configuration(
sources_storage: FileStorage, source_name: str
sources_storage: FileStorage, source_name: str, eject_source: bool
) -> SourceConfiguration:
src_pipeline_file = CORE_SOURCE_TEMPLATE_MODULE_NAME + "/" + source_name + PIPELINE_FILE_SUFFIX
dest_pipeline_file = source_name + PIPELINE_FILE_SUFFIX
files: List[str] = _get_source_files(sources_storage, source_name) if eject_source else []

return SourceConfiguration(
"core",
"dlt.sources." + source_name,
sources_storage,
src_pipeline_file,
dest_pipeline_file,
[".gitignore"],
files,
SourceRequirements([]),
_get_docstring_for_module(sources_storage, source_name),
False,
Expand All @@ -259,21 +279,7 @@ def get_verified_source_configuration(
f"Pipeline example script {example_script} could not be found in the repository",
source_name,
)
# get all files recursively
files: List[str] = []
for root, subdirs, _files in os.walk(sources_storage.make_full_path(source_name)):
# filter unwanted files
for subdir in list(subdirs):
if any(fnmatch.fnmatch(subdir, ignore) for ignore in IGNORE_FILES):
subdirs.remove(subdir)
rel_root = sources_storage.to_relative_path(root)
files.extend(
[
os.path.join(rel_root, file)
for file in _files
if all(not fnmatch.fnmatch(file, ignore) for ignore in IGNORE_FILES)
]
)
files = _get_source_files(sources_storage, source_name)
# read requirements
requirements_path = os.path.join(source_name, utils.REQUIREMENTS_TXT)
if sources_storage.has_file(requirements_path):
Expand Down
10 changes: 3 additions & 7 deletions dlt/cli/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,10 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
)

parser.add_argument(
"--omit-core-sources",
"--eject",
default=False,
action="store_true",
help=(
"When present, will not create the new pipeline with a core source of the given"
" name but will take a source of this name from the default or provided"
" location."
),
help="Ejects the source code of the core source like sql_database",
)

def execute(self, args: argparse.Namespace) -> None:
Expand All @@ -107,7 +103,7 @@ def execute(self, args: argparse.Namespace) -> None:
args.destination,
args.location,
args.branch,
args.omit_core_sources,
args.eject,
)


Expand Down
70 changes: 48 additions & 22 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,35 +457,15 @@ def diff_table(
* when columns with the same name have different data types
* when table links to different parent tables
"""
if tab_a["name"] != tab_b["name"]:
raise TablePropertiesConflictException(
schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"]
)
table_name = tab_a["name"]
# check if table properties can be merged
if tab_a.get("parent") != tab_b.get("parent"):
raise TablePropertiesConflictException(
schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent")
)
# allow for columns to differ
ensure_compatible_tables(schema_name, tab_a, tab_b, ensure_columns=False)

# get new columns, changes in the column data type or other properties are not allowed
tab_a_columns = tab_a["columns"]
new_columns: List[TColumnSchema] = []
for col_b_name, col_b in tab_b["columns"].items():
if col_b_name in tab_a_columns:
col_a = tab_a_columns[col_b_name]
# we do not support changing data types of columns
if is_complete_column(col_a) and is_complete_column(col_b):
if not compare_complete_columns(tab_a_columns[col_b_name], col_b):
# attempt to update to incompatible columns
raise CannotCoerceColumnException(
schema_name,
table_name,
col_b_name,
col_b["data_type"],
tab_a_columns[col_b_name]["data_type"],
None,
)
# all other properties can change
merged_column = merge_column(copy(col_a), col_b)
if merged_column != col_a:
Expand All @@ -494,6 +474,8 @@ def diff_table(
new_columns.append(col_b)

# return partial table containing only name and properties that differ (column, filters etc.)
table_name = tab_a["name"]

partial_table: TPartialTableSchema = {
"name": table_name,
"columns": {} if new_columns is None else {c["name"]: c for c in new_columns},
Expand All @@ -519,6 +501,50 @@ def diff_table(
return partial_table


def ensure_compatible_tables(
schema_name: str, tab_a: TTableSchema, tab_b: TPartialTableSchema, ensure_columns: bool = True
) -> None:
"""Ensures that `tab_a` and `tab_b` can be merged without conflicts. Conflicts are detected when

- tables have different names
- nested tables have different parents
- tables have any column with incompatible types

Note: all the identifiers must be already normalized

"""
if tab_a["name"] != tab_b["name"]:
raise TablePropertiesConflictException(
schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"]
)
table_name = tab_a["name"]
# check if table properties can be merged
if tab_a.get("parent") != tab_b.get("parent"):
raise TablePropertiesConflictException(
schema_name, table_name, "parent", tab_a.get("parent"), tab_b.get("parent")
)

if not ensure_columns:
return

tab_a_columns = tab_a["columns"]
for col_b_name, col_b in tab_b["columns"].items():
if col_b_name in tab_a_columns:
col_a = tab_a_columns[col_b_name]
# we do not support changing data types of columns
if is_complete_column(col_a) and is_complete_column(col_b):
if not compare_complete_columns(tab_a_columns[col_b_name], col_b):
# attempt to update to incompatible columns
raise CannotCoerceColumnException(
schema_name,
table_name,
col_b_name,
col_b["data_type"],
tab_a_columns[col_b_name]["data_type"],
None,
)


# def compare_tables(tab_a: TTableSchema, tab_b: TTableSchema) -> bool:
# try:
# table_name = tab_a["name"]
Expand Down
40 changes: 9 additions & 31 deletions dlt/normalize/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LoadStorage,
ParsedLoadJobFileName,
)
from dlt.common.schema import TSchemaUpdate, Schema
from dlt.common.schema import Schema
from dlt.common.schema.exceptions import CannotCoerceColumnException
from dlt.common.pipeline import (
NormalizeInfo,
Expand All @@ -34,7 +34,7 @@
from dlt.normalize.configuration import NormalizeConfiguration
from dlt.normalize.exceptions import NormalizeJobFailed
from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV
from dlt.normalize.validate import verify_normalized_table
from dlt.normalize.validate import validate_and_update_schema, verify_normalized_table


# normalize worker wrapping function signature
Expand Down Expand Up @@ -80,16 +80,6 @@ def create_storages(self) -> None:
config=self.config._load_storage_config,
)

def update_schema(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None:
for schema_update in schema_updates:
for table_name, table_updates in schema_update.items():
logger.info(
f"Updating schema for table {table_name} with {len(table_updates)} deltas"
)
for partial_table in table_updates:
# merge columns where we expect identifiers to be normalized
schema.update_table(partial_table, normalize_identifiers=False)

def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWorkerRV:
workers: int = getattr(self.pool, "_max_workers", 1)
chunk_files = group_worker_files(files, workers)
Expand Down Expand Up @@ -123,7 +113,7 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW
result: TWorkerRV = pending.result()
try:
# gather schema from all manifests, validate consistency and combine
self.update_schema(schema, result[0])
validate_and_update_schema(schema, result[0])
summary.schema_updates.extend(result.schema_updates)
summary.file_metrics.extend(result.file_metrics)
# update metrics
Expand Down Expand Up @@ -162,7 +152,7 @@ def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TWor
load_id,
files,
)
self.update_schema(schema, result.schema_updates)
validate_and_update_schema(schema, result.schema_updates)
self.collector.update("Files", len(result.file_metrics))
self.collector.update(
"Items", sum(result.file_metrics, EMPTY_DATA_WRITER_METRICS).items_count
Expand Down Expand Up @@ -237,23 +227,11 @@ def spool_schema_files(self, load_id: str, schema: Schema, files: Sequence[str])
self.load_storage.import_extracted_package(
load_id, self.normalize_storage.extracted_packages
)
logger.info(f"Created new load package {load_id} on loading volume")
try:
# process parallel
self.spool_files(
load_id, schema.clone(update_normalizers=True), self.map_parallel, files
)
except CannotCoerceColumnException as exc:
# schema conflicts resulting from parallel executing
logger.warning(
f"Parallel schema update conflict, switching to single thread ({str(exc)}"
)
# start from scratch
self.load_storage.new_packages.delete_package(load_id)
self.load_storage.import_extracted_package(
load_id, self.normalize_storage.extracted_packages
)
self.spool_files(load_id, schema.clone(update_normalizers=True), self.map_single, files)
logger.info(f"Created new load package {load_id} on loading volume with ")
# get number of workers with default == 1 if not set (ie. NullExecutor)
workers: int = getattr(self.pool, "_max_workers", 1)
map_f: TMapFuncType = self.map_parallel if workers > 1 else self.map_single
self.spool_files(load_id, schema.clone(update_normalizers=True), map_f, files)

return load_id

Expand Down
20 changes: 19 additions & 1 deletion dlt/normalize/validate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import List

from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.schema import Schema
from dlt.common.schema.typing import TTableSchema
from dlt.common.schema.typing import TTableSchema, TSchemaUpdate
from dlt.common.schema.utils import (
ensure_compatible_tables,
find_incomplete_columns,
get_first_column_name_with_prop,
is_nested_table,
Expand All @@ -10,6 +13,21 @@
from dlt.common import logger


def validate_and_update_schema(schema: Schema, schema_updates: List[TSchemaUpdate]) -> None:
"""Updates `schema` tables with partial tables in `schema_updates`"""
for schema_update in schema_updates:
for table_name, table_updates in schema_update.items():
logger.info(f"Updating schema for table {table_name} with {len(table_updates)} deltas")
for partial_table in table_updates:
# ensure updates will pass
if existing_table := schema.tables.get(partial_table["name"]):
ensure_compatible_tables(schema.name, existing_table, partial_table)

for partial_table in table_updates:
# merge columns where we expect identifiers to be normalized
schema.update_table(partial_table, normalize_identifiers=False)


def verify_normalized_table(
schema: Schema, table: TTableSchema, capabilities: DestinationCapabilitiesContext
) -> None:
Expand Down
Loading
Loading