Skip to content

Commit fbffd33

Browse files
authored
V0.15.1.dev1 (#1844)
* Preprocessing decorator fixes (#1843) * Fix handling bytesting input to tokenizers, preprocessing * Fix no convert scope in multithreaded contexts * Version bump dev release
1 parent 8390c65 commit fbffd33

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

keras_nlp/src/utils/tensor_utils.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,19 @@
3030

3131

3232
NO_CONVERT_COUNTER = threading.local()
33-
NO_CONVERT_COUNTER.count = 0
3433

3534

3635
@contextlib.contextmanager
3736
def no_convert_scope():
3837
try:
39-
NO_CONVERT_COUNTER.count += 1
38+
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) + 1
4039
yield
4140
finally:
42-
NO_CONVERT_COUNTER.count -= 1
41+
NO_CONVERT_COUNTER.count = getattr(NO_CONVERT_COUNTER, "count", 0) - 1
4342

4443

4544
def in_no_convert_scope():
46-
return NO_CONVERT_COUNTER.count > 0
45+
return getattr(NO_CONVERT_COUNTER, "count", 0) > 0
4746

4847

4948
def preprocessing_function(fn):
@@ -119,7 +118,7 @@ def convert_preprocessing_inputs(x):
119118
return {k: convert_preprocessing_inputs(x[k]) for k, v in x.items()}
120119
if isinstance(x, tuple):
121120
return tuple(convert_preprocessing_inputs(v) for v in x)
122-
if isinstance(x, str):
121+
if isinstance(x, (str, bytes)):
123122
return tf.constant(x)
124123
if isinstance(x, list):
125124
try:
@@ -132,7 +131,7 @@ def convert_preprocessing_inputs(x):
132131
# If ragged conversion failed return to the numpy error.
133132
raise e
134133
# If we have a string input, use tf.tensor.
135-
if numpy_x.dtype.type is np.str_:
134+
if numpy_x.dtype.type is np.str_ or numpy_x.dtype.type is np.bytes_:
136135
return tf.convert_to_tensor(x)
137136
# Numpy will default to int64, int32 works with more ops.
138137
if numpy_x.dtype == np.int64:

keras_nlp/src/utils/tensor_utils_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ def test_strings(self):
4949
self.assertIsInstance(outputs, list)
5050
self.assertEqual(outputs, inputs)
5151

52+
def test_bytestrings(self):
53+
inputs = ["one".encode("utf-8"), "two".encode("utf-8")]
54+
# Convert to tf.
55+
outputs = convert_preprocessing_inputs(inputs)
56+
self.assertIsInstance(outputs, tf.Tensor)
57+
self.assertAllEqual(outputs, tf.constant(inputs))
58+
# Convert from tf.
59+
outputs = convert_preprocessing_outputs(outputs)
60+
self.assertIsInstance(outputs, list)
61+
self.assertEqual(outputs, [x.decode("utf-8") for x in inputs])
62+
5263
def test_ragged(self):
5364
inputs = [np.ones((1, 3)), np.ones((1, 2))]
5465
# Convert to tf.

keras_nlp/src/version_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from keras_nlp.src.api_export import keras_nlp_export
1616

1717
# Unique source of truth for the version number.
18-
__version__ = "0.15.1.dev0"
18+
__version__ = "0.15.1.dev1"
1919

2020

2121
@keras_nlp_export("keras_nlp.version")

0 commit comments

Comments
 (0)