Browse Source

[1262] use XfrinConnection._diff for AXFR diffs. also delayed committing
the AXFR changes until XfrinAXFREnd.finish_message().

JINMEI Tatuya 13 years ago
parent
commit
5b302edc63
2 changed files with 92 additions and 73 deletions
  1. 76 58
      src/bin/xfrin/tests/xfrin_test.py
  2. 16 15
      src/bin/xfrin/xfrin.py.in

+ 76 - 58
src/bin/xfrin/tests/xfrin_test.py

@@ -66,6 +66,30 @@ example_soa_question = Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA())
 default_questions = [example_axfr_question]
 default_answers = [soa_rrset]
 
+def check_diffs(assert_fn, expected, actual):
+    '''A helper function checking the differences made in the XFR session.
+
+    This is expected called from some subclass of unittest.TestCase and
+    assert_fn is generally expected to be 'self.assertEqual' of that class.
+
+    '''
+    assert_fn(len(expected), len(actual))
+    for (diffs_exp, diffs_actual) in zip(expected, actual):
+        assert_fn(len(diffs_exp), len(diffs_actual))
+        for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
+            # operation should match
+            assert_fn(diff_exp[0], diff_actual[0])
+            # The diff as RRset should be equal (for simplicity we assume
+            # all RRsets contain exactly one RDATA)
+            assert_fn(diff_exp[1].get_name(), diff_actual[1].get_name())
+            assert_fn(diff_exp[1].get_type(), diff_actual[1].get_type())
+            assert_fn(diff_exp[1].get_class(), diff_actual[1].get_class())
+            assert_fn(diff_exp[1].get_rdata_count(),
+                      diff_actual[1].get_rdata_count())
+            assert_fn(1, diff_exp[1].get_rdata_count())
+            assert_fn(diff_exp[1].get_rdata()[0],
+                      diff_actual[1].get_rdata()[0])
+
 class XfrinTestException(Exception):
     pass
 
@@ -264,8 +288,12 @@ class TestXfrinState(unittest.TestCase):
                               RRTTL(3600))
         self.ns_rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS,
                                       'ns.example.com'))
+        self.a_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.A(),
+                             RRTTL(3600))
+        self.a_rrset.add_rdata(Rdata(RRType.A(), TEST_RRCLASS, '192.0.2.1'))
+
         self.conn._datasrc_client = MockDataSourceClient()
-        self.conn._diff = Diff(MockDataSourceClient(), TEST_ZONE_NAME)
+        self.conn._diff = Diff(self.conn._datasrc_client, TEST_ZONE_NAME)
 
 class TestXfrinInitialSOA(TestXfrinState):
     def setUp(self):
@@ -293,6 +321,7 @@ class TestXfrinFirstData(TestXfrinState):
         self.state = XfrinFirstData()
         self.conn._request_type = RRType.IXFR()
         self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
+        self.conn._diff = None           # should be replaced in the AXFR case
 
     def test_handle_ixfr_begin_soa(self):
         self.conn._request_type = RRType.IXFR()
@@ -312,6 +341,8 @@ class TestXfrinFirstData(TestXfrinState):
         # the initial SOA.  Should switch to AXFR.
         self.assertFalse(self.state.handle_rr(self.conn, self.ns_rrset))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
+        # The Diff for AXFR should be created at this point
+        self.assertNotEqual(None, self.conn._diff)
 
     def test_handle_ixfr_to_axfr_by_different_soa(self):
         # Response contains two consecutive SOA but the serial of the second
@@ -321,6 +352,7 @@ class TestXfrinFirstData(TestXfrinState):
         # be rejected anyway, but at this point we should switch to AXFR.
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
+        self.assertNotEqual(None, self.conn._diff)
 
     def test_finish_message(self):
         self.assertTrue(self.state.finish_message(self.conn))
