diff --git a/src/tpmstream/io/auto/marshal.py b/src/tpmstream/io/auto/marshal.py index cb80900..783a9af 100644 --- a/src/tpmstream/io/auto/marshal.py +++ b/src/tpmstream/io/auto/marshal.py @@ -1,12 +1,12 @@ import binascii +import re -from ...spec.structures.interface_types import TPMI_ST_COMMAND_TAG from ..binary import Binary from ..hex import Hex from ..pcapng import Pcapng -def detect_format_and_yield_buffer(buffer): +def detect_format_and_yield_buffer(buffer, strict=True): """First yield is format. Rest is buffer bytewise.""" buffer_iter = iter(buffer) @@ -19,26 +19,29 @@ def detect_format_and_yield_buffer(buffer): ) from e if look_ahead == b"\x0a\x0d": - # TODO use enum or some sort of canonical mapping? yield "pcapng" - elif look_ahead == b"80": - yield "hex" - elif look_ahead in ( - tag.to_bytes() for tag in TPMI_ST_COMMAND_TAG._valid_values._values - ): - yield "binary" else: - raise IOError( - f"Unknown detect input format: magic number is {binascii.hexlify(look_ahead).decode()}" - ) + if re.match(b"[0-9a-fA-F]{2}", look_ahead): + # look ahead is valid hex, so it's MAYBE hex + if not strict: + yield "hex" + else: + raise IOError( + f"Ambiguous input format: magic number is {binascii.hexlify(look_ahead).decode()}. Could be binary or hex." + ) + else: + # not valid hex, so it must be binary + yield "binary" yield from look_ahead yield from buffer_iter -def marshal(tpm_type, buffer, root_path=None, command_code=None, **kwargs): +def marshal( + tpm_type, buffer, root_path=None, command_code=None, strict=False, **kwargs +): """Generator. Take iterable which yields single bytes. Yield MarshalEvents. Be smart about format.""" - format_buffer_iter = detect_format_and_yield_buffer(buffer) + format_buffer_iter = detect_format_and_yield_buffer(buffer, strict=strict) format = next(format_buffer_iter)