|
@@ -3,8 +3,8 @@ from netaddr import IPNetwork
|
|
|
from django import VERSION
|
|
|
from django.db import models, connection
|
|
|
from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
|
|
|
-from django.db.models.fields import DateTimeField
|
|
|
-from django.db.models import sql, query
|
|
|
+from django.db.models.fields import DateTimeField, Field
|
|
|
+from django.db.models import sql, query, Lookup
|
|
|
from django.db.models.query_utils import QueryWrapper
|
|
|
from django.utils import tree
|
|
|
|
|
@@ -18,147 +18,58 @@ for operator in ['contains', 'startswith', 'endswith']:
|
|
|
|
|
|
NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
|
|
|
NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
|
|
|
-NET_OPERATORS['net_contained'] = '<< %s'
|
|
|
-NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
|
|
|
-NET_OPERATORS['net_contains'] = '>> %s'
|
|
|
-NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
|
|
|
|
|
|
NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
|
|
|
|
|
|
|
|
|
-class NetQuery(sql.Query):
|
|
|
- query_terms = sql.Query.query_terms.copy()
|
|
|
- query_terms.update(NET_OPERATORS)
|
|
|
-
|
|
|
-
|
|
|
-class NetWhere(sql.where.WhereNode):
|
|
|
-
|
|
|
-
|
|
|
- def _prepare_data(self, data):
|
|
|
- """
|
|
|
- Special form of WhereNode._prepare_data() that does not automatically consume the
|
|
|
- __iter__ method of IPNetwork objects. This is used in Django >= 1.6
|
|
|
- """
|
|
|
-
|
|
|
- if not isinstance(data, (list, tuple)):
|
|
|
- return data
|
|
|
- obj, lookup_type, value = data
|
|
|
- if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
|
|
|
- # Consume any generators immediately, so that we can determine
|
|
|
- # emptiness and transform any non-empty values correctly.
|
|
|
- value = list(value)
|
|
|
-
|
|
|
-
|
|
|
- # The "value_annotation" parameter is used to pass auxilliary information
|
|
|
- # about the value(s) to the query construction. Specifically, datetime
|
|
|
- # and empty values need special handling. Other types could be used
|
|
|
- # here in the future (using Python types is suggested for consistency).
|
|
|
- if (isinstance(value, datetime.datetime)
|
|
|
- or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')):
|
|
|
- value_annotation = datetime.datetime
|
|
|
- elif hasattr(value, 'value_annotation'):
|
|
|
- value_annotation = value.value_annotation
|
|
|
- else:
|
|
|
- value_annotation = bool(value)
|
|
|
-
|
|
|
- if hasattr(obj, "prepare"):
|
|
|
- value = obj.prepare(lookup_type, value)
|
|
|
- return (obj, lookup_type, value_annotation, value)
|
|
|
-
|
|
|
-
|
|
|
- if VERSION[:2] < (1, 6):
|
|
|
- def add(self, data, connector):
|
|
|
- """
|
|
|
- Special form of WhereNode.add() that does not automatically consume the
|
|
|
- __iter__ method of IPNetwork objects.
|
|
|
- """
|
|
|
- if not isinstance(data, (list, tuple)):
|
|
|
- # Need to bypass WhereNode
|
|
|
- tree.Node.add(self, data, connector)
|
|
|
- return
|
|
|
-
|
|
|
- obj, lookup_type, value = data
|
|
|
- if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
|
|
|
- # Consume any generators immediately, so that we can determine
|
|
|
- # emptiness and transform any non-empty values correctly.
|
|
|
- value = list(value)
|
|
|
-
|
|
|
- # The "value_annotation" parameter is used to pass auxilliary information
|
|
|
- # about the value(s) to the query construction. Specifically, datetime
|
|
|
- # and empty values need special handling. Other types could be used
|
|
|
- # here in the future (using Python types is suggested for consistency).
|
|
|
- if isinstance(value, datetime.datetime):
|
|
|
- value_annotation = datetime.datetime
|
|
|
- elif hasattr(value, 'value_annotation'):
|
|
|
- value_annotation = value.value_annotation
|
|
|
- else:
|
|
|
- value_annotation = bool(value)
|
|
|
-
|
|
|
- if hasattr(obj, "prepare"):
|
|
|
- value = obj.prepare(lookup_type, value)
|
|
|
-
|
|
|
- # Need to bypass WhereNode
|
|
|
- tree.Node.add(self,
|
|
|
- (obj, lookup_type, value_annotation, value), connector)
|
|
|
-
|
|
|
- def make_atom(self, child, qn, conn):
|
|
|
- lvalue, lookup_type, value_annot, params_or_value = child
|
|
|
-
|
|
|
- if hasattr(lvalue, 'process'):
|
|
|
- try:
|
|
|
- lvalue, params = lvalue.process(lookup_type, params_or_value,
|
|
|
- connection)
|
|
|
- except sql.where.EmptyShortCircuit:
|
|
|
- raise query.EmptyResultSet
|
|
|
- else:
|
|
|
- return super(NetWhere, self).make_atom(child, qn, conn)
|
|
|
-
|
|
|
- table_alias, name, db_type = lvalue
|
|
|
-
|
|
|
- if db_type not in ['inet', 'cidr']:
|
|
|
- return super(NetWhere, self).make_atom(child, qn, conn)
|
|
|
-
|
|
|
- if table_alias:
|
|
|
- field_sql = '%s.%s' % (qn(table_alias), qn(name))
|
|
|
- else:
|
|
|
- field_sql = qn(name)
|
|
|
-
|
|
|
- if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
|
|
|
- if db_type == 'inet':
|
|
|
- field_sql = 'HOST(%s)' % field_sql
|
|
|
- else:
|
|
|
- field_sql = 'TEXT(%s)' % field_sql
|
|
|
-
|
|
|
- if isinstance(params, QueryWrapper):
|
|
|
- extra, params = params.data
|
|
|
- else:
|
|
|
- extra = ''
|
|
|
-
|
|
|
- if isinstance(params, basestring):
|
|
|
- params = (params,)
|
|
|
-
|
|
|
- if lookup_type in NET_OPERATORS:
|
|
|
- return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]),
|
|
|
- params)
|
|
|
- elif lookup_type == 'in':
|
|
|
- if not value_annot:
|
|
|
- raise sql.datastructures.EmptyResultSet
|
|
|
- if extra:
|
|
|
- return ('%s IN %s' % (field_sql, extra), params)
|
|
|
- return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] *
|
|
|
- len(params))), params)
|
|
|
- elif lookup_type == 'range':
|
|
|
- return ('%s BETWEEN %%s and %%s' % field_sql, params)
|
|
|
- elif lookup_type == 'isnull':
|
|
|
- return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or
|
|
|
- '')), params)
|
|
|
-
|
|
|
- raise ValueError('Invalid lookup type "%s"' % lookup_type)
|
|
|
+class NetContainsOrEqual(Lookup):
|
|
|
+ lookup_name = 'net_contains_or_equals'
|
|
|
+ def as_sql(self, qn, connection):
|
|
|
+ lhs, lhs_params = self.process_lhs(qn, connection)
|
|
|
+ rhs, rhs_params = self.process_rhs(qn, connection)
|
|
|
+ params = lhs_params + rhs_params
|
|
|
+ return "%s >>= %s" % (lhs, rhs), params
|
|
|
+
|
|
|
+Field.register_lookup(NetContainsOrEqual)
|
|
|
+
|
|
|
+class NetContains(Lookup):
|
|
|
+ lookup_name = 'net_contains'
|
|
|
+ def as_sql(self, qn, connection):
|
|
|
+ lhs, lhs_params = self.process_lhs(qn, connection)
|
|
|
+ rhs, rhs_params = self.process_rhs(qn, connection)
|
|
|
+ params = lhs_params + rhs_params
|
|
|
+ return "%s >> %s" % (lhs, rhs), params
|
|
|
+
|
|
|
+Field.register_lookup(NetContains)
|
|
|
+
|
|
|
+class NetContained(Lookup):
|
|
|
+ lookup_name = 'net_contained'
|
|
|
+ def as_sql(self, qn, connection):
|
|
|
+ lhs, lhs_params = self.process_lhs(qn, connection)
|
|
|
+ rhs, rhs_params = self.process_rhs(qn, connection)
|
|
|
+ params = lhs_params + rhs_params
|
|
|
+ return "%s << %s" % (lhs, rhs), params
|
|
|
+
|
|
|
+Field.register_lookup(NetContained)
|
|
|
+
|
|
|
+
|
|
|
+class NetContainedOrEqual(Lookup):
|
|
|
+ lookup_name = 'net_contained_or_equal'
|
|
|
+ def as_sql(self, qn, connection):
|
|
|
+ lhs, lhs_params = self.process_lhs(qn, connection)
|
|
|
+ rhs, rhs_params = self.process_rhs(qn, connection)
|
|
|
+ params = lhs_params + rhs_params
|
|
|
+ return "%s <<= %s" % (lhs, rhs), params
|
|
|
+
|
|
|
+Field.register_lookup(NetContainedOrEqual)
|
|
|
|
|
|
|
|
|
class NetManager(models.Manager):
|
|
|
use_for_related_fields = True
|
|
|
|
|
|
def get_query_set(self):
|
|
|
- q = NetQuery(self.model, NetWhere)
|
|
|
- return query.QuerySet(self.model, q)
|
|
|
+ return super(NetManager, self).get_queryset()
|
|
|
+
|
|
|
+ def get_queryset(self):
|
|
|
+ return super(NetManager, self).get_queryset()
|
|
|
+
|