diff --git a/data/wikibase-test-data.db b/data/wikibase-test-data.db index f80195b4..5be8021d 100644 Binary files a/data/wikibase-test-data.db and b/data/wikibase-test-data.db differ diff --git a/fetch_data/__init__.py b/fetch_data/__init__.py index a185721c..633c28f1 100644 --- a/fetch_data/__init__.py +++ b/fetch_data/__init__.py @@ -29,3 +29,4 @@ create_special_statistics_observation, update_software_data, ) +from fetch_data.update_data import merge_software_by_id diff --git a/fetch_data/update_data/__init__.py b/fetch_data/update_data/__init__.py new file mode 100644 index 00000000..52153a93 --- /dev/null +++ b/fetch_data/update_data/__init__.py @@ -0,0 +1,3 @@ +"""Update Data""" + +from fetch_data.update_data.merge_software import merge_software_by_id diff --git a/fetch_data/update_data/merge_software.py b/fetch_data/update_data/merge_software.py new file mode 100644 index 00000000..086599d3 --- /dev/null +++ b/fetch_data/update_data/merge_software.py @@ -0,0 +1,88 @@ +"""Merge Software""" + +from sqlalchemy import Select, Update, and_, delete, select, update +from data.database_connection import get_async_session +from model.database import ( + WikibaseSoftwareVersionModel, +) +from model.database import WikibaseSoftwareModel +from model.database.wikibase_software.software_tag_xref_model import ( + software_tag_xref_table, +) + + +async def merge_software_by_id(base_id: int, additional_id: int) -> bool: + """Merge Software by ID""" + + software_query = get_select_software_query([base_id, additional_id]) + update_software_version_query = get_update_software_version_query( + base_id, additional_id + ) + update_software_tags_query = get_update_software_tags_query(base_id, additional_id) + delete_additional_tags_query = software_tag_xref_table.delete().where( + software_tag_xref_table.c.wikibase_software_id == additional_id + ) + delete_software_query = delete(WikibaseSoftwareModel).where( + WikibaseSoftwareModel.id == additional_id + ) + + async with get_async_session() as async_session: + software_list = (await async_session.scalars(software_query)).all() + assert len({s.software_type for s in software_list}) == 1 + + await async_session.execute(update_software_version_query) + await async_session.execute(update_software_tags_query) + await async_session.execute(delete_additional_tags_query) + await async_session.flush() + + await async_session.execute(delete_software_query) + await async_session.commit() + + async with get_async_session() as async_session: + remaining = (await async_session.scalars(software_query)).all() + return len(remaining) == 1 + + +def get_select_software_query(id_list: list[int]) -> Select[WikibaseSoftwareModel]: + """Select WikibaseSoftwareModel in ID list""" + + software_query = select(WikibaseSoftwareModel).where( + WikibaseSoftwareModel.id.in_(id_list) + ) + + return software_query + + +def get_update_software_tags_query(base_id: int, additional_id: int) -> Update: + """Add Additional Software Tags to Base""" + + update_software_tags_query = software_tag_xref_table.insert().from_select( + [ + software_tag_xref_table.c.wikibase_software_id, + software_tag_xref_table.c.wikibase_software_tag_id, + ], + select(base_id, software_tag_xref_table.c.wikibase_software_tag_id).where( + and_( + software_tag_xref_table.c.wikibase_software_id == additional_id, + software_tag_xref_table.c.wikibase_software_tag_id.not_in( + select(software_tag_xref_table.c.wikibase_software_tag_id).where( + software_tag_xref_table.c.wikibase_software_id == base_id + ) + ), + ) + ), + ) + + return update_software_tags_query + + +def get_update_software_version_query(base_id: int, additional_id: int) -> Update: + """Update Software Version from Additional ID to Base ID""" + + update_software_version_query = ( + update(WikibaseSoftwareVersionModel) + .where(WikibaseSoftwareVersionModel.software_id == additional_id) + .values(software_id=base_id) + ) + + return update_software_version_query diff --git a/model/strawberry/mutation.py b/model/strawberry/mutation.py index db1f65bf..fce7262d 100644 --- a/model/strawberry/mutation.py +++ b/model/strawberry/mutation.py @@ -12,6 +12,7 @@ create_software_version_observation, create_special_statistics_observation, create_user_observation, + merge_software_by_id, ) @@ -57,3 +58,8 @@ class Mutation: description="Scrape data from Special:Version page", resolver=create_software_version_observation, ) + + merge_software_by_id = strawberry.mutation( + description="Merge Software", + resolver=merge_software_by_id, + ) diff --git a/tests/test_create_observation/test_create_software_version_observation/test_update_software_data.py b/tests/test_create_observation/test_create_software_version_observation/test_update_software_data.py index d3dbffe7..bc647164 100644 --- a/tests/test_create_observation/test_create_software_version_observation/test_update_software_data.py +++ b/tests/test_create_observation/test_create_software_version_observation/test_update_software_data.py @@ -15,7 +15,9 @@ @freeze_time("2024-03-01") @pytest.mark.asyncio @pytest.mark.dependency( - name="update-software-data", depends=["software-version-success"], scope="session" + name="update-software-data", + depends=["software-version-success", "merge-software-by-id"], + scope="session", ) @pytest.mark.version async def test_update_software_data(mocker): diff --git a/tests/test_mutation/test_merge_software.py b/tests/test_mutation/test_merge_software.py new file mode 100644 index 00000000..636cef05 --- /dev/null +++ b/tests/test_mutation/test_merge_software.py @@ -0,0 +1,41 @@ +"""Test Merge Software""" + +import pytest + +from tests.test_schema import test_schema + + +MERGE_SOFTWARE_QUERY = """ +mutation MyMutation($baseId: Int!, $additionalId: Int!) { + mergeSoftwareById(baseId: $baseId, additionalId: $additionalId) +}""" + + +@pytest.mark.asyncio +@pytest.mark.mutation +@pytest.mark.dependency(name="merge-software-by-id") +async def test_merge_software_by_id_mutation(): + """Test Add Wikibase""" + + result = await test_schema.execute( + MERGE_SOFTWARE_QUERY, variable_values={"baseId": 1, "additionalId": 3} + ) + assert result.errors is None + assert result.data is not None + assert result.data.get("mergeSoftwareById") + + +@pytest.mark.asyncio +@pytest.mark.mutation +@pytest.mark.dependency( + depends=["software-version-success"], + name="merge-software-by-id-fail", + scope="session", +) +async def test_merge_software_by_id_fail_mutation(): + """Test Add Wikibase""" + + result = await test_schema.execute( + MERGE_SOFTWARE_QUERY, variable_values={"baseId": 1, "additionalId": 4} + ) + assert result.errors is not None diff --git a/tests/test_query/test_extension_list_query.py b/tests/test_query/test_extension_list_query.py index 0d7e074c..27eb01bd 100644 --- a/tests/test_query/test_extension_list_query.py +++ b/tests/test_query/test_extension_list_query.py @@ -90,7 +90,7 @@ async def test_extension_list_query(): ), ( 1, - "17", + "18", "Google Analytics Integration", "Google_Analytics_Integration", False, @@ -105,7 +105,7 @@ async def test_extension_list_query(): ), ( 2, - "11", + "12", "LabeledSectionTransclusion", "Labeled_Section_Transclusion", False, @@ -129,11 +129,11 @@ async def test_extension_list_query(): None, None, None, - [], + ["Magic", "extensionname"], ), ( 4, - "18", + "19", "ProofreadPage", "Proofread_Page", False, @@ -148,7 +148,7 @@ async def test_extension_list_query(): ), ( 5, - "12", + "13", "Scribunto", "Scribunto", False, @@ -162,7 +162,7 @@ async def test_extension_list_query(): ), ( 6, - "19", + "20", "UniversalLanguageSelector", "UniversalLanguageSelector", False, @@ -176,7 +176,7 @@ async def test_extension_list_query(): ), ( 7, - "13", + "14", "WikibaseClient", "Wikibase_Client", False, @@ -190,7 +190,7 @@ async def test_extension_list_query(): ), ( 8, - "14", + "15", "WikibaseLib", "WikibaseLib", True, @@ -204,7 +204,7 @@ async def test_extension_list_query(): ), ( 9, - "15", + "16", "WikibaseRepository", "Wikibase_Repository", False, @@ -218,7 +218,7 @@ async def test_extension_list_query(): ), ( 10, - "16", + "17", "WikibaseView", "WikibaseView", False,