123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879 |
- # Copyright (C) 2010 Internet Systems Consortium.
- #
- # Permission to use, copy, modify, and distribute this software for any
- # purpose with or without fee is hereby granted, provided that the above
- # copyright notice and this permission notice appear in all copies.
- #
- # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
- # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
- # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
- # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
- # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
- # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
- # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
- '''Tests for the XfroutSession and UnixSockServer classes '''
- import unittest
- import os
- from isc.testutils.tsigctx_mock import MockTSIGContext
- from isc.cc.session import *
- from pydnspp import *
- from xfrout import *
- import xfrout
- import isc.log
- import isc.acl.dns
- TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
- # our fake socket, where we can read and insert messages
- class MySocket():
- def __init__(self, family, type):
- self.family = family
- self.type = type
- self.sendqueue = bytearray()
- def connect(self, to):
- pass
- def close(self):
- pass
- def send(self, data):
- self.sendqueue.extend(data);
- return len(data)
- def readsent(self):
- if len(self.sendqueue) >= 2:
- size = 2 + struct.unpack("!H", self.sendqueue[:2])[0]
- else:
- size = 0
- result = self.sendqueue[:size]
- self.sendqueue = self.sendqueue[size:]
- return result
- def read_msg(self):
- sent_data = self.readsent()
- get_msg = Message(Message.PARSE)
- get_msg.from_wire(bytes(sent_data[2:]))
- return get_msg
- def clear_send(self):
- del self.sendqueue[:]
- # We subclass the Session class we're testing here, only
- # to override the handle() and _send_data() method
- class MyXfroutSession(XfroutSession):
- def handle(self):
- pass
- def _send_data(self, sock, data):
- size = len(data)
- total_count = 0
- while total_count < size:
- count = sock.send(data[total_count:])
- total_count += count
- class Dbserver:
- def __init__(self):
- self._shutdown_event = threading.Event()
- def get_db_file(self):
- return None
- def decrease_transfers_counter(self):
- pass
- class TestXfroutSession(unittest.TestCase):
- def getmsg(self):
- msg = Message(Message.PARSE)
- msg.from_wire(self.mdata)
- return msg
- def create_mock_tsig_ctx(self, error):
- # This helper function creates a MockTSIGContext for a given key
- # and TSIG error to be used as a result of verify (normally faked
- # one)
- mock_ctx = MockTSIGContext(TSIG_KEY)
- mock_ctx.error = error
- return mock_ctx
- def message_has_tsig(self, msg):
- return msg.get_tsig_record() is not None
- def create_request_data(self, with_tsig=False):
- msg = Message(Message.RENDER)
- query_id = 0x1035
- msg.set_qid(query_id)
- msg.set_opcode(Opcode.QUERY())
- msg.set_rcode(Rcode.NOERROR())
- query_question = Question(Name("example.com"), RRClass.IN(),
- RRType.AXFR())
- msg.add_question(query_question)
- renderer = MessageRenderer()
- if with_tsig:
- tsig_ctx = MockTSIGContext(TSIG_KEY)
- msg.to_wire(renderer, tsig_ctx)
- else:
- msg.to_wire(renderer)
- request_data = renderer.get_data()
- return request_data
- def setUp(self):
- self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
- self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
- TSIGKeyRing(), ('127.0.0.1', 12345),
- # When not testing ACLs, simply accept
- isc.acl.dns.REQUEST_LOADER.load(
- [{"action": "ACCEPT"}]))
- self.mdata = self.create_request_data(False)
- self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
- def test_parse_query_message(self):
- [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
- self.assertEqual(get_rcode.to_text(), "NOERROR")
- # tsig signed query message
- request_data = self.create_request_data(True)
- # BADKEY
- [rcode, msg] = self.xfrsess._parse_query_message(request_data)
- self.assertEqual(rcode.to_text(), "NOTAUTH")
- self.assertTrue(self.xfrsess._tsig_ctx is not None)
- # NOERROR
- self.assertEqual(TSIGKeyRing.SUCCESS,
- self.xfrsess._tsig_key_ring.add(TSIG_KEY))
- [rcode, msg] = self.xfrsess._parse_query_message(request_data)
- self.assertEqual(rcode.to_text(), "NOERROR")
- self.assertTrue(self.xfrsess._tsig_ctx is not None)
- def check_transfer_acl(self, acl_setter):
- # ACL checks, put some ACL inside
- acl_setter(isc.acl.dns.REQUEST_LOADER.load([
- {
- "from": "127.0.0.1",
- "action": "ACCEPT"
- },
- {
- "from": "192.0.2.1",
- "action": "DROP"
- }
- ]))
- # Localhost (the default in this test) is accepted
- rcode, msg = self.xfrsess._parse_query_message(self.mdata)
- self.assertEqual(rcode.to_text(), "NOERROR")
- # This should be dropped completely, therefore returning None
- self.xfrsess._remote = ('192.0.2.1', 12345)
- rcode, msg = self.xfrsess._parse_query_message(self.mdata)
- self.assertEqual(None, rcode)
- # This should be refused, therefore REFUSED
- self.xfrsess._remote = ('192.0.2.2', 12345)
- rcode, msg = self.xfrsess._parse_query_message(self.mdata)
- self.assertEqual(rcode.to_text(), "REFUSED")
- # TSIG signed request
- request_data = self.create_request_data(True)
- # If the TSIG check fails, it should not check ACL
- # (If it checked ACL as well, it would just drop the request)
- self.xfrsess._remote = ('192.0.2.1', 12345)
- self.xfrsess._tsig_key_ring = TSIGKeyRing()
- rcode, msg = self.xfrsess._parse_query_message(request_data)
- self.assertEqual(rcode.to_text(), "NOTAUTH")
- self.assertTrue(self.xfrsess._tsig_ctx is not None)
- # ACL using TSIG: successful case
- acl_setter(isc.acl.dns.REQUEST_LOADER.load([
- {"key": "example.com", "action": "ACCEPT"}, {"action": "REJECT"}
- ]))
- self.assertEqual(TSIGKeyRing.SUCCESS,
- self.xfrsess._tsig_key_ring.add(TSIG_KEY))
- [rcode, msg] = self.xfrsess._parse_query_message(request_data)
- self.assertEqual(rcode.to_text(), "NOERROR")
- # ACL using TSIG: key name doesn't match; should be rejected
- acl_setter(isc.acl.dns.REQUEST_LOADER.load([
- {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
- ]))
- [rcode, msg] = self.xfrsess._parse_query_message(request_data)
- self.assertEqual(rcode.to_text(), "REFUSED")
- # ACL using TSIG: no TSIG; should be rejected
- acl_setter(isc.acl.dns.REQUEST_LOADER.load([
- {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
- ]))
- [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
- self.assertEqual(rcode.to_text(), "REFUSED")
- #
- # ACL using IP + TSIG: both should match
- #
- acl_setter(isc.acl.dns.REQUEST_LOADER.load([
- {"ALL": [{"key": "example.com"}, {"from": "192.0.2.1"}],
- "action": "ACCEPT"},
- {"action": "REJECT"}
- ]))
- # both matches
- self.xfrsess._remote = ('192.0.2.1', 12345)
- [rcode, msg] = self.xfrsess._parse_query_message(request_data)
- self.assertEqual(rcode.to_text(), "NOERROR")
- # TSIG matches, but address doesn't
- self.xfrsess._remote = ('192.0.2.2', 12345)
- [rcode, msg] = self.xfrsess._parse_query_message(request_data)
- self.assertEqual(rcode.to_text(), "REFUSED")
- # Address matches, but TSIG doesn't (not included)
- self.xfrsess._remote = ('192.0.2.1', 12345)
- [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
- self.assertEqual(rcode.to_text(), "REFUSED")
- # Neither address nor TSIG matches
- self.xfrsess._remote = ('192.0.2.2', 12345)
- [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
- self.assertEqual(rcode.to_text(), "REFUSED")
- def test_transfer_acl(self):
- def acl_setter(acl):
- self.xfrsess._acl = acl
- self.check_transfer_acl(acl_setter)
- def test_get_query_zone_name(self):
- msg = self.getmsg()
- self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
- def test_send_data(self):
- self.xfrsess._send_data(self.sock, self.mdata)
- senddata = self.sock.readsent()
- self.assertEqual(senddata, self.mdata)
- def test_reply_xfrout_query_with_error_rcode(self):
- msg = self.getmsg()
- self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
- get_msg = self.sock.read_msg()
- self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
- # tsig signed message
- msg = self.getmsg()
- self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
- self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
- get_msg = self.sock.read_msg()
- self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
- self.assertTrue(self.message_has_tsig(get_msg))
- def test_send_message(self):
- msg = self.getmsg()
- msg.make_response()
- # soa record data with different cases
- soa_record = (4, 3, 'Example.com.', 'com.Example.', 3600, 'SOA', None, 'master.Example.com. admin.exAmple.com. 1234 3600 1800 2419200 7200')
- rrset_soa = self.xfrsess._create_rrset_from_db_record(soa_record)
- msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
- self.xfrsess._send_message(self.sock, msg)
- send_out_data = self.sock.readsent()[2:]
- # CASE_INSENSITIVE compression mode
- render = MessageRenderer();
- render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
- msg.to_wire(render)
- self.assertNotEqual(render.get_data(), send_out_data)
- # CASE_SENSITIVE compression mode
- render.clear()
- render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
- render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
- msg.to_wire(render)
- self.assertEqual(render.get_data(), send_out_data)
- def test_clear_message(self):
- msg = self.getmsg()
- qid = msg.get_qid()
- opcode = msg.get_opcode()
- rcode = msg.get_rcode()
- self.xfrsess._clear_message(msg)
- self.assertEqual(msg.get_qid(), qid)
- self.assertEqual(msg.get_opcode(), opcode)
- self.assertEqual(msg.get_rcode(), rcode)
- self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
- def test_create_rrset_from_db_record(self):
- rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
- self.assertEqual(rrset.get_name().to_text(), "example.com.")
- self.assertEqual(rrset.get_class(), RRClass("IN"))
- self.assertEqual(rrset.get_type().to_text(), "SOA")
- rdata = rrset.get_rdata()
- self.assertEqual(rdata[0].to_text(), self.soa_record[7])
- def test_send_message_with_last_soa(self):
- rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
- msg = self.getmsg()
- msg.make_response()
- # packet number less than TSIG_SIGN_EVERY_NTH
- packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
- self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
- 0, packet_neet_not_sign)
- get_msg = self.sock.read_msg()
- # tsig context is not exist
- self.assertFalse(self.message_has_tsig(get_msg))
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
- #answer_rrset_iter = section_iter(get_msg, section.ANSWER())
- answer = get_msg.get_section(Message.SECTION_ANSWER)[0]#answer_rrset_iter.get_rrset()
- self.assertEqual(answer.get_name().to_text(), "example.com.")
- self.assertEqual(answer.get_class(), RRClass("IN"))
- self.assertEqual(answer.get_type().to_text(), "SOA")
- rdata = answer.get_rdata()
- self.assertEqual(rdata[0].to_text(), self.soa_record[7])
- # msg is the TSIG_SIGN_EVERY_NTH one
- # sending the message with last soa together
- self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
- 0, TSIG_SIGN_EVERY_NTH)
- get_msg = self.sock.read_msg()
- # tsig context is not exist
- self.assertFalse(self.message_has_tsig(get_msg))
- def test_send_message_with_last_soa_with_tsig(self):
- # create tsig context
- self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
- rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
- msg = self.getmsg()
- msg.make_response()
- # packet number less than TSIG_SIGN_EVERY_NTH
- packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
- # msg is not the TSIG_SIGN_EVERY_NTH one
- # sending the message with last soa together
- self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
- 0, packet_neet_not_sign)
- get_msg = self.sock.read_msg()
- self.assertTrue(self.message_has_tsig(get_msg))
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
- # msg is the TSIG_SIGN_EVERY_NTH one
- # sending the message with last soa together
- self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
- 0, TSIG_SIGN_EVERY_NTH)
- get_msg = self.sock.read_msg()
- self.assertTrue(self.message_has_tsig(get_msg))
- def test_trigger_send_message_with_last_soa(self):
- rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
- rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
- rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
- msg = self.getmsg()
- msg.make_response()
- msg.add_rrset(Message.SECTION_ANSWER, rrset_a)
- # length larger than MAX-len(rrset)
- length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1
- # packet number less than TSIG_SIGN_EVERY_NTH
- packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
- # give the function a value that is larger than MAX-len(rrset)
- # this should have triggered the sending of two messages
- # (1 with the rrset we added manually, and 1 that triggered
- # the sending in _with_last_soa)
- self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
- packet_neet_not_sign)
- get_msg = self.sock.read_msg()
- self.assertFalse(self.message_has_tsig(get_msg))
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
- answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
- self.assertEqual(answer.get_name().to_text(), "example.com.")
- self.assertEqual(answer.get_class(), RRClass("IN"))
- self.assertEqual(answer.get_type().to_text(), "A")
- rdata = answer.get_rdata()
- self.assertEqual(rdata[0].to_text(), "192.0.2.1")
- get_msg = self.sock.read_msg()
- self.assertFalse(self.message_has_tsig(get_msg))
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0)
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
- self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
- #answer_rrset_iter = section_iter(get_msg, Message.SECTION_ANSWER)
- answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
- self.assertEqual(answer.get_name().to_text(), "example.com.")
- self.assertEqual(answer.get_class(), RRClass("IN"))
- self.assertEqual(answer.get_type().to_text(), "SOA")
- rdata = answer.get_rdata()
- self.assertEqual(rdata[0].to_text(), self.soa_record[7])
- # and it should not have sent anything else
- self.assertEqual(0, len(self.sock.sendqueue))
- def test_trigger_send_message_with_last_soa_with_tsig(self):
- self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
- rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
- msg = self.getmsg()
- msg.make_response()
- msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
- # length larger than MAX-len(rrset)
- length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1
- # packet number less than TSIG_SIGN_EVERY_NTH
- packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
- # give the function a value that is larger than MAX-len(rrset)
- # this should have triggered the sending of two messages
- # (1 with the rrset we added manually, and 1 that triggered
- # the sending in _with_last_soa)
- self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
- packet_neet_not_sign)
- get_msg = self.sock.read_msg()
- # msg is not the TSIG_SIGN_EVERY_NTH one, it shouldn't be tsig signed
- self.assertFalse(self.message_has_tsig(get_msg))
- # the last packet should be tsig signed
- get_msg = self.sock.read_msg()
- self.assertTrue(self.message_has_tsig(get_msg))
- # and it should not have sent anything else
- self.assertEqual(0, len(self.sock.sendqueue))
- # msg is the TSIG_SIGN_EVERY_NTH one, it should be tsig signed
- self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
- xfrout.TSIG_SIGN_EVERY_NTH)
- get_msg = self.sock.read_msg()
- self.assertTrue(self.message_has_tsig(get_msg))
- # the last packet should be tsig signed
- get_msg = self.sock.read_msg()
- self.assertTrue(self.message_has_tsig(get_msg))
- # and it should not have sent anything else
- self.assertEqual(0, len(self.sock.sendqueue))
- def test_get_rrset_len(self):
- rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
- self.assertEqual(82, get_rrset_len(rrset_soa))
- def test_zone_has_soa(self):
- global sqlite3_ds
- def mydb1(zone, file):
- return True
- sqlite3_ds.get_zone_soa = mydb1
- self.assertTrue(self.xfrsess._zone_has_soa(""))
- def mydb2(zone, file):
- return False
- sqlite3_ds.get_zone_soa = mydb2
- self.assertFalse(self.xfrsess._zone_has_soa(""))
- def test_zone_exist(self):
- global sqlite3_ds
- def zone_exist(zone, file):
- return zone
- sqlite3_ds.zone_exist = zone_exist
- self.assertTrue(self.xfrsess._zone_exist(True))
- self.assertFalse(self.xfrsess._zone_exist(False))
- def test_check_xfrout_available(self):
- def zone_exist(zone):
- return zone
- def zone_has_soa(zone):
- return (not zone)
- self.xfrsess._zone_exist = zone_exist
- self.xfrsess._zone_has_soa = zone_has_soa
- self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
- self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
- def zone_empty(zone):
- return zone
- self.xfrsess._zone_has_soa = zone_empty
- def false_func():
- return False
- self.xfrsess._server.increase_transfers_counter = false_func
- self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
- def true_func():
- return True
- self.xfrsess._server.increase_transfers_counter = true_func
- self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
- def test_dns_xfrout_start_formerror(self):
- # formerror
- self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
- sent_data = self.sock.readsent()
- self.assertEqual(len(sent_data), 0)
- def default(self, param):
- return "example.com"
- def test_dns_xfrout_start_notauth(self):
- self.xfrsess._get_query_zone_name = self.default
- def notauth(formpara):
- return Rcode.NOTAUTH()
- self.xfrsess._check_xfrout_available = notauth
- self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
- get_msg = self.sock.read_msg()
- self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
- def test_dns_xfrout_start_noerror(self):
- self.xfrsess._get_query_zone_name = self.default
- def noerror(form):
- return Rcode.NOERROR()
- self.xfrsess._check_xfrout_available = noerror
- def myreply(msg, sock, zonename):
- self.sock.send(b"success")
- self.xfrsess._reply_xfrout_query = myreply
- self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
- self.assertEqual(self.sock.readsent(), b"success")
- def test_reply_xfrout_query_noerror(self):
- global sqlite3_ds
- def get_zone_soa(zonename, file):
- return self.soa_record
- def get_zone_datas(zone, file):
- return [self.soa_record]
- sqlite3_ds.get_zone_soa = get_zone_soa
- sqlite3_ds.get_zone_datas = get_zone_datas
- self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
- reply_msg = self.sock.read_msg()
- self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
- def test_reply_xfrout_query_noerror_with_tsig(self):
- rrset_data = (4, 3, 'a.example.com.', 'com.example.', 3600, 'A', None, '192.168.1.1')
- global sqlite3_ds
- global xfrout
- def get_zone_soa(zonename, file):
- return self.soa_record
- def get_zone_datas(zone, file):
- zone_rrsets = []
- for i in range(0, 100):
- zone_rrsets.insert(i, rrset_data)
- return zone_rrsets
- def get_rrset_len(rrset):
- return 65520
- sqlite3_ds.get_zone_soa = get_zone_soa
- sqlite3_ds.get_zone_datas = get_zone_datas
- xfrout.get_rrset_len = get_rrset_len
- self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
- self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
- # tsig signed first package
- reply_msg = self.sock.read_msg()
- self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1)
- self.assertTrue(self.message_has_tsig(reply_msg))
- # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig
- for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1):
- reply_msg = self.sock.read_msg()
- self.assertFalse(self.message_has_tsig(reply_msg))
- # TSIG_SIGN_EVERY_NTH packet has tsig
- reply_msg = self.sock.read_msg()
- self.assertTrue(self.message_has_tsig(reply_msg))
- for i in range(0, 100 - TSIG_SIGN_EVERY_NTH):
- reply_msg = self.sock.read_msg()
- self.assertFalse(self.message_has_tsig(reply_msg))
- # tsig signed last package
- reply_msg = self.sock.read_msg()
- self.assertTrue(self.message_has_tsig(reply_msg))
- # and it should not have sent anything else
- self.assertEqual(0, len(self.sock.sendqueue))
- class MyCCSession():
- def __init__(self):
- pass
- def get_remote_config_value(self, module_name, identifier):
- if module_name == "Auth" and identifier == "database_file":
- return "initdb.file", False
- else:
- return "unknown", False
- class MyUnixSockServer(UnixSockServer):
- def __init__(self):
- self._shutdown_event = threading.Event()
- self._max_transfers_out = 10
- self._cc = MyCCSession()
- self._common_init()
- class TestUnixSockServer(unittest.TestCase):
- def setUp(self):
- self.write_sock, self.read_sock = socket.socketpair()
- self.unix = MyUnixSockServer()
- def test_guess_remote(self):
- """Test we can guess the remote endpoint when we have only the
- file descriptor. This is needed, because we get only that one
- from auth."""
- # We test with UDP, as it can be "connected" without other
- # endpoint
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- sock.connect(('127.0.0.1', 12345))
- self.assertEqual(('127.0.0.1', 12345),
- self.unix._guess_remote(sock.fileno()))
- if socket.has_ipv6:
- # Don't check IPv6 address on hosts not supporting them
- sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
- sock.connect(('::1', 12345))
- self.assertEqual(('::1', 12345, 0, 0),
- self.unix._guess_remote(sock.fileno()))
- # Try when pretending there's no IPv6 support
- # (No need to pretend when there's really no IPv6)
- xfrout.socket.has_ipv6 = False
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- sock.connect(('127.0.0.1', 12345))
- self.assertEqual(('127.0.0.1', 12345),
- self.unix._guess_remote(sock.fileno()))
- # Return it back
- xfrout.socket.has_ipv6 = True
- def test_receive_query_message(self):
- send_msg = b"\xd6=\x00\x00\x00\x01\x00"
- msg_len = struct.pack('H', socket.htons(len(send_msg)))
- self.write_sock.send(msg_len)
- self.write_sock.send(send_msg)
- recv_msg = self.unix._receive_query_message(self.read_sock)
- self.assertEqual(recv_msg, send_msg)
- def check_default_ACL(self):
- context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
- 1234, 0, socket.SOCK_DGRAM,
- socket.IPPROTO_UDP,
- socket.AI_NUMERICHOST)[0][4])
- self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
- def check_loaded_ACL(self, acl):
- context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
- 1234, 0, socket.SOCK_DGRAM,
- socket.IPPROTO_UDP,
- socket.AI_NUMERICHOST)[0][4])
- self.assertEqual(isc.acl.acl.ACCEPT, acl.execute(context))
- context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
- 1234, 0, socket.SOCK_DGRAM,
- socket.IPPROTO_UDP,
- socket.AI_NUMERICHOST)[0][4])
- self.assertEqual(isc.acl.acl.REJECT, acl.execute(context))
- def test_update_config_data(self):
- self.check_default_ACL()
- tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
- tsig_key_list = [tsig_key_str]
- bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
- self.unix.update_config_data({'transfers_out':10 })
- self.assertEqual(self.unix._max_transfers_out, 10)
- self.assertTrue(self.unix.tsig_key_ring is not None)
- self.check_default_ACL()
- self.unix.update_config_data({'transfers_out':9,
- 'tsig_key_ring':tsig_key_list})
- self.assertEqual(self.unix._max_transfers_out, 9)
- self.assertEqual(self.unix.tsig_key_ring.size(), 1)
- self.unix.tsig_key_ring.remove(Name("example.com."))
- self.assertEqual(self.unix.tsig_key_ring.size(), 0)
- # bad tsig key
- config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list}
- self.assertRaises(None, self.unix.update_config_data(config_data))
- self.assertEqual(self.unix.tsig_key_ring.size(), 0)
- # Load the ACL
- self.unix.update_config_data({'query_acl': [{'from': '127.0.0.1',
- 'action': 'ACCEPT'}]})
- self.check_loaded_ACL(self.unix._acl)
- # Pass a wrong data there and check it does not replace the old one
- self.assertRaises(isc.acl.acl.LoaderError,
- self.unix.update_config_data,
- {'query_acl': ['Something bad']})
- self.check_loaded_ACL(self.unix._acl)
- def test_zone_config_data(self):
- # By default, there's no specific zone config
- self.assertEqual({}, self.unix._zone_config)
- # Adding config for a specific zone. The config is empty unless
- # explicitly specified.
- self.unix.update_config_data({'zone_config':
- [{'origin': 'example.com',
- 'class': 'IN'}]})
- self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
- # zone class can be omitted
- self.unix.update_config_data({'zone_config':
- [{'origin': 'example.com'}]})
- self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
- # zone class, name are stored in the "normalized" form. class
- # strings are upper cased, names are down cased.
- self.unix.update_config_data({'zone_config':
- [{'origin': 'EXAMPLE.com'}]})
- self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
- # invalid zone class, name will result in exceptions
- self.assertRaises(EmptyLabel,
- self.unix.update_config_data,
- {'zone_config': [{'origin': 'bad..example'}]})
- self.assertRaises(InvalidRRClass,
- self.unix.update_config_data,
- {'zone_config': [{'origin': 'example.com',
- 'class': 'badclass'}]})
- # Configuring a couple of more zones
- self.unix.update_config_data({'zone_config':
- [{'origin': 'example.com'},
- {'origin': 'example.com',
- 'class': 'CH'},
- {'origin': 'example.org'}]})
- self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
- self.assertEqual({}, self.unix._zone_config[('CH', 'example.com.')])
- self.assertEqual({}, self.unix._zone_config[('IN', 'example.org.')])
- # Duplicate data: should be rejected with an exception
- self.assertRaises(ValueError,
- self.unix.update_config_data,
- {'zone_config': [{'origin': 'example.com'},
- {'origin': 'example.org'},
- {'origin': 'example.com'}]})
- def test_zone_config_data_with_acl(self):
- # Similar to the previous test, but with transfer_acl config
- self.unix.update_config_data({'zone_config':
- [{'origin': 'example.com',
- 'transfer_acl':
- [{'from': '127.0.0.1',
- 'action': 'ACCEPT'}]}]})
- acl = self.unix._zone_config[('IN', 'example.com.')]['transfer_acl']
- self.check_loaded_ACL(acl)
- # invalid ACL syntax will be rejected with exception
- self.assertRaises(isc.acl.acl.LoaderError,
- self.unix.update_config_data,
- {'zone_config': [{'origin': 'example.com',
- 'transfer_acl':
- [{'action': 'BADACTION'}]}]})
- def test_get_db_file(self):
- self.assertEqual(self.unix.get_db_file(), "initdb.file")
- def test_increase_transfers_counter(self):
- self.unix._max_transfers_out = 10
- count = self.unix._transfers_counter
- self.assertEqual(self.unix.increase_transfers_counter(), True)
- self.assertEqual(count + 1, self.unix._transfers_counter)
- self.unix._max_transfers_out = 0
- count = self.unix._transfers_counter
- self.assertEqual(self.unix.increase_transfers_counter(), False)
- self.assertEqual(count, self.unix._transfers_counter)
- def test_decrease_transfers_counter(self):
- count = self.unix._transfers_counter
- self.unix.decrease_transfers_counter()
- self.assertEqual(count - 1, self.unix._transfers_counter)
- def _remove_file(self, sock_file):
- try:
- os.remove(sock_file)
- except OSError:
- pass
- def test_sock_file_in_use_file_exist(self):
- sock_file = 'temp.sock.file'
- self._remove_file(sock_file)
- self.assertFalse(self.unix._sock_file_in_use(sock_file))
- self.assertFalse(os.path.exists(sock_file))
- def test_sock_file_in_use_file_not_exist(self):
- self.assertFalse(self.unix._sock_file_in_use('temp.sock.file'))
- def _start_unix_sock_server(self, sock_file):
- serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler)
- serv_thread = threading.Thread(target=serv.serve_forever)
- serv_thread.setDaemon(True)
- serv_thread.start()
- def test_sock_file_in_use(self):
- sock_file = 'temp.sock.file'
- self._remove_file(sock_file)
- self.assertFalse(self.unix._sock_file_in_use(sock_file))
- self._start_unix_sock_server(sock_file)
- old_stdout = sys.stdout
- sys.stdout = open(os.devnull, 'w')
- self.assertTrue(self.unix._sock_file_in_use(sock_file))
- sys.stdout = old_stdout
- def test_remove_unused_sock_file_in_use(self):
- sock_file = 'temp.sock.file'
- self._remove_file(sock_file)
- self.assertFalse(self.unix._sock_file_in_use(sock_file))
- self._start_unix_sock_server(sock_file)
- old_stdout = sys.stdout
- sys.stdout = open(os.devnull, 'w')
- try:
- self.unix._remove_unused_sock_file(sock_file)
- except SystemExit:
- pass
- else:
- # This should never happen
- self.assertTrue(False)
- sys.stdout = old_stdout
- def test_remove_unused_sock_file_dir(self):
- import tempfile
- dir_name = tempfile.mkdtemp()
- old_stdout = sys.stdout
- sys.stdout = open(os.devnull, 'w')
- try:
- self.unix._remove_unused_sock_file(dir_name)
- except SystemExit:
- pass
- else:
- # This should never happen
- self.assertTrue(False)
- sys.stdout = old_stdout
- os.rmdir(dir_name)
- class TestInitialization(unittest.TestCase):
- def setEnv(self, name, value):
- if value is None:
- if name in os.environ:
- del os.environ[name]
- else:
- os.environ[name] = value
- def setUp(self):
- self._oldSocket = os.getenv("BIND10_XFROUT_SOCKET_FILE")
- self._oldFromBuild = os.getenv("B10_FROM_BUILD")
- def tearDown(self):
- self.setEnv("B10_FROM_BUILD", self._oldFromBuild)
- self.setEnv("BIND10_XFROUT_SOCKET_FILE", self._oldSocket)
- # Make sure even the computed values are back
- xfrout.init_paths()
- def testNoEnv(self):
- self.setEnv("B10_FROM_BUILD", None)
- self.setEnv("BIND10_XFROUT_SOCKET_FILE", None)
- xfrout.init_paths()
- self.assertEqual(xfrout.UNIX_SOCKET_FILE,
- "@@LOCALSTATEDIR@@/auth_xfrout_conn")
- def testProvidedSocket(self):
- self.setEnv("B10_FROM_BUILD", None)
- self.setEnv("BIND10_XFROUT_SOCKET_FILE", "The/Socket/File")
- xfrout.init_paths()
- self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File")
- if __name__== "__main__":
- isc.log.resetUnitTestRootLogger()
- unittest.main()
|