Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ingest file handling for maintainability #73

Merged
merged 7 commits into from
Sep 15, 2023
124 changes: 53 additions & 71 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,81 +201,63 @@ def get_context_stuffed_prompt(self, user_question: str, course_name: str, top_n


def bulk_ingest(self, s3_paths: Union[List[str], str], course_name: str, **kwargs) -> Dict[str, List[str]]:
# https://python.langchain.com/en/latest/modules/indexes/document_loaders/examples/microsoft_word.html
success_status = {"success_ingest": [], "failure_ingest": []}

def ingest(file_ext_mapping, s3_path, *args, **kwargs):
handler = file_ext_mapping.get(Path(s3_path).suffix)
if handler:
ret = handler(s3_path, *args, **kwargs)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)

file_ext_mapping = {
'.html': self._ingest_html,
'.py': self._ingest_single_py,
'.vtt': self._ingest_single_vtt,
'.pdf': self._ingest_single_pdf,
'.txt': self._ingest_single_txt,
'.md': self._ingest_single_txt,
'.srt': self._ingest_single_srt,
'.docx': self._ingest_single_docx,
'.ppt': self._ingest_single_ppt,
'.pptx': self._ingest_single_ppt,
}

try:
if isinstance(s3_paths, str):
s3_paths = [s3_paths]

for s3_path in s3_paths:
ext = Path(s3_path).suffix # check mimetype of file
# TODO: no need to download, just guess_type against the s3_path...
with NamedTemporaryFile(suffix=ext) as tmpfile:
self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile)
mime_type = str(mimetypes.guess_type(tmpfile.name)[0])
category, subcategory = mime_type.split('/')

# TODO: if mime-type is text, we should handle that via .txt ingest

