Browse Source

[1261] updated do_xfrin() for the IXFR case

JINMEI Tatuya 13 years ago
parent
commit
461acc1e4b
2 changed files with 100 additions and 66 deletions
  1. 93 65
      src/bin/xfrin/tests/xfrin_test.py
  2. 7 1
      src/bin/xfrin/xfrin.py.in

+ 93 - 65
src/bin/xfrin/tests/xfrin_test.py

@@ -84,6 +84,43 @@ class MockDataSourceClient():
         self.committed_diffs = []
         self.diffs = []
 
+    def find_zone(self, zone_name):
+        '''Mock version of find_zone().
+
+        It returns itself (subsequently acting as a mock ZoneFinder) for
+        some test zone names.  For some others it returns either NOTFOUND
+        or PARTIALMATCH.
+
+        '''
+        if zone_name == TEST_ZONE_NAME or \
+                zone_name == Name('no-soa.example') or \
+                zone_name == Name('dup-soa.example'):
+            return (isc.datasrc.DataSourceClient.SUCCESS, self)
+        elif zone_name == Name('no-such-zone.example'):
+            return (DataSourceClient.NOTFOUND, None)
+        elif zone_name == Name('partial-match-zone.example'):
+            return (DataSourceClient.PARTIALMATCH, self)
+        raise ValueError('Unexpected input to mock client: bug in test case?')
+
+    def find(self, name, rrtype, target, options):
+        '''Mock ZoneFinder.find().
+
+        It returns the predefined SOA RRset to queries for SOA of the common
+        test zone name.  It also emulates some unusual cases for special
+        zone names.
+
+        '''
+        if name == TEST_ZONE_NAME and rrtype == RRType.SOA():
+            return (ZoneFinder.SUCCESS, begin_soa_rrset)
+        if name == Name('no-soa.example'):
+            return (ZoneFinder.NXDOMAIN, None)
+        if name == Name('dup-soa.example'):
+            dup_soa_rrset = RRset(name, TEST_RRCLASS, RRType.SOA(), RRTTL(0))
+            dup_soa_rrset.add_rdata(begin_soa_rdata)
+            dup_soa_rrset.add_rdata(soa_rdata)
+            return (ZoneFinder.SUCCESS, dup_soa_rrset)
+        raise ValueError('Unexpected input to mock finder: bug in test case?')
+
     def get_updater(self, zone_name, replace):
         return self
 
@@ -144,44 +181,7 @@ class MockXfrinConnection(XfrinConnection):
     # The following three implement a simplified mock of DataSourceClient
     # and ZoneFinder classes for testing purposes.
     def _get_datasrc_client(self, rrclass):
-        return self
-
-    def find_zone(self, zone_name):
-        '''Mock DataSourceClient.find_zone().
-
-        It returns itself (subsequently acting as a mock ZoneFinder) for
-        some test zone names.  For some others it returns either NOTFOUND
-        or PARTIALMATCH.
-
-        '''
-        if zone_name == TEST_ZONE_NAME or \
-                zone_name == Name('no-soa.example') or \
-                zone_name == Name('dup-soa.example'):
-            return (isc.datasrc.DataSourceClient.SUCCESS, self)
-        elif zone_name == Name('no-such-zone.example'):
-            return (DataSourceClient.NOTFOUND, None)
-        elif zone_name == Name('partial-match-zone.example'):
-            return (DataSourceClient.PARTIALMATCH, self)
-        raise ValueError('Unexpected input to mock client: bug in test case?')
-
-    def find(self, name, rrtype, target, options):
-        '''Mock ZoneFinder.find().
-
-        It returns the predefined SOA RRset to queries for SOA of the common
-        test zone name.  It also emulates some unusual cases for special
-        zone names.
-
-        '''
-        if name == TEST_ZONE_NAME and rrtype == RRType.SOA():
-            return (ZoneFinder.SUCCESS, begin_soa_rrset)
-        if name == Name('no-soa.example'):
-            return (ZoneFinder.NXDOMAIN, None)
-        if name == Name('dup-soa.example'):
-            dup_soa_rrset = RRset(name, TEST_RRCLASS, RRType.SOA(), RRTTL(0))
-            dup_soa_rrset.add_rdata(begin_soa_rdata)
-            dup_soa_rrset.add_rdata(soa_rdata)
-            return (ZoneFinder.SUCCESS, dup_soa_rrset)
-        raise ValueError('Unexpected input to mock finder: bug in test case?')
+        return MockDataSourceClient()
 
     def _asyncore_loop(self):
         if self.force_close:
@@ -196,7 +196,8 @@ class MockXfrinConnection(XfrinConnection):
         data = self.reply_data[:size]
         self.reply_data = self.reply_data[size:]
         if len(data) < size:
-            raise XfrinTestException('cannot get reply data')
+            raise XfrinTestException('cannot get reply data (' + str(size) +
+                                     ' bytes)')
         return data
 
     def send(self, data):
@@ -507,7 +508,7 @@ class TestXfrinConnection(unittest.TestCase):
 
     def _create_normal_response_data(self):
         # This helper method creates a simple sequence of DNS messages that
-        # forms a valid XFR transaction.  It consists of two messages, each
+        # forms a valid AXFR transaction.  It consists of two messages, each
         # containing just a single SOA RR.
         tsig_1st = self.axfr_response_params['tsig_1st']
         tsig_2nd = self.axfr_response_params['tsig_2nd']
