Browse Source

Merge #1262

Conflicts:
	src/bin/xfrin/tests/xfrin_test.py
	src/lib/python/isc/xfrin/diff.py
	src/lib/python/isc/xfrin/tests/diff_tests.py
Michal 'vorner' Vaner 13 years ago
parent
commit
929daeade2

+ 209 - 49
src/bin/xfrin/tests/xfrin_test.py

@@ -66,6 +66,30 @@ example_soa_question = Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA())
 default_questions = [example_axfr_question]
 default_answers = [soa_rrset]
 
+def check_diffs(assert_fn, expected, actual):
+    '''A helper function checking the differences made in the XFR session.
+
+    This is expected called from some subclass of unittest.TestCase and
+    assert_fn is generally expected to be 'self.assertEqual' of that class.
+
+    '''
+    assert_fn(len(expected), len(actual))
+    for (diffs_exp, diffs_actual) in zip(expected, actual):
+        assert_fn(len(diffs_exp), len(diffs_actual))
+        for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
+            # operation should match
+            assert_fn(diff_exp[0], diff_actual[0])
+            # The diff as RRset should be equal (for simplicity we assume
+            # all RRsets contain exactly one RDATA)
+            assert_fn(diff_exp[1].get_name(), diff_actual[1].get_name())
+            assert_fn(diff_exp[1].get_type(), diff_actual[1].get_type())
+            assert_fn(diff_exp[1].get_class(), diff_actual[1].get_class())
+            assert_fn(diff_exp[1].get_rdata_count(),
+                      diff_actual[1].get_rdata_count())
+            assert_fn(1, diff_exp[1].get_rdata_count())
+            assert_fn(diff_exp[1].get_rdata()[0],
+                      diff_actual[1].get_rdata()[0])
+
 class XfrinTestException(Exception):
     pass
 
@@ -274,8 +298,12 @@ class TestXfrinState(unittest.TestCase):
                               RRTTL(3600))
         self.ns_rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS,
                                       'ns.example.com'))
+        self.a_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.A(),
+                             RRTTL(3600))
+        self.a_rrset.add_rdata(Rdata(RRType.A(), TEST_RRCLASS, '192.0.2.1'))
+
         self.conn._datasrc_client = MockDataSourceClient()
-        self.conn._diff = Diff(MockDataSourceClient(), TEST_ZONE_NAME)
+        self.conn._diff = Diff(self.conn._datasrc_client, TEST_ZONE_NAME)
 
 class TestXfrinStateBase(TestXfrinState):
     def setUp(self):
@@ -312,6 +340,7 @@ class TestXfrinFirstData(TestXfrinState):
         self.state = XfrinFirstData()
         self.conn._request_type = RRType.IXFR()
         self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
+        self.conn._diff = None           # should be replaced in the AXFR case
 
     def test_handle_ixfr_begin_soa(self):
         self.conn._request_type = RRType.IXFR()
@@ -331,6 +360,8 @@ class TestXfrinFirstData(TestXfrinState):
         # the initial SOA.  Should switch to AXFR.
         self.assertFalse(self.state.handle_rr(self.conn, self.ns_rrset))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
+        # The Diff for AXFR should be created at this point
+        self.assertNotEqual(None, self.conn._diff)
 
     def test_handle_ixfr_to_axfr_by_different_soa(self):
         # An unusual case: Response contains two consecutive SOA but the
@@ -338,6 +369,7 @@ class TestXfrinFirstData(TestXfrinState):
         # the documentation for XfrinFirstData.handle_rr().
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
+        self.assertNotEqual(None, self.conn._diff)
 
     def test_finish_message(self):
         self.assertTrue(self.state.finish_message(self.conn))
@@ -475,12 +507,51 @@ class TestXfrinAXFR(TestXfrinState):
         self.state = XfrinAXFR()
 
     def test_handle_rr(self):
