xfrout_test.py 16 KB


  1. # Copyright (C) 2010 Internet Systems Consortium.
  2. #
  3. # Permission to use, copy, modify, and distribute this software for any
  4. # purpose with or without fee is hereby granted, provided that the above
  5. # copyright notice and this permission notice appear in all copies.
  6. #
  7. # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
  8. # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
  9. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
  10. # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
  11. # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
  12. # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
  13. # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
  14. # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. '''Tests for the XfroutSession and UnixSockServer classes '''
  16. import unittest
  17. import os
  18. from isc.cc.session import *
  19. from pydnspp import *
  20. from xfrout import *
  21. # our fake socket, where we can read and insert messages
  22. class MySocket():
  23. def __init__(self, family, type):
  24. self.family = family
  25. self.type = type
  26. self.sendqueue = bytearray()
  27. def connect(self, to):
  28. pass
  29. def close(self):
  30. pass
  31. def send(self, data):
  32. self.sendqueue.extend(data);
  33. return len(data)
  34. def readsent(self):
  35. if len(self.sendqueue) >= 2:
  36. size = 2 + struct.unpack("!H", self.sendqueue[:2])[0]
  37. else:
  38. size = 0
  39. result = self.sendqueue[:size]
  40. self.sendqueue = self.sendqueue[size:]
  41. return result
  42. def read_msg(self):
  43. sent_data = self.readsent()
  44. get_msg = Message(Message.PARSE)
  45. get_msg.from_wire(bytes(sent_data[2:]))
  46. return get_msg
  47. def clear_send(self):
  48. del self.sendqueue[:]
  49. # We subclass the Session class we're testing here, only
  50. # to override the handle() and _send_data() method
  51. class MyXfroutSession(XfroutSession):
  52. def handle(self):
  53. pass
  54. def _send_data(self, sock, data):
  55. size = len(data)
  56. total_count = 0
  57. while total_count < size:
  58. count = sock.send(data[total_count:])
  59. total_count += count
  60. class Dbserver:
  61. def __init__(self):
  62. self._shutdown_event = threading.Event()
  63. def get_db_file(self):
  64. return None
  65. def decrease_transfers_counter(self):
  66. pass
  67. class TestXfroutSession(unittest.TestCase):
  68. def getmsg(self):
  69. msg = Message(Message.PARSE)
  70. msg.from_wire(self.mdata)
  71. return msg
  72. def setUp(self):
  73. request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
  74. self.log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False )
  75. (self.write_sock, self.read_sock) = socket.socketpair()
  76. self.xfrsess = MyXfroutSession(request, None, None, self.log, self.read_sock)
  77. self.xfrsess.server = Dbserver()
  78. self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
  79. self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
  80. self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
  81. def test_receive_query_message(self):
  82. send_msg = b"\xd6=\x00\x00\x00\x01\x00"
  83. msg_len = struct.pack('H', socket.htons(len(send_msg)))
  84. self.write_sock.send(msg_len)
  85. self.write_sock.send(send_msg)
  86. recv_msg = self.xfrsess._receive_query_message(self.read_sock)
  87. self.assertEqual(recv_msg, send_msg)
  88. def test_parse_query_message(self):
  89. [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
  90. self.assertEqual(get_rcode.to_text(), "NOERROR")
  91. def test_get_query_zone_name(self):
  92. msg = self.getmsg()
  93. self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
  94. def test_send_data(self):
  95. self.xfrsess._send_data(self.sock, self.mdata)
  96. senddata = self.sock.readsent()
  97. self.assertEqual(senddata, self.mdata)
  98. def test_reply_xfrout_query_with_error_rcode(self):
  99. msg = self.getmsg()
  100. self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
  101. get_msg = self.sock.read_msg()
  102. self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
  103. def test_clear_message(self):
  104. msg = self.getmsg()
  105. qid = msg.get_qid()
  106. opcode = msg.get_opcode()
  107. rcode = msg.get_rcode()
  108. self.xfrsess._clear_message(msg)
  109. self.assertEqual(msg.get_qid(), qid)
  110. self.assertEqual(msg.get_opcode(), opcode)
  111. self.assertEqual(msg.get_rcode(), rcode)
  112. self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
  113. def test_reply_query_with_format_error(self):
  114. msg = self.getmsg()
  115. self.xfrsess._reply_query_with_format_error(msg, self.sock)
  116. get_msg = self.sock.read_msg()
  117. self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
  118. def test_create_rrset_from_db_record(self):
  119. rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  120. self.assertEqual(rrset.get_name().to_text(), "example.com.")
  121. self.assertEqual(rrset.get_class(), RRClass("IN"))
  122. self.assertEqual(rrset.get_type().to_text(), "SOA")
  123. rdata = rrset.get_rdata()
  124. self.assertEqual(rdata[0].to_text(), self.soa_record[7])
  125. def test_send_message_with_last_soa(self):
  126. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  127. msg = self.getmsg()
  128. msg.make_response()
  129. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0)
  130. get_msg = self.sock.read_msg()
  131. self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
  132. self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
  133. self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
  134. #answer_rrset_iter = section_iter(get_msg, section.ANSWER())
  135. answer = get_msg.get_section(Message.SECTION_ANSWER)[0]#answer_rrset_iter.get_rrset()
  136. self.assertEqual(answer.get_name().to_text(), "example.com.")
  137. self.assertEqual(answer.get_class(), RRClass("IN"))
  138. self.assertEqual(answer.get_type().to_text(), "SOA")
  139. rdata = answer.get_rdata()
  140. self.assertEqual(rdata[0].to_text(), self.soa_record[7])
  141. def test_trigger_send_message_with_last_soa(self):
  142. rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
  143. rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
  144. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  145. msg = self.getmsg()
  146. msg.make_response()
  147. msg.add_rrset(Message.SECTION_ANSWER, rrset_a)
  148. # give the function a value that is larger than MAX-len(rrset)
  149. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 65520)
  150. # this should have triggered the sending of two messages
  151. # (1 with the rrset we added manually, and 1 that triggered
  152. # the sending in _with_last_soa)
  153. get_msg = self.sock.read_msg()
  154. self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
  155. self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
  156. self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
  157. answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
  158. self.assertEqual(answer.get_name().to_text(), "example.com.")
  159. self.assertEqual(answer.get_class(), RRClass("IN"))
  160. self.assertEqual(answer.get_type().to_text(), "A")
  161. rdata = answer.get_rdata()
  162. self.assertEqual(rdata[0].to_text(), "192.0.2.1")
  163. get_msg = self.sock.read_msg()
  164. self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0)
  165. self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
  166. self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
  167. #answer_rrset_iter = section_iter(get_msg, Message.SECTION_ANSWER)
  168. answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
  169. self.assertEqual(answer.get_name().to_text(), "example.com.")
  170. self.assertEqual(answer.get_class(), RRClass("IN"))
  171. self.assertEqual(answer.get_type().to_text(), "SOA")
  172. rdata = answer.get_rdata()
  173. self.assertEqual(rdata[0].to_text(), self.soa_record[7])
  174. # and it should not have sent anything else
  175. self.assertEqual(0, len(self.sock.sendqueue))
  176. def test_get_rrset_len(self):
  177. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  178. self.assertEqual(82, get_rrset_len(rrset_soa))
  179. def test_zone_has_soa(self):
  180. global sqlite3_ds
  181. def mydb1(zone, file):
  182. return True
  183. sqlite3_ds.get_zone_soa = mydb1
  184. self.assertTrue(self.xfrsess._zone_has_soa(""))
  185. def mydb2(zone, file):
  186. return False
  187. sqlite3_ds.get_zone_soa = mydb2
  188. self.assertFalse(self.xfrsess._zone_has_soa(""))
  189. def test_zone_exist(self):
  190. global sqlite3_ds
  191. def zone_exist(zone, file):
  192. return zone
  193. sqlite3_ds.zone_exist = zone_exist
  194. self.assertTrue(self.xfrsess._zone_exist(True))
  195. self.assertFalse(self.xfrsess._zone_exist(False))
  196. def test_check_xfrout_available(self):
  197. def zone_exist(zone):
  198. return zone
  199. def zone_has_soa(zone):
  200. return (not zone)
  201. self.xfrsess._zone_exist = zone_exist
  202. self.xfrsess._zone_has_soa = zone_has_soa
  203. self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
  204. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
  205. def zone_empty(zone):
  206. return zone
  207. self.xfrsess._zone_has_soa = zone_empty
  208. def false_func():
  209. return False
  210. self.xfrsess.server.increase_transfers_counter = false_func
  211. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
  212. def true_func():
  213. return True
  214. self.xfrsess.server.increase_transfers_counter = true_func
  215. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
  216. def test_dns_xfrout_start_formerror(self):
  217. # formerror
  218. self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
  219. sent_data = self.sock.readsent()
  220. self.assertEqual(len(sent_data), 0)
  221. def default(self, param):
  222. return "example.com"
  223. def test_dns_xfrout_start_notauth(self):
  224. self.xfrsess._get_query_zone_name = self.default
  225. def notauth(formpara):
  226. return Rcode.NOTAUTH()
  227. self.xfrsess._check_xfrout_available = notauth
  228. self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
  229. get_msg = self.sock.read_msg()
  230. self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
  231. def test_dns_xfrout_start_noerror(self):
  232. self.xfrsess._get_query_zone_name = self.default
  233. def noerror(form):
  234. return Rcode.NOERROR()
  235. self.xfrsess._check_xfrout_available = noerror
  236. def myreply(msg, sock, zonename):
  237. self.sock.send(b"success")
  238. self.xfrsess._reply_xfrout_query = myreply
  239. self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
  240. self.assertEqual(self.sock.readsent(), b"success")
  241. def test_reply_xfrout_query_noerror(self):
  242. global sqlite3_ds
  243. def get_zone_soa(zonename, file):
  244. return self.soa_record
  245. def get_zone_datas(zone, file):
  246. return [self.soa_record]
  247. sqlite3_ds.get_zone_soa = get_zone_soa
  248. sqlite3_ds.get_zone_datas = get_zone_datas
  249. self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
  250. reply_msg = self.sock.read_msg()
  251. self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
  252. class MyCCSession():
  253. def __init__(self):
  254. pass
  255. def get_remote_config_value(self, module_name, identifier):
  256. if module_name == "Auth" and identifier == "database_file":
  257. return "initdb.file", False
  258. else:
  259. return "unknown", False
  260. class MyUnixSockServer(UnixSockServer):
  261. def __init__(self):
  262. self._lock = threading.Lock()
  263. self._transfers_counter = 0
  264. self._shutdown_event = threading.Event()
  265. self._max_transfers_out = 10
  266. self._cc = MyCCSession()
  267. self._log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False )
  268. class TestUnixSockServer(unittest.TestCase):
  269. def setUp(self):
  270. self.unix = MyUnixSockServer()
  271. def test_updata_config_data(self):
  272. self.unix.update_config_data({'transfers_out':10 })
  273. self.assertEqual(self.unix._max_transfers_out, 10)
  274. def test_get_db_file(self):
  275. self.assertEqual(self.unix.get_db_file(), "initdb.file")
  276. def test_increase_transfers_counter(self):
  277. self.unix._max_transfers_out = 10
  278. count = self.unix._transfers_counter
  279. self.assertEqual(self.unix.increase_transfers_counter(), True)
  280. self.assertEqual(count + 1, self.unix._transfers_counter)
  281. self.unix._max_transfers_out = 0
  282. count = self.unix._transfers_counter
  283. self.assertEqual(self.unix.increase_transfers_counter(), False)
  284. self.assertEqual(count, self.unix._transfers_counter)
  285. def test_decrease_transfers_counter(self):
  286. count = self.unix._transfers_counter
  287. self.unix.decrease_transfers_counter()
  288. self.assertEqual(count - 1, self.unix._transfers_counter)
  289. def _remove_file(self, sock_file):
  290. try:
  291. os.remove(sock_file)
  292. except OSError:
  293. pass
  294. def test_sock_file_in_use_file_exist(self):
  295. sock_file = 'temp.sock.file'
  296. self._remove_file(sock_file)
  297. self.assertFalse(self.unix._sock_file_in_use(sock_file))
  298. self.assertFalse(os.path.exists(sock_file))
  299. def test_sock_file_in_use_file_not_exist(self):
  300. self.assertFalse(self.unix._sock_file_in_use('temp.sock.file'))
  301. def _start_unix_sock_server(self, sock_file):
  302. serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler)
  303. serv_thread = threading.Thread(target=serv.serve_forever)
  304. serv_thread.setDaemon(True)
  305. serv_thread.start()
  306. def test_sock_file_in_use(self):
  307. sock_file = 'temp.sock.file'
  308. self._remove_file(sock_file)
  309. self.assertFalse(self.unix._sock_file_in_use(sock_file))
  310. self._start_unix_sock_server(sock_file)
  311. old_stdout = sys.stdout
  312. sys.stdout = open(os.devnull, 'w')
  313. self.assertTrue(self.unix._sock_file_in_use(sock_file))
  314. sys.stdout = old_stdout
  315. def test_remove_unused_sock_file_in_use(self):
  316. sock_file = 'temp.sock.file'
  317. self._remove_file(sock_file)
  318. self.assertFalse(self.unix._sock_file_in_use(sock_file))
  319. self._start_unix_sock_server(sock_file)
  320. old_stdout = sys.stdout
  321. sys.stdout = open(os.devnull, 'w')
  322. try:
  323. self.unix._remove_unused_sock_file(sock_file)
  324. except SystemExit:
  325. pass
  326. else:
  327. # This should never happen
  328. self.assertTrue(False)
  329. sys.stdout = old_stdout
  330. def test_remove_unused_sock_file_dir(self):
  331. import tempfile
  332. dir_name = tempfile.mkdtemp()
  333. old_stdout = sys.stdout
  334. sys.stdout = open(os.devnull, 'w')
  335. try:
  336. self.unix._remove_unused_sock_file(dir_name)
  337. except SystemExit:
  338. pass
  339. else:
  340. # This should never happen
  341. self.assertTrue(False)
  342. sys.stdout = old_stdout
  343. os.rmdir(dir_name)
  344. if __name__== "__main__":
  345. unittest.main()