Skip to content

Commit 09e5654

Browse files
authored
NL: reject inappropriate words and improve schools (#2952)
* The initial bad-word list is located [here](https://storage.mtls.cloud.google.com/datcom-website-config/nl_bad_words.txt). * Also, avoid school types from being considered as stop-words. There were two other ripple effects to this change: 1. It also requires special-handling for fallback logic (to not say "schools in sunnyvale" => "sunnyvale") 2. It also requires not regressing the demo query [how big are public schools in sunnyvale] Note: making the schools change also uncovered #2953. This whole stop-word removal business needs streamlining! Post fishfood maybe, and as part of fixing 2853. Screenshot ![image](https://github.com/datacommonsorg/website/assets/4375037/55a824e5-2ed3-4358-b928-e421cb9dd99f)
1 parent 25566e8 commit 09e5654

File tree

18 files changed

+918
-74
lines changed

18 files changed

+918
-74
lines changed

nl_server/gcs.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
BUCKET = 'datcom-nl-models'
16-
TEMP_DIR = '/tmp/'
1716

1817
import os
1918
from pathlib import Path
@@ -23,21 +22,15 @@
2322
from google.cloud import storage
2423
from sentence_transformers import SentenceTransformer
2524

25+
from shared.lib import gcs as gcs_lib
2626

27-
# Downloads the `embeddings_file` from GCS to TEMP_DIR
28-
# and return its path.
29-
def download_embeddings(embeddings_file: str) -> str:
30-
storage_client = storage.Client()
31-
bucket = storage_client.bucket(bucket_name=BUCKET)
32-
blob = bucket.get_blob(embeddings_file)
33-
# Download
34-
local_embeddings_path = local_path(embeddings_file)
35-
blob.download_to_filename(local_embeddings_path)
36-
return local_embeddings_path
27+
28+
def download_embeddings(embeddings_filename: str) -> str:
29+
return gcs_lib.download_file(bucket=BUCKET, filename=embeddings_filename)
3730

3831

3932
def local_path(embeddings_file: str) -> str:
40-
return os.path.join(TEMP_DIR, embeddings_file)
33+
return os.path.join(gcs_lib.TEMP_DIR, embeddings_file)
4134

4235

4336
def download_model_from_gcs(gcs_bucket: Any, local_dir: str,
@@ -78,7 +71,7 @@ def download_model_from_gcs(gcs_bucket: Any, local_dir: str,
7871
def download_model_folder(model_folder: str) -> str:
7972
sc = storage.Client()
8073
bucket = sc.bucket(bucket_name=BUCKET)
81-
directory = TEMP_DIR
74+
directory = gcs_lib.TEMP_DIR
8275

8376
# Only download if needed.
8477
model_path = os.path.join(directory, model_folder)

server/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import server.lib.config as libconfig
3636
from server.lib.disaster_dashboard import get_disaster_dashboard_data
3737
import server.lib.i18n as i18n
38+
from server.lib.nl.common import bad_words
3839
import server.lib.util as libutil
3940
import server.services.bigtable as bt
4041
from server.services.discovery import configure_endpoints_from_ingress
@@ -370,6 +371,7 @@ def create_app():
370371
secret_response = secret_client.access_secret_version(name=secret_name)
371372
app.config['PALM_API_KEY'] = secret_response.payload.data.decode(
372373
'UTF-8')
374+
app.config['NL_BAD_WORDS'] = bad_words.load_bad_words()
373375

374376
# Get and save the blocklisted svgs.
375377
blocklist_svg = []

server/integration_tests/nlnext_test.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def run_sequence(self,
3535
detector='hybrid',
3636
check_place_detection=False,
3737
expected_detectors=[],
38-
place_detector='ner'):
38+
place_detector='ner',
39+
failure=''):
3940
if detector == 'heuristic':
4041
detection_method = 'Heuristic Based'
4142
elif detector == 'llm':
@@ -76,6 +77,11 @@ def run_sequence(self,
7677
}
7778
infile.write(json.dumps(dbg_to_write, indent=2))
7879
else:
80+
if failure:
81+
self.assertTrue(failure in resp["failure"]), resp["failure"]
82+
self.assertTrue(not resp["config"])
83+
return
84+
7985
if not expected_detectors:
8086
self.assertTrue(dbg.get('detection_type').startswith(detection_method)), \
8187
'Query {q} failed!'
@@ -135,13 +141,16 @@ def test_demo_cities_feb2023(self):
135141
self.run_sequence(
136142
'demo2_cities_feb2023',
137143
[
144+
# This should list public school entities.
138145
'How big are the public schools in Sunnyvale',
139146
'What is the prevalence of asthma there',
140147
'What is the commute pattern there',
141148
'How does that compare with San Bruno',
142149
# Proxy for parks in magiceye
143150
'Which cities in the SF Bay Area have the highest larceny',
144151
'What countries in Africa had the greatest increase in life expectancy',
152+
# This should list stats about the middle school students.
153+
'How many middle schools are there in Sunnyvale',
145154
])
146155

147156
def test_demo_fallback(self):
@@ -254,3 +263,9 @@ def test_medium_index(self):
254263
self.run_sequence('medium_index',
255264
['cars per family in california counties'],
256265
idx='medium')
266+
267+
def test_inappropriate_query(self):
268+
self.run_sequence('inappropriate_query',
269+
['how many wise asses live in sunnyvale?'],
270+
idx='medium',
271+
failure='inappropriate words')

0 commit comments

Comments
 (0)