-        self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
-                          soa_rrset)
+        """
+        Test we can put data inside.
+        """
+        # Put some data inside
+        self.assertTrue(self.state.handle_rr(self.conn, self.a_rrset))
+        # This test uses internal Diff structure to check the behaviour of
+        # XfrinAXFR. Maybe there could be a cleaner way, but it would be more
+        # complicated.
+        self.assertEqual([('add', self.a_rrset)], self.conn._diff.get_buffer())
+        # This SOA terminates the transfer
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        # It should have changed the state
+        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
+        # At this point, the data haven't been committed yet
+        self.assertEqual([('add', self.a_rrset), ('add', soa_rrset)],
+                         self.conn._diff.get_buffer())
 
     def test_finish_message(self):
+        """
+        Check normal end of message.
+        """
+        # When a message ends, nothing happens usually
         self.assertTrue(self.state.finish_message(self.conn))
 
+class TestXfrinAXFREnd(TestXfrinState):
+    def setUp(self):
+        super().setUp()
+        self.state = XfrinAXFREnd()
+
+    def test_handle_rr(self):
+        self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
+                          self.ns_rrset)
+
+    def test_finish_message(self):
+        self.conn._diff.add_data(self.a_rrset)
+        self.conn._diff.add_data(soa_rrset)
+        self.assertFalse(self.state.finish_message(self.conn))
+
+        # The data should have been committed
+        self.assertEqual([], self.conn._diff.get_buffer())
+        check_diffs(self.assertEqual, [[('add', self.a_rrset),
+                                        ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
+        self.assertRaises(ValueError, self.conn._diff.commit)
+
 class TestXfrinConnection(unittest.TestCase):
     '''Convenient parent class for XFR-protocol tests.
 
@@ -569,30 +640,6 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data)))
         self.conn.reply_data += bogus_data
 
-    def check_diffs(self, expected, actual):
-        '''A helper method checking the differences made in the IXFR session.
-
-        '''
-        self.assertEqual(len(expected), len(actual))
-        for (diffs_exp, diffs_actual) in zip(expected, actual):
-            self.assertEqual(len(diffs_exp), len(diffs_actual))
-            for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
-                # operation should match
-                self.assertEqual(diff_exp[0], diff_actual[0])
-                # The diff as RRset should be equal (for simplicity we assume
-                # all RRsets contain exactly one RDATA)
-                self.assertEqual(diff_exp[1].get_name(),
-                                 diff_actual[1].get_name())
-                self.assertEqual(diff_exp[1].get_type(),
-                                 diff_actual[1].get_type())
-                self.assertEqual(diff_exp[1].get_class(),
-                                 diff_actual[1].get_class())
-                self.assertEqual(diff_exp[1].get_rdata_count(),
-                                 diff_actual[1].get_rdata_count())
-                self.assertEqual(1, diff_exp[1].get_rdata_count())
-                self.assertEqual(diff_exp[1].get_rdata()[0],
-                                 diff_actual[1].get_rdata()[0])
-
     def _create_a(self, address):
         rrset = RRset(Name('a.example.com'), TEST_RRCLASS, RRType.A(),
                       RRTTL(3600))
@@ -606,6 +653,11 @@ class TestXfrinConnection(unittest.TestCase):
         rrset.add_rdata(Rdata(RRType.SOA(), TEST_RRCLASS, rdata_str))
         return rrset
 
+    def _create_ns(self, nsname='ns.'+TEST_ZONE_NAME_STR):
+        rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.NS(), RRTTL(3600))
+        rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS, nsname))
+        return rrset
+
 class TestAXFR(TestXfrinConnection):
     def setUp(self):
         super().setUp()
@@ -1081,8 +1133,9 @@ class TestIXFRResponse(TestXfrinConnection):
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.diffs)
-        self.check_diffs([[('delete', begin_soa_rrset), ('add', soa_rrset)]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
     def test_ixfr_response_multi_sequences(self):
         '''Similar to the previous case, but with multiple diff seqs.
@@ -1107,19 +1160,20 @@ class TestIXFRResponse(TestXfrinConnection):
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.diffs)
-        self.check_diffs([[('delete', begin_soa_rrset),
-                           ('delete', self._create_a('192.0.2.1')),
-                           ('add', self._create_soa('1231')),
-                           ('add', self._create_a('192.0.2.2'))],
-                          [('delete', self._create_soa('1231')),
-                           ('delete', self._create_a('192.0.2.3')),
-                           ('add', self._create_soa('1232')),
-                           ('add', self._create_a('192.0.2.4'))],
-                          [('delete', self._create_soa('1232')),
-                           ('delete', self._create_a('192.0.2.5')),
-                           ('add', soa_rrset),
-                           ('add', self._create_a('192.0.2.6'))]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('delete', begin_soa_rrset),
+                      ('delete', self._create_a('192.0.2.1')),
+                      ('add', self._create_soa('1231')),
+                      ('add', self._create_a('192.0.2.2'))],
+                     [('delete', self._create_soa('1231')),
+                      ('delete', self._create_a('192.0.2.3')),
+                      ('add', self._create_soa('1232')),
+                      ('add', self._create_a('192.0.2.4'))],
+                     [('delete', self._create_soa('1232')),
+                      ('delete', self._create_a('192.0.2.5')),
+                      ('add', soa_rrset),
+                      ('add', self._create_a('192.0.2.6'))]],
+                    self.conn._datasrc_client.committed_diffs)
 
     def test_ixfr_response_multi_messages(self):
         '''Similar to the first case, but RRs span over multiple messages.
@@ -1133,8 +1187,9 @@ class TestIXFRResponse(TestXfrinConnection):
             answers=[soa_rrset])
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
-        self.check_diffs([[('delete', begin_soa_rrset), ('add', soa_rrset)]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
     def test_ixfr_response_broken(self):
         '''Test with a broken response.
@@ -1148,7 +1203,8 @@ class TestIXFRResponse(TestXfrinConnection):
         self.assertRaises(XfrinProtocolError,
                           self.conn._handle_xfrin_responses)
         # no diffs should have been committed
-        self.check_diffs([], self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [], self.conn._datasrc_client.committed_diffs)
 
     def test_ixfr_response_extra(self):
         '''Test with an extra RR after the end of IXFR diff sequences.
@@ -1163,8 +1219,64 @@ class TestIXFRResponse(TestXfrinConnection):
                      self._create_a('192.0.2.1')])
         self.assertRaises(XfrinProtocolError,
                           self.conn._handle_xfrin_responses)
-        self.check_diffs([[('delete', begin_soa_rrset), ('add', soa_rrset)]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
+
+    def test_ixfr_to_axfr_response(self):
+        '''AXFR-style IXFR response.
+
+        It simply updates the zone's SOA one time.
+
+        '''
+        ns_rr = self._create_ns()
+        a_rr = self._create_a('192.0.2.1')
+        self.conn.reply_data = self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[soa_rrset, ns_rr, a_rr, soa_rrset])
+        self.conn._handle_xfrin_responses()
+        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
+        self.assertEqual([], self.conn._datasrc_client.diffs)
+        # The SOA should be added exactly once, and in our implementation
+        # it should be added at the end of the sequence.
+        check_diffs(self.assertEqual,
+                    [[('add', ns_rr), ('add', a_rr), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
+
+    def test_ixfr_to_axfr_response_mismatch_soa(self):
+        '''AXFR-style IXFR response, but the two SOA are not the same.
+
+        In the current implementation, we accept it and use the second SOA.
+
+        '''
+        ns_rr = self._create_ns()
+        a_rr = self._create_a('192.0.2.1')
+        self.conn.reply_data = self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[soa_rrset, ns_rr, a_rr, begin_soa_rrset])
+        self.conn._handle_xfrin_responses()
+        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
+        self.assertEqual([], self.conn._datasrc_client.diffs)
+        check_diffs(self.assertEqual,
+                    [[('add', ns_rr), ('add', a_rr),
+                      ('add', begin_soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
+
+    def test_ixfr_to_axfr_response_extra(self):
+        '''Test with an extra RR after the end of AXFR-style IXFR session.
+
+        The session should be rejected, and nothing should be committed.
+
+        '''
+        ns_rr = self._create_ns()
+        a_rr = self._create_a('192.0.2.1')
+        self.conn.reply_data = self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[soa_rrset, ns_rr, a_rr, soa_rrset, a_rr])
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
+        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
+        self.assertEqual([], self.conn._datasrc_client.committed_diffs)
 
 class TestIXFRSession(TestXfrinConnection):
     '''Tests for a full IXFR session (query and response).
@@ -1187,8 +1299,9 @@ class TestIXFRSession(TestXfrinConnection):
 
         # Check some details of the IXFR protocol processing
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
-        self.check_diffs([[('delete', begin_soa_rrset), ('add', soa_rrset)]],
-                         self.conn._datasrc_client.committed_diffs)
+        check_diffs(self.assertEqual,
+                    [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
+                    self.conn._datasrc_client.committed_diffs)
 
         # Check if the query was IXFR.
         qdata = self.conn.query_data[2:]
@@ -1249,6 +1362,12 @@ class TestIXFRSessionWithSQLite3(TestXfrinConnection):
         self.assertEqual(1, soa.get_rdata_count())
         return get_soa_serial(soa.get_rdata()[0])
 
+    def record_exist(self, name, type):
+        result, finder = self.conn._datasrc_client.find_zone(TEST_ZONE_NAME)
+        self.assertEqual(DataSourceClient.SUCCESS, result)
+        result, soa = finder.find(name, type, None, ZoneFinder.FIND_DEFAULT)
+        return result == ZoneFinder.SUCCESS
+
     def test_do_xfrin_sqlite3(self):
         def create_ixfr_response():
             self.conn.reply_data = self.conn.create_response_data(
@@ -1280,6 +1399,47 @@ class TestIXFRSessionWithSQLite3(TestXfrinConnection):
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
         self.assertEqual(1230, self.get_zone_serial())
 
+    def test_do_xfrin_axfr_sqlite3(self):
+        '''AXFR-style IXFR.
+
+        '''
+        def create_ixfr_response():
+            self.conn.reply_data = self.conn.create_response_data(
+                questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS,
+                                    RRType.IXFR())],
+                answers=[soa_rrset, self._create_ns(), soa_rrset])
+        self.conn.response_generator = create_ixfr_response
+
+        # Confirm xfrin succeeds and SOA is updated, A RR is deleted.
+        self.assertEqual(1230, self.get_zone_serial())
+        self.assertTrue(self.record_exist(Name('dns01.example.com'),
+                                          RRType.A()))
+        self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
+        self.assertEqual(1234, self.get_zone_serial())
+        self.assertFalse(self.record_exist(Name('dns01.example.com'),
+                                           RRType.A()))
+
+    def test_do_xfrin_axfr_sqlite3_fail(self):
+        '''Similar to the previous test, but xfrin fails due to error.
+
+        Check the DB is not changed.
+
+        '''
+        def create_ixfr_response():
+            self.conn.reply_data = self.conn.create_response_data(
+                questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS,
+                                    RRType.IXFR())],
+                answers=[soa_rrset, self._create_ns(), soa_rrset, soa_rrset])
+        self.conn.response_generator = create_ixfr_response
+
+        self.assertEqual(1230, self.get_zone_serial())
+        self.assertTrue(self.record_exist(Name('dns01.example.com'),
+                                          RRType.A()))
+        self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
+        self.assertEqual(1230, self.get_zone_serial())
+        self.assertTrue(self.record_exist(Name('dns01.example.com'),
+                                          RRType.A()))
+
 class TestXfrinRecorder(unittest.TestCase):
     def setUp(self):
         self.recorder = XfrinRecorder()

