Browse Source

Use netaddr module to handle MAC addresses.

This solves the issue where the default lowercase MAC addresses returned from
Postgres were considered invalid in forms, even though existing values
were represented in lowercase. This also allows for common alternate formats
such as EUI64 (Windows-style) and the triple-dotted format used by Cisco.

The netaddr module is used since IPy has no equivalent functionality. However,
netaddr does provide the functionality IPy is used for...

A number of tests were added for MAC address fields and for both MAC and
IP forms. An issue where to_python() was not properly raising
ValidationError when checking for form validity was also fixed as a result.
James Oakley 12 years ago
parent
commit
3718b454e6
7 changed files with 210 additions and 51 deletions
  1. 14 0
      netfields/fields.py
  2. 22 7
      netfields/forms.py
  3. 7 0
      netfields/mac.py
  4. 1 1
      netfields/models.py
  5. 161 43
      netfields/tests.py
  6. 1 0
      requirements.txt
  7. 4 0
      tox.ini

+ 14 - 0
netfields/fields.py

@@ -1,9 +1,11 @@
 from IPy import IP
+from netaddr import EUI
 
 from django.db import models
 
 from netfields.managers import NET_OPERATORS, NET_TEXT_OPERATORS
 from netfields.forms import NetAddressFormField, MACAddressFormField
+from netfields.mac import mac_unix_common
 
 
 class _NetAddressField(models.Field):
@@ -82,6 +84,18 @@ class MACAddressField(models.Field):
     def db_type(self, connection):
         return 'macaddr'
 
+    def to_python(self, value):
+        if not value:
+            return value
+
+        return EUI(value, dialect=mac_unix_common)
+
+    def get_prep_value(self, value):
+        if not value:
+            return None
+
+        return unicode(self.to_python(value))
+
     def formfield(self, **kwargs):
         defaults = {'form_class': MACAddressFormField}
         defaults.update(kwargs)

+ 22 - 7
netfields/forms.py

@@ -1,9 +1,12 @@
-import re
 from IPy import IP
+from netaddr import EUI, AddrFormatError
 
 from django import forms
 from django.utils.encoding import force_unicode
 from django.utils.safestring import mark_safe
+from django.core.exceptions import ValidationError
+
+from netfields.mac import mac_unix_common
 
 
 class NetInput(forms.Widget):
@@ -35,16 +38,28 @@ class NetAddressFormField(forms.Field):
         if isinstance(value, IP):
             return value
 
-        return IP(value)
-
-
-MAC_RE = re.compile(r'^(([A-F0-9]{2}:){5}[A-F0-9]{2})$')
+        try:
+            return IP(value)
+        except ValueError, e:
+            raise ValidationError(str(e))
 
 
-class MACAddressFormField(forms.RegexField):
+class MACAddressFormField(forms.Field):
     default_error_messages = {
         'invalid': u'Enter a valid MAC address.',
     }
 
     def __init__(self, *args, **kwargs):
-        super(MACAddressFormField, self).__init__(MAC_RE, *args, **kwargs)
+        super(MACAddressFormField, self).__init__(*args, **kwargs)
+
+    def to_python(self, value):
+        if not value:
+            return None
+
+        if isinstance(value, EUI):
+            return value
+
+        try:
+            return EUI(value, dialect=mac_unix_common)
+        except AddrFormatError:
+            raise ValidationError(self.error_messages['invalid'])

+ 7 - 0
netfields/mac.py

@@ -0,0 +1,7 @@
+import netaddr
+
+
+class mac_unix_common(netaddr.mac_eui48):
+    """Common form of UNIX MAC address dialect class"""
+    word_sep  = ':'
+    word_fmt  = '%.2x'

+ 1 - 1
netfields/models.py

@@ -37,7 +37,7 @@ class NullCidrTestModel(Model):
 
 
 class MACTestModel(Model):
-    mac = MACAddressField(null=True)
+    field = MACAddressField(null=True)
     objects = NetManager()
 
     class Meta:

+ 161 - 43
netfields/tests.py

@@ -1,14 +1,16 @@
 from IPy import IP
+from netaddr import EUI
 
 from django.db import IntegrityError
 from django.forms import ModelForm
 from django.test import TestCase
 
 from netfields.models import (CidrTestModel, InetTestModel, NullCidrTestModel,
-                              NullInetTestModel)
+                              NullInetTestModel, MACTestModel)
+from netfields.mac import mac_unix_common
 
 
-class BaseTestCase(object):
+class BaseSqlTestCase(object):
     select = 'SELECT "table"."id", "table"."field" FROM "table" '
 
     def assertSqlEquals(self, qs, sql):
@@ -21,55 +23,69 @@ class BaseTestCase(object):
     def test_init_with_blank(self):
         self.model()
 
