diff --git a/adala/runtimes/_litellm.py b/adala/runtimes/_litellm.py index 8ff18058..bd2a191b 100644 --- a/adala/runtimes/_litellm.py +++ b/adala/runtimes/_litellm.py @@ -565,7 +565,8 @@ def get_cost_estimate( ) -> CostEstimate: try: user_prompts = [ - prompt.format(**substitution) for substitution in substitutions + partial_str_format(prompt, **substitution) + for substitution in substitutions ] cumulative_prompt_cost = 0 cumulative_completion_cost = 0 diff --git a/tests/test_cost_estimation.py b/tests/test_cost_estimation.py index b32c286d..bd2be2c6 100644 --- a/tests/test_cost_estimation.py +++ b/tests/test_cost_estimation.py @@ -2,12 +2,8 @@ import pytest from adala.runtimes._litellm import AsyncLiteLLMChatRuntime from adala.runtimes.base import CostEstimate -from adala.agents import Agent -from adala.skills import ClassificationSkill import numpy as np import os -from fastapi.testclient import TestClient -from server.app import app, CostEstimateRequest OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") @@ -60,7 +56,19 @@ def test_estimate_cost_endpoint(client): } }, }, - "prompt": "test {text}", + "prompt": """ + test {text} + + Use the following JSON format: + { + "data": [ + { + "output": "", + "reasoning": "", + } + ] + } + """, "substitutions": [{"text": "test"}], } resp = client.post(