+ 40 - 7
src/bin/xfrin/xfrin.py.in

@@ -142,17 +142,20 @@ class XfrinState:
     machine because they cannot be distinguished immediately - an AXFR
     response to an IXFR request can only be detected when the first two (2)
     response RRs have already been received.
-    NOTE: the AXFR part of the state machine is incomplete at this point.
 
     The following diagram summarizes the state transition.  After sending
     the query, xfrin starts the process with the InitialSOA state (all
     IXFR/AXFR response begins with an SOA).  When it reaches IXFREnd
-    (or AXFREnd, which is not yet implemented and not shown here), the
-    process successfully completes.
+    or AXFREnd, the process successfully completes.
 
 
-            (recv SOA)       (AXFR-style IXFR)
-    InitialSOA------->FirstData------------->AXFR
+            (recv SOA)       (AXFR-style IXFR)   (SOA, add)
+    InitialSOA------->FirstData------------->AXFR--------->AXFREnd
+                          |                  |  ^         (post xfr
+                          |                  |  |        checks, then
+                          |                  +--+        commit)
+                          |            (non SOA, add)
+                          |
                           |                     (non SOA, delete)
                (pure IXFR,|                           +-------+
             keep handling)|             (Delete SOA)  V       |
@@ -318,6 +321,9 @@ class XfrinFirstData(XfrinState):
         else:
             logger.debug(DBG_XFRIN_TRACE, XFRIN_GOT_NONINCREMENTAL_RESP,
                  conn.zone_str())