-    def test_init_with_text_fails(self):
-        self.assertRaises(ValueError, self.model, field='abc')
+    def test_isnull_true_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__isnull=True),
+            self.select + 'WHERE "table"."field" IS NULL')
 
-    def test_save(self):
-        self.model(field='10.0.0.1').save()
+    def test_isnull_false_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__isnull=False),
+            self.select + 'WHERE "table"."field" IS NOT NULL')
 
-    def test_save_object(self):
-        self.model(field=IP('10.0.0.1')).save()
+    def test_save(self):
+        self.model(field=self.value1).save()
 
     def test_equals_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field=self.value1),
             self.select + 'WHERE "table"."field" = %s ')
 
     def test_exact_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__exact='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__exact=self.value1),
             self.select + 'WHERE "table"."field" = %s ')
 
-    def test_iexact_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__iexact='10.0.0.1'),
-            self.select + 'WHERE "table"."field" = %s ')
-
-    def test_net_contains_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__net_contains='10.0.0.1'),
-            self.select + 'WHERE "table"."field" >> %s ')
-
     def test_in_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__in=['10.0.0.1', '10.0.0.2']),
+        self.assertSqlEquals(self.qs.filter(field__in=[self.value1, self.value2]),
             self.select + 'WHERE "table"."field" IN (%s, %s)')
 
     def test_gt_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__gt='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__gt=self.value1),
             self.select + 'WHERE "table"."field" > %s ')
 
     def test_gte_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__gte='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__gte=self.value1),
             self.select + 'WHERE "table"."field" >= %s ')
 
     def test_lt_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__lt='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__lt=self.value1),
             self.select + 'WHERE "table"."field" < %s ')
 
     def test_lte_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__lte='10.0.0.1'),
+        self.assertSqlEquals(self.qs.filter(field__lte=self.value1),
             self.select + 'WHERE "table"."field" <= %s ')
 
     def test_range_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__range=('10.0.0.1', '10.0.0.10')),
+        self.assertSqlEquals(self.qs.filter(field__range=(self.value1, self.value3)),
             self.select + 'WHERE "table"."field" BETWEEN %s and %s')
 
+
+
+class BaseInetTestCase(BaseSqlTestCase):
+    value1 = '10.0.0.1'
+    value2 = '10.0.0.2'
+    value3 = '10.0.0.10'
+
+    def test_save_object(self):
+        self.model(field=IP(self.value1)).save()
+
+    def test_init_with_text_fails(self):
+        self.assertRaises(ValueError, self.model, field='abc')
+
+    def test_iexact_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__iexact=self.value1),
+            self.select + 'WHERE "table"."field" = %s ')
+
+    def test_search_lookup_fails(self):
+        self.assertSqlRaises(self.qs.filter(field__search='10'), ValueError)
+
     def test_year_lookup_fails(self):
         self.assertSqlRaises(self.qs.filter(field__year=1), ValueError)
 
@@ -79,16 +95,9 @@ class BaseTestCase(object):
     def test_day_lookup_fails(self):
         self.assertSqlRaises(self.qs.filter(field__day=1), ValueError)
 
-    def test_isnull_true_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__isnull=True),
-            self.select + 'WHERE "table"."field" IS NULL')
-
-    def test_isnull_false_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__isnull=False),
-            self.select + 'WHERE "table"."field" IS NOT NULL')
-
-    def test_search_lookup_fails(self):
-        self.assertSqlRaises(self.qs.filter(field__search='10'), ValueError)
+    def test_net_contains_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__net_contains='10.0.0.1'),
+            self.select + 'WHERE "table"."field" >> %s ')
 
     def test_net_contains_or_equals(self):
         self.assertSqlEquals(self.qs.filter(field__net_contains_or_equals='10.0.0.1'),
@@ -103,7 +112,7 @@ class BaseTestCase(object):
             self.select + 'WHERE "table"."field" <<= %s ')
 
 
-class BaseInetFieldTestCase(BaseTestCase):
+class BaseInetFieldTestCase(BaseInetTestCase):
     def test_startswith_lookup(self):
         self.assertSqlEquals(self.qs.filter(field__startswith='10.'),
             self.select + 'WHERE HOST("table"."field") ILIKE %s ')
@@ -129,7 +138,7 @@ class BaseInetFieldTestCase(BaseTestCase):
             self.select + 'WHERE HOST("table"."field") ~* %s ')
 
 
-class BaseCidrFieldTestCase(BaseTestCase):
+class BaseCidrFieldTestCase(BaseInetTestCase):
     def test_startswith_lookup(self):
         self.assertSqlEquals(self.qs.filter(field__startswith='10.'),
             self.select + 'WHERE TEXT("table"."field") ILIKE %s ')
@@ -225,13 +234,122 @@ class InetTestModelForm(ModelForm):
 
 
 class TestNetAddressFormField(TestCase):
