Parcourir la source

[1261] refactoring: pass request type (IXFR or AXFR) from the command handler
to do_xfrin().

JINMEI Tatuya il y a 13 ans
Parent
commit
bff7aa9429
2 fichiers modifiés avec 23 ajouts et 17 suppressions
  1. 12 7
      src/bin/xfrin/tests/xfrin_test.py
  2. 11 10
      src/bin/xfrin/xfrin.py.in

+ 12 - 7
src/bin/xfrin/tests/xfrin_test.py

@@ -160,14 +160,15 @@ class MockXfrin(Xfrin):
             MockXfrin.check_command_hook()
 
     def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo,
-                    tsig_key, check_soa=True):
+                    tsig_key, request_type, check_soa=True):
         # store some of the arguments for verification, then call this
         # method in the superclass
         self.xfrin_started_master_addr = master_addrinfo[2][0]
         self.xfrin_started_master_port = master_addrinfo[2][1]
+        self.xfrin_started_request_type = request_type
         return Xfrin.xfrin_start(self, zone_name, rrclass, None,
                                  master_addrinfo, tsig_key,
-                                 check_soa)
+                                 request_type, check_soa)
 
 class MockXfrinConnection(XfrinConnection):
     def __init__(self, sock_map, zone_name, rrclass, db_file, shutdown_event,
@@ -1164,7 +1165,7 @@ class TestIXFRSession(TestXfrinConnection):
                                     RRType.IXFR())],
                 answers=[soa_rrset, begin_soa_rrset, soa_rrset, soa_rrset])
         self.conn.response_generator = create_ixfr_response
-        self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, True))
+        self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
 
         # Check some details of the IXFR protocol processing
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
@@ -1190,14 +1191,14 @@ class TestIXFRSession(TestXfrinConnection):
                 answers=[soa_rrset, begin_soa_rrset, soa_rrset,
                          self._create_soa('1235')])
         self.conn.response_generator = create_ixfr_response
-        self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, True))
+        self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
 
     def test_do_xfrin_fail(self):
         '''IXFR fails due to a bogus DNS message.
 
         '''
         self._create_broken_response_data()
-        self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, True))
+        self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
 
 class TestIXFRSessionWithSQLite3(TestXfrinConnection):
     '''Tests for IXFR sessions using an SQLite3 DB.
@@ -1240,7 +1241,7 @@ class TestIXFRSessionWithSQLite3(TestXfrinConnection):
 
         # Confirm xfrin succeeds and SOA is updated
         self.assertEqual(1230, self.get_zone_serial())
-        self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, True))
+        self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
         self.assertEqual(1234, self.get_zone_serial())
 
     def test_do_xfrin_sqlite3_fail(self):
@@ -1258,7 +1259,7 @@ class TestIXFRSessionWithSQLite3(TestXfrinConnection):
         self.conn.response_generator = create_ixfr_response
 
         self.assertEqual(1230, self.get_zone_serial())
-        self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, True))
+        self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
         self.assertEqual(1230, self.get_zone_serial())
 
 class TestXfrinRecorder(unittest.TestCase):
@@ -1386,6 +1387,8 @@ class TestXfrin(unittest.TestCase):
                                                   self.args)['result'][0], 0)
         self.assertEqual(self.args['master'], self.xfr.xfrin_started_master_addr)
         self.assertEqual(int(self.args['port']), self.xfr.xfrin_started_master_port)
+        # By default we use AXFR (for now)
+        self.assertEqual(RRType.AXFR(), self.xfr.xfrin_started_request_type)
 
     def test_command_handler_retransfer_short_command1(self):
         # try it when only specifying the zone name (of unknown zone)
@@ -1498,6 +1501,8 @@ class TestXfrin(unittest.TestCase):
                          self.xfr.xfrin_started_master_addr)
         self.assertEqual(int(TEST_MASTER_PORT),
                          self.xfr.xfrin_started_master_port)
+        # By default we use AXFR (for now)
+        self.assertEqual(RRType.AXFR(), self.xfr.xfrin_started_request_type)
 
     def test_command_handler_notify(self):
         # at this level, refresh is no different than retransfer.

+ 11 - 10
src/bin/xfrin/xfrin.py.in

@@ -481,25 +481,26 @@ class XfrinConnection(asyncore.dispatcher):
         # now.
         return XFRIN_OK
 
-    def do_xfrin(self, check_soa, ixfr_first=False):
+    def do_xfrin(self, check_soa, request_type=RRType.AXFR()):
         '''Do an xfr session by sending xfr request and parsing responses.'''
 
         try:
             ret = XFRIN_OK
+            self._request_type = request_type
             if check_soa:
                 logstr = 'SOA check for \'%s\' ' % self.zone_str()
                 ret =  self._check_soa_serial()
 
             if ret == XFRIN_OK:
-                if ixfr_first:
+                if self._request_type == RRType.IXFR():
                     # TODO: log it
                     self._request_type = RRType.IXFR()
-                    self._send_query(RRType.IXFR())
+                    self._send_query(self._request_type)
                     self.__state = XfrinInitialSOA()
                     self._handle_xfrin_responses()
                 else:
                     logger.info(XFRIN_AXFR_TRANSFER_STARTED, self.zone_str())
-                    self._send_query(RRType.AXFR())
+                    self._send_query(self._request_type)
                     isc.datasrc.sqlite3_ds.load(self._db_file,
                                                 self._zone_name.to_text(),
                                                 self._handle_axfrin_response)
@@ -660,7 +661,7 @@ class XfrinConnection(asyncore.dispatcher):
 
 def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file,
                   shutdown_event, master_addrinfo, check_soa, verbose,
-                  tsig_key):
+                  tsig_key, request_type):
     xfrin_recorder.increment(zone_name)
 
     # Create a data source client used in this XFR session.  Right now we
@@ -680,7 +681,7 @@ def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file,
                            tsig_key, verbose)
     ret = XFRIN_FAIL
     if conn.connect_to_master():
-        ret = conn.do_xfrin(check_soa)
+        ret = conn.do_xfrin(check_soa, request_type)
 
     # Publish the zone transfer result news, so zonemgr can reset the
     # zone timer, and xfrout can notify the zone's slaves if the result
@@ -924,7 +925,7 @@ class Xfrin:
                                            rrclass,
                                            self._get_db_file(),
                                            master_addr,
-                                           zone_info.tsig_key,
+                                           zone_info.tsig_key, RRType.AXFR(),
                                            True)
                     answer = create_answer(ret[0], ret[1])
 
@@ -944,7 +945,7 @@ class Xfrin:
                                        rrclass,
                                        db_file,
                                        master_addr,
-                                       tsig_key,
+                                       tsig_key, RRType.AXFR(),
                                        (False if command == 'retransfer' else True))
                 answer = create_answer(ret[0], ret[1])
 
@@ -1062,7 +1063,7 @@ class Xfrin:
             self._cc_check_command()
 
     def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo,
-                    tsig_key, check_soa=True):
+                    tsig_key, request_type, check_soa=True):
         if "pydnspp" not in sys.modules:
             return (1, "xfrin failed, can't load dns message python library: 'pydnspp'")
 
@@ -1082,7 +1083,7 @@ class Xfrin:
                                                 self._shutdown_event,
                                                 master_addrinfo, check_soa,
                                                 self._verbose,
-                                                tsig_key))
+                                                tsig_key, request_type))
 
         xfrin_thread.start()
         return (0, 'zone xfrin is started')