Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions keras/src/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,6 @@ def __init__(
dtype=None,
name=None,
):
if dtype is None:
dtype = "int64" if output_mode == "int" else backend.floatx()

super().__init__(name=name, dtype=dtype)

if sparse and not backend.SUPPORTS_SPARSE_TENSORS:
Expand Down
21 changes: 21 additions & 0 deletions keras/src/layers/preprocessing/discretization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,24 @@ def test_call_before_adapt_raises(self):
layer = layers.Discretization(num_bins=3)
with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"):
layer([[0.1, 0.8, 0.9]])

def test_model_call_vs_predict_consistency(self):
"""Test that model(input) and model.predict(input) produce consistent outputs.""" # noqa: E501
# Test with int output mode
layer = layers.Discretization(
bin_boundaries=[-0.5, 0, 0.1, 0.2, 3],
output_mode="int",
)
x = np.array([[0.0, 0.15, 0.21, 0.3], [0.0, 0.17, 0.451, 7.8]])

# Create model
inputs = layers.Input(shape=(4,), dtype="float32")
outputs = layer(inputs)
model = models.Model(inputs=inputs, outputs=outputs)

# Test both execution modes
model_call_output = model(x)
predict_output = model.predict(x)

# Check consistency
self.assertAllClose(model_call_output, predict_output)
Loading