Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,23 +0,0 @@
.DS_Store
*.pyc
.vscode-test
__pycache__
**/.vscode-test/**
**/.vscode test/**
**/.vscode-smoke/**
**/.venv*/
venv
bin/**
build/**
obj/**
.pytest_cache
tmp/**
.vs/
dist/**
**/*.egg-info/*
.vscode
examples/**/*.jpg
.python-version
.coverage
*coverage.xml
.ruff_cache
87 changes: 87 additions & 0 deletions CUSTOM_GRADIENT_JAX_FIX.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Fix for custom_gradient with JAX backend and Variables

## Issue
GitHub Issue [#21105](https://github.com/keras-team/keras/issues/21105)

When using `@ops.custom_gradient` with the JAX backend, passing Keras Variables as arguments would cause a `TypeError: 'NoneType' object is not callable` during training. This occurred because JAX's `custom_gradient` would capture the Variable object itself instead of extracting its underlying tensor value.

## Root Cause
The JAX backend's `custom_gradient` function was directly wrapping `jax.custom_gradient` without converting Variable objects to their values, unlike the `stop_gradient` function which already handled this correctly.

## Solution
Modified `keras/src/backend/jax/core.py` to add a wrapper that automatically extracts `.value` from Variable objects before passing them to the user's custom gradient function. This is done using `tree.map_structure` to recursively handle nested structures.

### Changes Made

**File: `keras/src/backend/jax/core.py`**

```python
def custom_gradient(fun):
def wrapper(*args, **kwargs):
# Convert Variable objects to their values
def _convert_arg(arg):
if isinstance(arg, Variable):
return arg.value
return arg

args = tree.map_structure(_convert_arg, args)
kwargs = tree.map_structure(_convert_arg, kwargs)
return fun(*args, **kwargs)

return jax.custom_gradient(fun=wrapper)
```

**File: `keras/src/ops/core_test.py`**

Added `test_custom_gradient_with_variable()` to verify that Variables can be passed directly to custom_gradient functions without needing to manually add `.value`.

## Testing

### Run the specific test:
```bash
pytest keras/src/ops/core_test.py::CoreOpsCorrectnessTest::test_custom_gradient_with_variable -v
```

### Run all core ops tests:
```bash
pytest keras/src/ops/core_test.py -v
```

## Example Usage

Before the fix, you needed to manually extract `.value`:

```python
@ops.custom_gradient
def roundpass(x, log_scaling):
scaling = ops.exp(log_scaling)
rounded = ops.round(x * scaling) / scaling

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
return upstream, ops.zeros_like(log_scaling)

return rounded, grad

class QuantizedLayer(layers.Layer):
def call(self, x):
# Workaround: manually add .value
return roundpass(x, self.log_scaling.value)
```

After the fix, Variables work directly:

```python
class QuantizedLayer(layers.Layer):
def call(self, x):
# Works automatically now!
return roundpass(x, self.log_scaling)
```

## Impact
- ✅ Fixes the TypeError when Variables are passed to custom_gradient functions
- ✅ Makes JAX backend behavior consistent with user expectations
- ✅ Aligns with how `stop_gradient` already handles Variables
- ✅ Backward compatible - existing code using `.value` workaround still works
- ✅ No performance impact - conversion happens once at function decoration time
56 changes: 56 additions & 0 deletions TORCH_JIT_COMPILE_LIMITATIONS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Torch Backend jit_compile Limitations

## Issue #21647: jit_compile=True with EfficientNetV2 on torch backend

### Problem
When using `jit_compile=True` with certain Keras models (especially EfficientNetV2) on the torch backend, you may encounter `InternalTorchDynamoError` or `RuntimeError` related to torch.compile being unable to trace optree operations.

### Root Cause
Keras uses tree operations (from optree or torch._pytree) for handling nested structures. When `jit_compile=True` is enabled, PyTorch's torch.compile attempts to trace through all Python operations, including these tree utilities. However, torch.compile has limitations with certain C/C++ extensions and symbolic operations.

### Error Messages
- **GPU**: `InternalTorchDynamoError: TypeError: '<' not supported between instances of 'NoneType' and 'int'`
- **CPU**: `RuntimeError: TypeError: cannot determine truth value of Relational`

### Workarounds

#### Option 1: Disable JIT Compilation (Recommended)
```python
model.compile(
optimizer=Adam(learning_rate=0.001),
loss=CategoricalCrossentropy(),
metrics=['accuracy'],
jit_compile=False # or omit this parameter
)
```

#### Option 2: Use a Different Backend
Switch to TensorFlow or JAX backend which have better jit_compile support:
```python
import os
os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax"
```

#### Option 3: Use Fixed Input Shapes
If you must use jit_compile with torch, ensure all input shapes are fixed (no None dimensions):
```python
base_model = EfficientNetV2B2(
include_top=False,
input_shape=(224, 224, 3), # Fixed shape, no None
pooling='avg',
weights=None
)
```

### Status
This is a known limitation of torch.compile when working with complex nested structures. The PyTorch team is aware of limitations with certain patterns and continues to improve torch.compile support.

### Related Issues
- PyTorch Issue: torch.compile limitations with pytree operations
- Keras Issue #21647

### Future Improvements
Potential solutions being explored:
1. Add torch.compile skip decorators for tree operations
2. Use torch.compiler.disable() context for specific operations
3. Refactor to use pure torch operations where possible
107 changes: 107 additions & 0 deletions keras/src/applications/efficientnet_v2_jit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Test for Issue #21647: jit_compile=True with EfficientNetV2 on torch
backend."""

import numpy as np
import pytest

from keras.src import backend
from keras.src import layers
from keras.src import models
from keras.src import testing
from keras.src.applications.efficientnet_v2 import EfficientNetV2B2
from keras.src.losses import CategoricalCrossentropy
from keras.src.optimizers import Adam


@pytest.mark.skipif(
backend.backend() != "torch",
reason="This test is specifically for torch backend",
)
class EfficientNetV2JitCompileTest(testing.TestCase):
"""Test EfficientNetV2 models with jit_compile=True on torch backend."""

def test_efficientnet_v2_b2_with_jit_compile(self):
"""Test that EfficientNetV2B2 works with jit_compile=True."""
num_classes = 10
batch_size = 2 # Small batch for testing
steps_per_epoch = 1
epochs = 1

# Generate random data (use minimum supported size)
data_shape = (224, 224, 3) # Minimum size for EfficientNetV2
x_train = np.random.rand(
batch_size * steps_per_epoch, *data_shape
).astype(np.float32)
y_train = np.random.randint(
0, num_classes, size=(batch_size * steps_per_epoch,)
)
y_train = np.eye(num_classes)[y_train]

# Create model
base_model = EfficientNetV2B2(
include_top=False,
input_shape=(224, 224, 3), # Fixed shape for jit_compile
pooling="avg",
include_preprocessing=True,
weights=None, # Don't load weights for faster testing
)
x = base_model.output
output = layers.Dense(num_classes, activation="softmax")(x)
model = models.Model(inputs=base_model.input, outputs=output)

# Compile with jit_compile=True
model.compile(
optimizer=Adam(learning_rate=0.001),
loss=CategoricalCrossentropy(),
metrics=["accuracy"],
jit_compile=True,
)

# This should not raise InternalTorchDynamoError
history = model.fit(
x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=0
)

# Basic sanity check
self.assertIsNotNone(history)
self.assertIn("loss", history.history)

def test_efficientnet_v2_b0_with_jit_compile(self):
"""Test that EfficientNetV2B0 also works with jit_compile=True."""
from keras.src.applications.efficientnet_v2 import EfficientNetV2B0

num_classes = 5
batch_size = 2

# Generate random data
x_train = np.random.rand(batch_size, 224, 224, 3).astype(np.float32)
_ = np.eye(num_classes)[
np.random.randint(0, num_classes, size=(batch_size,))
]

# Create model
base_model = EfficientNetV2B0(
include_top=False,
input_shape=(224, 224, 3),
pooling="avg",
weights=None,
)
x = base_model.output
output = layers.Dense(num_classes, activation="softmax")(x)
model = models.Model(inputs=base_model.input, outputs=output)

# Compile with jit_compile=True
model.compile(
optimizer=Adam(learning_rate=0.001),
loss=CategoricalCrossentropy(),
metrics=["accuracy"],
jit_compile=True,
)

# Should work without errors
predictions = model.predict(x_train, verbose=0)
self.assertEqual(predictions.shape, (batch_size, num_classes))


if __name__ == "__main__":
pytest.main([__file__])
14 changes: 12 additions & 2 deletions keras/src/applications/imagenet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,18 @@ def obtain_input_shape(
"""
if weights != "imagenet" and input_shape and len(input_shape) == 3:
if data_format == "channels_first":
correct_channel_axis = 1 if len(input_shape) == 4 else 0
if input_shape[correct_channel_axis] not in {1, 3}:
# Check if user accidentally provided channels_last format
# when channels_first was expected
if input_shape[-1] in {1, 3} and input_shape[0] not in {1, 3}:
raise ValueError(
f"The `input_shape` argument has shape {input_shape}, "
"which appears to be in 'channels_last' format "
f"(with {input_shape[-1]} channels), but the model "
"is configured to use 'channels_first' data format. "
f"For 'channels_first', provide input_shape as "
f"({input_shape[-1]}, {input_shape[0]}, {input_shape[1]})."
)
if input_shape[0] not in {1, 3}:
warnings.warn(
"This model usually expects 1 or 3 input channels. "
"However, it was passed an input_shape "
Expand Down
27 changes: 26 additions & 1 deletion keras/src/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def _convert_to_tensor(self, value, dtype=None):

# Overload native accessor.
def __jax_array__(self):
# Handle case where Variable is copied during JAX tracing
# and both _value and _initializer become None
if self._value is None and self._initializer is None:
# This can happen when JAX copies Variables during tracing.
# In this case, we need to use the actual shape to create a
# placeholder tensor for shape inference.
import jax.numpy as jnp

from keras.src.backend.common import standardize_dtype

return jnp.zeros(self._shape, dtype=standardize_dtype(self._dtype))
return self.value


Expand Down Expand Up @@ -513,8 +524,22 @@ def random_seed_dtype():
return "uint32"


def _convert_variable_to_value(arg):
"""Convert Variable objects to their underlying values."""
if isinstance(arg, Variable):
return arg.value
return arg


def custom_gradient(fun):
return jax.custom_gradient(fun=fun)
@jax.custom_gradient
def wrapper(*args, **kwargs):
# Convert Variable objects to their values
args = tree.map_structure(_convert_variable_to_value, args)
kwargs = tree.map_structure(_convert_variable_to_value, kwargs)
return fun(*args, **kwargs)

return wrapper


def remat(f):
Expand Down
6 changes: 6 additions & 0 deletions keras/src/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from keras.src.losses.losses import sparse_categorical_crossentropy
from keras.src.losses.losses import squared_hinge
from keras.src.losses.losses import tversky
from keras.src.losses.lpips import LPIPS
from keras.src.losses.lpips import lpips
from keras.src.saving import serialization_lib

ALL_OBJECTS = {
Expand Down Expand Up @@ -76,6 +78,8 @@
Tversky,
# Similarity
Circle,
# Feature extraction perceptual
LPIPS,
# Sequence
CTC,
# Probabilistic
Expand All @@ -94,6 +98,8 @@
cosine_similarity,
log_cosh,
huber,
# Feature extraction perceptual
lpips,
# Hinge
hinge,
squared_hinge,
Expand Down
Loading
Loading