Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions infrahub_sdk/spec/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..exceptions import ObjectValidationError, ValidationError
from ..schema import GenericSchemaAPI, RelationshipKind, RelationshipSchema
from ..utils import is_valid_uuid
from ..yaml import InfrahubFile, InfrahubFileKind
from .models import InfrahubObjectParameters
from .processors.factory import DataProcessorFactory
Expand All @@ -33,6 +34,28 @@ def validate_list_of_objects(value: list[Any]) -> bool:
return all(isinstance(item, dict) for item in value)


def normalize_hfid_reference(value: str | list[str]) -> list[str]:
"""Normalize a reference value to HFID format.

If the value is a string and not a valid UUID, wrap it in a list to treat it as a single-component HFID.
If the value is already a list, return it as-is.
If the value is a UUID string, return it as-is (will be treated as an ID).
"""
if isinstance(value, list):
return value
if is_valid_uuid(value):
return value # type: ignore[return-value]
return [value]


def normalize_hfid_references(values: list[str | list[str]]) -> list[str | list[str]]:
"""Normalize a list of reference values to HFID format.

Each string that is not a valid UUID will be wrapped in a list to treat it as a single-component HFID.
"""
return [normalize_hfid_reference(v) for v in values]


class RelationshipDataFormat(str, Enum):
UNKNOWN = "unknown"

