Skip to content

Commit

Permalink
Refactor file handling
Browse files Browse the repository at this point in the history
  • Loading branch information
kdeldycke committed Nov 3, 2024
1 parent 0075ef4 commit 828af15
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 73 deletions.
107 changes: 49 additions & 58 deletions meta_package_manager/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
import logging
import sys
from collections import Counter, namedtuple
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from io import TextIOWrapper
from operator import attrgetter
from pathlib import Path
from typing import Iterable
from typing import IO, Iterable
from unittest.mock import patch

import tomli_w
Expand All @@ -49,10 +48,7 @@
)
from click_extra.colorize import KO, OK, highlight
from click_extra.colorize import default_theme as theme
from extra_platforms import (
is_windows,
reduce,
)
from extra_platforms import is_windows, reduce

if sys.version_info >= (3, 11):
import tomllib
Expand Down Expand Up @@ -108,16 +104,11 @@ def is_stdout(filepath: Path) -> bool:
return str(filepath) == "-"


@contextmanager
def file_writer(filepath, mode: str = "w"):
"""A context-aware file writer which default to stdout if no path is
provided."""
def prep_path(filepath: Path) -> IO | None:
"""Prepare the output file parameter for Click's echo function."""
if is_stdout(filepath):
yield sys.stdout
else:
writer = filepath.open(mode)
yield writer
writer.close()
return None
return filepath.open("w", encoding="UTF-8")


