Browse Source

Switch to new Django 1.7 Lookups

Aaron C. de Bruyn 10 years ago
parent
commit
03ff3d67f0
2 changed files with 47 additions and 143 deletions
  1. 47 136
      netfields/managers.py
  2. 0 7
      netfields/models.py

+ 47 - 136
netfields/managers.py

@@ -3,8 +3,8 @@ from netaddr import IPNetwork
 from django import VERSION
 from django.db import models, connection
 from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
-from django.db.models.fields import DateTimeField
-from django.db.models import sql, query
+from django.db.models.fields import DateTimeField, Field
+from django.db.models import sql, query, Lookup
 from django.db.models.query_utils import QueryWrapper
 from django.utils import tree
 
@@ -18,147 +18,58 @@ for operator in ['contains', 'startswith', 'endswith']:
 
 NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
 NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
-NET_OPERATORS['net_contained'] = '<< %s'
-NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
-NET_OPERATORS['net_contains'] = '>> %s'
-NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
 
 NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
 
 
-class NetQuery(sql.Query):
-    query_terms = sql.Query.query_terms.copy()
-    query_terms.update(NET_OPERATORS)
-
-
-class NetWhere(sql.where.WhereNode):
-
-
-    def _prepare_data(self, data):
-        """
-        Special form of WhereNode._prepare_data() that does not automatically consume the
-        __iter__ method of IPNetwork objects.  This is used in Django >= 1.6
-        """
-
-        if not isinstance(data, (list, tuple)):
-            return data
-        obj, lookup_type, value = data
-        if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
-            # Consume any generators immediately, so that we can determine
-            # emptiness and transform any non-empty values correctly.
-            value = list(value)
-
-
-        # The "value_annotation" parameter is used to pass auxilliary information
-        # about the value(s) to the query construction. Specifically, datetime
-        # and empty values need special handling. Other types could be used
-        # here in the future (using Python types is suggested for consistency).
-        if (isinstance(value, datetime.datetime)
-            or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')):
-            value_annotation = datetime.datetime
-        elif hasattr(value, 'value_annotation'):
-            value_annotation = value.value_annotation
-        else:
-            value_annotation = bool(value)
-
-        if hasattr(obj, "prepare"):
-            value = obj.prepare(lookup_type, value)
-        return (obj, lookup_type, value_annotation, value)
-
-
-    if VERSION[:2] < (1, 6):
-        def add(self, data, connector):
-            """
-            Special form of WhereNode.add() that does not automatically consume the
-            __iter__ method of IPNetwork objects.
-            """
-            if not isinstance(data, (list, tuple)):
-                # Need to bypass WhereNode
-                tree.Node.add(self, data, connector)
-                return
-
-            obj, lookup_type, value = data
-            if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
-                # Consume any generators immediately, so that we can determine
-                # emptiness and transform any non-empty values correctly.
-                value = list(value)
-
-            # The "value_annotation" parameter is used to pass auxilliary information
-            # about the value(s) to the query construction. Specifically, datetime
-            # and empty values need special handling. Other types could be used
-            # here in the future (using Python types is suggested for consistency).
-            if isinstance(value, datetime.datetime):
-                value_annotation = datetime.datetime
-            elif hasattr(value, 'value_annotation'):
-                value_annotation = value.value_annotation
-            else:
-                value_annotation = bool(value)
-
-            if hasattr(obj, "prepare"):
-                value = obj.prepare(lookup_type, value)
-
-            # Need to bypass WhereNode
-            tree.Node.add(self,
-                (obj, lookup_type, value_annotation, value), connector)
-
-    def make_atom(self, child, qn, conn):
-        lvalue, lookup_type, value_annot, params_or_value = child
-
-        if hasattr(lvalue, 'process'):
-            try:
-                lvalue, params = lvalue.process(lookup_type, params_or_value,
-                                                connection)
-            except sql.where.EmptyShortCircuit:
-                raise query.EmptyResultSet
-        else:
-            return super(NetWhere, self).make_atom(child, qn, conn)
-
-        table_alias, name, db_type = lvalue
-
-        if db_type not in ['inet', 'cidr']:
-            return super(NetWhere, self).make_atom(child, qn, conn)
-
-        if table_alias:
-            field_sql = '%s.%s' % (qn(table_alias), qn(name))
-        else:
-            field_sql = qn(name)
-
-        if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
-            if db_type == 'inet':
-                field_sql = 'HOST(%s)' % field_sql
-            else:
-                field_sql = 'TEXT(%s)' % field_sql
-
-        if isinstance(params, QueryWrapper):
-            extra, params = params.data
-        else:
-            extra = ''
-
-        if isinstance(params, basestring):
-            params = (params,)
-
-        if lookup_type in NET_OPERATORS:
-            return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]),
-                    params)
-        elif lookup_type == 'in':
-            if not value_annot:
-                raise sql.datastructures.EmptyResultSet
-            if extra:
-                return ('%s IN %s' % (field_sql, extra), params)
-            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] *
-                    len(params))), params)
-        elif lookup_type == 'range':
-            return ('%s BETWEEN %%s and %%s' % field_sql, params)
-        elif lookup_type == 'isnull':
-            return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or
-                    '')), params)
-
-        raise ValueError('Invalid lookup type "%s"' % lookup_type)
+class NetContainsOrEqual(Lookup):
+    lookup_name = 'net_contains_or_equals'
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return "%s >>= %s" % (lhs, rhs), params
+
+Field.register_lookup(NetContainsOrEqual)
+
+class NetContains(Lookup):
+    lookup_name = 'net_contains'
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return "%s >> %s" % (lhs, rhs), params
+
+Field.register_lookup(NetContains)
+
+class NetContained(Lookup):
+    lookup_name = 'net_contained'
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return "%s << %s" % (lhs, rhs), params
+
+Field.register_lookup(NetContained)
+
+
+class NetContainedOrEqual(Lookup):
+    lookup_name = 'net_contained_or_equal'
+    def as_sql(self, qn, connection):
+        lhs, lhs_params = self.process_lhs(qn, connection)
+        rhs, rhs_params = self.process_rhs(qn, connection)
+        params = lhs_params + rhs_params
+        return "%s <<= %s" % (lhs, rhs), params
+
+Field.register_lookup(NetContainedOrEqual)
 
 
 class NetManager(models.Manager):
     use_for_related_fields = True
 
     def get_query_set(self):
