Browse Source

[master] [1299] cleanup, and added one more test (the case where the local SOA is unknown)

JINMEI Tatuya 13 years ago
parent
commit
50fdb098fc
2 changed files with 29 additions and 21 deletions
  1. 11 0
      src/bin/xfrin/tests/xfrin_test.py
  2. 18 21
      src/bin/xfrin/xfrin.py.in

+ 11 - 0
src/bin/xfrin/tests/xfrin_test.py

@@ -980,6 +980,17 @@ class TestAXFR(TestXfrinConnection):
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
         self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
         self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
 
 
+    def test_soacheck_newzone(self):
+        # Primary's SOA is 'old', but this secondary doesn't know anything
+        # about the zone yet, so it should accept it.
+        def response_generator():
+            # _request_serial is set in _check_soa_serial().  Reset it here.
+            self.conn._request_serial = None
+            self._create_soa_response_data()
+        self.soa_response_params['answers'] = [begin_soa_rrset]
+        self.conn.response_generator = response_generator
+        self.assertEqual(XFRIN_OK, self.conn._check_soa_serial())
+
     def test_soacheck_question_empty(self):
     def test_soacheck_question_empty(self):
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
         self.soa_response_params['questions'] = []
         self.soa_response_params['questions'] = []

+ 18 - 21
src/bin/xfrin/xfrin.py.in

@@ -764,10 +764,12 @@ class XfrinConnection(asyncore.dispatcher):
 
 
 
 
     def _check_soa_serial(self):
     def _check_soa_serial(self):
-        ''' Compare the soa serial, if soa serial in master is less than
-        the soa serial in local, Finish xfrin.
-        False: soa serial in master is less or equal to the local one.
-        True: soa serial in master is bigger
+        '''Send SOA query and compare the local and remote serials.
+
+        If we know our local serial and the remote serial isn't newer
+        than ours, we abort the session with XfrinZoneUptodate.
+        On success it returns XFRIN_OK for testing.  The caller won't use it.
+
         '''
         '''
 
 
         self._send_query(RRType.SOA())
         self._send_query(RRType.SOA())
@@ -781,6 +783,7 @@ class XfrinConnection(asyncore.dispatcher):
         # from the answer section
         # from the answer section
         soa = self.__parse_soa_response(msg, soa_response)
         soa = self.__parse_soa_response(msg, soa_response)
 
 
+        # Compare the two serials.  If ours is 'new', abort with ZoneUptodate.
         primary_serial = get_soa_serial(soa.get_rdata()[0])
         primary_serial = get_soa_serial(soa.get_rdata()[0])
         if self._request_serial is not None and \
         if self._request_serial is not None and \
                 self._request_serial >= primary_serial:
                 self._request_serial >= primary_serial:
@@ -791,9 +794,6 @@ class XfrinConnection(asyncore.dispatcher):
                             self._request_serial)
                             self._request_serial)
             raise XfrinZoneUptodate
             raise XfrinZoneUptodate
 
 
-        # TODO, need select soa record from data source then compare the two
-        # serial, current just return OK, since this function hasn't been used
-        # now.
         return XFRIN_OK
         return XFRIN_OK
 
 
     def do_xfrin(self, check_soa, request_type=RRType.AXFR()):
     def do_xfrin(self, check_soa, request_type=RRType.AXFR()):
@@ -804,31 +804,28 @@ class XfrinConnection(asyncore.dispatcher):
             self._request_type = request_type
             self._request_type = request_type
             # Right now RRType.[IA]XFR().to_text() is 'TYPExxx', so we need
             # Right now RRType.[IA]XFR().to_text() is 'TYPExxx', so we need
             # to hardcode here.
             # to hardcode here.
-            request_str = 'IXFR' if request_type == RRType.IXFR() else 'AXFR'
+            req_str = 'IXFR' if request_type == RRType.IXFR() else 'AXFR'
             if check_soa:
             if check_soa:
-                ret =  self._check_soa_serial()
+                self._check_soa_serial()
 
 
-            if ret == XFRIN_OK:
-                logger.info(XFRIN_XFR_TRANSFER_STARTED, request_str,
-                            self.zone_str())
-                self._send_query(self._request_type)
-                self.__state = XfrinInitialSOA()
-                self._handle_xfrin_responses()
-                logger.info(XFRIN_XFR_TRANSFER_SUCCESS, request_str,
-                            self.zone_str())
+            logger.info(XFRIN_XFR_TRANSFER_STARTED, req_str, self.zone_str())
+            self._send_query(self._request_type)
+            self.__state = XfrinInitialSOA()
+            self._handle_xfrin_responses()
+            logger.info(XFRIN_XFR_TRANSFER_SUCCESS, req_str, self.zone_str())
 
 
         except XfrinZoneUptodate:
         except XfrinZoneUptodate:
             # Eventually we'll probably have to treat this case as a trigger
             # Eventually we'll probably have to treat this case as a trigger
-            # of trying another primary server, etc, but for now We treat it
+            # of trying another primary server, etc, but for now we treat it
             # as "success".
             # as "success".
             pass
             pass
         except XfrinProtocolError as e:
         except XfrinProtocolError as e:
-            logger.info(XFRIN_XFR_TRANSFER_PROTOCOL_ERROR, request_str,
+            logger.info(XFRIN_XFR_TRANSFER_PROTOCOL_ERROR, req_str,
                         self.zone_str(),
                         self.zone_str(),
                         format_addrinfo(self._master_addrinfo), str(e))
                         format_addrinfo(self._master_addrinfo), str(e))
             ret = XFRIN_FAIL
             ret = XFRIN_FAIL
         except XfrinException as e:
         except XfrinException as e:
-            logger.error(XFRIN_XFR_TRANSFER_FAILURE, request_str,
+            logger.error(XFRIN_XFR_TRANSFER_FAILURE, req_str,
                          self.zone_str(),
                          self.zone_str(),
                          format_addrinfo(self._master_addrinfo), str(e))
                          format_addrinfo(self._master_addrinfo), str(e))
             ret = XFRIN_FAIL
             ret = XFRIN_FAIL
@@ -841,7 +838,7 @@ class XfrinConnection(asyncore.dispatcher):
             # catch it here, but until then we need broadest coverage so that
             # catch it here, but until then we need broadest coverage so that
             # we won't miss anything.
             # we won't miss anything.
 
 
-            logger.error(XFRIN_XFR_OTHER_FAILURE, request_str,
+            logger.error(XFRIN_XFR_OTHER_FAILURE, req_str,
                          self.zone_str(), str(e))
                          self.zone_str(), str(e))
             ret = XFRIN_FAIL
             ret = XFRIN_FAIL
         finally:
         finally: