Skip to content

Commit 28fcbfc

Browse files
Allow Input to be optional to take None inputs, similar to what keras3 has.
PiperOrigin-RevId: 819935785
1 parent 0dec184 commit 28fcbfc

17 files changed

+257
-43
lines changed

tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v1/tensorflow.keras.layers.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ tf_module {
482482
}
483483
member_method {
484484
name: "Input"
485-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
485+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
486486
}
487487
member_method {
488488
name: "add"

tf_keras/api/golden/v1/tensorflow.keras.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,6 @@ tf_module {
9090
}
9191
member_method {
9292
name: "Input"
93-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
93+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
9494
}
9595
}

tf_keras/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v2/tensorflow.keras.layers.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ tf_module {
538538
}
539539
member_method {
540540
name: "Input"
541-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
541+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
542542
}
543543
member_method {
544544
name: "add"

tf_keras/api/golden/v2/tensorflow.keras.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,6 @@ tf_module {
9595
}
9696
member_method {
9797
name: "Input"
98-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
98+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
9999
}
100100
}

tf_keras/engine/data_adapter.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ def _is_tensor(v):
231231
return True
232232
return False
233233

234-
return all(_is_tensor(v) for v in flat_inputs)
234+
return all(_is_tensor(v) for v in flat_inputs if v is not None) and any(
235+
_is_tensor(v) for v in flat_inputs
236+
)
235237

236238
def __init__(
237239
self,
@@ -259,7 +261,7 @@ def __init__(
259261
inputs = pack_x_y_sample_weight(x, y, sample_weights)
260262

261263
num_samples = set(
262-
int(i.shape[0]) for i in tf.nest.flatten(inputs)
264+
int(i.shape[0]) for i in tf.nest.flatten(inputs) if i is not None
263265
).pop()
264266
_check_data_cardinality(inputs)
265267

@@ -386,7 +388,7 @@ def slice_inputs(self, indices_dataset, inputs):
386388

387389
def grab_batch(i, data):
388390
return tf.nest.map_structure(
389-
lambda d: tf.gather(d, i, axis=0), data
391+
lambda d: tf.gather(d, i, axis=0) if d is not None else d, data
390392
)
391393

392394
dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
@@ -459,7 +461,9 @@ def _is_array_like(v):
459461
if not TensorLikeDataAdapter.can_handle(
460462
x, y
461463
) and not CompositeTensorDataAdapter.can_handle(x, y):
462-
return all(_is_array_like(v) for v in flat_inputs)
464+
return all(
465+
_is_array_like(v) for v in flat_inputs if v is not None
466+
) and any(v is not None for v in flat_inputs)
463467
else:
464468
return False
465469

@@ -496,7 +500,7 @@ def dynamic_shape_like(t):
496500
shape[0] = None
497501
return tuple(shape)
498502

499-
flat_dtypes = [inp.dtype for inp in flat_inputs]
503+
flat_dtypes = [inp.dtype for inp in flat_inputs if inp is not None]
500504
contiguous = True
501505
if self._shuffle and self._shuffle != "batch":
502506
contiguous = False
@@ -509,15 +513,26 @@ def grab_batch(indices):
509513
# to a Tensor may force it into memory..
510514
def py_method(ind):
511515
def slice_array(data):
516+
if data is None:
517+
return None
512518
return training_utils.slice_arrays(
513519
data, ind.numpy(), contiguous=contiguous
514520
)
515521

516-
return [slice_array(inp) for inp in flat_inputs]
522+
return [
523+
slice_array(inp) for inp in flat_inputs if inp is not None
524+
]
517525

518-
flat_out = tf.py_function(py_method, [indices], flat_dtypes)
519-
for v, original_inp in zip(flat_out, flat_inputs):
520-
v.set_shape(dynamic_shape_like(original_inp))
526+
results = tf.py_function(py_method, [indices], flat_dtypes)
527+
results_it = iter(results)
528+
flat_out = []
529+
for original_inp in flat_inputs:
530+
if original_inp is None:
531+
flat_out.append(None)
532+
else:
533+
v = next(results_it)
534+
v.set_shape(dynamic_shape_like(original_inp))
535+
flat_out.append(v)
521536
return tf.nest.pack_sequence_as(inputs, flat_out)
522537

523538
dataset = indices_dataset.map(
@@ -608,8 +623,10 @@ def _is_tensor_or_composite(v):
608623
return True
609624
return _is_composite(v)
610625

611-
return any(_is_composite(v) for v in flat_inputs) and all(
612-
_is_tensor_or_composite(v) for v in flat_inputs
626+
return any(
627+
_is_composite(v) for v in flat_inputs if v is not None
628+
) and all(
629+
_is_tensor_or_composite(v) for v in flat_inputs if v is not None
613630
)
614631

615632
def __init__(
@@ -1944,14 +1961,18 @@ def single_batch_iterator(
19441961

19451962

19461963
def _check_data_cardinality(data):
1947-
num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data))
1964+
num_samples = set(
1965+
int(i.shape[0]) for i in tf.nest.flatten(data) if i is not None
1966+
)
19481967
if len(num_samples) > 1:
19491968
msg = "Data cardinality is ambiguous:\n"
19501969
for label, single_data in zip(["x", "y", "sample_weight"], data):
19511970
msg += " {} sizes: {}\n".format(
19521971
label,
19531972
", ".join(
1954-
str(i.shape[0]) for i in tf.nest.flatten(single_data)
1973+
str(i.shape[0])
1974+
for i in tf.nest.flatten(single_data)
1975+
if i is not None
19551976
),
19561977
)
19571978
msg += "Make sure all arrays contain the same number of samples."

0 commit comments

Comments
 (0)