@@ -462,23 +494,18 @@ class TestXfrinAXFR(TestXfrinState):
         Test we can put data inside.
         """
         # Put some data inside
-        data = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.A(), RRTTL(3600))
-        data.add_rdata(Rdata(RRType.A(), TEST_RRCLASS, '192.0.2.1'))
-        self.assertTrue(self.state.handle_rr(self.conn, data))
+        self.assertTrue(self.state.handle_rr(self.conn, self.a_rrset))
         # This test uses internal Diff structure to check the behaviour of
         # XfrinAXFR. Maybe there could be a cleaner way, but it would be more
         # complicated.
-        self.assertEqual([('add', data)],
-                         self.state._XfrinAXFR__diff.get_buffer())
-        # This SOA terminates the transef
+        self.assertEqual([('add', self.a_rrset)], self.conn._diff.get_buffer())
+        # This SOA terminates the transfer
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         # It should have changed the state
         self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
-        # The data should have been commited
-        # FIXME: Or, should we wait and see there are no data after this,
-        # to do the commit? Pass it to the AXFREnd?
-        self.assertEqual([], self.state._XfrinAXFR__diff.get_buffer())
-        self.assertRaises(ValueError, self.state._XfrinAXFR__diff.commit)
+        # At this point, the data haven't been committed yet
+        self.assertEqual([('add', self.a_rrset), ('add', soa_rrset)],
+                         self.conn._diff.get_buffer())
 
     def test_finish_message(self):
         """
@@ -497,8 +524,17 @@ class TestXfrinAXFREnd(TestXfrinState):
                           self.ns_rrset)
 
     def test_finish_message(self):
+        self.conn._diff.add_data(self.a_rrset)
+        self.conn._diff.add_data(soa_rrset)
         self.assertFalse(self.state.finish_message(self.conn))
 
