1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from inspect import Parameter , Signature
1615import ast
16+ import functools
1717import os
18-
19- from pathlib import Path
20-
2118from contextlib import closing
22-
23- from typing import Union , Optional
24- from typing import Sequence , List , Any
19+ from inspect import Parameter , Signature
20+ from pathlib import Path
21+ from typing import Any , List , Literal , Optional , Sequence , Union
2522
2623from nvidia .dali import backend as _b
24+ from nvidia .dali import fn , ops , types
2725from nvidia .dali import types as _types
28- from nvidia .dali .ops import _registry , _names , _docs
29- from nvidia .dali import types
30- from nvidia .dali import ops , fn
26+ from nvidia .dali .ops import _docs , _names , _registry
3127
3228
3329def _create_annotation_placeholder (typename ):
@@ -77,6 +73,11 @@ def __repr__(self):
7773 types .DALIInterpType : _DALIInterpType ,
7874}
7975
76+ # Placeholders for dynamic mode
77+ _Tensor = _create_annotation_placeholder ("Tensor" )
78+ _Batch = _create_annotation_placeholder ("Batch" )
79+ _TensorLike = _create_annotation_placeholder ("TensorLike" )
80+
8081
8182def _scalar_element_annotation (scalar_dtype ):
8283 # We already have function that converts a scalar constant/literal into the desired type,
@@ -241,7 +242,7 @@ def _get_positional_input_params(schema, input_annotation_gen=_get_annotation_in
241242 return param_list
242243
243244
244- def _get_keyword_params (schema , all_args_optional = False ):
245+ def _get_keyword_params (schema , all_args_optional = False , data_node_tensors = False ):
245246 """Get the list of annotated keyword Parameters to the operator."""
246247 param_list = []
247248 for arg in schema .GetArgumentNames ():
@@ -253,7 +254,11 @@ def _get_keyword_params(schema, all_args_optional=False):
253254 is_arg_input = schema .IsTensorArgument (arg )
254255
255256 if is_arg_input :
256- annotation = Union [_DataNode , _TensorLikeArg , kw_annotation ]
257+ annotation = (
258+ Union [_DataNode , _TensorLikeArg , kw_annotation ]
259+ if data_node_tensors
260+ else Union [_TensorLikeArg , kw_annotation ]
261+ )
257262 else :
258263 annotation = kw_annotation
259264
@@ -310,7 +315,9 @@ def _call_signature(
310315 include_inputs = True ,
311316 include_kwargs = True ,
312317 include_self = False ,
318+ include_batch_size = False ,
313319 data_node_return = True ,
320+ data_node_kwargs = True ,
314321 all_args_optional = False ,
315322 input_annotation_gen = _get_annotation_input_regular ,
316323 return_annotation_gen = _get_annotation_return_regular ,
@@ -328,9 +335,13 @@ def _call_signature(
328335 If keyword arguments should be included in the signature, by default True
329336 include_self : bool, optional
330337 Prepend `self` as first positional argument in the signature, by default False
338+ include_batch_size : bool, optional
339+ Preped `batch_size` as first keyword-only argument in the signature, by default False
331340 data_node_return : bool, optional
332341 If the signature should have a return annotation or return None (for ops class __init__),
333342 by default True
343+ data_node_kwargs : bool, optional
344+ If tensor keyword arguments should accept DataNodes, by default True
334345 all_args_optional : bool, optional
335346 Make all keyword arguments optional, even if they are not - needed by the ops API, where
336347 the argument can be specified in either __init__ or __call__, by default False
@@ -348,8 +359,18 @@ def _call_signature(
348359 _get_positional_input_params (schema , input_annotation_gen = input_annotation_gen )
349360 )
350361
362+ if include_batch_size :
363+ parameter = Parameter (name = "batch_size" , kind = Parameter .KEYWORD_ONLY , annotation = int )
364+ param_list .append (parameter )
365+
351366 if include_kwargs :
352- param_list .extend (_get_keyword_params (schema , all_args_optional = all_args_optional ))
367+ param_list .extend (
368+ _get_keyword_params (
369+ schema ,
370+ all_args_optional = all_args_optional ,
371+ data_node_tensors = data_node_kwargs ,
372+ )
373+ )
353374 param_list .extend (_get_implicit_keyword_params (schema , all_args_optional = all_args_optional ))
354375
355376 if data_node_return :
@@ -494,6 +515,33 @@ def __init__{_call_signature(schema, include_inputs=False, include_kwargs=True,
494515 )
495516
496517
518+ def _gen_dynamic_signature_no_input (schema : _b .OpSchema , schema_name : str , fn_name : str ):
519+ """TODO"""
520+ call_signature = functools .partial (_call_signature , schema , data_node_kwargs = False )
521+ return f"""
522+ @overload
523+ def { fn_name } { call_signature (return_annotation_gen = lambda _ : _Tensor )} :
524+ \" ""{ _docs ._docstring_generator_fn (schema_name )}
525+ \" ""
526+
527+ @overload
528+ def { fn_name } { call_signature (include_batch_size = True , return_annotation_gen = lambda _ : _Batch )} :
529+ \" ""{ _docs ._docstring_generator_fn (schema_name )}
530+ \" ""
531+ """
532+
533+
534+ def _gen_dynamic_signature (schema : _b .OpSchema , schema_name : str , fn_name : str ):
535+ """TODO"""
536+ signature = (
537+ _gen_dynamic_signature_no_input (schema , schema_name , fn_name )
538+ if schema .MaxNumInput () == 0
539+ else "\n ...\n " # _gen_dynamic_signature_with_inputs(schema, schema_name, fn_name)
540+ )
541+
542+ return inspect_repr_fixups (signature )
543+
544+
497545# Preamble with license and helper imports for the stub file.
498546# We need the placeholders for actual Python classes, as the ones that are exported from backend
499547# don't seem to work with the intellisense.
@@ -515,12 +563,19 @@ def __init__{_call_signature(schema, include_inputs=False, include_kwargs=True,
515563from typing import Union, Optional, overload
516564from typing import Any, List, Sequence
517565
518- from nvidia.dali._typing import TensorLikeIn, TensorLikeArg
566+ from nvidia.dali.types import DALIDataType, DALIImageType, DALIInterpType
519567
520- from nvidia.dali.data_node import DataNode
568+ """
521569
522- from nvidia.dali.types import DALIDataType, DALIImageType, DALIInterpType
570+ _PIPELINE_HEADER = """
571+ from nvidia.dali._typing import TensorLikeIn, TensorLikeArg
572+ from nvidia.dali.data_node import DataNode
573+ """
523574
575+ _DYNAMIC_HEADER = """
576+ from nvidia.dali._typing import TensorLike, TensorLikeArg
577+ from nvidia.dali.experimental.dynamic._tensor import Tensor
578+ from nvidia.dali.experimental.dynamic._batch import Batch
524579"""
525580
526581
@@ -574,8 +629,8 @@ def _get_op(api_module, full_qualified_name: List[str]):
574629
575630
576631def _group_signatures (api : str ):
577- """Divide all operators registered into the "ops" or "fn" api into 4 categories and return them
578- as a dictionary:
632+ """Divide all operators registered into the "ops", "fn" or "dynamic api into 4 categories
633+ and return them as a dictionary:
579634 * python_only - there is just the Python definition
580635 * hidden_or_internal - op is hidden or internal, defined in backend
581636 * python_wrapper - op defined in backend, has a hand-written wrapper (op._generated = False)
@@ -585,31 +640,39 @@ def _group_signatures(api: str):
585640 depending on the api type.
586641
587642 """
643+
644+ from nvidia .dali .experimental import dynamic
645+
588646 sig_groups = {
589647 "python_only" : [],
590648 "hidden_or_internal" : [],
591649 "python_wrapper" : [],
592650 "generated" : [],
593651 }
594652
595- api_module = fn if api == "fn" else ops
653+ api_module = {"fn" : fn , "ops" : ops , "dynamic" : dynamic }[api ]
654+ naming_convention = api if api != "dynamic" else "fn"
596655
597656 for schema_name in sorted (_registry ._all_registered_ops ()):
598657 schema = _b .TryGetSchema (schema_name )
599658
600- _ , module_nesting , op_name = _names ._process_op_name (schema_name , api = api )
659+ _ , module_nesting , op_name = _names ._process_op_name (schema_name , api = naming_convention )
601660 op = _get_op (api_module , module_nesting + [op_name ])
602661
662+ if op is None :
663+ continue
664+
603665 if schema is None :
604- if op is not None :
605- sig_groups ["python_only" ].append ((schema_name , op ))
666+ sig_groups ["python_only" ].append ((schema_name , op ))
606667 continue
607668
608669 if schema .IsDocHidden () or schema .IsInternal ():
609670 sig_groups ["hidden_or_internal" ].append ((schema_name , op ))
610671 continue
611672
612- if not getattr (op , "_generated" , False ):
673+ # Dynamic mode doesn't have registered python wrappers yet
674+ # If necessary, we can later check hasattr(op, "op_class")
675+ if not hasattr (op , "_generated" ) and api != "dynamic" :
613676 sig_groups ["python_wrapper" ].append ((schema_name , op ))
614677 continue
615678
@@ -624,6 +687,12 @@ def __init__(self, nvidia_dali_path: Path, api: str):
624687 self ._nvidia_dali_path = nvidia_dali_path
625688 self ._api = api
626689 self ._module_tree = _build_module_tree ()
690+ self ._header = _HEADER
691+
692+ if api in ("ops" , "fn" ):
693+ self ._header += _PIPELINE_HEADER
694+ else :
695+ self ._header += _DYNAMIC_HEADER
627696
628697 def get (self , module_nesting : List [str ]):
629698 """Get the file representing the given submodule nesting.
@@ -639,7 +708,7 @@ def get(self, module_nesting: List[str]):
639708 open (file_path , "w" ).close () # clear the file
640709 f = open (file_path , "a" )
641710 self ._module_to_file [module_path ] = f
642- f .write (_HEADER )
711+ f .write (self . _header )
643712 full_module_nesting = ["" ] + module_nesting
644713 # Find out all the direct submodules and add the imports
645714 submodules_dict = self ._module_tree
@@ -657,24 +726,27 @@ def close(self):
657726 f .close ()
658727
659728
660- def gen_all_signatures (nvidia_dali_path , api ):
661- """Generate the signatures for "fn" or "ops " api.
729+ def gen_all_signatures (nvidia_dali_path : Path , api : Literal [ "fn" , "ops" , "dynamic" ] ):
730+ """Generate the signatures for "fn", "ops" or "dynamic " api.
662731
663732 Parameters
664733 ----------
665734 nvidia_dali_path : Path
666735 The path to the wheel pre-packaging to the nvidia/dali directory.
667736 api : str
668- "fn" or "ops "
737+ "fn", "ops" or "dynamic "
669738 """
670- nvidia_dali_path = Path (nvidia_dali_path )
739+ api_path = naming_convention = api
740+ if api == "dynamic" :
741+ api_path = os .path .join ("experimental" , api )
742+ naming_convention = "fn"
671743
672- with closing (StubFileManager (nvidia_dali_path , api )) as stub_manager :
744+ with closing (StubFileManager (nvidia_dali_path , api_path )) as stub_manager :
673745 sig_groups = _group_signatures (api )
674746
675747 # Python-only and the manually defined ones are reexported from their respective modules
676748 for schema_name , op in sig_groups ["python_only" ] + sig_groups ["python_wrapper" ]:
677- _ , module_nesting , op_name = _names ._process_op_name (schema_name , api = api )
749+ _ , module_nesting , op_name = _names ._process_op_name (schema_name , api = naming_convention )
678750
679751 stub_manager .get (module_nesting ).write (
680752 f"\n \n from { op ._impl_module } import" f" ({ op .__name__ } as { op .__name__ } )\n \n "
@@ -684,15 +756,14 @@ def gen_all_signatures(nvidia_dali_path, api):
684756 # directly visible
685757
686758 # Runtime generated classes use fully specified stubs.
759+ signature_generators = {
760+ "fn" : _gen_fn_signature ,
761+ "ops" : _gen_ops_signature ,
762+ "dynamic" : _gen_dynamic_signature ,
763+ }
687764 for schema_name , op in sig_groups ["generated" ]:
688- _ , module_nesting , op_name = _names ._process_op_name (schema_name , api = api )
765+ _ , module_nesting , op_name = _names ._process_op_name (schema_name , api = naming_convention )
689766 schema = _b .TryGetSchema (schema_name )
690767
691- if api == "fn" :
692- stub_manager .get (module_nesting ).write (
693- _gen_fn_signature (schema , schema_name , op_name )
694- )
695- else :
696- stub_manager .get (module_nesting ).write (
697- _gen_ops_signature (schema , schema_name , op_name )
698- )
768+ signature = signature_generators [api ](schema , schema_name , op_name )
769+ stub_manager .get (module_nesting ).write (signature )
0 commit comments