+            # We are now goint to add RRs to the new zone.  We need create
+            # a Diff object.  It will be used throughtout the XFR session.
+            conn._diff = Diff(conn._datasrc_client, conn._zone_name, True)
             self.set_xfrstate(conn, XfrinAXFR())
         return False
 
@@ -392,8 +398,35 @@ class XfrinIXFREnd(XfrinState):
 
 class XfrinAXFR(XfrinState):
     def handle_rr(self, conn, rr):
-        raise XfrinException('Falling back from IXFR to AXFR not ' +
-                             'supported yet')
+        """
+        Handle the RR by putting it into the zone.
+        """
+        conn._diff.add_data(rr)
+        if rr.get_type() == RRType.SOA():
+            # SOA means end.  Don't commit it yet - we need to perform
+            # post-transfer checks
+            self.set_xfrstate(conn, XfrinAXFREnd())
+        # Yes, we've eaten this RR.
+        return True
+
+class XfrinAXFREnd(XfrinState):
+    def handle_rr(self, conn, rr):
+        raise XfrinProtocolError('Extra data after the end of AXFR: ' +
+                                 rr.to_text())
+
+    def finish_message(self, conn):
+        """
+        Final processing after processing an entire AXFR session.
+
+        In this process all the AXFR changes are committed to the
+        data source.
+
+        There might be more actions here, but for now we simply return False,
+        indicating there will be no more message to receive.
+
+        """
+        conn._diff.commit()
+        return False
 
 class XfrinConnection(asyncore.dispatcher):
     '''Do xfrin in this class. '''

