from netaddr import IPNetwork from django import VERSION from django.db import models, connection from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper from django.db.models.fields import DateTimeField from django.db.models import sql, query from django.db.models.query_utils import QueryWrapper from django.utils import tree import datetime NET_OPERATORS = DatabaseWrapper.operators.copy() for operator in ['contains', 'startswith', 'endswith']: NET_OPERATORS[operator] = 'ILIKE %s' NET_OPERATORS['i%s' % operator] = 'ILIKE %s' NET_OPERATORS['iexact'] = NET_OPERATORS['exact'] NET_OPERATORS['regex'] = NET_OPERATORS['iregex'] NET_OPERATORS['net_contained'] = '<< %s' NET_OPERATORS['net_contained_or_equal'] = '<<= %s' NET_OPERATORS['net_contains'] = '>> %s' NET_OPERATORS['net_contains_or_equals'] = '>>= %s' NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s'] class NetQuery(sql.Query): query_terms = sql.Query.query_terms.copy() query_terms.update(NET_OPERATORS) class NetWhere(sql.where.WhereNode): def _prepare_data(self, data): """ Prepare data for addition to the tree. If the data is a list or tuple, it is expected to be of the form (obj, lookup_type, value), where obj is a Constraint object, and is then slightly munged before being stored (to avoid storing any reference to field objects). Otherwise, the 'data' is stored unchanged and can be any class with an 'as_sql()' method. """ if not isinstance(data, (list, tuple)): return data obj, lookup_type, value = data if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'): # Consume any generators immediately, so that we can determine # emptiness and transform any non-empty values correctly. value = list(value) # The "value_annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime # and empty values need special handling. Other types could be used # here in the future (using Python types is suggested for consistency). if (isinstance(value, datetime.datetime) or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')): value_annotation = datetime.datetime elif hasattr(value, 'value_annotation'): value_annotation = value.value_annotation else: value_annotation = bool(value) if hasattr(obj, "prepare"): value = obj.prepare(lookup_type, value) return (obj, lookup_type, value_annotation, value) if VERSION[:2] < (1, 6): def add(self, data, connector): """ Special form of WhereNode.add() that does not automatically consume the __iter__ method of IPNetwork objects. """ if not isinstance(data, (list, tuple)): # Need to bypass WhereNode tree.Node.add(self, data, connector) return obj, lookup_type, value = data if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'): # Consume any generators immediately, so that we can determine # emptiness and transform any non-empty values correctly. value = list(value) # The "value_annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime # and empty values need special handling. Other types could be used # here in the future (using Python types is suggested for consistency). if isinstance(value, datetime.datetime): value_annotation = datetime.datetime elif hasattr(value, 'value_annotation'): value_annotation = value.value_annotation else: value_annotation = bool(value) if hasattr(obj, "prepare"): value = obj.prepare(lookup_type, value) # Need to bypass WhereNode tree.Node.add(self, (obj, lookup_type, value_annotation, value), connector) def make_atom(self, child, qn, conn): lvalue, lookup_type, value_annot, params_or_value = child if hasattr(lvalue, 'process'): try: lvalue, params = lvalue.process(lookup_type, params_or_value, connection) except sql.where.EmptyShortCircuit: raise query.EmptyResultSet else: return super(NetWhere, self).make_atom(child, qn, conn) table_alias, name, db_type = lvalue if db_type not in ['inet', 'cidr']: return super(NetWhere, self).make_atom(child, qn, conn) if table_alias: field_sql = '%s.%s' % (qn(table_alias), qn(name)) else: field_sql = qn(name) if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS: if db_type == 'inet': field_sql = 'HOST(%s)' % field_sql else: field_sql = 'TEXT(%s)' % field_sql if isinstance(params, QueryWrapper): extra, params = params.data else: extra = '' if isinstance(params, basestring): params = (params,) if lookup_type in NET_OPERATORS: return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), params) elif lookup_type == 'in': if not value_annot: raise sql.datastructures.EmptyResultSet if extra: return ('%s IN %s' % (field_sql, extra), params) return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params) elif lookup_type == 'range': return ('%s BETWEEN %%s and %%s' % field_sql, params) elif lookup_type == 'isnull': return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), params) raise ValueError('Invalid lookup type "%s"' % lookup_type) class NetManager(models.Manager): use_for_related_fields = True def get_query_set(self): q = NetQuery(self.model, NetWhere) return query.QuerySet(self.model, q)