Browse Source

[1165] make sure the ACL returned by _get_transfer_alc() so that per zone
ACL will be used when configured.
An unrelated change was piggy-backed: the ACL check was moved outside of
the try-except block. it doesn't make sense to return FORMERR when an
exception is raised in the ACL check.
Also some minor style fixes were made (folding some long lines)

JINMEI Tatuya 13 years ago
parent
commit
219818389c
2 changed files with 65 additions and 34 deletions
  1. 26 1
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 39 33
      src/bin/xfrout/xfrout.py.in

+ 26 - 1
src/bin/xfrout/tests/xfrout_test.py.in

@@ -126,7 +126,8 @@ class TestXfroutSession(unittest.TestCase):
                                        TSIGKeyRing(), ('127.0.0.1', 12345),
                                        TSIGKeyRing(), ('127.0.0.1', 12345),
                                        # When not testing ACLs, simply accept
                                        # When not testing ACLs, simply accept
                                        isc.acl.dns.REQUEST_LOADER.load(
                                        isc.acl.dns.REQUEST_LOADER.load(
-                                           [{"action": "ACCEPT"}]))
+                                           [{"action": "ACCEPT"}]),
+                                       {})
         self.mdata = self.create_request_data(False)
         self.mdata = self.create_request_data(False)
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
 
@@ -231,10 +232,34 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(rcode.to_text(), "REFUSED")
         self.assertEqual(rcode.to_text(), "REFUSED")
 
 
     def test_transfer_acl(self):
     def test_transfer_acl(self):
+        # ACL checks only with the default ACL
         def acl_setter(acl):
         def acl_setter(acl):
             self.xfrsess._acl = acl
             self.xfrsess._acl = acl
         self.check_transfer_acl(acl_setter)
         self.check_transfer_acl(acl_setter)
 
 
+    def test_transfer_zoneacl(self):
+        # ACL check with a per zone ACL + default ACL.  The per zone ACL
+        # should match the queryied zone, so it should be used.
+        def acl_setter(acl):
+            zone_key = ('example.com.', 'IN')
+            self.xfrsess._zone_config[zone_key] = {}
+            self.xfrsess._zone_config[zone_key]['transfer_acl'] = acl
+            self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+                    {"from": "127.0.0.1", "action": "DROP"}])
+        self.check_transfer_acl(acl_setter)
+
+    def test_transfer_zoneacl_nomatch(self):
+        # similar to the previous one, but the per zone doesn't match the
+        # query.  The default should be used.
+        def acl_setter(acl):
+            zone_key = ('example.org.', 'IN')
+            self.xfrsess._zone_config[zone_key] = {}
+            self.xfrsess._zone_config[zone_key]['transfer_acl'] = \
+                isc.acl.dns.REQUEST_LOADER.load([
+                    {"from": "127.0.0.1", "action": "DROP"}])
+            self.xfrsess._acl = acl
+        self.check_transfer_acl(acl_setter)
+
     def test_get_transfer_acl(self):
     def test_get_transfer_acl(self):
         # set the default ACL.  If there's no specific zone ACL, this one
         # set the default ACL.  If there's no specific zone ACL, this one
         # should be used.
         # should be used.

+ 39 - 33
src/bin/xfrout/xfrout.py.in

@@ -99,7 +99,7 @@ def get_rrset_len(rrset):
 
 
 class XfroutSession():
 class XfroutSession():
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
-                 acl):
+                 default_acl, zone_config):
         self._sock_fd = sock_fd
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._request_data = request_data
         self._server = server
         self._server = server
@@ -107,8 +107,8 @@ class XfroutSession():
         self._tsig_ctx = None
         self._tsig_ctx = None
         self._tsig_len = 0
         self._tsig_len = 0
         self._remote = remote
         self._remote = remote
-        self._acl = acl
+        self._acl = default_acl
-        self._zone_config = {}
+        self._zone_config = zone_config
         self.handle()
         self.handle()
 
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -144,32 +144,30 @@ class XfroutSession():
         try:
         try:
             msg = Message(Message.PARSE)
             msg = Message(Message.PARSE)
             Message.from_wire(msg, mdata)
             Message.from_wire(msg, mdata)
