Browse Source

[1262] use XfrinConnection._diff for AXFR diffs. also delayed committing
the AXFR changes until XfrinAXFREnd.finish_message().

JINMEI Tatuya 13 years ago
parent
commit
5b302edc63
2 changed files with 92 additions and 73 deletions
  1. 76 58
      src/bin/xfrin/tests/xfrin_test.py
  2. 16 15
      src/bin/xfrin/xfrin.py.in

+ 76 - 58
src/bin/xfrin/tests/xfrin_test.py

@@ -66,6 +66,30 @@ example_soa_question = Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA())
 default_questions = [example_axfr_question]
 default_questions = [example_axfr_question]
 default_answers = [soa_rrset]
 default_answers = [soa_rrset]
 
 
+def check_diffs(assert_fn, expected, actual):
+    '''A helper function checking the differences made in the XFR session.
+
+    This is expected called from some subclass of unittest.TestCase and
+    assert_fn is generally expected to be 'self.assertEqual' of that class.
+
+    '''
+    assert_fn(len(expected), len(actual))
+    for (diffs_exp, diffs_actual) in zip(expected, actual):
+        assert_fn(len(diffs_exp), len(diffs_actual))
+        for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
+            # operation should match
+            assert_fn(diff_exp[0], diff_actual[0])
+            # The diff as RRset should be equal (for simplicity we assume
+            # all RRsets contain exactly one RDATA)
+            assert_fn(diff_exp[1].get_name(), diff_actual[1].get_name())
+            assert_fn(diff_exp[1].get_type(), diff_actual[1].get_type())
+            assert_fn(diff_exp[1].get_class(), diff_actual[1].get_class())
+            assert_fn(diff_exp[1].get_rdata_count(),
+                      diff_actual[1].get_rdata_count())
+            assert_fn(1, diff_exp[1].get_rdata_count())
+            assert_fn(diff_exp[1].get_rdata()[0],
+                      diff_actual[1].get_rdata()[0])
+
 class XfrinTestException(Exception):
 class XfrinTestException(Exception):
     pass
     pass
 
 
@@ -264,8 +288,12 @@ class TestXfrinState(unittest.TestCase):
                               RRTTL(3600))
                               RRTTL(3600))
         self.ns_rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS,
         self.ns_rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS,
                                       'ns.example.com'))
                                       'ns.example.com'))
+        self.a_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.A(),
+                             RRTTL(3600))
+        self.a_rrset.add_rdata(Rdata(RRType.A(), TEST_RRCLASS, '192.0.2.1'))
+
         self.conn._datasrc_client = MockDataSourceClient()
         self.conn._datasrc_client = MockDataSourceClient()
-        self.conn._diff = Diff(MockDataSourceClient(), TEST_ZONE_NAME)
+        self.conn._diff = Diff(self.conn._datasrc_client, TEST_ZONE_NAME)
 
 
 class TestXfrinInitialSOA(TestXfrinState):
 class TestXfrinInitialSOA(TestXfrinState):
     def setUp(self):
     def setUp(self):
@@ -293,6 +321,7 @@ class TestXfrinFirstData(TestXfrinState):
         self.state = XfrinFirstData()
         self.state = XfrinFirstData()
         self.conn._request_type = RRType.IXFR()
         self.conn._request_type = RRType.IXFR()
         self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
         self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
+        self.conn._diff = None           # should be replaced in the AXFR case
 
 
     def test_handle_ixfr_begin_soa(self):
     def test_handle_ixfr_begin_soa(self):
         self.conn._request_type = RRType.IXFR()
         self.conn._request_type = RRType.IXFR()
@@ -312,6 +341,8 @@ class TestXfrinFirstData(TestXfrinState):
         # the initial SOA.  Should switch to AXFR.
         # the initial SOA.  Should switch to AXFR.
         self.assertFalse(self.state.handle_rr(self.conn, self.ns_rrset))
         self.assertFalse(self.state.handle_rr(self.conn, self.ns_rrset))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
+        # The Diff for AXFR should be created at this point
+        self.assertNotEqual(None, self.conn._diff)
 
 
     def test_handle_ixfr_to_axfr_by_different_soa(self):
     def test_handle_ixfr_to_axfr_by_different_soa(self):
         # Response contains two consecutive SOA but the serial of the second
         # Response contains two consecutive SOA but the serial of the second
