-
Notifications
You must be signed in to change notification settings - Fork 100
Description
Hybrid runtime-static type-checking @beartype maintainer @leycec here. Because we live on the worst timeline, @beartype users like @pablovela5620 pour one out for Austin Texas are now futilely attempting to integrate @beartype with @tensorclass
-decorated types.
Predictably, @beartype refuses. Minimal-working example or it didn't happen:
from beartype import beartype
from tensordict.tensorclass import tensorclass
from torch import Tensor
@beartype
@tensorclass
class MyData(object):
floatdata: Tensor
intdata: Tensor
non_tensordata: str
...which raises the unreadable exception:
Traceback (most recent call last):
File "/home/leycec/tmp/mopy.py", line 8, in <module>
@beartype
^^^^^^^^
File "/home/leycec/py/beartype/beartype/_decor/decorcache.py", line 74, in beartype
return beartype_object(obj, conf)
File "/home/leycec/py/beartype/beartype/_decor/decorcore.py", line 87, in beartype_object
_beartype_object_fatal(obj, conf=conf, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/leycec/py/beartype/beartype/_decor/decorcore.py", line 133, in _beartype_object_fatal
beartype_type(obj, **kwargs) # type: ignore[return-value]
~~~~~~~~~~~~~^^^^^^^^^^^^^^^
File "/home/leycec/py/beartype/beartype/_decor/_decortype.py", line 224, in beartype_type
attr_value_beartyped = beartype_object( # type: ignore[type-var]
obj=attr_value, conf=conf, cls_stack=cls_stack)
File "/home/leycec/py/beartype/beartype/_decor/decorcore.py", line 87, in beartype_object
_beartype_object_fatal(obj, conf=conf, **kwargs)
~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/leycec/py/beartype/beartype/_decor/decorcore.py", line 137, in _beartype_object_fatal
beartype_nontype(obj, **kwargs) # type: ignore[return-value]
~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
File "/home/leycec/py/beartype/beartype/_decor/_decornontype.py", line 249, in beartype_nontype
return beartype_func(obj, **kwargs) # type: ignore[return-value]
File "/home/leycec/py/beartype/beartype/_decor/_decornontype.py", line 336, in beartype_func
func_wrapper_code = generate_code(decor_meta)
File "/home/leycec/py/beartype/beartype/_decor/wrap/wrapmain.py", line 126, in generate_code
code_check_return = _code_check_return(decor_meta)
File "/home/leycec/py/beartype/beartype/_decor/wrap/_wrapreturn.py", line 253, in code_check_return
reraise_exception_placeholder(
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
exception=exception,
^^^^^^^^^^^^^^^^^^^^
...<3 lines>...
),
^^
)
^
File "/home/leycec/py/beartype/beartype/_util/error/utilerrraise.py", line 137, in reraise_exception_placeholder
raise exception.with_traceback(exception.__traceback__)
File "/home/leycec/py/beartype/beartype/_decor/wrap/_wrapreturn.py", line 135, in code_check_return
hint_or_sane = sanify_hint_root_func(
decor_meta=decor_meta,
...<2 lines>...
exception_prefix=EXCEPTION_PLACEHOLDER,
)
File "/home/leycec/py/beartype/beartype/_check/convert/convsanify.py", line 200, in sanify_hint_root_func
hint_or_sane = reduce_hint(
hint=hint,
...<5 lines>...
exception_prefix=exception_prefix,
)
File "/home/leycec/py/beartype/beartype/_check/convert/_reduce/redhint.py", line 379, in reduce_hint
hint_or_sane = _reduce_hint_cached(hint, conf, exception_prefix)
File "/home/leycec/py/beartype/beartype/_util/cache/utilcachecall.py", line 249, in _callable_cached
raise exception
File "/home/leycec/py/beartype/beartype/_util/cache/utilcachecall.py", line 241, in _callable_cached
return_value = args_flat_to_return_value[args_flat] = func(
~~~~^
*args)
^^^^^^
File "/home/leycec/py/beartype/beartype/_check/convert/_reduce/redhint.py", line 511, in _reduce_hint_cached
hint = hint_reducer( # type: ignore[call-arg]
hint=hint, # pyright: ignore[reportGeneralTypeIssues]
conf=conf,
exception_prefix=exception_prefix,
)
File "/home/leycec/py/beartype/beartype/_check/convert/_reduce/_nonpep/rednonpeptype.py", line 103, in reduce_hint_nonpep_type
die_unless_hint(hint=hint, exception_prefix=exception_prefix)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/leycec/py/beartype/beartype/_util/hint/utilhinttest.py", line 103, in die_unless_hint
die_unless_hint_nonpep(hint=hint, exception_prefix=exception_prefix)
~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/leycec/py/beartype/beartype/_util/hint/nonpep/utilnonpeptest.py", line 204, in die_unless_hint_nonpep
raise exception_cls(
...<6 lines>...
)
beartype.roar.BeartypeDecorHintNonpepException: Method tensordict.tensorclass._eq()
return type hint <function _wrap_td_method.<locals>.wrapped_func at 0x726f56d64e00>
either PEP-noncompliant or PEP-compliant but currently unsupported by @beartype. You
suddenly feel encouraged to submit a feature request for this hint to our friendly
issue tracker at:
https://github.com/beartype/beartype/issues
I'm... outta my depth here, guys. There's so much badness happening all it once that it's non-trivial to track all of the badness.
Let's start with the obvious.
TensorDict Return Type Hints Appear to be Madness Incarnate
The tensordict.tensorclass._eq() method appears to be annotated by a return type hint that is (...waitforit) a <function _wrap_td_method.<locals>.wrapped_func at 0x726f56d64e00>
closure function.
No idea, guys. @beartype doesn't generally lie about these things. Very well. I admit it. @beartype often lies. So, that's possibly what is happening. Alternately, @beartype could be as confused as we all are right now. Let's generously assume that what @beartype is thinking is happening is happening. It's unclear why anybody would want that to happen, though. Also, this is why type-checking actually is important. It's not something that TensorDict should be intentionally ignoring. Type-checking could have caught issues like this earlier.
Because something definitely smells in the TensorDict codebase. Very well. It might be the @beartype codebase that actually smells here. Functions aren't valid type hints, obviously. So why is TensorDict possibly using functions as type hints? Alternately, why does @beartype mistakenly believe that this is what is happening?
No idea, guys. It's probably just a trivial typo somewhere five layers deep in the bowels of some private submodule dynamically defined in-memory, just because. Still, this is badness.
Which leads us to...
type(MyData) is type
The type of @tensorclass
-decorated dataclasses inexplicably appears to be (...waitforit) the ambiguous type
superclass that conveys no meaningful semantics:
# Uhh... wat? Tensorclasses don't even have a distinguishable class?
# You've gotta be friggin' kidding me here. What is this ugly madness?
>>> print(type(MyData))
type # <-- ...you don't say
Uhh. Wat? Because these dataclasses inherit from no TensorDict-specific superclasses or metaclasses, downstream third-party consumers like @beartype have no sane means of detecting and thus supporting these dataclasses at runtime.
Like, seriously. Do something to distinguish the method-resolution orders (MROs) of these dataclasses from standard @dataclasses.dataclass
-decorated dataclasses. Doing nothing definitely isn't cutting it.
Which leads us to...
That Ship Has Already Sailed
So. I get it. The @tensorclass
ship has already sailed. There's already millions of lines of code in the wild using this API. @beartype can hope for a better API all it likes – but that doesn't particularly change the grim reality of the situation on the ground.
So. How should downstream third-party consumers like @beartype detect these dataclasses then? I note a public tensordict.is_tensorclass()
tester function that seems particularly pertinent. The issue there, of course, is that calling that function requires @beartype to import that function which requires @beartype to then require TensorDict as a mandatory dependency – which @beartype is absolutely not doing. @beartype has no mandatory dependencies and that's never changing.
So, What now? Even importing TensorDict with import tensordict
and then immediately terminating the active Python process under python -v
shows an extreme number of extremely costly imports. In other words, @beartype is absolutely also not doing something like this:
try:
from tensordict import is_tensorclass
except ImportError:
def is_tensorclass(obj: object) -> bool: return False
Although trivial, that would make importing @beartype itself equally costly. The whole point of @beartype is to be "cost-free." Which leads us to...
dir(MyData)
is Truly a Nightmare on Earth
Seriously. I cannot believe how utterly intense the TensorDict
API is. This is the most outrageously verbose dir()
output I've ever seen Python spit out:
# My eyes and your eyes are now both bleeding.
>>> print(dir(MyData))
['__abs__', '__add__', '__and__', '__annotations__', '__bool__', '__class__',
'__dataclass_fields__', '__dataclass_params__', '__delattr__', '__dict__',
'__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__expected_keys__',
'__firstlineno__', '__format__', '__ge__', '__getattr__', '__getattribute__',
'__getitem__', '__getitems__', '__getstate__', '__gt__', '__hash__', '__iadd__',
'__imul__', '__init__', '__init_subclass__', '__invert__', '__ipow__', '__isub__',
'__itruediv__', '__le__', '__len__', '__lt__', '__match_args__', '__module__',
'__mul__', '__ne__', '__neg__', '__new__', '__or__', '__pow__', '__radd__',
'__rand__', '__reduce__', '__reduce_ex__', '__replace__', '__repr__', '__rmul__',
'__ror__', '__rpow__', '__rsub__', '__rtruediv__', '__rxor__', '__setattr__',
'__setitem__', '__setstate__', '__sizeof__', '__static_attributes__', '__str__',
'__sub__', '__subclasshook__', '__torch_function__', '__truediv__', '__weakref__',
'__xor__', '_add_batch_dim', '_apply_nest', '_autocast', '_check_batch_size',
'_check_device', '_check_dim_name', '_check_unlock', '_clone', '_clone_recurse',
'_data', '_default_get', '_erase_names', '_exclude', '_fast_apply',
'_flatten_keys_inplace', '_flatten_keys_outplace', '_from_dict_validated',
'_from_module', '_from_tensordict', '_frozen', '_get_at_str', '_get_at_tuple',
'_get_names_idx', '_get_str', '_get_sub_tensordict', '_get_tuple',
'_get_tuple_maybe_non_tensor', '_grad', '_has_names', '_is_non_tensor',
'_is_tensorclass', '_items_list', '_load_memmap', '_map', '_maybe_names',
'_maybe_remove_batch_dim', '_memmap_', '_multithread_apply_flat',
'_multithread_apply_nest', '_multithread_rebuild', '_new_unsafe', '_nocast',
'_permute', '_propagate_lock', '_propagate_unlock', '_reduce_get_metadata',
'_remove_batch_dim', '_repeat', '_select', '_set_at_str', '_set_at_tuple',
'_set_str', '_set_tuple', '_shadow', '_to_module', '_type_hints', '_unbind',
'_values_list', 'abs', 'abs_', 'acos', 'acos_', 'add', 'add_', 'addcdiv',
'addcdiv_', 'addcmul', 'addcmul_', 'all', 'amax', 'amin', 'any', 'apply',
'apply_', 'as_tensor', 'asin', 'asin_', 'atan', 'atan_', 'auto_batch_size_',
'auto_device_', 'batch_dims', 'batch_size', 'bfloat16', 'bitwise_and', 'bool',
'bytes', 'cat', 'cat_from_tensordict', 'cat_tensors', 'ceil', 'ceil_', 'chunk',
'clamp', 'clamp_max', 'clamp_max_', 'clamp_min', 'clamp_min_', 'clear',
'clear_device_', 'clear_refs_for_compile_', 'clone', 'complex128', 'complex32',
'complex64', 'consolidate', 'contiguous', 'copy', 'copy_', 'copy_at_', 'cos',
'cos_', 'cosh', 'cosh_', 'cpu', 'create_nested', 'cuda', 'cummax', 'cummin',
'data', 'data_ptr', 'del_', 'densify', 'depth', 'detach', 'detach_', 'device',
'dim', 'div', 'div_', 'double', 'dtype', 'dumps', 'empty', 'entry_class', 'erf',
'erf_', 'erfc', 'erfc_', 'exclude', 'exp', 'exp_', 'expand', 'expand_as', 'expm1',
'expm1_', 'fields', 'fill_', 'filter_empty_', 'filter_non_tensor_data', 'flatten',
'flatten_keys', 'float', 'float16', 'float32', 'float64', 'floor', 'floor_',
'frac', 'frac_', 'from_any', 'from_consolidated', 'from_dataclass', 'from_dict',
'from_dict_instance', 'from_h5', 'from_module', 'from_modules', 'from_namedtuple',
'from_pytree', 'from_struct_array', 'from_tensordict', 'from_tuple', 'fromkeys',
'gather', 'gather_and_stack', 'get', 'get_at', 'get_item_shape', 'get_non_tensor',
'grad', 'half', 'int', 'int16', 'int32', 'int64', 'int8', 'irecv',
'is_consolidated', 'is_contiguous', 'is_cpu', 'is_cuda', 'is_empty',
'is_floating_point', 'is_locked', 'is_memmap', 'is_meta', 'is_shared', 'isend',
'isfinite', 'isnan', 'isneginf', 'isposinf', 'isreal', 'items', 'keys',
'lazy_stack', 'lerp', 'lerp_', 'lgamma', 'lgamma_', 'load', 'load_',
'load_memmap', 'load_memmap_', 'load_state_dict', 'lock_', 'log', 'log10',
'log10_', 'log1p', 'log1p_', 'log2', 'log2_', 'log_', 'logical_and', 'logsumexp',
'make_memmap', 'make_memmap_from_storage', 'make_memmap_from_tensor', 'map',
'map_iter', 'masked_fill', 'masked_fill_', 'masked_select', 'max', 'maximum',
'maximum_', 'maybe_dense_stack', 'mean', 'memmap', 'memmap_', 'memmap_like',
'memmap_refresh_', 'min', 'minimum', 'minimum_', 'mul', 'mul_', 'named_apply',
'names', 'nanmean', 'nansum', 'ndim', 'ndimension', 'neg', 'neg_', 'new_empty',
'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'non_tensor_items', 'norm',
'numel', 'numpy', 'param_count', 'permute', 'pin_memory', 'pin_memory_', 'pop',
'popitem', 'pow', 'pow_', 'prod', 'qint32', 'qint8', 'quint4x2', 'quint8',
'reciprocal', 'reciprocal_', 'record_stream', 'recv', 'reduce', 'refine_names',
'rename', 'rename_', 'rename_key_', 'repeat', 'repeat_interleave', 'replace',
'requires_grad', 'requires_grad_', 'reshape', 'round', 'round_', 'save',
'saved_path', 'select', 'send', 'separates', 'set', 'set_', 'set_at_',
'set_non_tensor', 'setdefault', 'shape', 'share_memory_', 'sigmoid', 'sigmoid_',
'sign', 'sign_', 'sin', 'sin_', 'sinh', 'sinh_', 'size', 'softmax', 'sorted_keys',
'split', 'split_keys', 'sqrt', 'sqrt_', 'squeeze', 'stack',
'stack_from_tensordict', 'stack_tensors', 'state_dict', 'std', 'sub', 'sub_',
'sum', 'tan', 'tan_', 'tanh', 'tanh_', 'to', 'to_dict', 'to_h5', 'to_module',
'to_namedtuple', 'to_padded_tensor', 'to_pytree', 'to_struct_array',
'to_tensordict', 'transpose', 'trunc', 'trunc_', 'type', 'uint16', 'uint32',
'uint64', 'uint8', 'unbind', 'unflatten', 'unflatten_keys', 'unlock_',
'unsqueeze', 'update', 'update_', 'update_at_', 'values', 'var', 'view', 'where',
'zero_', 'zero_grad']
I mean... just... WTTTTTTTTTTTTTTTTF!? There's even an attribute called cummin
up there, which just raises even more questions than it answers. 🤣
I genuinely have no idea where to start with this beastly nightmare. Technically, I do note a private _is_tensorclass
attribute in the above output. The existence of that attribute definitely suggests the current type to be a tensordict.tensorclass
. There's no guarantee of that, though. Other third-party types in the wild might very well define the same private _is_tensorclass
attribute.
But... I guess that's what I gotta roll with, huh? Sucks, guys. This pretty much sucks.
So. Are You Actually Suggesting Anything Useful?
Yeah. I'm not here to just complain. That's only one of the reasons I'm here. 😉
Basically, feature request #663 really needs to happen to make TensorDict usable at runtime by everybody else in the Python ecosystem. If #663 happens, then @beartype can just trivially compare the obj.__class__.__module__
and obj.__class__.__name__
dunder attributes guaranteed to exist on any arbitrary object to decide whether that object is a TensorDict dataclass or not.
This trivial test just operates on strings and thus doesn't require @beartype to import TensorDict in a global scope: e.g.,
def is_tensorclass(obj: object) -> bool:
return (
obj.__class__.__module__ == 'tensordict' and # <-- yay!
obj.__class__.__name__ == 'TensorDict' # <-------- yay x 2!!
)
Until that glorious release day happens, I guess the best we can do is (as suggested above):
def is_tensorclass(obj: object) -> bool:
return type(obj) is type and hasattr(obj, '_is_tensorclass') # <-- yikes
That's... not great. But that's what we rollin' with. I sigh! I sigh so hard. 😩