Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT SUBMIT] CLOUDRUNDEBUG #4453

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build/web_compose/compose.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
# limitations under the License.

export MIXER_API_KEY=$DC_API_KEY
# https://stackoverflow.com/a/62703850
export TOKENIZERS_PARALLELISM=false
# https://github.com/UKPLab/sentence-transformers/issues/1318#issuecomment-1084731111
export OMP_NUM_THREADS=1

if [[ $USE_SQLITE == "true" ]]; then
export SQLITE_PATH=/sqlite/datacommons.db
Expand Down
15 changes: 15 additions & 0 deletions nl_server/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
import logging
from typing import Dict, List

import torch
Expand Down Expand Up @@ -78,20 +79,34 @@ class Embeddings:
def __init__(self, model: EmbeddingsModel, store: EmbeddingsStore):
self.model: EmbeddingsModel = model
self.store: EmbeddingsStore = store
self.search_count = 0

# Given a list of queries, returns
def vector_search(self, queries: List[str], top_k: int) -> SearchVarsResult:
self.search_count += 1
logging.info('CLOUDRUNDEBUG In Embeddings.vector_search #%s: %s (%s)',
self.search_count, queries, self)
query_embeddings = self.model.encode(queries)
logging.info('CLOUDRUNDEBUG len(query_embeddings) #%s: %s',
self.search_count, len(query_embeddings))

if self.model.returns_tensor and not self.store.needs_tensor:
# Convert to List[List[float]]
logging.info('In query_embeddings condition #1')
query_embeddings = query_embeddings.tolist()
logging.info('Done len(query_embeddings) condition #1: %s',
len(query_embeddings))
elif not self.model.returns_tensor and self.store.needs_tensor:
# Convert to torch.Tensor
logging.info('In query_embeddings condition #2')
query_embeddings = torch.tensor(query_embeddings, dtype=torch.float)
logging.info('Done len(query_embeddings) condition #2: %s',
len(query_embeddings))

# Call the store.
logging.info('Before store.vector_search: %s', self.store)
results = self.store.vector_search(query_embeddings, top_k)
logging.info('after store.vector_search len(results): %s', len(results))

# Turn this into a map:
return {k: v for k, v in zip(queries, results)}
11 changes: 11 additions & 0 deletions nl_server/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,33 @@ def create_app():
try:
# Build the registry before creating the Flask app to make sure all resources
# are loaded.
logging.info('CLOUDRUNDEBUG Building registry.')
reg = registry.build()
logging.info('CLOUDRUNDEBUG Registry built.')

if not lib_utils.is_test_env():
# Below is a safe check to ensure that the model and embedding is loaded.
server_config = reg.server_config()
idx_type = server_config.default_indexes[0]
embeddings = reg.get_index(idx_type)
query = server_config.indexes[idx_type].healthcheck_query
logging.info('CLOUDRUNDEBUG Healthcheck query: %s', query)
result = search.search_vars([embeddings], [query]).get(query)
logging.info('CLOUDRUNDEBUG Healthcheck query len(result.svs): %s',
len(result.svs))
if not result or not result.svs:
raise Exception(f'Registry does not have default index {idx_type}')

app = Flask(__name__)
app.register_blueprint(routes.bp)
app.config[registry.REGISTRY_KEY] = reg

embeddings = reg.get_index("medium_ft")
logging.info('CLOUDRUNDEBUG reg.get_index("medium_ft"): %s', embeddings)
if embeddings:
results = search.search_vars([embeddings], ["life expectancy"])
logging.info('CLOUDRUNDEBUG len(results): %s', len(results))

logging.info('NL Server Flask app initialized')
return app
except Exception as e:
Expand Down
14 changes: 14 additions & 0 deletions nl_server/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@

@bp.route('/healthz')
def healthz():
print("In healthz")
reg: Registry = current_app.config[REGISTRY_KEY]
embeddings = reg.get_index("medium_ft")
logging.info('CLOUDRUNDEBUG healthz reg.get_index("medium_ft"): %s',
embeddings)
if embeddings:
results = search.search_vars([embeddings], ["life expectancy"])
logging.info('CLOUDRUNDEBUG healthz len(results): %s', len(results))
return 'NL Server is healthy', 200


