Skip to content

Commit

Permalink
database connection handling, reference formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
spodgorny9 committed Sep 12, 2024
1 parent c66ae14 commit f8c7bd3
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 51 deletions.
40 changes: 26 additions & 14 deletions elm/web/rhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,19 +588,25 @@ def authors(self):
authors = []

for r in pa:
first = r.get('name').get('firstName')
last = r.get('name').get('lastName')

if first and last:
full = first + ' ' + last
elif first:
full = first
elif last:
full = last
name = r.get('name')

if name:
first = name.get('firstName')
last = name.get('lastName')

if first and last:
full = first + ' ' + last
elif first:
full = first
elif last:
full = last
else:
full = None

authors.append(full)
if full:
authors.append(full)

out = ', '.join(authors)
out = ', '.join(authors)

return out

Expand Down Expand Up @@ -653,8 +659,12 @@ def abstract(self):
String containing abstract text.
"""
abstract = self.get('abstract')
text = abstract.get('text')[0]
value = text.get('value')

if abstract:
text = abstract.get('text')[0]
value = text.get('value')
else:
value = None

return value

Expand Down Expand Up @@ -701,6 +711,9 @@ def download(self, pdf_dir, txt_dir):
if not os.path.exists(fp):
if abstract:
self.save_abstract(abstract, fp)
else:
logger.info(f'{self.title}: does not have an '
'abstract to downlod')
else:
if pdf_url and pdf_url.endswith('.pdf'):
fn = self.id.replace('/', '-') + '.pdf'
Expand Down Expand Up @@ -876,7 +889,6 @@ def download(self, pdf_dir, txt_dir):
try:
record.download(pdf_dir, txt_dir)
except Exception as e:
print(f"Could not download {record.title} with error {e}")
logger.exception('Could not download {}: {}'
.format(record.title, e))
logger.info('Finished publications download!')
100 changes: 63 additions & 37 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,14 @@ class EnergyWizardPostgres(EnergyWizardBase):
}
"""Optional mappings for weird azure names to tiktoken/openai names."""

DEFAULT_META_COLS = ('title', 'url', 'authors', 'year', 'category', 'id')
"""Default columns to retrieve for metadata"""

def __init__(self, db_host, db_port, db_name,
db_schema, db_table, meta_columns=None,
cursor=None, boto_client=None,
model=None, token_budget=3500,
tag=False):
db_schema, db_table, probes=25,
meta_columns=None, cursor=None,
boto_client=None, model=None,
token_budget=3500, tag=False):
"""
Parameters
----------
Expand All @@ -423,6 +426,9 @@ def __init__(self, db_host, db_port, db_name,
db_table : str
Table to query in Postgres database. Necessary columns: id,
chunks, embedding, title, and url.
probes : int
Number of lists to search in vector database. Recommended
value is sqrt(n_lists).
meta_columns : list
List of metadata columns to retrieve from database. Default
query returns title and url.
Expand All @@ -443,29 +449,32 @@ def __init__(self, db_host, db_port, db_name,
boto3 = try_import('boto3')
self.psycopg2 = try_import('psycopg2')

self.db_schema = db_schema
self.db_table = db_table

if meta_columns is None:
self.meta_columns = ['title', 'url']
else:
self.meta_columns = meta_columns

if cursor is None:
db_user = os.getenv("EWIZ_DB_USER")
db_password = os.getenv('EWIZ_DB_PASSWORD')
assert db_user is not None, "Must set EWIZ_DB_USER!"
assert db_password is not None, "Must set EWIZ_DB_PASSWORD!"
self.db_kwargs = dict(user=db_user, password=db_password,
host=db_host, port=db_port,
database=db_name)
self.db_host = db_host
self.db_port = db_port
self.db_name = db_name
self.db_schema = db_schema
self.db_table = db_table
self.db_user = os.getenv("EWIZ_DB_USER")
self.db_password = os.getenv('EWIZ_DB_PASSWORD')
assert self.db_user is not None, "Must set EWIZ_DB_USER!"
assert self.db_password is not None, "Must set EWIZ_DB_PASSWORD!"
self.db_kwargs = dict(user=self.db_user, password=self.db_password,
host=self.db_host, port=self.db_port,
database=self.db_name)
self.conn = self.psycopg2.connect(**self.db_kwargs)

self.cursor = self.conn.cursor()
else:
self.cursor = cursor

self.tag = tag
self.probes = probes

if boto_client is None:
access_key = os.getenv('AWS_ACCESS_KEY_ID')
Expand Down Expand Up @@ -553,16 +562,14 @@ def _add_tag(meta):

return tag

def query_vector_db(self, query, probes=25, limit=100):
def query_vector_db(self, query, limit=100):
"""Returns a list of strings and relatednesses, sorted from most
related to least.
Parameters
----------
query : str
Question being asked of GPT
probes: int
Number of lists to search in vector database index.
limit : int
Number of top results to return.
Expand All @@ -582,7 +589,7 @@ def query_vector_db(self, query, probes=25, limit=100):
with self.psycopg2.connect(**self.db_kwargs) as conn:
cursor = conn.cursor()
try:
cursor.execute(f"SET LOCAL ivfflat.probes = {probes};"
cursor.execute(f"SET LOCAL ivfflat.probes = {self.probes};"
f"SELECT {self.db_table}.id, "
f"{self.db_table}.chunks, "
f"{self.db_table}.embedding "
Expand Down Expand Up @@ -612,45 +619,64 @@ def query_vector_db(self, query, probes=25, limit=100):

return strings, scores, best

def _format_refs(self, refs):
"""Parse and nicely format a reference dictionary into
a list of well formatted string representations
def _format_refs(self, refs, ids):
"""Parse and nicely format a reference dictionary into a list of well
formatted string representations
Parameters
----------
refs : list
List of references returned from the vector db
ids : np.ndarray
IDs of the used text from the text corpus sorted by embedding
relevance.
Returns
-------
out : list
Unique ordered list of references
Unique ordered list of references (most relevant first)
"""

ref_list = []
for item in refs:
ref_dict = {self.meta_columns[i]: item[i]
for i in range(len(self.meta_columns))}

ilist = []
for key, value in ref_dict.items():
ref_dict = {}
for icol, col in enumerate(self.meta_columns):
value = item[icol]
value = str(value).replace(chr(34), '')
istr = f"\"{key}\": \"{value}\""
ilist.append(istr)
ref_dict[col] = value

ref_str = ", ".join(ilist)
ref_str = '{' + ref_str + '}'
ref_list.append(ref_str)
ref_list.append(ref_dict)

seen = set()
ref_list = [x for x in ref_list if not
(x in seen or seen.add(x))]
unique_ref_list = []
for ref_dict in ref_list:
if str(ref_dict) not in seen:
seen.add(str(ref_dict))
unique_ref_list.append(ref_dict)
ref_list = unique_ref_list

if 'id' in ref_list[0]:
ids_list = list(ids)
sorted_ref_list = []
for ref_id in ids_list:
for ref_dict in ref_list:
if ref_dict['id'] == ref_id:
sorted_ref_list.append(ref_dict)
break
ref_list = sorted_ref_list

ref_list = [json.dumps(ref) for ref in ref_list]

return ref_list

def make_ref_list(self, ids):
"""Make a reference list
Parameters
----------
used_index : np.ndarray
ids : np.ndarray
IDs of the used text from the text corpus
Returns
-------
ref_list : list
Expand All @@ -674,11 +700,11 @@ def make_ref_list(self, ids):
conn.rollback()
msg = (f'Received error when querying the postgres '
f'vector database: {exc}')
raise RuntimeError(msg) from exc
raise RuntimeError(msg)
else:
conn.commit()
refs = cursor.fetchall()

ref_list = self._format_refs(refs)
ref_list = self._format_refs(refs, ids)

return ref_list

0 comments on commit f8c7bd3

Please sign in to comment.