Closes #21257: Introduce & adopt MultiValueContentTypeFilter (#21417)

This commit is contained in:
Jeremy Stretch
2026-02-13 05:24:36 -05:00
committed by GitHub
parent 76fd3e3c61
commit dc738c7102
21 changed files with 105 additions and 78 deletions

View File

@@ -9,7 +9,7 @@ from ipam.models import ASN
from netbox.filtersets import NetBoxModelFilterSet, OrganizationalModelFilterSet, PrimaryModelFilterSet
from tenancy.filtersets import ContactModelFilterSet, TenancyFilterSet
from utilities.filters import (
ContentTypeFilter, MultiValueCharFilter, MultiValueNumberFilter, TreeNodeMultipleChoiceFilter,
MultiValueCharFilter, MultiValueContentTypeFilter, MultiValueNumberFilter, TreeNodeMultipleChoiceFilter,
)
from utilities.filtersets import register_filterset
from .choices import *
@@ -281,7 +281,7 @@ class CircuitTerminationFilterSet(NetBoxModelFilterSet, CabledObjectFilterSet):
queryset=Circuit.objects.all(),
label=_('Circuit'),
)
termination_type = ContentTypeFilter()
termination_type = MultiValueContentTypeFilter()
region_id = TreeNodeMultipleChoiceFilter(
queryset=Region.objects.all(),
field_name='_region',
@@ -381,7 +381,7 @@ class CircuitGroupAssignmentFilterSet(NetBoxModelFilterSet):
method='search',
label=_('Search'),
)
member_type = ContentTypeFilter()
member_type = MultiValueContentTypeFilter()
circuit = MultiValueCharFilter(
method='filter_circuit',
field_name='cid',

View File

@@ -6,7 +6,7 @@ from django.utils.translation import gettext as _
from netbox.filtersets import BaseFilterSet, ChangeLoggedModelFilterSet, PrimaryModelFilterSet
from netbox.utils import get_data_backend_choices
from users.models import User
from utilities.filters import ContentTypeFilter
from utilities.filters import MultiValueContentTypeFilter
from utilities.filtersets import register_filterset
from .choices import *
from .models import *
@@ -88,7 +88,7 @@ class JobFilterSet(BaseFilterSet):
queryset=ObjectType.objects.with_feature('jobs'),
field_name='object_type_id',
)
object_type = ContentTypeFilter()
object_type = MultiValueContentTypeFilter()
created = django_filters.DateTimeFilter()
created__before = django_filters.DateTimeFilter(
field_name='created',
@@ -180,11 +180,11 @@ class ObjectChangeFilterSet(BaseFilterSet):
label=_('Search'),
)
time = django_filters.DateTimeFromToRangeFilter()
changed_object_type = ContentTypeFilter()
changed_object_type = MultiValueContentTypeFilter()
changed_object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ContentType.objects.all()
)
related_object_type = ContentTypeFilter()
related_object_type = MultiValueContentTypeFilter()
user_id = django_filters.ModelMultipleChoiceFilter(
queryset=User.objects.all(),
label=_('User (ID)'),

View File

@@ -237,9 +237,9 @@ class ObjectChangeTestCase(TestCase, BaseFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_changed_object_type(self):
params = {'changed_object_type': 'dcim.site'}
params = {'changed_object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
params = {'changed_object_type_id': [ContentType.objects.get(app_label='dcim', model='site').pk]}
params = {'changed_object_type_id': [ContentType.objects.get_by_natural_key('dcim', 'site').pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)

View File

@@ -2,7 +2,7 @@ import django_filters
from django.utils.translation import gettext as _
from netbox.filtersets import BaseFilterSet
from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
from utilities.filters import MultiValueContentTypeFilter, TreeNodeMultipleChoiceFilter
from .models import *
__all__ = (
@@ -14,7 +14,7 @@ class ScopedFilterSet(BaseFilterSet):
"""
Provides additional filtering functionality for location, site, etc.. for Scoped models.
"""
scope_type = ContentTypeFilter()
scope_type = MultiValueContentTypeFilter()
region_id = TreeNodeMultipleChoiceFilter(
queryset=Region.objects.all(),
field_name='_region',

View File

@@ -21,8 +21,8 @@ from tenancy.models import *
from users.filterset_mixins import OwnerFilterMixin
from users.models import User
from utilities.filters import (
ContentTypeFilter, MultiValueCharFilter, MultiValueMACAddressFilter, MultiValueNumberFilter, MultiValueWWNFilter,
NumericArrayFilter, TreeNodeMultipleChoiceFilter,
MultiValueCharFilter, MultiValueContentTypeFilter, MultiValueMACAddressFilter, MultiValueNumberFilter,
MultiValueWWNFilter, NumericArrayFilter, TreeNodeMultipleChoiceFilter,
)
from utilities.filtersets import register_filterset
from virtualization.models import Cluster, ClusterGroup, VirtualMachine, VMInterface
@@ -977,7 +977,7 @@ class InventoryItemTemplateFilterSet(ChangeLoggedModelFilterSet, DeviceTypeCompo
to_field_name='slug',
label=_('Role (slug)'),
)
component_type = ContentTypeFilter()
component_type = MultiValueContentTypeFilter()
component_id = MultiValueNumberFilter()
class Meta:
@@ -1822,7 +1822,7 @@ class PowerOutletFilterSet(ModularDeviceComponentFilterSet, CabledObjectFilterSe
@register_filterset
class MACAddressFilterSet(PrimaryModelFilterSet):
mac_address = MultiValueMACAddressFilter()
assigned_object_type = ContentTypeFilter()
assigned_object_type = MultiValueContentTypeFilter()
device = MultiValueCharFilter(
method='filter_device',
field_name='name',
@@ -2267,7 +2267,7 @@ class InventoryItemFilterSet(DeviceComponentFilterSet):
to_field_name='slug',
label=_('Role (slug)'),
)
component_type = ContentTypeFilter()
component_type = MultiValueContentTypeFilter()
component_id = MultiValueNumberFilter()
serial = MultiValueCharFilter(
lookup_expr='iexact'
@@ -2381,14 +2381,14 @@ class VirtualChassisFilterSet(PrimaryModelFilterSet):
@register_filterset
class CableFilterSet(TenancyFilterSet, PrimaryModelFilterSet):
termination_a_type = ContentTypeFilter(
termination_a_type = MultiValueContentTypeFilter(
field_name='terminations__termination_type'
)
termination_a_id = MultiValueNumberFilter(
method='filter_by_cable_end_a',
field_name='terminations__termination_id'
)
termination_b_type = ContentTypeFilter(
termination_b_type = MultiValueContentTypeFilter(
field_name='terminations__termination_type'
)
termination_b_id = MultiValueNumberFilter(
@@ -2554,7 +2554,7 @@ class CableFilterSet(TenancyFilterSet, PrimaryModelFilterSet):
@register_filterset
class CableTerminationFilterSet(ChangeLoggedModelFilterSet):
termination_type = ContentTypeFilter()
termination_type = MultiValueContentTypeFilter()
class Meta:
model = CableTermination

View File

@@ -6251,7 +6251,7 @@ class InventoryItemTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_component_type(self):
params = {'component_type': 'dcim.interface'}
params = {'component_type': ['dcim.interface']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
def test_status(self):
@@ -6723,10 +6723,8 @@ class CableTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_termination_types(self):
params = {'termination_a_type': 'dcim.consoleport'}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
# params = {'termination_b_type': 'dcim.consoleserverport'}
# self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'termination_a_type': ['dcim.consoleport', 'dcim.consoleserverport']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_termination_ids(self):
interface_ids = CableTermination.objects.filter(
@@ -6734,7 +6732,7 @@ class CableTestCase(TestCase, ChangeLoggedFilterSetTests):
cable_end='A'
).values_list('termination_id', flat=True)
params = {
'termination_a_type': 'dcim.interface',
'termination_a_type': ['dcim.interface'],
'termination_a_id': list(interface_ids),
}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)

View File

@@ -10,7 +10,7 @@ from tenancy.models import Tenant, TenantGroup
from users.filterset_mixins import OwnerFilterMixin
from users.models import Group, User
from utilities.filters import (
ContentTypeFilter, MultiValueCharFilter, MultiValueNumberFilter
MultiValueCharFilter, MultiValueContentTypeFilter, MultiValueNumberFilter
)
from utilities.filtersets import register_filterset
from virtualization.models import Cluster, ClusterGroup, ClusterType
@@ -104,7 +104,7 @@ class EventRuleFilterSet(OwnerFilterMixin, NetBoxModelFilterSet):
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
object_type = MultiValueContentTypeFilter(
field_name='object_types'
)
event_type = MultiValueCharFilter(
@@ -113,7 +113,7 @@ class EventRuleFilterSet(OwnerFilterMixin, NetBoxModelFilterSet):
action_type = django_filters.MultipleChoiceFilter(
choices=EventRuleActionChoices
)
action_object_type = ContentTypeFilter()
action_object_type = MultiValueContentTypeFilter()
action_object_id = MultiValueNumberFilter()
class Meta:
@@ -148,14 +148,14 @@ class CustomFieldFilterSet(OwnerFilterMixin, ChangeLoggedModelFilterSet):
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
object_type = MultiValueContentTypeFilter(
field_name='object_types'
)
related_object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ObjectType.objects.all(),
field_name='related_object_type'
)
related_object_type = ContentTypeFilter()
related_object_type = MultiValueContentTypeFilter()
choice_set_id = django_filters.ModelMultipleChoiceFilter(
queryset=CustomFieldChoiceSet.objects.all()
)
@@ -224,7 +224,7 @@ class CustomLinkFilterSet(OwnerFilterMixin, ChangeLoggedModelFilterSet):
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
object_type = MultiValueContentTypeFilter(
field_name='object_types'
)
@@ -255,7 +255,7 @@ class ExportTemplateFilterSet(OwnerFilterMixin, ChangeLoggedModelFilterSet):
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
object_type = MultiValueContentTypeFilter(
field_name='object_types'
)
data_source_id = django_filters.ModelMultipleChoiceFilter(
@@ -294,7 +294,7 @@ class SavedFilterFilterSet(OwnerFilterMixin, ChangeLoggedModelFilterSet):
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
object_type = MultiValueContentTypeFilter(
field_name='object_types'
)
user_id = django_filters.ModelMultipleChoiceFilter(
@@ -347,7 +347,7 @@ class TableConfigFilterSet(ChangeLoggedModelFilterSet):
queryset=ObjectType.objects.all(),
field_name='object_type'
)
object_type = ContentTypeFilter(
object_type = MultiValueContentTypeFilter(
field_name='object_type'
)
user_id = django_filters.ModelMultipleChoiceFilter(
@@ -395,7 +395,7 @@ class TableConfigFilterSet(ChangeLoggedModelFilterSet):
class BookmarkFilterSet(BaseFilterSet):
created = django_filters.DateTimeFilter()
object_type_id = MultiValueNumberFilter()
object_type = ContentTypeFilter()
object_type = MultiValueContentTypeFilter()
user_id = django_filters.ModelMultipleChoiceFilter(
queryset=User.objects.all(),
label=_('User (ID)'),
@@ -462,7 +462,7 @@ class ImageAttachmentFilterSet(ChangeLoggedModelFilterSet):
method='search',
label=_('Search'),
)
object_type = ContentTypeFilter()
object_type = MultiValueContentTypeFilter()
class Meta:
model = ImageAttachment
@@ -481,7 +481,7 @@ class ImageAttachmentFilterSet(ChangeLoggedModelFilterSet):
@register_filterset
class JournalEntryFilterSet(NetBoxModelFilterSet):
created = django_filters.DateTimeFromToRangeFilter()
assigned_object_type = ContentTypeFilter()
assigned_object_type = MultiValueContentTypeFilter()
assigned_object_type_id = django_filters.ModelMultipleChoiceFilter(
queryset=ContentType.objects.all()
)
@@ -576,7 +576,7 @@ class TaggedItemFilterSet(BaseFilterSet):
method='search',
label=_('Search'),
)
object_type = ContentTypeFilter(
object_type = MultiValueContentTypeFilter(
field_name='content_type'
)
object_type_id = django_filters.ModelMultipleChoiceFilter(

View File

@@ -304,7 +304,7 @@ class ConditionSetTest(TestCase):
Test Event Rule with incorrect condition (key "foo" is wrong). Must return false.
"""
ct = ContentType.objects.get(app_label='extras', model='webhook')
ct = ContentType.objects.get_by_natural_key('extras', 'webhook')
site_ct = ContentType.objects.get_for_model(Site)
webhook = Webhook.objects.create(name='Webhook 100', payload_url='http://example.com/?1', http_method='POST')
form = EventRuleForm({

View File

@@ -111,13 +111,13 @@ class CustomFieldTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_object_type(self):
params = {'object_type': 'dcim.site'}
params = {'object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'object_type_id': [ObjectType.objects.get_by_natural_key('dcim', 'site').pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
def test_related_object_type(self):
params = {'related_object_type': 'dcim.site'}
params = {'related_object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'related_object_type_id': [ObjectType.objects.get_by_natural_key('dcim', 'site').pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -348,7 +348,7 @@ class EventRuleTestCase(TestCase, BaseFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_object_type(self):
params = {'object_type': 'dcim.region'}
params = {'object_type': ['dcim.region']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'object_type_id': [ObjectType.objects.get_for_model(Region).pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -417,7 +417,7 @@ class CustomLinkTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_object_type(self):
params = {'object_type': 'dcim.site'}
params = {'object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'object_type_id': [ObjectType.objects.get_for_model(Site).pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -508,7 +508,7 @@ class SavedFilterTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_object_type(self):
params = {'object_type': 'dcim.site'}
params = {'object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'object_type_id': [ObjectType.objects.get_for_model(Site).pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -600,7 +600,7 @@ class BookmarkTestCase(TestCase, BaseFilterSetTests):
Bookmark.objects.bulk_create(bookmarks)
def test_object_type(self):
params = {'object_type': 'dcim.site'}
params = {'object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
params = {'object_type_id': [ContentType.objects.get_for_model(Site).pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
@@ -663,7 +663,7 @@ class ExportTemplateTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_object_type(self):
params = {'object_type': 'dcim.site'}
params = {'object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
params = {'object_type_id': [ObjectType.objects.get_for_model(Site).pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -697,8 +697,8 @@ class ImageAttachmentTestCase(TestCase, ChangeLoggedFilterSetTests):
@classmethod
def setUpTestData(cls):
site_ct = ContentType.objects.get(app_label='dcim', model='site')
rack_ct = ContentType.objects.get(app_label='dcim', model='rack')
site_ct = ContentType.objects.get_by_natural_key('dcim', 'site')
rack_ct = ContentType.objects.get_by_natural_key('dcim', 'rack')
sites = (
Site(name='Site 1', slug='site-1'),
@@ -757,12 +757,12 @@ class ImageAttachmentTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_object_type(self):
params = {'object_type': 'dcim.site'}
params = {'object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_object_type_id_and_object_id(self):
params = {
'object_type_id': ContentType.objects.get(app_label='dcim', model='site').pk,
'object_type_id': ContentType.objects.get_by_natural_key('dcim', 'site').pk,
'object_id': [Site.objects.first().pk],
}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
@@ -845,14 +845,14 @@ class JournalEntryTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_assigned_object_type(self):
params = {'assigned_object_type': 'dcim.site'}
params = {'assigned_object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
params = {'assigned_object_type_id': [ContentType.objects.get(app_label='dcim', model='site').pk]}
params = {'assigned_object_type_id': [ContentType.objects.get_by_natural_key('dcim', 'site').pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
def test_assigned_object(self):
params = {
'assigned_object_type': 'dcim.site',
'assigned_object_type': ['dcim.site'],
'assigned_object_id': [Site.objects.first().pk],
}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
@@ -1426,15 +1426,15 @@ class TaggedItemFilterSetTestCase(TestCase):
def test_object_type(self):
object_type = ObjectType.objects.get_for_model(Site)
params = {'object_type': 'dcim.site'}
params = {'object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
params = {'object_type_id': [object_type.pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
def test_object_id(self):
def test_object(self):
site_ids = Site.objects.values_list('pk', flat=True)
params = {
'object_type': 'dcim.site',
'object_type': ['dcim.site'],
'object_id': site_ids[:2],
}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)

View File

@@ -17,7 +17,7 @@ from virtualization.models import Cluster, ClusterGroup, ClusterType, VirtualMac
class ImageAttachmentTests(TestCase):
@classmethod
def setUpTestData(cls):
cls.ct_rack = ContentType.objects.get(app_label='dcim', model='rack')
cls.ct_rack = ContentType.objects.get_by_natural_key('dcim', 'rack')
cls.image_content = b''
def _stub_image_attachment(self, object_id, image_filename, name=None):

View File

@@ -27,7 +27,7 @@ class ImageUploadTests(TestCase):
def setUpTestData(cls):
# We only need a ContentType with model="rack" for the prefix;
# this doesn't require creating a Rack object.
cls.ct_rack = ContentType.objects.get(app_label='dcim', model='rack')
cls.ct_rack = ContentType.objects.get_by_natural_key('dcim', 'rack')
def _stub_instance(self, object_id=12, name=None):
"""

View File

@@ -16,7 +16,8 @@ from netbox.filtersets import (
)
from tenancy.filtersets import ContactModelFilterSet, TenancyFilterSet
from utilities.filters import (
ContentTypeFilter, MultiValueCharFilter, MultiValueNumberFilter, NumericArrayFilter, TreeNodeMultipleChoiceFilter,
MultiValueCharFilter, MultiValueContentTypeFilter, MultiValueNumberFilter, NumericArrayFilter,
TreeNodeMultipleChoiceFilter,
)
from utilities.filtersets import register_filterset
from virtualization.models import VirtualMachine, VMInterface
@@ -607,7 +608,7 @@ class IPAddressFilterSet(PrimaryModelFilterSet, TenancyFilterSet, ContactModelFi
to_field_name='rd',
label=_('VRF (RD)'),
)
assigned_object_type = ContentTypeFilter()
assigned_object_type = MultiValueContentTypeFilter()
device = MultiValueCharFilter(
method='filter_device',
field_name='name',
@@ -846,7 +847,7 @@ class FHRPGroupFilterSet(PrimaryModelFilterSet):
@register_filterset
class FHRPGroupAssignmentFilterSet(ChangeLoggedModelFilterSet):
interface_type = ContentTypeFilter()
interface_type = MultiValueContentTypeFilter()
group_id = django_filters.ModelMultipleChoiceFilter(
queryset=FHRPGroup.objects.all(),
label=_('Group (ID)'),
@@ -901,7 +902,7 @@ class FHRPGroupAssignmentFilterSet(ChangeLoggedModelFilterSet):
@register_filterset
class VLANGroupFilterSet(OrganizationalModelFilterSet, TenancyFilterSet):
scope_type = ContentTypeFilter()
scope_type = MultiValueContentTypeFilter()
region = django_filters.NumberFilter(
method='filter_scope'
)
@@ -1173,7 +1174,7 @@ class ServiceTemplateFilterSet(PrimaryModelFilterSet):
@register_filterset
class ServiceFilterSet(ContactModelFilterSet, PrimaryModelFilterSet):
parent_object_type = ContentTypeFilter()
parent_object_type = MultiValueContentTypeFilter()
device = MultiValueCharFilter(
method='filter_device',
field_name='name',

View File

@@ -1572,12 +1572,12 @@ class FHRPGroupAssignmentTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_interface_type(self):
params = {'interface_type': 'dcim.interface'}
params = {'interface_type': ['dcim.interface']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
def test_interface(self):
interfaces = Interface.objects.all()[:2]
params = {'interface_type': 'dcim.interface', 'interface_id': [interfaces[0].pk, interfaces[1].pk]}
params = {'interface_type': ['dcim.interface'], 'interface_id': [interfaces[0].pk, interfaces[1].pk]}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)
def test_priority(self):

View File

@@ -5,7 +5,7 @@ from django.utils.translation import gettext as _
from netbox.filtersets import (
NestedGroupModelFilterSet, NetBoxModelFilterSet, OrganizationalModelFilterSet, PrimaryModelFilterSet,
)
from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
from utilities.filters import MultiValueContentTypeFilter, TreeNodeMultipleChoiceFilter
from utilities.filtersets import register_filterset
from .models import *
@@ -110,7 +110,7 @@ class ContactAssignmentFilterSet(NetBoxModelFilterSet):
method='search',
label=_('Search'),
)
object_type = ContentTypeFilter()
object_type = MultiValueContentTypeFilter()
contact_id = django_filters.ModelMultipleChoiceFilter(
queryset=Contact.objects.all(),
label=_('Contact (ID)'),

View File

@@ -355,6 +355,8 @@ class ContactAssignmentTestCase(TestCase, ChangeLoggedFilterSetTests):
ContactAssignment.objects.bulk_create(assignments)
def test_object_type(self):
params = {'object_type': ['dcim.site']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
params = {'object_type_id': ObjectType.objects.get_by_natural_key('dcim', 'site')}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)

View File

@@ -6,7 +6,7 @@ from core.models import ObjectType
from extras.models import NotificationGroup
from netbox.filtersets import BaseFilterSet
from users.models import Group, ObjectPermission, Owner, OwnerGroup, Token, User
from utilities.filters import ContentTypeFilter
from utilities.filters import MultiValueContentTypeFilter
from utilities.filtersets import register_filterset
__all__ = (
@@ -194,7 +194,7 @@ class ObjectPermissionFilterSet(BaseFilterSet):
queryset=ObjectType.objects.all(),
field_name='object_types'
)
object_type = ContentTypeFilter(
object_type = MultiValueContentTypeFilter(
field_name='object_types'
)
can_view = django_filters.BooleanFilter(

View File

@@ -1,6 +1,7 @@
import django_filters
from django import forms
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django_filters.constants import EMPTY_VALUES
from drf_spectacular.types import OpenApiTypes
@@ -10,6 +11,7 @@ __all__ = (
'ContentTypeFilter',
'MultiValueArrayFilter',
'MultiValueCharFilter',
'MultiValueContentTypeFilter',
'MultiValueDateFilter',
'MultiValueDateTimeFilter',
'MultiValueDecimalFilter',
@@ -171,3 +173,27 @@ class ContentTypeFilter(django_filters.CharFilter):
f'{self.field_name}__model': model
}
)
class MultiValueContentTypeFilter(MultiValueCharFilter):
"""
A multi-value version of ContentTypeFilter.
"""
def filter(self, qs, value):
if value in EMPTY_VALUES:
return qs
content_types = []
for key in value:
try:
app_label, model = key.lower().split('.')
ct = ContentType.objects.get_by_natural_key(app_label, model)
content_types.append(ct)
except (ValueError, ContentType.DoesNotExist):
continue
return qs.filter(
**{
f'{self.field_name}__in': content_types,
}
)

View File

@@ -10,7 +10,7 @@ from mptt.models import MPTTModel
from taggit.managers import TaggableManager
from extras.filters import TagFilter
from utilities.filters import ContentTypeFilter, TreeNodeMultipleChoiceFilter
from utilities.filters import MultiValueContentTypeFilter, TreeNodeMultipleChoiceFilter
__all__ = (
'BaseFilterSetTests',
@@ -75,7 +75,7 @@ class BaseFilterSetTests:
# Standardize on object_type for filter name even though it's technically a ContentType
filter_name = 'object_type'
return [
(filter_name, ContentTypeFilter),
(filter_name, MultiValueContentTypeFilter),
(f'{filter_name}_id', django_filters.ModelMultipleChoiceFilter),
]

View File

@@ -7,7 +7,7 @@ from dcim.models import Device, Interface
from ipam.models import IPAddress, RouteTarget, VLAN
from netbox.filtersets import NetBoxModelFilterSet, OrganizationalModelFilterSet, PrimaryModelFilterSet
from tenancy.filtersets import ContactModelFilterSet, TenancyFilterSet
from utilities.filters import ContentTypeFilter, MultiValueCharFilter, MultiValueNumberFilter
from utilities.filters import MultiValueCharFilter, MultiValueContentTypeFilter, MultiValueNumberFilter
from utilities.filtersets import register_filterset
from virtualization.models import VirtualMachine, VMInterface
from .choices import *
@@ -94,7 +94,7 @@ class TunnelTerminationFilterSet(NetBoxModelFilterSet):
role = django_filters.MultipleChoiceFilter(
choices=TunnelTerminationRoleChoices
)
termination_type = ContentTypeFilter()
termination_type = MultiValueContentTypeFilter()
interface = django_filters.ModelMultipleChoiceFilter(
field_name='interface__name',
queryset=Interface.objects.all(),
@@ -445,7 +445,7 @@ class L2VPNTerminationFilterSet(NetBoxModelFilterSet):
queryset=ObjectType.objects.all(),
field_name='assigned_object_type'
)
assigned_object_type = ContentTypeFilter()
assigned_object_type = MultiValueContentTypeFilter()
class Meta:
model = L2VPNTermination

View File

@@ -268,9 +268,9 @@ class TunnelTerminationTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 4)
def test_termination_type(self):
params = {'termination_type': 'dcim.interface'}
params = {'termination_type': ['dcim.interface']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
params = {'termination_type': 'virtualization.vminterface'}
params = {'termination_type': ['virtualization.vminterface']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
def test_interface(self):
@@ -902,7 +902,7 @@ class L2VPNTerminationTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 6)
def test_termination_type(self):
params = {'assigned_object_type': 'ipam.vlan'}
params = {'assigned_object_type': ['ipam.vlan']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 3)
def test_interface(self):

View File

@@ -305,7 +305,7 @@ class WirelessLANTestCase(TestCase, ChangeLoggedFilterSetTests):
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 1)
def test_scope_type(self):
params = {'scope_type': 'dcim.location'}
params = {'scope_type': ['dcim.location']}
self.assertEqual(self.filterset(params, self.queryset).qs.count(), 2)