Parcourir la source

[1288] make sure transfers_counter is always reset whateven happens within
XfroutSession(). The original code was already buggy in this sense, but
with the newer data source API it will be a bit more likely to happen
due to the generality of the API, so it would make sense to fix it here.

JINMEI Tatuya il y a 13 ans
Parent
commit
ca42fb6438
2 fichiers modifiés avec 112 ajouts et 52 suppressions
  1. 75 37
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 37 15
      src/bin/xfrout/xfrout.py.in

+ 75 - 37
src/bin/xfrout/tests/xfrout_test.py.in

@@ -64,10 +64,40 @@ class MySocket():
     def clear_send(self):
         del self.sendqueue[:]
 
-# We subclass the Session class we're testing here, only
-# to override the handle() and _send_data() method
+class MockDataSrcClient:
+    def __init__(self, type, config):
+        pass
+
+    def get_iterator(self, zone_name):
+        if zone_name == Name('notauth.example.com'):
+            raise isc.datasrc.Error('no such zone')
+        self._zone_name = zone_name
+        return self
+
+    def get_soa(self):  # emulate ZoneIterator.get_soa()
+        if self._zone_name == Name('nosoa.example.com'):
+            return None
+        soa_rrset = RRset(Name('multisoa.example.com'), RRClass.IN(),
+                          RRType.SOA(), RRTTL(3600))
+        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
+                                  'master.example.com. ' +
+                                  'admin.example.com. 1234 ' +
+                                  '3600 1800 2419200 7200'))
+        if self._zone_name == Name('multisoa.example.com'):
+            soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
+                                      'master.example.com. ' +
+                                      'admin.example.com. 1300 ' +
+                                      '3600 1800 2419200 7200'))
+            return soa_rrset
+        return soa_rrset
+
+# We subclass the Session class we're testing here, only overriding a few
+# methods
 class MyXfroutSession(XfroutSession):
-    def handle(self):
+    def _handle(self):
+        pass
+
+    def _close_socket(self):
         pass
 
     def _send_data(self, sock, data):
@@ -80,12 +110,14 @@ class MyXfroutSession(XfroutSession):
 class Dbserver:
     def __init__(self):
         self._shutdown_event = threading.Event()
+        self.transfer_counter = 0
     def get_db_file(self):
         return 'test.sqlite3'
     def increase_transfers_counter(self):
+        self.transfer_counter += 1
         return True
     def decrease_transfers_counter(self):
-        pass
+        self.transfer_counter -= 1
 
 class TestXfroutSession(unittest.TestCase):
     def getmsg(self):
@@ -139,6 +171,45 @@ class TestXfroutSession(unittest.TestCase):
                                        'admin.exAmple.com. ' +
                                        '1234 3600 1800 2419200 7200'))
 
+    def tearDown(self):
+        # transfer_counter must be always be reset no matter happens within
+        # the XfroutSession object.  We check the condition here.
+        self.assertEqual(0, self.xfrsess._server.transfer_counter)
+
+    def test_quota_error(self):
+        '''Emulating the server being too busy.
+
+        '''
+        self.xfrsess._request_data = self.mdata
+        self.xfrsess._server.increase_transfers_counter = lambda : False
+        XfroutSession._handle(self.xfrsess)
+        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.REFUSED())
+
+    def test_quota_ok(self):
+        '''The default case in terms of the xfrout quota.
+
+        '''
+        # set up a bogus request, which should result in FORMERR. (it only
+        # has to be something that is different from the previous case)
+        self.xfrsess._request_data = \
+            self.create_request_data(with_question=False)
+        # Replace the data source client to avoid datasrc related exceptions
+        self.xfrsess.ClientClass = MockDataSrcClient
+        XfroutSession._handle(self.xfrsess)
+        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.FORMERR())
+
+    def test_exception_from_session(self):
+        '''Test the case where the main processing raises an exception.
+
+        We just check it doesn't any unexpected disruption and (in teraDown)
+        transfer_counter is correctly reset to 0.
+
+        '''
+        def dns_xfrout_start(fd, msg, quota):
+            raise ValueError('fake exception')
+        self.xfrsess.dns_xfrout_start = dns_xfrout_start
+        XfroutSession._handle(self.xfrsess)
+
     def test_parse_query_message(self):
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(get_rcode.to_text(), "NOERROR")
@@ -520,32 +591,6 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(82, get_rrset_len(self.soa_rrset))
 
     def test_check_xfrout_available(self):
