diff --git a/docker/requirements-diode-netbox-plugin.txt b/docker/requirements-diode-netbox-plugin.txt index 8285846..644111c 100644 --- a/docker/requirements-diode-netbox-plugin.txt +++ b/docker/requirements-diode-netbox-plugin.txt @@ -4,4 +4,4 @@ coverage==7.6.0 grpcio==1.62.1 protobuf==5.28.1 pytest==8.0.2 -netboxlabs-netbox-branching \ No newline at end of file +netboxlabs-netbox-branching==0.5.7 \ No newline at end of file diff --git a/netbox_diode_plugin/api/applier.py b/netbox_diode_plugin/api/applier.py index 3dc6532..4ecc6fe 100644 --- a/netbox_diode_plugin/api/applier.py +++ b/netbox_diode_plugin/api/applier.py @@ -101,7 +101,7 @@ def _pre_apply(model_class: models.Model, change: Change, created: dict): # resolve foreign key references to new objects for ref_field in change.new_refs: v = _get_path(data, ref_field) - if isinstance(v, (list, tuple)): + if isinstance(v, list | tuple): ref_list = [] for ref in v: if isinstance(ref, str): diff --git a/netbox_diode_plugin/api/common.py b/netbox_diode_plugin/api/common.py index a00a504..17026b6 100644 --- a/netbox_diode_plugin/api/common.py +++ b/netbox_diode_plugin/api/common.py @@ -9,6 +9,7 @@ from collections import defaultdict from dataclasses import dataclass, field from enum import Enum +from zoneinfo import ZoneInfo import netaddr from django.apps import apps @@ -20,7 +21,6 @@ from extras.models import CustomField from netaddr.eui import EUI from rest_framework import status -from zoneinfo import ZoneInfo logger = logging.getLogger("netbox.diode_data") @@ -166,7 +166,7 @@ def _validate_relations(self, change_data: dict, model: models.Model) -> tuple[l excluded_relation_fields = [] rel_errors = defaultdict(list) for f in model._meta.get_fields(): - if isinstance(f, (GenericRelation, GenericForeignKey)): + if isinstance(f, GenericRelation | GenericForeignKey): excluded_relation_fields.append(f.name) continue if not f.is_relation: @@ -251,7 +251,7 @@ def error_from_validation_error(e, object_name): if e.detail: if isinstance(e.detail, dict): errors[object_name] = e.detail - elif isinstance(e.detail, (list, tuple)): + elif isinstance(e.detail, list | tuple): errors[object_name] = { NON_FIELD_ERRORS: e.detail } diff --git a/netbox_diode_plugin/api/compat.py b/netbox_diode_plugin/api/compat.py new file mode 100644 index 0000000..cf31f8c --- /dev/null +++ b/netbox_diode_plugin/api/compat.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# Copyright 2025 NetBox Labs Inc +"""Diode NetBox Plugin - API - Compatibility Transformations.""" + +import logging +import re +from collections import defaultdict +from functools import cache + +from django.conf import settings +from packaging import version +from utilities.release import load_release_data + +logger = logging.getLogger(__name__) + +_MIGRATIONS_BY_OBJECT_TYPE = defaultdict(list) + +def apply_entity_migrations(data: dict, object_type: str): + """ + Applies migrations to diode entity data prior to diffing to improve compatibility with current NetBox version. + + These represent cases like deprecated fields that have been replaced with new fields, but + are supported for backwards compatibility. + """ + for migration in _MIGRATIONS_BY_OBJECT_TYPE.get(object_type, []): + logger.debug(f"Applying migration {migration.__name__} for {object_type}") + migration(data) + +def _register_migration(func, min_version, max_version, object_type): + """Registers a migration function.""" + if in_version_range(min_version, max_version): + logger.debug(f"Registering migration {func.__name__} for {object_type}.") + _MIGRATIONS_BY_OBJECT_TYPE[object_type].append(func) + else: + logger.debug(f"Skipping migration {func.__name__} for {object_type}: {min_version} to {max_version}.") + +@cache +def _current_netbox_version(): + """Returns the current version of NetBox.""" + try: + return version.parse(settings.RELEASE.version) + except Exception: + logger.exception("Failed to determine current version of NetBox.") + return (0, 0, 0) + +def in_version_range(min_version: str | None, max_version: str | None): + """Returns True if the current version of NetBox is within the given version range.""" + min_version = version.parse(min_version) if min_version else None + max_version = version.parse(max_version) if max_version else None + current_version = _current_netbox_version() + if min_version and current_version < min_version: + return False + if max_version and current_version > max_version: + return False + return True + +def diode_migration(min_version: str, max_version: str | None, object_type: str): + """Decorator to mark a function as a diode migration.""" + def decorator(func): + _register_migration(func, min_version, max_version, object_type) + return func + return decorator + +@diode_migration(min_version="4.3.0", max_version=None, object_type="ipam.service") +def _migrate_service_parent_object(data: dict): + """Transforms ipam.service device and virtual_machine references to parent_object.""" + device = data.pop("device", None) + if device: + if data.get("parent_object_device") is None: + data["parent_object_device"] = device + # else ignored. + + virtual_machine = data.pop("virtual_machine", None) + if virtual_machine: + if data.get("parent_object_virtual_machine") is None: + data["parent_object_virtual_machine"] = virtual_machine + # else ignored. + +@diode_migration(min_version="4.3.0", max_version=None, object_type="tenancy.contact") +def _migrate_contact_group(data: dict): + """Transforms tenancy.contact group references to groups.""" + group = data.pop("group", None) + if group: + if data.get("groups") is None: + data["groups"] = [group] + # else ignored. diff --git a/netbox_diode_plugin/api/differ.py b/netbox_diode_plugin/api/differ.py index e213784..40d7e93 100644 --- a/netbox_diode_plugin/api/differ.py +++ b/netbox_diode_plugin/api/differ.py @@ -84,7 +84,7 @@ def prechange_data_from_instance(instance) -> dict: # noqa: C901 custom_field_values = instance.get_custom_fields() cfmap = {} for cf, value in custom_field_values.items(): - if isinstance(value, (datetime.datetime, datetime.date)): + if isinstance(value, datetime.datetime | datetime.date): cfmap[cf.name] = value else: cfmap[cf.name] = cf.serialize(value) diff --git a/netbox_diode_plugin/api/matcher.py b/netbox_diode_plugin/api/matcher.py index 1e4c512..4d2dca5 100644 --- a/netbox_diode_plugin/api/matcher.py +++ b/netbox_diode_plugin/api/matcher.py @@ -5,7 +5,6 @@ import logging from dataclasses import dataclass from functools import cache, lru_cache -from typing import Type import netaddr from django.contrib.contenttypes.fields import ContentType @@ -18,6 +17,7 @@ from extras.models.customfields import CustomField from .common import UnresolvedReference +from .compat import in_version_range from .plugin_utils import content_type_id, get_object_type, get_object_type_model logger = logging.getLogger(__name__) @@ -163,18 +163,28 @@ name="logical_service_name_no_device_or_vm", model_class=get_object_type_model("ipam.service"), condition=Q(device__isnull=True, virtual_machine__isnull=True), + max_version="4.2.99", ), ObjectMatchCriteria( fields=("name", "device"), name="logical_service_name_on_device", model_class=get_object_type_model("ipam.service"), condition=Q(device__isnull=False), + max_version="4.2.99", ), ObjectMatchCriteria( fields=("name", "virtual_machine"), name="logical_service_name_on_vm", model_class=get_object_type_model("ipam.service"), condition=Q(virtual_machine__isnull=False), + max_version="4.2.99", + ), + ObjectMatchCriteria( + fields=("name", "parent_object_type", "parent_object_id"), + name="logical_service_name_on_parent", + model_class=get_object_type_model("ipam.service"), + condition=Q(parent_object_type__isnull=False), + min_version="4.3.0" ), ], "dcim.modulebay": lambda: [ @@ -202,6 +212,32 @@ model_class=get_object_type_model("ipam.fhrpgroup"), ) ], + "tenancy.contact": lambda: [ + ObjectMatchCriteria( + # contacts are unconstrained in 4.3.0 + # in 4.2 they are constrained by unique name per group + fields=("name", ), + name="logical_contact_name", + model_class=get_object_type_model("tenancy.contact"), + min_version="4.3.0", + ) + ], + "dcim.devicerole": lambda: [ + ObjectMatchCriteria( + fields=("name",), + name="logical_device_role_name_no_parent", + model_class=get_object_type_model("dcim.devicerole"), + condition=Q(parent__isnull=True), + min_version="4.3.0", + ), + ObjectMatchCriteria( + fields=("slug",), + name="logical_device_role_slug_no_parent", + model_class=get_object_type_model("dcim.devicerole"), + condition=Q(parent__isnull=True), + min_version="4.3.0", + ) + ], } @dataclass @@ -221,9 +257,12 @@ class ObjectMatchCriteria: fields: tuple[str] | None = None expressions: tuple | None = None condition: Q | None = None - model_class: Type[models.Model] | None = None + model_class: type[models.Model] | None = None name: str | None = None + min_version: str | None = None + max_version: str | None = None + def __hash__(self): """Hash the object match criteria.""" return hash((self.fields, self.expressions, self.condition, self.model_class.__name__, self.name)) @@ -365,7 +404,7 @@ def _build_expressions_queryset(self, data) -> models.QuerySet: """Builds a queryset for the constraint with the given data.""" data = self._prepare_data(data) replacements = { - F(field): Value(value) if isinstance(value, (str, int, float, bool)) else value + F(field): Value(value) if isinstance(value, str | int | float | bool) else value for field, value in data.items() } @@ -413,7 +452,10 @@ class CustomFieldMatcher: name: str custom_field: str - model_class: Type[models.Model] + model_class: type[models.Model] + + min_version: str | None = None + max_version: str | None = None def fingerprint(self, data: dict) -> str|None: """Fingerprint the custom field value.""" @@ -448,9 +490,12 @@ class GlobalIPNetworkIPMatcher: ip_fields: tuple[str] vrf_field: str - model_class: Type[models.Model] + model_class: type[models.Model] name: str + min_version: str | None = None + max_version: str | None = None + def _check_condition(self, data: dict) -> bool: """Check the condition for the custom field.""" return data.get(self.vrf_field, None) is None @@ -508,9 +553,12 @@ class VRFIPNetworkIPMatcher: ip_fields: tuple[str] vrf_field: str - model_class: Type[models.Model] + model_class: type[models.Model] name: str + min_version: str | None = None + max_version: str | None = None + def _check_condition(self, data: dict) -> bool: """Check the condition for the custom field.""" return data.get(self.vrf_field, None) is not None @@ -583,7 +631,10 @@ class AutoSlugMatcher: name: str slug_field: str - model_class: Type[models.Model] + model_class: type[models.Model] + + min_version: str | None = None + max_version: str | None = None def fingerprint(self, data: dict) -> str|None: """Fingerprint the custom field value.""" @@ -650,7 +701,10 @@ def _get_autoslug_matchers(model_class) -> list: @lru_cache(maxsize=256) def _get_model_matchers(model_class) -> list[ObjectMatchCriteria]: object_type = get_object_type(model_class) - matchers = _LOGICAL_MATCHERS.get(object_type, lambda: [])() + matchers = [ + x for x in _LOGICAL_MATCHERS.get(object_type, lambda: [])() + if in_version_range(x.min_version, x.max_version) + ] # collect single fields that are unique for field in model_class._meta.fields: @@ -750,7 +804,7 @@ def _fingerprint_all(data: dict, object_type: str|None = None) -> str: if k.startswith("_"): continue values.append(k) - if isinstance(v, (list, tuple)): + if isinstance(v, list | tuple): values.extend(sorted(v)) elif isinstance(v, dict): values.append(_fingerprint_all(v)) diff --git a/netbox_diode_plugin/api/plugin_utils.py b/netbox_diode_plugin/api/plugin_utils.py index 15247ca..5d9e31c 100644 --- a/netbox_diode_plugin/api/plugin_utils.py +++ b/netbox_diode_plugin/api/plugin_utils.py @@ -544,6 +544,8 @@ class RefInfo: 'ipam.service': { 'device': RefInfo(object_type='dcim.device', field_name='device'), 'ipaddresses': RefInfo(object_type='ipam.ipaddress', field_name='ipaddresses', is_many=True), + 'parent_object_device': RefInfo(object_type='dcim.device', field_name='parent_object', is_generic=True), + 'parent_object_virtual_machine': RefInfo(object_type='virtualization.virtualmachine', field_name='parent_object', is_generic=True), 'tags': RefInfo(object_type='extras.tag', field_name='tags', is_many=True), 'virtual_machine': RefInfo(object_type='virtualization.virtualmachine', field_name='virtual_machine'), }, @@ -576,6 +578,7 @@ class RefInfo: }, 'tenancy.contact': { 'group': RefInfo(object_type='tenancy.contactgroup', field_name='group'), + 'groups': RefInfo(object_type='tenancy.contactgroup', field_name='groups', is_many=True), 'tags': RefInfo(object_type='extras.tag', field_name='tags', is_many=True), }, 'tenancy.contactassignment': { @@ -948,13 +951,13 @@ def get_json_ref_info(object_type: str|Type[models.Model], json_field_name: str) 'ipam.rir': frozenset(['custom_fields', 'description', 'is_private', 'name', 'slug', 'tags']), 'ipam.role': frozenset(['custom_fields', 'description', 'name', 'slug', 'tags', 'weight']), 'ipam.routetarget': frozenset(['comments', 'custom_fields', 'description', 'name', 'tags', 'tenant']), - 'ipam.service': frozenset(['comments', 'custom_fields', 'description', 'device', 'ipaddresses', 'name', 'ports', 'protocol', 'tags', 'virtual_machine']), + 'ipam.service': frozenset(['comments', 'custom_fields', 'description', 'device', 'ipaddresses', 'name', 'parent_object_id', 'parent_object_type', 'ports', 'protocol', 'tags', 'virtual_machine']), 'ipam.vlan': frozenset(['comments', 'custom_fields', 'description', 'group', 'name', 'qinq_role', 'qinq_svlan', 'role', 'site', 'status', 'tags', 'tenant', 'vid']), 'ipam.vlangroup': frozenset(['custom_fields', 'description', 'name', 'scope_id', 'scope_type', 'slug', 'tags', 'vid_ranges']), 'ipam.vlantranslationpolicy': frozenset(['description', 'name']), 'ipam.vlantranslationrule': frozenset(['description', 'local_vid', 'policy', 'remote_vid']), 'ipam.vrf': frozenset(['comments', 'custom_fields', 'description', 'enforce_unique', 'export_targets', 'import_targets', 'name', 'rd', 'tags', 'tenant']), - 'tenancy.contact': frozenset(['address', 'comments', 'custom_fields', 'description', 'email', 'group', 'link', 'name', 'phone', 'tags', 'title']), + 'tenancy.contact': frozenset(['address', 'comments', 'custom_fields', 'description', 'email', 'group', 'groups', 'link', 'name', 'phone', 'tags', 'title']), 'tenancy.contactassignment': frozenset(['contact', 'custom_fields', 'object_id', 'object_type', 'priority', 'role', 'tags']), 'tenancy.contactgroup': frozenset(['custom_fields', 'description', 'name', 'parent', 'slug', 'tags']), 'tenancy.contactrole': frozenset(['custom_fields', 'description', 'name', 'slug', 'tags']), diff --git a/netbox_diode_plugin/api/supported_models.py b/netbox_diode_plugin/api/supported_models.py index b2b7a09..2315a86 100644 --- a/netbox_diode_plugin/api/supported_models.py +++ b/netbox_diode_plugin/api/supported_models.py @@ -6,7 +6,6 @@ import logging import time from functools import lru_cache -from typing import List, Type from django.apps import apps from django.db import models @@ -82,9 +81,9 @@ def extract_supported_models() -> dict[str, dict]: return extracted_models -def get_prerequisites(model_class, fields) -> List[dict[str, str]]: +def get_prerequisites(model_class, fields) -> list[dict[str, str]]: """Get the prerequisite models for the model.""" - prerequisites: List[dict[str, str]] = [] + prerequisites: list[dict[str, str]] = [] prerequisite_models = getattr(model_class, "prerequisite_models", []) for prereq in prerequisite_models: @@ -252,7 +251,7 @@ def get_serializer_for_model(model, prefix=""): return netbox_get_serializer_for_model(model, prefix) -def discover_models(root_packages: List[str]) -> list[Type[models.Model]]: +def discover_models(root_packages: list[str]) -> list[type[models.Model]]: """Discovers all model classes in specified root packages.""" discovered_models = [] diff --git a/netbox_diode_plugin/api/transformer.py b/netbox_diode_plugin/api/transformer.py index 5876425..d97802d 100644 --- a/netbox_diode_plugin/api/transformer.py +++ b/netbox_diode_plugin/api/transformer.py @@ -4,6 +4,7 @@ import copy import datetime +import graphlib import json import logging import re @@ -11,12 +12,12 @@ from functools import lru_cache from uuid import uuid4 -import graphlib from django.utils.text import slugify from extras.models.customfields import CustomField from rest_framework import serializers from .common import NON_FIELD_ERRORS, AutoSlug, ChangeSetException, UnresolvedReference, harmonize_formats, sort_ints_first +from .compat import apply_entity_migrations from .matcher import find_existing_object, fingerprints from .plugin_utils import ( CUSTOM_FIELD_OBJECT_REFERENCE_TYPE, @@ -125,6 +126,7 @@ def _transform_proto_json_1(proto_json: dict, object_type: str, context=None) -> # handle camelCase protoJSON if provided... proto_json = _ensure_snake_case(proto_json, object_type) apply_format_transformations(proto_json, object_type) + apply_entity_migrations(proto_json, object_type) # context pushed down from parent nodes if context is not None: @@ -474,7 +476,7 @@ def _update_dict_refs(data, new_refs): for k, v in data.items(): if isinstance(v, UnresolvedReference) and v.uuid in new_refs: v.uuid = new_refs[v.uuid] - elif isinstance(v, (list, tuple)): + elif isinstance(v, list | tuple): for item in v: if isinstance(item, UnresolvedReference) and item.uuid in new_refs: item.uuid = new_refs[item.uuid] @@ -517,7 +519,7 @@ def _update_resolved_refs(data, new_refs): for k, v in list(data.items()): if isinstance(v, UnresolvedReference) and v.uuid in new_refs: data[k] = new_refs[v.uuid] - elif isinstance(v, (list, tuple)): + elif isinstance(v, list | tuple): new_items = [] has_refs = False for item in v: @@ -539,7 +541,7 @@ def cleanup_unresolved_references(data: dict) -> list[str]: if k != 'id': unresolved.add(k) data[k] = str(v) - elif isinstance(v, (list, tuple)): + elif isinstance(v, list | tuple): items = [] for item in v: if isinstance(item, UnresolvedReference): @@ -608,7 +610,7 @@ def _prepare_custom_fields(object_type: str, custom_fields: dict) -> tuple[dict, keyname = key try: value_type, value = _pop_custom_field_type_and_value(value) - if value_type in ("text", "longText", "decimal", "boolean", "datetime", "selection", "url", "multipleSelection"): + if value_type in ("text", "long_text", "decimal", "boolean", "datetime", "selection", "url", "multiple_selection"): out[key] = value elif value_type == "date": # truncate to YYYY-MM-DD @@ -629,11 +631,11 @@ def _prepare_custom_fields(object_type: str, custom_fields: dict) -> tuple[dict, object_type=ref['_object_type'], uuid=ref['_uuid'], ) - elif value_type == "multipleObjects": + elif value_type == "multiple_objects": vals = [] for i, item in enumerate(value): keyname = f"{key}[{i}]" - nested = _prepare_custom_ref(value) + nested = _prepare_custom_ref(item) ref = nested[0] refs.add(ref['_uuid']) nodes += nested diff --git a/netbox_diode_plugin/navigation.py b/netbox_diode_plugin/navigation.py index cdc64ce..7efb9ff 100644 --- a/netbox_diode_plugin/navigation.py +++ b/netbox_diode_plugin/navigation.py @@ -9,12 +9,12 @@ PluginMenuItem( link="plugins:netbox_diode_plugin:settings", link_text=_("Settings"), - staff_only= True, + permissions=("netbox_diode_plugin.view_setting",), ), PluginMenuItem( link="plugins:netbox_diode_plugin:client_credential_list", link_text=_("Client Credentials"), - staff_only= True, + permissions=("netbox_diode_plugin.view_clientcredentials",), ), ) diff --git a/netbox_diode_plugin/tests/test_api_diff_and_apply.py b/netbox_diode_plugin/tests/test_api_diff_and_apply.py index 3847429..d3081ee 100644 --- a/netbox_diode_plugin/tests/test_api_diff_and_apply.py +++ b/netbox_diode_plugin/tests/test_api_diff_and_apply.py @@ -15,7 +15,7 @@ from core.models import ObjectType from dcim.models import Device, Interface, ModuleBay, Site from extras.models import CustomField -from extras.models.customfields import CustomFieldTypeChoices +from extras.models.customfields import CustomFieldChoiceSet, CustomFieldChoiceSetBaseChoices, CustomFieldTypeChoices from ipam.models import IPAddress, VLANGroup from rest_framework import status from utilities.testing import APITestCase @@ -98,6 +98,59 @@ def setUp(self): self.decimal_field.object_types.set([self.object_type]) self.decimal_field.save() + self.long_text_field = CustomField.objects.create( + name='my_long_text', + type=CustomFieldTypeChoices.TYPE_LONGTEXT, + required=False, + unique=False, + ) + self.long_text_field.object_types.set([self.object_type]) + self.long_text_field.save() + + choices = CustomFieldChoiceSet.objects.create( + name='my_choices', + base_choices=CustomFieldChoiceSetBaseChoices.IATA, + ) + self.selection_field = CustomField.objects.create( + name='my_selection', + type=CustomFieldTypeChoices.TYPE_SELECT, + required=False, + unique=False, + choice_set=choices, + ) + self.selection_field.object_types.set([self.object_type]) + self.selection_field.save() + + self.multiple_selection_field = CustomField.objects.create( + name='my_multiple_selection', + type=CustomFieldTypeChoices.TYPE_MULTISELECT, + required=False, + unique=False, + choice_set=choices, + ) + self.multiple_selection_field.object_types.set([self.object_type]) + self.multiple_selection_field.save() + + self.object_field = CustomField.objects.create( + name='my_object', + type=CustomFieldTypeChoices.TYPE_OBJECT, + required=False, + unique=False, + related_object_type=self.object_type, + ) + self.object_field.object_types.set([self.object_type]) + self.object_field.save() + + self.multiple_objects_field = CustomField.objects.create( + name='my_multiple_objects', + type=CustomFieldTypeChoices.TYPE_MULTIOBJECT, + required=False, + unique=False, + related_object_type=self.object_type, + ) + self.multiple_objects_field.object_types.set([self.object_type]) + self.multiple_objects_field.save() + def tearDown(self): """Clean up after tests.""" self.introspect_patcher.stop() @@ -539,6 +592,36 @@ def test_generate_diff_and_apply_create_and_update_site_with_custom_field(self): "some_json": { "json": '{"some_key": 9876543210}', }, + "my_long_text": { + "long_text": "This is a long text", + }, + "my_selection": { + "selection": "LAX", + }, + "my_multiple_selection": { + "multiple_selection": ["JFK", "LAX"], + }, + "my_object": { + "object": { + "site": { + "name": "Custom Object Site Ref 1", + } + }, + }, + "my_multiple_objects": { + "multiple_objects": [ + { + "site": { + "name": "Custom Object Site Ref 2", + } + }, + { + "site": { + "name": "Custom Object Site Ref 3", + } + }, + ], + }, }, }, } @@ -549,6 +632,19 @@ def test_generate_diff_and_apply_create_and_update_site_with_custom_field(self): self.assertEqual(new_site.custom_field_data[self.uuid_field.name], site_uuid) self.assertEqual(new_site.custom_field_data[self.json_field.name], {"some_key": 9876543210}) self.assertEqual(new_site.custom_field_data[self.decimal_field.name], 1234.567) + self.assertEqual(new_site.custom_field_data[self.long_text_field.name], "This is a long text") + self.assertEqual(new_site.custom_field_data[self.selection_field.name], "LAX") + self.assertEqual(new_site.custom_field_data[self.multiple_selection_field.name], ["JFK", "LAX"]) + + siteRef1 = Site.objects.get(name="Custom Object Site Ref 1") + self.assertIsNotNone(siteRef1) + self.assertEqual(new_site.custom_field_data[self.object_field.name], siteRef1.pk) + siteRef2 = Site.objects.get(name="Custom Object Site Ref 2") + self.assertIsNotNone(siteRef2) + self.assertEqual(new_site.custom_field_data[self.multiple_objects_field.name][0], siteRef2.pk) + siteRef3 = Site.objects.get(name="Custom Object Site Ref 3") + self.assertIsNotNone(siteRef3) + self.assertEqual(new_site.custom_field_data[self.multiple_objects_field.name][1], siteRef3.pk) payload = { "timestamp": 1, @@ -596,6 +692,21 @@ def test_generate_diff_and_apply_create_and_update_site_with_custom_field(self): "mydate": { "date": "2026-01-02T00:00:00Z", }, + "my_multiple_objects": { + "multiple_objects": [ + { + "site": { + "name": "Custom Object Site Ref 2", + } + }, + { + "site": { + "name": "Custom Object Site Ref 4", + } + }, + ], + }, + }, }, } @@ -608,6 +719,15 @@ def test_generate_diff_and_apply_create_and_update_site_with_custom_field(self): self.assertEqual(new_site.cf[self.datetime_field.name], datetime.datetime(2026, 1, 1, 10, 0, 0, tzinfo=datetime.timezone.utc)) self.assertEqual(new_site.cf[self.date_field.name], datetime.date(2026, 1, 2)) + self.assertEqual(len(new_site.custom_field_data[self.multiple_objects_field.name]), 2) + siteRef2 = Site.objects.get(name="Custom Object Site Ref 2") + self.assertIsNotNone(siteRef2) + self.assertEqual(new_site.custom_field_data[self.multiple_objects_field.name][0], siteRef2.pk) + siteRef4 = Site.objects.get(name="Custom Object Site Ref 4") + self.assertIsNotNone(siteRef3) + self.assertEqual(new_site.custom_field_data[self.multiple_objects_field.name][1], siteRef4.pk) + + payload = { "timestamp": 1, "object_type": "dcim.site", diff --git a/netbox_diode_plugin/tests/test_updates.py b/netbox_diode_plugin/tests/test_updates.py index d0f085c..57e4981 100644 --- a/netbox_diode_plugin/tests/test_updates.py +++ b/netbox_diode_plugin/tests/test_updates.py @@ -26,7 +26,7 @@ def _harmonize_formats(data): return _tuples_to_lists(data) def _tuples_to_lists(data): - if isinstance(data, (tuple, list)): + if isinstance(data, tuple | list): return [_tuples_to_lists(d) for d in data] if isinstance(data, dict): return {k: _tuples_to_lists(v) for k, v in data.items()} @@ -165,7 +165,7 @@ def _check_set_by(self, obj, path, value): path = path[:-1] cur = self._follow_path(obj, path) - if isinstance(value, (list, tuple)): + if isinstance(value, list | tuple): vals = set(value) else: vals = {value} diff --git a/netbox_diode_plugin/tests/test_views.py b/netbox_diode_plugin/tests/test_views.py index cf5d93a..80620f6 100644 --- a/netbox_diode_plugin/tests/test_views.py +++ b/netbox_diode_plugin/tests/test_views.py @@ -8,9 +8,12 @@ from django.contrib.messages.middleware import MessageMiddleware from django.contrib.messages.storage.fallback import FallbackStorage from django.contrib.sessions.middleware import SessionMiddleware -from django.test import RequestFactory, TestCase +from django.test import RequestFactory +from django.test import TestCase as _TestCase from django.urls import reverse from rest_framework import status +from users.models import ObjectPermission +from utilities.permissions import resolve_permission_type from netbox_diode_plugin.models import Setting from netbox_diode_plugin.views import SettingsEditView, SettingsView @@ -18,6 +21,19 @@ User = get_user_model() +class TestCase(_TestCase): + """Base test case class for NetBox Diode plugin tests.""" + + def add_permissions(self, user, *names): + """Assign a set of permissions to the test user. Accepts permission names in the form ._.""" + for name in names: + object_type, action = resolve_permission_type(name) + obj_perm = ObjectPermission(name=name, actions=[action]) + obj_perm.save() + obj_perm.users.add(user) + obj_perm.object_types.add(object_type) + + class SettingsViewTestCase(TestCase): """Test case for the SettingsView.""" @@ -31,7 +47,7 @@ def setUp(self): def test_returns_200_for_authenticated(self): """Test that the view returns 200 for an authenticated user.""" self.request.user = User.objects.create_user("foo", password="pass") - self.request.user.is_staff = True + self.add_permissions(self.request.user, "netbox_diode_plugin.view_setting") response = self.view.get(self.request) self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -49,7 +65,7 @@ def test_redirects_to_login_page_for_unauthenticated_user(self): def test_settings_created_if_not_found(self): """Test that the settings are created with placeholder data if not found.""" self.request.user = User.objects.create_user("foo", password="pass") - self.request.user.is_staff = True + self.add_permissions(self.request.user, "netbox_diode_plugin.view_setting") with mock.patch("netbox_diode_plugin.models.Setting.objects.get") as mock_get: mock_get.side_effect = Setting.DoesNotExist @@ -71,8 +87,8 @@ def setUp(self): def test_returns_200_for_authenticated(self): """Test that the view returns 200 for an authenticated user.""" request = self.request_factory.get(self.path) - request.user = User.objects.create_user("foo", password="pass") - request.user.is_staff = True + request.user = User.objects.create_user("foo", password="pass", is_staff=True) + self.add_permissions(request.user, "netbox_diode_plugin.view_setting", "netbox_diode_plugin.change_setting") request.htmx = None self.view.setup(request) @@ -91,8 +107,8 @@ def test_redirects_to_login_page_for_unauthenticated_user(self): def test_settings_updated(self): """Test that the settings are updated.""" - user = User.objects.create_user("foo", password="pass") - user.is_staff = True + user = User.objects.create_user("foo", password="pass", is_staff=True) + self.add_permissions(user, "netbox_diode_plugin.view_setting", "netbox_diode_plugin.change_setting") request = self.request_factory.get(self.path) request.user = user @@ -149,8 +165,13 @@ def test_settings_update_disallowed_on_get_method(self): ) as mock_get_plugin_config: mock_get_plugin_config.return_value = "grpc://localhost:8080/diode" - user = User.objects.create_user("foo", password="pass") - user.is_staff = True + user = User.objects.create_user("foo", password="pass", is_staff=True) + self.add_permissions( + user, + "netbox_diode_plugin.view_setting", + "netbox_diode_plugin.add_setting", + "netbox_diode_plugin.change_setting", + ) request = self.request_factory.post(self.path) request.user = user @@ -188,8 +209,13 @@ def test_settings_update_disallowed_on_post_method(self): ) as mock_get_plugin_config: mock_get_plugin_config.return_value = "grpc://localhost:8080/diode" - user = User.objects.create_user("foo", password="pass") - user.is_staff = True + user = User.objects.create_user("foo", password="pass", is_staff=True) + self.add_permissions( + user, + "netbox_diode_plugin.view_setting", + "netbox_diode_plugin.add_setting", + "netbox_diode_plugin.change_setting", + ) request = self.request_factory.post(self.path) request.user = user diff --git a/netbox_diode_plugin/views.py b/netbox_diode_plugin/views.py index 1143c5d..79f9cc6 100644 --- a/netbox_diode_plugin/views.py +++ b/netbox_diode_plugin/views.py @@ -2,10 +2,13 @@ # Copyright 2025 NetBox Labs, Inc. """Diode NetBox Plugin - Views.""" import logging +from collections import defaultdict from django.conf import settings as netbox_settings from django.contrib import messages from django.contrib.auth import get_user_model +from django.core.exceptions import ImproperlyConfigured +from django.db.models import Q from django.http import HttpResponseRedirect from django.shortcuts import redirect, render from django.urls import reverse @@ -14,9 +17,10 @@ from django.views.generic import View from netbox.plugins import get_plugin_config from netbox.views import generic +from users.models import ObjectPermission from utilities.forms import ConfirmationForm from utilities.htmx import htmx_partial -from utilities.permissions import get_permission_for_model +from utilities.permissions import get_permission_for_model, permission_is_exempt from utilities.views import register_model_view from netbox_diode_plugin.client import create_client, delete_client, get_client, list_clients @@ -40,13 +44,105 @@ def redirect_to_login(request): return HttpResponseRedirect(redirect_url) -class SettingsView(View): +class BaseDiodeView(View): + """ + Base view class for Diode plugin views. + + Provides authentication and permission checking functionality for views + that need to interact with the Diode API. Includes methods for: + - Object permission filtering and retrieval + - Permission checking for authenticated users + - Authentication validation for requests + """ + + def get_permission_filter(self, user_obj): + """Return the permission filter for the user.""" + return Q(users=user_obj) | Q(groups__user=user_obj) + + def get_object_permissions(self, user_obj): + """Return all permissions granted to the user by an ObjectPermission.""" + # Initialize a dictionary mapping permission names to sets of constraints + perms = defaultdict(list) + + # Collect any configured default permissions + for perm_name, constraints in netbox_settings.DEFAULT_PERMISSIONS.items(): + constraints = constraints or () + if type(constraints) not in (list, tuple): + raise ImproperlyConfigured( + f"Constraints for default permission {perm_name} must be defined as a list or tuple." + ) + perms[perm_name].extend(constraints) + + # Retrieve all assigned and enabled ObjectPermissions + object_permissions = ObjectPermission.objects.filter( + self.get_permission_filter(user_obj), + enabled=True + ).order_by('id').distinct('id').prefetch_related('object_types') + + # Create a dictionary mapping permissions to their constraints + for obj_perm in object_permissions: + for object_type in obj_perm.object_types.all(): + for action in obj_perm.actions: + perm_name = f"{object_type.app_label}.{action}_{object_type.model}" + perms[perm_name].extend(obj_perm.list_constraints()) + + return perms + + def get_all_permissions(self, user_obj, obj=None): + """Get all permissions for the user.""" + if not user_obj.is_active or user_obj.is_anonymous: + return {} + if not hasattr(user_obj, '_object_perm_cache'): + user_obj._object_perm_cache = self.get_object_permissions(user_obj) + return user_obj._object_perm_cache + + def has_perm(self, user_obj, perm): + """Check if the user has the required permission.""" + # Superusers implicitly have all permissions + if not user_obj.is_authenticated: + return False + + if user_obj.is_active and user_obj.is_superuser: + return True + + # Permission is exempt from enforcement (i.e. listed in EXEMPT_VIEW_PERMISSIONS) + if permission_is_exempt(perm): + return True + + # Handle inactive/anonymous users + if not user_obj.is_active or user_obj.is_anonymous: + return False + + object_permissions = self.get_all_permissions(user_obj) + + # If no applicable ObjectPermissions have been created for this user/permission, deny permission + if perm not in object_permissions: + return False + + return True + + def check_authentication(self, request): + """Check if the user has the required permission.""" + if not request.user.is_authenticated: + return redirect_to_login(request) + + if not self.has_perm(request.user, self.get_required_permission()): + return redirect( + reverse("home",) + ) + return None + +class SettingsView(BaseDiodeView): """Settings view.""" + def get_required_permission(self): + """Return the permission required to view Diode plugin settings.""" + return "netbox_diode_plugin.view_setting" + def get(self, request): """Render settings template.""" - if not request.user.is_authenticated or not request.user.is_staff: - return redirect_to_login(request) + if ret := self.check_authentication(request): + return ret diode_target_override = get_plugin_config( "netbox_diode_plugin", "diode_target_override" @@ -73,7 +169,7 @@ def get(self, request): @register_model_view(Setting, "edit") -class SettingsEditView(generic.ObjectEditView): +class SettingsEditView(BaseDiodeView,generic.ObjectEditView): """Settings edit view.""" queryset = Setting.objects @@ -81,10 +177,14 @@ class SettingsEditView(generic.ObjectEditView): template_name = "diode/settings_edit.html" default_return_url = "plugins:netbox_diode_plugin:settings" + def get_required_permission(self): + """Return the permission required to view Diode plugin settings.""" + return "netbox_diode_plugin.change_setting" + def get(self, request, *args, **kwargs): """GET request handler.""" - if not request.user.is_authenticated or not request.user.is_staff: - return redirect_to_login(request) + if ret := self.check_authentication(request): + return ret diode_target_override = get_plugin_config( "netbox_diode_plugin", "diode_target_override" @@ -103,8 +203,8 @@ def get(self, request, *args, **kwargs): def post(self, request, *args, **kwargs): """POST request handler.""" - if not request.user.is_authenticated or not request.user.is_staff: - return redirect_to_login(request) + if ret := self.check_authentication(request): + return ret diode_target_override = get_plugin_config( "netbox_diode_plugin", "diode_target_override" @@ -138,19 +238,6 @@ def get_return_url(self, request): return None -class BaseDiodeView(View): - """Base diode view.""" - - def check_authentication(self, request): - """Check authentication.""" - if not request.user.is_authenticated or not request.user.is_staff: - return redirect_to_login(request) - return None - - def get_required_permission(self): - """Get required permission.""" - return get_permission_for_model(self.model, "view") - class ClientCredentialListView(BaseDiodeView): """Client credential list view.""" @@ -158,6 +245,10 @@ class ClientCredentialListView(BaseDiodeView): template_name = "diode/client_credential_list.html" model = ClientCredentials + def get_required_permission(self): + """Return the permission required to view client credentials list.""" + return "netbox_diode_plugin.view_clientcredentials" + def get_table_data(self, request): """Get table data.""" try: @@ -211,6 +302,10 @@ class ClientCredentialDeleteView(GetReturnURLMixin, BaseDiodeView): template_name = "diode/client_credential_delete.html" default_return_url = "plugins:netbox_diode_plugin:client_credential_list" + def get_required_permission(self): + """Return the permission required to delete client credentials.""" + return "netbox_diode_plugin.delete_clientcredentials" + def get(self, request, client_credential_id): """GET request handler.""" if ret := self.check_authentication(request): @@ -260,6 +355,10 @@ class ClientCredentialAddView(GetReturnURLMixin, BaseDiodeView): form_class = ClientCredentialForm default_return_url = "plugins:netbox_diode_plugin:client_credential_list" + def get_required_permission(self): + """Return the permission required to add new client credentials.""" + return "netbox_diode_plugin.add_clientcredentials" + def get(self, request): """GET request handler.""" if ret := self.check_authentication(request): @@ -312,6 +411,10 @@ class ClientCredentialSecretView(BaseDiodeView): template_name = "diode/client_credential_secret.html" + def get_required_permission(self): + """Return the permission required to view client credential secrets.""" + return "netbox_diode_plugin.view_clientcredentials" + def get(self, request): """Get request handler.""" if ret := self.check_authentication(request): diff --git a/pyproject.toml b/pyproject.toml index 54c7eed..fe04775 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "netboxlabs-diode-netbox-plugin" version = "0.0.1" # Overwritten during the build process description = "NetBox Labs, Diode NetBox plugin" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10" license = { text = "NetBox Limited Use License 1.0" } authors = [ {name = "NetBox Labs", email = "support@netboxlabs.com" } @@ -18,10 +18,10 @@ classifiers = [ "Topic :: Software Development :: Build Tools", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', ] dependencies = [