Parcourir la source

[1209] unified AXFR message handling logic into _handle_xfrin_responses().
removed now-unused code path, and adjusted existing test cases accordingly.

JINMEI Tatuya il y a 13 ans
Parent
commit
6aa910d630
2 fichiers modifiés avec 44 ajouts et 71 suppressions
  1. 41 36
      src/bin/xfrin/tests/xfrin_test.py
  2. 3 35
      src/bin/xfrin/xfrin.py.in

+ 41 - 36
src/bin/xfrin/tests/xfrin_test.py

@@ -93,6 +93,9 @@ def check_diffs(assert_fn, expected, actual):
 class XfrinTestException(Exception):
     pass
 
+class XfrinTestTimeoutException(Exception):
+    pass
+
 class MockCC():
     def get_default_value(self, identifier):
         if identifier == "zones/master_port":
@@ -219,6 +222,8 @@ class MockXfrinConnection(XfrinConnection):
     def recv(self, size):
         data = self.reply_data[:size]
         self.reply_data = self.reply_data[size:]
+        if len(data) == 0:
+            raise XfrinTestTimeoutException('Emulated timeout')
         if len(data) < size:
             raise XfrinTestException('cannot get reply data (' + str(size) +
                                      ' bytes)')
@@ -569,22 +574,14 @@ class TestXfrinConnection(unittest.TestCase):
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
 
-    def _handle_xfrin_response(self):
-        # This helper methods iterates over all RRs (excluding the ending SOA)
-        # transferred, and simply returns the number of RRs.  The return value
-        # may be used an assertion value for test cases.
-        rrs = 0
-        for rr in self.conn._handle_axfrin_response():
-            rrs += 1
-        return rrs
-
     def _create_normal_response_data(self):
         # This helper method creates a simple sequence of DNS messages that
-        # forms a valid AXFR transaction.  It consists of two messages, each
-        # containing just a single SOA RR.
+        # forms a valid AXFR transaction.  It consists of two messages: the
+        # first one containing SOA, NS, the second containing the trailing SOA.
         tsig_1st = self.axfr_response_params['tsig_1st']
         tsig_2nd = self.axfr_response_params['tsig_2nd']
-        self.conn.reply_data = self.conn.create_response_data(tsig_ctx=tsig_1st)
+        self.conn.reply_data = self.conn.create_response_data(
+            answers=[soa_rrset, self._create_ns()], tsig_ctx=tsig_1st)
         self.conn.reply_data += \
             self.conn.create_response_data(tsig_ctx=tsig_2nd)
 
@@ -644,6 +641,7 @@ class TestXfrinConnection(unittest.TestCase):
 class TestAXFR(TestXfrinConnection):
     def setUp(self):
         super().setUp()
+        XfrinInitialSOA().set_xfrstate(self.conn, XfrinInitialSOA())
 
     def __create_mock_tsig(self, key, error):
         # This helper function creates a MockTSIGContext for a given key
@@ -775,24 +773,28 @@ class TestAXFR(TestXfrinConnection):
 
     def test_response_with_invalid_msg(self):
         self.conn.reply_data = b'aaaxxxx'
-        self.assertRaises(XfrinTestException, self._handle_xfrin_response)
+        self.assertRaises(XfrinTestException,
+                          self.conn._handle_xfrin_responses)
 
     def test_response_with_tsigfail(self):
         self.conn._tsig_key = TSIG_KEY
         # server tsig check fail, return with RCODE 9 (NOTAUTH)
         self.conn._send_query(RRType.SOA())
         self.conn.reply_data = self.conn.create_response_data(rcode=Rcode.NOTAUTH())
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_without_end_soa(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data()
-        self.assertRaises(XfrinTestException, self._handle_xfrin_response)
+        # This should result in timeout in the asyncore loop.  We emulate
+        # that situation in recv() by emptying the reply data buffer.
+        self.assertRaises(XfrinTestTimeoutException,
+                          self.conn._handle_xfrin_responses)
 
     def test_response_bad_qid(self):
         self.conn._send_query(RRType.AXFR())
-        self.conn.reply_data = self.conn.create_response_data(bad_qid = True)
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_error_code_bad_sig(self):
         self.conn._tsig_key = TSIG_KEY
@@ -805,7 +807,7 @@ class TestAXFR(TestXfrinConnection):
         # validate log message for XfrinException
         self.__match_exception(XfrinException,
                                "TSIG verify fail: BADSIG",
-                               self._handle_xfrin_response)
+                               self.conn._handle_xfrin_responses)
 
     def test_response_bad_qid_bad_key(self):
         self.conn._tsig_key = TSIG_KEY
