Browse Source

Merge branch 'work/sock/unix' into work/sock/cacheinter

Conflicts:
	src/bin/bind10/bind10_src.py.in
	src/bin/bind10/tests/bind10_test.py.in
Michal 'vorner' Vaner 13 years ago
parent
commit
2a6b5e55ca
2 changed files with 349 additions and 56 deletions
  1. 153 56
      src/bin/bind10/bind10_src.py.in
  2. 196 0
      src/bin/bind10/tests/bind10_test.py.in

+ 153 - 56
src/bin/bind10/bind10_src.py.in

@@ -74,6 +74,7 @@ import isc.bind10.component
 import isc.bind10.special_component
 import isc.bind10.socket_cache
 import libutil_io_python
+import tempfile
 
 isc.log.init("b10-boss")
 logger = isc.log.Logger("boss")
@@ -247,9 +248,12 @@ class BoB:
         # If -v was set, enable full debug logging.
         if self.verbose:
             logger.set_severity("DEBUG", 99)
-        self._socket_cache = None
-        # TODO: To be filled in by #1428
+        # This is set in init_socket_srv
         self._socket_path = None
+        self._socket_cache = None
+        self._tmpdir = None
+        self._srv_socket = None
+        self._unix_sockets = {}
 
     def __propagate_component_config(self, config):
         comps = dict(config)
@@ -915,6 +919,127 @@ class BoB:
             raise ValueError("A creator was inserted previously")
         self._socket_cache = isc.bind10.socket_cache.Cache(creator)
 
+    def init_socket_srv(self):
+        """
+        Creates and listens on a unix-domain socket to be able to send out
+        the sockets.
+
+        This method should be called after switching user, or the switched
+        applications won't be able to access the socket.
+        """
+        self._srv_socket = socket.socket(socket.AF_UNIX)
+        # We create a temporary directory somewhere safe and unique, to avoid
+        # the need to find the place ourself or bother users. Also, this
+        # secures the socket on some platforms, as it creates a private
+        # directory.
+        self._tmpdir = tempfile.mkdtemp()
+        # Get the name
+        self._socket_path = os.path.join(self._tmpdir, "sockcreator")
+        # And bind the socket to the name
+        self._srv_socket.bind(self._socket_path)
+        self._srv_socket.listen(5)
+
+    def remove_socket_srv(self):
+        """
+        Closes and removes the listening socket and the directory where it
+        lives, as we created both.
+
+        It does nothing if the _srv_socket is not set (eg. it was not yet
+        initialized).
+        """
+        if self._srv_socket is not None:
+            self._srv_socket.close()
+            os.remove(self._socket_path)
+            os.rmdir(self._tmpdir)
+
+    def _srv_accept(self):
+        """
+        Accept a socket from the unix domain socket server and put it to the
+        others we care about.
+        """
+        socket = self._srv_socket.accept()
+        self._unix_sockets[socket.fileno()] = (socket, b'')
+
+    def _socket_data(self, socket_fileno):
+        """
+        This is called when a socket identified by the socket_fileno needs
+        attention. We try to read data from there. If it is closed, we remove
+        it.
+        """
+        (sock, previous) = self._unix_sockets[socket_fileno]
+        while True:
+            try:
+                data = sock.recv(1, socket.MSG_DONTWAIT)
+            except socket.error as se:
+                # These two might be different on some systems
+                if se.errno == errno.EAGAIN or se.errno == errno.EWOULDBLOCK:
+                    # No more data now. Oh, well, just store what we have.
+                    self._unix_sockets[socket_fileno] = (sock, previous)
+                    return
+                else:
+                    data = b'' # Pretend it got closed
+            if len(data) == 0: # The socket got to it's end
+                del self._unix_sockets[socket_fileno]
+                self.socket_consumer_dead(sock)
+                sock.close()
+                return
+            else:
+                if data == b"\n":
+                    # Handle this token and clear it
+                    self.socket_request_handler(previous, sock)
+                    previous = b''
+                else:
+                    previous += data
+
+    def run(self, wakeup_fd):
+        """
+        The main loop, waiting for sockets, commands and dead processes.
+        Runs as long as the runnable is true.
+
+        The wakeup_fd descriptor is the read end of pipe where CHLD signal
+        handler writes.
+        """
+        ccs_fd = self.ccs.get_socket().fileno()
+        while self.runnable:
+            # clean up any processes that exited
+            self.reap_children()
+            next_restart = self.restart_processes()
+            if next_restart is None:
+                wait_time = None
+            else:
+                wait_time = max(next_restart - time.time(), 0)
+
+            # select() can raise EINTR when a signal arrives,
+            # even if they are resumable, so we have to catch
+            # the exception
+            try:
+                (rlist, wlist, xlist) = \
+                    select.select([wakeup_fd, ccs_fd,
+                                   self._srv_socket.fileno()] +
+                                   list(self._unix_sockets.keys()), [], [],
+                                  wait_time)
+            except select.error as err:
+                if err.args[0] == errno.EINTR:
+                    (rlist, wlist, xlist) = ([], [], [])
+                else:
+                    logger.fatal(BIND10_SELECT_ERROR, err)
+                    break
+
+            for fd in rlist + xlist:
+                if fd == ccs_fd:
+                    try:
+                        self.ccs.check_command()
+                    except isc.cc.session.ProtocolError:
+                        logger.fatal(BIND10_MSGQ_DISAPPEARED)
+                        self.runnable = False
+                        break
+                elif fd == wakeup_fd:
+                    os.read(wakeup_fd, 32)
+                elif fd == self._srv_socket.fileno():
+                    self._srv_accept()
+                elif fd in self._unix_sockets:
+                    self._socket_data(fd)
+
 # global variables, needed for signal handlers
 options = None
 boss_of_bind = None
