diff --git a/keras/src/layers/preprocessing/discretization.py b/keras/src/layers/preprocessing/discretization.py index 50e79ba2f49..2262bd235b8 100644 --- a/keras/src/layers/preprocessing/discretization.py +++ b/keras/src/layers/preprocessing/discretization.py @@ -95,9 +95,6 @@ def __init__( dtype=None, name=None, ): - if dtype is None: - dtype = "int64" if output_mode == "int" else backend.floatx() - super().__init__(name=name, dtype=dtype) if sparse and not backend.SUPPORTS_SPARSE_TENSORS: @@ -155,6 +152,10 @@ def __init__( def input_dtype(self): return backend.floatx() + @property + def output_dtype(self): + return self.compute_dtype if self.output_mode != "int" else "int32" + def adapt(self, data, steps=None): """Computes bin boundaries from quantiles in a input dataset. @@ -213,7 +214,7 @@ def reset_state(self): self.summary = np.array([[], []], dtype="float32") def compute_output_spec(self, inputs): - return backend.KerasTensor(shape=inputs.shape, dtype=self.compute_dtype) + return backend.KerasTensor(shape=inputs.shape, dtype=self.output_dtype) def load_own_variables(self, store): if len(store) == 1: @@ -234,7 +235,7 @@ def call(self, inputs): indices, output_mode=self.output_mode, depth=len(self.bin_boundaries) + 1, - dtype=self.compute_dtype, + dtype=self.output_dtype, sparse=self.sparse, backend_module=self.backend, ) diff --git a/keras/src/layers/preprocessing/discretization_test.py b/keras/src/layers/preprocessing/discretization_test.py index 500c6e9ca03..b9cda1d34a8 100644 --- a/keras/src/layers/preprocessing/discretization_test.py +++ b/keras/src/layers/preprocessing/discretization_test.py @@ -205,3 +205,29 @@ def test_call_before_adapt_raises(self): layer = layers.Discretization(num_bins=3) with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"): layer([[0.1, 0.8, 0.9]]) + + def test_model_call_vs_predict_consistency(self): + """Test that model(input) and model.predict(input) produce consistent outputs.""" # noqa: E501 + # Test with int output mode + layer = layers.Discretization( + bin_boundaries=[-0.5, 0, 0.1, 0.2, 3], + output_mode="int", + ) + x = np.array([[0.0, 0.15, 0.21, 0.3], [0.0, 0.17, 0.451, 7.8]]) + + # Create model + inputs = layers.Input(shape=(4,), dtype="float32") + outputs = layer(inputs) + model = models.Model(inputs=inputs, outputs=outputs) + + # Test both execution modes + model_call_output = model(x) + predict_output = model.predict(x) + + # Check consistency + self.assertAllClose(model_call_output, predict_output) + self.assertEqual( + backend.standardize_dtype(model_call_output.dtype), + backend.standardize_dtype(predict_output.dtype), + ) + self.assertTrue(backend.is_int_dtype(model_call_output.dtype))