Skip to content

Commit 2ce5f34

Browse files
authored
Feat: Use dbt manifest to load dbt projects (#821)
1 parent b72f007 commit 2ce5f34

23 files changed

+587
-940
lines changed

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"google-cloud-bigquery-storage",
6464
"black==22.6.0",
6565
"dbt-core",
66+
"dbt-duckdb",
6667
"Faker",
6768
"google-auth",
6869
"isort==5.10.1",

sqlmesh/dbt/adapter.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ def __init__(
2323
self,
2424
jinja_macros: JinjaMacroRegistry,
2525
jinja_globals: t.Optional[t.Dict[str, t.Any]] = None,
26-
dialect: str = "",
2726
):
2827
self.jinja_macros = jinja_macros
29-
self.jinja_globals = jinja_globals or {}
30-
self.dialect = dialect
28+
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
29+
self.jinja_globals["adapter"] = self
3130

3231
@abc.abstractmethod
3332
def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]:
@@ -78,18 +77,15 @@ def quote(self, identifier: str) -> str:
7877

7978
def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable:
8079
"""Returns a dialect-specific version of a macro with the given name."""
81-
dialect_name = f"{self.dialect}__{name}"
82-
default_name = f"default__{name}"
83-
80+
target_type = self.jinja_globals["target"]["type"]
8481
references_to_try = [
85-
MacroReference(package=package, name=dialect_name),
86-
MacroReference(package=package, name=default_name),
82+
MacroReference(package=f"{package}_{target_type}", name=f"{target_type}__{name}"),
83+
MacroReference(package=package, name=f"{target_type}__{name}"),
84+
MacroReference(package=package, name=f"default__{name}"),
8785
]
8886

8987
for reference in references_to_try:
90-
macro_callable = self.jinja_macros.build_macro(
91-
reference, **{**self.jinja_globals, "adapter": self}
92-
)
88+
macro_callable = self.jinja_macros.build_macro(reference, **self.jinja_globals)
9389
if macro_callable is not None:
9490
return macro_callable
9591

@@ -141,7 +137,7 @@ def __init__(
141137
):
142138
from dbt.adapters.base.relation import Policy
143139

144-
super().__init__(jinja_macros, jinja_globals=jinja_globals, dialect=engine_adapter.dialect)
140+
super().__init__(jinja_macros, jinja_globals=jinja_globals)
145141

146142
self.engine_adapter = engine_adapter
147143
# All engines quote by default except Snowflake

sqlmesh/dbt/basemodel.py

+7-169
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,29 @@
55
from enum import Enum
66
from pathlib import Path
77

8-
from dbt.adapters.base import BaseRelation
98
from dbt.contracts.relation import RelationType
10-
from jinja2 import nodes
11-
from jinja2.exceptions import UndefinedError
129
from pydantic import Field, validator
1310
from sqlglot.helper import ensure_list
1411

15-
from sqlmesh.core import constants as c
1612
from sqlmesh.core import dialect as d
1713
from sqlmesh.core.config.base import UpdateStrategy
1814
from sqlmesh.core.model import Model
19-
from sqlmesh.dbt.adapter import ParsetimeAdapter
2015
from sqlmesh.dbt.column import (
2116
ColumnConfig,
2217
column_descriptions_to_sqlmesh,
2318
column_types_to_sqlmesh,
24-
yaml_to_columns,
2519
)
2620
from sqlmesh.dbt.common import DbtConfig, GeneralConfig, QuotingConfig, SqlStr
27-
from sqlmesh.dbt.context import DbtContext
2821
from sqlmesh.utils import AttributeDict
2922
from sqlmesh.utils.conversions import ensure_bool
30-
from sqlmesh.utils.date import date_dict
3123
from sqlmesh.utils.errors import ConfigError
32-
from sqlmesh.utils.jinja import MacroReference, extract_macro_references
24+
from sqlmesh.utils.jinja import MacroReference
3325
from sqlmesh.utils.pydantic import PydanticModel
3426

27+
if t.TYPE_CHECKING:
28+
from sqlmesh.dbt.context import DbtContext
29+
30+
3531
BMC = t.TypeVar("BMC", bound="BaseModelConfig")
3632

3733

@@ -43,21 +39,17 @@ class Dependencies(PydanticModel):
4339
macros: The references to macros
4440
sources: The "source_name.table_name" for source tables used
4541
refs: The table_name for models used
46-
variables: The names of variables used, mapped to a flag that indicates whether their
47-
definition is optional or not.
4842
"""
4943

5044
macros: t.Set[MacroReference] = set()
5145
sources: t.Set[str] = set()
5246
refs: t.Set[str] = set()
53-
variables: t.Set[str] = set()
5447

5548
def union(self, other: Dependencies) -> Dependencies:
5649
dependencies = Dependencies()
5750
dependencies.macros = self.macros | other.macros
5851
dependencies.sources = self.sources | other.sources
5952
dependencies.refs = self.refs | other.refs
60-
dependencies.variables = self.variables | other.variables
6153

6254
return dependencies
6355

@@ -101,7 +93,6 @@ class BaseModelConfig(GeneralConfig):
10193
storage_format: The storage format used to store the physical table, only applicable in certain engines.
10294
(eg. 'parquet')
10395
path: The file path of the model
104-
target_schema: The schema for the profile target
10596
dependencies: The macro, source, var, and ref dependencies used to execute the model and its hooks
10697
database: Database the model is stored in
10798
schema: Custom schema name added to the model schema name
@@ -119,12 +110,11 @@ class BaseModelConfig(GeneralConfig):
119110
stamp: t.Optional[str] = None
120111
storage_format: t.Optional[str] = None
121112
path: Path = Path()
122-
target_schema: str = ""
123113
dependencies: Dependencies = Dependencies()
124114

125115
# DBT configuration fields
116+
schema_: str = Field("", alias="schema")
126117
database: t.Optional[str] = None
127-
schema_: t.Optional[str] = Field(None, alias="schema")
128118
alias: t.Optional[str] = None
129119
pre_hook: t.List[Hook] = Field([], alias="pre-hook")
130120
post_hook: t.List[Hook] = Field([], alias="post-hook")
@@ -156,13 +146,6 @@ def _validate_bool(cls, v: str) -> bool:
156146
def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]:
157147
return {key: ensure_list(value) for key, value in v.items()}
158148

159-
@validator("columns", pre=True)
160-
def _validate_columns(cls, v: t.Any) -> t.Dict[str, ColumnConfig]:
161-
if isinstance(v, dict) and all(isinstance(col, ColumnConfig) for col in v.values()):
162-
return v
163-
164-
return yaml_to_columns(v)
165-
166149
_FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = {
167150
**GeneralConfig._FIELD_UPDATE_STRATEGY,
168151
**{
@@ -197,7 +180,7 @@ def table_schema(self) -> str:
197180
"""
198181
Get the full schema name
199182
"""
200-
return "_".join(part for part in (self.target_schema, self.schema_) if part)
183+
return self.schema_
201184

202185
@property
203186
def table_name(self) -> str:
@@ -293,21 +276,6 @@ def sqlmesh_model_kwargs(self, model_context: DbtContext) -> t.Dict[str, t.Any]:
293276
**optional_kwargs,
294277
}
295278

296-
def render_config(self: BMC, context: DbtContext) -> BMC:
297-
rendered = super().render_config(context)
298-
rendered = ModelSqlRenderer(context, rendered).enriched_config
299-
300-
rendered_dependencies = rendered.dependencies
301-
for dependency in rendered_dependencies.refs:
302-
model = context.models.get(dependency)
303-
if model and model.materialized == Materialization.EPHEMERAL:
304-
rendered.dependencies = rendered.dependencies.union(
305-
model.render_config(context).dependencies
306-
)
307-
rendered.dependencies.refs.discard(dependency)
308-
309-
return rendered
310-
311279
@abstractmethod
312280
def to_sqlmesh(self, context: DbtContext) -> Model:
313281
"""Convert DBT model into sqlmesh Model"""
@@ -338,135 +306,5 @@ def _context_for_dependencies(
338306
model_context.sources = sources
339307
model_context.seeds = seeds
340308
model_context.models = models
341-
model_context.variables = {
342-
name: value
343-
for name, value in context.variables.items()
344-
if name in dependencies.variables
345-
}
346309

347310
return model_context
348-
349-
350-
class ModelSqlRenderer(t.Generic[BMC]):
351-
def __init__(self, context: DbtContext, config: BMC):
352-
from sqlmesh.dbt.builtin import create_builtin_globals
353-
354-
self.context = context
355-
self.config = config
356-
357-
self._captured_dependencies: Dependencies = Dependencies()
358-
self._rendered_sql: t.Optional[str] = None
359-
self._enriched_config: BMC = config.copy()
360-
361-
self._jinja_globals = create_builtin_globals(
362-
jinja_macros=context.jinja_macros,
363-
jinja_globals={
364-
**context.jinja_globals,
365-
**date_dict(c.EPOCH, c.EPOCH, c.EPOCH),
366-
"config": lambda *args, **kwargs: "",
367-
"ref": self._ref,
368-
"var": self._var,
369-
"source": self._source,
370-
"this": self.config.relation_info,
371-
"model": self.config.model_function(),
372-
"schema": self.config.table_schema,
373-
},
374-
engine_adapter=None,
375-
)
376-
377-
# Set the adapter separately since it requires jinja globals to passed into it.
378-
self._jinja_globals["adapter"] = ModelSqlRenderer.TrackingAdapter(
379-
self,
380-
context.jinja_macros,
381-
jinja_globals=self._jinja_globals,
382-
dialect=context.engine_adapter.dialect if context.engine_adapter else "",
383-
)
384-
385-
self.jinja_env = self.context.jinja_macros.build_environment(**self._jinja_globals)
386-
387-
@property
388-
def enriched_config(self) -> BMC:
389-
if self._rendered_sql is None:
390-
self._enriched_config = self._update_with_sql_config(self._enriched_config)
391-
self._enriched_config.dependencies = Dependencies(
392-
macros=extract_macro_references(self._enriched_config.all_sql)
393-
)
394-
self.render()
395-
self._enriched_config.dependencies = self._enriched_config.dependencies.union(
396-
self._captured_dependencies
397-
)
398-
return self._enriched_config
399-
400-
def render(self) -> str:
401-
if self._rendered_sql is None:
402-
try:
403-
self._rendered_sql = self.jinja_env.from_string(
404-
self._enriched_config.all_sql
405-
).render()
406-
except UndefinedError as e:
407-
raise ConfigError(e.message)
408-
return self._rendered_sql
409-
410-
def _update_with_sql_config(self, config: BMC) -> BMC:
411-
def _extract_value(node: t.Any) -> t.Any:
412-
if not isinstance(node, nodes.Node):
413-
return node
414-
if isinstance(node, nodes.Const):
415-
return _extract_value(node.value)
416-
if isinstance(node, nodes.TemplateData):
417-
return _extract_value(node.data)
418-
if isinstance(node, nodes.List):
419-
return [_extract_value(val) for val in node.items]
420-
if isinstance(node, nodes.Dict):
421-
return {_extract_value(pair.key): _extract_value(pair.value) for pair in node.items}
422-
if isinstance(node, nodes.Tuple):
423-
return tuple(_extract_value(val) for val in node.items)
424-
425-
return self.jinja_env.from_string(nodes.Template([nodes.Output([node])])).render()
426-
427-
for call in self.jinja_env.parse(self._enriched_config.sql_embedded_config).find_all(
428-
nodes.Call
429-
):
430-
if not isinstance(call.node, nodes.Name) or call.node.name != "config":
431-
continue
432-
config = config.update_with(
433-
{kwarg.key: _extract_value(kwarg.value) for kwarg in call.kwargs}
434-
)
435-
436-
return config
437-
438-
def _ref(self, package_name: str, model_name: t.Optional[str] = None) -> BaseRelation:
439-
self._captured_dependencies.refs.add(package_name)
440-
return BaseRelation.create()
441-
442-
def _var(self, name: str, default: t.Optional[str] = None) -> t.Any:
443-
if default is None and name not in self.context.variables:
444-
raise ConfigError(
445-
f"Variable '{name}' was not found for model '{self.config.table_name}'."
446-
)
447-
self._captured_dependencies.variables.add(name)
448-
return self.context.variables.get(name, default)
449-
450-
def _source(self, source_name: str, table_name: str) -> BaseRelation:
451-
full_name = ".".join([source_name, table_name])
452-
self._captured_dependencies.sources.add(full_name)
453-
return BaseRelation.create()
454-
455-
class TrackingAdapter(ParsetimeAdapter):
456-
def __init__(self, outer_self: ModelSqlRenderer, *args: t.Any, **kwargs: t.Any):
457-
super().__init__(*args, **kwargs)
458-
self.outer_self = outer_self
459-
self.context = outer_self.context
460-
461-
def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable:
462-
macros = (
463-
self.context.jinja_macros.packages.get(package, {})
464-
if package is not None
465-
else self.context.jinja_macros.root_macros
466-
)
467-
for target_name in macros:
468-
if target_name.endswith(f"__{name}"):
469-
self.outer_self._captured_dependencies.macros.add(
470-
MacroReference(package=package, name=target_name)
471-
)
472-
return super().dispatch(name, package=package)

sqlmesh/dbt/builtin.py

+3-33
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,18 @@
22

33
import json
44
import os
5-
import sys
65
import typing as t
76
from ast import literal_eval
8-
from pathlib import Path
97

108
import agate
119
import jinja2
10+
from dbt import version
1211
from dbt.adapters.base import BaseRelation
1312
from dbt.contracts.relation import Policy
1413
from ruamel.yaml import YAMLError
1514

1615
from sqlmesh.core.engine_adapter import EngineAdapter
1716
from sqlmesh.dbt.adapter import ParsetimeAdapter, RuntimeAdapter
18-
from sqlmesh.dbt.context import DbtContext
19-
from sqlmesh.dbt.package import PackageLoader
2017
from sqlmesh.utils import AttributeDict, yaml
2118
from sqlmesh.utils.errors import ConfigError, MacroEvalError
2219
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReturnVal
@@ -250,30 +247,9 @@ def _try_literal_eval(value: str) -> t.Any:
250247
return value
251248

252249

253-
def _dbt_macro_registry() -> JinjaMacroRegistry:
254-
registry = JinjaMacroRegistry()
255-
256-
try:
257-
site_packages = next(
258-
p for p in sys.path if "site-packages" in p and Path(p, "dbt").exists()
259-
)
260-
except:
261-
return registry
262-
263-
for project_file in Path(site_packages).glob("dbt/include/*/dbt_project.yml"):
264-
if project_file.parent.stem == "starter_project":
265-
continue
266-
context = DbtContext(project_root=project_file.parent, jinja_macros=JinjaMacroRegistry())
267-
package = PackageLoader(context).load()
268-
registry.add_macros(package.macro_infos, package="dbt")
269-
270-
return registry
271-
272-
273-
DBT_MACRO_REGISTRY = _dbt_macro_registry()
274-
275250
BUILTIN_GLOBALS = {
276251
"api": Api(),
252+
"dbt_version": version.__version__,
277253
"env_var": env_var,
278254
"exceptions": Exceptions(),
279255
"flags": Flags(),
@@ -367,13 +343,7 @@ def create_builtin_globals(
367343
}
368344
)
369345

370-
builtin_globals.update(jinja_globals)
371-
if "dbt" not in builtin_globals:
372-
builtin_globals["dbt"] = DBT_MACRO_REGISTRY.build_environment(
373-
**builtin_globals
374-
).globals.get("dbt", {})
375-
376-
return builtin_globals
346+
return {**builtin_globals, **jinja_globals}
377347

378348

379349
def create_builtin_filters() -> t.Dict[str, t.Callable]:

0 commit comments

Comments
 (0)