Skip to content

Commit 61ac8c1

Browse files
authored
StringLookup & IntegerLookup now save vocabulary loaded from file (#21751)
in the `.keras` archive when they are initialized with a path to a vocabulary file. This makes the `.keras` archive fully self contained. This was already the behavior when using either `set_vocabulary` or `adapt`. Simply, this behavior was extended to the case when `__init__` is called with a vocabulary file. Note that this is technically a breaking change. Previously, upon doing `keras.saving.load_model`, it would be looking up the vocabulary file at the exact same path as when originally constructed. Also disallow loading an arbitrary vocabulary file during model loading with `safe_mode=True` since the vocabulary file should now come from the archive.
1 parent a6345cd commit 61ac8c1

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

keras/src/layers/preprocessing/index_lookup.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from keras.src import backend
66
from keras.src.layers.layer import Layer
7+
from keras.src.saving import serialization_lib
78
from keras.src.utils import argument_validation
89
from keras.src.utils import numerical_utils
910
from keras.src.utils import tf_utils
@@ -178,7 +179,12 @@ def __init__(
178179
self.vocabulary_dtype = tf.as_dtype(vocabulary_dtype).name
179180
self._frozen_vocab_size = kwargs.pop("vocabulary_size", None)
180181

181-
self.input_vocabulary = vocabulary
182+
# Remember original `vocabulary` as `input_vocabulary` for serialization
183+
# via `get_config`. However, if `vocabulary` is a file path or a URL, we
184+
# serialize the vocabulary as an asset and clear the original path/URL.
185+
self.input_vocabulary = (
186+
vocabulary if not isinstance(vocabulary, str) else None
187+
)
182188
self.input_idf_weights = idf_weights
183189

184190
# We set this hidden attr to
@@ -382,6 +388,18 @@ def set_vocabulary(self, vocabulary, idf_weights=None):
382388
)
383389

384390
if isinstance(vocabulary, str):
391+
if serialization_lib.in_safe_mode():
392+
raise ValueError(
393+
"Requested the loading of a vocabulary file outside of the "
394+
"model archive. This carries a potential risk of loading "
395+
"arbitrary and sensitive files and thus it is disallowed "
396+
"by default. If you trust the source of the artifact, you "
397+
"can override this error by passing `safe_mode=False` to "
398+
"the loading function, or calling "
399+
"`keras.config.enable_unsafe_deserialization(). "
400+
f"Vocabulary file: '{vocabulary}'"
401+
)
402+
385403
if not tf.io.gfile.exists(vocabulary):
386404
raise ValueError(
387405
f"Vocabulary file {vocabulary} does not exist."

keras/src/layers/preprocessing/string_lookup_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import os
2+
13
import numpy as np
24
import pytest
35
from tensorflow import data as tf_data
46

57
from keras.src import backend
68
from keras.src import layers
9+
from keras.src import models
10+
from keras.src import saving
711
from keras.src import testing
812
from keras.src.ops import convert_to_tensor
913

@@ -19,6 +23,40 @@ def test_config(self):
1923
mask_token="[MASK]",
2024
)
2125
self.run_class_serialization_test(layer)
26+
self.assertEqual(layer.get_config()["vocabulary"], ["a", "b", "c"])
27+
28+
def test_vocabulary_file(self):
29+
temp_dir = self.get_temp_dir()
30+
vocab_path = os.path.join(temp_dir, "vocab.txt")
31+
with open(vocab_path, "w") as file:
32+
file.write("a\nb\nc\n")
33+
34+
layer = layers.StringLookup(
35+
output_mode="int",
36+
vocabulary=vocab_path,
37+
oov_token="[OOV]",
38+
mask_token="[MASK]",
39+
name="index",
40+
)
41+
self.assertEqual(
42+
[str(v) for v in layer.get_vocabulary()],
43+
["[MASK]", "[OOV]", "a", "b", "c"],
44+
)
45+
self.assertIsNone(layer.get_config().get("vocabulary", None))
46+
47+
# Make sure vocabulary comes from the archive, not the original file.
48+
os.remove(vocab_path)
49+
50+
model = models.Sequential([layer])
51+
model_path = os.path.join(temp_dir, "test_model.keras")
52+
model.save(model_path)
53+
54+
reloaded_model = saving.load_model(model_path)
55+
reloaded_layer = reloaded_model.get_layer("index")
56+
self.assertEqual(
57+
[str(v) for v in reloaded_layer.get_vocabulary()],
58+
["[MASK]", "[OOV]", "a", "b", "c"],
59+
)
2260

2361
def test_adapt_flow(self):
2462
layer = layers.StringLookup(

0 commit comments

Comments
 (0)