Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions infrahub_sdk/spec/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..exceptions import ObjectValidationError, ValidationError
from ..schema import GenericSchemaAPI, RelationshipKind, RelationshipSchema
from ..utils import is_valid_uuid
from ..yaml import InfrahubFile, InfrahubFileKind
from .models import InfrahubObjectParameters
from .processors.factory import DataProcessorFactory
Expand All @@ -33,6 +34,36 @@ def validate_list_of_objects(value: list[Any]) -> bool:
return all(isinstance(item, dict) for item in value)


def normalize_hfid_reference(value: str | list[str]) -> str | list[str]:
"""Normalize a reference value to HFID format.

Only call this function when the peer schema has human_friendly_id defined.

Args:
value: Either a string (ID or single-component HFID) or a list of strings (multi-component HFID).

Returns:
- If value is already a list: returns it unchanged as list[str]
- If value is a valid UUID string: returns it unchanged as str (will be treated as an ID)
- If value is a non-UUID string: wraps it in a list as list[str] (single-component HFID)
"""
if isinstance(value, list):
return value
if is_valid_uuid(value):
return value
return [value]


def normalize_hfid_references(values: list[str | list[str]]) -> list[str | list[str]]:
"""Normalize a list of reference values to HFID format.

Only call this function when the peer schema has human_friendly_id defined.

Each string that is not a valid UUID will be wrapped in a list to treat it as a single-component HFID.
"""
return [normalize_hfid_reference(v) for v in values]


class RelationshipDataFormat(str, Enum):
UNKNOWN = "unknown"

Expand All @@ -51,6 +82,7 @@ class RelationshipInfo(BaseModel):
peer_rel: RelationshipSchema | None = None
reason_relationship_not_valid: str | None = None
format: RelationshipDataFormat = RelationshipDataFormat.UNKNOWN
peer_has_hfid: bool = False

