Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed py files to seed database and corrected some paths #21

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions azure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ services:
prebuild:
windows:
shell: pwsh
run: cd ../frontend;npm install;npm run build
run: cd ../frontend;npm install;npm run build
interactive: false
continueOnError: false
posix:
shell: sh
run: cd ../frontend;npm install;npm run build
run: cd ../frontend;npm install;npm run build
interactive: false
continueOnError: false
hooks:
Expand Down
649 changes: 301 additions & 348 deletions data/cases_final.csv → data/cases_updated.csv

Large diffs are not rendered by default.

4 changes: 0 additions & 4 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,6 @@ var webAppEnv = union(azureOpenAIKeyEnv, openAIComKeyEnv, [
name: 'POSTGRES_HOST'
value: postgresServer.outputs.POSTGRES_DOMAIN_NAME
}
{
name: 'POSTGRES_ADMIN_LOGIN_KEY'
value: postgresServer.outputs.POSTGRES_ADMIN_LOGIN_KEY
}
{
name: 'POSTGRES_USERNAME'
value: webAppIdentityName
Expand Down
8 changes: 8 additions & 0 deletions infra/skip-postgres.bicep
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
param name string
param location string = resourceGroup().location
param tags object = {}

// Output the existing PostgreSQL connection details
output POSTGRES_HOST string = '' // Will be overridden by .env
output POSTGRES_USERNAME string = '' // Will be overridden by .env
output POSTGRES_DATABASE string = '' // Will be overridden by .env
6 changes: 5 additions & 1 deletion scripts/setup_postgres_azurerole.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ if ([string]::IsNullOrEmpty($POSTGRES_HOST) -or [string]::IsNullOrEmpty($POSTGRE
Write-Host "Can't find POSTGRES_HOST, POSTGRES_USERNAME, and SERVICE_WEB_IDENTITY_NAME environment variables. Make sure you run azd up first."
exit 1
}
Write-Host "Running script with these parameters:"
Write-Host "Host: $POSTGRES_HOST"
Write-Host "Username: $POSTGRES_USERNAME"
Write-Host "App Identity Name: $APP_IDENTITY_NAME"

python ./src/backend/fastapi_app/setup_postgres_azurerole.py --host $POSTGRES_HOST --username $POSTGRES_USERNAME --app-identity-name $APP_IDENTITY_NAME
py ./src/backend/fastapi_app/setup_postgres_azurerole.py --host $POSTGRES_HOST --username $POSTGRES_USERNAME --app-identity-name $APP_IDENTITY_NAME --database $POSTGRES_DB
2 changes: 1 addition & 1 deletion scripts/setup_postgres_database.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ if ([string]::IsNullOrEmpty($POSTGRES_HOST) -or [string]::IsNullOrEmpty($POSTGRE
exit 1
}

python ./src/backend/fastapi_app/setup_postgres_legal_database.py --host $POSTGRES_HOST --username $POSTGRES_USERNAME --database $POSTGRES_DATABASE
py ./src/backend/fastapi_app/setup_postgres_legal_database.py --host $POSTGRES_HOST --username $POSTGRES_USERNAME --database $POSTGRES_DATABASE
2 changes: 1 addition & 1 deletion scripts/setup_postgres_seeddata.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ if ([string]::IsNullOrEmpty($POSTGRES_HOST) -or [string]::IsNullOrEmpty($POSTGRE
exit 1
}

python ./src/backend/fastapi_app/setup_postgres_legal_seeddata.py --host $POSTGRES_HOST --username $POSTGRES_USERNAME --database $POSTGRES_DATABASE
py ./src/backend/fastapi_app/setup_postgres_legal_seeddata.py --host $POSTGRES_HOST --username $POSTGRES_USERNAME --database $POSTGRES_DATABASE
3 changes: 2 additions & 1 deletion scripts/setup_postgres_seeddata.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ APP_IDENTITY_NAME=$(azd env get-value SERVICE_WEB_IDENTITY_NAME)

. ./scripts/load_python_env.sh

.venv/bin/python ./src/backend/fastapi_app/setup_postgres_legal_seeddata.py --host $POSTGRES_HOST --username $POSTGRES_USERNAME --database $POSTGRES_DATABASE --app-identity-name $APP_IDENTITY_NAME
.venv/bin/python ./src/backend/fastapi_app/setup_postgres_legal_seeddata.py --host $POSTGRES_HOST --username $POSTGRES_USERNAME --database $POSTGRES_DATABASE --app-identity-name $APP_IDENTITY_NAME
.venv/bin/python ./src/backend/fastapi_app/get_token.py > ./src/backend/fastapi_app/postgres_token.txt
60 changes: 60 additions & 0 deletions src/backend/fastapi_app/get_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import asyncio
import logging
import os

from azure.identity import AzureDeveloperCliCredential, ManagedIdentityCredential

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


async def get_azure_credential():
"""
Authenticate to Azure using Azure Developer CLI Credential or Managed Identity.
Returns an instance of AzureDeveloperCliCredential or ManagedIdentityCredential.
"""
try:
if client_id := os.getenv("APP_IDENTITY_ID"):
# Authenticate using a user-assigned managed identity on Azure
# See web.bicep for the value of APP_IDENTITY_ID
logger.info("Using managed identity for client ID %s", client_id)
azure_credential = ManagedIdentityCredential(client_id=client_id)
else:
if tenant_id := os.getenv("AZURE_TENANT_ID"):
logger.info(
"Authenticating to Azure using Azure Developer CLI Credential for tenant %s",
tenant_id,
)
azure_credential = AzureDeveloperCliCredential(tenant_id=tenant_id, process_timeout=60)
else:
logger.info("Authenticating to Azure using Azure Developer CLI Credential")
azure_credential = AzureDeveloperCliCredential(process_timeout=60)
return azure_credential
except Exception as e:
logger.warning("Failed to authenticate to Azure: %s", e)
raise e


def get_password_from_azure_credential():
"""
Fetch the Azure token using the credential obtained from get_azure_credential.
Returns the token string.
"""

async def get_token():
azure_credential = await get_azure_credential()
token = azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
return token.token

# Run the asynchronous token retrieval in the event loop
loop = asyncio.get_event_loop()
return loop.run_until_complete(get_token())


if __name__ == "__main__":
try:
password = get_password_from_azure_credential()
print(password)
except Exception as e:
logger.error("Failed to retrieve password: %s", e)
18 changes: 2 additions & 16 deletions src/backend/fastapi_app/postgres_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,8 @@ def to_str_for_rag(self):
"""
Converts Case to a string representation for Retrieval-Augmented Generation (RAG) usage.
"""
# data_fields = " ".join([f"{key}:{value}" for key, value in self.data.items()])
# return f"ID: {self.id} Data: {data_fields}"
casebody_text = self.data.get("casebody", {}).get("opinions", [{}])[0].get("text", "")
truncated_text = casebody_text[:800] # Truncate to 800 characters

# Include truncated data alongside other fields if necessary
data_fields = [
f"{key}:{value}"
for key, value in self.data.items()
if key != "casebody" # Exclude the large 'casebody' field
]
data_fields.append(f"casebody_opinion_text:{truncated_text}")

# Join the fields into a single string
data_str = " ".join(data_fields)
return f"ID: {self.id} Data: {data_str}"
data_fields = " ".join([f"{key}:{value}" for key, value in self.data.items()])
return f"ID: {self.id} Data: {data_fields}"

def to_str_for_embedding(self):
"""
Expand Down
74 changes: 42 additions & 32 deletions src/backend/fastapi_app/postgres_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,45 +56,56 @@ async def search(
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
table_name = Case.__tablename__

try:
script_path = Path(__file__).parent / "setup_postgres_age.py"
token_file = Path(__file__).parent / "postgres_token.txt"

if not script_path.exists():
logger.error(f"Setup script not found at {script_path}")
raise FileNotFoundError(f"Script {script_path} does not exist.")
if token_file.exists() and retrieval_mode == RetrievalMode.GRAPHRAG:
try:
script_path = Path(__file__).parent / "setup_postgres_age.py"

logger.info("Running setup_postgres_age.py...")
subprocess.run(["python", str(script_path)], check=True)
logger.info("setup_postgres_age.py completed successfully.")
if not script_path.exists():
logger.error(f"Setup script not found at {script_path}")
raise FileNotFoundError(f"Script {script_path} does not exist.")

except subprocess.CalledProcessError as e:
logger.error(f"Error occurred while running setup_postgres_age.py: {e}")
except Exception as e:
logger.error(f"Unexpected error: {e}")
logger.info("Running setup_postgres_age.py...")
subprocess.run(["python", str(script_path)], check=True)
logger.info("setup_postgres_age.py completed successfully.")

await self.db_session.execute(text('SET search_path = ag_catalog, "$user", public;'))
except subprocess.CalledProcessError as e:
logger.error(f"Error occurred while running setup_postgres_age.py: {e}")
except Exception as e:
logger.error(f"Unexpected error: {e}")

ranking_type_map = {
RetrievalMode.GRAPHRAG: "score",
RetrievalMode.SEMANTIC: "semantic_rank",
RetrievalMode.VECTOR: "vector_rank",
}
await self.db_session.execute(text('SET search_path = ag_catalog, "$user", public;'))

if retrieval_mode not in ranking_type_map:
if retrieval_mode == RetrievalMode.GRAPHRAG:
function_call = """
SELECT * FROM get_vector_semantic_graphrag_cases(
:query_text,
CAST(:embedding AS vector(1536)),
:top_n,
:consider_n
);
"""
elif retrieval_mode == RetrievalMode.SEMANTIC:
function_call = """
SELECT * FROM get_vector_semantic_cases(
:query_text,
CAST(:embedding AS vector(1536)),
:top_n,
:consider_n
);
"""
elif retrieval_mode == RetrievalMode.VECTOR:
function_call = """
SELECT * FROM get_vector_cases(
:query_text,
CAST(:embedding AS vector(1536)),
:top_n
);
"""
else:
raise ValueError("Invalid retrieval_mode. Options are: VECTOR, SEMANTIC, GRAPHRAG")

ranking_type = ranking_type_map[retrieval_mode]

function_call = f"""
SELECT * FROM get_vector_semantic_graphrag_optimized(
:query_text,
CAST(:embedding AS vector(1536)),
:top_n,
:consider_n,
'{ranking_type}'
);
"""

sql = text(function_call).columns(
column("rrf.id", String),
)
Expand Down Expand Up @@ -123,7 +134,6 @@ async def search(
id = row.id # Adjust if column names differ
# Fetch the corresponding row using the ID
item = await self.db_session.execute(select(Case).where(Case.id == id))
# logger.info(f"Item found: {item.scalar()}")
row_models.append(item.scalar())
return row_models

Expand Down
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
13 changes: 8 additions & 5 deletions src/backend/fastapi_app/setup_postgres_age.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def configure_age(conn, app_identity_name):
logger.info("Executing first MATCH query...")
cur.execute(
"""SELECT * FROM cypher('case_graph', $$
MATCH ()-[r:REF]->()
MATCH ()-[r:CITES]->()
RETURN COUNT(r) AS cites_count
$$) AS (cites_count agtype);"""
)
Expand All @@ -39,7 +39,7 @@ def configure_age(conn, app_identity_name):
logger.info("Executing second MATCH query...")
cur.execute(
"""SELECT * FROM cypher('case_graph', $$
MATCH ()-[r:REF]->()
MATCH ()-[r:CITES]->()
RETURN COUNT(r) AS cites_count
$$) AS (cites_count agtype);"""
)
Expand All @@ -49,12 +49,16 @@ def configure_age(conn, app_identity_name):


def main():
# Fetch environment variables
host = os.getenv("POSTGRES_HOST")
username = "legalcaseadmin"
username = os.getenv("POSTGRES_ADMIN")
database = os.getenv("POSTGRES_DATABASE")
sslmode = os.getenv("POSTGRES_SSLMODE", "require") # Default to 'require' SSL mode
app_identity_name = os.getenv("APP_IDENTITY_NAME")
password = os.getenv("POSTGRES_ADMIN_LOGIN_KEY")

filepath = os.path.join(os.path.dirname(__file__), "postgres_token.txt")
with open(filepath) as file:
password = file.read().strip()

# Ensure environment variables are set
if not all([host, username, password, database, app_identity_name]):
Expand All @@ -65,7 +69,6 @@ def main():
if sslmode.lower() in ["require", "verify-ca", "verify-full"]:
sslmode_params["sslmode"] = sslmode

conn = None
try:
# Connect to PostgreSQL using psycopg2
conn = psycopg2.connect(
Expand Down
Loading