-
+        except Exception as err: # Exception is too broad
-            # TSIG related checks
-            rcode = self._check_request_tsig(msg, mdata)
-
-            if rcode == Rcode.NOERROR():
-                # ACL checks
-                acl_result = self._acl.execute(
-                    isc.acl.dns.RequestContext(self._remote,
-                                               msg.get_tsig_record()))
-                if acl_result == DROP:
-                    logger.info(XFROUT_QUERY_DROPPED,
-                                self._get_query_zone_name(msg),
-                                self._get_query_zone_class(msg),
-                                self._remote[0], self._remote[1])
-                    return None, None
-                elif acl_result == REJECT:
-                    logger.info(XFROUT_QUERY_REJECTED,
-                                self._get_query_zone_name(msg),
-                                self._get_query_zone_class(msg),
-                                self._remote[0], self._remote[1])
-                    return Rcode.REFUSED(), msg
-
-        except Exception as err:
             logger.error(XFROUT_PARSE_QUERY_ERROR, err)
             logger.error(XFROUT_PARSE_QUERY_ERROR, err)
             return Rcode.FORMERR(), None
             return Rcode.FORMERR(), None
 
 
+        # TSIG related checks
+        rcode = self._check_request_tsig(msg, mdata)
+
+        if rcode == Rcode.NOERROR():
+            # ACL checks
+            zone_name = msg.get_question()[0].get_name()
+            zone_class = msg.get_question()[0].get_class()
+            acl = self._get_transfer_acl(zone_name, zone_class)
+            acl_result = acl.execute(
+                isc.acl.dns.RequestContext(self._remote,
+                                           msg.get_tsig_record()))
+            if acl_result == DROP:
+                logger.info(XFROUT_QUERY_DROPPED, zone_name, zone_class,
+                            self._remote[0], self._remote[1])
+                return None, None
+            elif acl_result == REJECT:
+                logger.info(XFROUT_QUERY_REJECTED, zone_name, zone_class,
+                            self._remote[0], self._remote[1])
+                return Rcode.REFUSED(), msg
+
         return rcode, msg
         return rcode, msg
 
 
     def _get_transfer_acl(self, zone_name, zone_class):
     def _get_transfer_acl(self, zone_name, zone_class):
@@ -406,10 +404,12 @@ class XfroutSession():
         self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len,
         self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len,
                                          count_since_last_tsig_sign)
                                          count_since_last_tsig_sign)
 
 
-class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
+class UnixSockServer(socketserver_mixin.NoPollMixIn,
+                     ThreadingUnixStreamServer):
     '''The unix domain socket server which accept xfr query sent from auth server.'''
     '''The unix domain socket server which accept xfr query sent from auth server.'''
 
 
-    def __init__(self, sock_file, handle_class, shutdown_event, config_data, cc):
+    def __init__(self, sock_file, handle_class, shutdown_event, config_data,
+                 cc):
         self._remove_unused_sock_file(sock_file)
         self._remove_unused_sock_file(sock_file)
         self._sock_file = sock_file
         self._sock_file = sock_file
         socketserver_mixin.NoPollMixIn.__init__(self)
         socketserver_mixin.NoPollMixIn.__init__(self)
@@ -505,7 +505,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         if not request_data:
         if not request_data:
             return
             return
 
 
-        t = threading.Thread(target = self.finish_request,
+        t = threading.Thread(target=self.finish_request,
                              args = (sock_fd, request_data))
                              args = (sock_fd, request_data))
         if self.daemon_threads:
         if self.daemon_threads:
             t.daemon = True
             t.daemon = True
@@ -529,10 +529,14 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         return sock.getpeername()
         return sock.getpeername()
 
 
     def finish_request(self, sock_fd, request_data):
     def finish_request(self, sock_fd, request_data):
-        '''Finish one request by instantiating RequestHandlerClass.'''
+        '''Finish one request by instantiating RequestHandlerClass.
+
+        This method creates a XfroutSession object.
+        '''
         self.RequestHandlerClass(sock_fd, request_data, self,
         self.RequestHandlerClass(sock_fd, request_data, self,
                                  self.tsig_key_ring,
                                  self.tsig_key_ring,
-                                 self._guess_remote(sock_fd), self._acl)
+                                 self._guess_remote(sock_fd), self._acl,
+                                 self._zone_config)
 
 
     def _remove_unused_sock_file(self, sock_file):
     def _remove_unused_sock_file(self, sock_file):
         '''Try to remove the socket file. If the file is being used
         '''Try to remove the socket file. If the file is being used
@@ -673,8 +677,10 @@ class XfroutServer:
 
 
     def _start_xfr_query_listener(self):
     def _start_xfr_query_listener(self):
         '''Start a new thread to accept xfr query. '''
         '''Start a new thread to accept xfr query. '''
-        self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession,
+        self._unix_socket_server = UnixSockServer(self._listen_sock_file,
-                                                  self._shutdown_event, self._config_data,
+                                                  XfroutSession,
+                                                  self._shutdown_event,
+                                                  self._config_data,
                                                   self._cc)
                                                   self._cc)
         listener = threading.Thread(target=self._unix_socket_server.serve_forever)
         listener = threading.Thread(target=self._unix_socket_server.serve_forever)
         listener.start()
         listener.start()