Browse Source

Bring code back into running state after refactoring

Thomas Adamcik 15 years ago
parent
commit
d8918712c6
4 changed files with 160 additions and 153 deletions
  1. 1 1
      netfields/__init__.py
  2. 76 77
      netfields/fields.py
  3. 77 75
      netfields/managers.py
  4. 6 0
      netfields/tests.py

+ 1 - 1
netfields/__init__.py

@@ -1,3 +1,3 @@
 from netfields.managers import NetManger
 from netfields.fields import (InetAddressField, CidrAddressField,
-        MACAddressFormField)
+        MACAddressField)

+ 76 - 77
netfields/fields.py

@@ -1,79 +1,78 @@
 from IPy import IP
 
-from django.db import models, connection
-from django.db.models import sql, query
-from django.db.models.query_utils import QueryWrapper
-
-NET_OPERATORS = connection.operators.copy()
-
-for operator in ['contains', 'startswith', 'endswith']:
-    NET_OPERATORS[operator] = 'ILIKE %s'
-    NET_OPERATORS['i%s' % operator] = 'ILIKE %s'
-
-NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
-NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
-NET_OPERATORS['net_contained'] = '<< %s'
-NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
-NET_OPERATORS['net_contains'] = '>> %s'
-NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
-
-NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
-
-
-class NetQuery(sql.Query):
-    query_terms = sql.Query.query_terms.copy()
-    query_terms.update(NET_OPERATORS)
-
-    def add_filter(self, (filter_string, value), *args, **kwargs):
-        # IP(...) == '' fails so make sure to force to string while we can
-        if isinstance(value, IP):
-            value = unicode(value)
-        return super(NetQuery, self).add_filter(
-            (filter_string, value), *args, **kwargs)
-
-
-class NetWhere(sql.where.WhereNode):
-    def make_atom(self, child, qn):
-        table_alias, name, db_type, lookup_type, value_annot, params = child
-
-        if db_type not in ['inet', 'cidr']:
-            return super(NetWhere, self).make_atom(child, qn)
-
-        if table_alias:
-            field_sql = '%s.%s' % (qn(table_alias), qn(name))
-        else:
-            field_sql = qn(name)
-
-        if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
-            if db_type == 'inet':
-                field_sql  = 'HOST(%s)' % field_sql
-            else:
-                field_sql  = 'TEXT(%s)' % field_sql
-
-        if isinstance(params, QueryWrapper):
-            extra, params = params.data
-        else:
-            extra = ''
-
-        if lookup_type in NET_OPERATORS:
-            return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), params)
-        elif lookup_type == 'in':
-            if not value_annot:
-                raise sql.datastructures.EmptyResultSet
-            if extra:
-                return ('%s IN %s' % (field_sql, extra), params)
-            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params)
-        elif lookup_type == 'range':
-            return ('%s BETWEEN %%s and %%s' % field_sql, params)
-        elif lookup_type == 'isnull':
-            return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), params)
-
-        raise ValueError('Invalid lookup type "%s"' % lookup_type)
-
-
-class NetManger(models.Manager):
-    use_for_related_fields = True
-
-    def get_query_set(self):
-        q = NetQuery(self.model, connection, NetWhere)
-        return query.QuerySet(self.model, q)
+from django.db import models
+
+from netfields.managers import NET_OPERATORS, NET_TEXT_OPERATORS
+from netfields.forms import NetAddressFormField, MACAddressFormField
+
+class _NetAddressField(models.Field):
+    empty_strings_allowed = False
+
+    def __init__(self, *args, **kwargs):
+        kwargs['max_length'] = self.max_length
+        super(_NetAddressField, self).__init__(*args, **kwargs)
+
+    def to_python(self, value):
+        if not value:
+            value = None
+
+        if value is None:
+            return value
+
+        return IP(value)
+
+    def get_db_prep_value(self, value):
+        if value is None:
+            return value
+
+        return unicode(self.to_python(value))
+
+    def get_db_prep_lookup(self, lookup_type, value):
+        if value is None:
+            return value
+
+        if (lookup_type in NET_OPERATORS and
+                NET_OPERATORS[lookup_type] not in NET_TEXT_OPERATORS):
+            return [self.get_db_prep_value(value)]
+
+        return super(_NetAddressField, self).get_db_prep_lookup(
+            lookup_type, value)
+
+    def formfield(self, **kwargs):
+        defaults = {'form_class': NetAddressFormField}
+        defaults.update(kwargs)
+        return super(_NetAddressField, self).formfield(**defaults)
+
+
+class InetAddressField(_NetAddressField):
+    description = "PostgreSQL INET field"
+    max_length = 39
+    __metaclass__ = models.SubfieldBase
+
+    def db_type(self):
+        return 'inet'
+
+
+class CidrAddressField(_NetAddressField):
+    description = "PostgreSQL CIDR field"
+    max_length = 43
+    __metaclass__ = models.SubfieldBase
+
+    def db_type(self):
+        return 'cidr'
+
+
+class MACAddressField(models.Field):
+    description = "PostgreSQL MACADDR field"
+
+    def __init__(self, *args, **kwargs):
+        kwargs['max_length'] = 17
+        super(MACAddressField, self).__init__(*args, **kwargs)
+
+    def db_type(self):
+        return 'macaddr'
+
+    def formfield(self, **kwargs):
+        defaults = {'form_class': MACAddressFormField}
+        defaults.update(kwargs)
+        return super(MACAddressField, self).formfield(**defaults)

