Browse Source

Merge branch 'work/xfroutacl'

Conflicts:
	src/bin/xfrout/tests/xfrout_test.py.in
Michal 'vorner' Vaner 14 years ago
parent
commit
50070c8242

+ 1 - 1
src/bin/xfrout/tests/Makefile.am

@@ -6,7 +6,7 @@ EXTRA_DIST = $(PYTESTS)
 # required by loadable python modules.
 LIBRARY_PATH_PLACEHOLDER =
 if SET_ENV_LIBRARY_PATH
-LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cc/.libs:$(abs_top_builddir)/src/lib/config/.libs:$(abs_top_builddir)/src/lib/log/.libs:$(abs_top_builddir)/src/lib/dns/.libs:$(abs_top_builddir)/src/lib/cryptolink/.libs:$(abs_top_builddir)/src/lib/util/.libs:$(abs_top_builddir)/src/lib/exceptions/.libs:$(abs_top_builddir)/src/lib/util/io/.libs:$$$(ENV_LIBRARY_PATH)
+LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cc/.libs:$(abs_top_builddir)/src/lib/config/.libs:$(abs_top_builddir)/src/lib/log/.libs:$(abs_top_builddir)/src/lib/dns/.libs:$(abs_top_builddir)/src/lib/cryptolink/.libs:$(abs_top_builddir)/src/lib/acl/.libs:$(abs_top_builddir)/src/lib/util/.libs:$(abs_top_builddir)/src/lib/exceptions/.libs:$(abs_top_builddir)/src/lib/util/io/.libs:$$$(ENV_LIBRARY_PATH)
 endif
 
 # test using command-line arguments, so use check-local target instead of TESTS

+ 94 - 21
src/bin/xfrout/tests/xfrout_test.py.in

@@ -24,6 +24,7 @@ from pydnspp import *
 from xfrout import *
 import xfrout
 import isc.log
+import isc.acl.dns
 
 TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
 
@@ -117,8 +118,11 @@ class TestXfroutSession(unittest.TestCase):
 
     def setUp(self):
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
-        #self.log = isc.log.NSLogger('xfrout', '',  severity = 'critical', log_to_console = False )
-        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), TSIGKeyRing())
+        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
+                                       TSIGKeyRing(), ('127.0.0.1', 12345),
+                                       # When not testing ACLs, simply accept
+                                       isc.acl.dns.REQUEST_LOADER.load(
+                                           [{"action": "ACCEPT"}]))
         self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
@@ -138,6 +142,36 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(rcode.to_text(), "NOERROR")
         self.assertTrue(self.xfrsess._tsig_ctx is not None)
 
+        # ACL checks, put some ACL inside
+        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+            {
+                "from": "127.0.0.1",
+                "action": "ACCEPT"
+            },
+            {
+                "from": "192.0.2.1",
+                "action": "DROP"
+            }
+        ])
+        # Localhost (the default in this test) is accepted
+        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(rcode.to_text(), "NOERROR")
+        # This should be dropped completely, therefore returning None
+        self.xfrsess._remote = ('192.0.2.1', 12345)
+        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(None, rcode)
+        # This should be refused, therefore REFUSED
+        self.xfrsess._remote = ('192.0.2.2', 12345)
+        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(rcode.to_text(), "REFUSED")
+        # If the TSIG check fails, it should not check ACL
+        # (If it checked ACL as well, it would just drop the request)
+        self.xfrsess._remote = ('192.0.2.1', 12345)
+        self.xfrsess._tsig_key_ring = TSIGKeyRing()
+        rcode, msg = self.xfrsess._parse_query_message(request_data)
+        self.assertEqual(rcode.to_text(), "NOTAUTH")
+        self.assertTrue(self.xfrsess._tsig_ctx is not None)
+
     def test_get_query_zone_name(self):
         msg = self.getmsg()
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
@@ -196,20 +230,6 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(msg.get_rcode(), rcode)
         self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
 
