Browse Source

Merge pull request #13 from jimfunk/master

Fix to_python() in form fields and add South introspection rules
Thomas Adamcik 12 years ago
parent
commit
61dda6ea71
8 changed files with 233 additions and 44 deletions
  1. 24 0
      netfields/fields.py
  2. 22 7
      netfields/forms.py
  3. 7 0
      netfields/mac.py
  4. 1 1
      netfields/models.py
  5. 173 36
      netfields/tests.py
  6. 1 0
      requirements.txt
  7. 1 0
      setup.py
  8. 4 0
      tox.ini

+ 24 - 0
netfields/fields.py

@@ -1,9 +1,11 @@
 from IPy import IP
+from netaddr import EUI
 
 from django.db import models
 
 from netfields.managers import NET_OPERATORS, NET_TEXT_OPERATORS
 from netfields.forms import NetAddressFormField, MACAddressFormField
+from netfields.mac import mac_unix_common
 
 
 class _NetAddressField(models.Field):
@@ -82,7 +84,29 @@ class MACAddressField(models.Field):
     def db_type(self, connection):
         return 'macaddr'
 
+    def to_python(self, value):
+        if not value:
+            return value
+
+        return EUI(value, dialect=mac_unix_common)
+
+    def get_prep_value(self, value):
+        if not value:
+            return None
+
+        return unicode(self.to_python(value))
+
     def formfield(self, **kwargs):
         defaults = {'form_class': MACAddressFormField}
         defaults.update(kwargs)
         return super(MACAddressField, self).formfield(**defaults)
+
+try:
+    from south.modelsinspector import add_introspection_rules
+    add_introspection_rules([], [
+        "^netfields\.fields\.InetAddressField",
+        "^netfields\.fields\.CidrAddressField",
+        "^netfields\.fields\.MACAddressField",
+    ])
+except ImportError:
+    pass

+ 22 - 7
netfields/forms.py

@@ -1,9 +1,12 @@
-import re
 from IPy import IP
+from netaddr import EUI, AddrFormatError
 
 from django import forms
 from django.utils.encoding import force_unicode
 from django.utils.safestring import mark_safe
+from django.core.exceptions import ValidationError
+
+from netfields.mac import mac_unix_common
 
 
 class NetInput(forms.Widget):
@@ -35,16 +38,28 @@ class NetAddressFormField(forms.Field):
         if isinstance(value, IP):
             return value
 
-        return self.python_type(value)
-
-
-MAC_RE = re.compile(r'^(([A-F0-9]{2}:){5}[A-F0-9]{2})$')
+        try:
+            return IP(value)
+        except ValueError, e:
+            raise ValidationError(str(e))
 
 
-class MACAddressFormField(forms.RegexField):
+class MACAddressFormField(forms.Field):
     default_error_messages = {
         'invalid': u'Enter a valid MAC address.',
     }
 
     def __init__(self, *args, **kwargs):
-        super(MACAddressFormField, self).__init__(MAC_RE, *args, **kwargs)
+        super(MACAddressFormField, self).__init__(*args, **kwargs)
+
+    def to_python(self, value):
+        if not value:
+            return None
+
+        if isinstance(value, EUI):
+            return value
+
+        try:
+            return EUI(value, dialect=mac_unix_common)
+        except AddrFormatError:
+            raise ValidationError(self.error_messages['invalid'])

+ 7 - 0
netfields/mac.py

@@ -0,0 +1,7 @@
+import netaddr
+
+
+class mac_unix_common(netaddr.mac_eui48):
+    """Common form of UNIX MAC address dialect class"""
+    word_sep  = ':'
+    word_fmt  = '%.2x'

+ 1 - 1
netfields/models.py

@@ -37,7 +37,7 @@ class NullCidrTestModel(Model):
 
 
 class MACTestModel(Model):
-    mac = MACAddressField(null=True)
+    field = MACAddressField(null=True)
     objects = NetManager()
 
     class Meta:

+ 173 - 36
netfields/tests.py

@@ -1,13 +1,16 @@
 from IPy import IP
+from netaddr import EUI
 
 from django.db import IntegrityError
+from django.forms import ModelForm
 from django.test import TestCase
 
 from netfields.models import (CidrTestModel, InetTestModel, NullCidrTestModel,
-                              NullInetTestModel)
+                              NullInetTestModel, MACTestModel)
+from netfields.mac import mac_unix_common
 
 
-class BaseTestCase(object):
+class BaseSqlTestCase(object):
     select = 'SELECT "table"."id", "table"."field" FROM "table" '
 
     def assertSqlEquals(self, qs, sql):
@@ -20,55 +23,69 @@ class BaseTestCase(object):
     def test_init_with_blank(self):
         self.model()
 
