Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/type annotations #38

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ build/*
dist/*
env
env?
.mypy*
venv
venv?
*/.*
Expand Down
97 changes: 58 additions & 39 deletions ebmlite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import errno
import importlib
from io import BytesIO, StringIO, IOBase
from typing import Any, Dict, List, Optional, Union, BinaryIO
import os.path
from pathlib import Path
import re
Expand Down Expand Up @@ -83,15 +84,15 @@
# SCHEMA_PATH: A list of paths for schema XML files, similar to `sys.path`.
# When `loadSchema()` is used, it will search these paths, in order, to find
# the schema file.
SCHEMA_PATH = ['',
SCHEMA_PATH: List[str] = ['',
os.path.realpath(os.path.dirname(schemata.__file__))]

SCHEMA_PATH.extend(p for p in os.environ.get('EBMLITE_SCHEMA_PATH', '').split(os.path.pathsep)
if p not in SCHEMA_PATH)

# SCHEMATA: A dictionary of loaded schemata, keyed by filename. Used by
# `loadSchema()`. In most cases, SCHEMATA should not be otherwise modified.
SCHEMATA = {}
SCHEMATA: Dict[Union[str, bytes, bytearray], Any] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the keys will always be string.



# ==============================================================================
Expand Down Expand Up @@ -202,7 +203,7 @@ def getRaw(self):
self.stream.seek(self.offset)
return self.stream.read(self.size + (self.payloadOffset - self.offset))

def getRawValue(self):
def getRawValue(self) -> Union[bytes, bytearray]:
""" Get the raw binary of the element's value.
"""
self.stream.seek(self.payloadOffset)
Expand All @@ -212,7 +213,7 @@ def getRawValue(self):
# Caching (experimental)
# ==========================================================================

def gc(self, recurse=False):
def gc(self, recurse: bool = False) -> int:
""" Clear any cached values. To save memory and/or force values to be
re-read from the file. Returns the number of cached values cleared.
"""
Expand All @@ -227,12 +228,16 @@ def gc(self, recurse=False):
# ==========================================================================

@classmethod
def encodePayload(cls, data, length=None):
def encodePayload(cls, data, length: Optional[int] = None) -> bytes:
""" Type-specific payload encoder. """
return encoding.encodeBinary(data, length)

@classmethod
def encode(cls, value, length=None, lengthSize=None, infinite=False):
def encode(cls,
value,
length: Optional[int] = None,
lengthSize: Optional[int] = None,
infinite: bool = False) -> bytes:
""" Encode an EBML element.

@param value: The value to encode, or a list of values to encode.
Expand Down Expand Up @@ -285,14 +290,14 @@ def __eq__(self, other):
return False
return self.value == other.value

def parse(self, stream, size):
def parse(self, stream: BinaryIO, size: int) -> int:
""" Type-specific helper function for parsing the element's payload.
It is assumed the file pointer is at the start of the payload.
"""
return readInt(stream, size)

@classmethod
def encodePayload(cls, data, length=None):
def encodePayload(cls, data: int, length: Optional[int] = None) -> bytes:
""" Type-specific payload encoder for signed integer elements. """
return encoding.encodeInt(data, length)

Expand All @@ -308,14 +313,14 @@ class UIntegerElement(IntegerElement):
dtype = int
precache = True

def parse(self, stream, size):
def parse(self, stream: BinaryIO, size: int) -> int:
""" Type-specific helper function for parsing the element's payload.
It is assumed the file pointer is at the start of the payload.
"""
return readUInt(stream, size)

@classmethod
def encodePayload(cls, data, length=None):
def encodePayload(cls, data: int, length: Optional[int] = None) -> bytes:
""" Type-specific payload encoder for unsigned integer elements. """
return encoding.encodeUInt(data, length)

Expand All @@ -336,14 +341,14 @@ def __eq__(self, other):
return False
return self.value == other.value

def parse(self, stream, size):
def parse(self, stream: BinaryIO, size: int) -> float:
""" Type-specific helper function for parsing the element's payload.
It is assumed the file pointer is at the start of the payload.
"""
return readFloat(stream, size)

@classmethod
def encodePayload(cls, data, length=None):
def encodePayload(cls, data: float, length: Optional[int] = None) -> bytes:
""" Type-specific payload encoder for floating point elements. """
return encoding.encodeFloat(data, length)

Expand All @@ -366,14 +371,14 @@ def __eq__(self, other):
def __len__(self):
return self.size

def parse(self, stream, size):
def parse(self, stream: BinaryIO, size: int) -> bytes:
""" Type-specific helper function for parsing the element's payload.
It is assumed the file pointer is at the start of the payload.
"""
return readString(stream, size)

@classmethod
def encodePayload(cls, data, length=None):
def encodePayload(cls, data: Union[bytes, bytearray], length: Optional[int] = None) -> bytes:
""" Type-specific payload encoder for ASCII string elements. """
return encoding.encodeString(data, length)

Expand All @@ -392,14 +397,14 @@ def __len__(self):
# Value may be multiple bytes per character
return len(self.value)

def parse(self, stream, size):
def parse(self, stream: BinaryIO, size: int) -> str:
""" Type-specific helper function for parsing the element's payload.
It is assumed the file pointer is at the start of the payload.
"""
return readUnicode(stream, size)

@classmethod
def encodePayload(cls, data, length=None):
def encodePayload(cls, data: str, length: Optional[int] = None) -> bytes:
""" Type-specific payload encoder for Unicode string elements. """
return encoding.encodeUnicode(data, length)

Expand All @@ -412,16 +417,16 @@ class DateElement(IntegerElement):
generated when a `Schema` is loaded.
"""
__slots__ = ("stream", "offset", "size", "sizeLength", "payloadOffset", "_value")
dtype = datetime
dtype: datetime = datetime

def parse(self, stream, size):
def parse(self, stream: BinaryIO, size: int) -> datetime:
""" Type-specific helper function for parsing the element's payload.
It is assumed the file pointer is at the start of the payload.
"""
return readDate(stream, size)

@classmethod
def encodePayload(cls, data, length=None):
def encodePayload(cls, data: datetime, length: Optional[int] = None) -> bytes:
""" Type-specific payload encoder for date elements. """
return encoding.encodeDate(data, length)

Expand Down Expand Up @@ -450,11 +455,11 @@ class VoidElement(BinaryElement):
"""
__slots__ = ("stream", "offset", "size", "sizeLength", "payloadOffset", "_value")

def parse(self, stream, size):
def parse(self, stream: BinaryIO, size: Any):
return bytearray()

@classmethod
def encodePayload(cls, data, length=0):
def encodePayload(cls, data: Any, length: Optional[int] = 0) -> bytearray:
""" Type-specific payload encoder for Void elements. """
length = 0 if length is None else length
return bytearray(b'\xff' * length)
Expand All @@ -473,8 +478,13 @@ class UnknownElement(BinaryElement):
name = "UnknownElement"
precache = False

def __init__(self, stream=None, offset=0, size=0, payloadOffset=0, eid=None,
schema=None):
def __init__(self,
stream: Optional[BinaryIO] = None,
offset: int = 0,
size: int = 0,
payloadOffset: int = 0,
eid: Optional[int] = None,
schema: Any = None):
""" Constructor. Instantiate a new `UnknownElement` from a file. In
most cases, elements should be created when a `Document` is loaded,
rather than instantiated explicitly.
Expand Down Expand Up @@ -529,7 +539,7 @@ def parse(self):
# parse(). Used only when pre-caching.
return self.value

def parseElement(self, stream, nocache=False):
def parseElement(self, stream: BinaryIO, nocache: bool = False):
""" Read the next element from a stream, instantiate a `MasterElement`
object, and then return it and the offset of the next element
(this element's position + size).
Expand Down Expand Up @@ -561,7 +571,7 @@ def parseElement(self, stream, nocache=False):
return el, payloadOffset + el.size

@classmethod
def _isValidChild(cls, elId):
def _isValidChild(cls, elId: int) -> bool:
""" Is the given element ID represent a valid sub-element, i.e.
explicitly specified as a child element or a 'global' in the
schema?
Expand All @@ -579,7 +589,7 @@ def _isValidChild(cls, elId):
return elId in cls._childIds

@property
def size(self):
def size(self) -> int:
""" The element's size. Master elements can be instantiated with this
as `None`; this denotes an 'infinite' EBML element, and its size
will be determined by iterating over its contents until an invalid
Expand Down Expand Up @@ -613,13 +623,13 @@ def size(self):
return self._size

@size.setter
def size(self, esize):
def size(self, esize: int):
if esize is not None:
# Only create the `_size` attribute for a real value. Don't
# define it if it's `None`, so `size` will get calculated.
self._size = esize

def __iter__(self, nocache=False):
def __iter__(self, nocache: bool = False):
""" x.__iter__() <==> iter(x)
"""
# TODO: Better support for 'infinite' elements (getting the size of
Expand Down Expand Up @@ -669,7 +679,7 @@ def __getitem__(self, *args):
# Caching (experimental!)
# ==========================================================================

def gc(self, recurse=False):
def gc(self, recurse: bool = False) -> int:
""" Clear any cached values. To save memory and/or force values to be
re-read from the file.
"""
Expand All @@ -685,7 +695,7 @@ def gc(self, recurse=False):
# ==========================================================================

@classmethod
def encodePayload(cls, data, length=None):
def encodePayload(cls, data, length: Optional[int] = None) -> bytearray:
""" Type-specific payload encoder for 'master' elements.
"""
result = bytearray()
Expand All @@ -705,7 +715,11 @@ def encodePayload(cls, data, length=None):
return result

@classmethod
def encode(cls, data, length=None, lengthSize=None, infinite=False):
def encode(cls,
data,
length: Optional[int] = None,
lengthSize: Optional[int] = None,
infinite: bool = False) -> bytes:
""" Encode an EBML master element.

@param data: The data to encode, provided as a dictionary keyed by
Expand Down Expand Up @@ -734,7 +748,7 @@ def encode(cls, data, length=None, lengthSize=None, infinite=False):
lengthSize=lengthSize,
infinite=infinite)