Expand Down Expand Up @@ -445,9 +468,12 @@ async def create_node(
# - if the relationship is bidirectional and is not mandatory on the other side, then we need should create the related object First
# - if the relationship is not bidirectional, then we need to create the related object First
if rel_info.is_reference and isinstance(value, list):
clean_data[key] = value
# Normalize string HFIDs to list format: "name" -> ["name"]
# UUIDs are left as-is since they are treated as IDs
clean_data[key] = normalize_hfid_references(value)
elif rel_info.format == RelationshipDataFormat.ONE_REF and isinstance(value, str):
clean_data[key] = [value]
# Normalize string to HFID format if not a UUID
clean_data[key] = [normalize_hfid_reference(value)]
elif not rel_info.is_reference and rel_info.is_bidirectional and rel_info.is_mandatory:
remaining_rels.append(key)
elif not rel_info.is_reference and not rel_info.is_mandatory:
Expand Down
247 changes: 247 additions & 0 deletions tests/unit/sdk/spec/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,250 @@ async def test_parameters_non_dict(client_with_schema_01: InfrahubClient, locati
obj = ObjectFile(location="some/path", content=location_with_non_dict_parameters)
with pytest.raises(ValidationError):
await obj.validate_format(client=client_with_schema_01)


class TestHfidNormalizationInObjectLoading:
"""Tests to verify HFID normalization works correctly through the object loading code path."""

@pytest.fixture
def location_with_cardinality_one_string_hfid(self, root_location: dict) -> dict:
"""Location with a cardinality-one relationship using string HFID."""
data = [{"name": "Mexico", "type": "Country", "primary_tag": "Important"}]
location = root_location.copy()
location["spec"]["data"] = data
return location

@pytest.fixture
def location_with_cardinality_one_list_hfid(self, root_location: dict) -> dict:
"""Location with a cardinality-one relationship using list HFID."""
data = [{"name": "Mexico", "type": "Country", "primary_tag": ["Important"]}]
location = root_location.copy()
location["spec"]["data"] = data
return location

@pytest.fixture
def location_with_cardinality_one_uuid(self, root_location: dict) -> dict:
"""Location with a cardinality-one relationship using UUID."""
data = [
{"name": "Mexico", "type": "Country", "primary_tag": "550e8400-e29b-41d4-a716-446655440000"}
]
location = root_location.copy()
location["spec"]["data"] = data
return location

@pytest.fixture
def location_with_cardinality_many_string_hfids(self, root_location: dict) -> dict:
"""Location with a cardinality-many relationship using string HFIDs."""
data = [{"name": "Mexico", "type": "Country", "tags": ["Important", "Active"]}]
location = root_location.copy()
location["spec"]["data"] = data
return location

@pytest.fixture
def location_with_cardinality_many_list_hfids(self, root_location: dict) -> dict:
"""Location with a cardinality-many relationship using list HFIDs."""
data = [{"name": "Mexico", "type": "Country", "tags": [["Important"], ["Active"]]}]
location = root_location.copy()
location["spec"]["data"] = data
return location

@pytest.fixture
def location_with_cardinality_many_mixed_hfids(self, root_location: dict) -> dict:
"""Location with a cardinality-many relationship using mixed string and list HFIDs."""
data = [{"name": "Mexico", "type": "Country", "tags": ["Important", ["namespace", "name"]]}]
location = root_location.copy()
location["spec"]["data"] = data
return location

@pytest.fixture
def location_with_cardinality_many_uuids(self, root_location: dict) -> dict:
"""Location with a cardinality-many relationship using UUIDs."""
data = [
{
"name": "Mexico",
"type": "Country",
"tags": ["550e8400-e29b-41d4-a716-446655440000", "6ba7b810-9dad-11d1-80b4-00c04fd430c8"],
}
]
location = root_location.copy()
location["spec"]["data"] = data
return location

async def test_cardinality_one_string_hfid_normalized(
self, client_with_schema_01: InfrahubClient, location_with_cardinality_one_string_hfid: dict
) -> None:
"""String HFID for cardinality-one should be wrapped in a list."""
obj = ObjectFile(location="some/path", content=location_with_cardinality_one_string_hfid)
await obj.validate_format(client=client_with_schema_01)

# Track calls to client.create
create_calls = []
original_create = client_with_schema_01.create

async def mock_create(kind, branch=None, data=None, **kwargs):
create_calls.append({"kind": kind, "data": data})
# Return a mock node that has the required methods
node = await original_create(kind=kind, branch=branch, data=data, **kwargs)
return node

client_with_schema_01.create = mock_create

# Mock the save method to avoid API calls
from unittest.mock import AsyncMock, patch

with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock):
await obj.process(client=client_with_schema_01)

# Verify the data passed to create has the normalized HFID
assert len(create_calls) == 1
assert create_calls[0]["data"]["primary_tag"] == [["Important"]]

async def test_cardinality_one_list_hfid_unchanged(
self, client_with_schema_01: InfrahubClient, location_with_cardinality_one_list_hfid: dict
) -> None:
"""List HFID for cardinality-one should remain unchanged."""
obj = ObjectFile(location="some/path", content=location_with_cardinality_one_list_hfid)
await obj.validate_format(client=client_with_schema_01)

create_calls = []

async def mock_create(kind, branch=None, data=None, **kwargs):
create_calls.append({"kind": kind, "data": data})
original_create = client_with_schema_01.__class__.create
return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs)

client_with_schema_01.create = mock_create

from unittest.mock import AsyncMock, patch

with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock):
await obj.process(client=client_with_schema_01)

assert len(create_calls) == 1
assert create_calls[0]["data"]["primary_tag"] == [["Important"]]

async def test_cardinality_one_uuid_unchanged(
self, client_with_schema_01: InfrahubClient, location_with_cardinality_one_uuid: dict
) -> None:
"""UUID for cardinality-one should remain as a string (not wrapped in list)."""
obj = ObjectFile(location="some/path", content=location_with_cardinality_one_uuid)
await obj.validate_format(client=client_with_schema_01)

create_calls = []

async def mock_create(kind, branch=None, data=None, **kwargs):
create_calls.append({"kind": kind, "data": data})
original_create = client_with_schema_01.__class__.create
return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs)

client_with_schema_01.create = mock_create

from unittest.mock import AsyncMock, patch

with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock):
await obj.process(client=client_with_schema_01)