+        # The data should have been committed
+        self.assertEqual([], self.conn._diff.get_buffer())
+        check_diffs(self.assertEqual, [[('add', self.a_rrset),
+                                        ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
+        self.assertRaises(ValueError, self.conn._diff.commit)
+
 class TestXfrinConnection(unittest.TestCase):
     '''Convenient parent class for XFR-protocol tests.
 
@@ -587,30 +623,6 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data)))
         self.conn.reply_data += bogus_data
 
-    def check_diffs(self, expected, actual):
-        '''A helper method checking the differences made in the IXFR session.
-
-        '''
-        self.assertEqual(len(expected), len(actual))
-        for (diffs_exp, diffs_actual) in zip(expected, actual):
-            self.assertEqual(len(diffs_exp), len(diffs_actual))
-            for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
-                # operation should match
-                self.assertEqual(diff_exp[0], diff_actual[0])
-                # The diff as RRset should be equal (for simplicity we assume
-                # all RRsets contain exactly one RDATA)
-                self.assertEqual(diff_exp[1].get_name(),
-                                 diff_actual[1].get_name())
-                self.assertEqual(diff_exp[1].get_type(),
-                                 diff_actual[1].get_type())
-                self.assertEqual(diff_exp[1].get_class(),
-                                 diff_actual[1].get_class())
-                self.assertEqual(diff_exp[1].get_rdata_count(),
-                                 diff_actual[1].get_rdata_count())
-                self.assertEqual(1, diff_exp[1].get_rdata_count())
-                self.assertEqual(diff_exp[1].get_rdata()[0],
-                                 diff_actual[1].get_rdata()[0])
-
     def _create_a(self, address):
         rrset = RRset(Name('a.example.com'), TEST_RRCLASS, RRType.A(),
                       RRTTL(3600))
@@ -1099,8 +1111,9 @@ class TestIXFRResponse(TestXfrinConnection):
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.diffs)
-        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
     def test_ixfr_response_multi_sequences(self):
         '''Similar to the previous case, but with multiple diff seqs.
@@ -1125,19 +1138,20 @@ class TestIXFRResponse(TestXfrinConnection):
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.diffs)
-        self.check_diffs([[('remove', begin_soa_rrset),
-                           ('remove', self._create_a('192.0.2.1')),
-                           ('add', self._create_soa('1231')),
-                           ('add', self._create_a('192.0.2.2'))],
-                          [('remove', self._create_soa('1231')),
-                           ('remove', self._create_a('192.0.2.3')),
-                           ('add', self._create_soa('1232')),
-                           ('add', self._create_a('192.0.2.4'))],
-                          [('remove', self._create_soa('1232')),
-                           ('remove', self._create_a('192.0.2.5')),
-                           ('add', soa_rrset),
-                           ('add', self._create_a('192.0.2.6'))]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('remove', begin_soa_rrset),
+                      ('remove', self._create_a('192.0.2.1')),
+                      ('add', self._create_soa('1231')),
+                      ('add', self._create_a('192.0.2.2'))],
+                     [('remove', self._create_soa('1231')),
+                      ('remove', self._create_a('192.0.2.3')),
+                      ('add', self._create_soa('1232')),
+                      ('add', self._create_a('192.0.2.4'))],
+                     [('remove', self._create_soa('1232')),
+                      ('remove', self._create_a('192.0.2.5')),
+                      ('add', soa_rrset),
+                      ('add', self._create_a('192.0.2.6'))]],
+                    self.conn._datasrc_client.committed_diffs)
 
     def test_ixfr_response_multi_messages(self):
         '''Similar to the first case, but RRs span over multiple messages.
@@ -1151,8 +1165,9 @@ class TestIXFRResponse(TestXfrinConnection):
             answers=[soa_rrset])
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
-        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
     def test_ixfr_response_broken(self):
         '''Test with a broken response.
@@ -1166,7 +1181,8 @@ class TestIXFRResponse(TestXfrinConnection):
         self.assertRaises(XfrinProtocolError,
                           self.conn._handle_xfrin_responses)
         # no diffs should have been committed
-        self.check_diffs([], self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [], self.conn._datasrc_client.committed_diffs)
 
     def test_ixfr_response_extra(self):
         '''Test with an extra RR after the end of IXFR diff sequences.
@@ -1181,8 +1197,9 @@ class TestIXFRResponse(TestXfrinConnection):
                      self._create_a('192.0.2.1')])
         self.assertRaises(XfrinProtocolError,
                           self.conn._handle_xfrin_responses)
-        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
 class TestIXFRSession(TestXfrinConnection):
     '''Tests for a full IXFR session (query and response).
@@ -1205,8 +1222,9 @@ class TestIXFRSession(TestXfrinConnection):
 
         # Check some details of the IXFR protocol processing
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
-        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
         # Check if the query was IXFR.
         qdata = self.conn.query_data[2:]

+ 16 - 15
src/bin/xfrin/xfrin.py.in

@@ -150,12 +150,12 @@ class XfrinState:
     process successfully completes.
 
 
-            (recv SOA)       (AXFR-style IXFR)      (SOA)
+            (recv SOA)       (AXFR-style IXFR)   (SOA, add)
     InitialSOA------->FirstData------------->AXFR--------->AXFREnd
-                          |                  |  ^
-                          |                  |  |
-                          |                  +--+
-                          |                (non SOA)
+                          |                  |  ^         (post xfr
+                          |                  |  |        checks, then
+                          |                  +--+        commit)
+                          |            (non SOA, add)
                           |
                           |                     (non SOA, delete)
                (pure IXFR,|                           +-------+
@@ -259,6 +259,9 @@ class XfrinFirstData(XfrinState):
         else:
             logger.debug(DBG_XFRIN_TRACE, XFRIN_GOT_NONINCREMENTAL_RESP,
                  conn.zone_str())
+            # We are now goint to add RRs to the new zone.  We need create
+            # a Diff object.  It will be used throughtout the XFR session.
+            conn._diff = Diff(conn._datasrc_client, conn._zone_name, True)
             self.set_xfrstate(conn, XfrinAXFR())
         return False    # need to revisit this RR in an update context
 
@@ -332,22 +335,15 @@ class XfrinIXFREnd(XfrinState):
         return False
 
 class XfrinAXFR(XfrinState):
-    def __init__(self):
-        self.__diff = None
-
     def handle_rr(self, conn, rr):
         """
         Handle the RR by putting it into the zone.
         """
-        if self.__diff is None:
-            # This is the first RR there. So create the diff to accumulate
-            # data
-            self.__diff = Diff(conn._datasrc_client, conn._zone_name, True)
-        self.__diff.add_data(rr)
+        conn._diff.add_data(rr)
         if rr.get_type() == RRType.SOA():
-            # Soa means end. We commit the data and move to the final state
+            # SOA means end.  Don't commit it yet - we need to perform
+            # post-transfer checks
             self.set_xfrstate(conn, XfrinAXFREnd())
-            self.__diff.commit()
         # Yes, we've eaten this RR.
         return True
 
@@ -360,9 +356,14 @@ class XfrinAXFREnd(XfrinState):
         """
         Final processing after processing an entire AXFR session.
 
+        In this process all the AXFR changes are committed to the
+        data source.
+
         There might be more actions here, but for now we simply return False,
         indicating there will be no more message to receive.
+
         """
+        conn._diff.commit()
         return False
 
 class XfrinConnection(asyncore.dispatcher):