Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 133 additions & 5 deletions 04-UX-demos/05-strands-playground/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,50 @@ def save_agent_state(self, user_id):
try:
logger.debug("TABLE_NAME and TABLE_REGION environment variable not set, fallback to local file session management, saving conversation to file")
os.makedirs("sessions", exist_ok=True)
# Check if each message is JSON serializable
# Recursive function to filter out binary image data
def filter_binary_data(obj):
if isinstance(obj, bytes):
# Directly handle bytes objects
return "[Binary data removed for serialization]"
elif isinstance(obj, dict):
# If this is an image object, replace it with a text description
if 'image' in obj:
return {"text": "[Image data removed for serialization]"}

# Otherwise, recursively filter each value in the dictionary
return {k: filter_binary_data(v) for k, v in obj.items()}
elif isinstance(obj, list):
# Recursively filter each item in the list
return [filter_binary_data(item) for item in obj]
else:
# Return primitive values as is
return obj

# Process all messages to ensure they're serializable
serializable_messages = []
for message in self.messages:
try:
# Create a deep copy and filter binary data
import copy
filtered_message = copy.deepcopy(message)
filtered_message = filter_binary_data(filtered_message)

# Test if the filtered message is JSON serializable
json.dumps(filtered_message)
serializable_messages.append(filtered_message)
except (TypeError, OverflowError) as e:
# If still not serializable, log the error and add a placeholder
logger.warning(f"Found non-serializable message after filtering: {str(e)}")
# Add a placeholder that maintains the role
if isinstance(message, dict) and 'role' in message:
serializable_messages.append({
"role": message.get("role", "unknown"),
"content": [{"text": "[Content not serializable]"}]
})

state = {
"messages": self.messages
"messages": serializable_messages
}
# Store state (e.g., database, file system, cache)
with open(f"sessions/{user_id}.json", "w") as f:
Expand All @@ -185,13 +227,55 @@ def save_agent_state(self, user_id):
logger.debug(f"Saving conversation to dynamodb table {table_name} in {table_region} region")
dynamodb = boto3.resource('dynamodb', region_name=table_region)
table = dynamodb.Table(table_name)
# Check if each message is JSON serializable
# Recursive function to filter out binary image data
def filter_binary_data(obj):
if isinstance(obj, bytes):
# Directly handle bytes objects
return "[Binary data removed for serialization]"
elif isinstance(obj, dict):
# If this is an image object, replace it with a text description
if 'image' in obj:
return {"text": "[Image data removed for serialization]"}

# Otherwise, recursively filter each value in the dictionary
return {k: filter_binary_data(v) for k, v in obj.items()}
elif isinstance(obj, list):
# Recursively filter each item in the list
return [filter_binary_data(item) for item in obj]
else:
# Return primitive values as is
return obj

# Process all messages to ensure they're serializable
serializable_messages = []
for message in self.messages:
try:
# Create a deep copy and filter binary data
import copy
filtered_message = copy.deepcopy(message)
filtered_message = filter_binary_data(filtered_message)

# Test if the filtered message is JSON serializable
json.dumps(filtered_message)
serializable_messages.append(filtered_message)
except (TypeError, OverflowError) as e:
# If still not serializable, log the error and add a placeholder
logger.warning(f"Found non-serializable message after filtering: {str(e)}")
# Add a placeholder that maintains the role
if isinstance(message, dict) and 'role' in message:
serializable_messages.append({
"role": message.get("role", "unknown"),
"content": [{"text": "[Content not serializable]"}]
})

state = {
"messages": self.messages
"messages": serializable_messages
}
table.put_item(
Item={
primary_key: user_id,
'messages': state['messages']
'messages': serializable_messages
}
)
except ClientError as e:
Expand Down Expand Up @@ -264,12 +348,56 @@ def get_agent_response(request: PromptRequest):
logger.debug(f"Model response: {result.message}")
agent.save_agent_state(request.userId)
logger.info(f"Agent state saved for user: {request.userId}")
return {
"messages": result.message,

# Create a deep copy of the message to avoid modifying the original
import copy
filtered_message = copy.deepcopy(result.message)

# Recursive function to filter out binary image data
def filter_binary_data(obj):
if isinstance(obj, bytes):
# Directly handle bytes objects
return "[Binary data removed for serialization]"
elif isinstance(obj, dict):
# If this is an image object, replace it with a text description
if 'image' in obj:
return {"text": "[Image data removed for serialization]"}

# Otherwise, recursively filter each value in the dictionary
return {k: filter_binary_data(v) for k, v in obj.items()}
elif isinstance(obj, list):
# Recursively filter each item in the list
return [filter_binary_data(item) for item in obj]
else:
# Return primitive values as is
return obj

# Apply the filtering to the entire message
filtered_message = filter_binary_data(filtered_message)

# Create a serializable response
response = {
"messages": filtered_message,
"latencyMs": result.metrics.accumulated_metrics["latencyMs"],
"totalTokens": result.metrics.accumulated_usage["totalTokens"],
"summary": result.metrics.get_summary()
}

# Verify that the response is serializable
import json
try:
json.dumps(response)
except TypeError as e:
logger.error(f"Response still contains non-serializable data: {str(e)}")
# Fall back to a minimal response if serialization fails
return {
"messages": {"role": "assistant", "content": [{"text": "Image generated successfully but couldn't be included in the response."}]},
"latencyMs": result.metrics.accumulated_metrics["latencyMs"],
"totalTokens": result.metrics.accumulated_usage["totalTokens"],
"summary": "Image generation completed"
}

return response
except Exception as e:
logger.error(f"Error processing agent response: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error processing agent response: {str(e)}")
Expand Down