-    def test_reply_query_with_format_error(self):
-        msg = self.getmsg()
-        self.xfrsess._reply_query_with_format_error(msg, self.sock)
-        get_msg = self.sock.read_msg()
-        self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
-
-        # tsig signed message
-        msg = self.getmsg()
-        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
-        self.xfrsess._reply_query_with_format_error(msg, self.sock)
-        get_msg = self.sock.read_msg()
-        self.assertEqual(get_msg.get_rcode().to_text(), "FORMERR")
-        self.assertTrue(self.message_has_tsig(get_msg))
-
     def test_create_rrset_from_db_record(self):
         rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
         self.assertEqual(rrset.get_name().to_text(), "example.com.")
@@ -516,18 +536,42 @@ class MyCCSession():
 
 class MyUnixSockServer(UnixSockServer):
     def __init__(self):
-        self._lock = threading.Lock()
-        self._transfers_counter = 0
         self._shutdown_event = threading.Event()
         self._max_transfers_out = 10
         self._cc = MyCCSession()
-        #self._log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False )
+        self._common_init()
 
 class TestUnixSockServer(unittest.TestCase):
     def setUp(self):
         self.write_sock, self.read_sock = socket.socketpair()
         self.unix = MyUnixSockServer()
 
+    def test_guess_remote(self):
+        """Test we can guess the remote endpoint when we have only the
+           file descriptor. This is needed, because we get only that one
+           from auth."""
+        # We test with UDP, as it can be "connected" without other
+        # endpoint
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        sock.connect(('127.0.0.1', 12345))
+        self.assertEqual(('127.0.0.1', 12345),
+                         self.unix._guess_remote(sock.fileno()))
+        if socket.has_ipv6:
+            # Don't check IPv6 address on hosts not supporting them
+            sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+            sock.connect(('::1', 12345))
+            self.assertEqual(('::1', 12345, 0, 0),
+                             self.unix._guess_remote(sock.fileno()))
+            # Try when pretending there's no IPv6 support
+            # (No need to pretend when there's really no IPv6)
+            xfrout.socket.has_ipv6 = False
+            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+            sock.connect(('127.0.0.1', 12345))
+            self.assertEqual(('127.0.0.1', 12345),
+                             self.unix._guess_remote(sock.fileno()))
+            # Return it back
+            xfrout.socket.has_ipv6 = True
+
     def test_receive_query_message(self):
         send_msg = b"\xd6=\x00\x00\x00\x01\x00"
         msg_len = struct.pack('H', socket.htons(len(send_msg)))
@@ -536,15 +580,34 @@ class TestUnixSockServer(unittest.TestCase):
         recv_msg = self.unix._receive_query_message(self.read_sock)
         self.assertEqual(recv_msg, send_msg)
 
-    def test_updata_config_data(self):
+    def check_default_ACL(self):
+        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
+                                             1234, 0, 0, 0,
+                                             socket.AI_NUMERICHOST)[0][4])
+        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
+
+    def check_loaded_ACL(self):
+        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
+                                             1234, 0, 0, 0,
+                                             socket.AI_NUMERICHOST)[0][4])
+        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
+        context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
+                                             1234, 0, 0, 0,
+                                             socket.AI_NUMERICHOST)[0][4])
+        self.assertEqual(isc.acl.acl.REJECT, self.unix._acl.execute(context))
+
+    def test_update_config_data(self):
+        self.check_default_ACL()
         tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
         tsig_key_list = [tsig_key_str]
         bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
         self.unix.update_config_data({'transfers_out':10 })
         self.assertEqual(self.unix._max_transfers_out, 10)
         self.assertTrue(self.unix.tsig_key_ring is not None)
+        self.check_default_ACL()
 
