Browse Source

[1261] added finish_message() method for the xfrin state classes and
supported multiple-message IXFR session.

JINMEI Tatuya 13 years ago
parent
commit
9163b38338
2 changed files with 59 additions and 3 deletions
  1. 37 0
      src/bin/xfrin/tests/xfrin_test.py
  2. 22 3
      src/bin/xfrin/xfrin.py.in

+ 37 - 0
src/bin/xfrin/tests/xfrin_test.py

@@ -283,6 +283,9 @@ class TestXfrinInitialSOA(TestXfrinState):
         self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
         self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
                           self.ns_rrset)
                           self.ns_rrset)
 
 
+    def test_finish_message(self):
+        self.assertTrue(self.state.finish_message(self.conn))
+
 class TestXfrinFirstData(TestXfrinState):
 class TestXfrinFirstData(TestXfrinState):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
@@ -318,6 +321,9 @@ class TestXfrinFirstData(TestXfrinState):
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
 
 
+    def test_finish_message(self):
+        self.assertTrue(self.state.finish_message(self.conn))
+
 class TestXfrinIXFRDeleteSOA(TestXfrinState):
 class TestXfrinIXFRDeleteSOA(TestXfrinState):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
@@ -337,6 +343,9 @@ class TestXfrinIXFRDeleteSOA(TestXfrinState):
         self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
         self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
                           self.ns_rrset)
                           self.ns_rrset)
 
 
+    def test_finish_message(self):
+        self.assertTrue(self.state.finish_message(self.conn))
+
 class TestXfrinIXFRDelete(TestXfrinState):
 class TestXfrinIXFRDelete(TestXfrinState):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
@@ -364,6 +373,9 @@ class TestXfrinIXFRDelete(TestXfrinState):
         self.assertEqual(type(XfrinIXFRAddSOA()),
         self.assertEqual(type(XfrinIXFRAddSOA()),
                          type(self.conn.get_xfrstate()))
                          type(self.conn.get_xfrstate()))
 
 
+    def test_finish_message(self):
+        self.assertTrue(self.state.finish_message(self.conn))
+
 class TestXfrinIXFRAddSOA(TestXfrinState):
 class TestXfrinIXFRAddSOA(TestXfrinState):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
@@ -379,6 +391,9 @@ class TestXfrinIXFRAddSOA(TestXfrinState):
         self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
         self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
                           self.ns_rrset)
                           self.ns_rrset)
 
 
+    def test_finish_message(self):
+        self.assertTrue(self.state.finish_message(self.conn))
+
 class TestXfrinIXFRAdd(TestXfrinState):
 class TestXfrinIXFRAdd(TestXfrinState):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
@@ -420,6 +435,9 @@ class TestXfrinIXFRAdd(TestXfrinState):
         self.assertRaises(XfrinProtocolError, self.state.handle_rr,
         self.assertRaises(XfrinProtocolError, self.state.handle_rr,
                           self.conn, soa_rrset)
                           self.conn, soa_rrset)
 
 
+    def test_finish_message(self):
+        self.assertTrue(self.state.finish_message(self.conn))
+
 class TestXfrinIXFREnd(TestXfrinState):
 class TestXfrinIXFREnd(TestXfrinState):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
@@ -429,6 +447,9 @@ class TestXfrinIXFREnd(TestXfrinState):
         self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
         self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
                           self.ns_rrset)
                           self.ns_rrset)
 
 
+    def test_finish_message(self):
+        self.assertFalse(self.state.finish_message(self.conn))
+
 class TestXfrinAXFR(TestXfrinState):
 class TestXfrinAXFR(TestXfrinState):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
@@ -438,6 +459,9 @@ class TestXfrinAXFR(TestXfrinState):
         self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
         self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
                           soa_rrset)
                           soa_rrset)
 
 
+    def test_finish_message(self):
+        self.assertTrue(self.state.finish_message(self.conn))
+
 class TestXfrinConnection(unittest.TestCase):
 class TestXfrinConnection(unittest.TestCase):
     '''Convenient parent class for XFR-protocol tests.
     '''Convenient parent class for XFR-protocol tests.
 
 
@@ -1081,6 +1105,19 @@ class TestIXFR(TestXfrinConnection):
                            ('add', self.create_a('192.0.2.6'))]],
                            ('add', self.create_a('192.0.2.6'))]],
                          self.conn._datasrc_client.committed_diffs)
                          self.conn._datasrc_client.committed_diffs)
 
 
+    def test_ixfr_response_multi_messages(self):
+        '''Similar to the first case, but RRs span over multiple messages.
+
+        '''
+        self.conn.reply_data = self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[soa_rrset, begin_soa_rrset, soa_rrset])
+        self.conn.reply_data += self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[soa_rrset])
+        self.conn._handle_xfrin_responses()
+        self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
+
 class TestXfrinRecorder(unittest.TestCase):
 class TestXfrinRecorder(unittest.TestCase):
     def setUp(self):
     def setUp(self):
         self.recorder = XfrinRecorder()
         self.recorder = XfrinRecorder()

+ 22 - 3
src/bin/xfrin/xfrin.py.in

@@ -146,6 +146,17 @@ class XfrinState:
         '''
         '''
         conn._XfrinConnection__set_xfrstate(new_state)
         conn._XfrinConnection__set_xfrstate(new_state)
 
 
+    def finish_message(self, conn):
+        '''Perform any final processing after handling all RRs of a response.
+
+        This method then returns a boolean indicating whether to continue
+        receiving the message.  Unless it's in the end of the entire XFR
+        session, we should continue, so this default method simply returns
+        True.
+
+        '''
+        return True
+
 class XfrinInitialSOA(XfrinState):
 class XfrinInitialSOA(XfrinState):
     def handle_rr(self, conn, rr):
     def handle_rr(self, conn, rr):
         if rr.get_type() != RRType.SOA():
         if rr.get_type() != RRType.SOA():
@@ -238,6 +249,15 @@ class XfrinIXFREnd(XfrinState):
         raise XfrinProtocolError('Extra data after the end of IXFR diffs: ' + \
         raise XfrinProtocolError('Extra data after the end of IXFR diffs: ' + \
                                      rr.to_text())
                                      rr.to_text())
 
 
+    def finish_message(self, conn):
+        '''Final processing after processing an entire IXFR session.
+
+        There will be more actions here, but for now we simply return False,
+        indicating there will be no more message to receive.
+
+        '''
+        return False
+
 class XfrinAXFR(XfrinState):
 class XfrinAXFR(XfrinState):
     def handle_rr(self, conn, rr):
     def handle_rr(self, conn, rr):
         raise XfrinException('Falling back from IXFR to AXFR not ' + \
         raise XfrinException('Falling back from IXFR to AXFR not ' + \
@@ -559,12 +579,11 @@ class XfrinConnection(asyncore.dispatcher):
                 while not rr_handled:
                 while not rr_handled:
                     rr_handled = self.__state.handle_rr(self, rr)
                     rr_handled = self.__state.handle_rr(self, rr)
 
 
+            read_next_msg = self.__state.finish_message(self)
+
             if self._shutdown_event.is_set():
             if self._shutdown_event.is_set():
                 raise XfrinException('xfrin is forced to stop')
                 raise XfrinException('xfrin is forced to stop')
 
 
-            # placeholder
-            read_next_msg = False
-
     def _handle_axfrin_response(self):
     def _handle_axfrin_response(self):
         '''Return a generator for the response to a zone transfer. '''
         '''Return a generator for the response to a zone transfer. '''
         while True:
         while True: