# Copyright (C) 2010 Internet Systems Consortium. # # Permission to use, copy, modify, and distribute this software for any # purpose with or without fee is hereby granted, provided that the above # copyright notice and this permission notice appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT, # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. '''Tests for the XfroutSession and UnixSockServer classes ''' import unittest import os from isc.testutils.tsigctx_mock import MockTSIGContext from isc.cc.session import * from pydnspp import * from xfrout import * import xfrout import isc.log import isc.acl.dns TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==") # our fake socket, where we can read and insert messages class MySocket(): def __init__(self, family, type): self.family = family self.type = type self.sendqueue = bytearray() def connect(self, to): pass def close(self): pass def send(self, data): self.sendqueue.extend(data); return len(data) def readsent(self): if len(self.sendqueue) >= 2: size = 2 + struct.unpack("!H", self.sendqueue[:2])[0] else: size = 0 result = self.sendqueue[:size] self.sendqueue = self.sendqueue[size:] return result def read_msg(self): sent_data = self.readsent() get_msg = Message(Message.PARSE) get_msg.from_wire(bytes(sent_data[2:])) return get_msg def clear_send(self): del self.sendqueue[:] # We subclass the Session class we're testing here, only # to override the handle() and _send_data() method class MyXfroutSession(XfroutSession): def handle(self): pass def _send_data(self, sock, data): size = len(data) total_count = 0 while total_count < size: count = sock.send(data[total_count:]) total_count += count class Dbserver: def __init__(self): self._shutdown_event = threading.Event() def get_db_file(self): return None def decrease_transfers_counter(self): pass class TestXfroutSession(unittest.TestCase): def getmsg(self): msg = Message(Message.PARSE) msg.from_wire(self.mdata) return msg def create_mock_tsig_ctx(self, error): # This helper function creates a MockTSIGContext for a given key # and TSIG error to be used as a result of verify (normally faked # one) mock_ctx = MockTSIGContext(TSIG_KEY) mock_ctx.error = error return mock_ctx def message_has_tsig(self, msg): return msg.get_tsig_record() is not None def create_request_data(self, with_tsig=False): msg = Message(Message.RENDER) query_id = 0x1035 msg.set_qid(query_id) msg.set_opcode(Opcode.QUERY()) msg.set_rcode(Rcode.NOERROR()) query_question = Question(Name("example.com"), RRClass.IN(), RRType.AXFR()) msg.add_question(query_question) renderer = MessageRenderer() if with_tsig: tsig_ctx = MockTSIGContext(TSIG_KEY) msg.to_wire(renderer, tsig_ctx) else: msg.to_wire(renderer) request_data = renderer.get_data() return request_data def setUp(self): self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM) self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), TSIGKeyRing(), ('127.0.0.1', 12345), # When not testing ACLs, simply accept isc.acl.dns.REQUEST_LOADER.load( [{"action": "ACCEPT"}])) self.mdata = self.create_request_data(False) self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200') def test_parse_query_message(self): [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(get_rcode.to_text(), "NOERROR") # tsig signed query message request_data = self.create_request_data(True) # BADKEY [rcode, msg] = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "NOTAUTH") self.assertTrue(self.xfrsess._tsig_ctx is not None) # NOERROR self.assertEqual(TSIGKeyRing.SUCCESS, self.xfrsess._tsig_key_ring.add(TSIG_KEY)) [rcode, msg] = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "NOERROR") self.assertTrue(self.xfrsess._tsig_ctx is not None) def check_transfer_acl(self, acl_setter): # ACL checks, put some ACL inside acl_setter(isc.acl.dns.REQUEST_LOADER.load([ { "from": "127.0.0.1", "action": "ACCEPT" }, { "from": "192.0.2.1", "action": "DROP" } ])) # Localhost (the default in this test) is accepted rcode, msg = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(rcode.to_text(), "NOERROR") # This should be dropped completely, therefore returning None self.xfrsess._remote = ('192.0.2.1', 12345) rcode, msg = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(None, rcode) # This should be refused, therefore REFUSED self.xfrsess._remote = ('192.0.2.2', 12345) rcode, msg = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(rcode.to_text(), "REFUSED") # TSIG signed request request_data = self.create_request_data(True) # If the TSIG check fails, it should not check ACL # (If it checked ACL as well, it would just drop the request) self.xfrsess._remote = ('192.0.2.1', 12345) self.xfrsess._tsig_key_ring = TSIGKeyRing() rcode, msg = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "NOTAUTH") self.assertTrue(self.xfrsess._tsig_ctx is not None) # ACL using TSIG: successful case acl_setter(isc.acl.dns.REQUEST_LOADER.load([ {"key": "example.com", "action": "ACCEPT"}, {"action": "REJECT"} ])) self.assertEqual(TSIGKeyRing.SUCCESS, self.xfrsess._tsig_key_ring.add(TSIG_KEY)) [rcode, msg] = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "NOERROR") # ACL using TSIG: key name doesn't match; should be rejected acl_setter(isc.acl.dns.REQUEST_LOADER.load([ {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"} ])) [rcode, msg] = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "REFUSED") # ACL using TSIG: no TSIG; should be rejected acl_setter(isc.acl.dns.REQUEST_LOADER.load([ {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"} ])) [rcode, msg] = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(rcode.to_text(), "REFUSED") # # ACL using IP + TSIG: both should match # acl_setter(isc.acl.dns.REQUEST_LOADER.load([ {"ALL": [{"key": "example.com"}, {"from": "192.0.2.1"}], "action": "ACCEPT"}, {"action": "REJECT"} ])) # both matches self.xfrsess._remote = ('192.0.2.1', 12345) [rcode, msg] = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "NOERROR") # TSIG matches, but address doesn't self.xfrsess._remote = ('192.0.2.2', 12345) [rcode, msg] = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "REFUSED") # Address matches, but TSIG doesn't (not included) self.xfrsess._remote = ('192.0.2.1', 12345) [rcode, msg] = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(rcode.to_text(), "REFUSED") # Neither address nor TSIG matches self.xfrsess._remote = ('192.0.2.2', 12345) [rcode, msg] = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(rcode.to_text(), "REFUSED") def test_transfer_acl(self): def acl_setter(acl): self.xfrsess._acl = acl self.check_transfer_acl(acl_setter) def test_get_query_zone_name(self): msg = self.getmsg() self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.") def test_send_data(self): self.xfrsess._send_data(self.sock, self.mdata) senddata = self.sock.readsent() self.assertEqual(senddata, self.mdata) def test_reply_xfrout_query_with_error_rcode(self): msg = self.getmsg() self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3)) get_msg = self.sock.read_msg() self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN") # tsig signed message msg = self.getmsg() self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3)) get_msg = self.sock.read_msg() self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN") self.assertTrue(self.message_has_tsig(get_msg)) def test_send_message(self): msg = self.getmsg() msg.make_response() # soa record data with different cases soa_record = (4, 3, 'Example.com.', 'com.Example.', 3600, 'SOA', None, 'master.Example.com. admin.exAmple.com. 1234 3600 1800 2419200 7200') rrset_soa = self.xfrsess._create_rrset_from_db_record(soa_record) msg.add_rrset(Message.SECTION_ANSWER, rrset_soa) self.xfrsess._send_message(self.sock, msg) send_out_data = self.sock.readsent()[2:] # CASE_INSENSITIVE compression mode render = MessageRenderer(); render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE) msg.to_wire(render) self.assertNotEqual(render.get_data(), send_out_data) # CASE_SENSITIVE compression mode render.clear() render.set_compress_mode(MessageRenderer.CASE_SENSITIVE) render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE) msg.to_wire(render) self.assertEqual(render.get_data(), send_out_data) def test_clear_message(self): msg = self.getmsg() qid = msg.get_qid() opcode = msg.get_opcode() rcode = msg.get_rcode() self.xfrsess._clear_message(msg) self.assertEqual(msg.get_qid(), qid) self.assertEqual(msg.get_opcode(), opcode) self.assertEqual(msg.get_rcode(), rcode) self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA)) def test_create_rrset_from_db_record(self): rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record) self.assertEqual(rrset.get_name().to_text(), "example.com.") self.assertEqual(rrset.get_class(), RRClass("IN")) self.assertEqual(rrset.get_type().to_text(), "SOA") rdata = rrset.get_rdata() self.assertEqual(rdata[0].to_text(), self.soa_record[7]) def test_send_message_with_last_soa(self): rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) msg = self.getmsg() msg.make_response() # packet number less than TSIG_SIGN_EVERY_NTH packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1 self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0, packet_neet_not_sign) get_msg = self.sock.read_msg() # tsig context is not exist self.assertFalse(self.message_has_tsig(get_msg)) self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0) #answer_rrset_iter = section_iter(get_msg, section.ANSWER()) answer = get_msg.get_section(Message.SECTION_ANSWER)[0]#answer_rrset_iter.get_rrset() self.assertEqual(answer.get_name().to_text(), "example.com.") self.assertEqual(answer.get_class(), RRClass("IN")) self.assertEqual(answer.get_type().to_text(), "SOA") rdata = answer.get_rdata() self.assertEqual(rdata[0].to_text(), self.soa_record[7]) # msg is the TSIG_SIGN_EVERY_NTH one # sending the message with last soa together self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0, TSIG_SIGN_EVERY_NTH) get_msg = self.sock.read_msg() # tsig context is not exist self.assertFalse(self.message_has_tsig(get_msg)) def test_send_message_with_last_soa_with_tsig(self): # create tsig context self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) msg = self.getmsg() msg.make_response() # packet number less than TSIG_SIGN_EVERY_NTH packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1 # msg is not the TSIG_SIGN_EVERY_NTH one # sending the message with last soa together self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0, packet_neet_not_sign) get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0) # msg is the TSIG_SIGN_EVERY_NTH one # sending the message with last soa together self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0, TSIG_SIGN_EVERY_NTH) get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) def test_trigger_send_message_with_last_soa(self): rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600)) rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1")) rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) msg = self.getmsg() msg.make_response() msg.add_rrset(Message.SECTION_ANSWER, rrset_a) # length larger than MAX-len(rrset) length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1 # packet number less than TSIG_SIGN_EVERY_NTH packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1 # give the function a value that is larger than MAX-len(rrset) # this should have triggered the sending of two messages # (1 with the rrset we added manually, and 1 that triggered # the sending in _with_last_soa) self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split, packet_neet_not_sign) get_msg = self.sock.read_msg() self.assertFalse(self.message_has_tsig(get_msg)) self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0) answer = get_msg.get_section(Message.SECTION_ANSWER)[0] self.assertEqual(answer.get_name().to_text(), "example.com.") self.assertEqual(answer.get_class(), RRClass("IN")) self.assertEqual(answer.get_type().to_text(), "A") rdata = answer.get_rdata() self.assertEqual(rdata[0].to_text(), "192.0.2.1") get_msg = self.sock.read_msg() self.assertFalse(self.message_has_tsig(get_msg)) self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0) self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0) #answer_rrset_iter = section_iter(get_msg, Message.SECTION_ANSWER) answer = get_msg.get_section(Message.SECTION_ANSWER)[0] self.assertEqual(answer.get_name().to_text(), "example.com.") self.assertEqual(answer.get_class(), RRClass("IN")) self.assertEqual(answer.get_type().to_text(), "SOA") rdata = answer.get_rdata() self.assertEqual(rdata[0].to_text(), self.soa_record[7]) # and it should not have sent anything else self.assertEqual(0, len(self.sock.sendqueue)) def test_trigger_send_message_with_last_soa_with_tsig(self): self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) msg = self.getmsg() msg.make_response() msg.add_rrset(Message.SECTION_ANSWER, rrset_soa) # length larger than MAX-len(rrset) length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1 # packet number less than TSIG_SIGN_EVERY_NTH packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1 # give the function a value that is larger than MAX-len(rrset) # this should have triggered the sending of two messages # (1 with the rrset we added manually, and 1 that triggered # the sending in _with_last_soa) self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split, packet_neet_not_sign) get_msg = self.sock.read_msg() # msg is not the TSIG_SIGN_EVERY_NTH one, it shouldn't be tsig signed self.assertFalse(self.message_has_tsig(get_msg)) # the last packet should be tsig signed get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) # and it should not have sent anything else self.assertEqual(0, len(self.sock.sendqueue)) # msg is the TSIG_SIGN_EVERY_NTH one, it should be tsig signed self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split, xfrout.TSIG_SIGN_EVERY_NTH) get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) # the last packet should be tsig signed get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) # and it should not have sent anything else self.assertEqual(0, len(self.sock.sendqueue)) def test_get_rrset_len(self): rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) self.assertEqual(82, get_rrset_len(rrset_soa)) def test_zone_has_soa(self): global sqlite3_ds def mydb1(zone, file): return True sqlite3_ds.get_zone_soa = mydb1 self.assertTrue(self.xfrsess._zone_has_soa("")) def mydb2(zone, file): return False sqlite3_ds.get_zone_soa = mydb2 self.assertFalse(self.xfrsess._zone_has_soa("")) def test_zone_exist(self): global sqlite3_ds def zone_exist(zone, file): return zone sqlite3_ds.zone_exist = zone_exist self.assertTrue(self.xfrsess._zone_exist(True)) self.assertFalse(self.xfrsess._zone_exist(False)) def test_check_xfrout_available(self): def zone_exist(zone): return zone def zone_has_soa(zone): return (not zone) self.xfrsess._zone_exist = zone_exist self.xfrsess._zone_has_soa = zone_has_soa self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH") self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL") def zone_empty(zone): return zone self.xfrsess._zone_has_soa = zone_empty def false_func(): return False self.xfrsess._server.increase_transfers_counter = false_func self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED") def true_func(): return True self.xfrsess._server.increase_transfers_counter = true_func self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR") def test_dns_xfrout_start_formerror(self): # formerror self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00") sent_data = self.sock.readsent() self.assertEqual(len(sent_data), 0) def default(self, param): return "example.com" def test_dns_xfrout_start_notauth(self): self.xfrsess._get_query_zone_name = self.default def notauth(formpara): return Rcode.NOTAUTH() self.xfrsess._check_xfrout_available = notauth self.xfrsess.dns_xfrout_start(self.sock, self.mdata) get_msg = self.sock.read_msg() self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH") def test_dns_xfrout_start_noerror(self): self.xfrsess._get_query_zone_name = self.default def noerror(form): return Rcode.NOERROR() self.xfrsess._check_xfrout_available = noerror def myreply(msg, sock, zonename): self.sock.send(b"success") self.xfrsess._reply_xfrout_query = myreply self.xfrsess.dns_xfrout_start(self.sock, self.mdata) self.assertEqual(self.sock.readsent(), b"success") def test_reply_xfrout_query_noerror(self): global sqlite3_ds def get_zone_soa(zonename, file): return self.soa_record def get_zone_datas(zone, file): return [self.soa_record] sqlite3_ds.get_zone_soa = get_zone_soa sqlite3_ds.get_zone_datas = get_zone_datas self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.") reply_msg = self.sock.read_msg() self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2) def test_reply_xfrout_query_noerror_with_tsig(self): rrset_data = (4, 3, 'a.example.com.', 'com.example.', 3600, 'A', None, '192.168.1.1') global sqlite3_ds global xfrout def get_zone_soa(zonename, file): return self.soa_record def get_zone_datas(zone, file): zone_rrsets = [] for i in range(0, 100): zone_rrsets.insert(i, rrset_data) return zone_rrsets def get_rrset_len(rrset): return 65520 sqlite3_ds.get_zone_soa = get_zone_soa sqlite3_ds.get_zone_datas = get_zone_datas xfrout.get_rrset_len = get_rrset_len self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.") # tsig signed first package reply_msg = self.sock.read_msg() self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertTrue(self.message_has_tsig(reply_msg)) # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1): reply_msg = self.sock.read_msg() self.assertFalse(self.message_has_tsig(reply_msg)) # TSIG_SIGN_EVERY_NTH packet has tsig reply_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(reply_msg)) for i in range(0, 100 - TSIG_SIGN_EVERY_NTH): reply_msg = self.sock.read_msg() self.assertFalse(self.message_has_tsig(reply_msg)) # tsig signed last package reply_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(reply_msg)) # and it should not have sent anything else self.assertEqual(0, len(self.sock.sendqueue)) class MyCCSession(): def __init__(self): pass def get_remote_config_value(self, module_name, identifier): if module_name == "Auth" and identifier == "database_file": return "initdb.file", False else: return "unknown", False class MyUnixSockServer(UnixSockServer): def __init__(self): self._shutdown_event = threading.Event() self._max_transfers_out = 10 self._cc = MyCCSession() self._common_init() class TestUnixSockServer(unittest.TestCase): def setUp(self): self.write_sock, self.read_sock = socket.socketpair() self.unix = MyUnixSockServer() def test_guess_remote(self): """Test we can guess the remote endpoint when we have only the file descriptor. This is needed, because we get only that one from auth.""" # We test with UDP, as it can be "connected" without other # endpoint sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.connect(('127.0.0.1', 12345)) self.assertEqual(('127.0.0.1', 12345), self.unix._guess_remote(sock.fileno())) if socket.has_ipv6: # Don't check IPv6 address on hosts not supporting them sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) sock.connect(('::1', 12345)) self.assertEqual(('::1', 12345, 0, 0), self.unix._guess_remote(sock.fileno())) # Try when pretending there's no IPv6 support # (No need to pretend when there's really no IPv6) xfrout.socket.has_ipv6 = False sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.connect(('127.0.0.1', 12345)) self.assertEqual(('127.0.0.1', 12345), self.unix._guess_remote(sock.fileno())) # Return it back xfrout.socket.has_ipv6 = True def test_receive_query_message(self): send_msg = b"\xd6=\x00\x00\x00\x01\x00" msg_len = struct.pack('H', socket.htons(len(send_msg))) self.write_sock.send(msg_len) self.write_sock.send(send_msg) recv_msg = self.unix._receive_query_message(self.read_sock) self.assertEqual(recv_msg, send_msg) def check_default_ACL(self): context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1", 1234, 0, socket.SOCK_DGRAM, socket.IPPROTO_UDP, socket.AI_NUMERICHOST)[0][4]) self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context)) def check_loaded_ACL(self, acl): context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1", 1234, 0, socket.SOCK_DGRAM, socket.IPPROTO_UDP, socket.AI_NUMERICHOST)[0][4]) self.assertEqual(isc.acl.acl.ACCEPT, acl.execute(context)) context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1", 1234, 0, socket.SOCK_DGRAM, socket.IPPROTO_UDP, socket.AI_NUMERICHOST)[0][4]) self.assertEqual(isc.acl.acl.REJECT, acl.execute(context)) def test_update_config_data(self): self.check_default_ACL() tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g==' tsig_key_list = [tsig_key_str] bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g=='] self.unix.update_config_data({'transfers_out':10 }) self.assertEqual(self.unix._max_transfers_out, 10) self.assertTrue(self.unix.tsig_key_ring is not None) self.check_default_ACL() self.unix.update_config_data({'transfers_out':9, 'tsig_key_ring':tsig_key_list}) self.assertEqual(self.unix._max_transfers_out, 9) self.assertEqual(self.unix.tsig_key_ring.size(), 1) self.unix.tsig_key_ring.remove(Name("example.com.")) self.assertEqual(self.unix.tsig_key_ring.size(), 0) # bad tsig key config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list} self.assertRaises(None, self.unix.update_config_data(config_data)) self.assertEqual(self.unix.tsig_key_ring.size(), 0) # Load the ACL self.unix.update_config_data({'query_acl': [{'from': '127.0.0.1', 'action': 'ACCEPT'}]}) self.check_loaded_ACL(self.unix._acl) # Pass a wrong data there and check it does not replace the old one self.assertRaises(isc.acl.acl.LoaderError, self.unix.update_config_data, {'query_acl': ['Something bad']}) self.check_loaded_ACL(self.unix._acl) def test_zone_config_data(self): # By default, there's no specific zone config self.assertEqual({}, self.unix._zone_config) # Adding config for a specific zone. The config is empty unless # explicitly specified. self.unix.update_config_data({'zone_config': [{'origin': 'example.com', 'class': 'IN'}]}) self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')]) # zone class can be omitted self.unix.update_config_data({'zone_config': [{'origin': 'example.com'}]}) self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')]) # zone class, name are stored in the "normalized" form. class # strings are upper cased, names are down cased. self.unix.update_config_data({'zone_config': [{'origin': 'EXAMPLE.com'}]}) self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')]) # invalid zone class, name will result in exceptions self.assertRaises(EmptyLabel, self.unix.update_config_data, {'zone_config': [{'origin': 'bad..example'}]}) self.assertRaises(InvalidRRClass, self.unix.update_config_data, {'zone_config': [{'origin': 'example.com', 'class': 'badclass'}]}) # Configuring a couple of more zones self.unix.update_config_data({'zone_config': [{'origin': 'example.com'}, {'origin': 'example.com', 'class': 'CH'}, {'origin': 'example.org'}]}) self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')]) self.assertEqual({}, self.unix._zone_config[('CH', 'example.com.')]) self.assertEqual({}, self.unix._zone_config[('IN', 'example.org.')]) # Duplicate data: should be rejected with an exception self.assertRaises(ValueError, self.unix.update_config_data, {'zone_config': [{'origin': 'example.com'}, {'origin': 'example.org'}, {'origin': 'example.com'}]}) def test_zone_config_data_with_acl(self): # Similar to the previous test, but with transfer_acl config self.unix.update_config_data({'zone_config': [{'origin': 'example.com', 'transfer_acl': [{'from': '127.0.0.1', 'action': 'ACCEPT'}]}]}) acl = self.unix._zone_config[('IN', 'example.com.')]['transfer_acl'] self.check_loaded_ACL(acl) # invalid ACL syntax will be rejected with exception self.assertRaises(isc.acl.acl.LoaderError, self.unix.update_config_data, {'zone_config': [{'origin': 'example.com', 'transfer_acl': [{'action': 'BADACTION'}]}]}) def test_get_db_file(self): self.assertEqual(self.unix.get_db_file(), "initdb.file") def test_increase_transfers_counter(self): self.unix._max_transfers_out = 10 count = self.unix._transfers_counter self.assertEqual(self.unix.increase_transfers_counter(), True) self.assertEqual(count + 1, self.unix._transfers_counter) self.unix._max_transfers_out = 0 count = self.unix._transfers_counter self.assertEqual(self.unix.increase_transfers_counter(), False) self.assertEqual(count, self.unix._transfers_counter) def test_decrease_transfers_counter(self): count = self.unix._transfers_counter self.unix.decrease_transfers_counter() self.assertEqual(count - 1, self.unix._transfers_counter) def _remove_file(self, sock_file): try: os.remove(sock_file) except OSError: pass def test_sock_file_in_use_file_exist(self): sock_file = 'temp.sock.file' self._remove_file(sock_file) self.assertFalse(self.unix._sock_file_in_use(sock_file)) self.assertFalse(os.path.exists(sock_file)) def test_sock_file_in_use_file_not_exist(self): self.assertFalse(self.unix._sock_file_in_use('temp.sock.file')) def _start_unix_sock_server(self, sock_file): serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler) serv_thread = threading.Thread(target=serv.serve_forever) serv_thread.setDaemon(True) serv_thread.start() def test_sock_file_in_use(self): sock_file = 'temp.sock.file' self._remove_file(sock_file) self.assertFalse(self.unix._sock_file_in_use(sock_file)) self._start_unix_sock_server(sock_file) old_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') self.assertTrue(self.unix._sock_file_in_use(sock_file)) sys.stdout = old_stdout def test_remove_unused_sock_file_in_use(self): sock_file = 'temp.sock.file' self._remove_file(sock_file) self.assertFalse(self.unix._sock_file_in_use(sock_file)) self._start_unix_sock_server(sock_file) old_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') try: self.unix._remove_unused_sock_file(sock_file) except SystemExit: pass else: # This should never happen self.assertTrue(False) sys.stdout = old_stdout def test_remove_unused_sock_file_dir(self): import tempfile dir_name = tempfile.mkdtemp() old_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') try: self.unix._remove_unused_sock_file(dir_name) except SystemExit: pass else: # This should never happen self.assertTrue(False) sys.stdout = old_stdout os.rmdir(dir_name) class TestInitialization(unittest.TestCase): def setEnv(self, name, value): if value is None: if name in os.environ: del os.environ[name] else: os.environ[name] = value def setUp(self): self._oldSocket = os.getenv("BIND10_XFROUT_SOCKET_FILE") self._oldFromBuild = os.getenv("B10_FROM_BUILD") def tearDown(self): self.setEnv("B10_FROM_BUILD", self._oldFromBuild) self.setEnv("BIND10_XFROUT_SOCKET_FILE", self._oldSocket) # Make sure even the computed values are back xfrout.init_paths() def testNoEnv(self): self.setEnv("B10_FROM_BUILD", None) self.setEnv("BIND10_XFROUT_SOCKET_FILE", None) xfrout.init_paths() self.assertEqual(xfrout.UNIX_SOCKET_FILE, "@@LOCALSTATEDIR@@/auth_xfrout_conn") def testProvidedSocket(self): self.setEnv("B10_FROM_BUILD", None) self.setEnv("BIND10_XFROUT_SOCKET_FILE", "The/Socket/File") xfrout.init_paths() self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File") if __name__== "__main__": isc.log.resetUnitTestRootLogger() unittest.main()