Skip to content

Commit

Permalink
support passing user-info to AOAI for Defender for cloud (#842)
Browse files Browse the repository at this point in the history
  • Loading branch information
slreznit authored May 6, 2024
1 parent 7fef128 commit 08f75ac
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 17 deletions.
48 changes: 32 additions & 16 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from openai import AsyncAzureOpenAI
from azure.identity.aio import DefaultAzureCredential, get_bearer_token_provider
from backend.auth.auth_utils import get_authenticated_user_details
from backend.auth.auth_utils import get_authenticated_user_details, get_tenantid
from backend.history.cosmosdbservice import CosmosConversationClient

from backend.utils import (
Expand Down Expand Up @@ -267,7 +267,8 @@ async def assets(path):
},
"sanitize_answer": SANITIZE_ANSWER,
}

# Enable Microsoft Defender for Cloud Integration
MS_DEFENDER_ENABLED = os.environ.get("MS_DEFENDER_ENABLED", "false").lower() == "true"

def should_use_data():
global DATASOURCE_TYPE
Expand Down Expand Up @@ -723,7 +724,7 @@ def get_configured_data_source():
return data_source


def prepare_model_args(request_body):
def prepare_model_args(request_body, request_headers):
request_messages = request_body.get("messages", [])
messages = []
if not SHOULD_USE_DATA:
Expand All @@ -733,6 +734,20 @@ def prepare_model_args(request_body):
if message:
messages.append({"role": message["role"], "content": message["content"]})

user_json = None
if (MS_DEFENDER_ENABLED):
authenticated_user_details = get_authenticated_user_details(request_headers)
tenantId = get_tenantid(authenticated_user_details.get("client_principal_b64"))
conversation_id = request_body.get("conversation_id", None)
user_args = {
"EndUserId": authenticated_user_details.get('user_principal_id'),
"EndUserIdType": 'Entra',
"EndUserTenantId": tenantId,
"ConversationId": conversation_id,
"SourceIp": request_headers.get('X-Forwarded-For', request_headers.get('Remote-Addr', '')),
}
user_json = json.dumps(user_args)

model_args = {
"messages": messages,
"temperature": float(AZURE_OPENAI_TEMPERATURE),
Expand All @@ -745,6 +760,7 @@ def prepare_model_args(request_body):
),
"stream": SHOULD_STREAM,
"model": AZURE_OPENAI_MODEL,
"user": user_json,
}

if SHOULD_USE_DATA:
Expand Down Expand Up @@ -822,15 +838,15 @@ async def promptflow_request(request):
logging.error(f"An error occurred while making promptflow_request: {e}")


async def send_chat_request(request):
async def send_chat_request(request_body, request_headers):
filtered_messages = []
messages = request.get("messages", [])
messages = request_body.get("messages", [])
for message in messages:
if message.get("role") != 'tool':
filtered_messages.append(message)

request['messages'] = filtered_messages
model_args = prepare_model_args(request)
request_body['messages'] = filtered_messages
model_args = prepare_model_args(request_body, request_headers)

try:
azure_openai_client = init_openai_client()
Expand All @@ -844,21 +860,21 @@ async def send_chat_request(request):
return response, apim_request_id


async def complete_chat_request(request_body):
async def complete_chat_request(request_body, request_headers):
if USE_PROMPTFLOW and PROMPTFLOW_ENDPOINT and PROMPTFLOW_API_KEY:
response = await promptflow_request(request_body)
history_metadata = request_body.get("history_metadata", {})
return format_pf_non_streaming_response(
response, history_metadata, PROMPTFLOW_RESPONSE_FIELD_NAME, PROMPTFLOW_CITATIONS_FIELD_NAME
)
else:
response, apim_request_id = await send_chat_request(request_body)
response, apim_request_id = await send_chat_request(request_body, request_headers)
history_metadata = request_body.get("history_metadata", {})
return format_non_streaming_response(response, history_metadata, apim_request_id)


async def stream_chat_request(request_body):
response, apim_request_id = await send_chat_request(request_body)
async def stream_chat_request(request_body, request_headers):
response, apim_request_id = await send_chat_request(request_body, request_headers)
history_metadata = request_body.get("history_metadata", {})

async def generate():
Expand All @@ -868,16 +884,16 @@ async def generate():
return generate()


async def conversation_internal(request_body):
async def conversation_internal(request_body, request_headers):
try:
if SHOULD_STREAM:
result = await stream_chat_request(request_body)
result = await stream_chat_request(request_body, request_headers)
response = await make_response(format_as_ndjson(result))
response.timeout = None
response.mimetype = "application/json-lines"
return response
else:
result = await complete_chat_request(request_body)
result = await complete_chat_request(request_body, request_headers)
return jsonify(result)

except Exception as ex:
Expand All @@ -894,7 +910,7 @@ async def conversation():
return jsonify({"error": "request must be json"}), 415
request_json = await request.get_json()

return await conversation_internal(request_json)
return await conversation_internal(request_json, request.headers)


@bp.route("/frontend_settings", methods=["GET"])
Expand Down Expand Up @@ -958,7 +974,7 @@ async def add_conversation():
request_body = await request.get_json()
history_metadata["conversation_id"] = conversation_id
request_body["history_metadata"] = history_metadata
return await conversation_internal(request_body)
return await conversation_internal(request_body, request.headers)

except Exception as e:
logging.exception("Exception in /history/generate")
Expand Down
21 changes: 20 additions & 1 deletion backend/auth/auth_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import base64
import json
import logging

def get_authenticated_user_details(request_headers):
user_object = {}

Expand All @@ -17,4 +21,19 @@ def get_authenticated_user_details(request_headers):
user_object['client_principal_b64'] = raw_user_object.get('X-Ms-Client-Principal')
user_object['aad_id_token'] = raw_user_object.get('X-Ms-Token-Aad-Id-Token')

return user_object
return user_object

def get_tenantid(client_principal_b64):
tenant_id = ''
if client_principal_b64:
try:
# Decode the base64 header to get the JSON string
decoded_bytes = base64.b64decode(client_principal_b64)
decoded_string = decoded_bytes.decode('utf-8')
# Convert the JSON string1into a Python dictionary
user_info = json.loads(decoded_string)
# Extract the tenant ID
tenant_id = user_info.get('tid') # 'tid' typically holds the tenant ID
except Exception as ex:
logging.exception(ex)
return tenant_id

0 comments on commit 08f75ac

Please sign in to comment.