@@ -1077,60 +1202,32 @@ def main():
     # Block SIGPIPE, as we don't want it to end this process
     signal.signal(signal.SIGPIPE, signal.SIG_IGN)
 
-    # Go bob!
-    boss_of_bind = BoB(options.msgq_socket_file, options.data_path,
-                       options.config_file, options.nocache, options.verbose,
-                       setuid, username, options.cmdctl_port,
-                       options.wait_time)
-    startup_result = boss_of_bind.startup()
-    if startup_result:
-        logger.fatal(BIND10_STARTUP_ERROR, startup_result)
-        sys.exit(1)
-    logger.info(BIND10_STARTUP_COMPLETE)
-    dump_pid(options.pid_file)
-
-    # In our main loop, we check for dead processes or messages 
-    # on the c-channel.
-    wakeup_fd = wakeup_pipe[0]
-    ccs_fd = boss_of_bind.ccs.get_socket().fileno()
-    while boss_of_bind.runnable:
-        # clean up any processes that exited
-        boss_of_bind.reap_children()
-        next_restart = boss_of_bind.restart_processes()
-        if next_restart is None:
-            wait_time = None
-        else:
-            wait_time = max(next_restart - time.time(), 0)
-
-        # select() can raise EINTR when a signal arrives, 
-        # even if they are resumable, so we have to catch
-        # the exception
-        try:
-            (rlist, wlist, xlist) = select.select([wakeup_fd, ccs_fd], [], [], 
-                                                  wait_time)
-        except select.error as err:
-            if err.args[0] == errno.EINTR:
-                (rlist, wlist, xlist) = ([], [], [])
-            else:
-                logger.fatal(BIND10_SELECT_ERROR, err)
-                break
-
-        for fd in rlist + xlist:
-            if fd == ccs_fd:
-                try:
-                    boss_of_bind.ccs.check_command()
-                except isc.cc.session.ProtocolError:
-                    logger.fatal(BIND10_MSGQ_DISAPPEARED)
-                    self.runnable = False
-                    break
-            elif fd == wakeup_fd:
-                os.read(wakeup_fd, 32)
-
-    # shutdown
-    signal.signal(signal.SIGCHLD, signal.SIG_DFL)
-    boss_of_bind.shutdown()
-    unlink_pid_file(options.pid_file)
-    sys.exit(0)
+    try:
+        # Go bob!
+        boss_of_bind = BoB(options.msgq_socket_file, options.data_path,
+                           options.config_file, options.nocache,
+                           options.verbose, setuid, username,
+                           options.cmdctl_port, options.wait_time)
+        startup_result = boss_of_bind.startup()
+        if startup_result:
+            logger.fatal(BIND10_STARTUP_ERROR, startup_result)
+            sys.exit(1)
+        boss_of_bind.init_socket_srv()
+        logger.info(BIND10_STARTUP_COMPLETE)
+        dump_pid(options.pid_file)
+
+        # Let it run
+        boss_of_bind.run(wakeup_pipe[0])
+
+        # shutdown
+        signal.signal(signal.SIGCHLD, signal.SIG_DFL)
+        boss_of_bind.shutdown()
+    finally:
+        # Clean up the filesystem
+        unlink_pid_file(options.pid_file)
+        if boss_of_bind is not None:
+            boss_of_bind.remove_socket_srv()
+    sys.exit(boss_of_bind.exitcode)
 
 if __name__ == "__main__":
     main()

