Parcourir la source

Fix an issue with queries of CidrAddressField where the ORM assumed that
IPNetwork objects are really lists since they have an __iter__ method.
This problem also manifests when saving a CidrAddressField with a unique
constraint from a ModelForm.

James Oakley il y a 12 ans
Parent
commit
aca50f8432
3 fichiers modifiés avec 110 ajouts et 10 suppressions
  1. 39 0
      netfields/managers.py
  2. 16 0
      netfields/models.py
  3. 55 10
      netfields/tests.py

+ 39 - 0
netfields/managers.py

@@ -1,7 +1,12 @@
+from netaddr import IPNetwork
+
 from django.db import models, connection
 from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
 from django.db.models import sql, query
 from django.db.models.query_utils import QueryWrapper
+from django.utils import tree
+
+import datetime
 
 NET_OPERATORS = DatabaseWrapper.operators.copy()
 
@@ -25,6 +30,40 @@ class NetQuery(sql.Query):
 
 
 class NetWhere(sql.where.WhereNode):
+    def add(self, data, connector):
+        """
+        Special form of WhereNode.add() that does not automatically consume the
+        __iter__ method of IPNetwork objects.
+        """
+        if not isinstance(data, (list, tuple)):
+            # Need to bypass WhereNode
+            tree.Node.add(self, data, connector)
+            return
+
+        obj, lookup_type, value = data
+        if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
+            # Consume any generators immediately, so that we can determine
+            # emptiness and transform any non-empty values correctly.
+            value = list(value)
+
+        # The "value_annotation" parameter is used to pass auxilliary information
+        # about the value(s) to the query construction. Specifically, datetime
+        # and empty values need special handling. Other types could be used
+        # here in the future (using Python types is suggested for consistency).
+        if isinstance(value, datetime.datetime):
+            value_annotation = datetime.datetime
+        elif hasattr(value, 'value_annotation'):
+            value_annotation = value.value_annotation
+        else:
+            value_annotation = bool(value)
+
+        if hasattr(obj, "prepare"):
+            value = obj.prepare(lookup_type, value)
+
+        # Need to bypass WhereNode
+        tree.Node.add(self,
+            (obj, lookup_type, value_annotation, value), connector)
+
     def make_atom(self, child, qn, conn):
         lvalue, lookup_type, value_annot, params_or_value = child
 

+ 16 - 0
netfields/models.py

@@ -20,6 +20,14 @@ class NullInetTestModel(Model):
         db_table = 'nullinet'
 
 
+class UniqueInetTestModel(Model):
+    field = InetAddressField(unique=True)
+    objects = NetManager()
+
+    class Meta:
+        db_table = 'uniqueinet'
+
+
 class CidrTestModel(Model):
     field = CidrAddressField()
     objects = NetManager()
@@ -36,6 +44,14 @@ class NullCidrTestModel(Model):
         db_table = 'nullcidr'
 
 
+class UniqueCidrTestModel(Model):
+    field = CidrAddressField(unique=True)
+    objects = NetManager()
+
+    class Meta:
+        db_table = 'uniquecidr'
+
+
 class MACTestModel(Model):
     field = MACAddressField(null=True)
     objects = NetManager()

+ 55 - 10
netfields/tests.py

@@ -5,7 +5,8 @@ from django.forms import ModelForm
 from django.test import TestCase
 
 from netfields.models import (CidrTestModel, InetTestModel, NullCidrTestModel,
-                              NullInetTestModel, MACTestModel)
+                              NullInetTestModel, UniqueInetTestModel,
+                              UniqueCidrTestModel, MACTestModel)
 from netfields.mac import mac_unix_common
 
 
@@ -200,6 +201,17 @@ class TestInetFieldNullable(BaseInetFieldTestCase, TestCase):
         self.model().save()
 
 
+class TestInetFieldUnique(BaseInetFieldTestCase, TestCase):
+    def setUp(self):
+        self.model = UniqueInetTestModel
+        self.qs = self.model.objects.all()
+        self.table = 'uniqueinet'
+
+    def test_save_nonunique(self):
+        self.model(field='1.2.3.4').save()
+        self.assertRaises(IntegrityError, self.model(field='1.2.3.4').save)
+
+
 class TestCidrField(BaseCidrFieldTestCase, TestCase):
     def setUp(self):
         self.model = CidrTestModel
