compiler.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # -*- coding: utf-8 -*-
  2. #
  3. # django-ldapdb
  4. # Copyright (c) 2009-2011, Bolloré telecom
  5. # Copyright (c) 2013, Jeremy Lainé
  6. # All rights reserved.
  7. #
  8. # See AUTHORS file for a full list of contributors.
  9. #
  10. # Redistribution and use in source and binary forms, with or without
  11. # modification, are permitted provided that the following conditions are met:
  12. #
  13. # 1. Redistributions of source code must retain the above copyright notice,
  14. # this list of conditions and the following disclaimer.
  15. #
  16. # 2. Redistributions in binary form must reproduce the above copyright
  17. # notice, this list of conditions and the following disclaimer in the
  18. # documentation and/or other materials provided with the distribution.
  19. #
  20. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  21. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  22. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  23. # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  24. # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  25. # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  26. # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  27. # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  28. # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  29. # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  30. # POSSIBILITY OF SUCH DAMAGE.
  31. #
  32. import ldap
  33. from django.db.models.sql import aggregates, compiler
  34. from django.db.models.sql.where import AND, OR
  35. def get_lookup_operator(lookup_type):
  36. if lookup_type == 'gte':
  37. return '>='
  38. elif lookup_type == 'lte':
  39. return '<='
  40. else:
  41. return '='
  42. def query_as_ldap(query):
  43. # starting with django 1.6 we can receive empty querysets
  44. if hasattr(query, 'is_empty') and query.is_empty():
  45. return
  46. filterstr = ''.join(['(objectClass=%s)' % cls for cls in
  47. query.model.object_classes])
  48. sql, params = where_as_ldap(query.where)
  49. filterstr += sql
  50. return '(&%s)' % filterstr
  51. def where_as_ldap(self):
  52. bits = []
  53. for item in self.children:
  54. if hasattr(item, 'as_sql'):
  55. sql, params = where_as_ldap(item)
  56. bits.append(sql)
  57. continue
  58. constraint, lookup_type, y, values = item
  59. comp = get_lookup_operator(lookup_type)
  60. if lookup_type == 'in':
  61. equal_bits = ["(%s%s%s)" % (constraint.col, comp, value) for value
  62. in values]
  63. clause = '(|%s)' % ''.join(equal_bits)
  64. else:
  65. clause = "(%s%s%s)" % (constraint.col, comp, values)
  66. bits.append(clause)
  67. if not len(bits):
  68. return '', []
  69. if len(bits) == 1:
  70. sql_string = bits[0]
  71. elif self.connector == AND:
  72. sql_string = '(&%s)' % ''.join(bits)
  73. elif self.connector == OR:
  74. sql_string = '(|%s)' % ''.join(bits)
  75. else:
  76. raise Exception("Unhandled WHERE connector: %s" % self.connector)
  77. if self.negated:
  78. sql_string = ('(!%s)' % sql_string)
  79. return sql_string, []
  80. class SQLCompiler(object):
  81. def __init__(self, query, connection, using):
  82. self.query = query
  83. self.connection = connection
  84. self.using = using
  85. def execute_sql(self, result_type=compiler.MULTI):
  86. if result_type != compiler.SINGLE:
  87. raise Exception("LDAP does not support MULTI queries")
  88. for key, aggregate in self.query.aggregate_select.items():
  89. if not isinstance(aggregate, aggregates.Count):
  90. raise Exception("Unsupported aggregate %s" % aggregate)
  91. filterstr = query_as_ldap(self.query)
  92. if not filterstr:
  93. return
  94. try:
  95. vals = self.connection.search_s(
  96. self.query.model.base_dn,
  97. self.query.model.search_scope,
  98. filterstr=filterstr,
  99. attrlist=['dn'],
  100. )
  101. except ldap.NO_SUCH_OBJECT:
  102. vals = []
  103. if not vals:
  104. return None
  105. output = []
  106. for alias, col in self.query.extra_select.iteritems():
  107. output.append(col[0])
  108. for key, aggregate in self.query.aggregate_select.items():
  109. if isinstance(aggregate, aggregates.Count):
  110. output.append(len(vals))
  111. else:
  112. output.append(None)
  113. return output
  114. def results_iter(self):
  115. filterstr = query_as_ldap(self.query)
  116. if not filterstr:
  117. return
  118. if hasattr(self.query, 'select_fields') and len(self.query.select_fields):
  119. # django < 1.6
  120. fields = self.query.select_fields
  121. elif len(self.query.select):
  122. # django >= 1.6
  123. fields = [x.field for x in self.query.select]
  124. else:
  125. fields = self.query.model._meta.fields
  126. attrlist = [x.db_column for x in fields if x.db_column]
  127. try:
  128. vals = self.connection.search_s(
  129. self.query.model.base_dn,
  130. self.query.model.search_scope,
  131. filterstr=filterstr,
  132. attrlist=attrlist,
  133. )
  134. except ldap.NO_SUCH_OBJECT:
  135. return
  136. # perform sorting
  137. if self.query.extra_order_by:
  138. ordering = self.query.extra_order_by
  139. elif not self.query.default_ordering:
  140. ordering = self.query.order_by
  141. else:
  142. ordering = self.query.order_by or self.query.model._meta.ordering
  143. def cmpvals(x, y):
  144. for fieldname in ordering:
  145. if fieldname.startswith('-'):
  146. fieldname = fieldname[1:]
  147. negate = True
  148. else:
  149. negate = False
  150. if fieldname == 'pk':
  151. fieldname = self.query.model._meta.pk.name
  152. field = self.query.model._meta.get_field(fieldname)
  153. attr_x = field.from_ldap(x[1].get(field.db_column, []),
  154. connection=self.connection)
  155. attr_y = field.from_ldap(y[1].get(field.db_column, []),
  156. connection=self.connection)
  157. # perform case insensitive comparison
  158. if hasattr(attr_x, 'lower'):
  159. attr_x = attr_x.lower()
  160. if hasattr(attr_y, 'lower'):
  161. attr_y = attr_y.lower()
  162. val = negate and cmp(attr_y, attr_x) or cmp(attr_x, attr_y)
  163. if val:
  164. return val
  165. return 0
  166. vals = sorted(vals, cmp=cmpvals)
  167. # process results
  168. pos = 0
  169. results = []
  170. for dn, attrs in vals:
  171. # FIXME : This is not optimal, we retrieve more results than we
  172. # need but there is probably no other options as we can't perform
  173. # ordering server side.
  174. if (self.query.low_mark and pos < self.query.low_mark) or \
  175. (self.query.high_mark is not None and
  176. pos >= self.query.high_mark):
  177. pos += 1
  178. continue
  179. row = []
  180. for field in iter(fields):
  181. if field.attname == 'dn':
  182. row.append(dn)
  183. elif hasattr(field, 'from_ldap'):
  184. row.append(field.from_ldap(attrs.get(field.db_column, []),
  185. connection=self.connection))
  186. else:
  187. row.append(None)
  188. if self.query.distinct:
  189. if row in results:
  190. continue
  191. else:
  192. results.append(row)
  193. yield row
  194. pos += 1
  195. class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
  196. pass
  197. class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
  198. def execute_sql(self, result_type=compiler.MULTI):
  199. filterstr = query_as_ldap(self.query)
  200. if not filterstr:
  201. return
  202. try:
  203. vals = self.connection.search_s(
  204. self.query.model.base_dn,
  205. self.query.model.search_scope,
  206. filterstr=filterstr,
  207. attrlist=['dn'],
  208. )
  209. except ldap.NO_SUCH_OBJECT:
  210. return
  211. # FIXME : there is probably a more efficient way to do this
  212. for dn, attrs in vals:
  213. self.connection.delete_s(dn)
  214. class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
  215. pass
  216. class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
  217. pass
  218. class SQLDateCompiler(compiler.SQLDateCompiler, SQLCompiler):
  219. pass