diff --git a/apex/services/deep_research/deep_research_langchain.py b/apex/services/deep_research/deep_research_langchain.py index 84ae33bd..4db05327 100644 --- a/apex/services/deep_research/deep_research_langchain.py +++ b/apex/services/deep_research/deep_research_langchain.py @@ -171,7 +171,7 @@ async def invoke( collected_sources: list[dict[str, str]] = [] seen_urls: set[str] = set() - agent_chain = self._build_agent_chain() + agent_chain = await self._build_agent_chain() while step_index < max_iterations: logger.debug(f"Starting deep researcher {step_index + 1}/{max_iterations} step") @@ -202,17 +202,18 @@ async def invoke( # Final answer branch if "final_answer" in parsed: - logger.debug("Early-stopping deep research due to the final answer") final_answer = str(parsed.get("final_answer", "")) reasoning_traces.append( { "step": f"iteration-{step_index}", "model": getattr(self.research_model, "model_name", "unknown"), "thought": thought, - "final_answer": final_answer, + "observation": final_answer, } ) - return final_answer, self.tool_history, reasoning_traces + logger.debug("Early-stopping deep research due to the final answer") + # return final_answer, self.tool_history, reasoning_traces + break # Action branch (only websearch supported) action = parsed.get("action") or {} @@ -255,7 +256,7 @@ async def invoke( "model": getattr(self.research_model, "model_name", "unknown"), "thought": thought, "action": {"tool": "websearch", "query": query, "max_results": max_results}, - "observation": observation_text[:1000], + "observation": observation_text[:1200], } ) continue @@ -289,7 +290,7 @@ async def invoke( "model": getattr(self.research_model, "model_name", "unknown"), "thought": thought, "action": {"tool": "python_repl", "code": code[:1000]}, - "observation": observation_text[:1000], + "observation": observation_text[:1200], } ) continue @@ -305,19 +306,9 @@ async def invoke( ) notes.append("Agent returned an unsupported action. Use the websearch tool or provide final_answer.") - # Fallback: if loop ends without final answer, ask final model to synthesize from notes + # If loop ends without final answer, ask final model to synthesize from notes. logger.debug("Generating final answer") - final_prompt = PromptTemplate( - input_variables=["question", "notes", "sources"], - template=( - self._FINAL_ANSWER_INST + "Do NOT use JSON, or any other structured data format.\n" - "Question:\n{question}\n\n" - "Notes:\n{notes}\n\n" - "Sources:\n{sources}\n\n" - "Research Report:" - ), - ) - final_chain = final_prompt | self.final_model | StrOutputParser() + final_chain = await self._build_final_chain() final_report: str = await self._try_invoke( final_chain, @@ -336,6 +327,19 @@ async def invoke( ) return final_report, self.tool_history, reasoning_traces + async def _build_final_chain(self) -> RunnableSerializable[dict[str, Any], str]: + final_prompt = PromptTemplate( + input_variables=["question", "notes", "sources"], + template=( + self._FINAL_ANSWER_INST + "Do NOT use JSON, or any other structured data format. Provide \n" + "Question:\n{question}\n\n" + "Notes:\n{notes}\n\n" + "Sources:\n{sources}\n\n" + "Research Report:" + ), + ) + return final_prompt | self.final_model | StrOutputParser() + def _render_sources(self, collected_sources: list[dict[str, str]], max_items: int = 12) -> str: if not collected_sources: return "(none)" @@ -352,19 +356,19 @@ def _render_notes(self, notes: list[str], max_items: int = 8) -> str: clipped = notes[-max_items:] return "\n".join(f"- {item}" for item in clipped) - def _build_agent_chain(self) -> RunnableSerializable[dict[str, Any], str]: + async def _build_agent_chain(self) -> RunnableSerializable[dict[str, Any], str]: prompt = PromptTemplate( input_variables=["question", "notes", "sources"], template=( "You are DeepResearcher, a meticulous, tool-using research agent.\n" "You can use exactly these tools: websearch, python_repl.\n\n" "Tool: websearch\n" - "- description: Search the web for relevant information.\n" - "- args: keys: 'query' (string), 'max_results' (integer <= 10)\n\n" + " - description: Search the web for relevant information.\n" + " - args: keys: 'query' (string), 'max_results' (integer <= 10)\n\n" "Tool: python_repl\n" - "- description: A Python shell for executing Python commands.\n" - "- note: Print values to see output, e.g., `print(...)`.\n" - "- args: keys: 'code' (string: valid python command).\n\n" + " - description: A Python shell for executing Python commands.\n" + " - note: Print values to see output, e.g., `print(...)`.\n" + " - args: keys: 'code' (string: valid python command).\n\n" "Follow an iterative think-act-observe loop. " "Prefer rich internal reasoning over issuing many tool calls.\n" "Spend time thinking: produce substantial, explicit reasoning in each 'thought'.\n" @@ -372,27 +376,29 @@ def _build_agent_chain(self) -> RunnableSerializable[dict[str, Any], str]: "unless the question is truly trivial. " "If no tool use is needed in a step, still provide a reflective 'thought'\n" "that evaluates evidence, identifies gaps, and plans the next step.\n\n" - "Always respond in strict JSON. Use one of the two schemas:\n\n" - "1) Action step (JSON keys shown with dot-paths):\n" - "- thought: string\n" - "- action.tool: 'websearch' | 'python_repl'\n" - "- action.input: for websearch -> {{query: string, max_results: integer}}\n" - "- action.input: for python_repl -> {{code: string}}\n\n" - "2) Final answer step:\n" - "- thought: string\n" - "- final_answer: string (use plain text for final answer, not a JSON)\n\n" + "Always respond in strict JSON for deep research steps (do not use JSON for final answer). " + "Use one of the two schemas:\n\n" + "1. Action step (JSON keys shown with dot-paths):\n" + " - thought: string\n" + " - action.tool: 'websearch' | 'python_repl'\n" + " - action.input: for websearch -> {{query: string, max_results: integer}}\n" + " - action.input: for python_repl -> {{code: string}}\n\n" + "2. Final answer step:\n" + " - thought: string\n" + " - final_answer: string\n\n" "In every step, make 'thought' a detailed paragraph (120-200 words) that:\n" - "- Summarizes what is known and unknown so far\n" - "- Justifies the chosen next action or decision not to act\n" - "- Evaluates evidence quality and cites source numbers when applicable\n" - "- Identifies risks, uncertainties, and alternative hypotheses\n\n" + " - Summarizes what is known and unknown so far\n" + " - Justifies the chosen next action or decision not to act\n" + " - Evaluates evidence quality and cites source numbers when applicable\n" + " - Identifies risks, uncertainties, and alternative hypotheses\n\n" + "Respond with JSON only during deep research steps, " + "final answer must be always in a plain text formatted as a research report, with sections:\n" "Executive Summary, Key Findings, Evidence, Limitations, Conclusion.\n" "Use inline numeric citations like [1], [2] that refer to Sources.\n" "Include a final section titled 'Sources' listing the numbered citations.\n\n" "Question:\n{question}\n\n" "Notes and observations so far:\n{notes}\n\n" "Sources (use these for citations):\n{sources}\n\n" - "Respond with JSON always, except for final_anwer (use plain text)." ), ) return prompt | self.research_model | StrOutputParser() diff --git a/tests/services/deep_research/test_deep_research_langchain.py b/tests/services/deep_research/test_deep_research_langchain.py index 7e639107..c68f9bfa 100644 --- a/tests/services/deep_research/test_deep_research_langchain.py +++ b/tests/services/deep_research/test_deep_research_langchain.py @@ -1,3 +1,4 @@ +import json from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -76,17 +77,26 @@ async def test_invoke_with_documents_in_body(deep_research_langchain, mock_webse body = {"documents": [{"page_content": "doc1"}, {"page_content": "doc2"}]} with ( - patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, - patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain" + ) as mock_build_agent_chain, + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain" + ) as mock_build_final_chain, ): agent_chain = AsyncMock() - agent_chain.ainvoke.return_value = '{"thought": "enough info", "final_answer": "final_report"}' - mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain + return_value = json.dumps({"thought": "enough info", "final_answer": "final_report"}) + agent_chain.ainvoke.return_value = return_value + mock_build_agent_chain.return_value = agent_chain + + final_chain_mock = AsyncMock() + final_chain_mock.ainvoke.return_value = return_value + mock_build_final_chain.return_value = final_chain_mock result = await deep_research_langchain.invoke(messages, body) mock_websearch.search.assert_not_called() - assert result[0] == "final_report" + assert result[0] == return_value @pytest.mark.asyncio @@ -96,8 +106,12 @@ async def test_invoke_with_websearch(deep_research_langchain, mock_websearch): mock_websearch.search.return_value = [MagicMock(content="web_doc", url="http://a.com", title="A")] with ( - patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, - patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain" + ) as mock_build_agent_chain, + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain" + ) as mock_build_final_chain, ): agent_chain = AsyncMock() agent_chain.ainvoke.side_effect = [ @@ -107,7 +121,11 @@ async def test_invoke_with_websearch(deep_research_langchain, mock_websearch): ), '{"thought": "done", "final_answer": "final_answer"}', ] - mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain + mock_build_agent_chain.return_value = agent_chain + + final_chain_mock = AsyncMock() + final_chain_mock.ainvoke.return_value = "final_answer" + mock_build_final_chain.return_value = final_chain_mock result = await deep_research_langchain.invoke(messages) @@ -121,17 +139,26 @@ async def test_invoke_no_websearch_needed_final_answer(deep_research_langchain, messages = [{"role": "user", "content": "test question"}] with ( - patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, - patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain" + ) as mock_build_agent_chain, + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain" + ) as mock_build_final_chain, ): agent_chain = AsyncMock() - agent_chain.ainvoke.return_value = '{"thought": "clear", "final_answer": "final_report"}' - mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain + return_value = json.dumps({"thought": "enough info", "final_answer": "final_report"}) + agent_chain.ainvoke.return_value = return_value + mock_build_agent_chain.return_value = agent_chain + + final_chain_mock = AsyncMock() + final_chain_mock.ainvoke.return_value = return_value + mock_build_final_chain.return_value = final_chain_mock result = await deep_research_langchain.invoke(messages) mock_websearch.search.assert_not_called() - assert result[0] == "final_report" + assert result[0] == return_value @pytest.mark.asyncio @@ -149,8 +176,12 @@ async def test_full_invoke_flow_with_multiple_actions(deep_research_langchain, m ] with ( - patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, - patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain" + ) as mock_build_agent_chain, + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain" + ) as mock_build_final_chain, ): agent_chain = AsyncMock() agent_chain.ainvoke.side_effect = [ @@ -164,7 +195,11 @@ async def test_full_invoke_flow_with_multiple_actions(deep_research_langchain, m ), '{"thought": "complete", "final_answer": "final_report"}', ] - mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain + mock_build_agent_chain.return_value = agent_chain + + final_chain_mock = AsyncMock() + final_chain_mock.ainvoke.return_value = "final_report" + mock_build_final_chain.return_value = final_chain_mock result = await deep_research_langchain.invoke(messages) @@ -186,15 +221,23 @@ async def test_full_invoke_flow_with_multiple_actions(deep_research_langchain, m async def test_invoke_with_python_repl(deep_research_langchain): """Agent chooses python_repl then produces final answer.""" with ( - patch("apex.services.deep_research.deep_research_langchain.PromptTemplate") as mock_prompt_template, - patch("apex.services.deep_research.deep_research_langchain.StrOutputParser"), + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_agent_chain" + ) as mock_build_agent_chain, + patch( + "apex.services.deep_research.deep_research_langchain.DeepResearchLangchain._build_final_chain" + ) as mock_build_final_chain, ): agent_chain = AsyncMock() agent_chain.ainvoke.side_effect = [ ('{"thought": "compute needed", "action": {"tool": "python_repl", "input": {"code": "print(1+1)"}}}'), '{"thought": "done", "final_answer": "final_answer"}', ] - mock_prompt_template.return_value.__or__.return_value.__or__.return_value = agent_chain + mock_build_agent_chain.return_value = agent_chain + + final_chain_mock = AsyncMock() + final_chain_mock.ainvoke.return_value = "final_answer" + mock_build_final_chain.return_value = final_chain_mock result = await deep_research_langchain.invoke([{"role": "user", "content": "q"}])