+ 77 - 75
netfields/managers.py

@@ -1,77 +1,79 @@
 from IPy import IP
 
-from django.db import models
-
-from netfields.forms import NetAddressFormField, MACAddressFormField
-
-class _NetAddressField(models.Field):
-    empty_strings_allowed = False
-
-    def __init__(self, *args, **kwargs):
-        kwargs['max_length'] = self.max_length
-        super(_NetAddressField, self).__init__(*args, **kwargs)
-
-    def to_python(self, value):
-        if not value:
-            value = None
-
-        if value is None:
-            return value
-
-        return IP(value)
-
-    def get_db_prep_value(self, value):
-        if value is None:
-            return value
-
-        return unicode(self.to_python(value))
-
-    def get_db_prep_lookup(self, lookup_type, value):
-        if value is None:
-            return value
-
-        if (lookup_type in NET_OPERATORS and
-                NET_OPERATORS[lookup_type] not in NET_TEXT_OPERATORS):
-            return [self.get_db_prep_value(value)]
-
-        return super(_NetAddressField, self).get_db_prep_lookup(
-            lookup_type, value)
-
-    def formfield(self, **kwargs):
-        defaults = {'form_class': NetAddressFormField}
-        defaults.update(kwargs)
-        return super(_NetAddressField, self).formfield(**defaults)
-
-
-class InetAddressField(_NetAddressField):
-    description = "PostgreSQL INET field"
-    max_length = 39
-    __metaclass__ = models.SubfieldBase
-
-    def db_type(self):
-        return 'inet'
-
-
-class CidrAddressField(_NetAddressField):
-    description = "PostgreSQL CIDR field"
-    max_length = 43
-    __metaclass__ = models.SubfieldBase
-
-    def db_type(self):
-        return 'cidr'
-
-
-class MACAddressField(models.Field):
-    description = "PostgreSQL MACADDR field"
-
-    def __init__(self, *args, **kwargs):
-        kwargs['max_length'] = 17
-        super(MACAddressField, self).__init__(*args, **kwargs)
-
-    def db_type(self):
-        return 'macaddr'
-
-    def formfield(self, **kwargs):
-        defaults = {'form_class': MACAddressFormField}
-        defaults.update(kwargs)
-        return super(MACAddressField, self).formfield(**defaults)
+from django.db import models, connection
+from django.db.models import sql, query
+from django.db.models.query_utils import QueryWrapper
+
+NET_OPERATORS = connection.operators.copy()
+
+for operator in ['contains', 'startswith', 'endswith']:
+    NET_OPERATORS[operator] = 'ILIKE %s'
+    NET_OPERATORS['i%s' % operator] = 'ILIKE %s'
+
+NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
+NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
+NET_OPERATORS['net_contained'] = '<< %s'
+NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
+NET_OPERATORS['net_contains'] = '>> %s'
+NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
+
+NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
+
+
+class NetQuery(sql.Query):
+    query_terms = sql.Query.query_terms.copy()
+    query_terms.update(NET_OPERATORS)
+
+    def add_filter(self, (filter_string, value), *args, **kwargs):
+        # IP(...) == '' fails so make sure to force to string while we can
+        if isinstance(value, IP):
+            value = unicode(value)
+        return super(NetQuery, self).add_filter(
+            (filter_string, value), *args, **kwargs)
+
+
+class NetWhere(sql.where.WhereNode):
+    def make_atom(self, child, qn):
+        table_alias, name, db_type, lookup_type, value_annot, params = child
+
+        if db_type not in ['inet', 'cidr']:
+            return super(NetWhere, self).make_atom(child, qn)
+
+        if table_alias:
+            field_sql = '%s.%s' % (qn(table_alias), qn(name))
+        else:
+            field_sql = qn(name)
+
+        if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
+            if db_type == 'inet':
+                field_sql  = 'HOST(%s)' % field_sql
+            else:
+                field_sql  = 'TEXT(%s)' % field_sql
+
+        if isinstance(params, QueryWrapper):
+            extra, params = params.data
+        else:
+            extra = ''
+
+        if lookup_type in NET_OPERATORS:
+            return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), params)
+        elif lookup_type == 'in':
+            if not value_annot:
+                raise sql.datastructures.EmptyResultSet
+            if extra:
+                return ('%s IN %s' % (field_sql, extra), params)
+            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params)
+        elif lookup_type == 'range':
+            return ('%s BETWEEN %%s and %%s' % field_sql, params)
+        elif lookup_type == 'isnull':
+            return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), params)
+
+        raise ValueError('Invalid lookup type "%s"' % lookup_type)
+
+
+class NetManger(models.Manager):
+    use_for_related_fields = True
+
+    def get_query_set(self):
+        q = NetQuery(self.model, connection, NetWhere)
+        return query.QuerySet(self.model, q)

+ 6 - 0
netfields/tests.py

@@ -1,3 +1,9 @@
+from IPy import IP
+
+from django.db import models, connection
+
+from netfields import *
+
 class InetTestModel(models.Model):
     '''
     >>> cursor = connection.cursor()