-        self.unix.update_config_data({'transfers_out':9, 'tsig_key_ring':tsig_key_list})
+        self.unix.update_config_data({'transfers_out':9,
+                                      'tsig_key_ring':tsig_key_list})
         self.assertEqual(self.unix._max_transfers_out, 9)
         self.assertEqual(self.unix.tsig_key_ring.size(), 1)
         self.unix.tsig_key_ring.remove(Name("example.com."))
@@ -555,6 +618,16 @@ class TestUnixSockServer(unittest.TestCase):
         self.assertRaises(None, self.unix.update_config_data(config_data))
         self.assertEqual(self.unix.tsig_key_ring.size(), 0)
 
+        # Load the ACL
+        self.unix.update_config_data({'query_acl': [{'from': '127.0.0.1',
+                                               'action': 'ACCEPT'}]})
+        self.check_loaded_ACL()
+        # Pass a wrong data there and check it does not replace the old one
+        self.assertRaises(isc.acl.acl.LoaderError,
+                          self.unix.update_config_data,
+                          {'query_acl': ['Something bad']})
+        self.check_loaded_ACL()
+
     def test_get_db_file(self):
         self.assertEqual(self.unix.get_db_file(), "initdb.file")
 

+ 69 - 29
src/bin/xfrout/xfrout.py.in

@@ -48,6 +48,9 @@ except ImportError as e:
     # must keep running, so we warn about it and move forward.
     log.error(XFROUT_IMPORT, str(e))
 
+from isc.acl.acl import ACCEPT, REJECT, DROP
+from isc.acl.dns import REQUEST_LOADER
+
 isc.util.process.rename()
 
 def init_paths():
@@ -92,16 +95,16 @@ def get_rrset_len(rrset):
 
 
 class XfroutSession():
-    def __init__(self, sock_fd, request_data, server, tsig_key_ring):
-        # The initializer for the superclass may call functions
-        # that need _log to be set, so we set it first
+    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
+                 acl):
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._server = server
-        #self._log = log
         self._tsig_key_ring = tsig_key_ring
         self._tsig_ctx = None
         self._tsig_len = 0
+        self._remote = remote
+        self._acl = acl
         self.handle()
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -114,7 +117,7 @@ class XfroutSession():
             self.dns_xfrout_start(self._sock_fd, self._request_data)
             #TODO, avoid catching all exceptions
         except Exception as e:
-            logger.error(XFROUT_HANDLE_QUERY_ERROR, str(e))
+            logger.error(XFROUT_HANDLE_QUERY_ERROR, e)
             pass
 
         os.close(self._sock_fd)
@@ -141,8 +144,25 @@ class XfroutSession():
             # TSIG related checks
             rcode = self._check_request_tsig(msg, mdata)
 
+            if rcode == Rcode.NOERROR():
+                # ACL checks
+                acl_result = self._acl.execute(
+                    isc.acl.dns.RequestContext(self._remote))
+                if acl_result == DROP:
+                    logger.info(XFROUT_QUERY_DROPPED,
+                                self._get_query_zone_name(msg),
+                                self._get_query_zone_class(msg),
+                                self._remote[0], self._remote[1])
+                    return None, None
+                elif acl_result == REJECT:
+                    logger.info(XFROUT_QUERY_REJECTED,
+                                self._get_query_zone_name(msg),
+                                self._get_query_zone_class(msg),
+                                self._remote[0], self._remote[1])
+                    return Rcode.REFUSED(), msg
+
         except Exception as err:
-            logger.error(XFROUT_PARSE_QUERY_ERROR, str(err))
+            logger.error(XFROUT_PARSE_QUERY_ERROR, err)
             return Rcode.FORMERR(), None
 
         return rcode, msg
@@ -183,18 +203,11 @@ class XfroutSession():
 
 
     def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
-        msg.make_response()
-        msg.set_rcode(rcode_)
-        self._send_message(sock_fd, msg, self._tsig_ctx)
-
-
-    def _reply_query_with_format_error(self, msg, sock_fd):
-        '''query message format isn't legal.'''
         if not msg:
             return # query message is invalid. send nothing back.
 
         msg.make_response()
