diff --git a/infrahub_sdk/spec/object.py b/infrahub_sdk/spec/object.py index 0df3f95c..456c9f1e 100644 --- a/infrahub_sdk/spec/object.py +++ b/infrahub_sdk/spec/object.py @@ -7,6 +7,7 @@ from ..exceptions import ObjectValidationError, ValidationError from ..schema import GenericSchemaAPI, RelationshipKind, RelationshipSchema +from ..utils import is_valid_uuid from ..yaml import InfrahubFile, InfrahubFileKind from .models import InfrahubObjectParameters from .processors.factory import DataProcessorFactory @@ -33,6 +34,36 @@ def validate_list_of_objects(value: list[Any]) -> bool: return all(isinstance(item, dict) for item in value) +def normalize_hfid_reference(value: str | list[str]) -> str | list[str]: + """Normalize a reference value to HFID format. + + Only call this function when the peer schema has human_friendly_id defined. + + Args: + value: Either a string (ID or single-component HFID) or a list of strings (multi-component HFID). + + Returns: + - If value is already a list: returns it unchanged as list[str] + - If value is a valid UUID string: returns it unchanged as str (will be treated as an ID) + - If value is a non-UUID string: wraps it in a list as list[str] (single-component HFID) + """ + if isinstance(value, list): + return value + if is_valid_uuid(value): + return value + return [value] + + +def normalize_hfid_references(values: list[str | list[str]]) -> list[str | list[str]]: + """Normalize a list of reference values to HFID format. + + Only call this function when the peer schema has human_friendly_id defined. + + Each string that is not a valid UUID will be wrapped in a list to treat it as a single-component HFID. + """ + return [normalize_hfid_reference(v) for v in values] + + class RelationshipDataFormat(str, Enum): UNKNOWN = "unknown" @@ -51,6 +82,12 @@ class RelationshipInfo(BaseModel): peer_rel: RelationshipSchema | None = None reason_relationship_not_valid: str | None = None format: RelationshipDataFormat = RelationshipDataFormat.UNKNOWN + peer_human_friendly_id: list[str] | None = None + + @property + def peer_has_hfid(self) -> bool: + """Indicate if the peer schema has a human-friendly ID defined.""" + return bool(self.peer_human_friendly_id) @property def is_bidirectional(self) -> bool: @@ -119,6 +156,7 @@ async def get_relationship_info( info.peer_kind = value["kind"] peer_schema = await client.schema.get(kind=info.peer_kind, branch=branch) + info.peer_human_friendly_id = peer_schema.human_friendly_id try: info.peer_rel = peer_schema.get_matching_relationship( @@ -444,10 +482,12 @@ async def create_node( # - if the relationship is bidirectional and is mandatory on the other side, then we need to create this object First # - if the relationship is bidirectional and is not mandatory on the other side, then we need should create the related object First # - if the relationship is not bidirectional, then we need to create the related object First - if rel_info.is_reference and isinstance(value, list): - clean_data[key] = value - elif rel_info.format == RelationshipDataFormat.ONE_REF and isinstance(value, str): - clean_data[key] = [value] + if rel_info.format == RelationshipDataFormat.MANY_REF and isinstance(value, list): + # Cardinality-many reference: normalize string HFIDs to list format if peer has HFID defined + clean_data[key] = normalize_hfid_references(value) if rel_info.peer_has_hfid else value + elif rel_info.format == RelationshipDataFormat.ONE_REF: + # Cardinality-one reference: normalize string to HFID list if peer has HFID, else pass as-is + clean_data[key] = normalize_hfid_reference(value) if rel_info.peer_has_hfid else value elif not rel_info.is_reference and rel_info.is_bidirectional and rel_info.is_mandatory: remaining_rels.append(key) elif not rel_info.is_reference and not rel_info.is_mandatory: diff --git a/tests/fixtures/schema_01.json b/tests/fixtures/schema_01.json index 344ebeab..c2fab38a 100644 --- a/tests/fixtures/schema_01.json +++ b/tests/fixtures/schema_01.json @@ -242,7 +242,10 @@ "label": null, "inherit_from": [], "branch": "aware", - "default_filter": "name__value" + "default_filter": "name__value", + "human_friendly_id": [ + "name__value" + ] }, { "name": "Location", diff --git a/tests/unit/sdk/spec/test_object.py b/tests/unit/sdk/spec/test_object.py index 1af02ac3..b5199c78 100644 --- a/tests/unit/sdk/spec/test_object.py +++ b/tests/unit/sdk/spec/test_object.py @@ -1,14 +1,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch import pytest from infrahub_sdk.exceptions import ValidationError -from infrahub_sdk.spec.object import ObjectFile, RelationshipDataFormat, get_relationship_info +from infrahub_sdk.node.related_node import RelatedNode +from infrahub_sdk.spec.object import ( + ObjectFile, + RelationshipDataFormat, + get_relationship_info, + normalize_hfid_reference, +) if TYPE_CHECKING: from infrahub_sdk.client import InfrahubClient + from infrahub_sdk.node import InfrahubNode @pytest.fixture @@ -263,3 +272,183 @@ async def test_parameters_non_dict(client_with_schema_01: InfrahubClient, locati obj = ObjectFile(location="some/path", content=location_with_non_dict_parameters) with pytest.raises(ValidationError): await obj.validate_format(client=client_with_schema_01) + + +@dataclass +class HfidLoadTestCase: + """Test case for HFID normalization in object loading.""" + + name: str + data: list[dict[str, Any]] + expected_primary_tag: str | list[str] | None + expected_tags: list[str] | list[list[str]] | None + + +HFID_NORMALIZATION_TEST_CASES = [ + HfidLoadTestCase( + name="cardinality_one_string_hfid_normalized", + data=[{"name": "Mexico", "type": "Country", "primary_tag": "Important"}], + expected_primary_tag=["Important"], + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_one_list_hfid_unchanged", + data=[{"name": "Mexico", "type": "Country", "primary_tag": ["Important"]}], + expected_primary_tag=["Important"], + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_one_uuid_unchanged", + data=[{"name": "Mexico", "type": "Country", "primary_tag": "550e8400-e29b-41d4-a716-446655440000"}], + expected_primary_tag="550e8400-e29b-41d4-a716-446655440000", + expected_tags=None, + ), + HfidLoadTestCase( + name="cardinality_many_string_hfids_normalized", + data=[{"name": "Mexico", "type": "Country", "tags": ["Important", "Active"]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["Active"]], + ), + HfidLoadTestCase( + name="cardinality_many_list_hfids_unchanged", + data=[{"name": "Mexico", "type": "Country", "tags": [["Important"], ["Active"]]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["Active"]], + ), + HfidLoadTestCase( + name="cardinality_many_mixed_hfids_normalized", + data=[{"name": "Mexico", "type": "Country", "tags": ["Important", ["namespace", "name"]]}], + expected_primary_tag=None, + expected_tags=[["Important"], ["namespace", "name"]], + ), + HfidLoadTestCase( + name="cardinality_many_uuids_unchanged", + data=[ + { + "name": "Mexico", + "type": "Country", + "tags": ["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"], + } + ], + expected_primary_tag=None, + expected_tags=["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"], + ), +] + + +@pytest.mark.parametrize("test_case", HFID_NORMALIZATION_TEST_CASES, ids=lambda tc: tc.name) +async def test_hfid_normalization_in_object_loading( + client_with_schema_01: InfrahubClient, test_case: HfidLoadTestCase +) -> None: + """Test that HFIDs are normalized correctly based on cardinality and format.""" + + root_location = {"apiVersion": "infrahub.app/v1", "kind": "Object", "spec": {"kind": "BuiltinLocation", "data": []}} + location = { + "apiVersion": root_location["apiVersion"], + "kind": root_location["kind"], + "spec": {"kind": root_location["spec"]["kind"], "data": test_case.data}, + } + + obj = ObjectFile(location="some/path", content=location) + await obj.validate_format(client=client_with_schema_01) + + create_calls: list[dict[str, Any]] = [] + + async def mock_create( + kind: str, + branch: str | None = None, + data: dict | None = None, + **kwargs: Any, # noqa: ANN401 + ) -> InfrahubNode: + create_calls.append({"kind": kind, "data": data}) + original_create = client_with_schema_01.__class__.create + return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs) + + client_with_schema_01.create = mock_create + + with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock): + await obj.process(client=client_with_schema_01) + + assert len(create_calls) == 1 + if test_case.expected_primary_tag is not None: + assert create_calls[0]["data"]["primary_tag"] == test_case.expected_primary_tag + if test_case.expected_tags is not None: + assert create_calls[0]["data"]["tags"] == test_case.expected_tags + + +def test_normalize_hfid_reference_function() -> None: + """Test the normalize_hfid_reference function directly. + + This tests the normalization logic in isolation: + - Non-UUID strings get wrapped in a list (for HFID lookup) + - UUID strings stay as strings (for ID lookup) + - Lists stay unchanged + """ + # Non-UUID string becomes list + assert normalize_hfid_reference("Important") == ["Important"] + + # UUID string stays as string + uuid_value = "550e8400-e29b-41d4-a716-446655440000" + assert normalize_hfid_reference(uuid_value) == uuid_value + + # List stays unchanged + assert normalize_hfid_reference(["namespace", "name"]) == ["namespace", "name"] + assert normalize_hfid_reference(["single"]) == ["single"] + + +@dataclass +class RelatedNodePayloadTestCase: + """Test case for verifying the actual GraphQL payload structure from RelatedNode.""" + + name: str + input_data: str | list[str] + expected_payload: dict[str, Any] + + +RELATED_NODE_PAYLOAD_TEST_CASES = [ + # String (UUID) → {"id": "uuid"} + RelatedNodePayloadTestCase( + name="uuid_string_becomes_id_payload", + input_data="550e8400-e29b-41d4-a716-446655440000", + expected_payload={"id": "550e8400-e29b-41d4-a716-446655440000"}, + ), + # List (HFID) → {"hfid": [...]} + RelatedNodePayloadTestCase( + name="list_becomes_hfid_payload", + input_data=["Important"], + expected_payload={"hfid": ["Important"]}, + ), + # Multi-component HFID list → {"hfid": [...]} + RelatedNodePayloadTestCase( + name="multi_component_hfid_payload", + input_data=["namespace", "name"], + expected_payload={"hfid": ["namespace", "name"]}, + ), +] + + +@pytest.mark.parametrize("test_case", RELATED_NODE_PAYLOAD_TEST_CASES, ids=lambda tc: tc.name) +def test_related_node_graphql_payload(test_case: RelatedNodePayloadTestCase) -> None: + """Test that RelatedNode produces the correct GraphQL payload structure. + + This test verifies the actual {"id": ...} or {"hfid": ...} payload + that gets sent in GraphQL mutations. + """ + # Create mock dependencies + mock_client = MagicMock() + mock_schema = MagicMock() + + # Create RelatedNode with the input data + related_node = RelatedNode( + schema=mock_schema, + name="test_rel", + branch="main", + client=mock_client, + data=test_case.input_data, + ) + + # Generate the input data that would go into GraphQL mutation + payload = related_node._generate_input_data() + + # Verify the payload structure + assert payload == test_case.expected_payload, f"Expected payload {test_case.expected_payload}, got {payload}"