Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making codeQA a bit faster #2

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


Blog Links:

[An attempt to build cursor's @codebase feature - RAG on codebases - part 1](https://blog.lancedb.com/rag-codebase-1/)
Expand All @@ -10,6 +9,11 @@ A powerful code search and query system that lets you explore codebases using na

> **Note**: New OpenAI/Anthropic accounts may experience token rate limits. Consider using an established account.

# Optimized Branch

Please read this quick write up about the optimizations [here](https://sankalp.bearblog.dev/lessons-from-speeding-up-codeqa/)
This branch runs 2.5x faster than the main branch in worst case.

## What is CodeQA?

CodeQA helps you understand codebases by:
Expand Down Expand Up @@ -67,7 +71,10 @@ Create a .env file and add the following:
```
OPENAI_API_KEY="your-openai-api-key"
JINA_API_KEY="your-jina-api-key"
SAMBANOVA_API_KEY="your-sambanova-api-key"
```

This branch uses SambaNova's API for faster LLM processing - 2x speed up over gpt4o-mini timings
## Building the Codebase Index

To build the index for the codebase, run the following script:
Expand Down
250 changes: 209 additions & 41 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
import json
from dotenv import load_dotenv
from redis import ConnectionPool
import time
from concurrent.futures import ThreadPoolExecutor
import openai

load_dotenv()

from prompts import (
HYDE_SYSTEM_PROMPT,
HYDE_V2_SYSTEM_PROMPT,
CHAT_SYSTEM_PROMPT
CHAT_SYSTEM_PROMPT,
RERANK_PROMPT
)

# Configuration
Expand All @@ -35,14 +39,29 @@

# Logging setup
def setup_logging(config):
logging.basicConfig(
filename=config['LOG_FILE'],
level=logging.INFO,
format=config['LOG_FORMAT'],
# Create a formatter
formatter = logging.Formatter(
config['LOG_FORMAT'],
datefmt=config['LOG_DATE_FORMAT']
)
# Return a logger instance
return logging.getLogger(__name__)

# Setup file handler
file_handler = logging.FileHandler(config['LOG_FILE'])
file_handler.setFormatter(formatter)

# Setup console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)

# Get the logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Add both handlers
logger.addHandler(file_handler)
logger.addHandler(console_handler)

return logger

# Database setup
def setup_database(codebase_path):
Expand Down Expand Up @@ -88,16 +107,21 @@ def markdown_filter(text):
app = setup_app()

# OpenAI client setup
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
client = openai.OpenAI(
api_key=os.environ.get("SAMBANOVA_API_KEY"),
base_url="https://api.sambanova.ai/v1",
)


# Initialize the reranker
reranker = AnswerdotaiRerankers(column="source_code")

# Replace groq_hyde function
def openai_hyde(query):
chat_completion = client.chat.completions.create(
chat_completion = openai_client.chat.completions.create(
model="gpt-4o-mini",
max_tokens=400,
messages=[
{
"role": "system",
Expand All @@ -109,28 +133,33 @@ def openai_hyde(query):
}
]
)
app.logger.info(f"First HYDE response: {chat_completion.choices[0].message.content}")
return chat_completion.choices[0].message.content

def openai_hyde_v2(query, temp_context, hyde_query):
chat_completion = client.chat.completions.create(
chat_completion = openai_client.chat.completions.create(
model="gpt-4o-mini",
max_tokens=768,
messages=[
{
"role": "system",
"content": HYDE_V2_SYSTEM_PROMPT.format(query=query, temp_context=temp_context)
"content": HYDE_V2_SYSTEM_PROMPT.format(temp_context=temp_context)
},
{
"role": "user",
"content": f"Predict the answer to the query: {hyde_query}",
"content": f"Predict the answer to the query: {query}",
}
]
)
app.logger.info(f"Second HYDE response: {chat_completion.choices[0].message.content}")
return chat_completion.choices[0].message.content


def openai_chat(query, context):
start_time = time.time()

chat_completion = client.chat.completions.create(
model="gpt-4o",
model='Meta-Llama-3.1-70B-Instruct',
messages=[
{
"role": "system",
Expand All @@ -142,6 +171,32 @@ def openai_chat(query, context):
}
]
)

chat_time = time.time() - start_time
app.logger.info(f"Chat response took: {chat_time:.2f} seconds")

return chat_completion.choices[0].message.content

def rerank_using_small_model(query, context):
start_time = time.time()

chat_completion = client.chat.completions.create(
model='Meta-Llama-3.1-8B-Instruct',
messages=[
{
"role": "system",
"content": RERANK_PROMPT.format(context=context)
},
{
"role": "user",
"content": query,
}
]
)

chat_time = time.time() - start_time
app.logger.info(f"Llama 8B reranker response took: {chat_time:.2f} seconds")

return chat_completion.choices[0].message.content

def process_input(input_text):
Expand All @@ -152,48 +207,118 @@ def process_input(input_text):
return processed_text

def generate_context(query, rerank=False):
start_time = time.time()

# First HYDE call
hyde_query = openai_hyde(query)
hyde_time = time.time()
app.logger.info(f"First HYDE call took: {hyde_time - start_time:.2f} seconds")

method_docs = method_table.search(hyde_query).limit(5).to_pandas()
class_docs = class_table.search(hyde_query).limit(5).to_pandas()
# Concurrent execution of first database searches
def search_method_table():
return method_table.search(hyde_query).limit(5).to_pandas()

temp_context = '\n'.join(method_docs['code'] + '\n'.join(class_docs['source_code']) )
def search_class_table():
return class_table.search(hyde_query).limit(5).to_pandas()

hyde_query_v2 = openai_hyde_v2(query, temp_context, hyde_query)
with ThreadPoolExecutor(max_workers=2) as executor:
future_method_docs = executor.submit(search_method_table)
future_class_docs = executor.submit(search_class_table)
method_docs = future_method_docs.result()
class_docs = future_class_docs.result()

logging.info("-query_v2-")
logging.info(hyde_query_v2)
first_search_time = time.time()
app.logger.info(f"First DB search took: {first_search_time - hyde_time:.2f} seconds")

method_search = method_table.search(hyde_query_v2)
class_search = class_table.search(hyde_query_v2)
temp_context = '\n'.join(method_docs['code'].tolist() + class_docs['source_code'].tolist())

if rerank:
method_search = method_search.rerank(reranker)
class_search = class_search.rerank(reranker)
# Second HYDE call
hyde_query_v2 = openai_hyde_v2(query, temp_context, hyde_query)
second_hyde_time = time.time()
app.logger.info(f"Second HYDE call took: {second_hyde_time - first_search_time:.2f} seconds")

method_docs = method_search.limit(5).to_list()
class_docs = class_search.limit(5).to_list()
# Concurrent execution of second database searches
def search_method_table_v2():
return method_table.search(hyde_query_v2)

top_3_methods = method_docs[:3]
methods_combined = "\n\n".join(f"File: {doc['file_path']}\nCode:\n{doc['code']}" for doc in top_3_methods)
def search_class_table_v2():
return class_table.search(hyde_query_v2)

top_3_classes = class_docs[:3]
classes_combined = "\n\n".join(f"File: {doc['file_path']}\nClass Info:\n{doc['source_code']} References: \n{doc['references']} \n END OF ROW {i}" for i, doc in enumerate(top_3_classes))
with ThreadPoolExecutor(max_workers=2) as executor:
future_method_search = executor.submit(search_method_table_v2)
future_class_search = executor.submit(search_class_table_v2)
method_search = future_method_search.result()
class_search = future_class_search.result()

app.logger.info("Classes Combined:")
app.logger.info("-" * 40)
app.logger.info(classes_combined)
app.logger.info(f"Length of classes_combined: {len(classes_combined)}")
app.logger.info("-" * 40)
search_time = time.time()
app.logger.info(f"Second DB search took: {search_time - second_hyde_time:.2f} seconds")

app.logger.info("Methods Combined:")
app.logger.info("-" * 40)
app.logger.info(methods_combined)
app.logger.info("-" * 40)
# Concurrent reranking if enabled
app.logger.info(f"Reranking enabled: {rerank}")
if rerank:
rerank_start_time = time.time() # Start timing before reranking

def rerank_method_search():
return method_search.rerank(reranker)

def rerank_class_search():
return class_search.rerank(reranker)

with ThreadPoolExecutor(max_workers=2) as executor:
future_method_search = executor.submit(rerank_method_search)
future_class_search = executor.submit(rerank_class_search)
method_search = future_method_search.result()
class_search = future_class_search.result()

rerank_time = time.time()
app.logger.info(f"Reranking took: {rerank_time - rerank_start_time:.2f} seconds")

# Set final time reference point
rerank_time = time.time() if rerank else search_time

# Fetch top documents
method_docs = method_search.limit(5).to_list()
class_docs = class_search.limit(5).to_list()
final_search_time = time.time()
app.logger.info(f"Final DB search took: {final_search_time - rerank_time:.2f} seconds")

def process_methods():
top_3_methods = method_docs[:3]
methods_combined = "\n\n".join(
f"File: {doc['file_path']}\nCode:\n{doc['code']}" for doc in top_3_methods
)
return rerank_using_small_model(query, methods_combined)

def process_classes():
top_3_classes = class_docs[:3]
classes_combined = "\n\n".join(
f"File: {doc['file_path']}\nClass Info:\n{doc['source_code']} References: \n{doc['references']} \n END OF ROW {i}"
for i, doc in enumerate(top_3_classes)
)
return rerank_using_small_model(query, classes_combined)

# Parallel execution of reranking
parallel_start_time = time.time()
with ThreadPoolExecutor(max_workers=2) as executor:
future_methods = executor.submit(process_methods)
future_classes = executor.submit(process_classes)
methods_context = future_methods.result()
classes_context = future_classes.result()
parallel_time = time.time() - parallel_start_time
app.logger.info(f"Parallel reranking took: {parallel_time:.2f} seconds")

final_context = f"{methods_context}\n{classes_context}"

app.logger.info(f"Final context: {final_context}")

app.logger.info("Context generation complete.")

return methods_combined + "\n below is class or constructor related code \n" + classes_combined
total_time = time.time() - start_time
app.logger.info(f"Total context generation took: {total_time:.2f} seconds")
return final_context


# return methods_combined + "\n below is class or constructor related code \n" + classes_combined

@app.route('/', methods=['GET', 'POST'])
def home():
Expand Down Expand Up @@ -224,7 +349,7 @@ def home():
context = context.decode()

# Now, apply reranking during the chat response if needed
response = openai_chat(query, context[:12000]) # Adjust as needed
response = openai_chat(query, context[:8192]) # Adjust as needed

# Store the conversation history
redis_key = f"user:{user_id}:responses"
Expand Down Expand Up @@ -257,4 +382,47 @@ def home():
# Setup database
method_table, class_table = setup_database(codebase_path)

app.logger.info("Server starting up...") # Test log message
app.run(host='0.0.0.0', port=5001)


# Main latency here is because of Context + LLM processing so need faster LLM


# SambaNova halves the total effective time
# 13-Nov-24 03:33:14 - First HYDE call took: 2.20 seconds
# 13-Nov-24 03:33:15 - First DB search took: 1.44 seconds
# 13-Nov-24 03:33:20 - Second HYDE call took: 4.91 seconds
# 13-Nov-24 03:33:22 - Second DB search took: 1.53 seconds
# 13-Nov-24 03:33:22 - Reranking enabled: True
# 13-Nov-24 03:33:22 - Reranking took: 0.00 seconds
# 13-Nov-24 03:33:22 - Final DB search took: 0.55 seconds
# 13-Nov-24 03:33:22 - Context generation complete.
# 13-Nov-24 03:33:22 - Total context generation took: 10.63 seconds
# 13-Nov-24 03:33:22 - Generated context for query with @codebase.
# 13-Nov-24 03:33:28 - Chat response took: 5.59 seconds

# 127.0.0.1 - - [13/Nov/2024 02:45:06] "GET / HTTP/1.1" 200 -
# 13-Nov-24 02:45:21 - First HYDE call took: 3.05 seconds
# 13-Nov-24 02:45:23 - First DB search took: 2.36 seconds
# 13-Nov-24 02:45:34 - Second HYDE call took: 10.82 seconds
# 13-Nov-24 02:45:36 - Reranking took: 2.44 seconds
# 13-Nov-24 02:45:37 - Second DB search took: 0.65 seconds
# 13-Nov-24 02:45:37 - Context generation complete.
# 13-Nov-24 02:45:37 - Total context generation took: 19.32 seconds
# 13-Nov-24 02:45:37 - Generated context for query with @codebase.
# 13-Nov-24 02:46:00 - Chat response took: 23.01 seconds


# 127.0.0.1 - - [13/Nov/2024 03:01:37] "POST / HTTP/1.1" 200 -
# 13-Nov-24 03:01:54 - First HYDE call took: 3.18 seconds
# 13-Nov-24 03:01:55 - First DB search took: 1.28 seconds
# 13-Nov-24 03:02:02 - Second HYDE call took: 6.87 seconds
# 13-Nov-24 03:02:03 - Second DB search took: 0.85 seconds
# 13-Nov-24 03:02:03 - Reranking took: 0.00 seconds
# 13-Nov-24 03:02:04 - Final DB search took: 0.68 seconds
# 13-Nov-24 03:02:04 - Context generation complete.
# 13-Nov-24 03:02:04 - Total context generation took: 12.86 seconds
# 13-Nov-24 03:02:04 - Generated context for query with @codebase.
# 13-Nov-24 03:02:26 - Chat response took: 22.19 seconds

Loading