Browse Source

[master] Merge branch 'trac2398'

Jelte Jansen 12 years ago
parent
commit
9f6b45ee21
2 changed files with 267 additions and 30 deletions
  1. 54 27
      src/bin/msgq/msgq.py.in
  2. 213 3
      src/bin/msgq/tests/msgq_test.py

+ 54 - 27
src/bin/msgq/msgq.py.in

@@ -127,6 +127,7 @@ class MsgQ:
         self.subs = SubscriptionManager()
         self.lnames = {}
         self.sendbuffs = {}
+        self.running = False
 
     def setup_poller(self):
         """Set up the poll thing.  Internal function."""
@@ -238,6 +239,7 @@ class MsgQ:
         self.subs.unsubscribe_all(sock)
         lname = [ k for k, v in self.lnames.items() if v == sock ][0]
         del self.lnames[lname]
+        sock.shutdown(socket.SHUT_RDWR)
         sock.close()
         del self.sockets[fd]
         if fd in self.sendbuffs:
@@ -315,6 +317,8 @@ class MsgQ:
         elif cmd == 'ping':
             # Command for testing purposes
             self.process_command_ping(sock, routing, data)
+        elif cmd == 'stop':
+            self.stop()
         else:
             sys.stderr.write("[b10-msgq] Invalid command: %s\n" % cmd)
 
@@ -336,14 +340,34 @@ class MsgQ:
         self.send_prepared_msg(sock, self.preparemsg(env, msg))
 
     def __send_data(self, sock, data):
+        """
+        Send a piece of data to the given socket.
+        Parameters:
+        sock: The socket to send to
+        data: The list of bytes to send
+        Returns:
+        An integer or None. If an integer (which can be 0), it signals
+        the number of bytes sent. If None, the socket appears to have
+        been closed on the other end, and it has been killed on this
+        side too.
+        """
         try:
             # We set the socket nonblocking, MSG_DONTWAIT doesn't exist
             # on some OSes
             sock.setblocking(0)
             return sock.send(data)
         except socket.error as e:
-            if e.errno == errno.EAGAIN or e.errno == errno.EWOULDBLOCK:
+            if e.errno in [ errno.EAGAIN,
+                            errno.EWOULDBLOCK,
+                            errno.EINTR ]:
                 return 0
+            elif e.errno in [ errno.EPIPE,
+                              errno.ECONNRESET,
+                              errno.ENOBUFS ]:
+                print("[b10-msgq] " + errno.errorcode[e.errno] +
+                      " on send, dropping message and closing connection")
+                self.kill_socket(sock.fileno(), sock)
+                return None
             else:
                 raise e
         finally:
@@ -356,20 +380,12 @@ class MsgQ:
         if fileno in self.sendbuffs:
             amount_sent = 0
         else:
-            try:
-                amount_sent = self.__send_data(sock, msg)
-            except socket.error as sockerr:
-                # in the case the other side seems gone, kill the socket
-                # and drop the send action
-                if sockerr.errno == errno.EPIPE:
-                    print("[b10-msgq] SIGPIPE on send, dropping message " +
-                          "and closing connection")
-                    self.kill_socket(fileno, sock)
-                    return
-                else:
-                    raise
+            amount_sent = self.__send_data(sock, msg)
+            if amount_sent is None:
+                # Socket has been killed, drop the send
+                return
 
-        # Still something to send
+        # Still something to send, add it to outgoing queue
         if amount_sent < len(msg):
             now = time.clock()
             # Append it to buffer (but check the data go away)
@@ -394,17 +410,18 @@ class MsgQ:
         (_, msg) = self.sendbuffs[fileno]
         sock = self.sockets[fileno]
         amount_sent = self.__send_data(sock, msg)
-        # Keep the rest
-        msg = msg[amount_sent:]
-        if len(msg) == 0:
-            # If there's no more, stop requesting for write availability
-            if self.poller:
-                self.poller.register(fileno, select.POLLIN)
+        if amount_sent is not None:
+            # Keep the rest
+            msg = msg[amount_sent:]
+            if len(msg) == 0:
+                # If there's no more, stop requesting for write availability
+                if self.poller:
+                    self.poller.register(fileno, select.POLLIN)
+                else:
+                    self.delete_kqueue_socket(sock, True)
+                del self.sendbuffs[fileno]
             else:
