diff --git a/04-UX-demos/05-strands-playground/app/main.py b/04-UX-demos/05-strands-playground/app/main.py index fa420a63..98f53912 100644 --- a/04-UX-demos/05-strands-playground/app/main.py +++ b/04-UX-demos/05-strands-playground/app/main.py @@ -172,8 +172,50 @@ def save_agent_state(self, user_id): try: logger.debug("TABLE_NAME and TABLE_REGION environment variable not set, fallback to local file session management, saving conversation to file") os.makedirs("sessions", exist_ok=True) + # Check if each message is JSON serializable + # Recursive function to filter out binary image data + def filter_binary_data(obj): + if isinstance(obj, bytes): + # Directly handle bytes objects + return "[Binary data removed for serialization]" + elif isinstance(obj, dict): + # If this is an image object, replace it with a text description + if 'image' in obj: + return {"text": "[Image data removed for serialization]"} + + # Otherwise, recursively filter each value in the dictionary + return {k: filter_binary_data(v) for k, v in obj.items()} + elif isinstance(obj, list): + # Recursively filter each item in the list + return [filter_binary_data(item) for item in obj] + else: + # Return primitive values as is + return obj + + # Process all messages to ensure they're serializable + serializable_messages = [] + for message in self.messages: + try: + # Create a deep copy and filter binary data + import copy + filtered_message = copy.deepcopy(message) + filtered_message = filter_binary_data(filtered_message) + + # Test if the filtered message is JSON serializable + json.dumps(filtered_message) + serializable_messages.append(filtered_message) + except (TypeError, OverflowError) as e: + # If still not serializable, log the error and add a placeholder + logger.warning(f"Found non-serializable message after filtering: {str(e)}") + # Add a placeholder that maintains the role + if isinstance(message, dict) and 'role' in message: + serializable_messages.append({ + "role": message.get("role", "unknown"), + "content": [{"text": "[Content not serializable]"}] + }) + state = { - "messages": self.messages + "messages": serializable_messages } # Store state (e.g., database, file system, cache) with open(f"sessions/{user_id}.json", "w") as f: @@ -185,13 +227,55 @@ def save_agent_state(self, user_id): logger.debug(f"Saving conversation to dynamodb table {table_name} in {table_region} region") dynamodb = boto3.resource('dynamodb', region_name=table_region) table = dynamodb.Table(table_name) + # Check if each message is JSON serializable + # Recursive function to filter out binary image data + def filter_binary_data(obj): + if isinstance(obj, bytes): + # Directly handle bytes objects + return "[Binary data removed for serialization]" + elif isinstance(obj, dict): + # If this is an image object, replace it with a text description + if 'image' in obj: + return {"text": "[Image data removed for serialization]"} + + # Otherwise, recursively filter each value in the dictionary + return {k: filter_binary_data(v) for k, v in obj.items()} + elif isinstance(obj, list): + # Recursively filter each item in the list + return [filter_binary_data(item) for item in obj] + else: + # Return primitive values as is + return obj + + # Process all messages to ensure they're serializable + serializable_messages = [] + for message in self.messages: + try: + # Create a deep copy and filter binary data + import copy + filtered_message = copy.deepcopy(message) + filtered_message = filter_binary_data(filtered_message) + + # Test if the filtered message is JSON serializable + json.dumps(filtered_message) + serializable_messages.append(filtered_message) + except (TypeError, OverflowError) as e: + # If still not serializable, log the error and add a placeholder + logger.warning(f"Found non-serializable message after filtering: {str(e)}") + # Add a placeholder that maintains the role + if isinstance(message, dict) and 'role' in message: + serializable_messages.append({ + "role": message.get("role", "unknown"), + "content": [{"text": "[Content not serializable]"}] + }) + state = { - "messages": self.messages + "messages": serializable_messages } table.put_item( Item={ primary_key: user_id, - 'messages': state['messages'] + 'messages': serializable_messages } ) except ClientError as e: @@ -264,12 +348,56 @@ def get_agent_response(request: PromptRequest): logger.debug(f"Model response: {result.message}") agent.save_agent_state(request.userId) logger.info(f"Agent state saved for user: {request.userId}") - return { - "messages": result.message, + + # Create a deep copy of the message to avoid modifying the original + import copy + filtered_message = copy.deepcopy(result.message) + + # Recursive function to filter out binary image data + def filter_binary_data(obj): + if isinstance(obj, bytes): + # Directly handle bytes objects + return "[Binary data removed for serialization]" + elif isinstance(obj, dict): + # If this is an image object, replace it with a text description + if 'image' in obj: + return {"text": "[Image data removed for serialization]"} + + # Otherwise, recursively filter each value in the dictionary + return {k: filter_binary_data(v) for k, v in obj.items()} + elif isinstance(obj, list): + # Recursively filter each item in the list + return [filter_binary_data(item) for item in obj] + else: + # Return primitive values as is + return obj + + # Apply the filtering to the entire message + filtered_message = filter_binary_data(filtered_message) + + # Create a serializable response + response = { + "messages": filtered_message, "latencyMs": result.metrics.accumulated_metrics["latencyMs"], "totalTokens": result.metrics.accumulated_usage["totalTokens"], "summary": result.metrics.get_summary() } + + # Verify that the response is serializable + import json + try: + json.dumps(response) + except TypeError as e: + logger.error(f"Response still contains non-serializable data: {str(e)}") + # Fall back to a minimal response if serialization fails + return { + "messages": {"role": "assistant", "content": [{"text": "Image generated successfully but couldn't be included in the response."}]}, + "latencyMs": result.metrics.accumulated_metrics["latencyMs"], + "totalTokens": result.metrics.accumulated_usage["totalTokens"], + "summary": "Image generation completed" + } + + return response except Exception as e: logger.error(f"Error processing agent response: {str(e)}") raise HTTPException(status_code=500, detail=f"Error processing agent response: {str(e)}")