@@ -321,6 +352,7 @@ class TestXfrinFirstData(TestXfrinState):
         # be rejected anyway, but at this point we should switch to AXFR.
         # be rejected anyway, but at this point we should switch to AXFR.
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
+        self.assertNotEqual(None, self.conn._diff)
 
 
     def test_finish_message(self):
     def test_finish_message(self):
         self.assertTrue(self.state.finish_message(self.conn))
         self.assertTrue(self.state.finish_message(self.conn))
@@ -462,23 +494,18 @@ class TestXfrinAXFR(TestXfrinState):
         Test we can put data inside.
         Test we can put data inside.
         """
         """
         # Put some data inside
         # Put some data inside
-        data = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.A(), RRTTL(3600))
+        self.assertTrue(self.state.handle_rr(self.conn, self.a_rrset))
-        data.add_rdata(Rdata(RRType.A(), TEST_RRCLASS, '192.0.2.1'))
-        self.assertTrue(self.state.handle_rr(self.conn, data))
         # This test uses internal Diff structure to check the behaviour of
         # This test uses internal Diff structure to check the behaviour of
         # XfrinAXFR. Maybe there could be a cleaner way, but it would be more
         # XfrinAXFR. Maybe there could be a cleaner way, but it would be more
         # complicated.
         # complicated.
-        self.assertEqual([('add', data)],
+        self.assertEqual([('add', self.a_rrset)], self.conn._diff.get_buffer())
-                         self.state._XfrinAXFR__diff.get_buffer())
+        # This SOA terminates the transfer
-        # This SOA terminates the transef
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         # It should have changed the state
         # It should have changed the state
         self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
-        # The data should have been commited
+        # At this point, the data haven't been committed yet
-        # FIXME: Or, should we wait and see there are no data after this,
+        self.assertEqual([('add', self.a_rrset), ('add', soa_rrset)],
-        # to do the commit? Pass it to the AXFREnd?
+                         self.conn._diff.get_buffer())
-        self.assertEqual([], self.state._XfrinAXFR__diff.get_buffer())
-        self.assertRaises(ValueError, self.state._XfrinAXFR__diff.commit)
 
 
     def test_finish_message(self):
     def test_finish_message(self):
         """
         """
@@ -497,8 +524,17 @@ class TestXfrinAXFREnd(TestXfrinState):
                           self.ns_rrset)
                           self.ns_rrset)
 
 
     def test_finish_message(self):
     def test_finish_message(self):
+        self.conn._diff.add_data(self.a_rrset)
+        self.conn._diff.add_data(soa_rrset)
         self.assertFalse(self.state.finish_message(self.conn))
         self.assertFalse(self.state.finish_message(self.conn))
 
 
+        # The data should have been committed
+        self.assertEqual([], self.conn._diff.get_buffer())
+        check_diffs(self.assertEqual, [[('add', self.a_rrset),
+                                        ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
+        self.assertRaises(ValueError, self.conn._diff.commit)
+
 class TestXfrinConnection(unittest.TestCase):
 class TestXfrinConnection(unittest.TestCase):
     '''Convenient parent class for XFR-protocol tests.
     '''Convenient parent class for XFR-protocol tests.
 
 
@@ -587,30 +623,6 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data)))
         self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data)))
         self.conn.reply_data += bogus_data
         self.conn.reply_data += bogus_data
 
 
-    def check_diffs(self, expected, actual):
-        '''A helper method checking the differences made in the IXFR session.
-
-        '''
-        self.assertEqual(len(expected), len(actual))
-        for (diffs_exp, diffs_actual) in zip(expected, actual):
-            self.assertEqual(len(diffs_exp), len(diffs_actual))
-            for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
-                # operation should match
-                self.assertEqual(diff_exp[0], diff_actual[0])
-                # The diff as RRset should be equal (for simplicity we assume
-                # all RRsets contain exactly one RDATA)
-                self.assertEqual(diff_exp[1].get_name(),
-                                 diff_actual[1].get_name())
-                self.assertEqual(diff_exp[1].get_type(),
-                                 diff_actual[1].get_type())
-                self.assertEqual(diff_exp[1].get_class(),
-                                 diff_actual[1].get_class())
-                self.assertEqual(diff_exp[1].get_rdata_count(),
-                                 diff_actual[1].get_rdata_count())
-                self.assertEqual(1, diff_exp[1].get_rdata_count())
-                self.assertEqual(diff_exp[1].get_rdata()[0],
-                                 diff_actual[1].get_rdata()[0])
-
     def _create_a(self, address):
     def _create_a(self, address):
         rrset = RRset(Name('a.example.com'), TEST_RRCLASS, RRType.A(),
         rrset = RRset(Name('a.example.com'), TEST_RRCLASS, RRType.A(),
                       RRTTL(3600))
                       RRTTL(3600))