Expand Down Expand Up @@ -75,6 +83,7 @@ def search_vars():
"""
queries = request.json.get('queries', [])
queries = [str(escape(q)) for q in queries]
logging.info('CLOUDRUNDEBUG In search_vars: %s', queries)

# TODO: clean up skip topics, may not be used anymore
skip_topics = False
Expand All @@ -100,8 +109,10 @@ def search_vars():
embeddings = _get_indexes(reg, idx_types)

debug_logs = {'sv_detection_query_index_types': idx_types}
logging.info('CLOUDRUNDEBUG debug_logs: %s', debug_logs)
results = search.search_vars(embeddings, queries, skip_topics, reranker_model,
debug_logs)
logging.info('CLOUDRUNDEBUG results: %s', results)
q2result = {q: var_candidates_to_dict(result) for q, result in results.items()}
return json.dumps({
'queryResults': q2result,
Expand Down Expand Up @@ -130,14 +141,17 @@ def embeddings_version_map():

@bp.route('/api/load/', methods=['POST'])
def load():
logging.info('CLOUDRUNDEBUG In nl_server.load')
catalog = request.json.get('catalog', None)
logging.info('CLOUDRUNDEBUG catalog: %s', catalog)
try:
current_app.config[REGISTRY_KEY] = registry.build(
additional_catalog=catalog)
except Exception as e:
logging.error(f'Server registry not built due to error: {str(e)}')
reg: Registry = current_app.config[REGISTRY_KEY]
server_config = reg.server_config()
logging.info('CLOUDRUNDEBUG server_config: %s', server_config)
return json.dumps(asdict(server_config))


Expand Down
9 changes: 8 additions & 1 deletion nl_server/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Library that exposes search_vars"""

import logging
import time
from typing import Dict, List

Expand Down Expand Up @@ -49,8 +50,12 @@ def search_vars(embeddings_list: List[Embeddings],

# Call vector search for each index.
query2candidates_list: List[EmbeddingsResult] = []
logging.info('CLOUDRUNDEBUG In search_vars')
for embeddings in embeddings_list:
query2candidates_list.append(embeddings.vector_search(queries, topk))
result = embeddings.vector_search(queries, topk)
logging.info('CLOUDRUNDEBUG vector_search len(result): %s (%s)',
len(result), embeddings)
query2candidates_list.append(result)

# Merge the results.
query2candidates = merge_search_results(query2candidates_list)
Expand All @@ -60,6 +65,8 @@ def search_vars(embeddings_list: List[Embeddings],
for query, candidates in query2candidates.items():
results[query] = _rank_vars(candidates, skip_topics)

logging.info('CLOUDRUNDEBUG merged len(results): %s', len(results))

if rerank_model:
start = time.time()
results = rerank.rerank(rerank_model, results, debug_logs)
Expand Down
2 changes: 2 additions & 0 deletions nl_server/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def __init__(self, idx_info: MemoryIndexConfig) -> None:
#
def vector_search(self, query_embeddings: torch.Tensor,
top_k: int) -> List[EmbeddingsResult]:
logging.info('CLOUDRUNDEBUG In vector_search')
hits = semantic_search(query_embeddings,
self.dataset_embeddings,
top_k=top_k)
logging.info('CLOUDRUNDEBUG len(hits): %s', len(hits))
results: List[EmbeddingsResult] = []
for hit in hits:
matches: List[EmbeddingsMatch] = []
Expand Down
1 change: 1 addition & 0 deletions run_cdc_dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ set -e
trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT

source .run_cdc_dev.env && export $(sed '/^#/d' .run_cdc_dev.env | cut -d= -f1)
export TOKENIZERS_PARALLELISM=false

# Print commit hashes.
echo -e "\033[0;32m" # Set different color.
Expand Down
Loading