-    def test_init_with_text_fails(self):
-        self.assertRaises(ValueError, self.model, field='abc')
+    def test_isnull_true_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__isnull=True),
+            self.select + 'WHERE "table"."field" IS NULL')
 
-    def test_save(self):
-        self.model(field='10.0.0.1').save()
+    def test_isnull_false_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__isnull=False),
+            self.select + 'WHERE "table"."field" IS NOT NULL')
 
-    def test_save_object(self):
-        self.model(field=IP('10.0.0.1')).save()
+    def test_save(self):
+        self.model(field=self.value1).save()
 
     def test_equals_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field=self.value1),
             self.select + 'WHERE "table"."field" = %s ')
 
     def test_exact_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__exact='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__exact=self.value1),
             self.select + 'WHERE "table"."field" = %s ')
 
-    def test_iexact_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__iexact='10.0.0.1'),
-            self.select + 'WHERE "table"."field" = %s ')
-
-    def test_net_contains_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__net_contains='10.0.0.1'),
-            self.select + 'WHERE "table"."field" >> %s ')
-
     def test_in_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__in=['10.0.0.1', '10.0.0.2']),
+        self.assertSqlEquals(self.qs.filter(field__in=[self.value1, self.value2]),
             self.select + 'WHERE "table"."field" IN (%s, %s)')
 
     def test_gt_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__gt='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__gt=self.value1),
             self.select + 'WHERE "table"."field" > %s ')
 
     def test_gte_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__gte='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__gte=self.value1),
             self.select + 'WHERE "table"."field" >= %s ')
 
     def test_lt_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__lt='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__lt=self.value1),
             self.select + 'WHERE "table"."field" < %s ')
 
     def test_lte_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__lte='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__lte=self.value1),
             self.select + 'WHERE "table"."field" <= %s ')
 
     def test_range_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__range=('10.0.0.1', '10.0.0.10')),
+        self.assertSqlEquals(self.qs.filter(field__range=(self.value1, self.value3)),
             self.select + 'WHERE "table"."field" BETWEEN %s and %s')
 
+
+
+class BaseInetTestCase(BaseSqlTestCase):
+    value1 = '10.0.0.1'
+    value2 = '10.0.0.2'
+    value3 = '10.0.0.10'
+
+    def test_save_object(self):
+        self.model(field=IP(self.value1)).save()
+
+    def test_init_with_text_fails(self):
+        self.assertRaises(ValueError, self.model, field='abc')
+
+    def test_iexact_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__iexact=self.value1),
+            self.select + 'WHERE "table"."field" = %s ')
+
+    def test_search_lookup_fails(self):
+        self.assertSqlRaises(self.qs.filter(field__search='10'), ValueError)
+
     def test_year_lookup_fails(self):
         self.assertSqlRaises(self.qs.filter(field__year=1), ValueError)
 
@@ -78,16 +95,9 @@ class BaseTestCase(object):
     def test_day_lookup_fails(self):
         self.assertSqlRaises(self.qs.filter(field__day=1), ValueError)
 
-    def test_isnull_true_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__isnull=True),
-            self.select + 'WHERE "table"."field" IS NULL')
-
-    def test_isnull_false_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__isnull=False),
-            self.select + 'WHERE "table"."field" IS NOT NULL')
-
-    def test_search_lookup_fails(self):
-        self.assertSqlRaises(self.qs.filter(field__search='10'), ValueError)
+    def test_net_contains_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__net_contains='10.0.0.1'),
+            self.select + 'WHERE "table"."field" >> %s ')
 
     def test_net_contains_or_equals(self):
         self.assertSqlEquals(self.qs.filter(field__net_contains_or_equals='10.0.0.1'),
@@ -102,7 +112,7 @@ class BaseTestCase(object):
             self.select + 'WHERE "table"."field" <<= %s ')
 
 
-class BaseInetFieldTestCase(BaseTestCase):
+class BaseInetFieldTestCase(BaseInetTestCase):
     def test_startswith_lookup(self):
         self.assertSqlEquals(self.qs.filter(field__startswith='10.'),
             self.select + 'WHERE HOST("table"."field") ILIKE %s ')
@@ -128,7 +138,7 @@ class BaseInetFieldTestCase(BaseTestCase):
             self.select + 'WHERE HOST("table"."field") ~* %s ')
 
 
-class BaseCidrFieldTestCase(BaseTestCase):
+class BaseCidrFieldTestCase(BaseInetTestCase):
     def test_startswith_lookup(self):
         self.assertSqlEquals(self.qs.filter(field__startswith='10.'),
             self.select + 'WHERE TEXT("table"."field") ILIKE %s ')
@@ -216,3 +226,130 @@ class TestCidrFieldNullable(BaseCidrFieldTestCase, TestCase):
 
     def test_save_nothing_fails(self):
         self.model().save()