-        msg.set_rcode(Rcode.FORMERR())
+        msg.set_rcode(rcode_)
         self._send_message(sock_fd, msg, self._tsig_ctx)
 
     def _zone_has_soa(self, zone):
@@ -244,10 +257,13 @@ class XfroutSession():
     def dns_xfrout_start(self, sock_fd, msg_query):
         rcode_, msg = self._parse_query_message(msg_query)
         #TODO. create query message and parse header
-        if rcode_ == Rcode.NOTAUTH():
+        if rcode_ is None: # Dropped by ACL
+            return
+        elif rcode_ == Rcode.NOTAUTH() or rcode_ == Rcode.REFUSED():
             return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
         elif rcode_ != Rcode.NOERROR():
-            return self._reply_query_with_format_error(msg, sock_fd)
+            return self._reply_query_with_error_rcode(msg, sock_fd,
+                                                      Rcode.FORMERR())
 
         zone_name = self._get_query_zone_name(msg)
         zone_class_str = self._get_query_zone_class(msg)
@@ -257,7 +273,7 @@ class XfroutSession():
         if rcode_ != Rcode.NOERROR():
             logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name,
                         zone_class_str, rcode_.to_text())
-            return self. _reply_query_with_error_rcode(msg, sock_fd, rcode_)
+            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
 
         try:
             logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name, zone_class_str)
@@ -375,14 +391,20 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         self._sock_file = sock_file
         socketserver_mixin.NoPollMixIn.__init__(self)
         ThreadingUnixStreamServer.__init__(self, sock_file, handle_class)
-        self._lock = threading.Lock()
-        self._transfers_counter = 0
         self._shutdown_event = shutdown_event
         self._write_sock, self._read_sock = socket.socketpair()
-        #self._log = log
+        self._common_init()
         self.update_config_data(config_data)
         self._cc = cc
 
+    def _common_init(self):
+        self._lock = threading.Lock()
+        self._transfers_counter = 0
+        # This default value will probably get overwritten by the (same)
+        # default value from the spec file. This is here just to make
+        # sure and to make the default value in tests consistent.
+        self._acl = REQUEST_LOADER.load('[{"action": "ACCEPT"}]')
+
     def _receive_query_message(self, sock):
         ''' receive request message from sock'''
         # receive data length
@@ -465,10 +487,28 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
             t.daemon = True
         t.start()
 
+    def _guess_remote(self, sock_fd):
+        """
+           Guess remote address and port of the socket. The sock_fd must be a
+           socket
+        """
+        # This uses a trick. If the socket is IPv4 in reality and we pretend
+        # it to be IPv6, it returns IPv4 address anyway. This doesn't seem
+        # to care about the SOCK_STREAM parameter at all (which it really is,
+        # except for testing)
+        if socket.has_ipv6:
+            sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM)
+        else:
+            # To make it work even on hosts without IPv6 support
+            # (Any idea how to simulate this in test?)
+            sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM)
+        return sock.getpeername()
 
     def finish_request(self, sock_fd, request_data):
         '''Finish one request by instantiating RequestHandlerClass.'''