@@ -1099,8 +1111,9 @@ class TestIXFRResponse(TestXfrinConnection):
         self.conn._handle_xfrin_responses()
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.diffs)
         self.assertEqual([], self.conn._datasrc_client.diffs)
-        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+        check_diffs(self.assertEqual,
-                         self.conn._datasrc_client.committed_diffs)
+                    [[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
 
     def test_ixfr_response_multi_sequences(self):
     def test_ixfr_response_multi_sequences(self):
         '''Similar to the previous case, but with multiple diff seqs.
         '''Similar to the previous case, but with multiple diff seqs.
@@ -1125,19 +1138,20 @@ class TestIXFRResponse(TestXfrinConnection):
         self.conn._handle_xfrin_responses()
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.diffs)
         self.assertEqual([], self.conn._datasrc_client.diffs)
-        self.check_diffs([[('remove', begin_soa_rrset),
+        check_diffs(self.assertEqual,
-                           ('remove', self._create_a('192.0.2.1')),
+                    [[('remove', begin_soa_rrset),
-                           ('add', self._create_soa('1231')),
+                      ('remove', self._create_a('192.0.2.1')),
-                           ('add', self._create_a('192.0.2.2'))],
+                      ('add', self._create_soa('1231')),
-                          [('remove', self._create_soa('1231')),
+                      ('add', self._create_a('192.0.2.2'))],
-                           ('remove', self._create_a('192.0.2.3')),
+                     [('remove', self._create_soa('1231')),
-                           ('add', self._create_soa('1232')),
+                      ('remove', self._create_a('192.0.2.3')),
-                           ('add', self._create_a('192.0.2.4'))],
+                      ('add', self._create_soa('1232')),
-                          [('remove', self._create_soa('1232')),
+                      ('add', self._create_a('192.0.2.4'))],
-                           ('remove', self._create_a('192.0.2.5')),
+                     [('remove', self._create_soa('1232')),
-                           ('add', soa_rrset),
+                      ('remove', self._create_a('192.0.2.5')),
-                           ('add', self._create_a('192.0.2.6'))]],
+                      ('add', soa_rrset),
-                         self.conn._datasrc_client.committed_diffs)
+                      ('add', self._create_a('192.0.2.6'))]],
+                    self.conn._datasrc_client.committed_diffs)
 
 
     def test_ixfr_response_multi_messages(self):
     def test_ixfr_response_multi_messages(self):
         '''Similar to the first case, but RRs span over multiple messages.
         '''Similar to the first case, but RRs span over multiple messages.
@@ -1151,8 +1165,9 @@ class TestIXFRResponse(TestXfrinConnection):
             answers=[soa_rrset])
             answers=[soa_rrset])
         self.conn._handle_xfrin_responses()
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
-        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+        check_diffs(self.assertEqual,
-                         self.conn._datasrc_client.committed_diffs)
+                    [[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
 
     def test_ixfr_response_broken(self):
     def test_ixfr_response_broken(self):
         '''Test with a broken response.
         '''Test with a broken response.
@@ -1166,7 +1181,8 @@ class TestIXFRResponse(TestXfrinConnection):
         self.assertRaises(XfrinProtocolError,
         self.assertRaises(XfrinProtocolError,
                           self.conn._handle_xfrin_responses)
                           self.conn._handle_xfrin_responses)
         # no diffs should have been committed
         # no diffs should have been committed
-        self.check_diffs([], self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [], self.conn._datasrc_client.committed_diffs)
 
 
     def test_ixfr_response_extra(self):
     def test_ixfr_response_extra(self):
         '''Test with an extra RR after the end of IXFR diff sequences.
         '''Test with an extra RR after the end of IXFR diff sequences.
@@ -1181,8 +1197,9 @@ class TestIXFRResponse(TestXfrinConnection):
                      self._create_a('192.0.2.1')])
                      self._create_a('192.0.2.1')])
         self.assertRaises(XfrinProtocolError,
         self.assertRaises(XfrinProtocolError,
                           self.conn._handle_xfrin_responses)
                           self.conn._handle_xfrin_responses)
