Skip to content

[BUG] @tensorclass-decorated dataclasses violate @beartype-based runtime type-checking 😭 #1243

@leycec

Description

@leycec

Hybrid runtime-static type-checking @beartype maintainer @leycec here. Because we live on the worst timeline, @beartype users like @pablovela5620 pour one out for Austin Texas are now futilely attempting to integrate @beartype with @tensorclass-decorated types.

Predictably, @beartype refuses. Minimal-working example or it didn't happen:

from beartype import beartype
from tensordict.tensorclass import tensorclass
from torch import Tensor

@beartype
@tensorclass
class MyData(object):
    floatdata: Tensor
    intdata: Tensor
    non_tensordata: str

...which raises the unreadable exception:

Traceback (most recent call last):
  File "/home/leycec/tmp/mopy.py", line 8, in <module>
    @beartype
     ^^^^^^^^
  File "/home/leycec/py/beartype/beartype/_decor/decorcache.py", line 74, in beartype
    return beartype_object(obj, conf)
  File "/home/leycec/py/beartype/beartype/_decor/decorcore.py", line 87, in beartype_object
    _beartype_object_fatal(obj, conf=conf, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/beartype/beartype/_decor/decorcore.py", line 133, in _beartype_object_fatal
    beartype_type(obj, **kwargs)  # type: ignore[return-value]
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/leycec/py/beartype/beartype/_decor/_decortype.py", line 224, in beartype_type
    attr_value_beartyped = beartype_object(  # type: ignore[type-var]
        obj=attr_value, conf=conf, cls_stack=cls_stack)
  File "/home/leycec/py/beartype/beartype/_decor/decorcore.py", line 87, in beartype_object
    _beartype_object_fatal(obj, conf=conf, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/beartype/beartype/_decor/decorcore.py", line 137, in _beartype_object_fatal
    beartype_nontype(obj, **kwargs)  # type: ignore[return-value]
    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/leycec/py/beartype/beartype/_decor/_decornontype.py", line 249, in beartype_nontype
    return beartype_func(obj, **kwargs)  # type: ignore[return-value]
  File "/home/leycec/py/beartype/beartype/_decor/_decornontype.py", line 336, in beartype_func
    func_wrapper_code = generate_code(decor_meta)
  File "/home/leycec/py/beartype/beartype/_decor/wrap/wrapmain.py", line 126, in generate_code
    code_check_return = _code_check_return(decor_meta)
  File "/home/leycec/py/beartype/beartype/_decor/wrap/_wrapreturn.py", line 253, in code_check_return
    reraise_exception_placeholder(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        exception=exception,
        ^^^^^^^^^^^^^^^^^^^^
    ...<3 lines>...
        ),
        ^^
    )
    ^
  File "/home/leycec/py/beartype/beartype/_util/error/utilerrraise.py", line 137, in reraise_exception_placeholder
    raise exception.with_traceback(exception.__traceback__)
  File "/home/leycec/py/beartype/beartype/_decor/wrap/_wrapreturn.py", line 135, in code_check_return
    hint_or_sane = sanify_hint_root_func(
        decor_meta=decor_meta,
    ...<2 lines>...
        exception_prefix=EXCEPTION_PLACEHOLDER,
    )
  File "/home/leycec/py/beartype/beartype/_check/convert/convsanify.py", line 200, in sanify_hint_root_func
    hint_or_sane = reduce_hint(
        hint=hint,
    ...<5 lines>...
        exception_prefix=exception_prefix,
    )
  File "/home/leycec/py/beartype/beartype/_check/convert/_reduce/redhint.py", line 379, in reduce_hint
    hint_or_sane = _reduce_hint_cached(hint, conf, exception_prefix)
  File "/home/leycec/py/beartype/beartype/_util/cache/utilcachecall.py", line 249, in _callable_cached
    raise exception
  File "/home/leycec/py/beartype/beartype/_util/cache/utilcachecall.py", line 241, in _callable_cached
    return_value = args_flat_to_return_value[args_flat] = func(
                                                          ~~~~^
        *args)
        ^^^^^^
  File "/home/leycec/py/beartype/beartype/_check/convert/_reduce/redhint.py", line 511, in _reduce_hint_cached
    hint = hint_reducer(  # type: ignore[call-arg]
        hint=hint,  # pyright: ignore[reportGeneralTypeIssues]
        conf=conf,
        exception_prefix=exception_prefix,
    )
  File "/home/leycec/py/beartype/beartype/_check/convert/_reduce/_nonpep/rednonpeptype.py", line 103, in reduce_hint_nonpep_type
    die_unless_hint(hint=hint, exception_prefix=exception_prefix)
    ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/beartype/beartype/_util/hint/utilhinttest.py", line 103, in die_unless_hint
    die_unless_hint_nonpep(hint=hint, exception_prefix=exception_prefix)
    ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/leycec/py/beartype/beartype/_util/hint/nonpep/utilnonpeptest.py", line 204, in die_unless_hint_nonpep
    raise exception_cls(
    ...<6 lines>...
    )
beartype.roar.BeartypeDecorHintNonpepException: Method tensordict.tensorclass._eq()
return type hint <function _wrap_td_method.<locals>.wrapped_func at 0x726f56d64e00> 
either PEP-noncompliant or PEP-compliant but currently unsupported by @beartype. You 
suddenly feel encouraged to submit a feature request for this hint to our friendly 
issue tracker at:
	https://github.com/beartype/beartype/issues

I'm... outta my depth here, guys. There's so much badness happening all it once that it's non-trivial to track all of the badness.

Let's start with the obvious.

TensorDict Return Type Hints Appear to be Madness Incarnate

The tensordict.tensorclass._eq() method appears to be annotated by a return type hint that is (...waitforit) a <function _wrap_td_method.<locals>.wrapped_func at 0x726f56d64e00> closure function.

No idea, guys. @beartype doesn't generally lie about these things. Very well. I admit it. @beartype often lies. So, that's possibly what is happening. Alternately, @beartype could be as confused as we all are right now. Let's generously assume that what @beartype is thinking is happening is happening. It's unclear why anybody would want that to happen, though. Also, this is why type-checking actually is important. It's not something that TensorDict should be intentionally ignoring. Type-checking could have caught issues like this earlier.

Because something definitely smells in the TensorDict codebase. Very well. It might be the @beartype codebase that actually smells here. Functions aren't valid type hints, obviously. So why is TensorDict possibly using functions as type hints? Alternately, why does @beartype mistakenly believe that this is what is happening?

No idea, guys. It's probably just a trivial typo somewhere five layers deep in the bowels of some private submodule dynamically defined in-memory, just because. Still, this is badness.

Which leads us to...

type(MyData) is type

The type of @tensorclass-decorated dataclasses inexplicably appears to be (...waitforit) the ambiguous type superclass that conveys no meaningful semantics:

# Uhh... wat? Tensorclasses don't even have a distinguishable class?
# You've gotta be friggin' kidding me here. What is this ugly madness?
>>> print(type(MyData))
type  # <-- ...you don't say

Uhh. Wat? Because these dataclasses inherit from no TensorDict-specific superclasses or metaclasses, downstream third-party consumers like @beartype have no sane means of detecting and thus supporting these dataclasses at runtime.

Like, seriously. Do something to distinguish the method-resolution orders (MROs) of these dataclasses from standard @dataclasses.dataclass-decorated dataclasses. Doing nothing definitely isn't cutting it.

Which leads us to...

That Ship Has Already Sailed

So. I get it. The @tensorclass ship has already sailed. There's already millions of lines of code in the wild using this API. @beartype can hope for a better API all it likes – but that doesn't particularly change the grim reality of the situation on the ground.

So. How should downstream third-party consumers like @beartype detect these dataclasses then? I note a public tensordict.is_tensorclass() tester function that seems particularly pertinent. The issue there, of course, is that calling that function requires @beartype to import that function which requires @beartype to then require TensorDict as a mandatory dependency – which @beartype is absolutely not doing. @beartype has no mandatory dependencies and that's never changing.

So, What now? Even importing TensorDict with import tensordict and then immediately terminating the active Python process under python -v shows an extreme number of extremely costly imports. In other words, @beartype is absolutely also not doing something like this:

try:
    from tensordict import is_tensorclass
except ImportError:
    def is_tensorclass(obj: object) -> bool: return False

Although trivial, that would make importing @beartype itself equally costly. The whole point of @beartype is to be "cost-free." Which leads us to...

dir(MyData) is Truly a Nightmare on Earth

Seriously. I cannot believe how utterly intense the TensorDict API is. This is the most outrageously verbose dir() output I've ever seen Python spit out:

# My eyes and your eyes are now both bleeding.
>>> print(dir(MyData))
['__abs__', '__add__', '__and__', '__annotations__', '__bool__', '__class__',
'__dataclass_fields__', '__dataclass_params__', '__delattr__', '__dict__',
'__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__expected_keys__',
'__firstlineno__', '__format__', '__ge__', '__getattr__', '__getattribute__', 
'__getitem__', '__getitems__', '__getstate__', '__gt__', '__hash__', '__iadd__', 
'__imul__', '__init__', '__init_subclass__', '__invert__', '__ipow__', '__isub__',
'__itruediv__', '__le__', '__len__', '__lt__', '__match_args__', '__module__', 
'__mul__', '__ne__', '__neg__', '__new__', '__or__', '__pow__', '__radd__', 
'__rand__', '__reduce__', '__reduce_ex__', '__replace__', '__repr__', '__rmul__', 
'__ror__', '__rpow__', '__rsub__', '__rtruediv__', '__rxor__', '__setattr__', 
'__setitem__', '__setstate__', '__sizeof__', '__static_attributes__', '__str__', 
'__sub__', '__subclasshook__', '__torch_function__', '__truediv__', '__weakref__',
'__xor__', '_add_batch_dim', '_apply_nest', '_autocast', '_check_batch_size', 
'_check_device', '_check_dim_name', '_check_unlock', '_clone', '_clone_recurse', 
'_data', '_default_get', '_erase_names', '_exclude', '_fast_apply', 
'_flatten_keys_inplace', '_flatten_keys_outplace', '_from_dict_validated', 
'_from_module', '_from_tensordict', '_frozen', '_get_at_str', '_get_at_tuple', 
'_get_names_idx', '_get_str', '_get_sub_tensordict', '_get_tuple', 
'_get_tuple_maybe_non_tensor', '_grad', '_has_names', '_is_non_tensor', 
'_is_tensorclass', '_items_list', '_load_memmap', '_map', '_maybe_names', 
'_maybe_remove_batch_dim', '_memmap_', '_multithread_apply_flat', 
'_multithread_apply_nest', '_multithread_rebuild', '_new_unsafe', '_nocast', 
'_permute', '_propagate_lock', '_propagate_unlock', '_reduce_get_metadata', 
'_remove_batch_dim', '_repeat', '_select', '_set_at_str', '_set_at_tuple', 
'_set_str', '_set_tuple', '_shadow', '_to_module', '_type_hints', '_unbind', 
'_values_list', 'abs', 'abs_', 'acos', 'acos_', 'add', 'add_', 'addcdiv', 
'addcdiv_', 'addcmul', 'addcmul_', 'all', 'amax', 'amin', 'any', 'apply', 
'apply_', 'as_tensor', 'asin', 'asin_', 'atan', 'atan_', 'auto_batch_size_', 
'auto_device_', 'batch_dims', 'batch_size', 'bfloat16', 'bitwise_and', 'bool', 
'bytes', 'cat', 'cat_from_tensordict', 'cat_tensors', 'ceil', 'ceil_', 'chunk', 
'clamp', 'clamp_max', 'clamp_max_', 'clamp_min', 'clamp_min_', 'clear', 
'clear_device_', 'clear_refs_for_compile_', 'clone', 'complex128', 'complex32', 
'complex64', 'consolidate', 'contiguous', 'copy', 'copy_', 'copy_at_', 'cos', 
'cos_', 'cosh', 'cosh_', 'cpu', 'create_nested', 'cuda', 'cummax', 'cummin', 
'data', 'data_ptr', 'del_', 'densify', 'depth', 'detach', 'detach_', 'device', 
'dim', 'div', 'div_', 'double', 'dtype', 'dumps', 'empty', 'entry_class', 'erf', 
'erf_', 'erfc', 'erfc_', 'exclude', 'exp', 'exp_', 'expand', 'expand_as', 'expm1', 
'expm1_', 'fields', 'fill_', 'filter_empty_', 'filter_non_tensor_data', 'flatten', 
'flatten_keys', 'float', 'float16', 'float32', 'float64', 'floor', 'floor_', 
'frac', 'frac_', 'from_any', 'from_consolidated', 'from_dataclass', 'from_dict', 
'from_dict_instance', 'from_h5', 'from_module', 'from_modules', 'from_namedtuple', 
'from_pytree', 'from_struct_array', 'from_tensordict', 'from_tuple', 'fromkeys', 
'gather', 'gather_and_stack', 'get', 'get_at', 'get_item_shape', 'get_non_tensor', 
'grad', 'half', 'int', 'int16', 'int32', 'int64', 'int8', 'irecv', 
'is_consolidated', 'is_contiguous', 'is_cpu', 'is_cuda', 'is_empty', 
'is_floating_point', 'is_locked', 'is_memmap', 'is_meta', 'is_shared', 'isend', 
'isfinite', 'isnan', 'isneginf', 'isposinf', 'isreal', 'items', 'keys', 
'lazy_stack', 'lerp', 'lerp_', 'lgamma', 'lgamma_', 'load', 'load_', 
'load_memmap', 'load_memmap_', 'load_state_dict', 'lock_', 'log', 'log10', 
'log10_', 'log1p', 'log1p_', 'log2', 'log2_', 'log_', 'logical_and', 'logsumexp', 
'make_memmap', 'make_memmap_from_storage', 'make_memmap_from_tensor', 'map', 
'map_iter', 'masked_fill', 'masked_fill_', 'masked_select', 'max', 'maximum', 
'maximum_', 'maybe_dense_stack', 'mean', 'memmap', 'memmap_', 'memmap_like', 
'memmap_refresh_', 'min', 'minimum', 'minimum_', 'mul', 'mul_', 'named_apply', 
'names', 'nanmean', 'nansum', 'ndim', 'ndimension', 'neg', 'neg_', 'new_empty', 
'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'non_tensor_items', 'norm', 
'numel', 'numpy', 'param_count', 'permute', 'pin_memory', 'pin_memory_', 'pop', 
'popitem', 'pow', 'pow_', 'prod', 'qint32', 'qint8', 'quint4x2', 'quint8', 
'reciprocal', 'reciprocal_', 'record_stream', 'recv', 'reduce', 'refine_names', 
'rename', 'rename_', 'rename_key_', 'repeat', 'repeat_interleave', 'replace', 
'requires_grad', 'requires_grad_', 'reshape', 'round', 'round_', 'save', 
'saved_path', 'select', 'send', 'separates', 'set', 'set_', 'set_at_', 
'set_non_tensor', 'setdefault', 'shape', 'share_memory_', 'sigmoid', 'sigmoid_', 
'sign', 'sign_', 'sin', 'sin_', 'sinh', 'sinh_', 'size', 'softmax', 'sorted_keys', 
'split', 'split_keys', 'sqrt', 'sqrt_', 'squeeze', 'stack', 
'stack_from_tensordict', 'stack_tensors', 'state_dict', 'std', 'sub', 'sub_', 
'sum', 'tan', 'tan_', 'tanh', 'tanh_', 'to', 'to_dict', 'to_h5', 'to_module', 
'to_namedtuple', 'to_padded_tensor', 'to_pytree', 'to_struct_array', 
'to_tensordict', 'transpose', 'trunc', 'trunc_', 'type', 'uint16', 'uint32', 
'uint64', 'uint8', 'unbind', 'unflatten', 'unflatten_keys', 'unlock_', 
'unsqueeze', 'update', 'update_', 'update_at_', 'values', 'var', 'view', 'where', 
'zero_', 'zero_grad']

I mean... just... WTTTTTTTTTTTTTTTTF!? There's even an attribute called cummin up there, which just raises even more questions than it answers. 🤣

I genuinely have no idea where to start with this beastly nightmare. Technically, I do note a private _is_tensorclass attribute in the above output. The existence of that attribute definitely suggests the current type to be a tensordict.tensorclass. There's no guarantee of that, though. Other third-party types in the wild might very well define the same private _is_tensorclass attribute.

But... I guess that's what I gotta roll with, huh? Sucks, guys. This pretty much sucks.

So. Are You Actually Suggesting Anything Useful?

Yeah. I'm not here to just complain. That's only one of the reasons I'm here. 😉

Basically, feature request #663 really needs to happen to make TensorDict usable at runtime by everybody else in the Python ecosystem. If #663 happens, then @beartype can just trivially compare the obj.__class__.__module__ and obj.__class__.__name__ dunder attributes guaranteed to exist on any arbitrary object to decide whether that object is a TensorDict dataclass or not.

This trivial test just operates on strings and thus doesn't require @beartype to import TensorDict in a global scope: e.g.,

def is_tensorclass(obj: object) -> bool:
    return (
        obj.__class__.__module__ == 'tensordict' and  # <-- yay!
        obj.__class__.__name__ == 'TensorDict'  # <-------- yay x 2!!
    )

Until that glorious release day happens, I guess the best we can do is (as suggested above):

def is_tensorclass(obj: object) -> bool:
    return type(obj) is type and hasattr(obj, '_is_tensorclass')  # <-- yikes

That's... not great. But that's what we rollin' with. I sigh! I sigh so hard. 😩

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions