Skip to content

Commit

Permalink
Add support for PyTorch and Jax to KerasCV (#1925)
Browse files Browse the repository at this point in the history
* Port preprocessing layers and bounding box utils to Keras Core [try 2] (#1903)

* Port preprocessing layers and bounding box utils to Keras Core

* Fix pip install

* Another git clone fix

* Fix for nested scopes

* Use PAT for git clone

* PAT structure

* Update actions.yml

* Update actions.yml

* Update actions.yml

* Int64 for jax

* Update utils

* Remove commented-out code

* ANY

* No any typehints

* Deploy key for install

* Remove prints

* SSH key as github action variable

* New ssh key approach

* Review comments

* Fancy pytest markers

* operations -> ops plus some test marks

* Add torch hax

* Fix TF marker

* Format

* Review comments

* rename tf_only to tf_keras_only

* Try installing keras core with pip_build

* s/_/-

* namex

* Namex fix + torch

* Rich

* requirements.txt

* I promise I have used a computer before

* Back to private API for now

* Fix path for backend

* Back to namex install

* Use correct policy API for mixed precision

* Reverse aliases

* Rely on keras-core for validation of backend

* Copy the homework of matt

* Newline

* Port backbones to Keras Core (#1906)

* Port backbones to Keras Core

* Fix YOLOV8 presets

* Fix tests

* Fix backbones

* utils -> saving

* YOLOV8 presets -- again

* saving for csp

* Test fixes

* Squeeze_excite

Co-authored-by: Tirth Patel <[email protected]>

* Port losses to Keras Core (#1905)

* Port losses to Keras Core

* Saving

* Fix focal

* Update CenterNetBoxLoss

Co-authored-by: Tirth Patel <[email protected]>

* Port ImageClassifier to Keras Core (#1908)

* Port ImageClassifier to Keras Core

* Backbone property

* Port object detection layers to Keras Core (#1907)

* Port OD layers to Keras Core

* Add multi-backend NMS

* Use pytest fixture for skippage

* Nice asserts for tf.keras-only components

* Review comments

Co-authored-by: Tirth Patel <[email protected]>

* Backbone property fix (#1909)

* Port RetinaNet to Keras Core (#1912)

* Port RetinaNet to Keras Core

* Add defensive tuple casting

* Remove Torch one_hot workaround

* Cast to int

* Port PyCOCOCallback to Keras Core (#1913)

* Fix default training value for SqueezeExcite (#1917)

* Update to_numpy to use Keras Core ops (#1916)

* Implement native NMS for PyTorch backend (#1918)

* Use torch-native NMS where possible

* Implement native NMS for Torch

* Remove prints

* valid_det is an int, foo

* fix num dets

* Comment

* Bump version to 0.6.0 for Keras Core release (#1919)

* Add keras-core as a dependency (#1922)

* Add keras-core as a dependency

* Install keras core from PyPi on CI

* Fix tf.keras CI

* Port YOLOV8 to Keras Core' (#1920)

* Port YOLOV8 to Keras Core'

* Update test utils

* Mark backbone tests XL'

---------

Co-authored-by: Tirth Patel <[email protected]>

* Fix saving namespace for TF < 2.12

* Fix penalty reduced focal loss

* Run GPU tests on tf 2.13, require 2.12

* Run cpu tests on tf 2.13, require 2.13

* Fix ABI and FasterRCNN tests

---------

Co-authored-by: Ian Stenbit <[email protected]>
  • Loading branch information
tirthasheshpatel and ianstenbit authored Jul 11, 2023
1 parent ce06df9 commit 993ff61
Show file tree
Hide file tree
Showing 164 changed files with 3,037 additions and 1,919 deletions.
73 changes: 60 additions & 13 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ on:
types: [created]
jobs:
test:
name: Test the code
name: Test the code with tf.keras
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Get pip cache dir
id: pip-cache
run: |
Expand All @@ -29,27 +29,75 @@ jobs:
${{ runner.os }}-pip-
- name: Install dependencies
run: |
pip install tensorflow-cpu==2.11.0
pip install tensorflow==2.13.0
pip install torch>=2.0.1+cpu
pip install "jax[cpu]"
pip install keras-core
pip install -e ".[tests]" --progress-bar off --upgrade
- name: Build custom ops for tests
run: |
python build_deps/configure.py
bazel build keras_cv/custom_ops:all
cp bazel-bin/keras_cv/custom_ops/*.so keras_cv/custom_ops/
- name: Test with pytest
env:
TEST_CUSTOM_OPS: true
TEST_CUSTOM_OPS: false
run: |
pytest keras_cv/ --ignore keras_cv/models/legacy/ --durations 0
multibackend:
name: Test the code with Keras Core
strategy:
fail-fast: false
matrix:
backend: [tensorflow, jax, torch]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v2
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
pip install tensorflow==2.13.0
pip install "jax[cpu]"
pip install torch>=2.0.1+cpu
pip install torchvision>=0.15.1
pip install keras-core
pip install -e ".[tests]" --progress-bar off --upgrade
- name: Test with pytest
env:
TEST_CUSTOM_OPS: false # TODO(ianstenbit): test custom ops, or figure out what our story is here
KERAS_CV_MULTI_BACKEND: true
KERAS_BACKEND: ${{ matrix.backend }}
JAX_ENABLE_X64: true
run: |
pytest --run_large keras_cv/bounding_box \
keras_cv/callbacks \
keras_cv/losses \
keras_cv/layers/object_detection \
keras_cv/layers/preprocessing \
keras_cv/models/backbones \
keras_cv/models/classification \
keras_cv/models/object_detection/retinanet \
keras_cv/models/object_detection/yolo_v8 \
--durations 0
format:
name: Check the code format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Get pip cache dir
id: pip-cache
run: |
Expand All @@ -75,4 +123,3 @@ jobs:
extensions: 'h,c,cpp,hpp,cc'
clangFormatVersion: 14
style: google

2 changes: 1 addition & 1 deletion cloudbuild/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ To add a dependency for GPU tests:
- Have a Keras team member update the Docker image for GPU tests by running the remaining steps
- Create a `Dockerfile` with the following contents:
```
FROM tensorflow/tensorflow:2.11.0-gpu
FROM tensorflow/tensorflow:2.13.0-gpu
RUN \
apt-get -y update && \
apt-get -y install openjdk-8-jdk && \
Expand Down
11 changes: 10 additions & 1 deletion keras_cv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

try:
# When using torch and tensorflow, torch needs to be imported first,
# otherwise it will segfault upon import.
import torch

del torch
except ImportError:
pass

# isort:off
from keras_cv import version_check

Expand All @@ -33,4 +42,4 @@
from keras_cv.core import NormalFactorSampler
from keras_cv.core import UniformFactorSampler

__version__ = "0.5.1"
__version__ = "0.6.0"
91 changes: 91 additions & 0 deletions keras_cv/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Keras backend module.
This module adds a temporarily Keras API surface that is fully under KerasCV
control. This allows us to switch between `keras_core` and `tf.keras`, as well
as add shims to support older version of `tf.keras`.
- `config`: check which backend is being run.
- `keras`: The full `keras` API (via `keras_core` or `tf.keras`).
- `ops`: `keras_core.ops`, always tf-backed if using `tf.keras`.
"""

import types

from keras_cv.backend.config import multi_backend

# Keys are of the form: "module.where.attr.exists->module.where.to.alias"
# Value are of the form: ["attr1", "attr2", ...] or
# [("attr1_original_name", "attr1_alias_name")]
_KERAS_CORE_ALIASES = {
"utils->saving": [
"register_keras_serializable",
"deserialize_keras_object",
"serialize_keras_object",
"get_registered_object",
],
"models->saving": ["load_model"],
}


if multi_backend():
import keras_core as keras
else:
from tensorflow import keras

if not hasattr(keras, "saving"):
keras.saving = types.SimpleNamespace()

# add aliases
for key, value in _KERAS_CORE_ALIASES.items():
src, _, dst = key.partition("->")
src = src.split(".")
dst = dst.split(".")

src_mod, dst_mod = keras, keras

# navigate to where we want to alias the attributes
for mod in src:
src_mod = getattr(src_mod, mod)
for mod in dst:
dst_mod = getattr(dst_mod, mod)

# add an alias for each attribute
for attr in value:
if isinstance(attr, tuple):
src_attr, dst_attr = attr
else:
src_attr, dst_attr = attr, attr
attr_val = getattr(src_mod, src_attr)
setattr(dst_mod, dst_attr, attr_val)

# TF Keras doesn't have this rename.
keras.activations.silu = keras.activations.swish

from keras_cv.backend import config # noqa: E402
from keras_cv.backend import ops # noqa: E402
from keras_cv.backend import tf_ops # noqa: E402


def assert_tf_keras(src):
if multi_backend():
raise NotImplementedError(
f"KerasCV component {src} does not yet support Keras Core, and can "
"only be used in `tf.keras`."
)


def supports_ragged():
return not multi_backend()
69 changes: 69 additions & 0 deletions keras_cv/backend/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os

_MULTI_BACKEND = False

# Set Keras base dir path given KERAS_HOME env variable, if applicable.
# Otherwise either ~/.keras or /tmp.
if "KERAS_HOME" in os.environ:
_keras_dir = os.environ.get("KERAS_HOME")
else:
_keras_base_dir = os.path.expanduser("~")
if not os.access(_keras_base_dir, os.W_OK):
_keras_base_dir = "/tmp"
_keras_dir = os.path.join(_keras_base_dir, ".keras")

# Attempt to read KerasCV config file.
_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras-cv.json"))
if os.path.exists(_config_path):
try:
with open(_config_path) as f:
_config = json.load(f)
except ValueError:
_config = {}
_MULTI_BACKEND = _config.get("multi_backend", _MULTI_BACKEND)

# Save config file, if possible.
if not os.path.exists(_keras_dir):
try:
os.makedirs(_keras_dir)
except OSError:
# Except permission denied and potential race conditions
# in multi-threaded environments.
pass

if not os.path.exists(_config_path):
_config = {
"multi_backend": _MULTI_BACKEND,
}
try:
with open(_config_path, "w") as f:
f.write(json.dumps(_config, indent=4))
except IOError:
# Except permission denied.
pass

# Set Keras backend based on KERAS_CV_MULTI_BACKEND flag, if applicable.
if "KERAS_CV_MULTI_BACKEND" in os.environ:
if os.environ["KERAS_CV_MULTI_BACKEND"]:
_MULTI_BACKEND = True

if "KERAS_BACKEND" in os.environ and os.environ["KERAS_BACKEND"]:
_MULTI_BACKEND = True


def multi_backend():
return _MULTI_BACKEND
23 changes: 23 additions & 0 deletions keras_cv/backend/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_cv.backend.config import multi_backend

if multi_backend():
from keras_core.src.backend import vectorized_map # noqa: F403, F401
from keras_core.src.ops import * # noqa: F403, F401
from keras_core.src.utils.image_utils import ( # noqa: F403, F401
smart_resize,
)
else:
from keras_cv.backend.tf_ops import * # noqa: F403, F401
59 changes: 59 additions & 0 deletions keras_cv/backend/scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools

from keras_cv import backend
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.backend import tf_ops
from keras_cv.backend.config import multi_backend

_ORIGINAL_OPS = copy.copy(backend.ops.__dict__)
_ORIGINAL_SUPPORTS_RAGGED = backend.supports_ragged

# A counter for potentially nested TF data scopes
_IN_TF_DATA_SCOPE = 0


def tf_data(function):
@functools.wraps(function)
def wrapper(*args, **kwargs):
if multi_backend() and keras.src.utils.backend_utils.in_tf_graph():
with TFDataScope():
return function(*args, **kwargs)
else:
return function(*args, **kwargs)

return wrapper


class TFDataScope:
def __enter__(self):
global _IN_TF_DATA_SCOPE
if _IN_TF_DATA_SCOPE == 0:
for k, v in ops.__dict__.items():
if k in tf_ops.__dict__:
setattr(ops, k, getattr(tf_ops, k))
backend.supports_ragged = lambda: True
_IN_TF_DATA_SCOPE += 1

def __exit__(self, exc_type, exc_value, exc_tb):
global _IN_TF_DATA_SCOPE
_IN_TF_DATA_SCOPE -= 1
if _IN_TF_DATA_SCOPE == 0:
for k, v in ops.__dict__.items():
setattr(ops, k, _ORIGINAL_OPS[k])
backend.supports_ragged = _ORIGINAL_SUPPORTS_RAGGED
_IN_TF_DATA_SCOPE = False
Loading

0 comments on commit 993ff61

Please sign in to comment.