This repository has been archived by the owner on Jun 10, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
70d53f7
commit 80143ce
Showing
32 changed files
with
1,898 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
FROM python:3.11.0 | ||
WORKDIR /daath-ai-parser-classifier | ||
COPY ./requirements.txt /daath-ai-parser-classifier/requirements.txt | ||
RUN pip install --no-cache-dir --upgrade -r /daath-ai-parser-classifier/requirements.txt | ||
COPY ./app /daath-ai-parser-classifier/app | ||
CMD ["gunicorn", "app.main:app", "--workers", "4", "--worker-class", "uvicorn.workers.UvicornWorker", "--bind", "0.0.0.0"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from app.classify.functions.parser import Parser | ||
from app.classify.functions.prompt_creator import PromptCreator | ||
from app.classify.functions.results import Results | ||
from app.schemas import * | ||
import importlib | ||
import json | ||
import os | ||
|
||
class Classify: | ||
def classify(self, targets = Targets): | ||
if targets.mock_name != None and "PYTEST_CURRENT_TEST" not in os.environ: | ||
return {"error": "Mock name is only allowed in unit tests."} | ||
|
||
if targets.parse_only == None: | ||
# Create a new mock dictionary for target if given | ||
if targets.mock_name != None and os.environ["PYTEST_CURRENT_TEST"] != None: | ||
try: | ||
with open(targets.mock_name) as json_file: | ||
targets_dict = json.load(json_file) | ||
except: | ||
targets_dict = targets.dict() | ||
if "openai_key" in targets_dict: | ||
targets_dict.pop('openai_key') | ||
if "save_locally" in targets_dict: | ||
targets_dict.pop('save_locally') | ||
if "save_name" in targets_dict: | ||
targets_dict.pop('save_name') | ||
with open(targets.mock_name, "w") as pretty_json: | ||
json.dump(targets_dict, pretty_json, indent=2, sort_keys=False) | ||
targets = Targets(**targets_dict) | ||
|
||
# Call the parser command dictionary or return an error | ||
if targets.classifier != None and targets.parse_only == False: | ||
classifier = json_to_pydantic(targets.classifier.dict()) | ||
else: | ||
try: | ||
parser = targets.path | ||
path = "app.classify.parsers.{}".format(parser.lower()) | ||
classifier = importlib.import_module(path) | ||
classifier = classifier.commands() | ||
if type(classifier) == dict: | ||
return classifier | ||
except: | ||
return {"error": "Could not find parser classifier commands"} | ||
|
||
# Parse the incoming body whether it is html, text, or a mixbag of them | ||
parser = Parser(classifier = classifier) | ||
desired_lines = parser.parse(targets.targets) | ||
|
||
# Create a prompt, get maximum response token size, get estimated maximum token size | ||
prompt_objects = PromptObjects(desired_lines = desired_lines) | ||
prompt_creator = PromptCreator(classifier = classifier, prompt_objects = prompt_objects) | ||
classifier, prompt_objects = prompt_creator.get_prompts() | ||
|
||
# Return an error if all bodies are illegal | ||
if prompt_objects.prompts == []: | ||
return {"error": "None of the items are below maximum token threshold for this prompt."} | ||
|
||
# Return mock prompt results, or create a new one, or return prompt results | ||
if targets.prompts_only == True and targets.mock_name != None and os.environ["PYTEST_CURRENT_TEST"] != None: | ||
mock_prompt_name = targets.mock_name.replace(".json", "-prompt.json") | ||
mock_prompt_name = mock_prompt_name.replace("/targets/", "/prompts/") | ||
try: | ||
with open(mock_prompt_name) as json_file: | ||
prompt = json.load(json_file) | ||
return prompt | ||
except: | ||
prompts_only_dict = { | ||
"prompts": prompt_objects.prompts, | ||
"prompt_objects": { | ||
"invalid_lines_indexes": prompt_objects.invalid_lines_indexes, | ||
"desired_lines": prompt_objects.desired_lines, | ||
"labels": prompt_objects.labels | ||
} | ||
} | ||
with open(mock_prompt_name, "w") as pretty_json: | ||
json.dump(prompts_only_dict, pretty_json, indent=2, sort_keys=False) | ||
return prompts_only_dict | ||
elif targets.prompts_only == True: | ||
return { | ||
"prompts": prompt_objects.prompts, | ||
"prompt_objects": { | ||
"invalid_lines_indexes": prompt_objects.invalid_lines_indexes, | ||
"desired_lines": prompt_objects.desired_lines, | ||
"labels": prompt_objects.labels | ||
} | ||
} | ||
|
||
# Return mock classified results, or create a new one | ||
if targets.mock_name != None and os.environ["PYTEST_CURRENT_TEST"] != None: | ||
mock_result_name = targets.mock_name.replace(".json", "-result.json") | ||
mock_result_name = mock_result_name.replace("/targets/", "/results/") | ||
try: | ||
with open(mock_result_name) as json_file: | ||
result = json.load(json_file) | ||
return result | ||
except: | ||
if targets.parse_only != None: | ||
results = Results(targets = targets, classifier = None, prompt_objects = targets.parse_only.prompt_objects) | ||
results_to_write = results.to_json() | ||
else: | ||
results = Results(targets = targets, classifier = classifier, prompt_objects = prompt_objects) | ||
results.get_results_from_openai() | ||
results_to_write = results.to_json() | ||
with open(mock_result_name, "w") as pretty_json: | ||
json.dump({"results": results_to_write}, pretty_json, indent=2, sort_keys=False) | ||
return {"results": results_to_write} | ||
|
||
# Return classified results | ||
if targets.parse_only != None: | ||
results = Results(targets = targets, classifier = None, prompt_objects = targets.parse_only.prompt_objects) | ||
results_from_parsing = results.to_json() | ||
return {"results": results_from_parsing} | ||
else: | ||
results = Results(targets = targets, classifier = classifier, prompt_objects = prompt_objects) | ||
results.get_results_from_openai() | ||
results_from_openai = results.to_json() | ||
return {"results": results_from_openai} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from selectolax.parser import HTMLParser | ||
from app.schemas import * | ||
import re | ||
|
||
|
||
class Parser: | ||
def __init__(self, classifier = Classifier): | ||
html_regex = "<(\"[^\"]*\"|'[^']*'|[^'\">])*>" | ||
self.html_regex = re.compile(html_regex) | ||
self.classifier = classifier | ||
self.explicitly_excluded_regex = re.compile("|".join(self.classifier.explicitly_excluded_strings)) | ||
|
||
def parse_single(self, text): | ||
tree = HTMLParser(text) | ||
tree = tree.text(separator=' ', strip=True) | ||
tree = re.sub(self.explicitly_excluded_regex,'',tree).strip() | ||
return tree | ||
|
||
def parse(self, texts): | ||
lined_targets = [] | ||
for text in texts: | ||
if re.search(self.html_regex, text): | ||
entry = self.parse_single(text) | ||
else: | ||
entry = re.sub("\n", "", text) | ||
lined_targets.append(entry) | ||
return lined_targets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import re | ||
from app.schemas import * | ||
from transformers import GPT2TokenizerFast | ||
|
||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | ||
|
||
class PromptCreator: | ||
def __init__(self, classifier = Classifier, prompt_objects = PromptObjects): | ||
self.classifier = classifier | ||
self.prompt_objects = prompt_objects | ||
self.keys = [] | ||
self.number_of_labels = 0 | ||
self.example_rows = [] | ||
self.base_prompt_token_size = 0 | ||
self.model_specific_token_size = classifier.model_specific_token_size | ||
|
||
def get_data_from_examples(self): | ||
for example in self.classifier.examples_for_prompt: | ||
[self.keys.append(key) for key in sorted(list(example.classifications.keys())) if key not in self.keys] | ||
self.prompt_objects.labels = [re.sub("_", " ", key.title()) for key in self.keys] | ||
self.number_of_labels = len(self.prompt_objects.labels) | ||
|
||
def get_main_prompt(self): | ||
return re.sub("NUMBER_OF_LABELS", "{}".format(self.number_of_labels), self.classifier.main_prompt) | ||
|
||
def get_example_lines(self): | ||
return "\n".join([example.text for example in self.classifier.examples_for_prompt]) | ||
|
||
def get_desired_lines(self, lines): | ||
desired_lines_string = "\n" + "\n".join(lines) + "\n" | ||
return desired_lines_string | ||
|
||
def get_table_labels(self): | ||
table_labels_string = "|" + "".join(([" {} |".format(label) for label in self.prompt_objects.labels ])) + "\n" | ||
return table_labels_string | ||
|
||
def get_table_separator(self): | ||
table_separator_string = "|" + "".join([" --- |" for i in self.prompt_objects.labels]) + "\n" | ||
return table_separator_string | ||
|
||
def get_example_rows(self): | ||
for example in self.classifier.examples_for_prompt: | ||
row_text = "|" | ||
for key in self.keys: | ||
if key in example.classifications: | ||
row_text = row_text + " {} |".format(example.classifications[key]) | ||
else: | ||
row_text = row_text + " - |" | ||
row_text = row_text + "\n" | ||
self.example_rows.append(row_text) | ||
self.example_rows = "".join(self.example_rows) | ||
self.example_rows = self.example_rows[0:-1] | ||
return self.example_rows | ||
|
||
def calculate_token_size(self, line): | ||
tokenized = tokenizer(line)['input_ids'] | ||
return len(tokenized) | ||
|
||
def get_maximum_token_size(self, previous_max_token_size, line): | ||
token_size = self.number_of_labels + (2 * self.calculate_token_size(line)) + 2 # abc\n|1|abc|-|-|\n | ||
return previous_max_token_size + token_size | ||
|
||
def separate_for_calls(self): | ||
token_sizes_of_lines = [(5 + self.calculate_token_size(line)) for line in self.prompt_objects.desired_lines] | ||
|
||
invalid_lines_indexes = [] | ||
valid_calls = [] | ||
valid_call = [] | ||
previous_max_token_size = 0 | ||
for size, line, i in zip(token_sizes_of_lines, self.prompt_objects.desired_lines, range(0, len(token_sizes_of_lines))): | ||
previous_max_token_size = self.get_maximum_token_size(previous_max_token_size, line) | ||
if (self.base_prompt_token_size + size) > self.model_specific_token_size: | ||
invalid_lines_indexes.append(i) | ||
elif (self.base_prompt_token_size + previous_max_token_size) > self.model_specific_token_size: | ||
valid_call.append(line) | ||
elif (self.base_prompt_token_size + previous_max_token_size) < self.model_specific_token_size: | ||
previous_max_token_size = 0 | ||
valid_calls.append(valid_call) | ||
valid_call = [] | ||
elif i == len(token_sizes_of_lines) - 1: | ||
valid_calls.append(valid_call) | ||
|
||
self.prompt_objects.invalid_lines_indexes = invalid_lines_indexes | ||
|
||
return valid_calls | ||
|
||
def get_prompts(self): | ||
self.get_data_from_examples() | ||
main_prompt_string = self.get_main_prompt() | ||
example_lines_string = self.get_example_lines() | ||
desired_lines_string = self.get_desired_lines(self.prompt_objects.desired_lines) | ||
table_labels_string = self.get_table_labels() | ||
table_separator_string = self.get_table_separator() | ||
example_rows = self.get_example_rows() | ||
|
||
prompt = "".join([ | ||
main_prompt_string, | ||
example_lines_string, | ||
desired_lines_string, | ||
table_labels_string, | ||
table_separator_string, | ||
example_rows | ||
]) | ||
|
||
max_tokens_size = self.get_maximum_token_size(0, prompt) | ||
total_estimated_token_size = self.calculate_token_size(prompt) + max_tokens_size | ||
|
||
if total_estimated_token_size > self.model_specific_token_size: | ||
base_prompt = "".join([ | ||
main_prompt_string, | ||
example_lines_string, | ||
table_labels_string, | ||
table_separator_string, | ||
example_rows | ||
]) | ||
self.base_prompt_token_size = self.calculate_token_size(base_prompt) | ||
valid_calls = self.separate_for_calls() | ||
|
||
if valid_calls == []: | ||
self.prompt_objects.prompts = [] | ||
else: | ||
for i in range(0,len(valid_calls)): | ||
desired_lines_string = self.get_desired_lines(valid_calls[i]) | ||
valid_calls[i] = "".join([ | ||
main_prompt_string, | ||
example_lines_string, | ||
desired_lines_string, | ||
table_labels_string, | ||
table_separator_string, | ||
example_rows | ||
]) | ||
self.prompt_objects.prompts = valid_calls | ||
else: | ||
self.prompt_objects.prompts = [prompt] | ||
self.prompt_objects.invalid_lines_indexes = [] | ||
|
||
return self.classifier, self.prompt_objects |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import asyncio | ||
import aiohttp | ||
import re | ||
from app.schemas import * | ||
|
||
class Results: | ||
def __init__(self, targets = Targets, classifier = Classifier, prompt_objects = PromptObjects): | ||
self.targets = targets | ||
self.classifier = classifier | ||
self.prompt_objects = prompt_objects | ||
if targets.parse_only != None and targets.parse_only.responses != None: | ||
self.responses = targets.parse_only.responses | ||
else: | ||
self.responses = [] | ||
|
||
def get_results_from_openai(self): | ||
async def call_openai(session, prompt): | ||
self.classifier.data.prompt = prompt | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": "Bearer {}".format(self.targets.openai_key) | ||
} | ||
try: | ||
async with session.post(self.classifier.openai_endpoint, headers=headers, json=self.classifier.data.dict()) as resp: | ||
return await resp.json() | ||
except: | ||
return {"error": "Error from Local Machine"} | ||
|
||
async def get_results(concurrent_prompts): | ||
connector = aiohttp.TCPConnector(limit=None) | ||
async with aiohttp.ClientSession(connector=connector) as session: | ||
tasks = [] | ||
for prompt in concurrent_prompts: | ||
tasks.append(asyncio.ensure_future(call_openai(session, prompt))) | ||
return await asyncio.gather(*tasks, return_exceptions=False) | ||
|
||
if self.targets.allowed_concurrency == 1: | ||
self.responses = asyncio.run(get_results(self.prompt_objects.prompts)) | ||
else: | ||
all_prompt_calls = [] | ||
remainder = len(self.prompt_objects.prompts) % self.targets.allowed_concurrency | ||
if remainder != 0 and remainder != len(self.prompt_objects.prompts): | ||
remainder_prompts = self.prompt_objects.prompts[(0 - remainder):] | ||
array_without_remainder = self.prompt_objects.prompts[0:(0 - remainder)] | ||
concurrent_prompts = [] | ||
for prompt, i in zip(array_without_remainder, range(0, len(array_without_remainder))): | ||
concurrent_prompts.append(prompt) | ||
if i != 0 and self.targets.allowed_concurrency % i == 0: | ||
all_prompt_calls.append(concurrent_prompts) | ||
concurrent_prompts = [] | ||
all_prompt_calls.append(remainder_prompts) # [[1,2],[3,4],[5]] | ||
else: | ||
array_without_remainder = self.prompt_objects.prompts | ||
concurrent_prompts = [] | ||
if len(array_without_remainder) == 1: | ||
all_prompt_calls = [array_without_remainder] | ||
else: | ||
for prompt, i in zip(array_without_remainder, range(0, len(array_without_remainder))): | ||
concurrent_prompts.append(prompt) | ||
if i != 0 and self.targets.allowed_concurrency % i == 0: | ||
all_prompt_calls.append(concurrent_prompts) | ||
concurrent_prompts = [] | ||
|
||
for concurrent_prompt_array in all_prompt_calls: | ||
self.responses = self.responses + asyncio.run(get_results(concurrent_prompt_array)) | ||
|
||
def to_json(self): | ||
results = [] | ||
index = 0 | ||
for response in self.responses: | ||
if index in self.prompt_objects.invalid_lines_indexes: | ||
while index not in self.prompt_objects.invalid_lines_indexes: | ||
results.append({"error": "Maximum Token Size is reached for this prompt. This is skipped."}) | ||
index = index + 1 | ||
if 'error' in response: | ||
results.append({"error": response['error']}) | ||
elif 'choices' in response: | ||
response = response['choices'][0]['text'] | ||
lines = response.split("\n") | ||
lines = [line for line in lines if line != ''] | ||
for line, line_index in zip(lines, range(0, len(lines))): | ||
result_dict = {} | ||
line = re.split(r" \| |\| | \|", line) | ||
line = [word for word in line if word != ''] | ||
for i in range(len(line)): | ||
if "#$" in line[i]: # Array | ||
desired_array = [] | ||
array = [word for word in line[i].split("#$") if word != ''] | ||
for word in array: | ||
desired_line = self.prompt_objects.desired_lines[index + line_index] | ||
if word in desired_line: | ||
desired_array.append(word.strip()) | ||
if desired_array != []: | ||
result_dict[self.prompt_objects.labels[i]] = desired_array | ||
elif line[i] != "-" and self.prompt_objects.labels[i] != "Line": # String | ||
desired_line = self.prompt_objects.desired_lines[index + line_index] | ||
if line[i] in desired_line: | ||
result_dict[self.prompt_objects.labels[i]] = line[i] | ||
results.append(result_dict) | ||
index = index + len(lines) | ||
return results |
Oops, something went wrong.