Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/api #1

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 16 additions & 52 deletions privateGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from langchain.llms import GPT4All, LlamaCpp
import chromadb
import os
import argparse
import time

if not load_dotenv():
print("Could not load .env file or it is empty. Please check if it exists and is readable.")
Expand All @@ -20,68 +18,34 @@
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
model_n_batch = int(os.environ.get('MODEL_N_BATCH', 8))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS', 4))

from constants import CHROMA_SETTINGS

def main():

def init():
# Parse the command line arguments
args = parse_arguments()
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)
chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS , path=persist_directory)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS, client=chroma_client)
chroma_client = chromadb.PersistentClient(settings=CHROMA_SETTINGS, path=persist_directory)
db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS,
client=chroma_client)
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
# activate/deactivate the streaming StdOut callback for LLMs
callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()]
callbacks = [StreamingStdOutCallbackHandler()]
# Prepare the LLM
match model_type:
case "LlamaCpp":
llm = LlamaCpp(model_path=model_path, max_tokens=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks, verbose=False)
llm = LlamaCpp(model_path=model_path, max_tokens=model_n_ctx, n_batch=model_n_batch, callbacks=callbacks,
verbose=False)
case "GPT4All":
llm = GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False)
llm = GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch,
callbacks=callbacks, verbose=False)
case _default:
# raise exception if model_type is not supported
raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All")

qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents= not args.hide_source)
# Interactive questions and answers
while True:
query = input("\nEnter a query: ")
if query == "exit":
break
if query.strip() == "":
continue

# Get the answer from the chain
start = time.time()
res = qa(query)
answer, docs = res['result'], [] if args.hide_source else res['source_documents']
end = time.time()

# Print the result
print("\n\n> Question:")
print(query)
print(f"\n> Answer (took {round(end - start, 2)} s.):")
print(answer)

# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
print(document.page_content)

def parse_arguments():
parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, '
'using the power of LLMs.')
parser.add_argument("--hide-source", "-S", action='store_true',
help='Use this flag to disable printing of source documents used for answers.')

parser.add_argument("--mute-stream", "-M",
action='store_true',
help='Use this flag to disable the streaming StdOut callback for LLMs.')

return parser.parse_args()
raise Exception(
f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All")

qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=False)

if __name__ == "__main__":
main()
return qa
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ pandoc==2.3
pypandoc==1.11
tqdm==4.66.1
sentence_transformers==2.2.2
flask==2.0.1
waitress==2.1.2
33 changes: 33 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from flask import Flask, request, jsonify
from privateGPT import init
import time
import uuid

app = Flask(__name__)
qa = None

@app.route("/query", methods=["GET"])
def query():
req_id = str(uuid.uuid4())
print(f"Request {req_id} received")
q = request.args.get("q")

if q is None or q == '':
return jsonify(query=q, answer="Empty input")

start = time.time()
res = qa(q)
answer, docs = res['result'], []
end = time.time()
print(f"Request {req_id} | Query: {q} | Answer: {answer} | Time: {end-start}")
return jsonify(query=q, answer=res['result'])

@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "OK"})


if __name__ == '__main__':
qa = init()
from waitress import serve
serve(app, host="0.0.0.0", port=5000)