querying.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from lettuce import *
  2. import subprocess
  3. import re
  4. # This script provides querying functionality
  5. # The most important step is
  6. #
  7. # query for <name> [type X] [class X] [to <addr>[:port]] should have rcode <rc>
  8. #
  9. # By default, it will send queries to 127.0.0.1:47806 unless specified
  10. # otherwise. The rcode is always checked. If the result is not NO_ANSWER,
  11. # the result will be stored in last_query_result, which can then be inspected
  12. # more closely, for instance with the step
  13. #
  14. # "the last query response should have <property> <value>"
  15. #
  16. # Also see example.feature for some examples
  17. #
  18. # define a class to easily access different parts
  19. # We may consider using our full library for this, but for now
  20. # simply store several parts of the response as text values in
  21. # this structure.
  22. # (this actually has the advantage of not relying on our own libraries
  23. # to test our own, well, libraries)
  24. #
  25. # The following attributes are 'parsed' from the response, all as strings,
  26. # and end up as direct attributes of the QueryResult object:
  27. # opcode, rcode, id, flags, qdcount, ancount, nscount, adcount
  28. # (flags is one string with all flags)
  29. #
  30. # this will set 'rcode' as the result code, we 'define' one additional
  31. # rcode, "NO_ANSWER", if the dig process returned an error code itself
  32. # In this case none of the other attributes will be set.
  33. #
  34. # The different sections will be lists of strings, one for each RR in the
  35. # section. The question section will start with ';', as per dig output
  36. #
  37. # See server_from_sqlite3.feature for various examples to perform queries
  38. class QueryResult(object):
  39. status_re = re.compile("opcode: ([A-Z])+, status: ([A-Z]+), id: ([0-9]+)")
  40. flags_re = re.compile("flags: ([a-z ]+); QUERY: ([0-9]+), ANSWER: " +
  41. "([0-9]+), AUTHORITY: ([0-9]+), ADDITIONAL: ([0-9]+)")
  42. def __init__(self, name, qtype, qclass, address, port):
  43. args = [ 'dig', '+tries=1', '@' + address, '-p', str(port) ]
  44. if qtype is not None:
  45. args.append('-t')
  46. args.append(str(qtype))
  47. if qclass is not None:
  48. args.append('-c')
  49. args.append(str(qclass))
  50. args.append(name)
  51. dig_process = subprocess.Popen(args, 1, None, None, subprocess.PIPE,
  52. None)
  53. result = dig_process.wait()
  54. if result != 0:
  55. self.rcode = "NO_ANSWER"
  56. else:
  57. self.rcode = None
  58. parsing = "HEADER"
  59. self.question_section = []
  60. self.answer_section = []
  61. self.authority_section = []
  62. self.additional_section = []
  63. self.line_handler = self.parse_header
  64. for out in dig_process.stdout:
  65. self.line_handler(out)
  66. def _check_next_header(self, line):
  67. """Returns true if we found a next header, and sets the internal
  68. line handler to the appropriate value.
  69. """
  70. if line == ";; ANSWER SECTION:\n":
  71. self.line_handler = self.parse_answer
  72. elif line == ";; AUTHORITY SECTION:\n":
  73. self.line_handler = self.parse_authority
  74. elif line == ";; ADDITIONAL SECTION:\n":
  75. self.line_handler = self.parse_additional
  76. elif line.startswith(";; Query time"):
  77. self.line_handler = self.parse_footer
  78. else:
  79. return False
  80. return True
  81. def parse_header(self, line):
  82. if not self._check_next_header(line):
  83. status_match = self.status_re.search(line)
  84. flags_match = self.flags_re.search(line)
  85. if status_match is not None:
  86. self.opcode = status_match.group(1)
  87. self.rcode = status_match.group(2)
  88. elif flags_match is not None:
  89. self.flags = flags_match.group(1)
  90. self.qdcount = flags_match.group(2)
  91. self.ancount = flags_match.group(3)
  92. self.nscount = flags_match.group(4)
  93. self.adcount = flags_match.group(5)
  94. def parse_question(self, line):
  95. if not self._check_next_header(line):
  96. if line != "\n":
  97. self.question_section.append(line.strip())
  98. def parse_answer(self, line):
  99. if not self._check_next_header(line):
  100. if line != "\n":
  101. self.answer_section.append(line.strip())
  102. def parse_authority(self, line):
  103. if not self._check_next_header(line):
  104. if line != "\n":
  105. self.authority_section.append(line.strip())
  106. def parse_authority(self, line):
  107. if not self._check_next_header(line):
  108. if line != "\n":
  109. self.additional_section.append(line.strip())
  110. def parse_footer(self, line):
  111. pass
  112. @step('A query for ([\w.]+) (?:type ([A-Z]+) )?(?:class ([A-Z]+) )?' +
  113. '(?:to ([^:]+)(?::([0-9]+))? )?should have rcode ([\w.]+)')
  114. def query(step, query_name, qtype, qclass, addr, port, rcode):
  115. if qtype is None:
  116. qtype = "A"
  117. if qclass is None:
  118. qclass = "IN"
  119. if addr is None:
  120. addr = "127.0.0.1"
  121. if port is None:
  122. port = 47806
  123. query_result = QueryResult(query_name, qtype, qclass, addr, port)
  124. assert query_result.rcode == rcode,\
  125. "Expected: " + rcode + ", got " + query_result.rcode
  126. world.last_query_result = query_result
  127. @step('The SOA serial for ([\w.]+) should be ([0-9]+)')
  128. def query_soa(step, query_name, serial):
  129. query_result = QueryResult(query_name, "SOA", "IN", "127.0.0.1", "47806")
  130. assert "NOERROR" == query_result.rcode,\
  131. "Got " + query_result.rcode + ", expected NOERROR"
  132. assert len(query_result.answer_section) == 1,\
  133. "Too few or too many answers in SOA response"
  134. soa_parts = query_result.answer_section[0].split()
  135. assert serial == soa_parts[6],\
  136. "Got SOA serial " + soa_parts[6] + ", expected " + serial
  137. @step('last query response should have (\S+) (.+)')
  138. def check_last_query(step, item, value):
  139. assert world.last_query_result is not None
  140. assert item in world.last_query_result.__dict__
  141. lq_val = world.last_query_result.__dict__[item]
  142. assert str(value) == str(lq_val),\
  143. "Got: " + str(lq_val) + ", expected: " + str(value)
  144. @step('([a-zA-Z]+) section of the last query response should be')
  145. def check_last_query_section(step, section):
  146. response_string = None
  147. if section.lower() == 'question':
  148. response_string = "\n".join(world.last_query_result.question_section)
  149. elif section.lower() == 'answer':
  150. response_string = "\n".join(world.last_query_result.answer_section)
  151. elif section.lower() == 'authority':
  152. response_string = "\n".join(world.last_query_result.answer_section)
  153. elif section.lower() == 'additional':
  154. response_string = "\n".join(world.last_query_result.answer_section)
  155. else:
  156. assert False, "Unknown section " + section
  157. # replace whitespace of any length by one space
  158. response_string = re.sub("[ \t]+", " ", response_string)
  159. expect = re.sub("[ \t]+", " ", step.multiline)
  160. assert response_string.strip() == expect.strip(),\
  161. "Got:\n'" + response_string + "'\nExpected:\n'" + step.multiline +"'"