Browse Source

Incorperate corner cases from django's make_atom

Thomas Adamcik 15 years ago
parent
commit
6f17f93783
1 changed files with 16 additions and 7 deletions
  1. 16 7
      manager.py

+ 16 - 7
manager.py

@@ -1,9 +1,10 @@
 from IPy import IP
 
 from django.core.exceptions import ValidationError
-from django.utils.translation import ugettext_lazy
 from django.db import models, connection
 from django.db.models import sql, query
+from django.db.models.query_utils import QueryWrapper
+from django.utils.translation import ugettext_lazy
 
 NET_OPERATORS = {
     'lt': '<',
@@ -28,8 +29,6 @@ NET_OPERATORS = {
 
 NET_TEXT_OPERATORS = ['ILIKE', '~*']
 
-# FIXME test with .extra() and QueryWrapper
-
 class NetQuery(sql.Query):
     query_terms = sql.Query.query_terms.copy()
     query_terms.update(NET_OPERATORS)
@@ -44,20 +43,30 @@ class NetQuery(sql.Query):
 class NetWhere(sql.where.WhereNode):
     def make_atom(self, child, qn):
         table_alias, name, db_type, lookup_type, value_annot, params = child
+
+        if db_type not in ['inet', 'cidr']:
+            return super(NetWhere, self).make_atom(child, qn)
+
         if table_alias:
             field_sql = '%s.%s' % (qn(table_alias), qn(name))
         else:
             field_sql = qn(name)
 
-        if db_type not in ['inet', 'cidr']:
-            return super(NetWhere, self).make_atom(child, qn)
-
         if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
             field_sql  = 'HOST(%s)' % field_sql
 
+        if isinstance(params, QueryWrapper):
+            extra, params = params.data
+        else:
+            extra = ''
+
         if lookup_type in NET_OPERATORS:
-            return ('%s %s %%s' % (field_sql, NET_OPERATORS[lookup_type]), params)
+            return ('%s %s %%s %s' % (field_sql, NET_OPERATORS[lookup_type], extra), params)
         elif lookup_type == 'in':
+            if not value_annot:
+                raise sql.datastructures.EmptyResultSet
+            if extra:
+                return ('%s IN %s' % (field_sql, extra), params)
             return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params)
         elif lookup_type == 'range':
             return ('%s BETWEEN %%s and %%s' % field_sql, params)