Skip to content

Commit 3c59c1b

Browse files
authored
Refactor NL Server bootstrap code for readability (#3725)
This is to make follow on custom DC specific changes cleaner. There is no intentional logic change other than using a single cache entry (instead of one per index) in test environment. A key simplification is to more directly rely on the embeddings name to extract the base / tuned-model (see `config._parse()`) instead of the current code which uses `models.yaml` and also relies on the naming partially.
1 parent fef25f3 commit 3c59c1b

9 files changed

+220
-151
lines changed

nl_server/__init__.py

+24-33
Original file line numberDiff line numberDiff line change
@@ -12,62 +12,53 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
1615
import os
1716
import sys
1817

1918
from flask import Flask
2019
import torch
2120
import yaml
2221

22+
from nl_server import config
2323
import nl_server.loader as loader
2424
import nl_server.routes as routes
2525

26+
_MODEL_YAML = 'models.yaml'
27+
_EMBEDDINGS_YAML = 'embeddings.yaml'
28+
2629

2730
def create_app():
2831
app = Flask(__name__)
2932
app.register_blueprint(routes.bp)
3033

31-
flask_env = os.environ.get('FLASK_ENV')
32-
3334
# https://github.com/UKPLab/sentence-transformers/issues/1318
3435
if sys.version_info >= (3, 8) and sys.platform == "darwin":
3536
torch.set_num_threads(1)
3637

37-
# Download existing finetuned models (if not already downloaded).
38-
models_downloaded_paths = {}
39-
models_config_path = '/datacommons/nl/models.yaml'
40-
if flask_env in ['local', 'test', 'integration_test', 'webdriver']:
41-
models_config_path = os.path.join(
42-
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
43-
'deploy/nl/models.yaml')
44-
app.config['MODELS_CONFIG_PATH'] = models_config_path
45-
with open(app.config['MODELS_CONFIG_PATH']) as f:
38+
with open(get_env_path(_MODEL_YAML)) as f:
4639
models_map = yaml.full_load(f)
47-
if not models_map:
48-
logging.error("No configuration found for model")
49-
return
40+
assert models_map, 'No models.yaml found!'
5041

51-
models_downloaded_paths = loader.download_models(models_map)
42+
with open(get_env_path(_EMBEDDINGS_YAML)) as f:
43+
embeddings_map = yaml.full_load(f)
44+
assert embeddings_map, 'No embeddings.yaml found!'
45+
app.config[config.NL_EMBEDDINGS_VERSION_KEY] = embeddings_map
5246

53-
assert models_downloaded_paths, "No models were found/downloaded. Check deploy/nl/models.yaml"
47+
loader.load_server_state(app, embeddings_map, models_map)
5448

55-
# Download existing embeddings (if not already downloaded).
56-
embeddings_config_path = '/datacommons/nl/embeddings.yaml'
57-
if flask_env in ['local', 'test', 'integration_test', 'webdriver']:
58-
embeddings_config_path = os.path.join(
59-
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
60-
'deploy/nl/embeddings.yaml')
61-
app.config['EMBEDDINGS_CONFIG_PATH'] = embeddings_config_path
49+
return app
6250

63-
# Initialize the NL module.
64-
with open(app.config['EMBEDDINGS_CONFIG_PATH']) as f:
65-
embeddings_map = yaml.full_load(f)
66-
if not embeddings_map:
67-
logging.error("No configuration found for embeddings")
68-
return
6951

70-
app.config['EMBEDDINGS_VERSION_MAP'] = embeddings_map
71-
loader.load_embeddings(app, embeddings_map, models_downloaded_paths)
52+
#
53+
# On prod the yaml files are in /datacommons/nl/, whereas
54+
# in test-like environments it is the checked in path
55+
# (deploy/nl/).
56+
#
57+
def get_env_path(file_name: str) -> str:
58+
flask_env = os.environ.get('FLASK_ENV')
59+
if flask_env in ['local', 'test', 'integration_test', 'webdriver']:
60+
return os.path.join(
61+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
62+
f'deploy/nl/{file_name}')
7263

73-
return app
64+
return f'/datacommons/nl/{file_name}'

nl_server/config.py

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Dict, List
17+
18+
from nl_server import embeddings
19+
from nl_server import gcs
20+
21+
# Index constants. Passed in `url=`
22+
CUSTOM_DC_INDEX = 'custom'
23+
DEFAULT_INDEX_TYPE = 'medium_ft'
24+
25+
# The default base model we use.
26+
EMBEDDINGS_BASE_MODEL_NAME = 'all-MiniLM-L6-v2'
27+
28+
# App Config constants.
29+
NL_MODEL_KEY = 'NL_MODEL'
30+
NL_EMBEDDINGS_KEY = 'NL_EMBEDDINGS'
31+
NL_EMBEDDINGS_VERSION_KEY = 'NL_EMBEDDINGS_VERSION_MAP'
32+
33+
34+
# Defines one embeddings index config.
35+
@dataclass
36+
class EmbeddingsIndex:
37+
# Name provided in the yaml file, and set in `idx=` URL param.
38+
name: str
39+
40+
# File name provided in the yaml file.
41+
embeddings_file_name: str
42+
# Local path.
43+
embeddings_local_path: str = ""
44+
45+
# Fine-tuned model name ("" if embeddings uses base model).
46+
tuned_model: str = ""
47+
# Fine-tuned model local path.
48+
tuned_model_local_path: str = ""
49+
50+
51+
#
52+
# Validates the config input, downloads all the files and returns a list of Indexes to load.
53+
#
54+
def load(embeddings_map: Dict[str, str],
55+
models_map: Dict[str, str]) -> List[EmbeddingsIndex]:
56+
# Create Index objects.
57+
indexes = _parse(embeddings_map)
58+
59+
# This is just a sanity, we can soon deprecate models.yaml
60+
tuned_models_provided = list(set(models_map.values()))
61+
tuned_models_configured = list(
62+
set([i.tuned_model for i in indexes if i.tuned_model]))
63+
assert sorted(tuned_models_configured) == sorted(tuned_models_provided), \
64+
f'{tuned_models_configured} vs. {tuned_models_provided}'
65+
66+
#
67+
# Download all the models.
68+
#
69+
model2path = {d: gcs.download_model_folder(d) for d in tuned_models_configured}
70+
for idx in indexes:
71+
if idx.tuned_model:
72+
idx.tuned_model_local_path = model2path[idx.tuned_model]
73+
74+
#
75+
# Download all the embeddings.
76+
#
77+
for idx in indexes:
78+
idx.embeddings_local_path = gcs.download_embeddings(
79+
idx.embeddings_file_name)
80+
81+
return indexes
82+
83+
84+
def _parse(embeddings_map: Dict[str, str]) -> List[EmbeddingsIndex]:
85+
indexes: List[EmbeddingsIndex] = []
86+
87+
for key, value in embeddings_map.items():
88+
idx = EmbeddingsIndex(name=key, embeddings_file_name=value)
89+
90+
parts = value.split('.')
91+
assert parts[
92+
-1] == 'csv', f'Embeddings file {value} name does not end with .csv!'
93+
94+
if len(parts) == 4:
95+
# Expect: <embeddings_version>.<fine-tuned-model-version>.<base-model>.csv
96+
# Example: embeddings_sdg_2023_09_12_16_38_04.ft_final_v20230717230459.all-MiniLM-L6-v2.csv
97+
assert parts[
98+
2] == EMBEDDINGS_BASE_MODEL_NAME, f'Unexpected base model {parts[3]}'
99+
idx.tuned_model = f'{parts[1]}.{parts[2]}'
100+
else:
101+
# Expect: <embeddings_version>.csv
102+
# Example: embeddings_small_2023_05_24_23_17_03.csv
103+
assert len(parts) == 2, f'Unexpected file name format {value}'
104+
indexes.append(idx)
105+
106+
return indexes

nl_server/embeddings.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Managing the embeddings."""
15-
from dataclasses import dataclass
1615
import logging
1716
import os
1817
from typing import Dict, List, Union
@@ -22,13 +21,12 @@
2221
from sentence_transformers.util import semantic_search
2322
import torch
2423

24+
from nl_server import config
2525
from nl_server import query_util
2626
from shared.lib import constants
2727
from shared.lib import detected_variables as vars
2828
from shared.lib import utils
2929

30-
MODEL_NAME = 'all-MiniLM-L6-v2'
31-
3230
# A value higher than the highest score.
3331
_HIGHEST_SCORE = 1.0
3432
_INIT_SCORE = (_HIGHEST_SCORE + 0.1)
@@ -52,7 +50,7 @@ def __init__(self,
5250
assert os.path.exists(existing_model_path)
5351
self.model = SentenceTransformer(existing_model_path)
5452
else:
55-
self.model = SentenceTransformer(MODEL_NAME)
53+
self.model = SentenceTransformer(config.EMBEDDINGS_BASE_MODEL_NAME)
5654
self.dataset_embeddings: torch.Tensor = None
5755
self.dcids: List[str] = []
5856
self.sentences: List[str] = []

nl_server/embeddings_store.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List
16+
17+
from nl_server.config import DEFAULT_INDEX_TYPE
18+
from nl_server.config import EmbeddingsIndex
19+
from nl_server.embeddings import Embeddings
20+
21+
22+
#
23+
# A simple wrapper class around multiple embeddings indexes.
24+
#
25+
# TODO: Handle custom DC specific logic here.
26+
#
27+
class Store:
28+
29+
def __init__(self, indexes: List[EmbeddingsIndex]):
30+
self.embeddings_map = {}
31+
for idx in indexes:
32+
self.embeddings_map[idx.name] = Embeddings(idx.embeddings_local_path,
33+
idx.tuned_model_local_path)
34+
35+
# Note: The caller takes care of exceptions.
36+
def get(self, index_type: str = DEFAULT_INDEX_TYPE) -> Embeddings:
37+
return self.embeddings_map[index_type]

0 commit comments

Comments
 (0)