-                self.delete_kqueue_socket(sock, True)
-            del self.sendbuffs[fileno]
-        else:
-            self.sendbuffs[fileno] = (time.clock(), msg)
+                self.sendbuffs[fileno] = (time.clock(), msg)
 
     def newlname(self):
         """Generate a unique connection identifier for this socket.
@@ -458,6 +475,7 @@ class MsgQ:
 
     def run(self):
         """Process messages.  Forever.  Mostly."""
+        self.running = True
 
         if self.poller:
             self.run_poller()
@@ -465,8 +483,10 @@ class MsgQ:
             self.run_kqueue()
 
     def run_poller(self):
-        while True:
+        while self.running:
             try:
+                # Poll with a timeout so that every once in a while,
+                # the loop checks for self.running.
                 events = self.poller.poll()
             except select.error as err:
                 if err.args[0] == errno.EINTR:
@@ -480,11 +500,15 @@ class MsgQ:
                 else:
                     if event & select.POLLOUT:
                         self.__process_write(fd)
-                    if event & select.POLLIN:
+                    elif event & select.POLLIN:
                         self.process_socket(fd)
+                    else:
+                        print("[b10-msgq] Error: Unknown even in run_poller()")
 
     def run_kqueue(self):
-        while True:
+        while self.running:
+            # Check with a timeout so that every once in a while,
+            # the loop checks for self.running.
             events = self.kqueue.control(None, 10)
             if not events:
                 raise RuntimeError('serve: kqueue returned no events')
@@ -502,6 +526,9 @@ class MsgQ:
                         self.kill_socket(event.ident,
                                          self.sockets[event.ident])
 
+    def stop(self):
+        self.running = False
+
     def shutdown(self):
         """Stop the MsgQ master."""
         if self.verbose:

+ 213 - 3
src/bin/msgq/tests/msgq_test.py

@@ -6,6 +6,8 @@ import socket
 import signal
 import sys
 import time
+import errno
+import threading
 import isc.cc
 
 #
@@ -112,6 +114,85 @@ class TestSubscriptionManager(unittest.TestCase):
         msgq = MsgQ("/does/not/exist")
         self.assertRaises(socket.error, msgq.setup)
 
+class DummySocket:
+    """
+    Dummy socket class.
+    This one does nothing at all, but some calls are used.
+    It is mainly intended to override the listen socket for msgq, which
+    we do not need in these tests.
+    """
+    def fileno():
+        return -1
+
+    def close():
+        pass
+
+class BadSocket:
+    """
+    Special socket wrapper class. Once given a socket in its constructor,
+    it completely behaves like that socket, except that its send() call
+    will only actually send one byte per call, and optionally raise a given
+    exception at a given time.
+    """
+    def __init__(self, real_socket, raise_on_send=0, send_exception=None):
+        """
+        Parameters:
+        real_socket: The actual socket to wrap
+        raise_on_send: integer. If higher than 0, and send_exception is
+                       not None, send_exception will be raised on the
+                       'raise_on_send'th call to send().
+        send_exception: if not None, this exception will be raised
+                        (if raise_on_send is not 0)
+        """
+        self.socket = real_socket
+        self.send_count = 0
+        self.raise_on_send = raise_on_send
+        self.send_exception = send_exception
+
+    # completely wrap all calls and member access
+    # (except explicitely overridden ones)
+    def __getattr__(self, name, *args):
+        attr = getattr(self.socket, name)
+        if callable(attr):
+            def callable_attr(*args):
+                return attr.__call__(*args)
+            return callable_attr
+        else:
+            return attr
+
+    def send(self, data):
+        self.send_count += 1
+        if self.send_exception is not None and\
+           self.send_count == self.raise_on_send:
+            raise self.send_exception
+
+        if len(data) > 0:
+            return self.socket.send(data[:1])
+        else:
+            return 0
+
+class MsgQThread(threading.Thread):
+    """
+    Very simple thread class that runs msgq.run() when started,
+    and stores the exception that msgq.run() raises, if any.
+    """
+    def __init__(self, msgq):
+        threading.Thread.__init__(self)
+        self.msgq_ = msgq
+        self.caught_exception = None
+        self.lock = threading.Lock()
+
+    def run(self):
+        try:
+            self.msgq_.run()
+        except Exception as exc:
+            # Store the exception to make the test fail if necessary
+            self.caught_exception = exc
+
+    def stop(self):
+        self.msgq_.stop()
+
+
 class SendNonblock(unittest.TestCase):
     """
     Tests that the whole thing will not get blocked if someone does not read.
@@ -191,9 +272,6 @@ class SendNonblock(unittest.TestCase):
         msgq = MsgQ()
         # msgq.run needs to compare with the listen_socket, so we provide
         # a replacement
-        class DummySocket:
-            def fileno():
-                return -1
         msgq.listen_socket = DummySocket
         (queue, out) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
         def run():
@@ -245,5 +323,137 @@ class SendNonblock(unittest.TestCase):
             data = data + data
         self.send_many(data)
 
