Skip to content

Commit e0e2173

Browse files
committed
Generate dynamic mode signatures for no-input operators
Signed-off-by: Rostan Tabet <[email protected]>
1 parent af6d1db commit e0e2173

File tree

3 files changed

+117
-43
lines changed

3 files changed

+117
-43
lines changed

dali/python/nvidia/dali/_typing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -71,5 +71,6 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]: ...
7171

7272
TensorLike: TypeAlias = ArrayInterface | CudaArrayInterface | DLPack
7373
"""
74-
Argument compatible with ``dali.dynamic.Tensor`` used as input for per-sample dynamic mode functions.
74+
Object compatible with ``dali.dynamic.Tensor`` used as input for per-sample
75+
dynamic mode functions.
7576
"""

dali/python/nvidia/dali/ops/_signatures.py

Lines changed: 110 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from inspect import Parameter, Signature
1615
import ast
16+
import functools
1717
import os
18-
19-
from pathlib import Path
20-
2118
from contextlib import closing
22-
23-
from typing import Union, Optional
24-
from typing import Sequence, List, Any
19+
from inspect import Parameter, Signature
20+
from pathlib import Path
21+
from typing import Any, List, Literal, Optional, Sequence, Union
2522

2623
from nvidia.dali import backend as _b
24+
from nvidia.dali import fn, ops, types
2725
from nvidia.dali import types as _types
28-
from nvidia.dali.ops import _registry, _names, _docs
29-
from nvidia.dali import types
30-
from nvidia.dali import ops, fn
26+
from nvidia.dali.ops import _docs, _names, _registry
3127

3228

3329
def _create_annotation_placeholder(typename):
@@ -77,6 +73,11 @@ def __repr__(self):
7773
types.DALIInterpType: _DALIInterpType,
7874
}
7975

76+
# Placeholders for dynamic mode
77+
_Tensor = _create_annotation_placeholder("Tensor")
78+
_Batch = _create_annotation_placeholder("Batch")
79+
_TensorLike = _create_annotation_placeholder("TensorLike")
80+
8081