+ 4 - 2
src/lib/python/isc/xfrin/diff.py

@@ -59,7 +59,7 @@ class Diff:
     the changes to underlying data source right away, but keeps them for
     a while.
     """
-    def __init__(self, ds_client, zone):
+    def __init__(self, ds_client, zone, replace=False):
         """
         Initializes the diff to a ready state. It checks the zone exists
         in the datasource and if not, NoSuchZone is raised. This also creates
@@ -67,11 +67,13 @@ class Diff:
 
         The ds_client is the datasource client containing the zone. Zone is
         isc.dns.Name object representing the name of the zone (its apex).
+        If replace is true, the content of the whole zone is wiped out before
+        applying the diff.
 
         You can also expect isc.datasrc.Error or isc.datasrc.NotImplemented
         exceptions.
         """
-        self.__updater = ds_client.get_updater(zone, False)
+        self.__updater = ds_client.get_updater(zone, replace)
         if self.__updater is None:
             # The no such zone case
             raise NoSuchZone("Zone " + str(zone) +

+ 10 - 1
src/lib/python/isc/xfrin/tests/diff_tests.py

@@ -46,6 +46,7 @@ class DiffTest(unittest.TestCase):
         self.__commit_called = False
         self.__broken_called = False
         self.__warn_called = False
+        self.__should_replace = False
         # Some common values
         self.__rrclass = RRClass.IN()
         self.__type = RRType.A()
@@ -135,7 +136,7 @@ class DiffTest(unittest.TestCase):
         it returns self.
         """
         # The diff should not delete the old data.
-        self.assertFalse(replace)
+        self.assertEqual(self.__should_replace, replace)
         self.__updater_requested = True
         # Pretend this zone doesn't exist
         if zone_name == Name('none.example.org.'):
@@ -432,6 +433,14 @@ class DiffTest(unittest.TestCase):
         finally:
             isc.xfrin.diff.logger = orig_logger
 
+    def test_relpace(self):
+        """
+        Test that when we want to replace the whole zone, it is propagated.
+        """
+        self.__should_replace = True
+        diff = Diff(self, "example.org.", True)
+        self.assertTrue(self.__updater_requested)
+
 if __name__ == "__main__":
     isc.log.init("bind10")
     unittest.main()