diff --git a/examples/somersaultecu.py b/examples/somersaultecu.py index 99ba9083..f07daade 100755 --- a/examples/somersaultecu.py +++ b/examples/somersaultecu.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT import pathlib from enum import IntEnum +from io import BytesIO from itertools import chain from typing import Any, Dict from xml.etree import ElementTree @@ -2467,10 +2468,11 @@ class SomersaultSID(IntEnum): database = Database() database._diag_layer_containers = NamedItemList([somersault_dlc]) database._comparam_subsets = NamedItemList(comparam_subsets) -database.add_auxiliary_file("jobs.py", b""" +database.add_auxiliary_file("jobs.py", + BytesIO(b""" def compulsory_program(): print("Hello, World") -""") +""")) # Create ID mapping and resolve references database.refresh() diff --git a/odxtools/database.py b/odxtools/database.py index 521a5fb8..be6aedd0 100644 --- a/odxtools/database.py +++ b/odxtools/database.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT from itertools import chain from pathlib import Path -from typing import List, Optional, OrderedDict +from typing import IO, List, Optional, OrderedDict from xml.etree import ElementTree from zipfile import ZipFile @@ -27,7 +27,7 @@ def __init__(self, pdx_zip: Optional[ZipFile] = None, odx_d_file_name: Optional[str] = None) -> None: self.model_version: Optional[Version] = None - self.auxiliary_files: OrderedDict[str, bytes] = OrderedDict() + self.auxiliary_files: OrderedDict[str, IO[bytes]] = OrderedDict() # create an empty database object self._diag_layer_containers = NamedItemList[DiagLayerContainer]() @@ -47,16 +47,18 @@ def add_pdx_file(self, pdx_file_name: str) -> None: root = ElementTree.parse(pdx_zip.open(zip_member)).getroot() self._process_xml_tree(root) elif p.name.lower() != "index.xml": - self.add_auxiliary_file(zip_member, pdx_zip.read(zip_member)) + self.add_auxiliary_file(zip_member, pdx_zip.open(zip_member)) def add_odx_file(self, odx_file_name: str) -> None: self._process_xml_tree(ElementTree.parse(odx_file_name).getroot()) - def add_auxiliary_file(self, aux_file_name: str, aux_file_data: Optional[bytes] = None) -> None: - if aux_file_data is None: - aux_file_data = open(aux_file_name, "rb").read() + def add_auxiliary_file(self, + aux_file_name: str, + aux_file_obj: Optional[IO[bytes]] = None) -> None: + if aux_file_obj is None: + aux_file_obj = open(aux_file_name, "rb") - self.auxiliary_files[aux_file_name] = aux_file_data + self.auxiliary_files[aux_file_name] = aux_file_obj def _process_xml_tree(self, root: ElementTree.Element) -> None: dlcs: List[DiagLayerContainer] = [] diff --git a/odxtools/loadfile.py b/odxtools/loadfile.py index cad54af8..6a7884a5 100644 --- a/odxtools/loadfile.py +++ b/odxtools/loadfile.py @@ -58,7 +58,7 @@ def load_directory(dir_name: Union[str, Path]) -> Database: elif p.suffix.lower().startswith(".odx"): db.add_odx_file(str(p)) elif p.name.lower() != "index.xml": - db.add_auxiliary_file(p.name, open(str(p), "rb").read()) + db.add_auxiliary_file(p.name, open(str(p), "rb")) db.refresh() return db diff --git a/odxtools/progcode.py b/odxtools/progcode.py index 191b4416..1c5ea69b 100644 --- a/odxtools/progcode.py +++ b/odxtools/progcode.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: MIT from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from xml.etree import ElementTree -from .exceptions import odxrequire +from .exceptions import odxraise, odxrequire from .odxlink import OdxDocFragment, OdxLinkDatabase, OdxLinkId, OdxLinkRef if TYPE_CHECKING: @@ -58,7 +58,13 @@ def _resolve_odxlinks(self, odxlinks: OdxLinkDatabase) -> None: def _resolve_snrefs(self, diag_layer: "DiagLayer") -> None: db = diag_layer._database - self._code = odxrequire( - db.auxiliary_files.get(self.code_file), - f"Reference to auxiliary file '{self.code_file}' " - f"could not be resolved") + aux_file = db.auxiliary_files.get(self.code_file) + + if aux_file is None: + odxraise(f"Reference to auxiliary file '{self.code_file}' " + f"could not be resolved") + self._code: bytes = cast(bytes, None) + return + + self._code = aux_file.read() + aux_file.seek(0) diff --git a/odxtools/writepdxfile.py b/odxtools/writepdxfile.py index a9e7612f..7ad05869 100644 --- a/odxtools/writepdxfile.py +++ b/odxtools/writepdxfile.py @@ -118,7 +118,7 @@ def write_pdx_file( out_file.write(open(in_file_name, "rb").read()) # write the auxiliary files - for output_file_name, data in database.auxiliary_files.items(): + for output_file_name, data_file in database.auxiliary_files.items(): file_cdate = datetime.datetime.fromtimestamp(time.time()) creation_date = file_cdate.strftime("%Y-%m-%dT%H:%M:%S") @@ -137,7 +137,7 @@ def write_pdx_file( zf_name = os.path.basename(output_file_name) with zf.open(zf_name, "w") as out_file: file_index.append((zf_name, creation_date, mime_type)) - out_file.write(data) + out_file.write(data_file.read()) jinja_env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_dir)) jinja_env.globals["hasattr"] = hasattr diff --git a/tests/test_singleecujob.py b/tests/test_singleecujob.py index ad30a20c..d193607c 100644 --- a/tests/test_singleecujob.py +++ b/tests/test_singleecujob.py @@ -2,6 +2,7 @@ import inspect import os import unittest +from io import BytesIO from typing import NamedTuple, cast from xml.etree import ElementTree @@ -468,7 +469,7 @@ def test_resolve_odxlinks(self) -> None: db = Database() db.add_auxiliary_file("abc.jar", - b"this is supposed to be a JAR archive, but it isn't (HARR)") + BytesIO(b"this is supposed to be a JAR archive, but it isn't (HARR)")) dl._resolve_odxlinks(odxlinks) dl._finalize_init(db, odxlinks)