+    def do_send(self, write, read, control_write, control_read,
+                expect_arrive=True, expect_send_exception=None):
+        """
+        Makes a msgq object that is talking to itself,
+        run it in a separate thread so we can use and
+        test run().
+        It is given two sets of connected sockets; write/read, and
+        control_write/control_read. The former may be throwing errors
+        and mangle data to test msgq. The second is mainly used to
+        send msgq the stop command.
+        (Note that the terms 'read' and 'write' are from the msgq
+        point of view, so the test itself writes to 'control_read')
+        Parameters:
+        write: a socket that is used to send the data to
+        read: a socket that is used to read the data from
+        control_write: a second socket for communication with msgq
+        control_read: a second socket for communication with msgq
+        expect_arrive: if True, the read socket is read from, and the data
+                       that is read is expected to be the same as the data
+                       that has been sent to the write socket.
+        expect_send_exception: if not None, this is the exception that is
+                               expected to be raised by msgq
+        """
+
+        # Some message and envelope data to send and check
+        env = b'{"env": "foo"}'
+        msg = b'{"msg": "bar"}'
+
+        msgq = MsgQ()
+        # Don't need a listen_socket
+        msgq.listen_socket = DummySocket
+        msgq.setup_poller()
+        msgq.register_socket(write)
+        msgq.register_socket(control_write)
+        # Queue the message for sending
+        msgq.sendmsg(write, env, msg)
+
+        # Run it in a thread
+        msgq_thread = MsgQThread(msgq)
+        # If we're done, just kill it
+        msgq_thread.start()
+
+        if expect_arrive:
+            (recv_env, recv_msg) = msgq.read_packet(read.fileno(),
+                read)
+            self.assertEqual(env, recv_env)
+            self.assertEqual(msg, recv_msg)
+
+        # Tell msgq to stop
+        msg = msgq.preparemsg({"type" : "stop"})
+        control_read.sendall(msg)
+
+        # Wait for thread to stop if it hasn't already.
+        # Put in a (long) timeout; the thread *should* stop, but if it
+        # does not, we don't want the test to hang forever
+        msgq_thread.join(60)
+        # Fail the test if it didn't stop
+        self.assertFalse(msgq_thread.isAlive(), "Thread did not stop")
+
+        # Check the exception from the thread, if any
+        # First, if we didn't expect it; reraise it (to make test fail and
+        # show the stacktrace for debugging)
+        if expect_send_exception is None:
+            if msgq_thread.caught_exception is not None:
+                raise msgq_thread.caught_exception
+        else:
+            # If we *did* expect it, fail it there was none
+            self.assertIsNotNone(msgq_thread.caught_exception)
+
+    def do_send_with_send_error(self, raise_on_send, send_exception,
+                                expect_answer=True,
+                                expect_send_exception=None):
+        """
+        Sets up two connected sockets, wraps the sender socket into a BadSocket
+        class, then performs a do_send() test.
+        Parameters:
+        raise_on_send: the byte at which send_exception should be raised
+                       (see BadSocket)
+        send_exception: the exception to raise (see BadSocket)
+        expect_answer: whether the send is expected to complete (and hence
+                       the read socket should get the message)
+        expect_send_exception: the exception msgq is expected to raise when
+                               send_exception is raised by BadSocket.
+        """
+        (write, read) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
+        (control_write, control_read) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
+        badwrite = BadSocket(write, raise_on_send, send_exception)
+        self.do_send(badwrite, read, control_write, control_read, expect_answer, expect_send_exception)
+        write.close()
+        read.close()
+        control_write.close()
+        control_read.close()
+
+    def test_send_raise_recoverable(self):
+        """
+        Test whether msgq survices a recoverable socket errors when sending.
+        Two tests are done: one where the error is raised on the 3rd octet,
+                            and one on the 23rd.
+        """
+        sockerr = socket.error
+        for err in [ errno.EAGAIN, errno.EWOULDBLOCK, errno.EINTR ]:
+            sockerr.errno = err
+            self.do_send_with_send_error(3, sockerr)
+            self.do_send_with_send_error(23, sockerr)
+
+    def test_send_raise_nonrecoverable(self):
+        """
+        Test whether msgq survives socket errors that are nonrecoverable
+        (for said socket that is, i.e. EPIPE etc).
+        Two tests are done: one where the error is raised on the 3rd octet,
+                            and one on the 23rd.
+        """
+        sockerr = socket.error
+        for err in [ errno.EPIPE, errno.ENOBUFS, errno.ECONNRESET ]:
+            sockerr.errno = err
+            self.do_send_with_send_error(3, sockerr, False)
+            self.do_send_with_send_error(23, sockerr, False)
+
+    def otest_send_raise_crash(self):
+        """
+        Test whether msgq does NOT survive on a general exception.
+        Note, perhaps it should; but we'd have to first discuss and decide
+        how it should recover (i.e. drop the socket and consider the client
+        dead?
+        It may be a coding problem in msgq itself, and we certainly don't
+        want to ignore those.
+        """
+        sockerr = Exception("just some general exception")
+        self.do_send_with_send_error(3, sockerr, False, sockerr)
+        self.do_send_with_send_error(23, sockerr, False, sockerr)
+
+
 if __name__ == '__main__':
     unittest.main()