12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- from IPy import IP
- from django.db import models, connection
- from django.db.models import sql, query
- from django.db.models.query_utils import QueryWrapper
- NET_OPERATORS = connection.operators.copy()
- for operator in ['contains', 'startswith', 'endswith']:
- NET_OPERATORS[operator] = 'ILIKE %s'
- NET_OPERATORS['i%s' % operator] = 'ILIKE %s'
- NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
- NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
- NET_OPERATORS['net_contained'] = '<< %s'
- NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
- NET_OPERATORS['net_contains'] = '>> %s'
- NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
- NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
- class NetQuery(sql.Query):
- query_terms = sql.Query.query_terms.copy()
- query_terms.update(NET_OPERATORS)
- def add_filter(self, (filter_string, value), *args, **kwargs):
- # IP(...) == '' fails so make sure to force to string while we can
- if isinstance(value, IP):
- value = unicode(value)
- return super(NetQuery, self).add_filter(
- (filter_string, value), *args, **kwargs)
- class NetWhere(sql.where.WhereNode):
- def make_atom(self, child, qn):
- table_alias, name, db_type, lookup_type, value_annot, params = child
- if db_type not in ['inet', 'cidr']:
- return super(NetWhere, self).make_atom(child, qn)
- if table_alias:
- field_sql = '%s.%s' % (qn(table_alias), qn(name))
- else:
- field_sql = qn(name)
- if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
- if db_type == 'inet':
- field_sql = 'HOST(%s)' % field_sql
- else:
- field_sql = 'TEXT(%s)' % field_sql
- if isinstance(params, QueryWrapper):
- extra, params = params.data
- else:
- extra = ''
- if lookup_type in NET_OPERATORS:
- return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]), params)
- elif lookup_type == 'in':
- if not value_annot:
- raise sql.datastructures.EmptyResultSet
- if extra:
- return ('%s IN %s' % (field_sql, extra), params)
- return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))), params)
- elif lookup_type == 'range':
- return ('%s BETWEEN %%s and %%s' % field_sql, params)
- elif lookup_type == 'isnull':
- return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or '')), params)
- raise ValueError('Invalid lookup type "%s"' % lookup_type)
- class NetManger(models.Manager):
- use_for_related_fields = True
- def get_query_set(self):
- q = NetQuery(self.model, connection, NetWhere)
- return query.QuerySet(self.model, q)
|