@@ -231,7 +231,9 @@ def _is_tensor(v):
231231                return  True 
232232            return  False 
233233
234-         return  all (_is_tensor (v ) for  v  in  flat_inputs )
234+         return  all (_is_tensor (v ) for  v  in  flat_inputs  if  v  is  not None ) and  any (
235+             _is_tensor (v ) for  v  in  flat_inputs 
236+         )
235237
236238    def  __init__ (
237239        self ,
@@ -259,7 +261,7 @@ def __init__(
259261        inputs  =  pack_x_y_sample_weight (x , y , sample_weights )
260262
261263        num_samples  =  set (
262-             int (i .shape [0 ]) for  i  in  tf .nest .flatten (inputs )
264+             int (i .shape [0 ]) for  i  in  tf .nest .flatten (inputs )  if   i   is   not   None 
263265        ).pop ()
264266        _check_data_cardinality (inputs )
265267
@@ -386,7 +388,7 @@ def slice_inputs(self, indices_dataset, inputs):
386388
387389        def  grab_batch (i , data ):
388390            return  tf .nest .map_structure (
389-                 lambda  d : tf .gather (d , i , axis = 0 ), data 
391+                 lambda  d : tf .gather (d , i , axis = 0 )  if   d   is   not   None   else   d , data 
390392            )
391393
392394        dataset  =  dataset .map (grab_batch , num_parallel_calls = tf .data .AUTOTUNE )
@@ -459,7 +461,9 @@ def _is_array_like(v):
459461        if  not  TensorLikeDataAdapter .can_handle (
460462            x , y 
461463        ) and  not  CompositeTensorDataAdapter .can_handle (x , y ):
462-             return  all (_is_array_like (v ) for  v  in  flat_inputs )
464+             return  all (
465+                 _is_array_like (v ) for  v  in  flat_inputs  if  v  is  not None 
466+             ) and  any (v  is  not None  for  v  in  flat_inputs )
463467        else :
464468            return  False 
465469
@@ -496,7 +500,7 @@ def dynamic_shape_like(t):
496500            shape [0 ] =  None 
497501            return  tuple (shape )
498502
499-         flat_dtypes  =  [inp .dtype  for  inp  in  flat_inputs ]
503+         flat_dtypes  =  [inp .dtype  for  inp  in  flat_inputs   if   inp   is   not   None ]
500504        contiguous  =  True 
501505        if  self ._shuffle  and  self ._shuffle  !=  "batch" :
502506            contiguous  =  False 
@@ -509,15 +513,26 @@ def grab_batch(indices):
509513            # to a Tensor may force it into memory.. 
510514            def  py_method (ind ):
511515                def  slice_array (data ):
516+                     if  data  is  None :
517+                         return  None 
512518                    return  training_utils .slice_arrays (
513519                        data , ind .numpy (), contiguous = contiguous 
514520                    )
515521
516-                 return  [slice_array (inp ) for  inp  in  flat_inputs ]
522+                 return  [
523+                     slice_array (inp ) for  inp  in  flat_inputs  if  inp  is  not None 
524+                 ]
517525
518-             flat_out  =  tf .py_function (py_method , [indices ], flat_dtypes )
519-             for  v , original_inp  in  zip (flat_out , flat_inputs ):
520-                 v .set_shape (dynamic_shape_like (original_inp ))
526+             results  =  tf .py_function (py_method , [indices ], flat_dtypes )
527+             results_it  =  iter (results )
528+             flat_out  =  []
529+             for  original_inp  in  flat_inputs :
530+                 if  original_inp  is  None :
531+                     flat_out .append (None )
532+                 else :
533+                     v  =  next (results_it )
534+                     v .set_shape (dynamic_shape_like (original_inp ))
535+                     flat_out .append (v )
521536            return  tf .nest .pack_sequence_as (inputs , flat_out )
522537
523538        dataset  =  indices_dataset .map (
@@ -608,8 +623,10 @@ def _is_tensor_or_composite(v):
608623                return  True 
609624            return  _is_composite (v )
610625
611-         return  any (_is_composite (v ) for  v  in  flat_inputs ) and  all (
612-             _is_tensor_or_composite (v ) for  v  in  flat_inputs 
626+         return  any (
627+             _is_composite (v ) for  v  in  flat_inputs  if  v  is  not None 
628+         ) and  all (
629+             _is_tensor_or_composite (v ) for  v  in  flat_inputs  if  v  is  not None 
613630        )
614631
615632    def  __init__ (
@@ -1944,14 +1961,18 @@ def single_batch_iterator(
19441961
19451962
19461963def  _check_data_cardinality (data ):
1947-     num_samples  =  set (int (i .shape [0 ]) for  i  in  tf .nest .flatten (data ))
1964+     num_samples  =  set (
1965+         int (i .shape [0 ]) for  i  in  tf .nest .flatten (data ) if  i  is  not None 
1966+     )
19481967    if  len (num_samples ) >  1 :
19491968        msg  =  "Data cardinality is ambiguous:\n " 
19501969        for  label , single_data  in  zip (["x" , "y" , "sample_weight" ], data ):
19511970            msg  +=  "  {} sizes: {}\n " .format (
19521971                label ,
19531972                ", " .join (
1954-                     str (i .shape [0 ]) for  i  in  tf .nest .flatten (single_data )
1973+                     str (i .shape [0 ])
1974+                     for  i  in  tf .nest .flatten (single_data )
1975+                     if  i  is  not None 
19551976                ),
19561977            )
19571978        msg  +=  "Make sure all arrays contain the same number of samples." 
0 commit comments