@@ -550,6 +551,30 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data)))
         self.conn.reply_data += bogus_data
 
+    def check_diffs(self, expected, actual):
+        '''A helper method checking the differences made in the IXFR session.
+
+        '''
+        self.assertEqual(len(expected), len(actual))
+        for (diffs_exp, diffs_actual) in zip(expected, actual):
+            self.assertEqual(len(diffs_exp), len(diffs_actual))
+            for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
+                # operation should match
+                self.assertEqual(diff_exp[0], diff_actual[0])
+                # The diff as RRset should be equal (for simplicity we assume
+                # all RRsets contain exactly one RDATA)
+                self.assertEqual(diff_exp[1].get_name(),
+                                 diff_actual[1].get_name())
+                self.assertEqual(diff_exp[1].get_type(),
+                                 diff_actual[1].get_type())
+                self.assertEqual(diff_exp[1].get_class(),
+                                 diff_actual[1].get_class())
+                self.assertEqual(diff_exp[1].get_rdata_count(),
+                                 diff_actual[1].get_rdata_count())
+                self.assertEqual(1, diff_exp[1].get_rdata_count())
+                self.assertEqual(diff_exp[1].get_rdata()[0],
+                                 diff_actual[1].get_rdata()[0])
+
 class TestAXFR(TestXfrinConnection):
     def setUp(self):
         super().setUp()
@@ -1006,7 +1031,7 @@ class TestAXFR(TestXfrinConnection):
         self.conn.response_generator = self._create_soa_response_data
         self.assertEqual(self.conn.do_xfrin(True), XFRIN_FAIL)
 
-class TestIXFR(TestXfrinConnection):
+class TestIXFRResponse(TestXfrinConnection):
     def setUp(self):
         super().setUp()
         self.conn._query_id = self.conn.qid = 1035
@@ -1029,30 +1054,6 @@ class TestIXFR(TestXfrinConnection):
         rrset.add_rdata(Rdata(RRType.SOA(), TEST_RRCLASS, rdata_str))
         return rrset
 
-    def check_diffs(self, expected, actual):
-        '''A helper method checking the differences made in the IXFR session.
-
-        '''
-        self.assertEqual(len(expected), len(actual))
-        for (diffs_exp, diffs_actual) in zip(expected, actual):
-            self.assertEqual(len(diffs_exp), len(diffs_actual))
-            for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
-                # operation should match
-                self.assertEqual(diff_exp[0], diff_actual[0])
-                # The diff as RRset should be equal (for simplicity we assume
-                # all RRsets contain exactly one RDATA)
-                self.assertEqual(diff_exp[1].get_name(),
-                                 diff_actual[1].get_name())
-                self.assertEqual(diff_exp[1].get_type(),
-                                 diff_actual[1].get_type())
-                self.assertEqual(diff_exp[1].get_class(),
-                                 diff_actual[1].get_class())
-                self.assertEqual(diff_exp[1].get_rdata_count(),
-                                 diff_actual[1].get_rdata_count())
-                self.assertEqual(1, diff_exp[1].get_rdata_count())
-                self.assertEqual(diff_exp[1].get_rdata()[0],
-                                 diff_actual[1].get_rdata()[0])
-
     def test_ixfr_response(self):
         '''A simplest form of IXFR response.
 
@@ -1118,6 +1119,33 @@ class TestIXFR(TestXfrinConnection):
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
 
+class TestIXFRSession(TestXfrinConnection):
+    def setUp(self):
+        super().setUp()
+
+    def test_do_xfrin(self):
+        def create_ixfr_response():
+            self.conn.reply_data = self.conn.create_response_data(
+                questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS,
+                                    RRType.IXFR())],
+                answers=[soa_rrset, begin_soa_rrset, soa_rrset, soa_rrset])
+        self.conn.response_generator = create_ixfr_response
+        self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, True))
+
+        # Check some details of the IXFR protocol processing
+        self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
+        self.check_diffs([[('remove', begin_soa_rrset), ('add', soa_rrset)]],
+                         self.conn._datasrc_client.committed_diffs)
+
+        # Check if the query was IXFR.  We only check for the RR type.  Other
+        # details are tested in test_create_query().
+        qdata = self.conn.query_data[2:]
+        qmsg = Message(Message.PARSE)
+        qmsg.from_wire(qdata, len(qdata))
+        self.assertEqual(1, qmsg.get_rr_count(Message.SECTION_QUESTION))
+        self.assertEqual(TEST_ZONE_NAME, qmsg.get_question()[0].get_name())
+        self.assertEqual(RRType.IXFR(), qmsg.get_question()[0].get_type())
+
 class TestXfrinRecorder(unittest.TestCase):
     def setUp(self):
         self.recorder = XfrinRecorder()

+ 7 - 1
src/bin/xfrin/xfrin.py.in

@@ -473,7 +473,13 @@ class XfrinConnection(asyncore.dispatcher):
                 ret =  self._check_soa_serial()
 
             if ret == XFRIN_OK:
-                if not ixfr_first:
+                if ixfr_first:
+                    # TODO: log it
+                    self._request_type = RRType.IXFR()
+                    self._send_query(RRType.IXFR())
+                    self.__state = XfrinInitialSOA()
+                    self._handle_xfrin_responses()
+                else:
                     logger.info(XFRIN_AXFR_TRANSFER_STARTED, self.zone_str())
                     self._send_query(RRType.AXFR())
                     isc.datasrc.sqlite3_ds.load(self._db_file,