|
@@ -1,9 +1,10 @@
|
|
|
from IPy import IP
|
|
|
|
|
|
from django.core.exceptions import ValidationError
|
|
|
-from django.utils.translation import ugettext_lazy
|
|
|
from django.db import models, connection
|
|
|
from django.db.models import sql, query
|
|
|
+from django.db.models.query_utils import QueryWrapper
|
|
|
+from django.utils.translation import ugettext_lazy
|
|
|
|
|
|
NET_OPERATORS = {
|
|
|
'lt': '<',
|
|
@@ -28,8 +29,6 @@ NET_OPERATORS = {
|
|
|
|
|
|
NET_TEXT_OPERATORS = ['ILIKE', '~*']
|
|
|
|
|
|
-
|
|
|
-
|
|
|
class NetQuery(sql.Query):
|
|
|
query_terms = sql.Query.query_terms.copy()
|
|
|
query_terms.update(NET_OPERATORS)
|
|
@@ -44,20 +43,30 @@ class NetQuery(sql.Query):
|
|
|
class NetWhere(sql.where.WhereNode):
|
|
|
def make_atom(self, child, qn):
|
|
|
table_alias, name, db_type, lookup_type, value_annot, params = child
|
|
|
+
|
|
|
+ if db_type not in ['inet', 'cidr']:
|
|
|
+ return super(NetWhere, self).make_atom(child, qn)
|
|
|
+
|
|
|
if table_alias:
|
|
|
field_sql = '%s.%s' % (qn(table_alias), qn(name))
|
|
|
else:
|
|
|
field_sql = qn(name)
|
|
|
|
|
|
- if db_type not in ['inet', 'cidr']:
|
|
|
- return super(NetWhere, self).make_atom(child, qn)
|
|
|
-
|
|
|
if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
|
|
|
field_sql = 'HOST(%s)' % field_sql
|
|
|
|
|
|
+ if isinstance(params, QueryWrapper):
|
|
|
+ extra, params = params.data
|
|
|
+ else:
|
|
|
+ extra = ''
|
|
|
+
|
|
|
if lookup_type in NET_OPERATORS:
|
|
|
- return ('%s %s %%s' % (field_sql, NET_OPERATORS[lookup_type]), params)
|
|
|
+ return ('%s %s %%s %s' % (field_sql, NET_OPERATORS[lookup_type], extra), params)
|
|
|
elif lookup_type == 'in':
|
|
|
+ if not value_annot:
|
|
|
+ raise sql.datastructures.EmptyResultSet
|
|
|
+ if extra:
|
|
|
+ return ('%s IN %s' % (field_sql, extra), params)
|
|
|
return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params)
|
|
|
elif lookup_type == 'range':
|
|
|
return ('%s BETWEEN %%s and %%s' % field_sql, params)
|