+    def test_form_ipv4_valid(self):
+        form = InetTestModelForm({'field': '10.0.0.1'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], IP('10.0.0.1'))
+
+    def test_form_ipv4_invalid(self):
+        form = InetTestModelForm({'field': '10.0.0.1.2'})
+        self.assertFalse(form.is_valid())
+
+    def test_form_ipv6(self):
+        form = InetTestModelForm({'field': '2001:0:1::2'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], IP('2001:0:1::2'))
+
+    def test_form_ipv6_invalid(self):
+        form = InetTestModelForm({'field': '2001:0::1::2'})
+        self.assertFalse(form.is_valid())
+
+
+class BaseMacTestCase(BaseSqlTestCase):
+    value1 = '00:aa:2b:c3:dd:44'
+    value2 = '00:aa:2b:c3:dd:45'
+    value3 = '00:aa:2b:c3:dd:ff'
+
+    def test_save_object(self):
+        self.model(field=EUI(self.value1)).save()
+
+    def test_iexact_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__iexact=self.value1),
+            self.select + 'WHERE UPPER("table"."field"::text) = UPPER(%s) ')
+
+    def test_startswith_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__startswith='00:'),
+            self.select + 'WHERE "table"."field"::text LIKE %s ')
+
+    def test_istartswith_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__istartswith='00:'),
+            self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ')
+
+    def test_endswith_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__endswith=':ff'),
+            self.select + 'WHERE "table"."field"::text LIKE %s ')
+
+    def test_iendswith_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__iendswith=':ff'),
+            self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ')
+
+    def test_regex_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__regex='00'),
+            self.select + 'WHERE "table"."field" ~ %s ')
+
+    def test_iregex_lookup(self):
+        self.assertSqlEquals(self.qs.filter(field__iregex='00'),
+            self.select + 'WHERE "table"."field" ~* %s ')
+
+
+class TestMacAddressField(BaseMacTestCase, TestCase):
     def setUp(self):
-        self.data = {
-            'field': '10.0.0.1',
-        }
-        self.addr = IP('10.0.0.1')
+        self.model = MACTestModel
+        self.qs = self.model.objects.all()
+        self.table = 'mac'
+
+    def test_save_blank(self):
+        self.model().save()
 
-    def test_form(self):
-        form = InetTestModelForm(self.data)
+    def test_save_none(self):
+        self.model(field=None).save()
+
+    def test_save_nothing_fails(self):
+        self.model().save()
+
+
+class MacAddressTestModelForm(ModelForm):
+    class Meta:
+        model = MACTestModel
+
+
+class TestMacAddressFormField(TestCase):
+    def setUp(self):
+        self.mac = EUI('00:aa:2b:c3:dd:44', dialect=mac_unix_common)
+
+    def test_unix(self):
+        form = MacAddressTestModelForm({'field': '0:AA:2b:c3:dd:44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_unix_common(self):
+        form = MacAddressTestModelForm({'field': '00:aa:2b:c3:dd:44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_eui48(self):
+        form = MacAddressTestModelForm({'field': '00-AA-2B-C3-DD-44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_cisco(self):
+        form = MacAddressTestModelForm({'field': '00aa.2bc3.dd44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_24bit_colon(self):
+        form = MacAddressTestModelForm({'field': '00aa2b:c3dd44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_24bit_hyphen(self):
+        form = MacAddressTestModelForm({'field': '00aa2b-c3dd44'})
         self.assertTrue(form.is_valid())
-        self.assertEqual(form.cleaned_data['field'], self.addr)
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_bare(self):
+        form = MacAddressTestModelForm({'field': '00aa2b:c3dd44'})
+        self.assertTrue(form.is_valid())
+        self.assertEqual(form.cleaned_data['field'], self.mac)
+
+    def test_invalid(self):
+        form = MacAddressTestModelForm({'field': 'notvalid'})
+        self.assertFalse(form.is_valid())

+ 1 - 0
requirements.txt

@@ -1,3 +1,4 @@
 IPy
 django>=1.3
+netaddr
 psycopg2

+ 4 - 0
tox.ini

@@ -16,6 +16,7 @@ basepython=python2.6
 deps=
     IPy
     django==1.3
+    netaddr
     psycopg2==2.4.1
 
 [testenv:py27-django13]
@@ -23,6 +24,7 @@ basepython=python2.7
 deps=
     IPy
     django==1.3
+    netaddr
     psycopg2==2.4.1
 
 [testenv:py26-django14]
@@ -30,6 +32,7 @@ basepython=python2.6
 deps=
     IPy
     django==1.4
+    netaddr
     psycopg2
 
 [testenv:py27-django14]
@@ -37,4 +40,5 @@ basepython=python2.7
 deps=
     IPy
     django==1.4
+    netaddr
     psycopg2