xfrout_test.py.in 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. # Copyright (C) 2010 Internet Systems Consortium.
  2. #
  3. # Permission to use, copy, modify, and distribute this software for any
  4. # purpose with or without fee is hereby granted, provided that the above
  5. # copyright notice and this permission notice appear in all copies.
  6. #
  7. # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
  8. # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
  9. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
  10. # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
  11. # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
  12. # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
  13. # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
  14. # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. '''Tests for the XfroutSession and UnixSockServer classes '''
  16. import unittest
  17. import os
  18. from isc.testutils.tsigctx_mock import MockTSIGContext
  19. from isc.cc.session import *
  20. from pydnspp import *
  21. from xfrout import *
  22. import xfrout
  23. import isc.acl.dns
  24. TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
  25. # our fake socket, where we can read and insert messages
  26. class MySocket():
  27. def __init__(self, family, type):
  28. self.family = family
  29. self.type = type
  30. self.sendqueue = bytearray()
  31. def connect(self, to):
  32. pass
  33. def close(self):
  34. pass
  35. def send(self, data):
  36. self.sendqueue.extend(data);
  37. return len(data)
  38. def readsent(self):
  39. if len(self.sendqueue) >= 2:
  40. size = 2 + struct.unpack("!H", self.sendqueue[:2])[0]
  41. else:
  42. size = 0
  43. result = self.sendqueue[:size]
  44. self.sendqueue = self.sendqueue[size:]
  45. return result
  46. def read_msg(self):
  47. sent_data = self.readsent()
  48. get_msg = Message(Message.PARSE)
  49. get_msg.from_wire(bytes(sent_data[2:]))
  50. return get_msg
  51. def clear_send(self):
  52. del self.sendqueue[:]
  53. # We subclass the Session class we're testing here, only
  54. # to override the handle() and _send_data() method
  55. class MyXfroutSession(XfroutSession):
  56. def handle(self):
  57. pass
  58. def _send_data(self, sock, data):
  59. size = len(data)
  60. total_count = 0
  61. while total_count < size:
  62. count = sock.send(data[total_count:])
  63. total_count += count
  64. class Dbserver:
  65. def __init__(self):
  66. self._shutdown_event = threading.Event()
  67. def get_db_file(self):
  68. return None
  69. def decrease_transfers_counter(self):
  70. pass
  71. class TestXfroutSession(unittest.TestCase):
  72. def getmsg(self):
  73. msg = Message(Message.PARSE)
  74. msg.from_wire(self.mdata)
  75. return msg
  76. def create_mock_tsig_ctx(self, error):
  77. # This helper function creates a MockTSIGContext for a given key
  78. # and TSIG error to be used as a result of verify (normally faked
  79. # one)
  80. mock_ctx = MockTSIGContext(TSIG_KEY)
  81. mock_ctx.error = error
  82. return mock_ctx
  83. def message_has_tsig(self, msg):
  84. return msg.get_tsig_record() is not None
  85. def create_request_data_with_tsig(self):
  86. msg = Message(Message.RENDER)
  87. query_id = 0x1035
  88. msg.set_qid(query_id)
  89. msg.set_opcode(Opcode.QUERY())
  90. msg.set_rcode(Rcode.NOERROR())
  91. query_question = Question(Name("example.com."), RRClass.IN(), RRType.AXFR())
  92. msg.add_question(query_question)
  93. renderer = MessageRenderer()
  94. tsig_ctx = MockTSIGContext(TSIG_KEY)
  95. msg.to_wire(renderer, tsig_ctx)
  96. reply_data = renderer.get_data()
  97. return reply_data
  98. def setUp(self):
  99. self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
  100. self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
  101. TSIGKeyRing(), ('127.0.0.1', 12345),
  102. # When not testing ACLs, simply accept
  103. isc.acl.dns.REQUEST_LOADER.load(
  104. [{"action": "ACCEPT"}]))
  105. self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
  106. self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
  107. def test_parse_query_message(self):
  108. [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
  109. self.assertEqual(get_rcode.to_text(), "NOERROR")
  110. # tsig signed query message
  111. request_data = self.create_request_data_with_tsig()
  112. # BADKEY
  113. [rcode, msg] = self.xfrsess._parse_query_message(request_data)
  114. self.assertEqual(rcode.to_text(), "NOTAUTH")
  115. self.assertTrue(self.xfrsess._tsig_ctx is not None)
  116. # NOERROR
  117. self.xfrsess._tsig_key_ring.add(TSIG_KEY)
  118. [rcode, msg] = self.xfrsess._parse_query_message(request_data)
  119. self.assertEqual(rcode.to_text(), "NOERROR")
  120. self.assertTrue(self.xfrsess._tsig_ctx is not None)
  121. # ACL checks, put some ACL inside
  122. self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
  123. {
  124. "from": "127.0.0.1",
  125. "action": "ACCEPT"
  126. },
  127. {
  128. "from": "192.0.2.1",
  129. "action": "DROP"
  130. }
  131. ])
  132. # Localhost (the default in this test) is accepted
  133. rcode, msg = self.xfrsess._parse_query_message(self.mdata)
  134. self.assertEqual(rcode.to_text(), "NOERROR")
  135. # This should be dropped completely, therefore returning None
  136. self.xfrsess._remote = ('192.0.2.1', 12345)
  137. rcode, msg = self.xfrsess._parse_query_message(self.mdata)
  138. self.assertEqual(None, rcode)
  139. # This should be refused, therefore REFUSED
  140. self.xfrsess._remote = ('192.0.2.2', 12345)
  141. rcode, msg = self.xfrsess._parse_query_message(self.mdata)
  142. self.assertEqual(rcode.to_text(), "REFUSED")
  143. # If the TSIG check fails, it should not check ACL
  144. # (If it checked ACL as well, it would just drop the request)
  145. self.xfrsess._remote = ('192.0.2.1', 12345)
  146. self.xfrsess._tsig_key_ring = TSIGKeyRing()
  147. rcode, msg = self.xfrsess._parse_query_message(request_data)
  148. self.assertEqual(rcode.to_text(), "NOTAUTH")
  149. self.assertTrue(self.xfrsess._tsig_ctx is not None)
  150. def test_get_query_zone_name(self):
  151. msg = self.getmsg()
  152. self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
  153. def test_send_data(self):
  154. self.xfrsess._send_data(self.sock, self.mdata)
  155. senddata = self.sock.readsent()
  156. self.assertEqual(senddata, self.mdata)
  157. def test_reply_xfrout_query_with_error_rcode(self):
  158. msg = self.getmsg()
  159. self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
  160. get_msg = self.sock.read_msg()
  161. self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
  162. # tsig signed message
  163. msg = self.getmsg()
  164. self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
  165. self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
  166. get_msg = self.sock.read_msg()
  167. self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
  168. self.assertTrue(self.message_has_tsig(get_msg))
  169. def test_send_message(self):
  170. msg = self.getmsg()
  171. msg.make_response()
  172. # soa record data with different cases
  173. soa_record = (4, 3, 'Example.com.', 'com.Example.', 3600, 'SOA', None, 'master.Example.com. admin.exAmple.com. 1234 3600 1800 2419200 7200')
  174. rrset_soa = self.xfrsess._create_rrset_from_db_record(soa_record)
  175. msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
  176. self.xfrsess._send_message(self.sock, msg)
  177. send_out_data = self.sock.readsent()[2:]
  178. # CASE_INSENSITIVE compression mode
  179. render = MessageRenderer();
  180. render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
  181. msg.to_wire(render)
  182. self.assertNotEqual(render.get_data(), send_out_data)
  183. # CASE_SENSITIVE compression mode
  184. render.clear()
  185. render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
  186. render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
  187. msg.to_wire(render)
  188. self.assertEqual(render.get_data(), send_out_data)
  189. def test_clear_message(self):
  190. msg = self.getmsg()
  191. qid = msg.get_qid()
  192. opcode = msg.get_opcode()
  193. rcode = msg.get_rcode()
  194. self.xfrsess._clear_message(msg)
  195. self.assertEqual(msg.get_qid(), qid)
  196. self.assertEqual(msg.get_opcode(), opcode)
  197. self.assertEqual(msg.get_rcode(), rcode)
  198. self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
  199. def test_create_rrset_from_db_record(self):
  200. rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  201. self.assertEqual(rrset.get_name().to_text(), "example.com.")
  202. self.assertEqual(rrset.get_class(), RRClass("IN"))
  203. self.assertEqual(rrset.get_type().to_text(), "SOA")
  204. rdata = rrset.get_rdata()
  205. self.assertEqual(rdata[0].to_text(), self.soa_record[7])
  206. def test_send_message_with_last_soa(self):
  207. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  208. msg = self.getmsg()
  209. msg.make_response()
  210. # packet number less than TSIG_SIGN_EVERY_NTH
  211. packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
  212. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
  213. 0, packet_neet_not_sign)
  214. get_msg = self.sock.read_msg()
  215. # tsig context is not exist
  216. self.assertFalse(self.message_has_tsig(get_msg))
  217. self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
  218. self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
  219. self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
  220. #answer_rrset_iter = section_iter(get_msg, section.ANSWER())
  221. answer = get_msg.get_section(Message.SECTION_ANSWER)[0]#answer_rrset_iter.get_rrset()
  222. self.assertEqual(answer.get_name().to_text(), "example.com.")
  223. self.assertEqual(answer.get_class(), RRClass("IN"))
  224. self.assertEqual(answer.get_type().to_text(), "SOA")
  225. rdata = answer.get_rdata()
  226. self.assertEqual(rdata[0].to_text(), self.soa_record[7])
  227. # msg is the TSIG_SIGN_EVERY_NTH one
  228. # sending the message with last soa together
  229. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
  230. 0, TSIG_SIGN_EVERY_NTH)
  231. get_msg = self.sock.read_msg()
  232. # tsig context is not exist
  233. self.assertFalse(self.message_has_tsig(get_msg))
  234. def test_send_message_with_last_soa_with_tsig(self):
  235. # create tsig context
  236. self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
  237. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  238. msg = self.getmsg()
  239. msg.make_response()
  240. # packet number less than TSIG_SIGN_EVERY_NTH
  241. packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
  242. # msg is not the TSIG_SIGN_EVERY_NTH one
  243. # sending the message with last soa together
  244. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
  245. 0, packet_neet_not_sign)
  246. get_msg = self.sock.read_msg()
  247. self.assertTrue(self.message_has_tsig(get_msg))
  248. self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
  249. self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
  250. self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
  251. # msg is the TSIG_SIGN_EVERY_NTH one
  252. # sending the message with last soa together
  253. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
  254. 0, TSIG_SIGN_EVERY_NTH)
  255. get_msg = self.sock.read_msg()
  256. self.assertTrue(self.message_has_tsig(get_msg))
  257. def test_trigger_send_message_with_last_soa(self):
  258. rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
  259. rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
  260. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  261. msg = self.getmsg()
  262. msg.make_response()
  263. msg.add_rrset(Message.SECTION_ANSWER, rrset_a)
  264. # length larger than MAX-len(rrset)
  265. length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1
  266. # packet number less than TSIG_SIGN_EVERY_NTH
  267. packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
  268. # give the function a value that is larger than MAX-len(rrset)
  269. # this should have triggered the sending of two messages
  270. # (1 with the rrset we added manually, and 1 that triggered
  271. # the sending in _with_last_soa)
  272. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
  273. packet_neet_not_sign)
  274. get_msg = self.sock.read_msg()
  275. self.assertFalse(self.message_has_tsig(get_msg))
  276. self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
  277. self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
  278. self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
  279. answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
  280. self.assertEqual(answer.get_name().to_text(), "example.com.")
  281. self.assertEqual(answer.get_class(), RRClass("IN"))
  282. self.assertEqual(answer.get_type().to_text(), "A")
  283. rdata = answer.get_rdata()
  284. self.assertEqual(rdata[0].to_text(), "192.0.2.1")
  285. get_msg = self.sock.read_msg()
  286. self.assertFalse(self.message_has_tsig(get_msg))
  287. self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0)
  288. self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
  289. self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
  290. #answer_rrset_iter = section_iter(get_msg, Message.SECTION_ANSWER)
  291. answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
  292. self.assertEqual(answer.get_name().to_text(), "example.com.")
  293. self.assertEqual(answer.get_class(), RRClass("IN"))
  294. self.assertEqual(answer.get_type().to_text(), "SOA")
  295. rdata = answer.get_rdata()
  296. self.assertEqual(rdata[0].to_text(), self.soa_record[7])
  297. # and it should not have sent anything else
  298. self.assertEqual(0, len(self.sock.sendqueue))
  299. def test_trigger_send_message_with_last_soa_with_tsig(self):
  300. self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
  301. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  302. msg = self.getmsg()
  303. msg.make_response()
  304. msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
  305. # length larger than MAX-len(rrset)
  306. length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1
  307. # packet number less than TSIG_SIGN_EVERY_NTH
  308. packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
  309. # give the function a value that is larger than MAX-len(rrset)
  310. # this should have triggered the sending of two messages
  311. # (1 with the rrset we added manually, and 1 that triggered
  312. # the sending in _with_last_soa)
  313. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
  314. packet_neet_not_sign)
  315. get_msg = self.sock.read_msg()
  316. # msg is not the TSIG_SIGN_EVERY_NTH one, it shouldn't be tsig signed
  317. self.assertFalse(self.message_has_tsig(get_msg))
  318. # the last packet should be tsig signed
  319. get_msg = self.sock.read_msg()
  320. self.assertTrue(self.message_has_tsig(get_msg))
  321. # and it should not have sent anything else
  322. self.assertEqual(0, len(self.sock.sendqueue))
  323. # msg is the TSIG_SIGN_EVERY_NTH one, it should be tsig signed
  324. self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
  325. xfrout.TSIG_SIGN_EVERY_NTH)
  326. get_msg = self.sock.read_msg()
  327. self.assertTrue(self.message_has_tsig(get_msg))
  328. # the last packet should be tsig signed
  329. get_msg = self.sock.read_msg()
  330. self.assertTrue(self.message_has_tsig(get_msg))
  331. # and it should not have sent anything else
  332. self.assertEqual(0, len(self.sock.sendqueue))
  333. def test_get_rrset_len(self):
  334. rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
  335. self.assertEqual(82, get_rrset_len(rrset_soa))
  336. def test_zone_has_soa(self):
  337. global sqlite3_ds
  338. def mydb1(zone, file):
  339. return True
  340. sqlite3_ds.get_zone_soa = mydb1
  341. self.assertTrue(self.xfrsess._zone_has_soa(""))
  342. def mydb2(zone, file):
  343. return False
  344. sqlite3_ds.get_zone_soa = mydb2
  345. self.assertFalse(self.xfrsess._zone_has_soa(""))
  346. def test_zone_exist(self):
  347. global sqlite3_ds
  348. def zone_exist(zone, file):
  349. return zone
  350. sqlite3_ds.zone_exist = zone_exist
  351. self.assertTrue(self.xfrsess._zone_exist(True))
  352. self.assertFalse(self.xfrsess._zone_exist(False))
  353. def test_check_xfrout_available(self):
  354. def zone_exist(zone):
  355. return zone
  356. def zone_has_soa(zone):
  357. return (not zone)
  358. self.xfrsess._zone_exist = zone_exist
  359. self.xfrsess._zone_has_soa = zone_has_soa
  360. self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
  361. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
  362. def zone_empty(zone):
  363. return zone
  364. self.xfrsess._zone_has_soa = zone_empty
  365. def false_func():
  366. return False
  367. self.xfrsess._server.increase_transfers_counter = false_func
  368. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
  369. def true_func():
  370. return True
  371. self.xfrsess._server.increase_transfers_counter = true_func
  372. self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
  373. def test_dns_xfrout_start_formerror(self):
  374. # formerror
  375. self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
  376. sent_data = self.sock.readsent()
  377. self.assertEqual(len(sent_data), 0)
  378. def default(self, param):
  379. return "example.com"
  380. def test_dns_xfrout_start_notauth(self):
  381. self.xfrsess._get_query_zone_name = self.default
  382. def notauth(formpara):
  383. return Rcode.NOTAUTH()
  384. self.xfrsess._check_xfrout_available = notauth
  385. self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
  386. get_msg = self.sock.read_msg()
  387. self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
  388. def test_dns_xfrout_start_noerror(self):
  389. self.xfrsess._get_query_zone_name = self.default
  390. def noerror(form):
  391. return Rcode.NOERROR()
  392. self.xfrsess._check_xfrout_available = noerror
  393. def myreply(msg, sock, zonename):
  394. self.sock.send(b"success")
  395. self.xfrsess._reply_xfrout_query = myreply
  396. self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
  397. self.assertEqual(self.sock.readsent(), b"success")
  398. def test_reply_xfrout_query_noerror(self):
  399. global sqlite3_ds
  400. def get_zone_soa(zonename, file):
  401. return self.soa_record
  402. def get_zone_datas(zone, file):
  403. return [self.soa_record]
  404. sqlite3_ds.get_zone_soa = get_zone_soa
  405. sqlite3_ds.get_zone_datas = get_zone_datas
  406. self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
  407. reply_msg = self.sock.read_msg()
  408. self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
  409. def test_reply_xfrout_query_noerror_with_tsig(self):
  410. rrset_data = (4, 3, 'a.example.com.', 'com.example.', 3600, 'A', None, '192.168.1.1')
  411. global sqlite3_ds
  412. global xfrout
  413. def get_zone_soa(zonename, file):
  414. return self.soa_record
  415. def get_zone_datas(zone, file):
  416. zone_rrsets = []
  417. for i in range(0, 100):
  418. zone_rrsets.insert(i, rrset_data)
  419. return zone_rrsets
  420. def get_rrset_len(rrset):
  421. return 65520
  422. sqlite3_ds.get_zone_soa = get_zone_soa
  423. sqlite3_ds.get_zone_datas = get_zone_datas
  424. xfrout.get_rrset_len = get_rrset_len
  425. self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
  426. self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
  427. # tsig signed first package
  428. reply_msg = self.sock.read_msg()
  429. self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1)
  430. self.assertTrue(self.message_has_tsig(reply_msg))
  431. # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig
  432. for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1):
  433. reply_msg = self.sock.read_msg()
  434. self.assertFalse(self.message_has_tsig(reply_msg))
  435. # TSIG_SIGN_EVERY_NTH packet has tsig
  436. reply_msg = self.sock.read_msg()
  437. self.assertTrue(self.message_has_tsig(reply_msg))
  438. for i in range(0, 100 - TSIG_SIGN_EVERY_NTH):
  439. reply_msg = self.sock.read_msg()
  440. self.assertFalse(self.message_has_tsig(reply_msg))
  441. # tsig signed last package
  442. reply_msg = self.sock.read_msg()
  443. self.assertTrue(self.message_has_tsig(reply_msg))
  444. # and it should not have sent anything else
  445. self.assertEqual(0, len(self.sock.sendqueue))
  446. class MyCCSession():
  447. def __init__(self):
  448. pass
  449. def get_remote_config_value(self, module_name, identifier):
  450. if module_name == "Auth" and identifier == "database_file":
  451. return "initdb.file", False
  452. else:
  453. return "unknown", False
  454. class MyUnixSockServer(UnixSockServer):
  455. def __init__(self):
  456. self._shutdown_event = threading.Event()
  457. self._max_transfers_out = 10
  458. self._cc = MyCCSession()
  459. self._common_init()
  460. class TestUnixSockServer(unittest.TestCase):
  461. def setUp(self):
  462. self.write_sock, self.read_sock = socket.socketpair()
  463. self.unix = MyUnixSockServer()
  464. def test_guess_remote(self):
  465. """Test we can guess the remote endpoint when we have only the
  466. file descriptor. This is needed, because we get only that one
  467. from auth."""
  468. # We test with UDP, as it can be "connected" without other
  469. # endpoint
  470. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  471. sock.connect(('127.0.0.1', 12345))
  472. self.assertEqual(('127.0.0.1', 12345),
  473. self.unix._guess_remote(sock.fileno()))
  474. if socket.has_ipv6:
  475. # Don't check IPv6 address on hosts not supporting them
  476. sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  477. sock.connect(('::1', 12345))
  478. self.assertEqual(('::1', 12345, 0, 0),
  479. self.unix._guess_remote(sock.fileno()))
  480. # Try when pretending there's no IPv6 support
  481. # (No need to pretend when there's really no IPv6)
  482. xfrout.socket.has_ipv6 = False
  483. sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  484. sock.connect(('127.0.0.1', 12345))
  485. self.assertEqual(('127.0.0.1', 12345),
  486. self.unix._guess_remote(sock.fileno()))
  487. # Return it back
  488. xfrout.socket.has_ipv6 = True
  489. def test_receive_query_message(self):
  490. send_msg = b"\xd6=\x00\x00\x00\x01\x00"
  491. msg_len = struct.pack('H', socket.htons(len(send_msg)))
  492. self.write_sock.send(msg_len)
  493. self.write_sock.send(send_msg)
  494. recv_msg = self.unix._receive_query_message(self.read_sock)
  495. self.assertEqual(recv_msg, send_msg)
  496. def check_default_ACL(self):
  497. context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
  498. 1234, 0, 0, 0,
  499. socket.AI_NUMERICHOST)[0][4])
  500. self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
  501. def check_loaded_ACL(self):
  502. context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
  503. 1234, 0, 0, 0,
  504. socket.AI_NUMERICHOST)[0][4])
  505. self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
  506. context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
  507. 1234, 0, 0, 0,
  508. socket.AI_NUMERICHOST)[0][4])
  509. self.assertEqual(isc.acl.acl.REJECT, self.unix._acl.execute(context))
  510. def test_update_config_data(self):
  511. self.check_default_ACL()
  512. tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
  513. tsig_key_list = [tsig_key_str]
  514. bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
  515. self.unix.update_config_data({'transfers_out':10 })
  516. self.assertEqual(self.unix._max_transfers_out, 10)
  517. self.assertTrue(self.unix.tsig_key_ring is not None)
  518. self.check_default_ACL()
  519. self.unix.update_config_data({'transfers_out':9,
  520. 'tsig_key_ring':tsig_key_list})
  521. self.assertEqual(self.unix._max_transfers_out, 9)
  522. self.assertEqual(self.unix.tsig_key_ring.size(), 1)
  523. self.unix.tsig_key_ring.remove(Name("example.com."))
  524. self.assertEqual(self.unix.tsig_key_ring.size(), 0)
  525. # bad tsig key
  526. config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list}
  527. self.assertRaises(None, self.unix.update_config_data(config_data))
  528. self.assertEqual(self.unix.tsig_key_ring.size(), 0)
  529. # Load the ACL
  530. self.unix.update_config_data({'query_acl': [{'from': '127.0.0.1',
  531. 'action': 'ACCEPT'}]})
  532. self.check_loaded_ACL()
  533. # Pass a wrong data there and check it does not replace the old one
  534. self.assertRaises(isc.acl.acl.LoaderError,
  535. self.unix.update_config_data,
  536. {'query_acl': ['Something bad']})
  537. self.check_loaded_ACL()
  538. def test_get_db_file(self):
  539. self.assertEqual(self.unix.get_db_file(), "initdb.file")
  540. def test_increase_transfers_counter(self):
  541. self.unix._max_transfers_out = 10
  542. count = self.unix._transfers_counter
  543. self.assertEqual(self.unix.increase_transfers_counter(), True)
  544. self.assertEqual(count + 1, self.unix._transfers_counter)
  545. self.unix._max_transfers_out = 0
  546. count = self.unix._transfers_counter
  547. self.assertEqual(self.unix.increase_transfers_counter(), False)
  548. self.assertEqual(count, self.unix._transfers_counter)
  549. def test_decrease_transfers_counter(self):
  550. count = self.unix._transfers_counter
  551. self.unix.decrease_transfers_counter()
  552. self.assertEqual(count - 1, self.unix._transfers_counter)
  553. def _remove_file(self, sock_file):
  554. try:
  555. os.remove(sock_file)
  556. except OSError:
  557. pass
  558. def test_sock_file_in_use_file_exist(self):
  559. sock_file = 'temp.sock.file'
  560. self._remove_file(sock_file)
  561. self.assertFalse(self.unix._sock_file_in_use(sock_file))
  562. self.assertFalse(os.path.exists(sock_file))
  563. def test_sock_file_in_use_file_not_exist(self):
  564. self.assertFalse(self.unix._sock_file_in_use('temp.sock.file'))
  565. def _start_unix_sock_server(self, sock_file):
  566. serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler)
  567. serv_thread = threading.Thread(target=serv.serve_forever)
  568. serv_thread.setDaemon(True)
  569. serv_thread.start()
  570. def test_sock_file_in_use(self):
  571. sock_file = 'temp.sock.file'
  572. self._remove_file(sock_file)
  573. self.assertFalse(self.unix._sock_file_in_use(sock_file))
  574. self._start_unix_sock_server(sock_file)
  575. old_stdout = sys.stdout
  576. sys.stdout = open(os.devnull, 'w')
  577. self.assertTrue(self.unix._sock_file_in_use(sock_file))
  578. sys.stdout = old_stdout
  579. def test_remove_unused_sock_file_in_use(self):
  580. sock_file = 'temp.sock.file'
  581. self._remove_file(sock_file)
  582. self.assertFalse(self.unix._sock_file_in_use(sock_file))
  583. self._start_unix_sock_server(sock_file)
  584. old_stdout = sys.stdout
  585. sys.stdout = open(os.devnull, 'w')
  586. try:
  587. self.unix._remove_unused_sock_file(sock_file)
  588. except SystemExit:
  589. pass
  590. else:
  591. # This should never happen
  592. self.assertTrue(False)
  593. sys.stdout = old_stdout
  594. def test_remove_unused_sock_file_dir(self):
  595. import tempfile
  596. dir_name = tempfile.mkdtemp()
  597. old_stdout = sys.stdout
  598. sys.stdout = open(os.devnull, 'w')
  599. try:
  600. self.unix._remove_unused_sock_file(dir_name)
  601. except SystemExit:
  602. pass
  603. else:
  604. # This should never happen
  605. self.assertTrue(False)
  606. sys.stdout = old_stdout
  607. os.rmdir(dir_name)
  608. class TestInitialization(unittest.TestCase):
  609. def setEnv(self, name, value):
  610. if value is None:
  611. if name in os.environ:
  612. del os.environ[name]
  613. else:
  614. os.environ[name] = value
  615. def setUp(self):
  616. self._oldSocket = os.getenv("BIND10_XFROUT_SOCKET_FILE")
  617. self._oldFromBuild = os.getenv("B10_FROM_BUILD")
  618. def tearDown(self):
  619. self.setEnv("B10_FROM_BUILD", self._oldFromBuild)
  620. self.setEnv("BIND10_XFROUT_SOCKET_FILE", self._oldSocket)
  621. # Make sure even the computed values are back
  622. xfrout.init_paths()
  623. def testNoEnv(self):
  624. self.setEnv("B10_FROM_BUILD", None)
  625. self.setEnv("BIND10_XFROUT_SOCKET_FILE", None)
  626. xfrout.init_paths()
  627. self.assertEqual(xfrout.UNIX_SOCKET_FILE,
  628. "@@LOCALSTATEDIR@@/auth_xfrout_conn")
  629. def testProvidedSocket(self):
  630. self.setEnv("B10_FROM_BUILD", None)
  631. self.setEnv("BIND10_XFROUT_SOCKET_FILE", "The/Socket/File")
  632. xfrout.init_paths()
  633. self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File")
  634. if __name__== "__main__":
  635. unittest.main()