Browse Source

merge branches/trac185 (trac #185): more tests for xfrin with some bug fixes.

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@2000 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya 15 years ago
parent
commit
85bc9eea0d
3 changed files with 346 additions and 126 deletions
  1. 12 0
      src/bin/xfrin/TODO
  2. 204 58
      src/bin/xfrin/tests/xfrin_test.py
  3. 130 68
      src/bin/xfrin/xfrin.py.in

+ 12 - 0
src/bin/xfrin/TODO

@@ -3,22 +3,34 @@
    occur at the same time.  (but testing it would be very difficult)
    occur at the same time.  (but testing it would be very difficult)
 3. It wouldn't support IPv6 because of the following line:
 3. It wouldn't support IPv6 because of the following line:
         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+   [FIXED in r1851]
 4. Xfrin.retransfer and refresh share most of the code.  should be unified.
 4. Xfrin.retransfer and refresh share most of the code.  should be unified.
+   [FIXED in r1861]
 5. class IN is hardcoded.  bad.
 5. class IN is hardcoded.  bad.
         query_question = question(name(self._zone_name), rr_class.IN(), query_type)
         query_question = question(name(self._zone_name), rr_class.IN(), query_type)
+   [FIXED in r1889]
+   Note: we still hardcode it as the fixed default value for
+   retransfer/refresh commands.
+   we should fix this so that this is specifiable, so this TODO item is 
+   still open.
 6. QID 0 should be allowed:
 6. QID 0 should be allowed:
         query_id = random.randint(1, 0xFFFF)
         query_id = random.randint(1, 0xFFFF)
+   [FIXED in r1880]
 7. what if xfrin fails after opening a new DB?  looks like garbage
 7. what if xfrin fails after opening a new DB?  looks like garbage
    (intermediate) data remains in the DB file, although it's more about
    (intermediate) data remains in the DB file, although it's more about
    the data source implementation.  check it, and fix it if it's the case.
    the data source implementation.  check it, and fix it if it's the case.
 8. Xfrin.command_handler() ignores unknown commands.  should return an error.
 8. Xfrin.command_handler() ignores unknown commands.  should return an error.
+   [FIXED in r1882]
 9. XfrinConnection can leak sockets. (same problem as that Jelte mentioned
 9. XfrinConnection can leak sockets. (same problem as that Jelte mentioned
    on xfrout?)
    on xfrout?)
+   [FIXED in r1908]
 10. The following line of _check_soa_serial() is incorrect.
 10. The following line of _check_soa_serial() is incorrect.
         soa_reply = self._get_request_response(int(data_size))
         soa_reply = self._get_request_response(int(data_size))
     Unpack the data and convert it in the host by order.
     Unpack the data and convert it in the host by order.
+    [FIXED in r1866]
 11. if do_xfrin fails it should probably return a non "OK" value.
 11. if do_xfrin fails it should probably return a non "OK" value.
     (it's currently ignored anyway, though)
     (it's currently ignored anyway, though)
+    [FIXED in r1887]
 12. XfrinConnection should probably define handle_close().  Also, the
 12. XfrinConnection should probably define handle_close().  Also, the
     following part should be revised because this can also happen when the
     following part should be revised because this can also happen when the
     master closes the connection.
     master closes the connection.

+ 204 - 58
src/bin/xfrin/tests/xfrin_test.py

@@ -26,7 +26,13 @@ TEST_ZONE_NAME = "example.com"
 TEST_RRCLASS = rr_class.IN()
 TEST_RRCLASS = rr_class.IN()
 TEST_DB_FILE = 'db_file'
 TEST_DB_FILE = 'db_file'
 TEST_MASTER_IPV4_ADDRESS = '127.0.0.1'
 TEST_MASTER_IPV4_ADDRESS = '127.0.0.1'
+TEST_MASTER_IPV4_ADDRINFO = (socket.AF_INET, socket.SOCK_STREAM,
+                             socket.IPPROTO_TCP, '',
+                             (TEST_MASTER_IPV4_ADDRESS, 53))
 TEST_MASTER_IPV6_ADDRESS = '::1'
 TEST_MASTER_IPV6_ADDRESS = '::1'
+TEST_MASTER_IPV6_ADDRINFO = (socket.AF_INET6, socket.SOCK_STREAM,
+                             socket.IPPROTO_TCP, '',
+                             (TEST_MASTER_IPV6_ADDRESS, 53))
 # XXX: This should be a non priviledge port that is unlikely to be used.
 # XXX: This should be a non priviledge port that is unlikely to be used.
 # If some other process uses this port test will fail.
 # If some other process uses this port test will fail.
 TEST_MASTER_PORT = '53535'
 TEST_MASTER_PORT = '53535'
@@ -37,32 +43,45 @@ soa_rdata = create_rdata(rr_type.SOA(), TEST_RRCLASS,
 soa_rrset = rrset(name(TEST_ZONE_NAME), TEST_RRCLASS, rr_type.SOA(),
 soa_rrset = rrset(name(TEST_ZONE_NAME), TEST_RRCLASS, rr_type.SOA(),
                   rr_ttl(3600))
                   rr_ttl(3600))
 soa_rrset.add_rdata(soa_rdata)
 soa_rrset.add_rdata(soa_rdata)
-example_question = question(name(TEST_ZONE_NAME), TEST_RRCLASS, rr_type.AXFR())
-default_questions = [example_question]
+example_axfr_question = question(name(TEST_ZONE_NAME), TEST_RRCLASS,
+                                 rr_type.AXFR())
+example_soa_question = question(name(TEST_ZONE_NAME), TEST_RRCLASS,
+                                 rr_type.SOA())
+default_questions = [example_axfr_question]
 default_answers = [soa_rrset]
 default_answers = [soa_rrset]
 
 
 class XfrinTestException(Exception):
 class XfrinTestException(Exception):
     pass
     pass
 
 
-# Rewrite the class for unittest.
 class MockXfrin(Xfrin):
 class MockXfrin(Xfrin):
+    # This is a class attribute of a callable object that specifies a non
+    # default behavior triggered in _cc_check_command().  Specific test methods
+    # are expected to explicitly set this attribute before creating a
+    # MockXfrin object (when it needs a non default behavior).
+    # See the TestMain class.
+    check_command_hook = None
+
     def _cc_setup(self):
     def _cc_setup(self):
         pass
         pass
+    
+    def _cc_check_command(self):
+        self._shutdown_event.set()
+        if MockXfrin.check_command_hook:
+            MockXfrin.check_command_hook()
 
 
 class MockXfrinConnection(XfrinConnection):
 class MockXfrinConnection(XfrinConnection):
-    def __init__(self, TEST_ZONE_NAME, db_file, shutdown_event, master_addr):
-        super().__init__(TEST_ZONE_NAME, db_file, shutdown_event, master_addr)
+    def __init__(self, sock_map, zone_name, rrclass, db_file, shutdown_event,
+                 master_addr):
+        super().__init__(sock_map, zone_name, rrclass, db_file, shutdown_event,
+                         master_addr)
         self.query_data = b''
         self.query_data = b''
         self.reply_data = b''
         self.reply_data = b''
         self.force_time_out = False
         self.force_time_out = False
         self.force_close = False
         self.force_close = False
+        self.qlen = None
         self.qid = None
         self.qid = None
         self.response_generator = None
         self.response_generator = None
 
 
-    def _handle_xfrin_response(self):
-        for rr in super()._handle_xfrin_response():
-            pass
-
     def _asyncore_loop(self):
     def _asyncore_loop(self):
         if self.force_close:
         if self.force_close:
             self.handle_close()
             self.handle_close()
@@ -80,11 +99,21 @@ class MockXfrinConnection(XfrinConnection):
         return data
         return data
 
 
     def send(self, data):
     def send(self, data):
+        if self.qlen != None and len(self.query_data) >= self.qlen:
+            # This is a new query.  reset the internal state.
+            self.qlen = None
+            self.qid = None
+            self.query_data = b''
         self.query_data += data
         self.query_data += data
-        # when the outgoing data is sufficiently large to contain the QID field
-        # (4 octets or more - 16-bit length field + 16-bit QID), extract the
-        # value so that we can construct a matching response.
+
+        # when the outgoing data is sufficiently large to contain the length
+        # and the QID fields (4 octets or more), extract these fields.
+        # The length will be reset the internal query data to support multiple
+        # queries in a single test.
+        # The QID will be used to construct a matching response.
         if len(self.query_data) >= 4 and self.qid == None:
         if len(self.query_data) >= 4 and self.qid == None:
+            self.qlen = socket.htons(struct.unpack('H',
+                                                   self.query_data[0:2])[0])
             self.qid = socket.htons(struct.unpack('H', self.query_data[2:4])[0])
             self.qid = socket.htons(struct.unpack('H', self.query_data[2:4])[0])
             # if the response generator method is specified, invoke it now.
             # if the response generator method is specified, invoke it now.
             if self.response_generator != None:
             if self.response_generator != None:
@@ -119,92 +148,150 @@ class TestXfrinConnection(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         if os.path.exists(TEST_DB_FILE):
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
             os.remove(TEST_DB_FILE)
-        self.conn = MockXfrinConnection('example.com.', TEST_DB_FILE,
+        self.sock_map = {}
+        self.conn = MockXfrinConnection(self.sock_map, 'example.com.',
+                                        TEST_RRCLASS, TEST_DB_FILE,
                                         threading.Event(),
                                         threading.Event(),
-                                        TEST_MASTER_IPV4_ADDRESS)
+                                        TEST_MASTER_IPV4_ADDRINFO)
+        self.axfr_after_soa = False
+        self.soa_response_params = {
+            'questions': [example_soa_question],
+            'bad_qid': False,
+            'response': True,
+            'rcode': rcode.NOERROR(),
+            'axfr_after_soa': self._create_normal_response_data
+            }
 
 
     def tearDown(self):
     def tearDown(self):
         self.conn.close()
         self.conn.close()
         if os.path.exists(TEST_DB_FILE):
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
             os.remove(TEST_DB_FILE)
 
 
+    def test_close(self):
+        # we shouldn't be using the global asyncore map.
+        self.assertEqual(len(asyncore.socket_map), 0)
+        # there should be exactly one entry in our local map
+        self.assertEqual(len(self.sock_map), 1)
+        # once closing the dispatch the map should become empty
+        self.conn.close()
+        self.assertEqual(len(self.sock_map), 0)
+
     def test_init_ip6(self):
     def test_init_ip6(self):
         # This test simply creates a new XfrinConnection object with an
         # This test simply creates a new XfrinConnection object with an
         # IPv6 address, tries to bind it to an IPv6 wildcard address/port
         # IPv6 address, tries to bind it to an IPv6 wildcard address/port
         # to confirm an AF_INET6 socket has been created.  A naive application
         # to confirm an AF_INET6 socket has been created.  A naive application
         # tends to assume it's IPv4 only and hardcode AF_INET.  This test
         # tends to assume it's IPv4 only and hardcode AF_INET.  This test
         # uncovers such a bug.
         # uncovers such a bug.
-        c = MockXfrinConnection('example.com.', TEST_DB_FILE,
+        c = MockXfrinConnection({}, 'example.com.', TEST_RRCLASS, TEST_DB_FILE,
                                 threading.Event(),
                                 threading.Event(),
-                                TEST_MASTER_IPV6_ADDRESS)
-        #This test currently fails.  Fix the code, then enable it
-        #c.bind(('::', 0))
+                                TEST_MASTER_IPV6_ADDRINFO)
+        c.bind(('::', 0))
+        c.close()
+
+    def test_init_chclass(self):
+        c = XfrinConnection({}, 'example.com.', rr_class.CH(), TEST_DB_FILE,
+                            threading.Event(), TEST_MASTER_IPV4_ADDRINFO)
+        axfrmsg = c._create_query(rr_type.AXFR())
+        self.assertEqual(question_iter(axfrmsg).get_question().get_class(),
+                         rr_class.CH())
         c.close()
         c.close()
 
 
     def test_response_with_invalid_msg(self):
     def test_response_with_invalid_msg(self):
         self.conn.reply_data = b'aaaxxxx'
         self.conn.reply_data = b'aaaxxxx'
-        self.assertRaises(XfrinTestException, self.conn._handle_xfrin_response)
+        self.assertRaises(XfrinTestException, self._handle_xfrin_response)
 
 
     def test_response_without_end_soa(self):
     def test_response_without_end_soa(self):
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
         self.conn.reply_data = self.conn.create_response_data()
         self.conn.reply_data = self.conn.create_response_data()
-        self.assertRaises(XfrinTestException, self.conn._handle_xfrin_response)
+        self.assertRaises(XfrinTestException, self._handle_xfrin_response)
 
 
     def test_response_bad_qid(self):
     def test_response_bad_qid(self):
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
         self.conn.reply_data = self.conn.create_response_data(bad_qid = True)
         self.conn.reply_data = self.conn.create_response_data(bad_qid = True)
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
 
 
     def test_response_non_response(self):
     def test_response_non_response(self):
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
         self.conn.reply_data = self.conn.create_response_data(response = False)
         self.conn.reply_data = self.conn.create_response_data(response = False)
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
 
 
     def test_response_error_code(self):
     def test_response_error_code(self):
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
         self.conn.reply_data = self.conn.create_response_data(
         self.conn.reply_data = self.conn.create_response_data(
-            rcode = rcode.SERVFAIL())
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
+            rcode=rcode.SERVFAIL())
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
 
 
     def test_response_multi_question(self):
     def test_response_multi_question(self):
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
         self.conn.reply_data = self.conn.create_response_data(
         self.conn.reply_data = self.conn.create_response_data(
-            questions=[example_question, example_question])
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
+            questions=[example_axfr_question, example_axfr_question])
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
 
 
     def test_response_empty_answer(self):
     def test_response_empty_answer(self):
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
         self.conn.reply_data = self.conn.create_response_data(answers=[])
         self.conn.reply_data = self.conn.create_response_data(answers=[])
         # Should an empty answer trigger an exception?  Even though it's very
         # Should an empty answer trigger an exception?  Even though it's very
         # unusual it's not necessarily invalid.  Need to revisit.
         # unusual it's not necessarily invalid.  Need to revisit.
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
+
+    def test_response_non_response(self):
+        self.conn._send_query(rr_type.AXFR())
+        self.conn.reply_data = self.conn.create_response_data(response = False)
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
+
+    def test_soacheck(self):
+        # we need to defer the creation until we know the QID, which is
+        # determined in _check_soa_serial(), so we use response_generator.
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertEqual(self.conn._check_soa_serial(), XFRIN_OK)
+
+    def test_soacheck_with_bad_response(self):
+        self.conn.response_generator = self._create_broken_response_data
+        self.assertRaises(UserWarning, self.conn._check_soa_serial)
+
+    def test_soacheck_badqid(self):
+        self.soa_response_params['bad_qid'] = True
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
+    def test_soacheck_non_response(self):
+        self.soa_response_params['response'] = False
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
+    def test_soacheck_error_code(self):
+        self.soa_response_params['rcode'] = rcode.SERVFAIL()
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
 
 
     def test_response_shutdown(self):
     def test_response_shutdown(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.response_generator = self._create_normal_response_data
         self.conn._shutdown_event.set()
         self.conn._shutdown_event.set()
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
 
 
     def test_response_timeout(self):
     def test_response_timeout(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.response_generator = self._create_normal_response_data
         self.conn.force_time_out = True
         self.conn.force_time_out = True
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
 
 
     def test_response_remote_close(self):
     def test_response_remote_close(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.response_generator = self._create_normal_response_data
         self.conn.force_close = True
         self.conn.force_close = True
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
 
 
     def test_response_bad_message(self):
     def test_response_bad_message(self):
         self.conn.response_generator = self._create_broken_response_data
         self.conn.response_generator = self._create_broken_response_data
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
-        self.assertRaises(Exception, self.conn._handle_xfrin_response)
+        self.assertRaises(Exception, self._handle_xfrin_response)
 
 
     def test_response(self):
     def test_response(self):
-        # normal case.  should silently succeed.
+        # normal case.
         self.conn.response_generator = self._create_normal_response_data
         self.conn.response_generator = self._create_normal_response_data
         self.conn._send_query(rr_type.AXFR())
         self.conn._send_query(rr_type.AXFR())
-        self.conn._handle_xfrin_response()
+        # two SOAs, and only these have been transfered.  the 2nd SOA is just
+        # a marker, so only 1 RR has been provided in the iteration.
+        self.assertEqual(self._handle_xfrin_response(), 1)
 
 
     def test_do_xfrin(self):
     def test_do_xfrin(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.response_generator = self._create_normal_response_data
@@ -212,30 +299,41 @@ class TestXfrinConnection(unittest.TestCase):
 
 
     def test_do_xfrin_empty_response(self):
     def test_do_xfrin_empty_response(self):
         # skipping the creation of response data, so the transfer will fail.
         # skipping the creation of response data, so the transfer will fail.
-        # (but do_xfrin() always return OK.)
-        self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
-
-    def test_do_xfrin_empty_response(self):
-        self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
 
 
     def test_do_xfrin_bad_response(self):
     def test_do_xfrin_bad_response(self):
         self.conn.response_generator = self._create_broken_response_data
         self.conn.response_generator = self._create_broken_response_data
-        self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
 
 
     def test_do_xfrin_dberror(self):
     def test_do_xfrin_dberror(self):
         # DB file is under a non existent directory, so its creation will fail,
         # DB file is under a non existent directory, so its creation will fail,
         # which will make the transfer fail.
         # which will make the transfer fail.
         self.conn._db_file = "not_existent/" + TEST_DB_FILE
         self.conn._db_file = "not_existent/" + TEST_DB_FILE
-        self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
 
 
-# This test currently doesn't work due to bug.  Fix it and then enable the test.
-#     def test_do_xfrin_with_soacheck(self):
-#         self.conn.response_generator = self._create_normal_response_data
-#         self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK)
+    def test_do_soacheck_and_xfrin(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK)
 
 
-#     def test_do_xfrin_with_soacheck_bad_response(self):
-#         self.conn.response_generator = self._create_broken_response_data
-#         self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK)
+    def test_do_soacheck_broken_response(self):
+        self.conn.response_generator = self._create_broken_response_data
+        self.assertEqual(self.conn.do_xfrin(True), XFRIN_FAIL)
+
+    def test_do_soacheck_badqid(self):
+        # the QID mismatch would internally trigger a XfrinException exception,
+        # and covers part of the code that other tests can't.
+        self.soa_response_params['bad_qid'] = True
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertEqual(self.conn.do_xfrin(True), XFRIN_FAIL)
+
+    def _handle_xfrin_response(self):
+        # This helper methods iterates over all RRs (excluding the ending SOA)
+        # transferred, and simply returns the number of RRs.  The return value
+        # may be used an assertion value for test cases.
+        rrs = 0
+        for rr in self.conn._handle_xfrin_response():
+            rrs += 1
+        return rrs
 
 
     def _create_normal_response_data(self):
     def _create_normal_response_data(self):
         # This helper method creates a simple sequence of DNS messages that
         # This helper method creates a simple sequence of DNS messages that
@@ -244,6 +342,19 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = self.conn.create_response_data()
         self.conn.reply_data = self.conn.create_response_data()
         self.conn.reply_data += self.conn.create_response_data()
         self.conn.reply_data += self.conn.create_response_data()
 
 
+    def _create_soa_response_data(self):
+        # This helper method creates a DNS message that is supposed to be
+        # used a valid response to SOA queries prior to XFR.
+        # If axfr_after_soa is True, it resets the response_generator so that
+        # a valid XFR messages will follow.
+        self.conn.reply_data = self.conn.create_response_data(
+            bad_qid=self.soa_response_params['bad_qid'],
+            response=self.soa_response_params['response'],
+            rcode=self.soa_response_params['rcode'],
+            questions=self.soa_response_params['questions'])
+        if self.soa_response_params['axfr_after_soa'] != None:
+            self.conn.response_generator = self.soa_response_params['axfr_after_soa']
+
     def _create_broken_response_data(self):
     def _create_broken_response_data(self):
         # This helper method creates a bogus "DNS message" that only contains
         # This helper method creates a bogus "DNS message" that only contains
         # 4 octets of data.  The DNS message parser will raise an exception.
         # 4 octets of data.  The DNS message parser will raise an exception.
@@ -284,10 +395,9 @@ class TestXfrinRecorder(unittest.TestCase):
         self.assertEqual(self.recorder.xfrin_in_progress(TEST_ZONE_NAME), False)
         self.assertEqual(self.recorder.xfrin_in_progress(TEST_ZONE_NAME), False)
 
 
 class TestXfrin(unittest.TestCase):
 class TestXfrin(unittest.TestCase):
-    args = {}
-
     def setUp(self):
     def setUp(self):
         self.xfr = MockXfrin()
         self.xfr = MockXfrin()
+        self.args = {}
         self.args['zone_name'] = TEST_ZONE_NAME
         self.args['zone_name'] = TEST_ZONE_NAME
         self.args['port'] = TEST_MASTER_PORT
         self.args['port'] = TEST_MASTER_PORT
         self.args['master'] = TEST_MASTER_IPV4_ADDRESS
         self.args['master'] = TEST_MASTER_IPV4_ADDRESS
@@ -300,19 +410,21 @@ class TestXfrin(unittest.TestCase):
         return self.xfr._parse_cmd_params(self.args)
         return self.xfr._parse_cmd_params(self.args)
 
 
     def test_parse_cmd_params(self):
     def test_parse_cmd_params(self):
-        name, master, port, db_file = self._do_parse()
-        self.assertEqual(port, int(TEST_MASTER_PORT))
+        name, master_addrinfo, db_file = self._do_parse()
+        self.assertEqual(master_addrinfo[4][1], int(TEST_MASTER_PORT))
         self.assertEqual(name, TEST_ZONE_NAME)
         self.assertEqual(name, TEST_ZONE_NAME)
-        self.assertEqual(master, TEST_MASTER_IPV4_ADDRESS)
+        self.assertEqual(master_addrinfo[4][0], TEST_MASTER_IPV4_ADDRESS)
         self.assertEqual(db_file, TEST_DB_FILE)
         self.assertEqual(db_file, TEST_DB_FILE)
 
 
     def test_parse_cmd_params_default_port(self):
     def test_parse_cmd_params_default_port(self):
         del self.args['port']
         del self.args['port']
-        self.assertEqual(self._do_parse()[2], 53)
+        master_addrinfo = self._do_parse()[1]
+        self.assertEqual(master_addrinfo[4][1], 53)
 
 
     def test_parse_cmd_params_ip6master(self):
     def test_parse_cmd_params_ip6master(self):
         self.args['master'] = TEST_MASTER_IPV6_ADDRESS
         self.args['master'] = TEST_MASTER_IPV6_ADDRESS
-        self.assertEqual(self._do_parse()[1], TEST_MASTER_IPV6_ADDRESS)
+        master_addrinfo = self._do_parse()[1]
+        self.assertEqual(master_addrinfo[4][0], TEST_MASTER_IPV6_ADDRESS)
 
 
     def test_parse_cmd_params_nozone(self):
     def test_parse_cmd_params_nozone(self):
         # zone name is mandatory.
         # zone name is mandatory.
@@ -325,7 +437,7 @@ class TestXfrin(unittest.TestCase):
         self.assertRaises(XfrinException, self._do_parse)
         self.assertRaises(XfrinException, self._do_parse)
 
 
     def test_parse_cmd_params_bad_ip4(self):
     def test_parse_cmd_params_bad_ip4(self):
-        self.args['master'] = '3.3.3'
+        self.args['master'] = '3.3.3.3.3'
         self.assertRaises(XfrinException, self._do_parse)
         self.assertRaises(XfrinException, self._do_parse)
 
 
     def test_parse_cmd_params_bad_ip6(self):
     def test_parse_cmd_params_bad_ip6(self):
@@ -339,6 +451,9 @@ class TestXfrin(unittest.TestCase):
         self.args['port'] = '65536'
         self.args['port'] = '65536'
         self.assertRaises(XfrinException, self._do_parse)
         self.assertRaises(XfrinException, self._do_parse)
 
 
+        self.args['port'] = 'http'
+        self.assertRaises(XfrinException, self._do_parse)
+
     def test_command_handler_shutdown(self):
     def test_command_handler_shutdown(self):
         self.assertEqual(self.xfr.command_handler("shutdown",
         self.assertEqual(self.xfr.command_handler("shutdown",
                                                   None)['result'][0], 0)
                                                   None)['result'][0], 0)
@@ -346,9 +461,6 @@ class TestXfrin(unittest.TestCase):
         self.assertEqual(self.xfr.command_handler("shutdown",
         self.assertEqual(self.xfr.command_handler("shutdown",
                                                   "unused")['result'][0], 0)
                                                   "unused")['result'][0], 0)
 
 
-        self.assertEqual(self.xfr.command_handler("Shutdown",
-                                                  "unused")['result'][0], 0)
-
     def test_command_handler_retransfer(self):
     def test_command_handler_retransfer(self):
         self.assertEqual(self.xfr.command_handler("retransfer",
         self.assertEqual(self.xfr.command_handler("retransfer",
                                                   self.args)['result'][0], 0)
                                                   self.args)['result'][0], 0)
@@ -390,6 +502,40 @@ class TestXfrin(unittest.TestCase):
         self.assertEqual(self.xfr.command_handler("refresh",
         self.assertEqual(self.xfr.command_handler("refresh",
                                                   self.args)['result'][0], 0)
                                                   self.args)['result'][0], 0)
 
 
+    def test_command_handler_unknown(self):
+        self.assertEqual(self.xfr.command_handler("xxx", None)['result'][0], 1)
+
+def raise_interrupt():
+    raise KeyboardInterrupt()
+
+def raise_ccerror():
+    raise isc.cc.session.SessionError('test error')
+
+def raise_excpetion():
+    raise Exception('test exception')
+
+class TestMain(unittest.TestCase):
+    def setUp(self):
+        MockXfrin.check_command_hook = None
+
+    def tearDown(self):
+        MockXfrin.check_command_hook = None
+
+    def test_startup(self):
+        main(MockXfrin, False)
+
+    def test_startup_interrupt(self):
+        MockXfrin.check_command_hook = raise_interrupt
+        main(MockXfrin, False)
+
+    def test_startup_ccerror(self):
+        MockXfrin.check_command_hook = raise_ccerror
+        main(MockXfrin, False)
+
+    def test_startup_generalerror(self):
+        MockXfrin.check_command_hook = raise_excpetion
+        main(MockXfrin, False)
+
 if __name__== "__main__":
 if __name__== "__main__":
     try:
     try:
         unittest.main()
         unittest.main()

+ 130 - 68
src/bin/xfrin/xfrin.py.in

@@ -50,6 +50,9 @@ SPECFILE_LOCATION = SPECFILE_PATH + "/xfrin.spec"
 __version__ = 'BIND10'
 __version__ = 'BIND10'
 # define xfrin rcode
 # define xfrin rcode
 XFRIN_OK = 0
 XFRIN_OK = 0
+XFRIN_FAIL = 1
+
+DEFAULT_MASTER_PORT = '53'
 
 
 def log_error(msg):
 def log_error(msg):
     sys.stderr.write("[b10-xfrin] ")
     sys.stderr.write("[b10-xfrin] ")
@@ -62,46 +65,48 @@ class XfrinException(Exception):
 class XfrinConnection(asyncore.dispatcher):
 class XfrinConnection(asyncore.dispatcher):
     '''Do xfrin in this class. '''    
     '''Do xfrin in this class. '''    
 
 
-    def __init__(self, 
-                 zone_name, db_file, shutdown_event, master_addr, 
-                 port = 53, verbose = False, idle_timeout = 60): 
+    def __init__(self,
+                 sock_map, zone_name, rrclass, db_file, shutdown_event,
+                 master_addrinfo, verbose = False, idle_timeout = 60): 
         ''' idle_timeout: max idle time for read data from socket.
         ''' idle_timeout: max idle time for read data from socket.
             db_file: specify the data source file.
             db_file: specify the data source file.
             check_soa: when it's true, check soa first before sending xfr query
             check_soa: when it's true, check soa first before sending xfr query
         '''
         '''
 
 
-        asyncore.dispatcher.__init__(self)
-        self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+        asyncore.dispatcher.__init__(self, map=sock_map)
+        self.create_socket(master_addrinfo[0], master_addrinfo[1])
         self._zone_name = zone_name
         self._zone_name = zone_name
+        self._rrclass = rrclass
         self._db_file = db_file
         self._db_file = db_file
         self._soa_rr_count = 0
         self._soa_rr_count = 0
         self._idle_timeout = idle_timeout
         self._idle_timeout = idle_timeout
         self.setblocking(1)
         self.setblocking(1)
         self._shutdown_event = shutdown_event
         self._shutdown_event = shutdown_event
         self._verbose = verbose
         self._verbose = verbose
-        self._master_addr = master_addr
-        self._port = port
+        self._master_address = master_addrinfo[4]
 
 
     def connect_to_master(self):
     def connect_to_master(self):
         '''Connect to master in TCP.'''
         '''Connect to master in TCP.'''
 
 
         try:
         try:
-            self.connect((self._master_addr, self._port))
+            self.connect(self._master_address)
             return True
             return True
         except socket.error as e:
         except socket.error as e:
-            self.log_msg('Failed to connect:(%s:%d), %s' % (self._master_addr, self._port, str(e)))
+            self.log_msg('Failed to connect:(%s), %s' % (self._master_address,
+                                                            str(e)))
             return False
             return False
 
 
     def _create_query(self, query_type):
     def _create_query(self, query_type):
         '''Create dns query message. '''
         '''Create dns query message. '''
 
 
         msg = message(message_mode.RENDER)
         msg = message(message_mode.RENDER)
-        query_id = random.randint(1, 0xFFFF)
+        query_id = random.randint(0, 0xFFFF)
         self._query_id = query_id
         self._query_id = query_id
         msg.set_qid(query_id)
         msg.set_qid(query_id)
         msg.set_opcode(op_code.QUERY())
         msg.set_opcode(op_code.QUERY())
         msg.set_rcode(rcode.NOERROR())
         msg.set_rcode(rcode.NOERROR())
-        query_question = question(name(self._zone_name), rr_class.IN(), query_type)
+        query_question = question(name(self._zone_name), self._rrclass,
+                                  query_type)
         msg.add_question(query_question)
         msg.add_question(query_question)
         return msg
         return msg
 
 
@@ -155,10 +160,19 @@ involving actual communication with a remote server.
         '''
         '''
 
 
         self._send_query(rr_type.SOA())
         self._send_query(rr_type.SOA())
-        data_size = self._get_request_response(2)
-        soa_reply = self._get_request_response(int(data_size))
-        #TODO, need select soa record from data source then compare the two 
-        #serial, current just return OK, since this function hasn't been used now 
+        data_len = self._get_request_response(2)
+        msg_len = socket.htons(struct.unpack('H', data_len)[0])
+        soa_response = self._get_request_response(msg_len)
+        msg = message(message_mode.PARSE)
+        msg.from_wire(input_buffer(soa_response))
+
+        # perform some minimal level validation.  It's an open issue how
+        # strict we should be (see the comment in _check_response_header())
+        self._check_response_header(msg)
+
+        # TODO, need select soa record from data source then compare the two 
+        # serial, current just return OK, since this function hasn't been used
+        # now.
         return XFRIN_OK
         return XFRIN_OK
 
 
     def do_xfrin(self, check_soa, ixfr_first = False):
     def do_xfrin(self, check_soa, ixfr_first = False):
@@ -167,33 +181,52 @@ involving actual communication with a remote server.
         try:
         try:
             ret = XFRIN_OK
             ret = XFRIN_OK
             if check_soa:
             if check_soa:
+                logstr = 'SOA check for \'%s\' ' % self._zone_name
                 ret =  self._check_soa_serial()
                 ret =  self._check_soa_serial()
             
             
             logstr = 'transfer of \'%s\': AXFR ' % self._zone_name
             logstr = 'transfer of \'%s\': AXFR ' % self._zone_name
-            if ret == XFRIN_OK:    
+            if ret == XFRIN_OK:
                 self.log_msg(logstr + 'started')
                 self.log_msg(logstr + 'started')
                 self._send_query(rr_type.AXFR())
                 self._send_query(rr_type.AXFR())
                 isc.datasrc.sqlite3_ds.load(self._db_file, self._zone_name,
                 isc.datasrc.sqlite3_ds.load(self._db_file, self._zone_name,
                                             self._handle_xfrin_response)
                                             self._handle_xfrin_response)
 
 
                 self.log_msg(logstr + 'succeeded')
                 self.log_msg(logstr + 'succeeded')
+                ret = XFRIN_OK
 
 
         except XfrinException as e:
         except XfrinException as e:
             self.log_msg(e)
             self.log_msg(e)
             self.log_msg(logstr + 'failed')
             self.log_msg(logstr + 'failed')
+            ret = XFRIN_FAIL
             #TODO, recover data source.
             #TODO, recover data source.
         except isc.datasrc.sqlite3_ds.Sqlite3DSError as e:
         except isc.datasrc.sqlite3_ds.Sqlite3DSError as e:
             self.log_msg(e)
             self.log_msg(e)
             self.log_msg(logstr + 'failed')
             self.log_msg(logstr + 'failed')
+            ret = XFRIN_FAIL
+        except UserWarning as e:
+            # XXX: this is an exception from our C++ library via the
+            # Boost.Python binding.  It would be better to have more more
+            # specific exceptions, but at this moment this is the finest
+            # granularity.
+            self.log_msg(e)
+            self.log_msg(logstr + 'failed')
+            ret = XFRIN_FAIL
         finally:
         finally:
            self.close()
            self.close()
 
 
         return ret
         return ret
-    
-    def _check_response_status(self, msg):
-        '''Check validation of xfr response. '''
 
 
-        #TODO, check more?
+    def _check_response_header(self, msg):
+        '''Perform minimal validation on responses'''
+
+        # It's not clear how strict we should be about response validation.
+        # BIND 9 ignores some cases where it would normally be considered a
+        # bogus response.  For example, it accepts a response even if its
+        # opcode doesn't match that of the corresponding request.
+        # According to an original developer of BIND 9 some of the missing
+        # checks are deliberate to be kind to old implementations that would
+        # cause interoperability trouble with stricter checks.
+
         msg_rcode = msg.get_rcode()
         msg_rcode = msg.get_rcode()
         if msg_rcode != rcode.NOERROR():
         if msg_rcode != rcode.NOERROR():
             raise XfrinException('error response: %s' % msg_rcode.to_text())
             raise XfrinException('error response: %s' % msg_rcode.to_text())
@@ -204,6 +237,11 @@ involving actual communication with a remote server.
         if msg.get_qid() != self._query_id:
         if msg.get_qid() != self._query_id:
             raise XfrinException('bad query id')
             raise XfrinException('bad query id')
 
 
+    def _check_response_status(self, msg):
+        '''Check validation of xfr response. '''
+
+        self._check_response_header(msg)
+
         if msg.get_rr_count(section.ANSWER()) == 0:
         if msg.get_rr_count(section.ANSWER()) == 0:
             raise XfrinException('answer section is empty')
             raise XfrinException('answer section is empty')
 
 
@@ -286,19 +324,19 @@ involving actual communication with a remote server.
             sys.stdout.write('\n')
             sys.stdout.write('\n')
 
 
 
 
-def process_xfrin(xfrin_recorder, zone_name, db_file, 
-                  shutdown_event, master_addr, port, check_soa, verbose):
-    port = int(port)
+def process_xfrin(xfrin_recorder, zone_name, rrclass, db_file, 
+                  shutdown_event, master_addrinfo, check_soa, verbose):
     xfrin_recorder.increment(zone_name)
     xfrin_recorder.increment(zone_name)
-    conn = XfrinConnection(zone_name, db_file, shutdown_event, 
-                           master_addr, port, verbose)
+    sock_map = {}
+    conn = XfrinConnection(sock_map, zone_name, rrclass, db_file,
+                           shutdown_event, master_addrinfo, verbose)
     if conn.connect_to_master():
     if conn.connect_to_master():
         conn.do_xfrin(check_soa)
         conn.do_xfrin(check_soa)
 
 
     xfrin_recorder.decrement(zone_name)
     xfrin_recorder.decrement(zone_name)
 
 
 
 
-class XfrinRecorder():
+class XfrinRecorder:
     def __init__(self):
     def __init__(self):
         self._lock = threading.Lock()
         self._lock = threading.Lock()
         self._zones = []
         self._zones = []
@@ -326,7 +364,7 @@ class XfrinRecorder():
         self._lock.release()
         self._lock.release()
         return ret
         return ret
 
 
-class Xfrin():
+class Xfrin:
     def __init__(self, verbose = False):
     def __init__(self, verbose = False):
         self._cc_setup()
         self._cc_setup()
         self._max_transfers_in = 10
         self._max_transfers_in = 10
@@ -345,6 +383,13 @@ this method we can test most of this class without requiring a command channel.
                                               self.command_handler)
                                               self.command_handler)
         self._cc.start()
         self._cc.start()
 
 
+    def _cc_check_command(self):
+        '''
+This is a straightforward wrapper for cc.check_command, but provided as
+a separate method for the convenience of unit tests.
+'''
+        self._cc.check_command()
+
     def config_handler(self, new_config):
     def config_handler(self, new_config):
         # TODO, process new config data
         # TODO, process new config data
         return create_answer(0)
         return create_answer(0)
@@ -363,20 +408,20 @@ this method we can test most of this class without requiring a command channel.
 
 
     def command_handler(self, command, args):
     def command_handler(self, command, args):
         answer = create_answer(0)
         answer = create_answer(0)
-        cmd = command
         try:
         try:
-            if cmd == 'shutdown':
+            if command == 'shutdown':
                 self._shutdown_event.set()
                 self._shutdown_event.set()
-
-            elif cmd == 'retransfer':
-                zone_name, master, port, db_file = self._parse_cmd_params(args)
-                ret = self.xfrin_start(zone_name, db_file, master, port, False)
-                answer = create_answer(ret[0], ret[1])
-
-            elif cmd == 'refresh':
-                zone_name, master, port, db_file = self._parse_cmd_params(args)
-                ret = self.xfrin_start(zone_name, db_file, master, port)
+            elif command == 'retransfer' or command == 'refresh':
+                # The default RR class is IN.  We should fix this so that
+                # the class is passed in the command arg (where we specify
+                # the default)
+                rrclass = rr_class.IN()
+                zone_name, master_addr, db_file = self._parse_cmd_params(args)
+                ret = self.xfrin_start(zone_name, rrclass, db_file, master_addr,
+                                   False if command == 'retransfer' else True)
                 answer = create_answer(ret[0], ret[1])
                 answer = create_answer(ret[0], ret[1])
+            else:
+                answer = create_answer(1, 'unknown command: ' + command)
 
 
         except XfrinException as err:
         except XfrinException as err:
             answer = create_answer(1, str(err))
             answer = create_answer(1, str(err))
@@ -392,28 +437,23 @@ this method we can test most of this class without requiring a command channel.
         if not master:
         if not master:
             raise XfrinException('master address should be provided')
             raise XfrinException('master address should be provided')
 
 
-        check_addr(master)
-        port = 53
         port_str = args.get('port')
         port_str = args.get('port')
-        if port_str:
-            port = int(port_str)
-            check_port(port)
+        if not port_str:
+            port_str = DEFAULT_MASTER_PORT
+        master_addrinfo = check_addr_port(master, port_str)
 
 
         db_file = args.get('db_file')
         db_file = args.get('db_file')
         if not db_file:
         if not db_file:
             #TODO, the db file path should be got in auth server's configuration
             #TODO, the db file path should be got in auth server's configuration
             db_file = '@@LOCALSTATEDIR@@/@PACKAGE@/zone.sqlite3'
             db_file = '@@LOCALSTATEDIR@@/@PACKAGE@/zone.sqlite3'
 
 
-        return (zone_name, master, port, db_file)
-            
+        return (zone_name, master_addrinfo, db_file)
 
 
     def startup(self):
     def startup(self):
         while not self._shutdown_event.is_set():
         while not self._shutdown_event.is_set():
-            self._cc.check_command()
+            self._cc_check_command()
 
 
-
-    def xfrin_start(self, zone_name, db_file, master_addr, 
-                    port = 53, 
+    def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo,
                     check_soa = True):
                     check_soa = True):
         if "bind10_dns" not in sys.modules:
         if "bind10_dns" not in sys.modules:
             return (1, "xfrin failed, can't load dns message python library: 'bind10_dns'")
             return (1, "xfrin failed, can't load dns message python library: 'bind10_dns'")
@@ -427,11 +467,11 @@ this method we can test most of this class without requiring a command channel.
 
 
         xfrin_thread = threading.Thread(target = process_xfrin,
         xfrin_thread = threading.Thread(target = process_xfrin,
                                         args = (self.recorder,
                                         args = (self.recorder,
-                                                zone_name,
+                                                zone_name, rrclass,
                                                 db_file,
                                                 db_file,
                                                 self._shutdown_event,
                                                 self._shutdown_event,
-                                                master_addr,
-                                                port, check_soa, self._verbose))
+                                                master_addrinfo, check_soa,
+                                                self._verbose))
 
 
         xfrin_thread.start()
         xfrin_thread.start()
         return (0, 'zone xfrin is started')
         return (0, 'zone xfrin is started')
@@ -448,34 +488,53 @@ def set_signal_handler():
     signal.signal(signal.SIGTERM, signal_handler)
     signal.signal(signal.SIGTERM, signal_handler)
     signal.signal(signal.SIGINT, signal_handler)
     signal.signal(signal.SIGINT, signal_handler)
 
 
-def check_port(value):
-    if (value < 0) or (value > 65535):
-        raise XfrinException('requires a port number (0-65535)')
-
-def check_addr(ipstr):
-    ip_family = socket.AF_INET
-    if (ipstr.find(':') != -1):
-        ip_family = socket.AF_INET6
-
+def check_addr_port(addrstr, portstr):
+    # XXX: Linux (glibc)'s getaddrinfo incorrectly accepts numeric port
+    # string larger than 65535.  So we need to explicit validate it separately.
     try:
     try:
-        socket.inet_pton(ip_family, ipstr)
-    except:
-        raise XfrinException("%s invalid ip address" % ipstr)
+        portnum = int(portstr)
+        if portnum < 0 or portnum > 65535:
+            raise ValueError("invalid port number (out of range): " + portstr)
+    except ValueError as err:
+        raise XfrinException("failed to resolve master address/port=%s/%s: %s" %
+                             (addrstr, portstr, str(err)))
 
 
+    try:
+        addrinfo = socket.getaddrinfo(addrstr, portstr, socket.AF_UNSPEC,
+                                      socket.SOCK_STREAM, socket.IPPROTO_TCP,
+                                      socket.AI_NUMERICHOST|
+                                      socket.AI_NUMERICSERV)
+    except socket.gaierror as err:
+        raise XfrinException("failed to resolve master address/port=%s/%s: %s" %
+                             (addrstr, portstr, str(err)))
+    if len(addrinfo) != 1:
+        # with the parameters above the result must be uniquely determined.
+        errmsg = "unexpected result for address/port resolution for %s:%s"
+        raise XfrinException(errmsg % (addrstr, portstr))
+    return addrinfo[0]
 
 
 def set_cmd_options(parser):
 def set_cmd_options(parser):
     parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
     parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
             help="display more about what is going on")
             help="display more about what is going on")
 
 
-    
-if __name__ == '__main__':
+def main(xfrin_class, use_signal = True):
+    """The main loop of the Xfrin daemon.
+
+    @param xfrin_class: A class of the Xfrin object.  This is normally Xfrin,
+    but can be a subclass of it for customization.
+    @param use_signal: True if this process should catch signals.  This is
+    normally True, but may be disabled when this function is called in a
+    testing context."""
+    global xfrind
+
     try:
     try:
         parser = OptionParser(version = __version__)
         parser = OptionParser(version = __version__)
         set_cmd_options(parser)
         set_cmd_options(parser)
         (options, args) = parser.parse_args()
         (options, args) = parser.parse_args()
 
 
-        set_signal_handler()
-        xfrind = Xfrin(verbose = options.verbose)
+        if use_signal:
+            set_signal_handler()
+        xfrind = xfrin_class(verbose = options.verbose)
         xfrind.startup()
         xfrind.startup()
     except KeyboardInterrupt:
     except KeyboardInterrupt:
         log_error("exit b10-xfrin")
         log_error("exit b10-xfrin")
@@ -487,3 +546,6 @@ if __name__ == '__main__':
 
 
     if xfrind:
     if xfrind:
         xfrind.shutdown()
         xfrind.shutdown()
+
+if __name__ == '__main__':
+    main(Xfrin)