From a37e72fc01ece2bfe6b619d09318196844b2fe6e Mon Sep 17 00:00:00 2001 From: David Randall Stokes Date: Wed, 10 May 2023 19:47:58 -0400 Subject: [PATCH] Converted `children` from a `dict` to a `set` of IDs (for future development). Also: * Updated tests to handle `children` change * Added `util.flatiter()` function (was planned for later, but needed for tests) * Updated tool tests to run only in GitHub Actions (does not work locally if package not installed) --- ebmlite/core.py | 139 +++++++++++++++++++----------------------- ebmlite/util.py | 33 +++++++++- tests/test_general.py | 53 +++++----------- tests/test_tools.py | 28 +++++++++ 4 files changed, 137 insertions(+), 116 deletions(-) diff --git a/ebmlite/core.py b/ebmlite/core.py index ae605ed..453c4a3 100644 --- a/ebmlite/core.py +++ b/ebmlite/core.py @@ -154,10 +154,10 @@ def __init__(self, stream=None, offset=0, size=0, payloadOffset=0): elements should be created when a `Document` is loaded, rather than instantiated explicitly. - @keyword stream: A file-like object containing EBML data. - @keyword offset: The element's starting location in the file. - @keyword size: The size of the whole element. - @keyword payloadOffset: The starting location of the element's + :param stream: A file-like object containing EBML data. + :param offset: The element's starting location in the file. + :param size: The size of the whole element. + :param payloadOffset: The starting location of the element's payload (i.e. immediately after the element's header). """ self.stream = stream @@ -235,13 +235,13 @@ def encodePayload(cls, data, length=None): def encode(cls, value, length=None, lengthSize=None, infinite=False): """ Encode an EBML element. - @param value: The value to encode, or a list of values to encode. + :param value: The value to encode, or a list of values to encode. If a list is provided, each item will be encoded as its own element. - @keyword length: An explicit length for the encoded data, + :param length: An explicit length for the encoded data, overriding the variable length encoding. For producing byte-aligned structures. - @keyword lengthSize: An explicit length for the encoded element + :param lengthSize: An explicit length for the encoded element size, overriding the variable length encoding. @return: A bytearray containing the encoded EBML data. """ @@ -479,15 +479,15 @@ def __init__(self, stream=None, offset=0, size=0, payloadOffset=0, eid=None, most cases, elements should be created when a `Document` is loaded, rather than instantiated explicitly. - @keyword stream: A file-like object containing EBML data. - @keyword offset: The element's starting location in the file. - @keyword size: The size of the whole element. - @keyword payloadOffset: The starting location of the element's + :param stream: A file-like object containing EBML data. + :param offset: The element's starting location in the file. + :param size: The size of the whole element. + :param payloadOffset: The starting location of the element's payload (i.e. immediately after the element's header). - @keyword id: The unknown element's ID. Unlike 'normal' elements, + :param id: The unknown element's ID. Unlike 'normal' elements, in which ID is a class attribute, each UnknownElement instance explicitly defines this. - @keyword schema: The schema used to load the element. Specified + :param schema: The schema used to load the element. Specified explicitly because `UnknownElement`s are not part of any schema. """ @@ -534,12 +534,12 @@ def parseElement(self, stream, nocache=False): object, and then return it and the offset of the next element (this element's position + size). - @param stream: The source file-like stream. - @keyword nocache: If `True`, the parsed element's `precache` + :param stream: The source file-like stream. + :param nocache: If `True`, the parsed element's `precache` attribute is ignored, and the element's value will not be cached. For faster iteration when the element value doesn't matter (e.g. counting child elements). - @return: The parsed element and the offset of the next element + :return: The parsed element and the offset of the next element (i.e. the end of the parsed element). """ offset = stream.tell() @@ -569,14 +569,8 @@ def _isValidChild(cls, elId): if not cls.children: return False - try: - return elId in cls._childIds - except AttributeError: - # The set of valid child IDs hasn't been created yet. - cls._childIds = set(cls.children) - if cls.schema is not None: - cls._childIds.update(cls.schema.globals) - return elId in cls._childIds + return elId in cls.children or elId in cls.schema.globals + @property def size(self): @@ -708,14 +702,14 @@ def encodePayload(cls, data, length=None): def encode(cls, data, length=None, lengthSize=None, infinite=False): """ Encode an EBML master element. - @param data: The data to encode, provided as a dictionary keyed by + :param data: The data to encode, provided as a dictionary keyed by element name, a list of two-item name/value tuples, or a list of either. Note: individual items in a list of name/value pairs *must* be tuples! - @keyword infinite: If `True`, the element will be written with an + :param infinite: If `True`, the element will be written with an undefined size. When parsed, its end will be determined by the occurrence of an invalid child element (or end-of-file). - @return: A bytearray containing the encoded EBML binary. + :return: A bytearray containing the encoded EBML binary. """ # TODO: Use 'length' to automatically generate `Void` element? if isinstance(data, list) and len(data) > 0 and isinstance(data[0], list): @@ -769,13 +763,13 @@ def __init__(self, stream, name=None, size=None, headers=True): In most cases, `Schema.load()` should be used instead of explicitly instantiating a `Document`. - @param stream: A stream object (e.g. a file) from which to read + :param stream: A stream object (e.g. a file) from which to read the EBML content. - @keyword name: The name of the document. Defaults to the filename + :param name: The name of the document. Defaults to the filename (if applicable). - @keyword size: The size of the document, in bytes. Use if the + :param size: The size of the document, in bytes. Use if the stream is neither a file or a `BytesIO` object. - @keyword headers: If `False`, the file's ``EBML`` header element + :param headers: If `False`, the file's ``EBML`` header element (if present) will not appear as a root element in the document. The contents of the ``EBML`` element will always be read, regardless, and stored in the Document's `info` attribute. @@ -941,7 +935,7 @@ def _createHeaders(cls): """ Create the default EBML 'header' elements for a Document, using the default values in the schema. - @return: A dictionary containing a single key (``EBML``) with a + :return: A dictionary containing a single key (``EBML``) with a dictionary as its value. The child dictionary contains element names and values. """ @@ -965,11 +959,11 @@ def _createHeaders(cls): def encode(cls, stream, data, headers=False, **kwargs): """ Encode an EBML document. - @param value: The data to encode, provided as a dictionary keyed + :param data: The data to encode, provided as a dictionary keyed by element name, or a list of two-item name/value tuples. Note: individual items in a list of name/value pairs *must* be tuples! - @return: A bytearray containing the encoded EBML binary. + :return: A bytearray containing the encoded EBML binary. """ if headers is True: stream.write(cls.encodePayload(cls._createHeaders())) @@ -1049,9 +1043,9 @@ class Schema(object): def __init__(self, source, name=None): """ Constructor. Creates a new Schema from a schema description XML. - @param source: The Schema's source, either a string with the full + :param source: The Schema's source, either a string with the full path and name of the schema XML file, or a file-like stream. - @keyword name: The schema's name. Defaults to the document type + :param name: The schema's name. Defaults to the document type element's default value (if defined) or the base file name. """ self.source = source @@ -1067,7 +1061,7 @@ def __init__(self, source, name=None): self.elementInfo = {} # Raw element schema attributes, keyed by ID self.globals = {} # Elements valid for any parent, by ID - self.children = {} # Valid root elements, by ID + self.children = set() # Valid root elements, by ID # Parse, using the correct method for the schema format. schema = ET.parse(source) @@ -1158,7 +1152,7 @@ def _parseSchema(self, el, parent=None): for chEl in el: self._parseSchema(chEl, cls) - def addElement(self, eid, ename, baseClass, attribs={}, parent=None, + def addElement(self, eid, ename, baseClass, attribs=None, parent=None, docs=None): """ Create a new `Element` subclass and add it to the schema. @@ -1168,23 +1162,16 @@ def addElement(self, eid, ename, baseClass, attribs={}, parent=None, schema must contain the required ID, name, and type; successive appearances only need the ID and/or name. - @param eid: The element's EBML ID. - @param ename: The element's name. - @keyword multiple: If `True`, an EBML document can contain more - than one of this element. Not currently enforced. - @keyword mandatory: If `True`, a valid EBML document requires one - (or more) of this element. Not currently enforced. - @keyword length: A fixed length to use when writing the element. - `None` will use the minimum length required. - @keyword precache: If `True`, the element's value will be read - when the element is parsed, rather than when the value is - explicitly accessed. Can save time for small elements. - @keyword attribs: A dictionary of raw element attributes, as read + :param eid: The element's EBML ID. + :param ename: The element's name. + :param baseClass: + :param attribs: A dictionary of raw element attributes, as read from the schema file. - @keyword parent: The new element's parent element class. - @keyword docs: The new element's docstring (e.g. the defining XML + :param parent: The new element's parent element class. + :param docs: The new element's docstring (e.g. the defining XML element's text content). """ + attribs = {} if attribs is None else attribs def _getBool(d, k, default): """ Helper function to get a dictionary value cast to bool. """ @@ -1265,7 +1252,7 @@ def _getInt(d, k, default): {'id': eid, 'name': ename, 'schema': self, 'mandatory': mandatory, 'multiple': multiple, 'precache': precache, 'length': length, - 'children': dict(), '__doc__': docs, + 'children': set(), '__doc__': docs, '__slots__': baseClass.__slots__}) self.elements[eid] = eclass @@ -1277,8 +1264,8 @@ def _getInt(d, k, default): parent = parent or self if parent.children is None: - parent.children = {} - parent.children[eid] = eclass + parent.children = set() + parent.children.add(eid) return eclass @@ -1321,10 +1308,10 @@ def get(self, key, default=None): def load(self, fp, name=None, headers=False, **kwargs): """ Load an EBML file using this Schema. - @param fp: A file-like object containing the EBML to load, or the + :param fp: A file-like object containing the EBML to load, or the name of an EBML file. - @keyword name: The name of the document. Defaults to filename. - @keyword headers: If `False`, the file's ``EBML`` header element + :param name: The name of the document. Defaults to filename. + :param headers: If `False`, the file's ``EBML`` header element (if present) will not appear as a root element in the document. The contents of the ``EBML`` element will always be read. @@ -1334,8 +1321,8 @@ def load(self, fp, name=None, headers=False, **kwargs): def loads(self, data, name=None): """ Load EBML from a string using this Schema. - @param data: A string or bytearray containing raw EBML data. - @keyword name: The name of the document. Defaults to the Schema's + :param data: A string or bytearray containing raw EBML data. + :param name: The name of the document. Defaults to the Schema's document class name. """ return self.load(BytesIO(data), name=name) @@ -1346,9 +1333,9 @@ def __call__(self, fp, name=None): @todo: Decide if this is worth keeping. It exists for historical reasons that may have been refactored out. - @param fp: A file-like object containing the EBML to load, or the + :param fp: A file-like object containing the EBML to load, or the name of an EBML file. - @keyword name: The name of the document. Defaults to filename. + :param name: The name of the document. Defaults to filename. """ return self.load(fp, name=name) @@ -1381,9 +1368,9 @@ def encode(self, stream, data, headers=False): """ Write an EBML document using this Schema to a file or file-like stream. - @param stream: The file (or ``.write()``-supporting file-like + :param stream: The file (or ``.write()``-supporting file-like object) to which to write the encoded EBML. - @param data: The data to encode, provided as a dictionary keyed by + :param data: The data to encode, provided as a dictionary keyed by element name, or a list of two-item name/value tuples. Note: individual items in a list of name/value pairs *must* be tuples! """ @@ -1393,10 +1380,10 @@ def encode(self, stream, data, headers=False): def encodes(self, data, headers=False): """ Create an EBML document using this Schema, returned as a string. - @param data: The data to encode, provided as a dictionary keyed by + :param data: The data to encode, provided as a dictionary keyed by element name, or a list of two-item name/value tuples. Note: individual items in a list of name/value pairs *must* be tuples! - @return: A string containing the encoded EBML binary. + :return: A string containing the encoded EBML binary. """ stream = BytesIO() self.encode(stream, data, headers=headers) @@ -1430,13 +1417,13 @@ def _expandSchemaPath(path, name=''): """ Helper function to process a schema path or name, converting module references to Paths. - @param path: The schema path. May be a directory name, a module + :param path: The schema path. May be a directory name, a module name in braces (e.g., `{idelib.schemata}`), or a module instance. Directory and module names may contain schema filenames. - @param name: An optional schema base filename. Will get appended + :param name: An optional schema base filename. Will get appended to the resulting `Path`/`Traversable`. - @return: A `Path`/`Traversable` object. + :return: A `Path`/`Traversable` object. """ strpath = str(path) subdir = '' @@ -1474,7 +1461,7 @@ def listSchemata(*paths, absolute=True): alternatively, one or more paths or modules can be supplied as arguments. - @returns: A dictionary of schema files. Keys are the base name of the + :returns: A dictionary of schema files. Keys are the base name of the schema XML, values are lists of full paths to the XML. The first filename in the list is what will load if the base name is used with `loadSchema()`. @@ -1510,14 +1497,14 @@ def loadSchema(filename, reload=False, paths=None, **kwargs): """ Import a Schema XML file. Loading the same file more than once will return the initial instantiation, unless `reload` is `True`. - @param filename: The name of the Schema XML file. If the file cannot + :param filename: The name of the Schema XML file. If the file cannot be found and file's path is not absolute, the paths listed in `SCHEMA_PATH` will be searched (similar to `sys.path` when importing modules). - @param reload: If `True`, the resulting Schema is guaranteed to be + :param reload: If `True`, the resulting Schema is guaranteed to be new. Note: existing references to previous instances of the Schema and/or its elements will not update. - @param paths: A list of paths to search for schemata, an alternative + :param paths: A list of paths to search for schemata, an alternative to `ebmlite.SCHEMA_PATH` Additional keyword arguments are sent verbatim to the `Schema` @@ -1570,10 +1557,10 @@ def parseSchema(src, name=None, reload=False, **kwargs): is `True`. Calls to `loadSchema()` using a name previously used with `parseSchema()` will also return the previously instantiated Schema. - @param src: The XML string, or a stream containing XML. - @param name: The name of the schema. If none is supplied, + :param src: The XML string, or a stream containing XML. + :param name: The name of the schema. If none is supplied, the name defined within the schema will be used. - @param reload: If `True`, the resulting Schema is guaranteed to be + :param reload: If `True`, the resulting Schema is guaranteed to be new. Note: existing references to previous instances of the Schema and/or its elements will not update. diff --git a/ebmlite/util.py b/ebmlite/util.py index 6fe77da..b8fecd5 100644 --- a/ebmlite/util.py +++ b/ebmlite/util.py @@ -15,7 +15,7 @@ __credits__ = "David Randall Stokes, Connor Flanigan, Becker Awqatty, Derek Witt" __all__ = ['createID', 'validateID', 'toXml', 'xml2ebml', 'loadXml', 'pprint', - 'printSchemata'] + 'printSchemata', 'flatiter'] import ast from base64 import b64encode, b64decode @@ -194,9 +194,9 @@ def toXml(el, parent=None, offsets=True, sizes=True, types=True, ids=True, return xmlEl -#=============================================================================== +# ============================================================================== # -#=============================================================================== +# ============================================================================== def xmlElement2ebml(xmlEl, ebmlFile, schema, sizeLength=None, unknown=True): """ Convert an XML element to EBML, recursing if necessary. For converting @@ -473,3 +473,30 @@ def printSchemata(paths=None, out=sys.stdout, absolute=True): finally: if newfile: out.close() + + +#=============================================================================== +# +#=============================================================================== + + +def flatiter(element, depth=None): + """ Recursively crawl an EBML document or element, depth-first, + yielding all elements (or elements down to a given depth). + + :param element: The EBML `Document` or `Element` to iterate. + :param depth: The maximum recursion depth. `None` or a value less + than zero will fully recurse without limit. + """ + depth = -1 if depth is None else depth + + def _flatiter(el, d, first): + if not first: + yield el + if abs(d) > 0 and isinstance(el, core.MasterElement): + for ch in el: + for grandchild in _flatiter(ch, d-1, False): + yield grandchild + + for child in _flatiter(element, depth, True): + yield child diff --git a/tests/test_general.py b/tests/test_general.py index c108083..3114ef3 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -4,6 +4,7 @@ @author: dstokes """ +from itertools import zip_longest import os.path import unittest from xml.dom.minidom import parseString @@ -66,18 +67,11 @@ def testMkv(self): xmlDoc2 = util.loadXml(xmlFile2, schema) # Compare each element from the XML - xmlEls1 = [xmlDoc1] - xmlEls2 = [xmlDoc2] - while len(xmlEls1) > 0: - self.assertEqual(xmlEls1[0], xmlEls2[0], 'Element ' - + repr(xmlEls1[0]) - + ' was not converted properly') - for x in list(xmlEls1.pop(0).children.values()): - if issubclass(x, core.Element): - xmlEls1.append(x) - for x in list(xmlEls2.pop(0).children.values()): - if issubclass(x, core.Element): - xmlEls2.append(x) + for el1, el2 in zip_longest(util.flatiter(xmlDoc1), + util.flatiter(xmlDoc2), + fillvalue=None): + self.assertEqual(el1, el2, + 'Element {!r} was not converted properly'.format(el1)) def testIde(self): @@ -121,19 +115,11 @@ def testIde(self): xmlDoc2 = util.loadXml(xmlFile2, schema) # Compare each element from the XML - xmlEls1 = [xmlDoc1] - xmlEls2 = [xmlDoc2] - while len(xmlEls1) > 0: - self.assertEqual(xmlEls1[0], xmlEls2[0], 'Element ' - + repr(xmlEls1[0]) - + ' was not converted properly') - for x in list(xmlEls1.pop(0).children.values()): - if issubclass(x, core.Element): - xmlEls1.append(x) - for x in list(xmlEls2.pop(0).children.values()): - if issubclass(x, core.Element): - xmlEls2.append(x) - + for el1, el2 in zip_longest(util.flatiter(xmlDoc1), + util.flatiter(xmlDoc2), + fillvalue=None): + self.assertEqual(el1, el2, + 'Element {!r} was not converted properly'.format(el1)) def testPPrint(self): @@ -298,18 +284,11 @@ def testMkv(self): xmlDoc2 = util.loadXml(xmlFile2, schema) # Compare each element from the XML - xmlEls1 = [xmlDoc1] - xmlEls2 = [xmlDoc2] - while len(xmlEls1) > 0: - self.assertEqual(xmlEls1[0], xmlEls2[0], 'Element ' - + repr(xmlEls1[0]) - + ' was not converted properly') - for x in list(xmlEls1.pop(0).children.values()): - if issubclass(x, core.Element): - xmlEls1.append(x) - for x in list(xmlEls2.pop(0).children.values()): - if issubclass(x, core.Element): - xmlEls2.append(x) + for el1, el2 in zip_longest(util.flatiter(xmlDoc1), + util.flatiter(xmlDoc2), + fillvalue=None): + self.assertEqual(el1, el2, + 'Element {!r} was not converted properly'.format(el1)) if __name__ == "__main__": diff --git a/tests/test_tools.py b/tests/test_tools.py index 396fd92..f1f293a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,6 +12,13 @@ @pytest.mark.script_launch_mode('subprocess') def test_ebml2xml(script_runner): + + # This test can only run if the library has been installed, + # e.g., in a GitHub action. Bail if not. + # TODO: This is a hack and should be redone. + if os.getenv("GITHUB_ACTIONS") != "true": + return + path_base = os.path.join(".", "tests", "video-4{ext}") path_in = path_base.format(ext=".ebml") path_out = path_base.format(ext=".ebml.xml") @@ -56,6 +63,13 @@ def assert_elements_are_equiv(e1, e2): @pytest.mark.script_launch_mode('subprocess') def test_xml2ebml(script_runner): + + # This test can only run if the library has been installed, + # e.g., in a GitHub action. Bail if not. + # TODO: This is a hack and should be redone. + if os.getenv("GITHUB_ACTIONS") != "true": + return + path_base = os.path.join(".", "tests", "video-4{ext}") path_in = path_base.format(ext=".xml") path_out = path_base.format(ext=".xml.ebml") @@ -84,6 +98,13 @@ def test_xml2ebml(script_runner): @pytest.mark.script_launch_mode('subprocess') def test_view(script_runner): + + # This test can only run if the library has been installed, + # e.g., in a GitHub action. Bail if not. + # TODO: This is a hack and should be redone. + if os.getenv("GITHUB_ACTIONS") != "true": + return + path_base = os.path.join(".", "tests", "video-4{ext}") path_in = path_base.format(ext=".ebml") path_out = path_base.format(ext=".xml.txt") @@ -112,6 +133,13 @@ def test_view(script_runner): @pytest.mark.script_launch_mode('subprocess') def test_list_schemata(script_runner): + + # This test can only run if the library has been installed, + # e.g., in a GitHub action. Bail if not. + # TODO: This is a hack and should be redone. + if os.getenv("GITHUB_ACTIONS") != "true": + return + core.SCHEMA_PATH = [os.path.dirname(schemata.__file__)] path_out = os.path.join(".", "tests", "list-schemata.txt")