Parcourir la source

Fix form handling

Thomas Adamcik il y a 15 ans
Parent
commit
636b8ceb60
1 fichiers modifiés avec 31 ajouts et 6 suppressions
  1. 31 6
      manager.py

+ 31 - 6
manager.py

@@ -1,10 +1,12 @@
 import re
 from IPy import IP
 
+from django import forms
 from django.db import models, connection
 from django.db.models import sql, query
 from django.db.models.query_utils import QueryWrapper
-from django import forms
+from django.utils.encoding import force_unicode
+from django.utils.safestring import mark_safe
 
 NET_OPERATORS = {
     'lt': '<',
@@ -40,6 +42,7 @@ class NetQuery(sql.Query):
         return super(NetQuery, self).add_filter(
             (filter_string, value), *args, **kwargs)
 
+
 class NetWhere(sql.where.WhereNode):
     def make_atom(self, child, qn):
         table_alias, name, db_type, lookup_type, value_annot, params = child
@@ -78,6 +81,7 @@ class NetWhere(sql.where.WhereNode):
 
         raise ValueError('Invalid lookup type "%s"' % lookup_type)
 
+
 class NetManger(models.Manager):
     use_for_related_fields = True
 
@@ -85,13 +89,27 @@ class NetManger(models.Manager):
         q = NetQuery(self.model, connection, NetWhere)
         return query.QuerySet(self.model, q)
 
+
+class NetInput(forms.Widget):
+    input_type = 'text'
+
+    def render(self, name, value, attrs=None):
+        # Default forms.Widget compares value != '' which breaks IP...
+        if value is None: value = ''
+        final_attrs = self.build_attrs(attrs, type=self.input_type, name=name)
+        if value:
+            final_attrs['value'] = force_unicode(value)
+        return mark_safe(u'<input%s />' % forms.util.flatatt(final_attrs))
+
+
 class NetAddressFormField(forms.Field):
+    widget = NetInput
     default_error_messages = {
         'invalid': u'Enter a valid IP Address.',
     }
 
     def __init__(self, *args, **kwargs):
-        super(DateTimeField, self).__init__(*args, **kwargs)
+        super(NetAddressFormField, self).__init__(*args, **kwargs)
 
     def clean(self, value):
         super(NetAddressFormField, self).clean(value)
@@ -102,10 +120,11 @@ class NetAddressFormField(forms.Field):
             return value
         try:
             return IP(value)
-        except ValueError:
-            raise forms.ValidationError(self.error_messages['invalid'])
+        except ValueError, e:
+            raise forms.ValidationError(e)
+
 
-mac_re = re.compile(r'^(([A-F0-9]:){5}[A-F0-9])$')
+mac_re = re.compile(r'^(([A-F0-9]{2}:){5}[A-F0-9]{2})$')
 
 class MACAddressFormField(forms.RegexField):
     default_error_messages = {
@@ -113,7 +132,8 @@ class MACAddressFormField(forms.RegexField):
     }
 
     def __init__(self, *args, **kwargs):
-        super(IPAddressField, self).__init__(mac_re, *args, **kwargs)
+        super(MACAddressFormField, self).__init__(mac_re, *args, **kwargs)
+
 
 class _NetAddressField(models.Field):
     empty_strings_allowed = False
@@ -153,6 +173,7 @@ class _NetAddressField(models.Field):
         defaults.update(kwargs)
         return super(_NetAddressField, self).formfield(**defaults)
 
+
 class InetAddressField(_NetAddressField):
     description = "PostgreSQL INET field"
     max_length = 39
@@ -161,6 +182,7 @@ class InetAddressField(_NetAddressField):
     def db_type(self):
         return 'inet'
 
+
 class CidrAddressField(_NetAddressField):
     description = "PostgreSQL CIDR field"
     max_length = 43
@@ -169,6 +191,7 @@ class CidrAddressField(_NetAddressField):
     def db_type(self):
         return 'cidr'
 
+
 class MACAddressField(models.Field):
     description = "PostgreSQL MACADDR field"
 
@@ -184,6 +207,8 @@ class MACAddressField(models.Field):
         defaults.update(kwargs)
         return super(MACAddressField, self).formfield(**defaults)
 
+
+# ---- TESTS ----
 class InetTestModel(models.Model):
     '''
     >>> cursor = connection.cursor()