Browse Source

Add unittest code for xfrout.

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1643 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang 15 years ago
parent
commit
3bbc64dc58

+ 2 - 0
configure.ac

@@ -388,6 +388,7 @@ AC_OUTPUT([src/bin/cfgmgr/b10-cfgmgr.py
            src/bin/xfrin/run_b10-xfrin.sh
            src/bin/xfrout/xfrout.py
            src/bin/xfrout/xfrout.pre
+           src/bin/xfrout/tests/xfrout_test
            src/bin/xfrout/run_b10-xfrout.sh
            src/bin/bind10/bind10.py
            src/bin/bind10/tests/bind10_test
@@ -416,6 +417,7 @@ AC_OUTPUT([src/bin/cfgmgr/b10-cfgmgr.py
            chmod +x src/bin/bind10/run_bind10.sh
            chmod +x src/bin/cmdctl/unittest/cmdctl_test
            chmod +x src/bin/xfrin/unittest/xfrin_test
+           chmod +x src/bin/xfrout/tests/xfrout_test
            chmod +x src/bin/bindctl/unittest/bindctl_test
            chmod +x src/bin/bindctl/run_bindctl.sh
            chmod +x src/bin/loadzone/run_loadzone.sh

+ 12 - 0
src/bin/xfrout/tests/xfrout_test.in

@@ -0,0 +1,12 @@
+#! /bin/sh
+
+PYTHON_EXEC=${PYTHON_EXEC:-@PYTHON@}
+export PYTHON_EXEC
+
+TEST_PATH=@abs_top_srcdir@/src/bin/xfrout/tests
+PYTHONPATH=@abs_top_srcdir@/src/bin/xfrout:@abs_top_srcdir@/src/lib/python
+export PYTHONPATH
+
+cd ${TEST_PATH}
+exec ${PYTHON_EXEC} -O xfrout_test.py $*
+

+ 283 - 0
src/bin/xfrout/tests/xfrout_test.py

@@ -0,0 +1,283 @@
+# Copyright (C) 2010  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+'''Tests for the XfroutSession and UnixSockServer classes '''
+
+
+import unittest
+import os
+from isc.cc.session import *
+from bind10_dns import *
+from xfrout import *
+
+# our fake socket, where we can read and insert messages
+class MySocket():
+    def __init__(self, family, type):
+        self.family = family
+        self.type = type
+        self.sendqueue = bytearray()
+
+    def connect(self, to):
+        pass
+
+    def close(self):
+        pass
+
+    def send(self, data):
+        self.sendqueue.extend(data);
+        return len(data)
+
+    def readsent(self):
+        result = self.sendqueue[:]
+        del self.sendqueue[:]
+        return result
+    
+    def read_msg(self):
+        sent_data = self.readsent()
+        get_msg = message(message_mode.PARSE)
+        get_msg.from_wire(input_buffer(bytes(sent_data[2:])))
+        return get_msg
+    
+    def clear_send(self):
+        del self.sendqueue[:]
+
+# We subclass the Session class we're testing here, only
+# to override the __init__() method, which wants a socket,
+class MyXfroutSession(XfroutSession):
+    def handle(self):
+        pass
+    
+class Dbserver:
+    def __init__(self):
+        self._shutdown_event = threading.Event()
+    def get_db_file(self):
+        return None
+    def decrease_transfers_counter(self):
+        pass
+
+class TestXfroutSession(unittest.TestCase):
+    def getmsg(self):
+        msg = message(message_mode.PARSE)
+        msg.from_wire(input_buffer(self.mdata))
+        return msg
+
+    def setUp(self):
+        request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
+        self.xfrsess = MyXfroutSession(request, None, None)
+        self.xfrsess.server = Dbserver()
+        self.mdata = b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01'
+        self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
+        self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
+
+    def test_parse_query_message(self):
+        [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(get_rcode.to_text(), "NOERROR")
+
+    def test_get_query_zone_name(self):
+        msg = self.getmsg()
+        self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
+  
+    def test_send_data(self):
+        self.xfrsess._send_data(self.sock, self.mdata)
+        senddata = self.sock.readsent()
+        self.assertEqual(senddata, self.mdata)
+
+    def test_reply_xfrout_query_with_error_rcode(self):
+        msg = self.getmsg()
+        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, rcode(3))
+        get_msg = self.sock.read_msg()
+        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN") 
+     
+    def test_clear_message(self):
+        msg = self.getmsg()
+        qid = msg.get_qid()
+        opcode = msg.get_opcode()
+        rcode = msg.get_rcode()
+
+        self.xfrsess._clear_message(msg)
+        self.assertEqual(msg.get_qid(), qid)
+        self.assertEqual(msg.get_opcode(), opcode)
+        self.assertEqual(msg.get_rcode(), rcode)
+        self.assertTrue(msg.get_header_flag(message_flag.AA()))
+
+    def test_reply_query_with_format_error(self):
+         
+        msg = self.getmsg()
+        self.xfrsess._reply_query_with_format_error(msg, self.sock)
+        get_msg = self.sock.read_msg()
+        self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
+
+    def test_create_rrset_from_db_record(self):
+        rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
+        self.assertEqual(rrset.get_name().to_text(), "example.com.")
+        self.assertEqual(rrset.get_class(), rr_class.IN())
+        self.assertEqual(rrset.get_type().to_text(), "SOA")
+        rdata_iter = rrset.get_rdata_iterator()
+        rdata_iter.first()
+        self.assertEqual(rdata_iter.get_current().to_text(), self.soa_record[7])
+
+    def test_send_message_with_last_soa(self):
+        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
+
+        msg = self.getmsg()
+        msg.make_response()
+        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa)
+        get_msg = self.sock.read_msg()
+
+        self.assertEqual(get_msg.get_rr_count(section.QUESTION()), 1)
+        self.assertEqual(get_msg.get_rr_count(section.ANSWER()), 1)
+        self.assertEqual(get_msg.get_rr_count(section.AUTHORITY()), 0)
+
+        answer_rrset_iter = section_iter(get_msg, section.ANSWER())
+        answer = answer_rrset_iter.get_rrset()
+        self.assertEqual(answer.get_name().to_text(), "example.com.")
+        self.assertEqual(answer.get_class(), rr_class.IN())
+        self.assertEqual(answer.get_type().to_text(), "SOA")
+        rdata_iter = answer.get_rdata_iterator()
+        rdata_iter.first()
+        self.assertEqual(rdata_iter.get_current().to_text(), self.soa_record[7])
+
+    def test_get_message_len(self):
+        msg = self.getmsg()
+        msg.make_response()  
+        self.assertEqual(self.xfrsess._get_message_len(msg), 29)
+
+    def test_zone_is_empty(self):
+        global sqlite3_ds
+        def mydb1(zone, file):
+            return True
+        sqlite3_ds.get_zone_soa = mydb1
+        self.assertEqual(self.xfrsess._zone_is_empty(""), False)
+        def mydb2(zone, file):
+            return False
+        sqlite3_ds.get_zone_soa = mydb2
+        self.assertEqual(self.xfrsess._zone_is_empty(""), True)
+
+    def test_zone_exist(self):
+        global sqlite3_ds
+        def zone_soa(zone, file):
+            return zone
+        sqlite3_ds.get_zone_soa = zone_soa
+        self.assertEqual(self.xfrsess._zone_exist(True), True)
+        self.assertEqual(self.xfrsess._zone_exist(False), False)
+    
+    def test_check_xfrout_available(self):
+        def zone_exist(zone):
+            return zone
+        self.xfrsess._zone_exist = zone_exist
+        self.xfrsess._zone_is_empty = zone_exist
+        self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
+        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
+
+        def zone_empty(zone):
+            return not zone
+        self.xfrsess._zone_is_empty = zone_empty
+        def false_func():
+            return False
+        self.xfrsess.server.increase_transfers_counter = false_func
+        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
+        def true_func():
+            return True
+        self.xfrsess.server.increase_transfers_counter = true_func
+        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
+
+    def test_dns_xfrout_start_formerror(self):
+        # formerror
+        self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
+        sent_data = self.sock.readsent()
+        self.assertEqual(len(sent_data), 0)
+    
+    def default(self, param):
+        return "example.com"
+
+    def test_dns_xfrout_start_notauth(self):
+        self.xfrsess._get_query_zone_name = self.default
+        def notauth(formpara):
+            return rcode.NOTAUTH()
+        self.xfrsess._check_xfrout_available = notauth
+        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
+        get_msg = self.sock.read_msg()
+        self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
+    
+    def test_dns_xfrout_start_noerror(self):
+        self.xfrsess._get_query_zone_name = self.default
+        def noerror(form):
+            return rcode.NOERROR() 
+        self.xfrsess._check_xfrout_available = noerror
+        
+        def myreply(msg, sock, zonename):
+            self.sock.send(b"success")
+        
+        self.xfrsess._reply_xfrout_query = myreply
+        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
+        self.assertEqual(self.sock.readsent(), b"success")
+    
+    def test_reply_xfrout_query_noerror(self):
+        global sqlite3_ds
+        def get_zone_soa(zonename, file):
+            return self.soa_record
+
+        def get_zone_datas(zone, file):
+            return [self.soa_record]
+
+        sqlite3_ds.get_zone_soa = get_zone_soa
+        sqlite3_ds.get_zone_datas = get_zone_datas
+        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
+        reply_msg = self.sock.read_msg()
+        self.assertEqual(reply_msg.get_rr_count(section.ANSWER()), 2)
+
+        # set event
+        self.xfrsess.server._shutdown_event.set()
+        self.assertRaises(XfroutException, self.xfrsess._reply_xfrout_query, self.getmsg(), self.sock, "example.com.")
+
+class MyUnixSockServer(UnixSockServer):
+    def __init__(self):
+        self._lock = threading.Lock()
+        self._transfers_counter = 0
+        self._shutdown_event = threading.Event()
+        self._db_file = "initdb.file"
+        self._max_transfers_out = 10
+
+class TestUnixSockServer(unittest.TestCase):
+    def setUp(self):
+        self.unix = MyUnixSockServer()
+     
+    def test_updata_config_data(self):
+        self.unix.update_config_data({'transfers_out':10, 'db_file':"db.file"})
+        self.assertEqual(self.unix._max_transfers_out, 10)
+        self.assertEqual(self.unix._db_file, "db.file")
+
+    def test_get_db_file(self):
+        self.assertEqual(self.unix.get_db_file(), "initdb.file")
+
+    def test_increase_transfers_counter(self):
+        self.unix._max_transfers_out = 10
+        count = self.unix._transfers_counter
+        self.assertEqual(self.unix.increase_transfers_counter(), True)
+        self.assertEqual(count + 1, self.unix._transfers_counter)
+
+        self.unix._max_transfers_out = 0
+        count = self.unix._transfers_counter
+        self.assertEqual(self.unix.increase_transfers_counter(), False)
+        self.assertEqual(count, self.unix._transfers_counter)
+ 
+    def test_decrease_transfers_counter(self):
+        count = self.unix._transfers_counter
+        self.unix.decrease_transfers_counter()
+        self.assertEqual(count - 1, self.unix._transfers_counter)
+
+
+if __name__== "__main__":
+    unittest.main()

+ 5 - 3
src/bin/xfrout/xfrout.py.in

@@ -43,7 +43,7 @@ else:
 SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
 
 MAX_TRANSFERS_OUT = 10
-verbose_mode = True
+verbose_mode = False
 
 
 class XfroutException(Exception): pass
@@ -59,7 +59,8 @@ class XfroutSession(BaseRequestHandler):
         try:
             self.dns_xfrout_start(sock, msgdata)
         except Exception as e:
-            self.log_msg(str(e))
+            if verbose_mode:
+                self.log_msg(str(e))
 
         sock.close()
                
@@ -70,7 +71,8 @@ class XfroutSession(BaseRequestHandler):
             msg = message(message_mode.PARSE)
             msg.from_wire(input_buffer(mdata))
         except Exception as err:
-            self.log_msg(str(err))
+            if verbose_mode:
+                self.log_msg(str(err))
             return rcode.FORMERR(), None
 
         return rcode.NOERROR(), msg