Skip to content

Commit

Permalink
Use conditional keras_nlp imports
Browse files Browse the repository at this point in the history
  • Loading branch information
smitlg committed Mar 19, 2024
1 parent fb05c82 commit 8265a17
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 87 deletions.
7 changes: 2 additions & 5 deletions keras_cv/models/feature_extractor/clip/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CLIPTextEncoder,
)
from keras_cv.models.task import Task
from keras_cv.utils.conditional_imports import assert_keras_nlp_installed
from keras_cv.utils.python_utils import classproperty

try:
Expand Down Expand Up @@ -98,11 +99,7 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
if keras_nlp is None:
raise ValueError(
"ClipTokenizer requires keras-nlp. Please install "
"using pip `pip install -U keras-nlp && pip install -U keras`"
)
assert_keras_nlp_installed("CLIP")
self.embed_dim = embed_dim
self.image_resolution = image_resolution
self.vision_layers = vision_layers
Expand Down
9 changes: 7 additions & 2 deletions keras_cv/models/feature_extractor/clip/clip_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.layers import StartEndPacker

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer
from keras_cv.utils.conditional_imports import assert_keras_nlp_installed

try:
import keras_nlp
except ImportError:
keras_nlp = None

@keras_cv_export("keras_cv.models.feature_extractor.CLIPProcessor")
class CLIPProcessor:
Expand Down Expand Up @@ -45,6 +49,7 @@ class CLIPProcessor:
"""

def __init__(self, input_resolution, vocabulary, merges, **kwargs):
assert_keras_nlp_installed("CLIPProcessor")
self.input_resolution = input_resolution
self.vocabulary = vocabulary
self.merges = merges
Expand All @@ -54,7 +59,7 @@ def __init__(self, input_resolution, vocabulary, merges, **kwargs):
merges=self.merges,
unsplittable_tokens=["</w>"],
)
self.packer = StartEndPacker(
self.packer = keras_nlp.layers.StartEndPacker(
start_value=self.tokenizer.token_to_id("<|startoftext|>"),
end_value=self.tokenizer.token_to_id("<|endoftext|>"),
pad_value=None,
Expand Down
158 changes: 78 additions & 80 deletions keras_cv/models/feature_extractor/clip/clip_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
import tensorflow_text as tf_text

try:
import keras_nlp
from keras_nlp.tokenizers import BytePairTokenizer
except ImportError:
keras_nlp = None
BytePairTokenizer = None

# As python and TF handles special spaces differently, we need to
# manually handle special spaces during string split.
Expand Down Expand Up @@ -104,83 +103,82 @@ def remove_strings_from_inputs(tensor, string_to_remove):
return result


class CLIPTokenizer(BytePairTokenizer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if keras_nlp is None:
raise ValueError(
"ClipTokenizer requires keras-nlp. Please install "
"using pip `pip install -U keras-nlp && pip install -U keras`"
if BytePairTokenizer:
class CLIPTokenizer(BytePairTokenizer):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def _bpe_merge_and_update_cache(self, tokens):
"""Process unseen tokens and add to cache."""
words = self._transform_bytes(tokens)
tokenized_words = self._bpe_merge(words)

# For each word, join all its token by a whitespace,
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
tokenized_words = tf.strings.reduce_join(
tokenized_words,
axis=1,
)
self.cache.insert(tokens, tokenized_words)

def tokenize(self, inputs):
self._check_vocabulary()
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)

if self.add_prefix_space:
inputs = tf.strings.join([" ", inputs])

scalar_input = inputs.shape.rank == 0
if scalar_input:
inputs = tf.expand_dims(inputs, 0)

raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
token_row_splits = raw_tokens.row_splits
flat_tokens = raw_tokens.flat_values
# Check cache.
cache_lookup = self.cache.lookup(flat_tokens)
cache_mask = cache_lookup == ""

has_unseen_words = tf.math.reduce_any(
(cache_lookup == "") & (flat_tokens != "")
)

def process_unseen_tokens():
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
self._bpe_merge_and_update_cache(unseen_tokens)
return self.cache.lookup(flat_tokens)

# If `has_unseen_words == True`, it means not all tokens are in cache,
# we will process the unseen tokens. Otherwise return the cache lookup.
tokenized_words = tf.cond(
has_unseen_words,
process_unseen_tokens,
lambda: cache_lookup,
)
tokens = tf.strings.split(tokenized_words, sep=" ")
if self.compute_dtype != tf.string:
# Encode merged tokens.
tokens = self.token_to_id_map.lookup(tokens)

# Unflatten to match input.
tokens = tf.RaggedTensor.from_row_splits(
tokens.flat_values,
tf.gather(tokens.row_splits, token_row_splits),
)

def _bpe_merge_and_update_cache(self, tokens):
"""Process unseen tokens and add to cache."""
words = self._transform_bytes(tokens)
tokenized_words = self._bpe_merge(words)

# For each word, join all its token by a whitespace,
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
tokenized_words = tf.strings.reduce_join(
tokenized_words,
axis=1,
)
self.cache.insert(tokens, tokenized_words)

def tokenize(self, inputs):
self._check_vocabulary()
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)

if self.add_prefix_space:
inputs = tf.strings.join([" ", inputs])

scalar_input = inputs.shape.rank == 0
if scalar_input:
inputs = tf.expand_dims(inputs, 0)

raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
token_row_splits = raw_tokens.row_splits
flat_tokens = raw_tokens.flat_values
# Check cache.
cache_lookup = self.cache.lookup(flat_tokens)
cache_mask = cache_lookup == ""

has_unseen_words = tf.math.reduce_any(
(cache_lookup == "") & (flat_tokens != "")
)

def process_unseen_tokens():
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
self._bpe_merge_and_update_cache(unseen_tokens)
return self.cache.lookup(flat_tokens)

# If `has_unseen_words == True`, it means not all tokens are in cache,
# we will process the unseen tokens. Otherwise return the cache lookup.
tokenized_words = tf.cond(
has_unseen_words,
process_unseen_tokens,
lambda: cache_lookup,
)
tokens = tf.strings.split(tokenized_words, sep=" ")
if self.compute_dtype != tf.string:
# Encode merged tokens.
tokens = self.token_to_id_map.lookup(tokens)

# Unflatten to match input.
tokens = tf.RaggedTensor.from_row_splits(
tokens.flat_values,
tf.gather(tokens.row_splits, token_row_splits),
)

# Convert to a dense output if `sequence_length` is set.
if self.sequence_length:
output_shape = tokens.shape.as_list()
output_shape[-1] = self.sequence_length
tokens = tokens.to_tensor(shape=output_shape)

# Convert to a dense output if input in scalar
if scalar_input:
tokens = tf.squeeze(tokens, 0)
tf.ensure_shape(tokens, shape=[self.sequence_length])

return tokens
# Convert to a dense output if `sequence_length` is set.
if self.sequence_length:
output_shape = tokens.shape.as_list()
output_shape[-1] = self.sequence_length
tokens = tokens.to_tensor(shape=output_shape)

# Convert to a dense output if input in scalar
if scalar_input:
tokens = tf.squeeze(tokens, 0)
tf.ensure_shape(tokens, shape=[self.sequence_length])

return tokens

else:
CLIPTokenizer = None
14 changes: 14 additions & 0 deletions keras_cv/utils/conditional_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
except ImportError:
pycocotools = None

try:
import keras_nlp
except ImportError:
keras_nlp = None


def assert_cv2_installed(symbol_name):
if cv2 is None:
Expand Down Expand Up @@ -70,3 +75,12 @@ def assert_pycocotools_installed(symbol_name):
"Please install the package using "
"`pip install pycocotools`."
)


def assert_keras_nlp_installed(symbol_name):
if keras_nlp is None:
raise ImportError(
f"{symbol_name} requires the `keras_nlp` package. "
"Please install the package using "
"`pip install keras_nlp`."
)

0 comments on commit 8265a17

Please sign in to comment.