Skip to content

Commit 11da67d

Browse files
authored
Fix TF trainer bug when first input is None (#21630)
* Fix TF trainer bug when first flatten input is None * Test the fix by adapting 2 unit tests * Same fix for jax & torch
1 parent bc3d38c commit 11da67d

File tree

4 files changed

+34
-21
lines changed

4 files changed

+34
-21
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ def _update_metrics_variables(
105105
]
106106
) as scope:
107107
self._loss_tracker.update_state(
108-
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
108+
unscaled_loss,
109+
sample_weight=next(
110+
i for i in tree.flatten(x) if i is not None
111+
).shape[0],
109112
)
110113
logs = self.compute_metrics(x, y, y_pred, sample_weight)
111114

keras/src/backend/tensorflow/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def train_step(self, data):
6868
)
6969
self._loss_tracker.update_state(
7070
loss_module.unscale_loss_for_distribution(loss),
71-
sample_weight=tf.shape(tree.flatten(x)[0])[0],
71+
sample_weight=tf.shape(
72+
next(i for i in tree.flatten(x) if i is not None)
73+
)[0],
7274
)
7375
if self.optimizer is not None:
7476
loss = self.optimizer.scale_loss(loss)
@@ -96,7 +98,9 @@ def test_step(self, data):
9698
)
9799
self._loss_tracker.update_state(
98100
loss_module.unscale_loss_for_distribution(loss),
99-
sample_weight=tf.shape(tree.flatten(x)[0])[0],
101+
sample_weight=tf.shape(
102+
next(i for i in tree.flatten(x) if i is not None)
103+
)[0],
100104
)
101105
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
102106

keras/src/backend/torch/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def train_step(self, data):
5454
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True
5555
)
5656
self._loss_tracker.update_state(
57-
loss, sample_weight=tree.flatten(x)[0].shape[0]
57+
loss,
58+
sample_weight=next(
59+
i for i in tree.flatten(x) if i is not None
60+
).shape[0],
5861
)
5962
if self.optimizer is not None:
6063
loss = self.optimizer.scale_loss(loss)
@@ -90,7 +93,10 @@ def test_step(self, data):
9093
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
9194
)
9295
self._loss_tracker.update_state(
93-
loss, sample_weight=tree.flatten(x)[0].shape[0]
96+
loss,
97+
sample_weight=next(
98+
i for i in tree.flatten(x) if i is not None
99+
).shape[0],
94100
)
95101
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)
96102

keras/src/models/model_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,14 @@ def __init__(self):
163163
super().__init__()
164164
self.dense = layers.Dense(2)
165165

166-
def call(self, a, b=None):
167-
x = a if b is None else a + b
168-
return self.dense(x)
169-
170-
x1 = Input((2,), name="x1")
171-
x2 = Input((2,), name="x2", optional=True)
172-
y = OptionalInputLayer()(x1, x2)
173-
model = Model({"x1": x1, "x2": x2}, y)
166+
def call(self, x, o=None):
167+
z = x if o is None else x + o
168+
return self.dense(z)
169+
170+
x = Input((2,), name="x")
171+
o = Input((2,), name="o", optional=True)
172+
y = OptionalInputLayer()(x, o)
173+
model = Model({"x": x, "o": o}, y)
174174
return model
175175

176176

@@ -1241,27 +1241,27 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
12411241
)
12421242
def test_functional_optional_inputs(self, is_optional_none):
12431243
model = _get_model_optional_inputs()
1244-
x1 = np.ones((2, 2))
1245-
x2 = None if is_optional_none else np.ones((2, 2))
1244+
x = np.ones((2, 2))
1245+
o = None if is_optional_none else np.ones((2, 2))
12461246
y_true = np.ones((2, 2))
12471247

12481248
model.compile(loss="mse", optimizer="adam")
1249-
model.fit(x={"x1": x1, "x2": x2}, y=y_true)
1250-
model.evaluate(x={"x1": x1, "x2": x2}, y=y_true)
1251-
model.predict(x={"x1": x1, "x2": x2})
1249+
model.fit(x={"x": x, "o": o}, y=y_true)
1250+
model.evaluate(x={"x": x, "o": o}, y=y_true)
1251+
model.predict(x={"x": x, "o": o})
12521252

12531253
@parameterized.named_parameters(
12541254
("optional_none", True), ("optional_tensor", False)
12551255
)
12561256
def test_functional_optional_inputs_generator(self, is_optional_none):
12571257
model = _get_model_optional_inputs()
1258-
x1 = np.ones((2, 2))
1259-
x2 = None if is_optional_none else np.ones((2, 2))
1258+
x = np.ones((2, 2))
1259+
o = None if is_optional_none else np.ones((2, 2))
12601260
y_true = np.ones((2, 2))
12611261

12621262
def data_generator(with_y=True):
12631263
for _ in range(4):
1264-
yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ())
1264+
yield ({"x": x, "o": o},) + ((y_true,) if with_y else ())
12651265

12661266
model.compile(loss="mse", optimizer="adam")
12671267
model.fit(data_generator())

0 commit comments

Comments
 (0)