managers.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from netaddr import IPNetwork
  2. from django.db import models, connection
  3. from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
  4. from django.db.models import sql, query
  5. from django.db.models.query_utils import QueryWrapper
  6. from django.utils import tree
  7. import datetime
  8. NET_OPERATORS = DatabaseWrapper.operators.copy()
  9. for operator in ['contains', 'startswith', 'endswith']:
  10. NET_OPERATORS[operator] = 'ILIKE %s'
  11. NET_OPERATORS['i%s' % operator] = 'ILIKE %s'
  12. NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
  13. NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
  14. NET_OPERATORS['net_contained'] = '<< %s'
  15. NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
  16. NET_OPERATORS['net_contains'] = '>> %s'
  17. NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
  18. NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
  19. class NetQuery(sql.Query):
  20. query_terms = sql.Query.query_terms.copy()
  21. query_terms.update(NET_OPERATORS)
  22. class NetWhere(sql.where.WhereNode):
  23. def add(self, data, connector):
  24. """
  25. Special form of WhereNode.add() that does not automatically consume the
  26. __iter__ method of IPNetwork objects.
  27. """
  28. if not isinstance(data, (list, tuple)):
  29. # Need to bypass WhereNode
  30. tree.Node.add(self, data, connector)
  31. return
  32. obj, lookup_type, value = data
  33. if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
  34. # Consume any generators immediately, so that we can determine
  35. # emptiness and transform any non-empty values correctly.
  36. value = list(value)
  37. # The "value_annotation" parameter is used to pass auxilliary information
  38. # about the value(s) to the query construction. Specifically, datetime
  39. # and empty values need special handling. Other types could be used
  40. # here in the future (using Python types is suggested for consistency).
  41. if isinstance(value, datetime.datetime):
  42. value_annotation = datetime.datetime
  43. elif hasattr(value, 'value_annotation'):
  44. value_annotation = value.value_annotation
  45. else:
  46. value_annotation = bool(value)
  47. if hasattr(obj, "prepare"):
  48. value = obj.prepare(lookup_type, value)
  49. # Need to bypass WhereNode
  50. tree.Node.add(self,
  51. (obj, lookup_type, value_annotation, value), connector)
  52. def make_atom(self, child, qn, conn):
  53. lvalue, lookup_type, value_annot, params_or_value = child
  54. if hasattr(lvalue, 'process'):
  55. try:
  56. lvalue, params = lvalue.process(lookup_type, params_or_value,
  57. connection)
  58. except sql.where.EmptyShortCircuit:
  59. raise query.EmptyResultSet
  60. else:
  61. return super(NetWhere, self).make_atom(child, qn, conn)
  62. table_alias, name, db_type = lvalue
  63. if db_type not in ['inet', 'cidr']:
  64. return super(NetWhere, self).make_atom(child, qn, conn)
  65. if table_alias:
  66. field_sql = '%s.%s' % (qn(table_alias), qn(name))
  67. else:
  68. field_sql = qn(name)
  69. if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
  70. if db_type == 'inet':
  71. field_sql = 'HOST(%s)' % field_sql
  72. else:
  73. field_sql = 'TEXT(%s)' % field_sql
  74. if isinstance(params, QueryWrapper):
  75. extra, params = params.data
  76. else:
  77. extra = ''
  78. if isinstance(params, basestring):
  79. params = (params,)
  80. if lookup_type in NET_OPERATORS:
  81. return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]),
  82. params)
  83. elif lookup_type == 'in':
  84. if not value_annot:
  85. raise sql.datastructures.EmptyResultSet
  86. if extra:
  87. return ('%s IN %s' % (field_sql, extra), params)
  88. return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] *
  89. len(params))), params)
  90. elif lookup_type == 'range':
  91. return ('%s BETWEEN %%s and %%s' % field_sql, params)
  92. elif lookup_type == 'isnull':
  93. return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or
  94. '')), params)
  95. raise ValueError('Invalid lookup type "%s"' % lookup_type)
  96. class NetManager(models.Manager):
  97. use_for_related_fields = True
  98. def get_query_set(self):
  99. q = NetQuery(self.model, NetWhere)
  100. return query.QuerySet(self.model, q)