-        self.RequestHandlerClass(sock_fd, request_data, self, self.tsig_key_ring)
+        self.RequestHandlerClass(sock_fd, request_data, self,
+                                 self.tsig_key_ring,
+                                 self._guess_remote(sock_fd), self._acl)
 
     def _remove_unused_sock_file(self, sock_file):
         '''Try to remove the socket file. If the file is being used
@@ -512,6 +552,8 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
     def update_config_data(self, new_config):
         '''Apply the new config setting of xfrout module. '''
         logger.info(XFROUT_NEW_CONFIG)
+        if 'query_acl' in new_config:
+            self._acl = REQUEST_LOADER.load(new_config['query_acl'])
         self._lock.acquire()
         self._max_transfers_out = new_config.get('transfers_out')
         self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
@@ -563,16 +605,12 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
 class XfroutServer:
     def __init__(self):
         self._unix_socket_server = None
-        #self._log = None
         self._listen_sock_file = UNIX_SOCKET_FILE
         self._shutdown_event = threading.Event()
         self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
         self._config_data = self._cc.get_full_config()
         self._cc.start()
         self._cc.add_remote_config(AUTH_SPECFILE_LOCATION);
-        #self._log = isc.log.NSLogger(self._config_data.get('log_name'), self._config_data.get('log_file'),
-        #                        self._config_data.get('log_severity'), self._config_data.get('log_versions'),
-        #                        self._config_data.get('log_max_bytes'), True)
         self._start_xfr_query_listener()
         self._start_notifier()
 
@@ -601,11 +639,13 @@ class XfroutServer:
                 continue
             self._config_data[key] = new_config[key]
 
-        #if self._log:
-        #    self._log.update_config(new_config)
-
         if self._unix_socket_server:
-            self._unix_socket_server.update_config_data(self._config_data)
+            try:
+                self._unix_socket_server.update_config_data(self._config_data)
+            except Exception as e:
+                answer = create_answer(1,
+                                       "Failed to handle new configuration: " +
+                                       str(e))
 
         return answer
 

+ 19 - 7
src/bin/xfrout/xfrout.spec.pre.in

@@ -16,27 +16,27 @@
        },
        {
          "item_name": "log_file",
-    	 "item_type": "string",
+         "item_type": "string",
          "item_optional": false,
          "item_default": "@@LOCALSTATEDIR@@/@PACKAGE@/log/Xfrout.log"
        },
        {
          "item_name": "log_severity",
-    	 "item_type": "string",
+         "item_type": "string",
          "item_optional": false,
-    	 "item_default": "debug"
+         "item_default": "debug"
        },
        {
          "item_name": "log_versions",
-    	 "item_type": "integer",
+         "item_type": "integer",
          "item_optional": false,
-    	 "item_default": 5
+         "item_default": 5
        },
        {
          "item_name": "log_max_bytes",
-    	 "item_type": "integer",
+         "item_type": "integer",
          "item_optional": false,
-    	 "item_default": 1048576
+         "item_default": 1048576
        },
        {
          "item_name": "tsig_key_ring",
@@ -49,6 +49,18 @@
              "item_type": "string",
              "item_optional": true
          }
+       },
+       {
+         "item_name": "query_acl",
+         "item_type": "list",
+         "item_optional": false,
+         "item_default": [{"action": "ACCEPT"}],
+         "list_item_spec":
+         {
+             "item_name": "acl_element",
+             "item_type": "any",
+             "item_optional": true
+         }
        }
       ],
       "commands": [

+ 11 - 0
src/bin/xfrout/xfrout_messages.mes

@@ -95,6 +95,17 @@ in the log message, but at this point no specific information other
 than that could be given. This points to incomplete exception handling
 in the code.
 
+% XFROUT_QUERY_DROPPED request to transfer %1/%2 to [%3]:%4 dropped
+The xfrout process silently dropped a request to transfer zone to given host.
+This is required by the ACLs. The %1 and %2 represent the zone name and class,
+the %3 and %4 the IP address and port of the peer requesting the transfer.
+
+% XFROUT_QUERY_REJECTED request to transfer %1/%2 to [%3]:%4 rejected
+The xfrout process rejected (by REFUSED rcode) a request to transfer zone to
+given host. This is because of ACLs. The %1 and %2 represent the zone name and
+class, the %3 and %4 the IP address and port of the peer requesting the
+transfer.
+
 % XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR error receiving the file descriptor for an XFR connection
 There was an error receiving the file descriptor for the transfer
 request. Normally, the request is received by b10-auth, and passed on