managers.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from netaddr import IPNetwork
  2. from django import VERSION
  3. from django.db import models, connection
  4. from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper
  5. from django.db.models.fields import DateTimeField
  6. from django.db.models import sql, query
  7. from django.db.models.query_utils import QueryWrapper
  8. from django.utils import tree
  9. import datetime
  10. NET_OPERATORS = DatabaseWrapper.operators.copy()
  11. for operator in ['contains', 'startswith', 'endswith']:
  12. NET_OPERATORS[operator] = 'ILIKE %s'
  13. NET_OPERATORS['i%s' % operator] = 'ILIKE %s'
  14. NET_OPERATORS['iexact'] = NET_OPERATORS['exact']
  15. NET_OPERATORS['regex'] = NET_OPERATORS['iregex']
  16. NET_OPERATORS['net_contained'] = '<< %s'
  17. NET_OPERATORS['net_contained_or_equal'] = '<<= %s'
  18. NET_OPERATORS['net_contains'] = '>> %s'
  19. NET_OPERATORS['net_contains_or_equals'] = '>>= %s'
  20. NET_TEXT_OPERATORS = ['ILIKE %s', '~* %s']
  21. class NetQuery(sql.Query):
  22. query_terms = sql.Query.query_terms.copy()
  23. query_terms.update(NET_OPERATORS)
  24. class NetWhere(sql.where.WhereNode):
  25. def _prepare_data(self, data):
  26. """
  27. Special form of WhereNode._prepare_data() that does not automatically consume the
  28. __iter__ method of IPNetwork objects. This is used in Django >= 1.6
  29. """
  30. if not isinstance(data, (list, tuple)):
  31. return data
  32. obj, lookup_type, value = data
  33. if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
  34. # Consume any generators immediately, so that we can determine
  35. # emptiness and transform any non-empty values correctly.
  36. value = list(value)
  37. # The "value_annotation" parameter is used to pass auxilliary information
  38. # about the value(s) to the query construction. Specifically, datetime
  39. # and empty values need special handling. Other types could be used
  40. # here in the future (using Python types is suggested for consistency).
  41. if (isinstance(value, datetime.datetime)
  42. or (isinstance(obj.field, DateTimeField) and lookup_type != 'isnull')):
  43. value_annotation = datetime.datetime
  44. elif hasattr(value, 'value_annotation'):
  45. value_annotation = value.value_annotation
  46. else:
  47. value_annotation = bool(value)
  48. if hasattr(obj, "prepare"):
  49. value = obj.prepare(lookup_type, value)
  50. return (obj, lookup_type, value_annotation, value)
  51. if VERSION[:2] < (1, 6):
  52. def add(self, data, connector):
  53. """
  54. Special form of WhereNode.add() that does not automatically consume the
  55. __iter__ method of IPNetwork objects.
  56. """
  57. if not isinstance(data, (list, tuple)):
  58. # Need to bypass WhereNode
  59. tree.Node.add(self, data, connector)
  60. return
  61. obj, lookup_type, value = data
  62. if not isinstance(value, IPNetwork) and hasattr(value, '__iter__') and hasattr(value, 'next'):
  63. # Consume any generators immediately, so that we can determine
  64. # emptiness and transform any non-empty values correctly.
  65. value = list(value)
  66. # The "value_annotation" parameter is used to pass auxilliary information
  67. # about the value(s) to the query construction. Specifically, datetime
  68. # and empty values need special handling. Other types could be used
  69. # here in the future (using Python types is suggested for consistency).
  70. if isinstance(value, datetime.datetime):
  71. value_annotation = datetime.datetime
  72. elif hasattr(value, 'value_annotation'):
  73. value_annotation = value.value_annotation
  74. else:
  75. value_annotation = bool(value)
  76. if hasattr(obj, "prepare"):
  77. value = obj.prepare(lookup_type, value)
  78. # Need to bypass WhereNode
  79. tree.Node.add(self,
  80. (obj, lookup_type, value_annotation, value), connector)
  81. def make_atom(self, child, qn, conn):
  82. lvalue, lookup_type, value_annot, params_or_value = child
  83. if hasattr(lvalue, 'process'):
  84. try:
  85. lvalue, params = lvalue.process(lookup_type, params_or_value,
  86. connection)
  87. except sql.where.EmptyShortCircuit:
  88. raise query.EmptyResultSet
  89. else:
  90. return super(NetWhere, self).make_atom(child, qn, conn)
  91. table_alias, name, db_type = lvalue
  92. if db_type not in ['inet', 'cidr']:
  93. return super(NetWhere, self).make_atom(child, qn, conn)
  94. if table_alias:
  95. field_sql = '%s.%s' % (qn(table_alias), qn(name))
  96. else:
  97. field_sql = qn(name)
  98. if NET_OPERATORS.get(lookup_type, '') in NET_TEXT_OPERATORS:
  99. if db_type == 'inet':
  100. field_sql = 'HOST(%s)' % field_sql
  101. else:
  102. field_sql = 'TEXT(%s)' % field_sql
  103. if isinstance(params, QueryWrapper):
  104. extra, params = params.data
  105. else:
  106. extra = ''
  107. if isinstance(params, basestring):
  108. params = (params,)
  109. if lookup_type in NET_OPERATORS:
  110. return (' '.join([field_sql, NET_OPERATORS[lookup_type], extra]),
  111. params)
  112. elif lookup_type == 'in':
  113. if not value_annot:
  114. raise sql.datastructures.EmptyResultSet
  115. if extra:
  116. return ('%s IN %s' % (field_sql, extra), params)
  117. return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] *
  118. len(params))), params)
  119. elif lookup_type == 'range':
  120. return ('%s BETWEEN %%s and %%s' % field_sql, params)
  121. elif lookup_type == 'isnull':
  122. return ('%s IS %sNULL' % (field_sql, (not value_annot and 'NOT ' or
  123. '')), params)
  124. raise ValueError('Invalid lookup type "%s"' % lookup_type)
  125. class NetManager(models.Manager):
  126. use_for_related_fields = True
  127. def get_query_set(self):
  128. q = NetQuery(self.model, NetWhere)
  129. return query.QuerySet(self.model, q)