|
@@ -0,0 +1,283 @@
|
|
|
+# Copyright (C) 2010 Internet Systems Consortium.
|
|
|
+#
|
|
|
+# Permission to use, copy, modify, and distribute this software for any
|
|
|
+# purpose with or without fee is hereby granted, provided that the above
|
|
|
+# copyright notice and this permission notice appear in all copies.
|
|
|
+#
|
|
|
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
|
|
|
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
|
|
|
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
|
|
|
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
|
|
|
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
|
|
|
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
|
|
|
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
|
|
|
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
|
+
|
|
|
+'''Tests for the XfroutSession and UnixSockServer classes '''
|
|
|
+
|
|
|
+
|
|
|
+import unittest
|
|
|
+import os
|
|
|
+from isc.cc.session import *
|
|
|
+from bind10_dns import *
|
|
|
+from xfrout import *
|
|
|
+
|
|
|
+# our fake socket, where we can read and insert messages
|
|
|
+class MySocket():
|
|
|
+ def __init__(self, family, type):
|
|
|
+ self.family = family
|
|
|
+ self.type = type
|
|
|
+ self.sendqueue = bytearray()
|
|
|
+
|
|
|
+ def connect(self, to):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def close(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def send(self, data):
|
|
|
+ self.sendqueue.extend(data);
|
|
|
+ return len(data)
|
|
|
+
|
|
|
+ def readsent(self):
|
|
|
+ result = self.sendqueue[:]
|
|
|
+ del self.sendqueue[:]
|
|
|
+ return result
|
|
|
+
|
|
|
+ def read_msg(self):
|
|
|
+ sent_data = self.readsent()
|
|
|
+ get_msg = message(message_mode.PARSE)
|
|
|
+ get_msg.from_wire(input_buffer(bytes(sent_data[2:])))
|
|
|
+ return get_msg
|
|
|
+
|
|
|
+ def clear_send(self):
|
|
|
+ del self.sendqueue[:]
|
|
|
+
|
|
|
+# We subclass the Session class we're testing here, only
|
|
|
+# to override the __init__() method, which wants a socket,
|
|
|
+class MyXfroutSession(XfroutSession):
|
|
|
+ def handle(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+class Dbserver:
|
|
|
+ def __init__(self):
|
|
|
+ self._shutdown_event = threading.Event()
|
|
|
+ def get_db_file(self):
|
|
|
+ return None
|
|
|
+ def decrease_transfers_counter(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+class TestXfroutSession(unittest.TestCase):
|
|
|
+ def getmsg(self):
|
|
|
+ msg = message(message_mode.PARSE)
|
|
|
+ msg.from_wire(input_buffer(self.mdata))
|
|
|
+ return msg
|
|
|
+
|
|
|
+ def setUp(self):
|
|
|
+ request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
|
|
|
+ self.xfrsess = MyXfroutSession(request, None, None)
|
|
|
+ self.xfrsess.server = Dbserver()
|
|
|
+ self.mdata = b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01'
|
|
|
+ self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
|
|
|
+ self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
|
|
|
+
|
|
|
+ def test_parse_query_message(self):
|
|
|
+ [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
|
|
|
+ self.assertEqual(get_rcode.to_text(), "NOERROR")
|
|
|
+
|
|
|
+ def test_get_query_zone_name(self):
|
|
|
+ msg = self.getmsg()
|
|
|
+ self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
|
|
|
+
|
|
|
+ def test_send_data(self):
|
|
|
+ self.xfrsess._send_data(self.sock, self.mdata)
|
|
|
+ senddata = self.sock.readsent()
|
|
|
+ self.assertEqual(senddata, self.mdata)
|
|
|
+
|
|
|
+ def test_reply_xfrout_query_with_error_rcode(self):
|
|
|
+ msg = self.getmsg()
|
|
|
+ self.xfrsess._reply_query_with_error_rcode(msg, self.sock, rcode(3))
|
|
|
+ get_msg = self.sock.read_msg()
|
|
|
+ self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
|
|
|
+
|
|
|
+ def test_clear_message(self):
|
|
|
+ msg = self.getmsg()
|
|
|
+ qid = msg.get_qid()
|
|
|
+ opcode = msg.get_opcode()
|
|
|
+ rcode = msg.get_rcode()
|
|
|
+
|
|
|
+ self.xfrsess._clear_message(msg)
|
|
|
+ self.assertEqual(msg.get_qid(), qid)
|
|
|
+ self.assertEqual(msg.get_opcode(), opcode)
|
|
|
+ self.assertEqual(msg.get_rcode(), rcode)
|
|
|
+ self.assertTrue(msg.get_header_flag(message_flag.AA()))
|
|
|
+
|
|
|
+ def test_reply_query_with_format_error(self):
|
|
|
+
|
|
|
+ msg = self.getmsg()
|
|
|
+ self.xfrsess._reply_query_with_format_error(msg, self.sock)
|
|
|
+ get_msg = self.sock.read_msg()
|
|
|
+ self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
|
|
|
+
|
|
|
+ def test_create_rrset_from_db_record(self):
|
|
|
+ rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
|
|
|
+ self.assertEqual(rrset.get_name().to_text(), "example.com.")
|
|
|
+ self.assertEqual(rrset.get_class(), rr_class.IN())
|
|
|
+ self.assertEqual(rrset.get_type().to_text(), "SOA")
|
|
|
+ rdata_iter = rrset.get_rdata_iterator()
|
|
|
+ rdata_iter.first()
|
|
|
+ self.assertEqual(rdata_iter.get_current().to_text(), self.soa_record[7])
|
|
|
+
|
|
|
+ def test_send_message_with_last_soa(self):
|
|
|
+ rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
|
|
|
+
|
|
|
+ msg = self.getmsg()
|
|
|
+ msg.make_response()
|
|
|
+ self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa)
|
|
|
+ get_msg = self.sock.read_msg()
|
|
|
+
|
|
|
+ self.assertEqual(get_msg.get_rr_count(section.QUESTION()), 1)
|
|
|
+ self.assertEqual(get_msg.get_rr_count(section.ANSWER()), 1)
|
|
|
+ self.assertEqual(get_msg.get_rr_count(section.AUTHORITY()), 0)
|
|
|
+
|
|
|
+ answer_rrset_iter = section_iter(get_msg, section.ANSWER())
|
|
|
+ answer = answer_rrset_iter.get_rrset()
|
|
|
+ self.assertEqual(answer.get_name().to_text(), "example.com.")
|
|
|
+ self.assertEqual(answer.get_class(), rr_class.IN())
|
|
|
+ self.assertEqual(answer.get_type().to_text(), "SOA")
|
|
|
+ rdata_iter = answer.get_rdata_iterator()
|
|
|
+ rdata_iter.first()
|
|
|
+ self.assertEqual(rdata_iter.get_current().to_text(), self.soa_record[7])
|
|
|
+
|
|
|
+ def test_get_message_len(self):
|
|
|
+ msg = self.getmsg()
|
|
|
+ msg.make_response()
|
|
|
+ self.assertEqual(self.xfrsess._get_message_len(msg), 29)
|
|
|
+
|
|
|
+ def test_zone_is_empty(self):
|
|
|
+ global sqlite3_ds
|
|
|
+ def mydb1(zone, file):
|
|
|
+ return True
|
|
|
+ sqlite3_ds.get_zone_soa = mydb1
|
|
|
+ self.assertEqual(self.xfrsess._zone_is_empty(""), False)
|
|
|
+ def mydb2(zone, file):
|
|
|
+ return False
|
|
|
+ sqlite3_ds.get_zone_soa = mydb2
|
|
|
+ self.assertEqual(self.xfrsess._zone_is_empty(""), True)
|
|
|
+
|
|
|
+ def test_zone_exist(self):
|
|
|
+ global sqlite3_ds
|
|
|
+ def zone_soa(zone, file):
|
|
|
+ return zone
|
|
|
+ sqlite3_ds.get_zone_soa = zone_soa
|
|
|
+ self.assertEqual(self.xfrsess._zone_exist(True), True)
|
|
|
+ self.assertEqual(self.xfrsess._zone_exist(False), False)
|
|
|
+
|
|
|
+ def test_check_xfrout_available(self):
|
|
|
+ def zone_exist(zone):
|
|
|
+ return zone
|
|
|
+ self.xfrsess._zone_exist = zone_exist
|
|
|
+ self.xfrsess._zone_is_empty = zone_exist
|
|
|
+ self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
|
|
|
+ self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
|
|
|
+
|
|
|
+ def zone_empty(zone):
|
|
|
+ return not zone
|
|
|
+ self.xfrsess._zone_is_empty = zone_empty
|
|
|
+ def false_func():
|
|
|
+ return False
|
|
|
+ self.xfrsess.server.increase_transfers_counter = false_func
|
|
|
+ self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
|
|
|
+ def true_func():
|
|
|
+ return True
|
|
|
+ self.xfrsess.server.increase_transfers_counter = true_func
|
|
|
+ self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
|
|
|
+
|
|
|
+ def test_dns_xfrout_start_formerror(self):
|
|
|
+ # formerror
|
|
|
+ self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
|
|
|
+ sent_data = self.sock.readsent()
|
|
|
+ self.assertEqual(len(sent_data), 0)
|
|
|
+
|
|
|
+ def default(self, param):
|
|
|
+ return "example.com"
|
|
|
+
|
|
|
+ def test_dns_xfrout_start_notauth(self):
|
|
|
+ self.xfrsess._get_query_zone_name = self.default
|
|
|
+ def notauth(formpara):
|
|
|
+ return rcode.NOTAUTH()
|
|
|
+ self.xfrsess._check_xfrout_available = notauth
|
|
|
+ self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
|
|
|
+ get_msg = self.sock.read_msg()
|
|
|
+ self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
|
|
|
+
|
|
|
+ def test_dns_xfrout_start_noerror(self):
|
|
|
+ self.xfrsess._get_query_zone_name = self.default
|
|
|
+ def noerror(form):
|
|
|
+ return rcode.NOERROR()
|
|
|
+ self.xfrsess._check_xfrout_available = noerror
|
|
|
+
|
|
|
+ def myreply(msg, sock, zonename):
|
|
|
+ self.sock.send(b"success")
|
|
|
+
|
|
|
+ self.xfrsess._reply_xfrout_query = myreply
|
|
|
+ self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
|
|
|
+ self.assertEqual(self.sock.readsent(), b"success")
|
|
|
+
|
|
|
+ def test_reply_xfrout_query_noerror(self):
|
|
|
+ global sqlite3_ds
|
|
|
+ def get_zone_soa(zonename, file):
|
|
|
+ return self.soa_record
|
|
|
+
|
|
|
+ def get_zone_datas(zone, file):
|
|
|
+ return [self.soa_record]
|
|
|
+
|
|
|
+ sqlite3_ds.get_zone_soa = get_zone_soa
|
|
|
+ sqlite3_ds.get_zone_datas = get_zone_datas
|
|
|
+ self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
|
|
|
+ reply_msg = self.sock.read_msg()
|
|
|
+ self.assertEqual(reply_msg.get_rr_count(section.ANSWER()), 2)
|
|
|
+
|
|
|
+ # set event
|
|
|
+ self.xfrsess.server._shutdown_event.set()
|
|
|
+ self.assertRaises(XfroutException, self.xfrsess._reply_xfrout_query, self.getmsg(), self.sock, "example.com.")
|
|
|
+
|
|
|
+class MyUnixSockServer(UnixSockServer):
|
|
|
+ def __init__(self):
|
|
|
+ self._lock = threading.Lock()
|
|
|
+ self._transfers_counter = 0
|
|
|
+ self._shutdown_event = threading.Event()
|
|
|
+ self._db_file = "initdb.file"
|
|
|
+ self._max_transfers_out = 10
|
|
|
+
|
|
|
+class TestUnixSockServer(unittest.TestCase):
|
|
|
+ def setUp(self):
|
|
|
+ self.unix = MyUnixSockServer()
|
|
|
+
|
|
|
+ def test_updata_config_data(self):
|
|
|
+ self.unix.update_config_data({'transfers_out':10, 'db_file':"db.file"})
|
|
|
+ self.assertEqual(self.unix._max_transfers_out, 10)
|
|
|
+ self.assertEqual(self.unix._db_file, "db.file")
|
|
|
+
|
|
|
+ def test_get_db_file(self):
|
|
|
+ self.assertEqual(self.unix.get_db_file(), "initdb.file")
|
|
|
+
|
|
|
+ def test_increase_transfers_counter(self):
|
|
|
+ self.unix._max_transfers_out = 10
|
|
|
+ count = self.unix._transfers_counter
|
|
|
+ self.assertEqual(self.unix.increase_transfers_counter(), True)
|
|
|
+ self.assertEqual(count + 1, self.unix._transfers_counter)
|
|
|
+
|
|
|
+ self.unix._max_transfers_out = 0
|
|
|
+ count = self.unix._transfers_counter
|
|
|
+ self.assertEqual(self.unix.increase_transfers_counter(), False)
|
|
|
+ self.assertEqual(count, self.unix._transfers_counter)
|
|
|
+
|
|
|
+ def test_decrease_transfers_counter(self):
|
|
|
+ count = self.unix._transfers_counter
|
|
|
+ self.unix.decrease_transfers_counter()
|
|
|
+ self.assertEqual(count - 1, self.unix._transfers_counter)
|
|
|
+
|
|
|
+
|
|
|
+if __name__== "__main__":
|
|
|
+ unittest.main()
|