Browse Source

Merge pull request #17 from jmacul2/master

Fixes for Django 1.6
jimfunk 11 years ago
parent
commit
db1a400de0
2 changed files with 61 additions and 15 deletions
  1. 48 11
      netfields/managers.py
  2. 13 4
      netfields/tests.py

+ 48 - 11
netfields/managers.py

@@ -1,7 +1,9 @@
 from netaddr import IPNetwork
 
+from django import VERSION
 from django.db import models, connection
 from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
+from django.db.models.fields import DateTimeField
 from django.db.models import sql, query
 from django.db.models.query_utils import QueryWrapper
 from django.utils import tree
@@ -30,27 +32,29 @@ class NetQuery(sql.Query):
 
 
 class NetWhere(sql.where.WhereNode):
-    def add(self, data, connector):
+
+
+    def _prepare_data(self, data):
         """
-        Special form of WhereNode.add() that does not automatically consume the
-        __iter__ method of IPNetwork objects.
+            Special form of WhereNode._prepare_data() that does not automatically consume the
+            __iter__ method of IPNetwork objects.  This is used in Django >= 1.6
         """
-        if not isinstance(data, (list, tuple)):
-            # Need to bypass WhereNode
-            tree.Node.add(self, data, connector)
-            return
 
+        if not isinstance(data, (list, tuple)):
+            return data
         obj, lookup_type, value = data
         if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
             # Consume any generators immediately, so that we can determine
             # emptiness and transform any non-empty values correctly.
             value = list(value)
 
+
         # The "value_annotation" parameter is used to pass auxilliary information
         # about the value(s) to the query construction. Specifically, datetime
         # and empty values need special handling. Other types could be used
         # here in the future (using Python types is suggested for consistency).
-        if isinstance(value, datetime.datetime):
+        if (isinstance(value, datetime.datetime)
+            or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')):
             value_annotation = datetime.datetime
         elif hasattr(value, 'value_annotation'):
             value_annotation = value.value_annotation
@@ -59,10 +63,43 @@ class NetWhere(sql.where.WhereNode):
 
         if hasattr(obj, "prepare"):
             value = obj.prepare(lookup_type, value)
+        return (obj, lookup_type, value_annotation, value)
+
+
+    if VERSION[:2] < (1, 6):
+        def add(self, data, connector):
+            """
+            Special form of WhereNode.add() that does not automatically consume the
+            __iter__ method of IPNetwork objects.
+            """
+            if not isinstance(data, (list, tuple)):
+                # Need to bypass WhereNode
+                tree.Node.add(self, data, connector)
+                return
+
+            obj, lookup_type, value = data
+            if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
+                # Consume any generators immediately, so that we can determine
+                # emptiness and transform any non-empty values correctly.
+                value = list(value)
+
+            # The "value_annotation" parameter is used to pass auxilliary information
+            # about the value(s) to the query construction. Specifically, datetime
+            # and empty values need special handling. Other types could be used
+            # here in the future (using Python types is suggested for consistency).
+            if isinstance(value, datetime.datetime):
+                value_annotation = datetime.datetime
+            elif hasattr(value, 'value_annotation'):
+                value_annotation = value.value_annotation
+            else:
+                value_annotation = bool(value)
 
-        # Need to bypass WhereNode
-        tree.Node.add(self,
-            (obj, lookup_type, value_annotation, value), connector)
+            if hasattr(obj, "prepare"):
+                value = obj.prepare(lookup_type, value)
+
+            # Need to bypass WhereNode
+            tree.Node.add(self,
+                (obj, lookup_type, value_annotation, value), connector)
 
     def make_atom(self, child, qn, conn):
         lvalue, lookup_type, value_annot, params_or_value = child

+ 13 - 4
netfields/tests.py

@@ -1,5 +1,6 @@
 from netaddr import IPAddress, IPNetwork, EUI, AddrFormatError
 
+from django import VERSION
 from django.db import IntegrityError
 from django.forms import ModelForm
 from django.test import TestCase
@@ -355,12 +356,20 @@ class BaseMacTestCase(BaseSqlTestCase):
             self.select + 'WHERE UPPER("table"."field"::text) LIKE UPPER(%s) ')
 
     def test_regex_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__regex='00'),
-            self.select + 'WHERE "table"."field" ~ %s ')
+        if VERSION[:2] < (1, 6):
+            self.assertSqlEquals(self.qs.filter(field__regex='00'),
+                self.select + 'WHERE "table"."field" ~ %s ')
+        else:
+            self.assertSqlEquals(self.qs.filter(field__regex='00'),
+                self.select + 'WHERE "table"."field"::text ~ %s ')
 
     def test_iregex_lookup(self):
-        self.assertSqlEquals(self.qs.filter(field__iregex='00'),
-            self.select + 'WHERE "table"."field" ~* %s ')
+        if VERSION[:2] < (1, 6):
+            self.assertSqlEquals(self.qs.filter(field__iregex='00'),
+                self.select + 'WHERE "table"."field" ~* %s ')
+        else:
+            self.assertSqlEquals(self.qs.filter(field__iregex='00'),
+                self.select + 'WHERE "table"."field"::text ~* %s ')
 
 
 class TestMacAddressField(BaseMacTestCase, TestCase):