Browse Source

[1288] make sure transfers_counter is always reset whateven happens within
XfroutSession(). The original code was already buggy in this sense, but
with the newer data source API it will be a bit more likely to happen
due to the generality of the API, so it would make sense to fix it here.

JINMEI Tatuya 13 years ago
parent
commit
ca42fb6438
2 changed files with 112 additions and 52 deletions
  1. 75 37
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 37 15
      src/bin/xfrout/xfrout.py.in

+ 75 - 37
src/bin/xfrout/tests/xfrout_test.py.in

@@ -64,10 +64,40 @@ class MySocket():
     def clear_send(self):
     def clear_send(self):
         del self.sendqueue[:]
         del self.sendqueue[:]
 
 
-# We subclass the Session class we're testing here, only
-# to override the handle() and _send_data() method
+class MockDataSrcClient:
+    def __init__(self, type, config):
+        pass
+
+    def get_iterator(self, zone_name):
+        if zone_name == Name('notauth.example.com'):
+            raise isc.datasrc.Error('no such zone')
+        self._zone_name = zone_name
+        return self
+
+    def get_soa(self):  # emulate ZoneIterator.get_soa()
+        if self._zone_name == Name('nosoa.example.com'):
+            return None
+        soa_rrset = RRset(Name('multisoa.example.com'), RRClass.IN(),
+                          RRType.SOA(), RRTTL(3600))
+        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
+                                  'master.example.com. ' +
+                                  'admin.example.com. 1234 ' +
+                                  '3600 1800 2419200 7200'))
+        if self._zone_name == Name('multisoa.example.com'):
+            soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
+                                      'master.example.com. ' +
+                                      'admin.example.com. 1300 ' +
+                                      '3600 1800 2419200 7200'))
+            return soa_rrset
+        return soa_rrset
+
+# We subclass the Session class we're testing here, only overriding a few
+# methods
 class MyXfroutSession(XfroutSession):
 class MyXfroutSession(XfroutSession):
-    def handle(self):
+    def _handle(self):
+        pass
+
+    def _close_socket(self):
         pass
         pass
 
 
     def _send_data(self, sock, data):
     def _send_data(self, sock, data):
@@ -80,12 +110,14 @@ class MyXfroutSession(XfroutSession):
 class Dbserver:
 class Dbserver:
     def __init__(self):
     def __init__(self):
         self._shutdown_event = threading.Event()
         self._shutdown_event = threading.Event()
+        self.transfer_counter = 0
     def get_db_file(self):
     def get_db_file(self):
         return 'test.sqlite3'
         return 'test.sqlite3'
     def increase_transfers_counter(self):
     def increase_transfers_counter(self):
+        self.transfer_counter += 1
         return True
         return True
     def decrease_transfers_counter(self):
     def decrease_transfers_counter(self):
-        pass
+        self.transfer_counter -= 1
 
 
 class TestXfroutSession(unittest.TestCase):
 class TestXfroutSession(unittest.TestCase):
     def getmsg(self):
     def getmsg(self):
@@ -139,6 +171,45 @@ class TestXfroutSession(unittest.TestCase):
                                        'admin.exAmple.com. ' +
                                        'admin.exAmple.com. ' +
                                        '1234 3600 1800 2419200 7200'))
                                        '1234 3600 1800 2419200 7200'))
 
 
