Browse Source

[1165] make sure the ACL returned by _get_transfer_alc() so that per zone
ACL will be used when configured.
An unrelated change was piggy-backed: the ACL check was moved outside of
the try-except block. it doesn't make sense to return FORMERR when an
exception is raised in the ACL check.
Also some minor style fixes were made (folding some long lines)

JINMEI Tatuya 13 years ago
parent
commit
219818389c
2 changed files with 65 additions and 34 deletions
  1. 26 1
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 39 33
      src/bin/xfrout/xfrout.py.in

+ 26 - 1
src/bin/xfrout/tests/xfrout_test.py.in

@@ -126,7 +126,8 @@ class TestXfroutSession(unittest.TestCase):
                                        TSIGKeyRing(), ('127.0.0.1', 12345),
                                        # When not testing ACLs, simply accept
                                        isc.acl.dns.REQUEST_LOADER.load(
-                                           [{"action": "ACCEPT"}]))
+                                           [{"action": "ACCEPT"}]),
+                                       {})
         self.mdata = self.create_request_data(False)
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
@@ -231,10 +232,34 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(rcode.to_text(), "REFUSED")
 
     def test_transfer_acl(self):
+        # ACL checks only with the default ACL
         def acl_setter(acl):
             self.xfrsess._acl = acl
         self.check_transfer_acl(acl_setter)
 
+    def test_transfer_zoneacl(self):
+        # ACL check with a per zone ACL + default ACL.  The per zone ACL
+        # should match the queryied zone, so it should be used.
+        def acl_setter(acl):
+            zone_key = ('example.com.', 'IN')
+            self.xfrsess._zone_config[zone_key] = {}
+            self.xfrsess._zone_config[zone_key]['transfer_acl'] = acl
+            self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+                    {"from": "127.0.0.1", "action": "DROP"}])
+        self.check_transfer_acl(acl_setter)
+
+    def test_transfer_zoneacl_nomatch(self):
+        # similar to the previous one, but the per zone doesn't match the
+        # query.  The default should be used.
+        def acl_setter(acl):
+            zone_key = ('example.org.', 'IN')
+            self.xfrsess._zone_config[zone_key] = {}
+            self.xfrsess._zone_config[zone_key]['transfer_acl'] = \
+                isc.acl.dns.REQUEST_LOADER.load([
+                    {"from": "127.0.0.1", "action": "DROP"}])
+            self.xfrsess._acl = acl
+        self.check_transfer_acl(acl_setter)
+
     def test_get_transfer_acl(self):
         # set the default ACL.  If there's no specific zone ACL, this one
         # should be used.

+ 39 - 33
src/bin/xfrout/xfrout.py.in

@@ -99,7 +99,7 @@ def get_rrset_len(rrset):
 
 class XfroutSession():
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
-                 acl):
+                 default_acl, zone_config):
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._server = server
@@ -107,8 +107,8 @@ class XfroutSession():
         self._tsig_ctx = None
         self._tsig_len = 0
         self._remote = remote
-        self._acl = acl
-        self._zone_config = {}
+        self._acl = default_acl
+        self._zone_config = zone_config
         self.handle()
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -144,32 +144,30 @@ class XfroutSession():
         try:
             msg = Message(Message.PARSE)
             Message.from_wire(msg, mdata)
-
-            # TSIG related checks
-            rcode = self._check_request_tsig(msg, mdata)
-
-            if rcode == Rcode.NOERROR():
-                # ACL checks
-                acl_result = self._acl.execute(
-                    isc.acl.dns.RequestContext(self._remote,
-                                               msg.get_tsig_record()))
-                if acl_result == DROP:
-                    logger.info(XFROUT_QUERY_DROPPED,
-                                self._get_query_zone_name(msg),
-                                self._get_query_zone_class(msg),
-                                self._remote[0], self._remote[1])
-                    return None, None
-                elif acl_result == REJECT:
-                    logger.info(XFROUT_QUERY_REJECTED,
-                                self._get_query_zone_name(msg),
-                                self._get_query_zone_class(msg),
-                                self._remote[0], self._remote[1])
-                    return Rcode.REFUSED(), msg
-
-        except Exception as err:
+        except Exception as err: # Exception is too broad
             logger.error(XFROUT_PARSE_QUERY_ERROR, err)
             return Rcode.FORMERR(), None
 
+        # TSIG related checks
+        rcode = self._check_request_tsig(msg, mdata)
+
+        if rcode == Rcode.NOERROR():
+            # ACL checks
+            zone_name = msg.get_question()[0].get_name()
+            zone_class = msg.get_question()[0].get_class()
+            acl = self._get_transfer_acl(zone_name, zone_class)
+            acl_result = acl.execute(
+                isc.acl.dns.RequestContext(self._remote,
+                                           msg.get_tsig_record()))
+            if acl_result == DROP:
+                logger.info(XFROUT_QUERY_DROPPED, zone_name, zone_class,
+                            self._remote[0], self._remote[1])
+                return None, None
+            elif acl_result == REJECT:
+                logger.info(XFROUT_QUERY_REJECTED, zone_name, zone_class,
+                            self._remote[0], self._remote[1])
+                return Rcode.REFUSED(), msg
+
         return rcode, msg
 
     def _get_transfer_acl(self, zone_name, zone_class):
@@ -406,10 +404,12 @@ class XfroutSession():
         self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len,
                                          count_since_last_tsig_sign)
 
-class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
+class UnixSockServer(socketserver_mixin.NoPollMixIn,
+                     ThreadingUnixStreamServer):
     '''The unix domain socket server which accept xfr query sent from auth server.'''
 
-    def __init__(self, sock_file, handle_class, shutdown_event, config_data, cc):
+    def __init__(self, sock_file, handle_class, shutdown_event, config_data,
+                 cc):
         self._remove_unused_sock_file(sock_file)
         self._sock_file = sock_file
         socketserver_mixin.NoPollMixIn.__init__(self)
@@ -505,7 +505,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         if not request_data:
             return
 
-        t = threading.Thread(target = self.finish_request,
+        t = threading.Thread(target=self.finish_request,
                              args = (sock_fd, request_data))
         if self.daemon_threads:
             t.daemon = True
@@ -529,10 +529,14 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         return sock.getpeername()
 
     def finish_request(self, sock_fd, request_data):
-        '''Finish one request by instantiating RequestHandlerClass.'''
+        '''Finish one request by instantiating RequestHandlerClass.
+
+        This method creates a XfroutSession object.
+        '''
         self.RequestHandlerClass(sock_fd, request_data, self,
                                  self.tsig_key_ring,
-                                 self._guess_remote(sock_fd), self._acl)
+                                 self._guess_remote(sock_fd), self._acl,
+                                 self._zone_config)
 
     def _remove_unused_sock_file(self, sock_file):
         '''Try to remove the socket file. If the file is being used
@@ -673,8 +677,10 @@ class XfroutServer:
 
     def _start_xfr_query_listener(self):
         '''Start a new thread to accept xfr query. '''
-        self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession,
-                                                  self._shutdown_event, self._config_data,
+        self._unix_socket_server = UnixSockServer(self._listen_sock_file,
+                                                  XfroutSession,
+                                                  self._shutdown_event,
+                                                  self._config_data,
                                                   self._cc)
         listener = threading.Thread(target=self._unix_socket_server.serve_forever)
         listener.start()