@@ -817,36 +819,36 @@ class TestAXFR(TestXfrinConnection):
         # validate log message for XfrinException
         self.__match_exception(XfrinException,
                                "TSIG verify fail: BADKEY",
-                               self._handle_xfrin_response)
+                               self.conn._handle_xfrin_responses)
 
     def test_response_non_response(self):
         self.conn._send_query(RRType.AXFR())
-        self.conn.reply_data = self.conn.create_response_data(response = False)
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.conn.reply_data = self.conn.create_response_data(response=False)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_error_code(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(
             rcode=Rcode.SERVFAIL())
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_multi_question(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(
             questions=[example_axfr_question, example_axfr_question])
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_empty_answer(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(answers=[])
         # Should an empty answer trigger an exception?  Even though it's very
         # unusual it's not necessarily invalid.  Need to revisit.
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_non_response(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(response = False)
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_soacheck(self):
         # we need to defer the creation until we know the QID, which is
@@ -937,30 +939,32 @@ class TestAXFR(TestXfrinConnection):
         self.conn.response_generator = self._create_normal_response_data
         self.conn._shutdown_event.set()
         self.conn._send_query(RRType.AXFR())
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_timeout(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.force_time_out = True
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_remote_close(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.force_close = True
-        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
 
     def test_response_bad_message(self):
         self.conn.response_generator = self._create_broken_response_data
         self.conn._send_query(RRType.AXFR())
-        self.assertRaises(Exception, self._handle_xfrin_response)
+        self.assertRaises(Exception, self.conn._handle_xfrin_responses)
 
     def test_axfr_response(self):
-        # normal case.
+        # A simple normal case: AXFR consists of SOA, NS, then trailing SOA.
         self.conn.response_generator = self._create_normal_response_data
         self.conn._send_query(RRType.AXFR())
-        # two SOAs, and only these have been transfered.  the 2nd SOA is just
-        # a marker, so only 1 RR has been provided in the iteration.
-        self.assertEqual(self._handle_xfrin_response(), 1)
+        self.conn._handle_xfrin_responses()
+        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
+        check_diffs(self.assertEqual,
+                    [[('add', self._create_ns()), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
     def test_do_xfrin(self):
         self.conn.response_generator = self._create_normal_response_data
@@ -974,9 +978,10 @@ class TestAXFR(TestXfrinConnection):
             lambda key: self.__create_mock_tsig(key, TSIGError.NOERROR)
         self.conn.response_generator = self._create_normal_response_data
         self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
-        # We use two messages in the tests.  The same context should have been
-        # usef for both.
-        self.assertEqual(2, self.conn._tsig_ctx.verify_called)
+        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
+        check_diffs(self.assertEqual,
+                    [[('add', self._create_ns()), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
     def test_do_xfrin_with_tsig_fail(self):
         # TSIG verify will fail for the first message.  xfrin should fail

+ 3 - 35
src/bin/xfrin/xfrin.py.in

@@ -586,16 +586,9 @@ class XfrinConnection(asyncore.dispatcher):
             if ret == XFRIN_OK:
                 logger.info(XFRIN_XFR_TRANSFER_STARTED, request_str,
                             self.zone_str())
-                if self._request_type == RRType.IXFR():
-                    self._request_type = RRType.IXFR()
-                    self._send_query(self._request_type)
-                    self.__state = XfrinInitialSOA()
-                    self._handle_xfrin_responses()
-                else:
-                    self._send_query(self._request_type)
-                    isc.datasrc.sqlite3_ds.load(self._db_file,
-                                                self._zone_name.to_text(),
-                                                self._handle_axfrin_response)
+                self._send_query(self._request_type)
+                self.__state = XfrinInitialSOA()
+                self._handle_xfrin_responses()
                 logger.info(XFRIN_XFR_TRANSFER_SUCCESS, request_str,
                             self.zone_str())
 
@@ -713,31 +706,6 @@ class XfrinConnection(asyncore.dispatcher):
             if self._shutdown_event.is_set():
                 raise XfrinException('xfrin is forced to stop')
 
-    def _handle_axfrin_response(self):
-        '''Return a generator for the response to a zone transfer. '''
-        while True:
-            data_len = self._get_request_response(2)
-            msg_len = socket.htons(struct.unpack('H', data_len)[0])
-            recvdata = self._get_request_response(msg_len)
-            msg = Message(Message.PARSE)
-            msg.from_wire(recvdata)
-
-            # TSIG related checks, including an unexpected signed response
-            self._check_response_tsig(msg, recvdata)
-
-            # Perform response status validation
-            self._check_response_status(msg)
-
-            answer_section = msg.get_section(Message.SECTION_ANSWER)
-            for rr in self._handle_answer_section(answer_section):
-                yield rr
-
-            if self._soa_rr_count == 2:
-                break
-
-            if self._shutdown_event.is_set():
-                raise XfrinException('xfrin is forced to stop')
-
     def handle_read(self):
         '''Read query's response from socket. '''