1- """
2- Tests for the AbstractGraph.
3- """
4-
5- from unittest .mock import patch
6-
71import pytest
2+
83from langchain_aws import ChatBedrock
94from langchain_ollama import ChatOllama
105from langchain_openai import AzureChatOpenAI , ChatOpenAI
11-
126from scrapegraphai .graphs import AbstractGraph , BaseGraph
137from scrapegraphai .models import DeepSeek , OneApi
148from scrapegraphai .nodes import FetchNode , ParseNode
9+ from unittest .mock import Mock , patch
1510
11+ """
12+ Tests for the AbstractGraph.
13+ """
1614
1715class TestGraph (AbstractGraph ):
1816 def __init__ (self , prompt : str , config : dict ):
@@ -50,7 +48,6 @@ def run(self) -> str:
5048
5149 return self .final_state .get ("answer" , "No answer found." )
5250
53-
5451class TestAbstractGraph :
5552 @pytest .mark .parametrize (
5653 "llm_config, expected_model" ,
@@ -161,3 +158,45 @@ async def test_run_safe_async(self):
161158 result = await graph .run_safe_async ()
162159 assert result == "Async result"
163160 mock_run .assert_called_once ()
161+
162+ def test_create_llm_with_custom_model_instance (self ):
163+ """
164+ Test that the _create_llm method correctly uses a custom model instance
165+ when provided in the configuration.
166+ """
167+ mock_model = Mock ()
168+ mock_model .model_name = "custom-model"
169+
170+ config = {
171+ "llm" : {
172+ "model_instance" : mock_model ,
173+ "model_tokens" : 1000 ,
174+ "model" : "custom/model"
175+ }
176+ }
177+
178+ graph = TestGraph ("Test prompt" , config )
179+
180+ assert graph .llm_model == mock_model
181+ assert graph .model_token == 1000
182+
183+ def test_set_common_params (self ):
184+ """
185+ Test that the set_common_params method correctly updates the configuration
186+ of all nodes in the graph.
187+ """
188+ # Create a mock graph with mock nodes
189+ mock_graph = Mock ()
190+ mock_node1 = Mock ()
191+ mock_node2 = Mock ()
192+ mock_graph .nodes = [mock_node1 , mock_node2 ]
193+
194+ # Create a TestGraph instance with the mock graph
195+ with patch ('scrapegraphai.graphs.abstract_graph.AbstractGraph._create_graph' , return_value = mock_graph ):
196+ graph = TestGraph ("Test prompt" , {"llm" : {"model" : "openai/gpt-3.5-turbo" , "openai_api_key" : "sk-test" }})
197+
198+ # Call set_common_params with test parameters
199+ test_params = {"param1" : "value1" , "param2" : "value2" }
200+ graph .set_common_params (test_params )
201+
202+ # Assert that update_config was called on each node with the correct parameters
0 commit comments