# Copyright (C) 2010 Internet Systems Consortium. # # Permission to use, copy, modify, and distribute this software for any # purpose with or without fee is hereby granted, provided that the above # copyright notice and this permission notice appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT, # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. '''Tests for the XfroutSession and UnixSockServer classes ''' import unittest import os from isc.testutils.tsigctx_mock import MockTSIGContext from isc.cc.session import * from pydnspp import * from xfrout import * import xfrout TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==") # our fake socket, where we can read and insert messages class MySocket(): def __init__(self, family, type): self.family = family self.type = type self.sendqueue = bytearray() def connect(self, to): pass def close(self): pass def send(self, data): self.sendqueue.extend(data); return len(data) def readsent(self): if len(self.sendqueue) >= 2: size = 2 + struct.unpack("!H", self.sendqueue[:2])[0] else: size = 0 result = self.sendqueue[:size] self.sendqueue = self.sendqueue[size:] return result def read_msg(self): sent_data = self.readsent() get_msg = Message(Message.PARSE) get_msg.from_wire(bytes(sent_data[2:])) return get_msg def clear_send(self): del self.sendqueue[:] # We subclass the Session class we're testing here, only # to override the handle() and _send_data() method class MyXfroutSession(XfroutSession): def handle(self): pass def _send_data(self, sock, data): size = len(data) total_count = 0 while total_count < size: count = sock.send(data[total_count:]) total_count += count class Dbserver: def __init__(self): self._shutdown_event = threading.Event() def get_db_file(self): return None def decrease_transfers_counter(self): pass class TestXfroutSession(unittest.TestCase): def getmsg(self): msg = Message(Message.PARSE) msg.from_wire(self.mdata) return msg def create_mock_tsig_ctx(self, error): # This helper function creates a MockTSIGContext for a given key # and TSIG error to be used as a result of verify (normally faked # one) mock_ctx = MockTSIGContext(TSIG_KEY) mock_ctx.error = error return mock_ctx def message_has_tsig(self, msg): return msg.get_tsig_record() is not None def create_request_data_with_tsig(self): msg = Message(Message.RENDER) query_id = 0x1035 msg.set_qid(query_id) msg.set_opcode(Opcode.QUERY()) msg.set_rcode(Rcode.NOERROR()) query_question = Question(Name("example.com."), RRClass.IN(), RRType.AXFR()) msg.add_question(query_question) renderer = MessageRenderer() tsig_ctx = MockTSIGContext(TSIG_KEY) msg.to_wire(renderer, tsig_ctx) reply_data = renderer.get_data() return reply_data def setUp(self): self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM) #self.log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False ) self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), TSIGKeyRing()) self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01') self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200') def test_parse_query_message(self): [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata) self.assertEqual(get_rcode.to_text(), "NOERROR") # tsig signed query message request_data = self.create_request_data_with_tsig() # BADKEY [rcode, msg] = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "NOTAUTH") self.assertTrue(self.xfrsess._tsig_ctx is not None) # NOERROR self.xfrsess._tsig_key_ring.add(TSIG_KEY) [rcode, msg] = self.xfrsess._parse_query_message(request_data) self.assertEqual(rcode.to_text(), "NOERROR") self.assertTrue(self.xfrsess._tsig_ctx is not None) def test_get_query_zone_name(self): msg = self.getmsg() self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.") def test_send_data(self): self.xfrsess._send_data(self.sock, self.mdata) senddata = self.sock.readsent() self.assertEqual(senddata, self.mdata) def test_reply_xfrout_query_with_error_rcode(self): msg = self.getmsg() self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3)) get_msg = self.sock.read_msg() self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN") # tsig signed message msg = self.getmsg() self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3)) get_msg = self.sock.read_msg() self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN") self.assertTrue(self.message_has_tsig(get_msg)) def test_send_message(self): msg = self.getmsg() msg.make_response() # soa record data with different cases soa_record = (4, 3, 'Example.com.', 'com.Example.', 3600, 'SOA', None, 'master.Example.com. admin.exAmple.com. 1234 3600 1800 2419200 7200') rrset_soa = self.xfrsess._create_rrset_from_db_record(soa_record) msg.add_rrset(Message.SECTION_ANSWER, rrset_soa) self.xfrsess._send_message(self.sock, msg) send_out_data = self.sock.readsent()[2:] # CASE_INSENSITIVE compression mode render = MessageRenderer(); render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE) msg.to_wire(render) self.assertNotEqual(render.get_data(), send_out_data) # CASE_SENSITIVE compression mode render.clear() render.set_compress_mode(MessageRenderer.CASE_SENSITIVE) render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE) msg.to_wire(render) self.assertEqual(render.get_data(), send_out_data) def test_clear_message(self): msg = self.getmsg() qid = msg.get_qid() opcode = msg.get_opcode() rcode = msg.get_rcode() self.xfrsess._clear_message(msg) self.assertEqual(msg.get_qid(), qid) self.assertEqual(msg.get_opcode(), opcode) self.assertEqual(msg.get_rcode(), rcode) self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA)) def test_reply_query_with_format_error(self): msg = self.getmsg() self.xfrsess._reply_query_with_format_error(msg, self.sock) get_msg = self.sock.read_msg() self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR") # tsig signed message msg = self.getmsg() self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) self.xfrsess._reply_query_with_format_error(msg, self.sock) get_msg = self.sock.read_msg() self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR") self.assertTrue(self.message_has_tsig(get_msg)) def test_create_rrset_from_db_record(self): rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record) self.assertEqual(rrset.get_name().to_text(), "example.com.") self.assertEqual(rrset.get_class(), RRClass("IN")) self.assertEqual(rrset.get_type().to_text(), "SOA") rdata = rrset.get_rdata() self.assertEqual(rdata[0].to_text(), self.soa_record[7]) def test_send_message_with_last_soa(self): rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) msg = self.getmsg() msg.make_response() # packet number less than TSIG_SIGN_EVERY_NTH packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1 self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0, packet_neet_not_sign) get_msg = self.sock.read_msg() # tsig context is not exist self.assertFalse(self.message_has_tsig(get_msg)) self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0) #answer_rrset_iter = section_iter(get_msg, section.ANSWER()) answer = get_msg.get_section(Message.SECTION_ANSWER)[0]#answer_rrset_iter.get_rrset() self.assertEqual(answer.get_name().to_text(), "example.com.") self.assertEqual(answer.get_class(), RRClass("IN")) self.assertEqual(answer.get_type().to_text(), "SOA") rdata = answer.get_rdata() self.assertEqual(rdata[0].to_text(), self.soa_record[7]) # msg is the TSIG_SIGN_EVERY_NTH one # sending the message with last soa together self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0, TSIG_SIGN_EVERY_NTH) get_msg = self.sock.read_msg() # tsig context is not exist self.assertFalse(self.message_has_tsig(get_msg)) def test_send_message_with_last_soa_with_tsig(self): # create tsig context self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) msg = self.getmsg() msg.make_response() # packet number less than TSIG_SIGN_EVERY_NTH packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1 # msg is not the TSIG_SIGN_EVERY_NTH one # sending the message with last soa together self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0, packet_neet_not_sign) get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0) # msg is the TSIG_SIGN_EVERY_NTH one # sending the message with last soa together self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, 0, TSIG_SIGN_EVERY_NTH) get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) def test_trigger_send_message_with_last_soa(self): rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600)) rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1")) rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) msg = self.getmsg() msg.make_response() msg.add_rrset(Message.SECTION_ANSWER, rrset_a) # length larger than MAX-len(rrset) length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1 # packet number less than TSIG_SIGN_EVERY_NTH packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1 # give the function a value that is larger than MAX-len(rrset) # this should have triggered the sending of two messages # (1 with the rrset we added manually, and 1 that triggered # the sending in _with_last_soa) self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split, packet_neet_not_sign) get_msg = self.sock.read_msg() self.assertFalse(self.message_has_tsig(get_msg)) self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0) answer = get_msg.get_section(Message.SECTION_ANSWER)[0] self.assertEqual(answer.get_name().to_text(), "example.com.") self.assertEqual(answer.get_class(), RRClass("IN")) self.assertEqual(answer.get_type().to_text(), "A") rdata = answer.get_rdata() self.assertEqual(rdata[0].to_text(), "192.0.2.1") get_msg = self.sock.read_msg() self.assertFalse(self.message_has_tsig(get_msg)) self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0) self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0) #answer_rrset_iter = section_iter(get_msg, Message.SECTION_ANSWER) answer = get_msg.get_section(Message.SECTION_ANSWER)[0] self.assertEqual(answer.get_name().to_text(), "example.com.") self.assertEqual(answer.get_class(), RRClass("IN")) self.assertEqual(answer.get_type().to_text(), "SOA") rdata = answer.get_rdata() self.assertEqual(rdata[0].to_text(), self.soa_record[7]) # and it should not have sent anything else self.assertEqual(0, len(self.sock.sendqueue)) def test_trigger_send_message_with_last_soa_with_tsig(self): self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) msg = self.getmsg() msg.make_response() msg.add_rrset(Message.SECTION_ANSWER, rrset_soa) # length larger than MAX-len(rrset) length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1 # packet number less than TSIG_SIGN_EVERY_NTH packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1 # give the function a value that is larger than MAX-len(rrset) # this should have triggered the sending of two messages # (1 with the rrset we added manually, and 1 that triggered # the sending in _with_last_soa) self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split, packet_neet_not_sign) get_msg = self.sock.read_msg() # msg is not the TSIG_SIGN_EVERY_NTH one, it shouldn't be tsig signed self.assertFalse(self.message_has_tsig(get_msg)) # the last packet should be tsig signed get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) # and it should not have sent anything else self.assertEqual(0, len(self.sock.sendqueue)) # msg is the TSIG_SIGN_EVERY_NTH one, it should be tsig signed self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split, xfrout.TSIG_SIGN_EVERY_NTH) get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) # the last packet should be tsig signed get_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(get_msg)) # and it should not have sent anything else self.assertEqual(0, len(self.sock.sendqueue)) def test_get_rrset_len(self): rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record) self.assertEqual(82, get_rrset_len(rrset_soa)) def test_zone_has_soa(self): global sqlite3_ds def mydb1(zone, file): return True sqlite3_ds.get_zone_soa = mydb1 self.assertTrue(self.xfrsess._zone_has_soa("")) def mydb2(zone, file): return False sqlite3_ds.get_zone_soa = mydb2 self.assertFalse(self.xfrsess._zone_has_soa("")) def test_zone_exist(self): global sqlite3_ds def zone_exist(zone, file): return zone sqlite3_ds.zone_exist = zone_exist self.assertTrue(self.xfrsess._zone_exist(True)) self.assertFalse(self.xfrsess._zone_exist(False)) def test_check_xfrout_available(self): def zone_exist(zone): return zone def zone_has_soa(zone): return (not zone) self.xfrsess._zone_exist = zone_exist self.xfrsess._zone_has_soa = zone_has_soa self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH") self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL") def zone_empty(zone): return zone self.xfrsess._zone_has_soa = zone_empty def false_func(): return False self.xfrsess._server.increase_transfers_counter = false_func self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED") def true_func(): return True self.xfrsess._server.increase_transfers_counter = true_func self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR") def test_dns_xfrout_start_formerror(self): # formerror self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00") sent_data = self.sock.readsent() self.assertEqual(len(sent_data), 0) def default(self, param): return "example.com" def test_dns_xfrout_start_notauth(self): self.xfrsess._get_query_zone_name = self.default def notauth(formpara): return Rcode.NOTAUTH() self.xfrsess._check_xfrout_available = notauth self.xfrsess.dns_xfrout_start(self.sock, self.mdata) get_msg = self.sock.read_msg() self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH") def test_dns_xfrout_start_noerror(self): self.xfrsess._get_query_zone_name = self.default def noerror(form): return Rcode.NOERROR() self.xfrsess._check_xfrout_available = noerror def myreply(msg, sock, zonename): self.sock.send(b"success") self.xfrsess._reply_xfrout_query = myreply self.xfrsess.dns_xfrout_start(self.sock, self.mdata) self.assertEqual(self.sock.readsent(), b"success") def test_reply_xfrout_query_noerror(self): global sqlite3_ds def get_zone_soa(zonename, file): return self.soa_record def get_zone_datas(zone, file): return [self.soa_record] sqlite3_ds.get_zone_soa = get_zone_soa sqlite3_ds.get_zone_datas = get_zone_datas self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.") reply_msg = self.sock.read_msg() self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2) def test_reply_xfrout_query_noerror_with_tsig(self): rrset_data = (4, 3, 'a.example.com.', 'com.example.', 3600, 'A', None, '192.168.1.1') global sqlite3_ds global xfrout def get_zone_soa(zonename, file): return self.soa_record def get_zone_datas(zone, file): zone_rrsets = [] for i in range(0, 100): zone_rrsets.insert(i, rrset_data) return zone_rrsets def get_rrset_len(rrset): return 65520 sqlite3_ds.get_zone_soa = get_zone_soa sqlite3_ds.get_zone_datas = get_zone_datas xfrout.get_rrset_len = get_rrset_len self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR) self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.") # tsig signed first package reply_msg = self.sock.read_msg() self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1) self.assertTrue(self.message_has_tsig(reply_msg)) # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1): reply_msg = self.sock.read_msg() self.assertFalse(self.message_has_tsig(reply_msg)) # TSIG_SIGN_EVERY_NTH packet has tsig reply_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(reply_msg)) for i in range(0, 100 - TSIG_SIGN_EVERY_NTH): reply_msg = self.sock.read_msg() self.assertFalse(self.message_has_tsig(reply_msg)) # tsig signed last package reply_msg = self.sock.read_msg() self.assertTrue(self.message_has_tsig(reply_msg)) # and it should not have sent anything else self.assertEqual(0, len(self.sock.sendqueue)) class MyCCSession(): def __init__(self): pass def get_remote_config_value(self, module_name, identifier): if module_name == "Auth" and identifier == "database_file": return "initdb.file", False else: return "unknown", False class MyUnixSockServer(UnixSockServer): def __init__(self): self._lock = threading.Lock() self._transfers_counter = 0 self._shutdown_event = threading.Event() self._max_transfers_out = 10 self._cc = MyCCSession() #self._log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False ) class TestUnixSockServer(unittest.TestCase): def setUp(self): self.write_sock, self.read_sock = socket.socketpair() self.unix = MyUnixSockServer() def test_receive_query_message(self): send_msg = b"\xd6=\x00\x00\x00\x01\x00" msg_len = struct.pack('H', socket.htons(len(send_msg))) self.write_sock.send(msg_len) self.write_sock.send(send_msg) recv_msg = self.unix._receive_query_message(self.read_sock) self.assertEqual(recv_msg, send_msg) def test_updata_config_data(self): tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g==' tsig_key_list = [tsig_key_str] bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g=='] self.unix.update_config_data({'transfers_out':10 }) self.assertEqual(self.unix._max_transfers_out, 10) self.assertTrue(self.unix.tsig_key_ring is not None) self.unix.update_config_data({'transfers_out':9, 'tsig_key_ring':tsig_key_list}) self.assertEqual(self.unix._max_transfers_out, 9) self.assertEqual(self.unix.tsig_key_ring.size(), 1) self.unix.tsig_key_ring.remove(Name("example.com.")) self.assertEqual(self.unix.tsig_key_ring.size(), 0) # bad tsig key config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list} self.assertRaises(None, self.unix.update_config_data(config_data)) self.assertEqual(self.unix.tsig_key_ring.size(), 0) def test_get_db_file(self): self.assertEqual(self.unix.get_db_file(), "initdb.file") def test_increase_transfers_counter(self): self.unix._max_transfers_out = 10 count = self.unix._transfers_counter self.assertEqual(self.unix.increase_transfers_counter(), True) self.assertEqual(count + 1, self.unix._transfers_counter) self.unix._max_transfers_out = 0 count = self.unix._transfers_counter self.assertEqual(self.unix.increase_transfers_counter(), False) self.assertEqual(count, self.unix._transfers_counter) def test_decrease_transfers_counter(self): count = self.unix._transfers_counter self.unix.decrease_transfers_counter() self.assertEqual(count - 1, self.unix._transfers_counter) def _remove_file(self, sock_file): try: os.remove(sock_file) except OSError: pass def test_sock_file_in_use_file_exist(self): sock_file = 'temp.sock.file' self._remove_file(sock_file) self.assertFalse(self.unix._sock_file_in_use(sock_file)) self.assertFalse(os.path.exists(sock_file)) def test_sock_file_in_use_file_not_exist(self): self.assertFalse(self.unix._sock_file_in_use('temp.sock.file')) def _start_unix_sock_server(self, sock_file): serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler) serv_thread = threading.Thread(target=serv.serve_forever) serv_thread.setDaemon(True) serv_thread.start() def test_sock_file_in_use(self): sock_file = 'temp.sock.file' self._remove_file(sock_file) self.assertFalse(self.unix._sock_file_in_use(sock_file)) self._start_unix_sock_server(sock_file) old_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') self.assertTrue(self.unix._sock_file_in_use(sock_file)) sys.stdout = old_stdout def test_remove_unused_sock_file_in_use(self): sock_file = 'temp.sock.file' self._remove_file(sock_file) self.assertFalse(self.unix._sock_file_in_use(sock_file)) self._start_unix_sock_server(sock_file) old_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') try: self.unix._remove_unused_sock_file(sock_file) except SystemExit: pass else: # This should never happen self.assertTrue(False) sys.stdout = old_stdout def test_remove_unused_sock_file_dir(self): import tempfile dir_name = tempfile.mkdtemp() old_stdout = sys.stdout sys.stdout = open(os.devnull, 'w') try: self.unix._remove_unused_sock_file(dir_name) except SystemExit: pass else: # This should never happen self.assertTrue(False) sys.stdout = old_stdout os.rmdir(dir_name) class TestInitialization(unittest.TestCase): def setEnv(self, name, value): if value is None: if name in os.environ: del os.environ[name] else: os.environ[name] = value def setUp(self): self._oldSocket = os.getenv("BIND10_XFROUT_SOCKET_FILE") self._oldFromBuild = os.getenv("B10_FROM_BUILD") def tearDown(self): self.setEnv("B10_FROM_BUILD", self._oldFromBuild) self.setEnv("BIND10_XFROUT_SOCKET_FILE", self._oldSocket) # Make sure even the computed values are back xfrout.init_paths() def testNoEnv(self): self.setEnv("B10_FROM_BUILD", None) self.setEnv("BIND10_XFROUT_SOCKET_FILE", None) xfrout.init_paths() self.assertEqual(xfrout.UNIX_SOCKET_FILE, "@@LOCALSTATEDIR@@/auth_xfrout_conn") def testProvidedSocket(self): self.setEnv("B10_FROM_BUILD", None) self.setEnv("BIND10_XFROUT_SOCKET_FILE", "The/Socket/File") xfrout.init_paths() self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File") if __name__== "__main__": unittest.main()