xfrout_test.py 10 KB


  1. # Copyright (C) 2010 Internet Systems Consortium.
  2. #
  3. # Permission to use, copy, modify, and distribute this software for any
  4. # purpose with or without fee is hereby granted, provided that the above
  5. # copyright notice and this permission notice appear in all copies.
  6. #
  7. # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
  8. # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
  9. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
  10. # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
  11. # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
  12. # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
  13. # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
  14. # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. '''Tests for the XfroutSession and UnixSockServer classes '''
  16. import unittest
  17. import os
  18. from isc.cc.session import *
  19. from bind10_dns import *
  20. from xfrout import *
  21. # our fake socket, where we can read and insert messages
  22. class MySocket():
  23. def __init__(self, family, type):
  24. self.family = family
  25. self.type = type
  26. self.sendqueue = bytearray()
  27. def connect(self, to):
  28. pass
  29. def close(self):
  30. pass
  31. def send(self, data):
  32. self.sendqueue.extend(data);
  33. return len(data)
  34. def readsent(self):
  35. result = self.sendqueue[:]
  36. del self.sendqueue[:]
  37. return result
  38. def read_msg(self):
  39. sent_data = self.readsent()
  40. get_msg = message(message_mode.PARSE)
  41. get_msg.from_wire(input_buffer(bytes(sent_data[2:])))
  42. return get_msg
  43. def clear_send(self):
  44. del self.sendqueue[:]
  45. # We subclass the Session class we're testing here, only
  46. # to override the __init__() method, which wants a socket,
  47. class MyXfroutSession(XfroutSession):
  48. def handle(self):
  49. pass
  50. class Dbserver:
  51. def __init__(self):
  52. self._shutdown_event = threading.Event()
  53. def get_db_file(self):
  54. return None
  55. def decrease_transfers_counter(self):
  56. pass
  57. class TestXfroutSession(unittest.TestCase):
  58. def getmsg(self):
  59. msg = message(message_mode.PARSE)
  60. msg.from_wire(input_buffer(self.mdata))
  61. return msg
  62. def setUp(self):
  63. request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
  64. self.xfrsess = MyXfroutSession(request, None, None)
  65. self.xfrsess.server = Dbserver()
  66. self.mdata = b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01'
  67. self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
  68. self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
  69. def test_parse_query_message(self):
  70. [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
  71. self.assertEqual(get_rcode.to_text(), "NOERROR")
  72. def test_get_query_zone_name(self):
  73. msg = self.getmsg()
  74. self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
  75. def test_send_data(self):
  76. self.xfrsess._send_data(self.sock, self.mdata)
  77. senddata = self.sock.readsent()
  78. self.assertEqual(senddata, self.mdata)
  79. def test_reply_xfrout_query_with_error_rcode(self):
  80. msg = self.getmsg()
  81. self.xfrsess._reply_query_with_error_rcode(msg, self.sock, rcode(3))
  82. get_msg = self.sock.read_msg()
  83. self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
  84. def test_clear_message(self):
  85. msg = self.getmsg()
  86. qid = msg.get_qid()
  87. opcode = msg.get_opcode()
  88. rcode = msg.get_rcode()
  89. self.xfrsess._clear_message(msg)
  90. self.assertEqual(msg.get_qid(), qid)
  91. self.assertEqual(msg.get_opcode(), opcode)
  92. self.assertEqual(msg.get_rcode(), rcode)
  93. self.assertTrue(msg.get_header_flag(message_flag.AA()))
  94. def test_reply_query_with_format_error(self):
  95. msg = self.getmsg()
  96. self.xfrsess._reply_query_with_format_error(msg, self.sock)
  97. get_msg = self.sock.read_msg()
  98. self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
  99. def test_create_rrset_from_db_record(self):
  100. rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  101. self.assertEqual(rrset.get_name().to_text(), "example.com.")
  102. self.assertEqual(rrset.get_class(), rr_class.IN())
  103. self.assertEqual(rrset.get_type().to_text(), "SOA")
  104. rdata_iter = rrset.get_rdata_iterator()
  105. rdata_iter.first()
  106. self.assertEqual(rdata_iter.get_current().to_text(), self.soa_record[7])
  107. def test_send_message_with_last_soa(self):
  108. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  109. msg = self.getmsg()
  110. msg.make_response()
  111. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa)
  112. get_msg = self.sock.read_msg()
  113. self.assertEqual(get_msg.get_rr_count(section.QUESTION()), 1)
  114. self.assertEqual(get_msg.get_rr_count(section.ANSWER()), 1)
  115. self.assertEqual(get_msg.get_rr_count(section.AUTHORITY()), 0)
  116. answer_rrset_iter = section_iter(get_msg, section.ANSWER())
  117. answer = answer_rrset_iter.get_rrset()
  118. self.assertEqual(answer.get_name().to_text(), "example.com.")
  119. self.assertEqual(answer.get_class(), rr_class.IN())
  120. self.assertEqual(answer.get_type().to_text(), "SOA")
  121. rdata_iter = answer.get_rdata_iterator()
  122. rdata_iter.first()
  123. self.assertEqual(rdata_iter.get_current().to_text(), self.soa_record[7])
  124. def test_get_message_len(self):
  125. msg = self.getmsg()
  126. msg.make_response()
  127. self.assertEqual(self.xfrsess._get_message_len(msg), 29)
  128. def test_zone_is_empty(self):
  129. global sqlite3_ds
  130. def mydb1(zone, file):
  131. return True
  132. sqlite3_ds.get_zone_soa = mydb1
  133. self.assertEqual(self.xfrsess._zone_is_empty(""), False)
  134. def mydb2(zone, file):
  135. return False
  136. sqlite3_ds.get_zone_soa = mydb2
  137. self.assertEqual(self.xfrsess._zone_is_empty(""), True)
  138. def test_zone_exist(self):
  139. global sqlite3_ds
  140. def zone_soa(zone, file):
  141. return zone
  142. sqlite3_ds.get_zone_soa = zone_soa
  143. self.assertEqual(self.xfrsess._zone_exist(True), True)
  144. self.assertEqual(self.xfrsess._zone_exist(False), False)
  145. def test_check_xfrout_available(self):
  146. def zone_exist(zone):
  147. return zone
  148. self.xfrsess._zone_exist = zone_exist
  149. self.xfrsess._zone_is_empty = zone_exist
  150. self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
  151. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
  152. def zone_empty(zone):
  153. return not zone
  154. self.xfrsess._zone_is_empty = zone_empty
  155. def false_func():
  156. return False
  157. self.xfrsess.server.increase_transfers_counter = false_func
  158. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
  159. def true_func():
  160. return True
  161. self.xfrsess.server.increase_transfers_counter = true_func
  162. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
  163. def test_dns_xfrout_start_formerror(self):
  164. # formerror
  165. self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
  166. sent_data = self.sock.readsent()
  167. self.assertEqual(len(sent_data), 0)
  168. def default(self, param):
  169. return "example.com"
  170. def test_dns_xfrout_start_notauth(self):
  171. self.xfrsess._get_query_zone_name = self.default
  172. def notauth(formpara):
  173. return rcode.NOTAUTH()
  174. self.xfrsess._check_xfrout_available = notauth
  175. self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
  176. get_msg = self.sock.read_msg()
  177. self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
  178. def test_dns_xfrout_start_noerror(self):
  179. self.xfrsess._get_query_zone_name = self.default
  180. def noerror(form):
  181. return rcode.NOERROR()
  182. self.xfrsess._check_xfrout_available = noerror
  183. def myreply(msg, sock, zonename):
  184. self.sock.send(b"success")
  185. self.xfrsess._reply_xfrout_query = myreply
  186. self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
  187. self.assertEqual(self.sock.readsent(), b"success")
  188. def test_reply_xfrout_query_noerror(self):
  189. global sqlite3_ds
  190. def get_zone_soa(zonename, file):
  191. return self.soa_record
  192. def get_zone_datas(zone, file):
  193. return [self.soa_record]
  194. sqlite3_ds.get_zone_soa = get_zone_soa
  195. sqlite3_ds.get_zone_datas = get_zone_datas
  196. self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
  197. reply_msg = self.sock.read_msg()
  198. self.assertEqual(reply_msg.get_rr_count(section.ANSWER()), 2)
  199. # set event
  200. self.xfrsess.server._shutdown_event.set()
  201. self.assertRaises(XfroutException, self.xfrsess._reply_xfrout_query, self.getmsg(), self.sock, "example.com.")
  202. class MyUnixSockServer(UnixSockServer):
  203. def __init__(self):
  204. self._lock = threading.Lock()
  205. self._transfers_counter = 0
  206. self._shutdown_event = threading.Event()
  207. self._db_file = "initdb.file"
  208. self._max_transfers_out = 10
  209. class TestUnixSockServer(unittest.TestCase):
  210. def setUp(self):
  211. self.unix = MyUnixSockServer()
  212. def test_updata_config_data(self):
  213. self.unix.update_config_data({'transfers_out':10, 'db_file':"db.file"})
  214. self.assertEqual(self.unix._max_transfers_out, 10)
  215. self.assertEqual(self.unix._db_file, "db.file")
  216. def test_get_db_file(self):
  217. self.assertEqual(self.unix.get_db_file(), "initdb.file")
  218. def test_increase_transfers_counter(self):
  219. self.unix._max_transfers_out = 10
  220. count = self.unix._transfers_counter
  221. self.assertEqual(self.unix.increase_transfers_counter(), True)
  222. self.assertEqual(count + 1, self.unix._transfers_counter)
  223. self.unix._max_transfers_out = 0
  224. count = self.unix._transfers_counter
  225. self.assertEqual(self.unix.increase_transfers_counter(), False)
  226. self.assertEqual(count, self.unix._transfers_counter)
  227. def test_decrease_transfers_counter(self):
  228. count = self.unix._transfers_counter
  229. self.unix.decrease_transfers_counter()
  230. self.assertEqual(count - 1, self.unix._transfers_counter)
  231. if __name__== "__main__":
  232. unittest.main()