if s3_path.endswith('.html'):
ret = self._ingest_html(s3_path, course_name, kwargs=kwargs)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif s3_path.endswith('.py'):
ret = self._ingest_single_py(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif s3_path.endswith('.vtt'):
ret = self._ingest_single_vtt(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif s3_path.endswith('.pdf'):
ret = self._ingest_single_pdf(s3_path, course_name, kwargs=kwargs)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif s3_path.endswith('.txt') or s3_path.endswith('.md'):
ret = self._ingest_single_txt(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif s3_path.endswith('.srt'):
ret = self._ingest_single_srt(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif s3_path.endswith('.docx'):
ret = self._ingest_single_docx(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif s3_path.endswith('.ppt') or s3_path.endswith('.pptx'):
ret = self._ingest_single_ppt(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif category == 'video' or category == 'audio':
ret = self._ingest_single_video(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
return success_status
if isinstance(s3_paths, str):
s3_paths = [s3_paths]

for s3_path in s3_paths:
with NamedTemporaryFile(suffix=Path(s3_path).suffix) as tmpfile:
self.s3_client.download_fileobj(Bucket=os.environ['S3_BUCKET_NAME'], Key=s3_path, Fileobj=tmpfile)
mime_type = mimetypes.guess_type(tmpfile.name)[0]
category, _ = mime_type.split('/')
match_file_ext = "." + _

if category in ['video', 'audio']:
ret = self._ingest_single_video(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
elif category == 'text' and match_file_ext not in file_ext_mapping.keys():
print(category, match_file_ext)
ret = self._ingest_single_txt(s3_path, course_name)
if ret != "Success":
success_status['failure_ingest'].append(s3_path)
else:
success_status['success_ingest'].append(s3_path)
else:
ingest(file_ext_mapping, s3_path, course_name, kwargs=kwargs)

return success_status

except Exception as e:
success_status['failure_ingest'].append("MAJOR ERROR IN /bulk_ingest: Error: " + str(e))
return success_status
success_status['failure_ingest'].append(f"MAJOR ERROR IN /bulk_ingest: Error: {str(e)}")
return success_status


def _ingest_single_py(self, s3_path: str, course_name: str):
try:
Expand Down
33 changes: 22 additions & 11 deletions ai_ta_backend/web_scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def remove_duplicates(urls:list, supabase_urls:list=None):
print("deleted", og_len-len(not_repeated_files), "duplicate files")
return urls

def crawler(url:str, max_urls:int=1000, max_depth:int=3, timeout:int=1, base_url_on:str=None, _depth:int=0, _soup:BeautifulSoup=None, _filetype:str=None, _invalid_urls:list=[], _existing_urls:list=None):
def crawler(url:str, max_urls:int=1000, max_depth:int=3, timeout:int=1, base_url_on:str=None, _depth:int=0, _soup:BeautifulSoup=None, _filetype:str=None, _invalid_urls:list=[], _existing_urls:list=[]):
'''Function gets titles of urls and the urls themselves'''
# Prints the depth of the current search
print("depth: ", _depth)
Expand Down Expand Up @@ -181,7 +181,7 @@ def crawler(url:str, max_urls:int=1000, max_depth:int=3, timeout:int=1, base_url
url, s, filetype = valid_url(url)
time.sleep(timeout)
url_contents.append((url,s, filetype))
print("Scraped:", url)
print("Scraped:", url, "✅")
if url:
if filetype == '.html':
try:
Expand Down Expand Up @@ -227,7 +227,7 @@ def crawler(url:str, max_urls:int=1000, max_depth:int=3, timeout:int=1, base_url
if url.startswith(site):
url, s, filetype = valid_url(url)
if url:
print("Scraped:", url)
print("Scraped:", url, "✅")
url_contents.append((url, s, filetype))
else:
_invalid_urls.append(url)
Expand All @@ -236,7 +236,7 @@ def crawler(url:str, max_urls:int=1000, max_depth:int=3, timeout:int=1, base_url
else:
url, s, filetype = valid_url(url)
if url:
print("Scraped:", url)
print("Scraped:", url, "✅")
url_contents.append((url, s, filetype))
else:
_invalid_urls.append(url)
Expand Down Expand Up @@ -285,6 +285,17 @@ def crawler(url:str, max_urls:int=1000, max_depth:int=3, timeout:int=1, base_url

return url_contents

def is_github_repo(url):
pattern = re.compile(r'^https://github\.com/[^/]+/[^/]+$')
if not pattern.match(url):
return False

response = requests.head(url)
if response.status_code == 200 and response.headers['Content-Type'].startswith('text/html'):
return url
else:
return False

def main_crawler(url:str, course_name:str, max_urls:int=100, max_depth:int=3, timeout:int=1, stay_on_baseurl:bool=False):
"""
Crawl a site and scrape its content and PDFs, then upload the data to S3 and ingest it.
Expand All @@ -305,8 +316,8 @@ def main_crawler(url:str, course_name:str, max_urls:int=100, max_depth:int=3, ti
timeout = int(timeout)
stay_on_baseurl = bool(stay_on_baseurl)
if stay_on_baseurl:
stay_on_baseurl = base_url(url)
print(stay_on_baseurl)
baseurl = base_url(url)
print(baseurl)

ingester = Ingest()
s3_client = boto3.client(
Expand All @@ -316,7 +327,7 @@ def main_crawler(url:str, course_name:str, max_urls:int=100, max_depth:int=3, ti
)

# Check for GitHub repository coming soon
if url.startswith("https://github.com/"):
if is_github_repo(url):
print("Begin Ingesting GitHub page")
results = ingester.ingest_github(url, course_name)
print("Finished ingesting GitHub page")
Expand All @@ -331,7 +342,7 @@ def main_crawler(url:str, course_name:str, max_urls:int=100, max_depth:int=3, ti
urls = supabase_client.table(os.getenv('NEW_NEW_NEWNEW_MATERIALS_SUPABASE_TABLE')).select('course_name, url, contexts').eq('course_name', course_name).execute()
del supabase_client
if urls.data == []:
existing_urls = None
existing_urls = []
else:
existing_urls = []
for thing in urls.data:
Expand All @@ -340,13 +351,13 @@ def main_crawler(url:str, course_name:str, max_urls:int=100, max_depth:int=3, ti
whole += t['text']
existing_urls.append((thing['url'], whole))
print("Finished gathering existing urls from Supabase")
print("Length of existing urls:", len(existing_urls))
except Exception as e:
print("Error:", e)
print("Could not gather existing urls from Supabase")
existing_urls = None

existing_urls = []
print("Begin Ingesting Web page")
data = crawler(url=url, max_urls=max_urls, max_depth=max_depth, timeout=timeout, base_url_on=stay_on_baseurl, _existing_urls=existing_urls)
data = crawler(url=url, max_urls=max_urls, max_depth=max_depth, timeout=timeout, base_url_on=baseurl, _existing_urls=existing_urls)

# Clean some keys for a proper file name
# todo: have a default title
Expand Down