+    def tearDown(self):
+        # transfer_counter must be always be reset no matter happens within
+        # the XfroutSession object.  We check the condition here.
+        self.assertEqual(0, self.xfrsess._server.transfer_counter)
+
+    def test_quota_error(self):
+        '''Emulating the server being too busy.
+
+        '''
+        self.xfrsess._request_data = self.mdata
+        self.xfrsess._server.increase_transfers_counter = lambda : False
+        XfroutSession._handle(self.xfrsess)
+        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.REFUSED())
+
+    def test_quota_ok(self):
+        '''The default case in terms of the xfrout quota.
+
+        '''
+        # set up a bogus request, which should result in FORMERR. (it only
+        # has to be something that is different from the previous case)
+        self.xfrsess._request_data = \
+            self.create_request_data(with_question=False)
+        # Replace the data source client to avoid datasrc related exceptions
+        self.xfrsess.ClientClass = MockDataSrcClient
+        XfroutSession._handle(self.xfrsess)
+        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.FORMERR())
+
+    def test_exception_from_session(self):
+        '''Test the case where the main processing raises an exception.
+
+        We just check it doesn't any unexpected disruption and (in teraDown)
+        transfer_counter is correctly reset to 0.
+
+        '''
+        def dns_xfrout_start(fd, msg, quota):
+            raise ValueError('fake exception')
+        self.xfrsess.dns_xfrout_start = dns_xfrout_start
+        XfroutSession._handle(self.xfrsess)
+
     def test_parse_query_message(self):
     def test_parse_query_message(self):
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(get_rcode.to_text(), "NOERROR")
         self.assertEqual(get_rcode.to_text(), "NOERROR")
@@ -520,32 +591,6 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(82, get_rrset_len(self.soa_rrset))
         self.assertEqual(82, get_rrset_len(self.soa_rrset))
 
 
     def test_check_xfrout_available(self):
     def test_check_xfrout_available(self):
-        class MockDataSrcClient:
-            def __init__(self, type, config): pass
-
-            def get_iterator(self, zone_name):
-                if zone_name == Name('notauth.example.com'):
-                    raise isc.datasrc.Error('no such zone')
-                self._zone_name = zone_name
-                return self
-
-            def get_soa(self):  # emulate ZoneIterator.get_soa()
-                if self._zone_name == Name('nosoa.example.com'):
-                    return None
-                soa_rrset = RRset(Name('multisoa.example.com'), RRClass.IN(),
-                                  RRType.SOA(), RRTTL(3600))
-                soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
-                                          'master.example.com. ' +
-                                          'admin.example.com. 1234 ' +
-                                          '3600 1800 2419200 7200'))
-                if self._zone_name == Name('multisoa.example.com'):
-                    soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
-                                                   'master.example.com. ' +
-                                                   'admin.example.com. 1300 ' +
-                                                   '3600 1800 2419200 7200'))
-                    return soa_rrset
-                return soa_rrset
-
         self.xfrsess.ClientClass = MockDataSrcClient
         self.xfrsess.ClientClass = MockDataSrcClient
         self.assertEqual(self.xfrsess._check_xfrout_available(
         self.assertEqual(self.xfrsess._check_xfrout_available(
                 Name('notauth.example.com')), Rcode.NOTAUTH())
                 Name('notauth.example.com')), Rcode.NOTAUTH())
@@ -554,13 +599,6 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(self.xfrsess._check_xfrout_available(
         self.assertEqual(self.xfrsess._check_xfrout_available(
                 Name('multisoa.example.com')), Rcode.SERVFAIL())
                 Name('multisoa.example.com')), Rcode.SERVFAIL())
 
 
-        self.xfrsess._server.increase_transfers_counter = lambda : False
-        self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('example.com')), Rcode.REFUSED())
-        self.xfrsess._server.increase_transfers_counter = lambda : True
-        self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('example.com')), Rcode.NOERROR())
-
     def test_dns_xfrout_start_formerror(self):
     def test_dns_xfrout_start_formerror(self):
         # formerror
         # formerror
         self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
         self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")

+ 37 - 15
src/bin/xfrout/xfrout.py.in

@@ -128,21 +128,46 @@ class XfroutSession():
         self._zone_config = zone_config
         self._zone_config = zone_config
         self.ClientClass = client_class # parameterize this for testing
         self.ClientClass = client_class # parameterize this for testing
         self._soa = None # will be set in _check_xfrout_available or in tests
         self._soa = None # will be set in _check_xfrout_available or in tests
-        self.handle()
+        self._handle()
 
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
         return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
         return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
                            tsig_key_ring)
                            tsig_key_ring)
 
 