def dump(self):
def dump(self) -> OrderedDict:
""" Dump this element's value as nested dictionaries, keyed by
element name. The values of 'multiple' elements return as lists.
Note: The order of 'multiple' elements relative to other elements
Expand Down Expand Up @@ -764,7 +778,7 @@ class Document(MasterElement):
Loading a `Schema` generates a subclass.
"""

def __init__(self, stream, name=None, size=None, headers=True):
def __init__(self, stream, name: Optional[str] = None, size: Optional[int] = None, headers: bool = True):
""" Constructor. Instantiate a `Document` from a file-like stream.
In most cases, `Schema.load()` should be used instead of
explicitly instantiating a `Document`.
Expand Down Expand Up @@ -928,7 +942,7 @@ def type(self):
# Caching (experimental!)
# ==========================================================================

def gc(self, recurse=False):
def gc(self, recurse: bool = False) -> int:
# TODO: Implement this if/when caching of root elements is implemented.
return 0

Expand Down Expand Up @@ -1046,7 +1060,7 @@ class Schema(object):
# factory function.
UNKNOWN = UnknownElement

def __init__(self, source, name=None):
def __init__(self, source, name: Optional[Union[str, bytes, bytearray]] = None):
""" Constructor. Creates a new Schema from a schema description XML.