assert len(create_calls) == 1
# UUID should be passed as-is (not wrapped in a list)
assert create_calls[0]["data"]["primary_tag"] == ["550e8400-e29b-41d4-a716-446655440000"]

async def test_cardinality_many_string_hfids_normalized(
self, client_with_schema_01: InfrahubClient, location_with_cardinality_many_string_hfids: dict
) -> None:
"""String HFIDs for cardinality-many should each be wrapped in a list."""
obj = ObjectFile(location="some/path", content=location_with_cardinality_many_string_hfids)
await obj.validate_format(client=client_with_schema_01)

create_calls = []

async def mock_create(kind, branch=None, data=None, **kwargs):
create_calls.append({"kind": kind, "data": data})
original_create = client_with_schema_01.__class__.create
return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs)

client_with_schema_01.create = mock_create

from unittest.mock import AsyncMock, patch

with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock):
await obj.process(client=client_with_schema_01)

assert len(create_calls) == 1
assert create_calls[0]["data"]["tags"] == [["Important"], ["Active"]]

async def test_cardinality_many_list_hfids_unchanged(
self, client_with_schema_01: InfrahubClient, location_with_cardinality_many_list_hfids: dict
) -> None:
"""List HFIDs for cardinality-many should remain unchanged."""
obj = ObjectFile(location="some/path", content=location_with_cardinality_many_list_hfids)
await obj.validate_format(client=client_with_schema_01)

create_calls = []

async def mock_create(kind, branch=None, data=None, **kwargs):
create_calls.append({"kind": kind, "data": data})
original_create = client_with_schema_01.__class__.create
return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs)

client_with_schema_01.create = mock_create

from unittest.mock import AsyncMock, patch

with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock):
await obj.process(client=client_with_schema_01)

assert len(create_calls) == 1
assert create_calls[0]["data"]["tags"] == [["Important"], ["Active"]]

async def test_cardinality_many_mixed_hfids_normalized(
self, client_with_schema_01: InfrahubClient, location_with_cardinality_many_mixed_hfids: dict
) -> None:
"""Mixed string and list HFIDs for cardinality-many should be normalized correctly."""
obj = ObjectFile(location="some/path", content=location_with_cardinality_many_mixed_hfids)
await obj.validate_format(client=client_with_schema_01)

create_calls = []

async def mock_create(kind, branch=None, data=None, **kwargs):
create_calls.append({"kind": kind, "data": data})
original_create = client_with_schema_01.__class__.create
return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs)

client_with_schema_01.create = mock_create

from unittest.mock import AsyncMock, patch

with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock):
await obj.process(client=client_with_schema_01)

assert len(create_calls) == 1
# "Important" should be wrapped, ["namespace", "name"] should remain unchanged
assert create_calls[0]["data"]["tags"] == [["Important"], ["namespace", "name"]]

async def test_cardinality_many_uuids_unchanged(
self, client_with_schema_01: InfrahubClient, location_with_cardinality_many_uuids: dict
) -> None:
"""UUIDs for cardinality-many should remain as strings (not wrapped)."""
obj = ObjectFile(location="some/path", content=location_with_cardinality_many_uuids)
await obj.validate_format(client=client_with_schema_01)

create_calls = []

async def mock_create(kind, branch=None, data=None, **kwargs):
create_calls.append({"kind": kind, "data": data})
original_create = client_with_schema_01.__class__.create
return await original_create(client_with_schema_01, kind=kind, branch=branch, data=data, **kwargs)

client_with_schema_01.create = mock_create

from unittest.mock import AsyncMock, patch

with patch("infrahub_sdk.node.InfrahubNode.save", new_callable=AsyncMock):
await obj.process(client=client_with_schema_01)

assert len(create_calls) == 1
# UUIDs should remain as-is
assert create_calls[0]["data"]["tags"] == [
"550e8400-e29b-41d4-a716-446655440000",
"6ba7b810-9dad-11d1-80b4-00c04fd430c8",
]
Loading