8182
def _scalar_element_annotation(scalar_dtype):
8283
# We already have function that converts a scalar constant/literal into the desired type,
@@ -241,7 +242,7 @@ def _get_positional_input_params(schema, input_annotation_gen=_get_annotation_in
241242
return param_list
242243

243244

244-
def _get_keyword_params(schema, all_args_optional=False):
245+
def _get_keyword_params(schema, all_args_optional=False, data_node_tensors=False):
245246
"""Get the list of annotated keyword Parameters to the operator."""
246247
param_list = []
247248
for arg in schema.GetArgumentNames():
@@ -253,7 +254,11 @@ def _get_keyword_params(schema, all_args_optional=False):
253254
is_arg_input = schema.IsTensorArgument(arg)
254255

255256
if is_arg_input:
256-
annotation = Union[_DataNode, _TensorLikeArg, kw_annotation]
257+
annotation = (
258+
Union[_DataNode, _TensorLikeArg, kw_annotation]
259+
if data_node_tensors
260+
else Union[_TensorLikeArg, kw_annotation]
261+
)
257262
else:
258263
annotation = kw_annotation
259264

@@ -310,7 +315,9 @@ def _call_signature(
310315
include_inputs=True,
311316
include_kwargs=True,
312317
include_self=False,
318+
include_batch_size=False,
313319
data_node_return=True,
320+
data_node_kwargs=True,
314321
all_args_optional=False,
315322
input_annotation_gen=_get_annotation_input_regular,
316323
return_annotation_gen=_get_annotation_return_regular,
@@ -328,9 +335,13 @@ def _call_signature(
328335
If keyword arguments should be included in the signature, by default True
329336
include_self : bool, optional
330337
Prepend `self` as first positional argument in the signature, by default False
338+
include_batch_size : bool, optional
339+
Preped `batch_size` as first keyword-only argument in the signature, by default False
331340
data_node_return : bool, optional
332341
If the signature should have a return annotation or return None (for ops class __init__),
333342
by default True
343+
data_node_kwargs : bool, optional
344+
If tensor keyword arguments should accept DataNodes, by default True
334345
all_args_optional : bool, optional
335346
Make all keyword arguments optional, even if they are not - needed by the ops API, where
336347
the argument can be specified in either __init__ or __call__, by default False
@@ -348,8 +359,18 @@ def _call_signature(
348359
_get_positional_input_params(schema, input_annotation_gen=input_annotation_gen)
349360
)
350361

362+
if include_batch_size:
363+
parameter = Parameter(name="batch_size", kind=Parameter.KEYWORD_ONLY, annotation=int)
364+
param_list.append(parameter)
365+
351366
if include_kwargs:
352-
param_list.extend(_get_keyword_params(schema, all_args_optional=all_args_optional))
367+
param_list.extend(
368+
_get_keyword_params(
369+
schema,
370+
all_args_optional=all_args_optional,
371+
data_node_tensors=data_node_kwargs,
372+
)
373+
)
353374
param_list.extend(_get_implicit_keyword_params(schema, all_args_optional=all_args_optional))
354375

355376
if data_node_return:
@@ -494,6 +515,33 @@ def __init__{_call_signature(schema, include_inputs=False, include_kwargs=True,
494515
)
495516

496517

518+
def _gen_dynamic_signature_no_input(schema: _b.OpSchema, schema_name: str, fn_name: str):
519+
"""TODO"""
520+
call_signature = functools.partial(_call_signature, schema, data_node_kwargs=False)
521+
return f"""
522+
@overload
523+
def {fn_name}{call_signature(return_annotation_gen=lambda _: _Tensor)}:
524+
\"""{_docs._docstring_generator_fn(schema_name)}
525+
\"""
526+
527+
@overload
528+
def {fn_name}{call_signature(include_batch_size=True, return_annotation_gen=lambda _: _Batch)}:
529+
\"""{_docs._docstring_generator_fn(schema_name)}
530+
\"""
531+
"""
532+
533+
534+
def _gen_dynamic_signature(schema: _b.OpSchema, schema_name: str, fn_name: str):
535+
"""TODO"""
536+
signature = (
537+
_gen_dynamic_signature_no_input(schema, schema_name, fn_name)
538+
if schema.MaxNumInput() == 0
539+
else "\n...\n" # _gen_dynamic_signature_with_inputs(schema, schema_name, fn_name)
540+
)
541+
542+
return inspect_repr_fixups(signature)
543+
544+
497545
# Preamble with license and helper imports for the stub file.
498546
# We need the placeholders for actual Python classes, as the ones that are exported from backend
499547
# don't seem to work with the intellisense.
@@ -515,12 +563,19 @@ def __init__{_call_signature(schema, include_inputs=False, include_kwargs=True,
515563
from typing import Union, Optional, overload
516564
from typing import Any, List, Sequence
517565
518-
from nvidia.dali._typing import TensorLikeIn, TensorLikeArg
566+
from nvidia.dali.types import DALIDataType, DALIImageType, DALIInterpType
519567
520-
from nvidia.dali.data_node import DataNode
568+
"""
521569

522-
from nvidia.dali.types import DALIDataType, DALIImageType, DALIInterpType
570+
_PIPELINE_HEADER = """
571+
from nvidia.dali._typing import TensorLikeIn, TensorLikeArg
572+
from nvidia.dali.data_node import DataNode
573+
"""
523574

575+
_DYNAMIC_HEADER = """
576+
from nvidia.dali._typing import TensorLike, TensorLikeArg
577+
from nvidia.dali.experimental.dynamic._tensor import Tensor
578+
from nvidia.dali.experimental.dynamic._batch import Batch
524579
"""
525580

526581

@@ -574,8 +629,8 @@ def _get_op(api_module, full_qualified_name: List[str]):
574629

575630

576631
def _group_signatures(api: str):
577-
"""Divide all operators registered into the "ops" or "fn" api into 4 categories and return them
578-
as a dictionary:
632+
"""Divide all operators registered into the "ops", "fn" or "dynamic api into 4 categories
633+
and return them as a dictionary:
579634
* python_only - there is just the Python definition
580635
* hidden_or_internal - op is hidden or internal, defined in backend
581636
* python_wrapper - op defined in backend, has a hand-written wrapper (op._generated = False)
@@ -585,31 +640,39 @@ def _group_signatures(api: str):
585640
depending on the api type.
586641
587642
"""
643+
644+
from nvidia.dali.experimental import dynamic
645+
588646
sig_groups = {
589647
"python_only": [],
590648
"hidden_or_internal": [],
591649
"python_wrapper": [],
592650
"generated": [],
593651
}
594652

595-
api_module = fn if api == "fn" else ops
653+
api_module = {"fn": fn, "ops": ops, "dynamic": dynamic}[api]
654+
naming_convention = api if api != "dynamic" else "fn"
596655

597656
for schema_name in sorted(_registry._all_registered_ops()):
598657
schema = _b.TryGetSchema(schema_name)
599658

600-
_, module_nesting, op_name = _names._process_op_name(schema_name, api=api)
659+
_, module_nesting, op_name = _names._process_op_name(schema_name, api=naming_convention)
601660
op = _get_op(api_module, module_nesting + [op_name])
602661

662+
if op is None:
663+
continue
664+
603665
if schema is None:
604-
if op is not None:
605-
sig_groups["python_only"].append((schema_name, op))
666+
sig_groups["python_only"].append((schema_name, op))
606667
continue
607668

608669
if schema.IsDocHidden() or schema.IsInternal():
609670
sig_groups["hidden_or_internal"].append((schema_name, op))
610671
continue
611672

612-
if not getattr(op, "_generated", False):
673+
# Dynamic mode doesn't have registered python wrappers yet
674+
# If necessary, we can later check hasattr(op, "op_class")
675+
if not hasattr(op, "_generated") and api != "dynamic":
613676
sig_groups["python_wrapper"].append((schema_name, op))
614677
continue
615678

@@ -624,6 +687,12 @@ def __init__(self, nvidia_dali_path: Path, api: str):
624687
self._nvidia_dali_path = nvidia_dali_path
625688
self._api = api
626689
self._module_tree = _build_module_tree()
690+
self._header = _HEADER
691+
692+
if api in ("ops", "fn"):
693+
self._header += _PIPELINE_HEADER
694+
else:
695+
self._header += _DYNAMIC_HEADER
627696

628697
def get(self, module_nesting: List[str]):
629698
"""Get the file representing the given submodule nesting.
@@ -639,7 +708,7 @@ def get(self, module_nesting: List[str]):
639708
open(file_path, "w").close() # clear the file
640709
f = open(file_path, "a")
641710
self._module_to_file[module_path] = f
642-
f.write(_HEADER)
711+
f.write(self._header)
643712
full_module_nesting = [""] + module_nesting
644713
# Find out all the direct submodules and add the imports
645714
submodules_dict = self._module_tree
@@ -657,24 +726,27 @@ def close(self):
657726
f.close()
658727

659728

660-
def gen_all_signatures(nvidia_dali_path, api):
661-
"""Generate the signatures for "fn" or "ops" api.
729+
def gen_all_signatures(nvidia_dali_path: Path, api: Literal["fn", "ops", "dynamic"]):
730+
"""Generate the signatures for "fn", "ops" or "dynamic" api.
662731
663732
Parameters
664733
----------
665734
nvidia_dali_path : Path
666735
The path to the wheel pre-packaging to the nvidia/dali directory.
667736
api : str
668-
"fn" or "ops"
737+
"fn", "ops" or "dynamic"
669738
"""
670-
nvidia_dali_path = Path(nvidia_dali_path)
739+
api_path = naming_convention = api
740+
if api == "dynamic":
741+
api_path = os.path.join("experimental", api)
742+
naming_convention = "fn"
671743

672-
with closing(StubFileManager(nvidia_dali_path, api)) as stub_manager:
744+
with closing(StubFileManager(nvidia_dali_path, api_path)) as stub_manager:
673745
sig_groups = _group_signatures(api)
674746

675747
# Python-only and the manually defined ones are reexported from their respective modules
676748
for schema_name, op in sig_groups["python_only"] + sig_groups["python_wrapper"]:
677-
_, module_nesting, op_name = _names._process_op_name(schema_name, api=api)
749+
_, module_nesting, op_name = _names._process_op_name(schema_name, api=naming_convention)
678750

679751
stub_manager.get(module_nesting).write(
680752
f"\n\nfrom {op._impl_module} import" f" ({op.__name__} as {op.__name__})\n\n"
@@ -684,15 +756,14 @@ def gen_all_signatures(nvidia_dali_path, api):
684756
# directly visible
685757

686758
# Runtime generated classes use fully specified stubs.
759+
signature_generators = {
760+
"fn": _gen_fn_signature,
761+
"ops": _gen_ops_signature,
762+
"dynamic": _gen_dynamic_signature,
763+
}
687764
for schema_name, op in sig_groups["generated"]:
688-
_, module_nesting, op_name = _names._process_op_name(schema_name, api=api)
765+
_, module_nesting, op_name = _names._process_op_name(schema_name, api=naming_convention)
689766
schema = _b.TryGetSchema(schema_name)
690767

691-
if api == "fn":
692-
stub_manager.get(module_nesting).write(
693-
_gen_fn_signature(schema, schema_name, op_name)
694-
)
695-
else:
696-
stub_manager.get(module_nesting).write(
697-
_gen_ops_signature(schema, schema_name, op_name)
698-
)
768+
signature = signature_generators[api](schema, schema_name, op_name)
769+
stub_manager.get(module_nesting).write(signature)

internal_tools/python_stub_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,7 @@
3434

3535
print(f"Generating signatures for {args.wheel_path=}")
3636

37-
_signatures.gen_all_signatures(Path(args.wheel_path), "fn")
38-
_signatures.gen_all_signatures(Path(args.wheel_path), "ops")
37+
wheel_path = Path(args.wheel_path)
38+
_signatures.gen_all_signatures(wheel_path, "fn")
39+
_signatures.gen_all_signatures(wheel_path, "ops")
40+
_signatures.gen_all_signatures(wheel_path, "dynamic")

0 commit comments

Comments
 (0)