diff --git a/src/rybak/__init__.py b/src/rybak/__init__.py index 686f73c..a4c7297 100644 --- a/src/rybak/__init__.py +++ b/src/rybak/__init__.py @@ -4,49 +4,38 @@ 'render', ] -from pathlib import Path -from typing import Any, Iterable, Mapping, Optional, Type, Union +from pathlib import Path, PurePath +from typing import Iterable, Union from ._types import RenderError, TemplateData from .adapter import RendererAdapter -from .pycompat import Traversable from .tree_renderer import RenderContext, TreeRenderer def render( - template_root: Traversable, target_root: Path, - adapter: Union[Type[RendererAdapter], RendererAdapter], + adapter: RendererAdapter, data: TemplateData, *, - renderer_args: Optional[Mapping[str, Any]] = None, - excluded: Iterable[Path] = (), + excluded: Union[Iterable[Path], Iterable[str]] = (), remove_suffixes: Iterable[str] = (), ) -> None: """Render a directory-tree from a template and a data dictionary - :param template_root: root template directory (filesystem or importlib resource) :param target_root: render target root directory (filesystem) - :param adapter: template engine adapter (jinja, mako, tornado) or its type + :param adapter: template engine adapter (jinja, mako) :param data: template data - :param renderer_args: parameters for template engine adapter, when just the adapter class is passed :param excluded: paths within the template root directory, which are not templates :param remove_suffixes: filename suffixes to be removed when rendering file names, in `.suffix` format """ - actual_renderer = ( - adapter - if isinstance(adapter, RendererAdapter) - else adapter(template_root=template_root, **renderer_args if renderer_args else {}) - ) - + exclude_paths = {Path(path) for path in excluded} TreeRenderer( RenderContext( - template_root=template_root, target_root=target_root, - adapter=actual_renderer, - excluded=excluded, + adapter=adapter, + excluded=exclude_paths, remove_suffixes=remove_suffixes, ), - Path(), + PurePath(), Path(), ).render(data) diff --git a/src/rybak/adapter.py b/src/rybak/adapter.py index 85b67be..066dc26 100644 --- a/src/rybak/adapter.py +++ b/src/rybak/adapter.py @@ -3,6 +3,7 @@ from typing import Any from ._types import LoopOverFn, TemplateData +from .pycompat import Traversable class RendererAdapter(abc.ABC): @@ -17,3 +18,8 @@ def render_str(self, template: str, data: TemplateData, loop_over: LoopOverFn) - @abc.abstractmethod def render_file(self, template_path: str, target_file: Path, data: TemplateData) -> None: pass + + @property + @abc.abstractmethod + def template_root(self) -> Traversable: + pass diff --git a/src/rybak/jinja.py b/src/rybak/jinja.py index 316ba6a..6fdefba 100644 --- a/src/rybak/jinja.py +++ b/src/rybak/jinja.py @@ -1,22 +1,38 @@ from pathlib import Path -from typing import Any, Optional +from typing import Optional, cast import jinja2 from ._types import LoopOverFn, RenderError, TemplateData from .adapter import RendererAdapter +from .pycompat import Traversable, files class JinjaAdapter(RendererAdapter): """Adapter for Jinja engine. - Unless you pass your own jinja.Environment instance, the default for keep_trailing_newline is True.""" + Unless you pass your own jinja.Environment instance, the default for keep_trailing_newline is True, + and the default loader is FileSystemLoader.""" - def __init__(self, environment: Optional[jinja2.Environment] = None, **env_kwargs: Any) -> None: - keep_trailing_newline = env_kwargs.pop('keep_trailing_newline', True) - self._env = environment or jinja2.Environment( - keep_trailing_newline=keep_trailing_newline, - **env_kwargs, - ) + def __init__( + self, + environment: Optional[jinja2.Environment] = None, + loader: Optional[jinja2.BaseLoader] = None, + keep_trailing_newline: Optional[bool] = True, + ) -> None: + """Create adapter for Jinja Environment. Only either `loader` or `environment` is accepted.""" + + if environment: + if loader: + raise ValueError('Set loader in the Jinja environment') + elif not loader: + raise ValueError('Either environment or loader is required') + + if not environment: + self._env = jinja2.Environment(loader=loader, keep_trailing_newline=keep_trailing_newline) + else: + if keep_trailing_newline is not None: + self._env = environment.overlay() + self._env.keep_trailing_newline = self._keep_trailing_newline def render_str(self, template: str, data: TemplateData, loop_over: Optional[LoopOverFn] = None) -> str: env = self._env.overlay() @@ -35,3 +51,16 @@ def render_file(self, template_path: str, target_file: Path, data: TemplateData) except (jinja2.TemplateError, ValueError) as e: raise RenderError from e target_file.write_text(text) + + @property + def template_root(self) -> Traversable: + loader = self._env.loader + if isinstance(loader, jinja2.FileSystemLoader): + path = loader.searchpath + if len(path) != 1: + raise ValueError('Template root path must be a single path') + return Path(path[0]) + elif isinstance(loader, jinja2.PackageLoader): + return cast(Traversable, files(loader.package_name) / loader.package_path) + else: + raise TypeError(type(loader)) diff --git a/src/rybak/mako.py b/src/rybak/mako.py index 330d81a..4cbdbf5 100644 --- a/src/rybak/mako.py +++ b/src/rybak/mako.py @@ -6,12 +6,13 @@ from pathlib import Path from typing import Optional -import mako.exceptions +import mako.exceptions # type: ignore[import-untyped] import mako.lookup import mako.template from ._types import LoopOverFn, RenderError, TemplateData from .adapter import RendererAdapter +from .pycompat import Traversable class MakoAdapter(RendererAdapter): @@ -33,6 +34,13 @@ def render_file(self, template_path: str, target_file: Path, data: TemplateData) raise RenderError from e target_file.write_text(text) + @property + def template_root(self) -> Traversable: + paths = self._loader.directories + if len(paths) != 1: + raise ValueError('Template root path must be a single path') + return Path(paths[0]) + @functools.lru_cache(maxsize=10) def str_template(text: str) -> mako.template.Template: diff --git a/src/rybak/pycompat.py b/src/rybak/pycompat.py index 9f47b0b..2751b9a 100644 --- a/src/rybak/pycompat.py +++ b/src/rybak/pycompat.py @@ -1,13 +1,16 @@ __all__ = [ 'Traversable', 'TypeAlias', + 'files', ] import sys if sys.version_info >= (3, 12): + from importlib.resources import files from importlib.resources.abc import Traversable else: + from importlib_resources import files # type: ignore[import-not-found] from importlib_resources.abc import Traversable # type: ignore[import-not-found] if sys.version_info >= (3, 10): diff --git a/src/rybak/tree_renderer.py b/src/rybak/tree_renderer.py index 0ef8726..a22b745 100644 --- a/src/rybak/tree_renderer.py +++ b/src/rybak/tree_renderer.py @@ -2,7 +2,7 @@ import logging import os.path from functools import cached_property -from pathlib import Path +from pathlib import Path, PurePath from typing import Iterable, NoReturn from ._types import LoopOverFn, RenderFn, TemplateData @@ -32,19 +32,14 @@ def loop_over(items: Iterable) -> NoReturn: @dataclasses.dataclass class RenderContext: - template_root: Traversable target_root: Path adapter: RendererAdapter - excluded: Iterable[Path] = () - remove_suffixes: Iterable[str] = () - - def __post_init__(self): - if not self.template_root.is_dir(): - raise ValueError('template_root must exist and be a directory', self.template_root) + excluded: Iterable[PurePath] + remove_suffixes: Iterable[str] class TreeRenderer: - def __init__(self, context: RenderContext, template_path: Path, target_path: Path) -> None: + def __init__(self, context: RenderContext, template_path: PurePath, target_path: Path) -> None: self._context = context self._template_path = template_path self._target_path = target_path @@ -61,7 +56,7 @@ def _render(self, file_name: str, data: TemplateData) -> None: logger.debug('Excluded %s', rel_path) return - path = self._context.template_root / str(rel_path) + path = self._context.adapter.template_root / rel_path if path.is_dir(): render_single = self._render_dir else: @@ -130,7 +125,7 @@ def _with_subdir(self, template_name: str, target_name: str) -> 'TreeRenderer': @cached_property def _full_template_path(self) -> Traversable: - return self._context.template_root / str(self._template_path) + return self._context.adapter.template_root / str(self._template_path) @cached_property def _full_target_path(self) -> Path: diff --git a/tests/test_render.py b/tests/test_render.py index 5fa2337..2f48cdf 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -1,9 +1,11 @@ from itertools import product from pathlib import Path -from typing import Any, Iterable, Mapping, NamedTuple, Optional +from typing import Any, Callable, Iterable, Mapping, NamedTuple, Optional +import jinja2 import pytest from rybak import RenderError, render +from rybak.adapter import RendererAdapter from rybak.jinja import JinjaAdapter from rybak.mako import MakoAdapter @@ -69,7 +71,7 @@ class TestData(NamedTuple): ] adapters = { - 'jinja': JinjaAdapter, + 'jinja': lambda template_root: JinjaAdapter(loader=jinja2.FileSystemLoader(template_root)), 'mako': MakoAdapter, } @@ -79,21 +81,28 @@ class TestData(NamedTuple): } adapter_test_data = [ - (adapter, *param_set, exclusions[adapter]) for adapter, param_set in product(adapters.keys(), jinja_test_data) + (*adapter, *param_set, exclusions[adapter[0]]) for adapter, param_set in product(adapters.items(), jinja_test_data) ] -@pytest.mark.parametrize('renderer,test_name,data,error,excluded', adapter_test_data) -def test_render(renderer: str, test_name: str, data: Mapping, error: bool, excluded: Iterable, tmp_path: Path) -> None: +@pytest.mark.parametrize('adapter_name,adapter,test_name,data,error,excluded', adapter_test_data) +def test_render( + adapter_name: str, + adapter: Callable[[Path], RendererAdapter], + test_name: str, + data: Mapping, + error: bool, + excluded: Iterable, + tmp_path: Path, +) -> None: root = Path(__file__).parent / 'test_render' - target_path = tmp_path / f'{renderer}_{test_name}' + target_path = tmp_path / f'{adapter_name}_{test_name}' target_path.mkdir() def fn(): render( - root / 'templates' / renderer / test_name, target_path, - adapters[renderer], + adapter(root / 'templates' / adapter_name / test_name), data, excluded=[Path(item) for item in excluded] + [Path('__pycache__')], remove_suffixes=['.jinja', '.mako'],