@param source: The Schema's source, either a string with the full
Expand Down Expand Up @@ -1158,7 +1172,12 @@ def _parseSchema(self, el, parent=None):
for chEl in el:
self._parseSchema(chEl, cls)

def addElement(self, eid, ename, baseClass, attribs={}, parent=None,
def addElement(self,
eid: int,
ename: Union[str, bytes, bytearray],
baseClass,
attribs={},
parent=None,
docs=None):
""" Create a new `Element` subclass and add it to the schema.

Expand Down Expand Up @@ -1356,7 +1375,7 @@ def __call__(self, fp, name=None):
# Schema info stuff. Uses python-ebml schema XML data. Refactor later.
# ==========================================================================

def _getInfo(self, eid, dtype):
def _getInfo(self, eid: int, dtype):
""" Helper method to get the 'default' value of an element. """
try:
return dtype(self.elementInfo[eid]['default'])
Expand All @@ -1377,7 +1396,7 @@ def type(self):
# Encoding
# ==========================================================================

def encode(self, stream, data, headers=False):
def encode(self, stream: BinaryIO, data, headers=False) -> bytes:
""" Write an EBML document using this Schema to a file or file-like
stream.

Expand Down Expand Up @@ -1506,7 +1525,7 @@ def listSchemata(*paths, absolute=True):
return schemata


def loadSchema(filename, reload=False, paths=None, **kwargs):
def loadSchema(filename: Union[str, bytes, bytearray], reload=False, paths=None, **kwargs):
""" Import a Schema XML file. Loading the same file more than once will
return the initial instantiation, unless `reload` is `True`.

Expand Down
39 changes: 39 additions & 0 deletions ebmlite/core.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import (Any,
BinaryIO,
ClassVar,
Dict,
List,
Optional,
Text,
Type,
Union)

SCHEMA_PATH: List[str] = ...
SCHEMATA: Dict[Union[str, bytes, bytearray], Any] = ...


class Element:

id: Any
name: Optional[Text]

schema: ClassVar[Optional[Schema]]
dtype: ClassVar[Optional[Type]]

def parse(self, stream: BinaryIO, size: int) -> Any: ...

def __init__(self,
stream: Optional[BinaryIO] = ...,
offset: int = ...,
size: int = ...,
payloadOffset: int = ...): ...

def getRaw(self) -> Union[bytes, bytearray]: ...


class Schema(object): ...


def loadSchema(filename: Union[str, bytes, bytearray],
reload: bool = ...,
**kwargs) -> Schema: ...
Loading