-        q = NetQuery(self.model, NetWhere)
-        return query.QuerySet(self.model, q)
+        return super(NetManager, self).get_queryset()
+
+    def get_queryset(self):
+        return super(NetManager, self).get_queryset()
+

+ 0 - 7
netfields/models.py

@@ -6,7 +6,6 @@ from netfields import InetAddressField, CidrAddressField, MACAddressField, \
 
 class InetTestModel(Model):
     field = InetAddressField()
-    objects = NetManager()
 
     class Meta:
         db_table = 'inet'
@@ -14,7 +13,6 @@ class InetTestModel(Model):
 
 class NullInetTestModel(Model):
     field = InetAddressField(null=True)
-    objects = NetManager()
 
     class Meta:
         db_table = 'nullinet'
@@ -22,7 +20,6 @@ class NullInetTestModel(Model):
 
 class UniqueInetTestModel(Model):
     field = InetAddressField(unique=True)
-    objects = NetManager()
 
     class Meta:
         db_table = 'uniqueinet'
@@ -30,7 +27,6 @@ class UniqueInetTestModel(Model):
 
 class CidrTestModel(Model):
     field = CidrAddressField()
-    objects = NetManager()
 
     class Meta:
         db_table = 'cidr'
@@ -38,7 +34,6 @@ class CidrTestModel(Model):
 
 class NullCidrTestModel(Model):
     field = CidrAddressField(null=True)
-    objects = NetManager()
 
     class Meta:
         db_table = 'nullcidr'
@@ -46,7 +41,6 @@ class NullCidrTestModel(Model):
 
 class UniqueCidrTestModel(Model):
     field = CidrAddressField(unique=True)
-    objects = NetManager()
 
     class Meta:
         db_table = 'uniquecidr'
@@ -54,7 +48,6 @@ class UniqueCidrTestModel(Model):
 
 class MACTestModel(Model):
     field = MACAddressField(null=True)
-    objects = NetManager()
 
     class Meta:
         db_table = 'mac'