+ 196 - 0
src/bin/bind10/tests/bind10_test.py.in

@@ -33,6 +33,7 @@ import time
 import isc
 import isc.log
 import isc.bind10.socket_cache
+import errno
 
 from isc.testutils.parse_args import TestOptParser, OptsError
 
@@ -1195,6 +1196,201 @@ class TestBossComponents(unittest.TestCase):
         bob.start_all_components()
         self.__check_extended(self.__param)
 
+class SocketSrvTest(unittest.TestCase):
+    """
+    This tests some methods of boss related to the unix domain sockets used
+    to transfer other sockets to applications.
+    """
+    def setUp(self):
+        """
+        Create the boss to test, testdata and backup some functions.
+        """
+        self.__boss = BoB()
+        self.__select_backup = bind10_src.select.select
+        self.__select_called = None
+        self.__socket_data_called = None
+        self.__consumer_dead_called = None
+        self.__socket_request_handler_called = None
+
+    def tearDown(self):
+        """
+        Restore functions.
+        """
+        bind10_src.select.select = self.__select_backup
+
+    class __FalseSocket:
+        """
+        A mock socket for the select and accept and stuff like that.
+        """
+        def __init__(self, owner, fileno=42):
+            self.__owner = owner
+            self.__fileno = fileno
+            self.data = None
+            self.closed = False
+
+        def fileno(self):
+            return self.__fileno
+
+        def accept(self):
+            return self.__class__(self.__owner, 13)
+
+        def recv(self, bufsize, flags=0):
+            self.__owner.assertEqual(1, bufsize)
+            self.__owner.assertEqual(socket.MSG_DONTWAIT, flags)
+            if isinstance(self.data, socket.error):
+                raise self.data
+            elif self.data is not None:
+                if len(self.data):
+                    result = self.data[0:1]
+                    self.data = self.data[1:]
+                    return result
+                else:
+                    raise socket.error(errno.EAGAIN, "Would block")
+            else:
+                return b''
+
+        def close(self):
+            self.closed = True
+
+    class __CCS:
+        """
+        A mock CCS, just to provide the socket file number.
+        """
+        class __Socket:
+            def fileno(self):
+                return 1
+        def get_socket(self):
+            return self.__Socket()
+
+    def __select_accept(self, r, w, x, t):
+        self.__select_called = (r, w, x, t)
+        return ([42], [], [])
+
+    def __select_data(self, r, w, x, t):
+        self.__select_called = (r, w, x, t)
+        return ([13], [], [])
+
+    def __accept(self):
+        """
+        Hijact the accept method of the boss.
+
+        Notes down it was called and stops the boss.
+        """
+        self.__accept_called = True
+        self.__boss.runnable = False
+
+    def test_srv_accept_called(self):
+        """
+        Test that the _srv_accept method of boss is called when the listening
+        socket is readable.
+        """
+        self.__boss.runnable = True
+        self.__boss._srv_socket = self.__FalseSocket(self)
+        self.__boss._srv_accept = self.__accept
+        self.__boss.ccs = self.__CCS()
+        bind10_src.select.select = self.__select_accept
+        self.__boss.run(2)
+        # It called the accept
+        self.assertTrue(self.__accept_called)
+        # And the select had the right parameters
+        self.assertEqual(([2, 1, 42], [], [], None), self.__select_called)
+
+    def test_srv_accept(self):
+        """
+        Test how the _srv_accept method works.
+        """
+        self.__boss._srv_socket = self.__FalseSocket(self)
+        self.__boss._srv_accept()
+        # After we accepted, a new socket is added there
+        socket = self.__boss._unix_sockets[13][0]
+        # The socket is properly stored there
+        self.assertIsInstance(socket, self.__FalseSocket)
+        # And the buffer (yet empty) is there
+        self.assertEqual({13: (socket, b'')}, self.__boss._unix_sockets)
+
+    def __socket_data(self, socket):
+        self.__boss.runnable = False
+        self.__socket_data_called = socket
+
+    def test_socket_data(self):
+        """
+        Test that a socket that wants attention gets it.
+        """
+        self.__boss._srv_socket = self.__FalseSocket(self)
+        self.__boss._socket_data = self.__socket_data
+        self.__boss.ccs = self.__CCS()
+        self.__boss._unix_sockets = {13: (self.__FalseSocket(self, 13), b'')}
+        self.__boss.runnable = True
+        bind10_src.select.select = self.__select_data
+        self.__boss.run(2)
+        self.assertEqual(13, self.__socket_data_called)
+        self.assertEqual(([2, 1, 42, 13], [], [], None), self.__select_called)
+
+    def __prepare_data(self, data):
+        socket = self.__FalseSocket(self, 13)
+        self.__boss._unix_sockets = {13: (socket, b'')}
+        socket.data = data
+        self.__boss.socket_consumer_dead = self.__consumer_dead
+        self.__boss.socket_request_handler = self.__socket_request_handler
+        return socket
+
+    def __consumer_dead(self, socket):
+        self.__consumer_dead_called = socket
+
+    def __socket_request_handler(self, token, socket):
+        self.__socket_request_handler_called = (token, socket)
+
+    def test_socket_closed(self):
+        """
+        Test that a socket is removed and the socket_consumer_dead is called
+        when it is closed.
+        """
+        socket = self.__prepare_data(None)
+        self.__boss._socket_data(13)
+        self.assertEqual(socket, self.__consumer_dead_called)
+        self.assertEqual({}, self.__boss._unix_sockets)
+        self.assertTrue(socket.closed)
+
+    def test_socket_short(self):
+        """
+        Test that if there's not enough data to get the whole socket, it is
+        kept there, but nothing is called.
+        """
+        socket = self.__prepare_data(b'tok')
+        self.__boss._socket_data(13)
+        self.assertEqual({13: (socket, b'tok')}, self.__boss._unix_sockets)
+        self.assertFalse(socket.closed)
+        self.assertIsNone(self.__consumer_dead_called)
+        self.assertIsNone(self.__socket_request_handler_called)
+
+    def test_socket_continue(self):
+        """
+        Test that we call the token handling function when the whole token
+        comes. This test pretends to continue reading where the previous one
+        stopped.
+        """
+        socket = self.__prepare_data(b"en\nanothe")
+        # The data to finish
+        self.__boss._unix_sockets[13] = (socket, b'tok')
+        self.__boss._socket_data(13)
+        self.assertEqual({13: (socket, b'anothe')}, self.__boss._unix_sockets)
+        self.assertFalse(socket.closed)
+        self.assertIsNone(self.__consumer_dead_called)
+        self.assertEqual((b'token', socket),
+                         self.__socket_request_handler_called)
+
+    def test_broken_socket(self):
+        """
+        If the socket raises an exception during the read other than EAGAIN,
+        it is broken and we remove it.
+        """
+        sock = self.__prepare_data(socket.error(errno.ENOMEM,
+            "There's more memory available, but not for you"))
+        self.__boss._socket_data(13)
+        self.assertEqual(sock, self.__consumer_dead_called)
+        self.assertEqual({}, self.__boss._unix_sockets)
+        self.assertTrue(sock.closed)
+
 if __name__ == '__main__':
     # store os.environ for test_unchanged_environment
     original_os_environ = copy.deepcopy(os.environ)