diff --git a/tests/fixtures/tracer_fixtures.py b/tests/fixtures/tracer_fixtures.py new file mode 100644 index 00000000..6163ad52 --- /dev/null +++ b/tests/fixtures/tracer_fixtures.py @@ -0,0 +1,62 @@ + +import pytest +import json +from policyengine_api.services.tracer_analysis_service import TracerAnalysisService + +valid_request_body = { + "tracer_output": [ + "only_government_benefit <1500>", + " market_income <1000>", + " employment_income <1000>", + " main_employment_income <1000>", + " non_market_income <500>", + " pension_income <500>", + ] +} + +valid_tracer_row = { + "household_id": "71424", + "policy_id": "2", + "country_id": "us", + "api_version": "1.150.0", + "tracer_output": json.dumps(valid_request_body["tracer_output"]), +} + + +@pytest.fixture +def tracer_service(): + """Fixture to initialize the TracerAnalysisService.""" + return TracerAnalysisService() + + +@pytest.fixture +def test_tracer_data(test_db): + + # Insert data using query() + test_db.query( + """ + INSERT INTO tracers (household_id, policy_id, country_id, api_version, tracer_output) + VALUES (?, ?, ?, ?, ?) + """, + ( + valid_tracer_row["household_id"], + valid_tracer_row["policy_id"], + valid_tracer_row["country_id"], + valid_tracer_row["api_version"], + valid_tracer_row["tracer_output"], + ), + ) + + + # Verify that the data has been inserted + inserted_row = test_db.query( + "SELECT * FROM tracers WHERE household_id = ? AND policy_id = ? AND country_id = ? AND api_version = ?", + ( + valid_tracer_row["household_id"], + valid_tracer_row["policy_id"], + valid_tracer_row["country_id"], + valid_tracer_row["api_version"], + ), + ).fetchone() + + return inserted_row \ No newline at end of file diff --git a/tests/unit/services/test_tracer_get_service.py b/tests/unit/services/test_tracer_get_service.py deleted file mode 100644 index 2c7d82c9..00000000 --- a/tests/unit/services/test_tracer_get_service.py +++ /dev/null @@ -1,87 +0,0 @@ -import pytest -import json -from policyengine_api.services.tracer_analysis_service import TracerAnalysisService -from werkzeug.exceptions import NotFound - - -@pytest.fixture -def tracer_service(): - """Fixture to initialize the TracerAnalysisService.""" - return TracerAnalysisService() - - -@pytest.fixture -def test_tracer_data(test_db): - """Fixture to insert a valid tracer record into the database.""" - valid_tracer_row = { - "household_id": "71424", - "policy_id": "2", - "country_id": "us", - "api_version": "1.150.0", - "tracer_output": json.dumps([ - "only_government_benefit <1500>", - " market_income <1000>", - " employment_income <1000>", - " main_employment_income <1000>", - " non_market_income <500>", - " pension_income <500>", - ]), - } - - # Insert data using query() - test_db.query( - """ - INSERT INTO tracers (household_id, policy_id, country_id, api_version, tracer_output) - VALUES (?, ?, ?, ?, ?) - """, - ( - valid_tracer_row["household_id"], - valid_tracer_row["policy_id"], - valid_tracer_row["country_id"], - valid_tracer_row["api_version"], - valid_tracer_row["tracer_output"], - ), - ) - - - # Verify that the data has been inserted - inserted_row = test_db.query( - "SELECT * FROM tracers WHERE household_id = ? AND policy_id = ? AND country_id = ? AND api_version = ?", - ( - valid_tracer_row["household_id"], - valid_tracer_row["policy_id"], - valid_tracer_row["country_id"], - valid_tracer_row["api_version"], - ), - ).fetchone() - - - return inserted_row - - - - -def test_get_tracer_valid(tracer_service, test_tracer_data): - # Test get_tracer successfully retrieves valid data from the database."" - record = test_tracer_data - - result = tracer_service.get_tracer( - record["country_id"], record["household_id"], record["policy_id"], record["api_version"] - ) - - assert isinstance(result, list) - assert result["tracer_output"] == record["tracer_output"] - - -def test_get_tracer_not_found(tracer_service): - # Test get_tracer raises NotFound when no matching record exists. - with pytest.raises(NotFound): - tracer_service.get_tracer("us", "999999", "999", "9.999.0") - - -def test_get_tracer_database_error(tracer_service): - # Test get_tracer handles database errors properly. - with pytest.raises(Exception): - tracer_service.get_tracer("us", "71424", "2", "1.150.0") - - diff --git a/tests/unit/services/test_tracer_service.py b/tests/unit/services/test_tracer_service.py new file mode 100644 index 00000000..d8b4cb54 --- /dev/null +++ b/tests/unit/services/test_tracer_service.py @@ -0,0 +1,59 @@ +import pytest +import json +from policyengine_api.services.tracer_analysis_service import TracerAnalysisService +from werkzeug.exceptions import NotFound + +from tests.fixtures.tracer_fixtures import ( + test_tracer_data, + valid_tracer_row, + valid_request_body, + tracer_service +) + + +def test_get_tracer_valid(tracer_service, test_tracer_data): + # Test get_tracer successfully retrieves valid data from the database. + record = test_tracer_data + + result = tracer_service.get_tracer( + record["country_id"], record["household_id"], record["policy_id"], record["api_version"] + ) + + # match the valid output as collected from fixture + valid_output = valid_request_body["tracer_output"] + assert result == valid_output + + +def test_get_tracer_not_found(tracer_service): + # Test get_tracer raises NotFound when no matching record exists. + with pytest.raises(NotFound): + tracer_service.get_tracer("us", "999999", "999", "9.999.0") + + +def test_get_tracer_database_error(tracer_service, test_db): + # Test get_tracer handles database errors properly. + with pytest.raises(Exception): + tracer_service.get_tracer("us", "71424", "2", "1.150.0") + + +def test_invalid_input_type(tracer_service , test_db): + # Test get_tracer handles invalid input parameter error + with pytest.raises(Exception) as exception_val: + tracer_service.get_tracer("us" , 100 , "2" , "1.150.0") + + if exception_val.type == TypeError: + assert exception_val.type == TypeError + else: + pytest.fail(f"Expected value not found instead {type(exception_val)}") + +def test_invalid_input_value(tracer_service,test_db): + # Test get_tracer handles invalid input value error + + with pytest.raises(Exception) as exception_val: + tracer_service.get_tracer("usa" , "-100" , "2" , "1.150.0") + + if exception_val.type == ValueError: + assert exception_val.type == ValueError + else: + pytest.fail(f"Expected value not found instead {type(exception_val)}") + \ No newline at end of file