Skip to content

Commit

Permalink
add reference code, query error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
spodgorny9 committed Aug 30, 2024
1 parent 6ca811b commit c66ae14
Showing 1 changed file with 76 additions and 39 deletions.
115 changes: 76 additions & 39 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,11 @@ def __init__(self, db_host, db_port, db_name,
GPT.
"""
boto3 = try_import('boto3')
psycopg2 = try_import('psycopg2')
self.psycopg2 = try_import('psycopg2')

self.db_schema = db_schema
self.db_table = db_table

if meta_columns is None:
self.meta_columns = ['title', 'url']
else:
Expand All @@ -455,11 +456,10 @@ def __init__(self, db_host, db_port, db_name,
db_password = os.getenv('EWIZ_DB_PASSWORD')
assert db_user is not None, "Must set EWIZ_DB_USER!"
assert db_password is not None, "Must set EWIZ_DB_PASSWORD!"
self.conn = psycopg2.connect(user=db_user,
password=db_password,
host=db_host,
port=db_port,
database=db_name)
self.db_kwargs = dict(user=db_user, password=db_password,
host=db_host, port=db_port,
database=db_name)
self.conn = self.psycopg2.connect(**self.db_kwargs)

self.cursor = self.conn.cursor()
else:
Expand Down Expand Up @@ -579,19 +579,28 @@ def query_vector_db(self, query, probes=25, limit=100):

query_embedding = self.get_embedding(query)

self.cursor.execute(f"SET LOCAL ivfflat.probes = {probes};"
f"SELECT {self.db_table}.id, "
f"{self.db_table}.chunks, "
f"{self.db_table}.embedding "
"<=> %s::vector as score, "
f"{self.db_table}.title, "
f"{self.db_table}.authors, "
f"{self.db_table}.year "
f"FROM {self.db_schema}.{self.db_table} "
"ORDER BY embedding <=> %s::vector LIMIT %s;",
(query_embedding, query_embedding, limit,), )

result = self.cursor.fetchall()
with self.psycopg2.connect(**self.db_kwargs) as conn:
cursor = conn.cursor()
try:
cursor.execute(f"SET LOCAL ivfflat.probes = {probes};"
f"SELECT {self.db_table}.id, "
f"{self.db_table}.chunks, "
f"{self.db_table}.embedding "
"<=> %s::vector as score, "
f"{self.db_table}.title, "
f"{self.db_table}.authors, "
f"{self.db_table}.year "
f"FROM {self.db_schema}.{self.db_table} "
"ORDER BY embedding <=> %s::vector LIMIT %s;",
(query_embedding, query_embedding, limit,), )
except Exception as exc:
conn.rollback()
msg = (f'Received error when querying the postgres '
f'vector database: {exc}')
raise RuntimeError(msg) from exc
else:
conn.commit()
result = cursor.fetchall()

if self.tag:
strings = [self._add_tag(s[3:]) + s[1] for s in result]
Expand All @@ -603,14 +612,45 @@ def query_vector_db(self, query, probes=25, limit=100):

return strings, scores, best

def _format_refs(self, refs):
"""Parse and nicely format a reference dictionary into
a list of well formatted string representations
Parameters
----------
refs : list
List of references returned from the vector db
Returns
-------
out : list
Unique ordered list of references
"""
ref_list = []
for item in refs:
ref_dict = {self.meta_columns[i]: item[i]
for i in range(len(self.meta_columns))}

ilist = []
for key, value in ref_dict.items():
value = str(value).replace(chr(34), '')
istr = f"\"{key}\": \"{value}\""
ilist.append(istr)

ref_str = ", ".join(ilist)
ref_str = '{' + ref_str + '}'
ref_list.append(ref_str)

seen = set()
ref_list = [x for x in ref_list if not
(x in seen or seen.add(x))]

return ref_list

def make_ref_list(self, ids):
"""Make a reference list
Parameters
----------
used_index : np.ndarray
IDs of the used text from the text corpus
Returns
-------
ref_list : list
Expand All @@ -626,22 +666,19 @@ def make_ref_list(self, ids):
f"FROM {self.db_schema}.{self.db_table} "
f"WHERE {self.db_table}.id IN (" + placeholders + ")")

self.cursor.execute(sql_query, ids)

refs = self.cursor.fetchall()
with self.psycopg2.connect(**self.db_kwargs) as conn:
cursor = conn.cursor()
try:
cursor.execute(sql_query, ids)
except Exception as exc:
conn.rollback()
msg = (f'Received error when querying the postgres '
f'vector database: {exc}')
raise RuntimeError(msg) from exc
else:
conn.commit()
refs = cursor.fetchall()

ref_list = self._format_refs(refs)

ref_list = []
for item in refs:
ref_dict = {self.meta_columns[i]: item[i]
for i in range(len(self.meta_columns))}
ref_str = "{"
ref_str += ", ".join([f"\"{key}\": \"{value}\""
for key, value in ref_dict.items()])
ref_str += "}"

ref_list.append(ref_str)

unique_values = set(ref_list)
unique_list = list(unique_values)

return unique_list
return ref_list

0 comments on commit c66ae14

Please sign in to comment.