managers.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from IPy import IP
  2. from django.db import models, connection
  3. from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
  4. from django.db.models import sql, query
  5. from django.db.models.query_utils import QueryWrapper
  6. NET_OPERATORS = DatabaseWrapper.operators.copy()
  7. for operator in ['contains', 'startswith', 'endswith']:
  8. NET_OPERATORS[operator] = 'ILIKE %s'
  9. NET_OPERATORS['i%s' % operator] = 'ILIKE %s'
  10. NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
  11. NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
  12. NET_OPERATORS['net_contained'] = '<< %s'
  13. NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
  14. NET_OPERATORS['net_contains'] = '>> %s'
  15. NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
  16. NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
  17. class NetQuery(sql.Query):
  18. query_terms = sql.Query.query_terms.copy()
  19. query_terms.update(NET_OPERATORS)
  20. class NetWhere(sql.where.WhereNode):
  21. def make_atom(self, child, qn, conn):
  22. lvalue, lookup_type, value_annot, params_or_value = child
  23. if hasattr(lvalue, 'process'):
  24. try:
  25. lvalue, params = lvalue.process(lookup_type, params_or_value,
  26. connection)
  27. except sql.where.EmptyShortCircuit:
  28. raise query.EmptyResultSet
  29. else:
  30. return super(NetWhere, self).make_atom(child, qn, conn)
  31. table_alias, name, db_type = lvalue
  32. if db_type not in ['inet', 'cidr']:
  33. return super(NetWhere, self).make_atom(child, qn, conn)
  34. if table_alias:
  35. field_sql = '%s.%s' % (qn(table_alias), qn(name))
  36. else:
  37. field_sql = qn(name)
  38. if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
  39. if db_type == 'inet':
  40. field_sql = 'HOST(%s)' % field_sql
  41. else:
  42. field_sql = 'TEXT(%s)' % field_sql
  43. if isinstance(params, QueryWrapper):
  44. extra, params = params.data
  45. else:
  46. extra = ''
  47. if isinstance(params, basestring):
  48. params = (params,)
  49. if lookup_type in NET_OPERATORS:
  50. return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]),
  51. params)
  52. elif lookup_type == 'in':
  53. if not value_annot:
  54. raise sql.datastructures.EmptyResultSet
  55. if extra:
  56. return ('%s IN %s' % (field_sql, extra), params)
  57. return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] *
  58. len(params))), params)
  59. elif lookup_type == 'range':
  60. return ('%s BETWEEN %%s and %%s' % field_sql, params)
  61. elif lookup_type == 'isnull':
  62. return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or
  63. '')), params)
  64. raise ValueError('Invalid lookup type "%s"' % lookup_type)
  65. class NetManager(models.Manager):
  66. use_for_related_fields = True
  67. def get_query_set(self):
  68. q = NetQuery(self.model, NetWhere)
  69. return query.QuerySet(self.model, q)