-    def handle(self):
-        ''' Handle a xfrout query, send xfrout response '''
+    def _handle(self):
+        ''' Handle a xfrout query, send xfrout response(s).
+
+        This is separated from the constructor so that we can override
+        it from tests.
+
+        '''
+        # Check the xfrout quota.  We do both increase/decrease in this
+        # method so it's clear we always release it once acuired.
+        quota_ok = self._server.increase_transfers_counter()
+        ex = None
         try:
         try:
-            self.dns_xfrout_start(self._sock_fd, self._request_data)
-            #TODO, avoid catching all exceptions
+            self.dns_xfrout_start(self._sock_fd, self._request_data, quota_ok)
         except Exception as e:
         except Exception as e:
-            logger.error(XFROUT_HANDLE_QUERY_ERROR, e)
-            pass
+            # To avoid resource leak we need catch all possible exceptions
+            # We log it later to exclude the case where even logger raises
+            # an exception.
+            ex = e
+
+        # Release any critical resources
+        if quota_ok:
+            self._server.decrease_transfers_counter()
+        self._close_socket()
+
+        if ex is not None:
+            logger.error(XFROUT_HANDLE_QUERY_ERROR, ex)
+
+    def _close_socket(self):
+        '''Simply close the socket via the given FD.
 
 
+        This is a dedicated subroutine of handle() and is sepsarated from it
+        for the convenience of tests.
+
+        '''
         os.close(self._sock_fd)
         os.close(self._sock_fd)
 
 
     def _check_request_tsig(self, msg, request_data):
     def _check_request_tsig(self, msg, request_data):
@@ -252,12 +277,8 @@ class XfroutSession():
         '''Check if xfr request can be responsed.
         '''Check if xfr request can be responsed.
            TODO, Get zone's configuration from cfgmgr or some other place
            TODO, Get zone's configuration from cfgmgr or some other place
            eg. check allow_transfer setting,
            eg. check allow_transfer setting,
-        '''
 
 
-        # Reject the attempt if we are too busy.  Check this first to avoid
-        # unnecessary resource consumption even if we discard it soon.
-        if not self._server.increase_transfers_counter():
-            return Rcode.REFUSED()
+        '''
 
 
         # Identify the data source for the requested zone and see if it has
         # Identify the data source for the requested zone and see if it has
         # SOA while initializing objects used for request processing later.
         # SOA while initializing objects used for request processing later.
@@ -292,7 +313,7 @@ class XfroutSession():
         return Rcode.NOERROR()
         return Rcode.NOERROR()
 
 
 
 
-    def dns_xfrout_start(self, sock_fd, msg_query):
+    def dns_xfrout_start(self, sock_fd, msg_query, quota_ok=True):
         rcode_, msg = self._parse_query_message(msg_query)
         rcode_, msg = self._parse_query_message(msg_query)
         #TODO. create query message and parse header
         #TODO. create query message and parse header
         if rcode_ is None: # Dropped by ACL
         if rcode_ is None: # Dropped by ACL
@@ -302,6 +323,9 @@ class XfroutSession():
         elif rcode_ != Rcode.NOERROR():
         elif rcode_ != Rcode.NOERROR():
             return self._reply_query_with_error_rcode(msg, sock_fd,
             return self._reply_query_with_error_rcode(msg, sock_fd,
                                                       Rcode.FORMERR())
                                                       Rcode.FORMERR())
+        elif not quota_ok:
+            return self._reply_query_with_error_rcode(msg, sock_fd,
+                                                      Rcode.REFUSED())
 
 
         question = msg.get_question()[0]
         question = msg.get_question()[0]
         zone_name = question.get_name()
         zone_name = question.get_name()
@@ -322,8 +346,6 @@ class XfroutSession():
             pass
             pass
         logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_str)
         logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_str)
 
 
-        self._server.decrease_transfers_counter()
-
     def _clear_message(self, msg):
     def _clear_message(self, msg):
         qid = msg.get_qid()
         qid = msg.get_qid()
         opcode = msg.get_opcode()
         opcode = msg.get_opcode()