-        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+        check_diffs(self.assertEqual,
-                         self.conn._datasrc_client.committed_diffs)
+                    [[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
 
 class TestIXFRSession(TestXfrinConnection):
 class TestIXFRSession(TestXfrinConnection):
     '''Tests for a full IXFR session (query and response).
     '''Tests for a full IXFR session (query and response).
@@ -1205,8 +1222,9 @@ class TestIXFRSession(TestXfrinConnection):
 
 
         # Check some details of the IXFR protocol processing
         # Check some details of the IXFR protocol processing
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
-        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+        check_diffs(self.assertEqual,
-                         self.conn._datasrc_client.committed_diffs)
+                    [[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
 
         # Check if the query was IXFR.
         # Check if the query was IXFR.
         qdata = self.conn.query_data[2:]
         qdata = self.conn.query_data[2:]

+ 16 - 15
src/bin/xfrin/xfrin.py.in

@@ -150,12 +150,12 @@ class XfrinState:
     process successfully completes.
     process successfully completes.
 
 
 
 
-            (recv SOA)       (AXFR-style IXFR)      (SOA)
+            (recv SOA)       (AXFR-style IXFR)   (SOA, add)
     InitialSOA------->FirstData------------->AXFR--------->AXFREnd
     InitialSOA------->FirstData------------->AXFR--------->AXFREnd
-                          |                  |  ^
+                          |                  |  ^         (post xfr
-                          |                  |  |
+                          |                  |  |        checks, then
-                          |                  +--+
+                          |                  +--+        commit)
-                          |                (non SOA)
+                          |            (non SOA, add)
                           |
                           |
                           |                     (non SOA, delete)
                           |                     (non SOA, delete)
                (pure IXFR,|                           +-------+
                (pure IXFR,|                           +-------+
@@ -259,6 +259,9 @@ class XfrinFirstData(XfrinState):
         else:
         else:
             logger.debug(DBG_XFRIN_TRACE, XFRIN_GOT_NONINCREMENTAL_RESP,
             logger.debug(DBG_XFRIN_TRACE, XFRIN_GOT_NONINCREMENTAL_RESP,
                  conn.zone_str())
                  conn.zone_str())
+            # We are now goint to add RRs to the new zone.  We need create
+            # a Diff object.  It will be used throughtout the XFR session.
+            conn._diff = Diff(conn._datasrc_client, conn._zone_name, True)
             self.set_xfrstate(conn, XfrinAXFR())
             self.set_xfrstate(conn, XfrinAXFR())
         return False    # need to revisit this RR in an update context
         return False    # need to revisit this RR in an update context
 
 
@@ -332,22 +335,15 @@ class XfrinIXFREnd(XfrinState):
         return False
         return False
 
 
 class XfrinAXFR(XfrinState):
 class XfrinAXFR(XfrinState):
-    def __init__(self):
-        self.__diff = None
-
     def handle_rr(self, conn, rr):
     def handle_rr(self, conn, rr):
         """
         """
         Handle the RR by putting it into the zone.
         Handle the RR by putting it into the zone.
         """
         """
-        if self.__diff is None:
+        conn._diff.add_data(rr)
-            # This is the first RR there. So create the diff to accumulate
-            # data
-            self.__diff = Diff(conn._datasrc_client, conn._zone_name, True)
-        self.__diff.add_data(rr)
         if rr.get_type() == RRType.SOA():
         if rr.get_type() == RRType.SOA():
-            # Soa means end. We commit the data and move to the final state
+            # SOA means end.  Don't commit it yet - we need to perform
+            # post-transfer checks
             self.set_xfrstate(conn, XfrinAXFREnd())
             self.set_xfrstate(conn, XfrinAXFREnd())
-            self.__diff.commit()
         # Yes, we've eaten this RR.
         # Yes, we've eaten this RR.
         return True
         return True
 
 
@@ -360,9 +356,14 @@ class XfrinAXFREnd(XfrinState):
         """
         """
         Final processing after processing an entire AXFR session.
         Final processing after processing an entire AXFR session.
 
 
+        In this process all the AXFR changes are committed to the
+        data source.
+
         There might be more actions here, but for now we simply return False,
         There might be more actions here, but for now we simply return False,
         indicating there will be no more message to receive.
         indicating there will be no more message to receive.
+
         """
         """
+        conn._diff.commit()
         return False
         return False
 
 
 class XfrinConnection(asyncore.dispatcher):
 class XfrinConnection(asyncore.dispatcher):