+
+
+class InetTestModelForm(ModelForm):
+    class Meta:
+        model = InetTestModel
+
+
+class TestNetAddressFormField(TestCase):
+    def test_form_ipv4_valid(self):
+        form = InetTestModelForm({'field': '10.0.0.1'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], IP('10.0.0.1'))
+
+    def test_form_ipv4_invalid(self):
+        form = InetTestModelForm({'field': '10.0.0.1.2'})
+        self.assertFalse(form.is_valid())
+
+    def test_form_ipv6(self):
+        form = InetTestModelForm({'field': '2001:0:1::2'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], IP('2001:0:1::2'))
+
+    def test_form_ipv6_invalid(self):
+        form = InetTestModelForm({'field': '2001:0::1::2'})
+        self.assertFalse(form.is_valid())
+
+
+class BaseMacTestCase(BaseSqlTestCase):
+    value1 = '00:aa:2b:c3:dd:44'
+    value2 = '00:aa:2b:c3:dd:45'
+    value3 = '00:aa:2b:c3:dd:ff'
+
+    def test_save_object(self):
+        self.model(field=EUI(self.value1)).save()
+
+    def test_iexact_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__iexact=self.value1),
+            self.select + 'WHERE UPPER("table"."field"::text) = UPPER(%s) ')
+
+    def test_startswith_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__startswith='00:'),
+            self.select + 'WHERE "table"."field"::text LIKE %s ')
+
+    def test_istartswith_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__istartswith='00:'),
+            self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ')
+
+    def test_endswith_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__endswith=':ff'),
+            self.select + 'WHERE "table"."field"::text LIKE %s ')
+
+    def test_iendswith_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__iendswith=':ff'),
+            self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ')
+
+    def test_regex_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__regex='00'),
+            self.select + 'WHERE "table"."field" ~ %s ')
+
+    def test_iregex_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__iregex='00'),
+            self.select + 'WHERE "table"."field" ~* %s ')
+
+
+class TestMacAddressField(BaseMacTestCase, TestCase):
+    def setUp(self):
+        self.model = MACTestModel
+        self.qs = self.model.objects.all()
+        self.table = 'mac'
+
+    def test_save_blank(self):
+        self.model().save()
+
+    def test_save_none(self):
+        self.model(field=None).save()
+
+    def test_save_nothing_fails(self):
+        self.model().save()
+
+
+class MacAddressTestModelForm(ModelForm):
+    class Meta:
+        model = MACTestModel
+
+
+class TestMacAddressFormField(TestCase):
+    def setUp(self):
+        self.mac = EUI('00:aa:2b:c3:dd:44', dialect=mac_unix_common)
+
+    def test_unix(self):
+        form = MacAddressTestModelForm({'field': '0:AA:2b:c3:dd:44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_unix_common(self):
+        form = MacAddressTestModelForm({'field': '00:aa:2b:c3:dd:44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_eui48(self):
+        form = MacAddressTestModelForm({'field': '00-AA-2B-C3-DD-44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_cisco(self):
+        form = MacAddressTestModelForm({'field': '00aa.2bc3.dd44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_24bit_colon(self):
+        form = MacAddressTestModelForm({'field': '00aa2b:c3dd44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_24bit_hyphen(self):
+        form = MacAddressTestModelForm({'field': '00aa2b-c3dd44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_bare(self):
+        form = MacAddressTestModelForm({'field': '00aa2b:c3dd44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_invalid(self):
+        form = MacAddressTestModelForm({'field': 'notvalid'})
+        self.assertFalse(form.is_valid())

+ 1 - 0
requirements.txt

@@ -1,3 +1,4 @@
 IPy
 django>=1.3
+netaddr
 psycopg2

+ 1 - 0
setup.py

@@ -26,6 +26,7 @@ setup(
     zip_safe=False,
     install_requires=[
         'IPy',
+        'netaddr',
         'django>=1.3',
     ],
 

+ 4 - 0
tox.ini

@@ -16,6 +16,7 @@ basepython=python2.6
 deps=
     IPy
     django==1.3
+    netaddr
     psycopg2==2.4.1
 
 [testenv:py27-django13]
@@ -23,6 +24,7 @@ basepython=python2.7
 deps=
     IPy
     django==1.3
+    netaddr
     psycopg2==2.4.1
 
 [testenv:py26-django14]
@@ -30,6 +32,7 @@ basepython=python2.6
 deps=
     IPy
     django==1.4
+    netaddr
     psycopg2
 
 [testenv:py27-django14]
@@ -37,4 +40,5 @@ basepython=python2.7
 deps=
     IPy
     django==1.4
+    netaddr
     psycopg2