@@ -232,56 +244,89 @@ class TestCidrFieldNullable(BaseCidrFieldTestCase, TestCase):
         self.model().save()
 
 
+class TestCidrFieldUnique(BaseCidrFieldTestCase, TestCase):
+    def setUp(self):
+        self.model = UniqueCidrTestModel
+        self.qs = self.model.objects.all()
+        self.table = 'uniquecidr'
+
+    def test_save_nonunique(self):
+        self.model(field='1.2.3.0/24').save()
+        self.assertRaises(IntegrityError, self.model(field='1.2.3.0/24').save)
+
+
 class InetAddressTestModelForm(ModelForm):
     class Meta:
         model = InetTestModel
 
 
 class TestInetAddressFormField(TestCase):
+    form_class = InetAddressTestModelForm
+
     def test_form_ipv4_valid(self):
-        form = InetAddressTestModelForm({'field': '10.0.0.1'})
+        form = self.form_class({'field': '10.0.0.1'})
         self.assertTrue(form.is_valid())
         self.assertEqual(form.cleaned_data['field'], IPAddress('10.0.0.1'))
 
     def test_form_ipv4_invalid(self):
-        form = InetAddressTestModelForm({'field': '10.0.0.1.2'})
+        form = self.form_class({'field': '10.0.0.1.2'})
         self.assertFalse(form.is_valid())
 
     def test_form_ipv6(self):
-        form = InetAddressTestModelForm({'field': '2001:0:1::2'})
+        form = self.form_class({'field': '2001:0:1::2'})
         self.assertTrue(form.is_valid())
         self.assertEqual(form.cleaned_data['field'], IPAddress('2001:0:1::2'))
 
     def test_form_ipv6_invalid(self):
-        form = InetAddressTestModelForm({'field': '2001:0::1::2'})
+        form = self.form_class({'field': '2001:0::1::2'})
         self.assertFalse(form.is_valid())
 
 
+class UniqueInetAddressTestModelForm(ModelForm):
+    class Meta:
+        model = UniqueInetTestModel
+
+
+class TestUniqueInetAddressFormField(TestInetAddressFormField):
+    form_class = UniqueInetAddressTestModelForm
+
+
 class CidrAddressTestModelForm(ModelForm):
     class Meta:
         model = CidrTestModel
 
 
-class TestNetAddressFormField(TestCase):
+class TestCidrAddressFormField(TestCase):
+    form_class = CidrAddressTestModelForm
+
     def test_form_ipv4_valid(self):
-        form = CidrAddressTestModelForm({'field': '10.0.0.1/24'})
+        form = self.form_class({'field': '10.0.0.1/24'})
         self.assertTrue(form.is_valid())
         self.assertEqual(form.cleaned_data['field'], IPNetwork('10.0.0.1/24'))
 
     def test_form_ipv4_invalid(self):
-        form = CidrAddressTestModelForm({'field': '10.0.0.1.2/32'})
+        form = self.form_class({'field': '10.0.0.1.2/32'})
         self.assertFalse(form.is_valid())
 
     def test_form_ipv6(self):
-        form = CidrAddressTestModelForm({'field': '2001:0:1::2/64'})
+        form = self.form_class({'field': '2001:0:1::2/64'})
         self.assertTrue(form.is_valid())
         self.assertEqual(form.cleaned_data['field'], IPNetwork('2001:0:1::2/64'))
 
     def test_form_ipv6_invalid(self):
-        form = CidrAddressTestModelForm({'field': '2001:0::1::2/128'})
+        form = self.form_class({'field': '2001:0::1::2/128'})
         self.assertFalse(form.is_valid())
 
 
+class UniqueCidrAddressTestModelForm(ModelForm):
+    class Meta:
+        model = UniqueCidrTestModel
+
+
+class TestUniqueCidrAddressFormField(TestCidrAddressFormField):
+    form_class = UniqueCidrAddressTestModelForm
+
+
 class BaseMacTestCase(BaseSqlTestCase):
     value1 = '00:aa:2b:c3:dd:44'
     value2 = '00:aa:2b:c3:dd:45'