def update_manager_selection(
Expand Down Expand Up @@ -170,9 +161,9 @@ def update_manager_selection(
# instantiation, we have to reverse the process to get our value.
# Example: --apt-mint => apt_mint => apt-mint
manager_id = param.name.removeprefix("no_").replace("_", "-")
assert manager_id == value, (
f"unrecognized single manager selector {param.name!r}"
)
assert (
manager_id == value
), f"unrecognized single manager selector {param.name!r}"
if param.name.startswith("no_"):
assert isinstance(value, str)
to_remove.add(value)
Expand Down Expand Up @@ -1354,50 +1345,51 @@ def backup(ctx, overwrite, merge, update_version, toml_path):
if merge or update_version:
installed_data = tomllib.loads(toml_path.read_text(**encoding_args))

with file_writer(toml_path) as f:
# Leave some metadata as comment.
f.write(f"# Generated by mpm v{__version__}.\n")
f.write(f"# Timestamp: {datetime.now().isoformat()}.\n")
# Leave some metadata as comment.
content = (
f"# Generated by mpm v{__version__}.\n"
f"# Timestamp: {datetime.now().isoformat()}.\n\n"
)
# Create one section for each manager.
for manager in ctx.obj.selected_managers(implements_operation=Operations.installed):
logging.info(f"Dumping packages from {theme.invoked_command(manager.id)}...")

# Create one section for each manager.
for manager in ctx.obj.selected_managers(
implements_operation=Operations.installed,
):
logging.info(
f"Dumping packages from {theme.invoked_command(manager.id)}..."
)
packages = tuple(packages_asdict(manager.installed, fields))

packages = tuple(packages_asdict(manager.installed, fields))

for pkg in packages:
# Only update version in that mode if the package is already referenced
# into original TOML file.
if update_version:
if pkg["id"] in installed_data.get(manager.id, {}):
installed_data[manager.id][pkg["id"]] = str(
pkg["installed_version"],
)
# Insert installed package in data structure for standard dump and merge
# mode.
else:
installed_data.setdefault(manager.id, {})[pkg["id"]] = str(
for pkg in packages:
# Only update version in that mode if the package is already referenced
# into original TOML file.
if update_version:
if pkg["id"] in installed_data.get(manager.id, {}):
installed_data[manager.id][pkg["id"]] = str(
pkg["installed_version"],
)

# Re-sort package list.
if installed_data.get(manager.id):
installed_data[manager.id] = dict(
sorted(
installed_data[manager.id].items(),
# Case-insensitive lexicographical sort on keys.
key=lambda i: (i[0].lower(), i[0]),
),
# Insert installed package in data structure for standard dump and merge
# mode.
else:
installed_data.setdefault(manager.id, {})[pkg["id"]] = str(
pkg["installed_version"],
)

# Write each section separated by an empty line for readability.
for manager_id, packages in installed_data.items():
f.write("\n")
f.write(tomli_w.dumps({manager_id: packages}))
# Re-sort package list.
if installed_data.get(manager.id):
installed_data[manager.id] = dict(
sorted(
installed_data[manager.id].items(),
# Case-insensitive lexicographical sort on keys.
key=lambda i: (i[0].lower(), i[0]),
),
)

# Write each section separated by an empty line for readability.
content += "\n".join(
(
tomli_w.dumps({manager_id: packages})
for manager_id, packages in installed_data.items()
)
)

echo(content, file=prep_path(toml_path))

if ctx.obj.stats:
print_stats(Counter({k: len(v) for k, v in installed_data.items()}))
Expand Down Expand Up @@ -1539,5 +1531,4 @@ def sbom(ctx, spdx, format, overwrite, export_path):
for package in manager.installed:
sbom.add_package(manager, package)

with file_writer(export_path) as f:
sbom.export(f)
echo(sbom.export(), file=prep_path(export_path))
38 changes: 23 additions & 15 deletions meta_package_manager/sbom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

from __future__ import annotations

import io
import logging
import re
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import IO, Any
from typing import Any

from boltons.ecoutils import get_profile
from cyclonedx.model import (
Expand Down Expand Up @@ -100,6 +101,9 @@ def __init__(self, export_format: ExportFormat = ExportFormat.JSON) -> None:
def autodetect_export_format(file_path: Path) -> ExportFormat | None:
"""Better version of ``spdx_tools.spdx.formats.file_name_to_format`` which is
based on ``Path`` objects and is case-insensitive.
.. todo:
Contribute generic autodetection method to Click Extra?
"""
suffixes = tuple(s.lower() for s in file_path.suffixes[-2:])
export_format = None
Expand Down Expand Up @@ -141,14 +145,16 @@ def init_doc(self) -> None:
"""
profile = get_profile()
system_id = self.normalize_spdx_id(
"-".join((
CURRENT_OS_LABEL,
profile["linux_dist_name"],
profile["linux_dist_version"],
profile["uname"]["system"],
profile["uname"]["release"],
profile["uname"]["machine"],
))
"-".join(
(
CURRENT_OS_LABEL,
profile["linux_dist_name"],
profile["linux_dist_version"],
profile["uname"]["system"],
profile["uname"]["release"],
profile["uname"]["machine"],
)
)
)

self.document = Document(
Expand Down Expand Up @@ -212,10 +218,12 @@ def add_package(self, manager: PackageManager, package: Package) -> None:
Relationship(self.DOC_ID, RelationshipType.DESCRIBES, package_docid)
)

def export(self, stream: IO):
def export(self) -> str:
"""Similar to ``spdx_tools.spdx.writer.write_anything.write_file`` but write
directly to provided stream instead of file path.
"""
stream = io.StringIO()

writer: Any
if self.export_format == ExportFormat.JSON:
writer = json_writer
Expand All @@ -227,9 +235,8 @@ def export(self, stream: IO):
writer = tagvalue_writer
elif self.export_format == ExportFormat.RDF_XML:
writer = rdf_writer
# RDF writer expects a binary-mode IO stream but the one provided is
# string-based.
stream = stream.buffer # type: ignore[attr-defined]
# RDF writer expects a binary-mode IO stream.
stream = io.BytesIO()
else:
raise ValueError(f"{self.export_format} not supported.")

Expand All @@ -242,6 +249,7 @@ def export(self, stream: IO):

logging.debug(f"Export with {writer.__name__}")
writer.write_document_to_stream(self.document, stream, validate=False)
return stream.getvalue()


class CycloneDX(SBOM):
Expand Down Expand Up @@ -367,7 +375,7 @@ def add_package(self, manager: PackageManager, package: Package) -> None:
[data],
)

def export(self, stream: IO):
def export(self) -> str:
validator: BaseSchemabasedValidator
if self.export_format == ExportFormat.JSON:
content = JsonV1Dot5(self.document).output_as_string(indent=2)
Expand All @@ -389,4 +397,4 @@ def export(self, stream: IO):
logging.debug(content)
raise ValueError(f"Document is not valid. Errors: {errors}")

stream.write(content)
return content

0 comments on commit 828af15

Please sign in to comment.