|
@@ -1,5 +1,6 @@
|
|
|
from IPy import IP
|
|
|
|
|
|
+from types import StringTypes
|
|
|
from django.db import models, connection
|
|
|
from django.db.models import sql, query
|
|
|
from django.db.models.query_utils import QueryWrapper
|
|
@@ -33,11 +34,18 @@ class NetQuery(sql.Query):
|
|
|
|
|
|
|
|
|
class NetWhere(sql.where.WhereNode):
|
|
|
- def make_atom(self, child, qn):
|
|
|
- table_alias, name, db_type, lookup_type, value_annot, params = child
|
|
|
+ def make_atom(self, child, qn , conn):
|
|
|
+ if isinstance(child[0] , sql.where.Constraint):
|
|
|
+ c = child[0]
|
|
|
+ table_alias = c.alias
|
|
|
+ name = c.col
|
|
|
+ field = c.field
|
|
|
+ lookup_type , value_annot , params = child[1:]
|
|
|
+ else:
|
|
|
+ table_alias, name, db_type, lookup_type, value_annot, params = child
|
|
|
|
|
|
- if db_type not in ['inet', 'cidr']:
|
|
|
- return super(NetWhere, self).make_atom(child, qn)
|
|
|
+ if field.db_type() not in ['inet', 'cidr']:
|
|
|
+ return super(NetWhere, self).make_atom(child, qn , conn)
|
|
|
|
|
|
if table_alias:
|
|
|
field_sql = '%s.%s' % (qn(table_alias), qn(name))
|
|
@@ -45,7 +53,7 @@ class NetWhere(sql.where.WhereNode):
|
|
|
field_sql = qn(name)
|
|
|
|
|
|
if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
|
|
|
- if db_type == 'inet':
|
|
|
+ if field.db_type() == 'inet':
|
|
|
field_sql = 'HOST(%s)' % field_sql
|
|
|
else:
|
|
|
field_sql = 'TEXT(%s)' % field_sql
|
|
@@ -55,6 +63,9 @@ class NetWhere(sql.where.WhereNode):
|
|
|
else:
|
|
|
extra = ''
|
|
|
|
|
|
+ if type(params) in StringTypes:
|
|
|
+ params = (params,)
|
|
|
+
|
|
|
if lookup_type in NET_OPERATORS:
|
|
|
return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), params)
|
|
|
elif lookup_type == 'in':
|
|
@@ -75,5 +86,5 @@ class NetManager(models.Manager):
|
|
|
use_for_related_fields = True
|
|
|
|
|
|
def get_query_set(self):
|
|
|
- q = NetQuery(self.model, connection, NetWhere)
|
|
|
+ q = NetQuery(self.model, NetWhere)
|
|
|
return query.QuerySet(self.model, q)
|