-        class MockDataSrcClient:
-            def __init__(self, type, config): pass
-
-            def get_iterator(self, zone_name):
-                if zone_name == Name('notauth.example.com'):
-                    raise isc.datasrc.Error('no such zone')
-                self._zone_name = zone_name
-                return self
-
-            def get_soa(self):  # emulate ZoneIterator.get_soa()
-                if self._zone_name == Name('nosoa.example.com'):
-                    return None
-                soa_rrset = RRset(Name('multisoa.example.com'), RRClass.IN(),
-                                  RRType.SOA(), RRTTL(3600))
-                soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
-                                          'master.example.com. ' +
-                                          'admin.example.com. 1234 ' +
-                                          '3600 1800 2419200 7200'))
-                if self._zone_name == Name('multisoa.example.com'):
-                    soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
-                                                   'master.example.com. ' +
-                                                   'admin.example.com. 1300 ' +
-                                                   '3600 1800 2419200 7200'))
-                    return soa_rrset
-                return soa_rrset
-
         self.xfrsess.ClientClass = MockDataSrcClient
         self.assertEqual(self.xfrsess._check_xfrout_available(
                 Name('notauth.example.com')), Rcode.NOTAUTH())
@@ -554,13 +599,6 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(self.xfrsess._check_xfrout_available(
                 Name('multisoa.example.com')), Rcode.SERVFAIL())
 
-        self.xfrsess._server.increase_transfers_counter = lambda : False
-        self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('example.com')), Rcode.REFUSED())
-        self.xfrsess._server.increase_transfers_counter = lambda : True
-        self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('example.com')), Rcode.NOERROR())
-
     def test_dns_xfrout_start_formerror(self):
         # formerror
         self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")

+ 37 - 15
src/bin/xfrout/xfrout.py.in

@@ -128,21 +128,46 @@ class XfroutSession():
         self._zone_config = zone_config
         self.ClientClass = client_class # parameterize this for testing
         self._soa = None # will be set in _check_xfrout_available or in tests
-        self.handle()
+        self._handle()
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
         return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
                            tsig_key_ring)
 
-    def handle(self):
-        ''' Handle a xfrout query, send xfrout response '''
+    def _handle(self):
+        ''' Handle a xfrout query, send xfrout response(s).
+
+        This is separated from the constructor so that we can override
+        it from tests.
+
+        '''
+        # Check the xfrout quota.  We do both increase/decrease in this
+        # method so it's clear we always release it once acuired.
+        quota_ok = self._server.increase_transfers_counter()
+        ex = None
         try:
-            self.dns_xfrout_start(self._sock_fd, self._request_data)
-            #TODO, avoid catching all exceptions
+            self.dns_xfrout_start(self._sock_fd, self._request_data, quota_ok)
         except Exception as e:
-            logger.error(XFROUT_HANDLE_QUERY_ERROR, e)
-            pass
+            # To avoid resource leak we need catch all possible exceptions
+            # We log it later to exclude the case where even logger raises
+            # an exception.
+            ex = e
+
+        # Release any critical resources
+        if quota_ok:
+            self._server.decrease_transfers_counter()
+        self._close_socket()
+
+        if ex is not None:
+            logger.error(XFROUT_HANDLE_QUERY_ERROR, ex)
+
+    def _close_socket(self):
+        '''Simply close the socket via the given FD.
 
+        This is a dedicated subroutine of handle() and is sepsarated from it
+        for the convenience of tests.
+
+        '''
         os.close(self._sock_fd)
 
     def _check_request_tsig(self, msg, request_data):
@@ -252,12 +277,8 @@ class XfroutSession():
         '''Check if xfr request can be responsed.
            TODO, Get zone's configuration from cfgmgr or some other place
            eg. check allow_transfer setting,
-        '''
 
-        # Reject the attempt if we are too busy.  Check this first to avoid
-        # unnecessary resource consumption even if we discard it soon.
-        if not self._server.increase_transfers_counter():
-            return Rcode.REFUSED()
+        '''
 
         # Identify the data source for the requested zone and see if it has
         # SOA while initializing objects used for request processing later.
@@ -292,7 +313,7 @@ class XfroutSession():
         return Rcode.NOERROR()
 
 
-    def dns_xfrout_start(self, sock_fd, msg_query):
+    def dns_xfrout_start(self, sock_fd, msg_query, quota_ok=True):
         rcode_, msg = self._parse_query_message(msg_query)
         #TODO. create query message and parse header
         if rcode_ is None: # Dropped by ACL
@@ -302,6 +323,9 @@ class XfroutSession():
         elif rcode_ != Rcode.NOERROR():
             return self._reply_query_with_error_rcode(msg, sock_fd,
                                                       Rcode.FORMERR())
+        elif not quota_ok:
+            return self._reply_query_with_error_rcode(msg, sock_fd,
+                                                      Rcode.REFUSED())
 
         question = msg.get_question()[0]
         zone_name = question.get_name()
@@ -322,8 +346,6 @@ class XfroutSession():
             pass
         logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_str)
 
-        self._server.decrease_transfers_counter()
-
     def _clear_message(self, msg):
         qid = msg.get_qid()
         opcode = msg.get_opcode()