@property
def is_bidirectional(self) -> bool:
Expand Down Expand Up @@ -119,6 +151,7 @@ async def get_relationship_info(
info.peer_kind = value["kind"]

peer_schema = await client.schema.get(kind=info.peer_kind, branch=branch)
info.peer_has_hfid = bool(peer_schema.human_friendly_id)

try:
info.peer_rel = peer_schema.get_matching_relationship(
Expand Down Expand Up @@ -444,10 +477,12 @@ async def create_node(
# - if the relationship is bidirectional and is mandatory on the other side, then we need to create this object First
# - if the relationship is bidirectional and is not mandatory on the other side, then we need should create the related object First
# - if the relationship is not bidirectional, then we need to create the related object First
if rel_info.is_reference and isinstance(value, list):
clean_data[key] = value
elif rel_info.format == RelationshipDataFormat.ONE_REF and isinstance(value, str):
clean_data[key] = [value]
if rel_info.format == RelationshipDataFormat.MANY_REF and isinstance(value, list):
# Cardinality-many reference: normalize string HFIDs to list format if peer has HFID defined
clean_data[key] = normalize_hfid_references(value) if rel_info.peer_has_hfid else value
elif rel_info.format == RelationshipDataFormat.ONE_REF:
# Cardinality-one reference: normalize string to HFID list if peer has HFID, else pass as-is
clean_data[key] = normalize_hfid_reference(value) if rel_info.peer_has_hfid else value
elif not rel_info.is_reference and rel_info.is_bidirectional and rel_info.is_mandatory:
remaining_rels.append(key)
elif not rel_info.is_reference and not rel_info.is_mandatory:
Expand Down
5 changes: 4 additions & 1 deletion tests/fixtures/schema_01.json
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@
"label": null,
"inherit_from": [],
"branch": "aware",
"default_filter": "name__value"
"default_filter": "name__value",
"human_friendly_id": [
"name__value"
]
},
{
"name": "Location",
Expand Down
266 changes: 264 additions & 2 deletions tests/unit/sdk/spec/test_object.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from infrahub_sdk.exceptions import ValidationError
from infrahub_sdk.spec.object import ObjectFile, RelationshipDataFormat, get_relationship_info
from infrahub_sdk.node.related_node import RelatedNode
from infrahub_sdk.spec.object import (
ObjectFile,
RelationshipDataFormat,
get_relationship_info,
normalize_hfid_reference,
)

if TYPE_CHECKING:
from infrahub_sdk.client import InfrahubClient
from infrahub_sdk.node import InfrahubNode


@pytest.fixture
Expand Down Expand Up @@ -263,3 +272,256 @@ async def test_parameters_non_dict(client_with_schema_01: InfrahubClient, locati
obj = ObjectFile(location="some/path", content=location_with_non_dict_parameters)
with pytest.raises(ValidationError):
await obj.validate_format(client=client_with_schema_01)


@dataclass
class HfidLoadTestCase:
"""Test case for HFID normalization in object loading."""

name: str
data: list[dict[str, Any]]
expected_primary_tag: str | list[str] | None
expected_tags: list[str] | list[list[str]] | None


HFID_NORMALIZATION_TEST_CASES = [
HfidLoadTestCase(
name="cardinality_one_string_hfid_normalized",
data=[{"name": "Mexico", "type": "Country", "primary_tag": "Important"}],
expected_primary_tag=["Important"],
expected_tags=None,
),
HfidLoadTestCase(
name="cardinality_one_list_hfid_unchanged",
data=[{"name": "Mexico", "type": "Country", "primary_tag": ["Important"]}],
expected_primary_tag=["Important"],
expected_tags=None,
),
HfidLoadTestCase(
name="cardinality_one_uuid_unchanged",
data=[{"name": "Mexico", "type": "Country", "primary_tag": "550e8400-e29b-41d4-a716-446655440000"}],
expected_primary_tag="550e8400-e29b-41d4-a716-446655440000",
expected_tags=None,
),
HfidLoadTestCase(
name="cardinality_many_string_hfids_normalized",
data=[{"name": "Mexico", "type": "Country", "tags": ["Important", "Active"]}],
expected_primary_tag=None,
expected_tags=[["Important"], ["Active"]],
),
HfidLoadTestCase(
name="cardinality_many_list_hfids_unchanged",
data=[{"name": "Mexico", "type": "Country", "tags": [["Important"], ["Active"]]}],
expected_primary_tag=None,
expected_tags=[["Important"], ["Active"]],
),
HfidLoadTestCase(
name="cardinality_many_mixed_hfids_normalized",
data=[{"name": "Mexico", "type": "Country", "tags": ["Important", ["namespace", "name"]]}],
expected_primary_tag=None,
expected_tags=[["Important"], ["namespace", "name"]],
),
HfidLoadTestCase(
name="cardinality_many_uuids_unchanged",
data=[
{
"name": "Mexico",
"type": "Country",
"tags": ["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"],
}
],
expected_primary_tag=None,
expected_tags=["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"],
),
]


@pytest.mark.parametrize("test_case", HFID_NORMALIZATION_TEST_CASES, ids=lambda tc: tc.name)
async def test_hfid_normalization_in_object_loading(
client_with_schema_01: InfrahubClient, test_case: HfidLoadTestCase
) -> None:
"""Test that HFIDs are normalized correctly based on cardinality and format."""

root_location = {"apiVersion": "infrahub.app/v1", "kind": "Object", "spec": {"kind": "BuiltinLocation", "data": []}}
location = {
"apiVersion": root_location["apiVersion"],
"kind": root_location["kind"],
"spec": {"kind": root_location["spec"]["kind"], "data": test_case.data},
}

obj = ObjectFile(location="some/path", content=location)
await obj.validate_format(client=client_with_schema_01)

create_calls: list[dict[str, Any]] = []

async def mock_create(
kind: str,
branch: str | None = None,
data: dict | None = None,
**kwargs: Any, # noqa: ANN401
) -> InfrahubNode:
create_calls.append({"kind": kind, "data": data})
original_create = client_with_schema_01.__class__.create
return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs)

client_with_schema_01.create = mock_create

with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock):
await obj.process(client=client_with_schema_01)

assert len(create_calls) == 1
if test_case.expected_primary_tag is not None:
assert create_calls[0]["data"]["primary_tag"] == test_case.expected_primary_tag
if test_case.expected_tags is not None:
assert create_calls[0]["data"]["tags"] == test_case.expected_tags


@dataclass
class GraphQLPayloadTestCase:
"""Test case for verifying data format that leads to correct GraphQL payload.

The RelatedNode interprets data as follows:
- list → stored as hfid → GraphQL: {"hfid": [...]}
- string → stored as id → GraphQL: {"id": "..."}
"""

name: str
peer_has_hfid: bool
input_value: str | list[str]
expected_output_type: str # "list" for hfid, "string" for id
expected_output_value: str | list[str]


GRAPHQL_PAYLOAD_TEST_CASES = [
# Peer HAS HFID - non-UUID string should become list (hfid)
GraphQLPayloadTestCase(
name="hfid_defined_string_becomes_list",
peer_has_hfid=True,
input_value="Important",
expected_output_type="list",
expected_output_value=["Important"],
),
# Peer HAS HFID - UUID string should stay as string (id)
GraphQLPayloadTestCase(
name="hfid_defined_uuid_stays_string",
peer_has_hfid=True,
input_value="550e8400-e29b-41d4-a716-446655440000",
expected_output_type="string",
expected_output_value="550e8400-e29b-41d4-a716-446655440000",
),
# Peer HAS HFID - list stays as list (hfid)
GraphQLPayloadTestCase(
name="hfid_defined_list_stays_list",
peer_has_hfid=True,
input_value=["namespace", "name"],
expected_output_type="list",
expected_output_value=["namespace", "name"],
),
# Peer has NO HFID - non-UUID string stays as string (id lookup)
GraphQLPayloadTestCase(
name="no_hfid_string_stays_string",
peer_has_hfid=False,
input_value="some-string-value",
expected_output_type="string",
expected_output_value="some-string-value",
),
# Peer has NO HFID - UUID stays as string (id)
GraphQLPayloadTestCase(
name="no_hfid_uuid_stays_string",
peer_has_hfid=False,
input_value="550e8400-e29b-41d4-a716-446655440000",
expected_output_type="string",
expected_output_value="550e8400-e29b-41d4-a716-446655440000",
),
]


@pytest.mark.parametrize("test_case", GRAPHQL_PAYLOAD_TEST_CASES, ids=lambda tc: tc.name)
def test_graphql_payload_format(test_case: GraphQLPayloadTestCase) -> None:
"""Test that relationship data is formatted correctly for GraphQL payload.

The RelatedNode class interprets:
- list input → {"hfid": [...]} in GraphQL
- string input → {"id": "..."} in GraphQL

This test verifies the normalization produces the correct format.
"""
if test_case.peer_has_hfid:
# When peer has HFID, use normalization
processed_value = normalize_hfid_reference(test_case.input_value)
else:
# When peer has no HFID, pass value as-is (no normalization)
processed_value = test_case.input_value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't actually test the logic that checks if the peer schema has an HFID defined. for that, this test probably needs to call get_relationship_info. although that looks like it will require some mocking


# Verify the output type matches expected
if test_case.expected_output_type == "list":
assert isinstance(processed_value, list), (
f"Expected list output for hfid, got {type(processed_value).__name__}: {processed_value}"
)
else:
assert isinstance(processed_value, str), (
f"Expected string output for id, got {type(processed_value).__name__}: {processed_value}"
)

# Verify the actual value
assert processed_value == test_case.expected_output_value, (
f"Expected {test_case.expected_output_value}, got {processed_value}"
)


@dataclass
class RelatedNodePayloadTestCase:
"""Test case for verifying the actual GraphQL payload structure from RelatedNode."""

name: str
input_data: str | list[str]
expected_payload: dict[str, Any]


RELATED_NODE_PAYLOAD_TEST_CASES = [
# String (UUID) → {"id": "uuid"}
RelatedNodePayloadTestCase(
name="uuid_string_becomes_id_payload",
input_data="550e8400-e29b-41d4-a716-446655440000",
expected_payload={"id": "550e8400-e29b-41d4-a716-446655440000"},
),
# List (HFID) → {"hfid": [...]}
RelatedNodePayloadTestCase(
name="list_becomes_hfid_payload",
input_data=["Important"],
expected_payload={"hfid": ["Important"]},
),
# Multi-component HFID list → {"hfid": [...]}
RelatedNodePayloadTestCase(
name="multi_component_hfid_payload",
input_data=["namespace", "name"],
expected_payload={"hfid": ["namespace", "name"]},
),
]


@pytest.mark.parametrize("test_case", RELATED_NODE_PAYLOAD_TEST_CASES, ids=lambda tc: tc.name)
def test_related_node_graphql_payload(test_case: RelatedNodePayloadTestCase) -> None:
"""Test that RelatedNode produces the correct GraphQL payload structure.

This test verifies the actual {"id": ...} or {"hfid": ...} payload
that gets sent in GraphQL mutations.
"""
# Create mock dependencies
mock_client = MagicMock()
mock_schema = MagicMock()

# Create RelatedNode with the input data
related_node = RelatedNode(
schema=mock_schema,
name="test_rel",
branch="main",
client=mock_client,
data=test_case.input_data,
)

# Generate the input data that would go into GraphQL mutation
payload = related_node._generate_input_data()

# Verify the payload structure
assert payload == test_case.expected_payload, f"Expected payload {test_case.expected_payload}, got {payload}"