managers.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from netaddr import IPNetwork
  2. from django import VERSION
  3. from django.db import models, connection
  4. from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
  5. from django.db.models.fields import DateTimeField
  6. from django.db.models import sql, query
  7. from django.db.models.query_utils import QueryWrapper
  8. from django.utils import tree
  9. import datetime
  10. NET_OPERATORS = DatabaseWrapper.operators.copy()
  11. for operator in ['contains', 'startswith', 'endswith']:
  12. NET_OPERATORS[operator] = 'ILIKE %s'
  13. NET_OPERATORS['i%s' % operator] = 'ILIKE %s'
  14. NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
  15. NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
  16. NET_OPERATORS['net_contained'] = '<< %s'
  17. NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
  18. NET_OPERATORS['net_contains'] = '>> %s'
  19. NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
  20. NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
  21. class NetQuery(sql.Query):
  22. query_terms = sql.Query.query_terms.copy()
  23. query_terms.update(NET_OPERATORS)
  24. class NetWhere(sql.where.WhereNode):
  25. def _prepare_data(self, data):
  26. """
  27. Prepare data for addition to the tree. If the data is a list or tuple,
  28. it is expected to be of the form (obj, lookup_type, value), where obj
  29. is a Constraint object, and is then slightly munged before being
  30. stored (to avoid storing any reference to field objects). Otherwise,
  31. the 'data' is stored unchanged and can be any class with an 'as_sql()'
  32. method.
  33. """
  34. if not isinstance(data, (list, tuple)):
  35. return data
  36. obj, lookup_type, value = data
  37. if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
  38. # Consume any generators immediately, so that we can determine
  39. # emptiness and transform any non-empty values correctly.
  40. value = list(value)
  41. # The "value_annotation" parameter is used to pass auxilliary information
  42. # about the value(s) to the query construction. Specifically, datetime
  43. # and empty values need special handling. Other types could be used
  44. # here in the future (using Python types is suggested for consistency).
  45. if (isinstance(value, datetime.datetime)
  46. or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')):
  47. value_annotation = datetime.datetime
  48. elif hasattr(value, 'value_annotation'):
  49. value_annotation = value.value_annotation
  50. else:
  51. value_annotation = bool(value)
  52. if hasattr(obj, "prepare"):
  53. value = obj.prepare(lookup_type, value)
  54. return (obj, lookup_type, value_annotation, value)
  55. if VERSION[:2] < (1, 6):
  56. def add(self, data, connector):
  57. """
  58. Special form of WhereNode.add() that does not automatically consume the
  59. __iter__ method of IPNetwork objects.
  60. """
  61. if not isinstance(data, (list, tuple)):
  62. # Need to bypass WhereNode
  63. tree.Node.add(self, data, connector)
  64. return
  65. obj, lookup_type, value = data
  66. if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
  67. # Consume any generators immediately, so that we can determine
  68. # emptiness and transform any non-empty values correctly.
  69. value = list(value)
  70. # The "value_annotation" parameter is used to pass auxilliary information
  71. # about the value(s) to the query construction. Specifically, datetime
  72. # and empty values need special handling. Other types could be used
  73. # here in the future (using Python types is suggested for consistency).
  74. if isinstance(value, datetime.datetime):
  75. value_annotation = datetime.datetime
  76. elif hasattr(value, 'value_annotation'):
  77. value_annotation = value.value_annotation
  78. else:
  79. value_annotation = bool(value)
  80. if hasattr(obj, "prepare"):
  81. value = obj.prepare(lookup_type, value)
  82. # Need to bypass WhereNode
  83. tree.Node.add(self,
  84. (obj, lookup_type, value_annotation, value), connector)
  85. def make_atom(self, child, qn, conn):
  86. lvalue, lookup_type, value_annot, params_or_value = child
  87. if hasattr(lvalue, 'process'):
  88. try:
  89. lvalue, params = lvalue.process(lookup_type, params_or_value,
  90. connection)
  91. except sql.where.EmptyShortCircuit:
  92. raise query.EmptyResultSet
  93. else:
  94. return super(NetWhere, self).make_atom(child, qn, conn)
  95. table_alias, name, db_type = lvalue
  96. if db_type not in ['inet', 'cidr']:
  97. return super(NetWhere, self).make_atom(child, qn, conn)
  98. if table_alias:
  99. field_sql = '%s.%s' % (qn(table_alias), qn(name))
  100. else:
  101. field_sql = qn(name)
  102. if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
  103. if db_type == 'inet':
  104. field_sql = 'HOST(%s)' % field_sql
  105. else:
  106. field_sql = 'TEXT(%s)' % field_sql
  107. if isinstance(params, QueryWrapper):
  108. extra, params = params.data
  109. else:
  110. extra = ''
  111. if isinstance(params, basestring):
  112. params = (params,)
  113. if lookup_type in NET_OPERATORS:
  114. return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]),
  115. params)
  116. elif lookup_type == 'in':
  117. if not value_annot:
  118. raise sql.datastructures.EmptyResultSet
  119. if extra:
  120. return ('%s IN %s' % (field_sql, extra), params)
  121. return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] *
  122. len(params))), params)
  123. elif lookup_type == 'range':
  124. return ('%s BETWEEN %%s and %%s' % field_sql, params)
  125. elif lookup_type == 'isnull':
  126. return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or
  127. '')), params)
  128. raise ValueError('Invalid lookup type "%s"' % lookup_type)
  129. class NetManager(models.Manager):
  130. use_for_related_fields = True
  131. def get_query_set(self):
  132. q = NetQuery(self.model, NetWhere)
  133. return query.QuerySet(self.model, q)