Browse Source

Merge branch trac420

Michal 'vorner' Vaner 14 years ago
parent
commit
93697f58e4
2 changed files with 203 additions and 11 deletions
  1. 82 9
      src/bin/msgq/msgq.py.in
  2. 121 2
      src/bin/msgq/tests/msgq_test.py

+ 82 - 9
src/bin/msgq/msgq.py.in

@@ -127,6 +127,7 @@ class MsgQ:
         self.hostname = socket.gethostname()
         self.subs = SubscriptionManager()
         self.lnames = {}
+        self.sendbuffs = {}
 
     def setup_poller(self):
         """Set up the poll thing.  Internal function."""
@@ -135,9 +136,11 @@ class MsgQ:
         except AttributeError:
             self.kqueue = select.kqueue()
     
-    def add_kqueue_socket(self, socket):
-        event = select.kevent(socket.fileno(),
-                              select.KQ_FILTER_READ,
+    def add_kqueue_socket(self, socket, enable_write = False):
+        filters = select.KQ_FILTER_READ
+        if enable_write:
+            filters |= select.KQ_FILTER_WRITE
+        event = select.kevent(socket.fileno(), filters,
                               select.KQ_EV_ADD | select.KQ_EV_ENABLE)
         self.kqueue.control([event], 0)
 
@@ -187,6 +190,12 @@ class MsgQ:
         # TODO: When we have logging, we might want
         # to add a debug message here that a new connection
         # was made
+        self.register_socket(self, newsocket)
+
+    def register_socket(self, newsocket):
+        """
+        Internal function to insert a socket. Used by process_accept and some tests.
+        """
         self.sockets[newsocket.fileno()] = newsocket
         lname = self.newlname()
         self.lnames[lname] = newsocket
@@ -198,10 +207,10 @@ class MsgQ:
 
     def process_socket(self, fd):
         """Process a read on a socket."""
-        sock = self.sockets[fd]
-        if sock == None:
+        if not fd in self.sockets:
             sys.stderr.write("[b10-msgq] Got read on Strange Socket fd %d\n" % fd)
             return
+        sock = self.sockets[fd]
 #        sys.stderr.write("[b10-msgq] Got read on fd %d\n" %fd)
         self.process_packet(fd, sock)
 
@@ -213,7 +222,9 @@ class MsgQ:
         lname = [ k for k, v in self.lnames.items() if v == sock ][0]
         del self.lnames[lname]
         sock.close()
-        self.sockets[fd] = None
+        del self.sockets[fd]
+        if fd in self.sendbuffs:
+            del self.sendbuffs[fd]
         sys.stderr.write("[b10-msgq] Closing socket fd %d\n" % fd)
 
     def getbytes(self, fd, sock, length):
@@ -287,6 +298,9 @@ class MsgQ:
             self.process_command_unsubscribe(sock, routing, data)
         elif cmd == 'getlname':
             self.process_command_getlname(sock, routing, data)
+        elif cmd == 'ping':
+            # Command for testing purposes
+            self.process_command_ping(sock, routing, data)
         else:
             sys.stderr.write("[b10-msgq] Invalid command: %s\n" % cmd)
 
@@ -305,10 +319,61 @@ class MsgQ:
         return ret
 
     def sendmsg(self, sock, env, msg = None):
-        sock.send(self.preparemsg(env, msg))
+        self.send_prepared_msg(sock, self.preparemsg(env, msg))
+
+    def __send_data(self, sock, data):
+        try:
+            return sock.send(data, socket.MSG_DONTWAIT)
+        except socket.error as e:
+            if e.errno == errno.EAGAIN or e.errno == errno.EWOULDBLOCK:
+                return 0
+            else:
+                raise e
 
     def send_prepared_msg(self, sock, msg):
-        sock.send(msg)
+        # Try to send the data, but only if there's nothing waiting
+        fileno = sock.fileno()
+        if fileno in self.sendbuffs:
+            amount_sent = 0
+        else:
+            amount_sent = self.__send_data(sock, msg)
+
+        # Still something to send
+        if amount_sent < len(msg):
+            now = time.clock()
+            # Append it to buffer (but check the data go away)
+            if fileno in self.sendbuffs:
+                (last_sent, buff) = self.sendbuffs[fileno]
+                if now - last_sent > 0.1:
+                    self.kill_socket(fileno, sock)
+                    return
+                buff += msg
+            else:
+                buff = msg[amount_sent:]
+                last_sent = now
+                if self.poller:
+                    self.poller.register(fileno, select.POLLIN |
+                        select.POLLOUT)
+                else:
+                    self.add_kqueue_socket(fileno, True)
+            self.sendbuffs[fileno] = (last_sent, buff)
+
+    def __process_write(self, fileno):
+        # Try to send some data from the buffer
+        (_, msg) = self.sendbuffs[fileno]
+        sock = self.sockets[fileno]
+        amount_sent = self.__send_data(sock, msg)
+        # Keep the rest
+        msg = msg[amount_sent:]
+        if len(msg) == 0:
+            # If there's no more, stop requesting for write availability
+            if self.poller:
+                self.poller.register(fileno, select.POLLIN)
+            else:
+                self.add_kqueue_socket(fileno)
+            del self.sendbuffs[fileno]
+        else:
+            self.sendbuffs[fileno] = (time.clock(), msg)
 
     def newlname(self):
         """Generate a unique connection identifier for this socket.
@@ -317,6 +382,9 @@ class MsgQ:
         self.connection_counter += 1
         return "%x_%x@%s" % (time.time(), self.connection_counter, self.hostname)
 
+    def process_command_ping(self, sock, routing, data):
+        self.sendmsg(sock, { "type" : "pong" }, data)
+
     def process_command_getlname(self, sock, routing, data):
         lname = [ k for k, v in self.lnames.items() if v == sock ][0]
         self.sendmsg(sock, { "type" : "getlname" }, { "lname" : lname })
@@ -379,7 +447,10 @@ class MsgQ:
                 if fd == self.listen_socket.fileno():
                     self.process_accept()
                 else:
-                    self.process_socket(fd)
+                    if event & select.POLLOUT:
+                        self.__process_write(fd)
+                    if event & select.POLLIN:
+                        self.process_socket(fd)
 
     def run_kqueue(self):
         while True:
@@ -391,6 +462,8 @@ class MsgQ:
                 if event.ident == self.listen_socket.fileno():
                     self.process_accept()
                 else:
+                    if event.flags & select.KQ_FILTER_WRITE:
+                        self.process_socket(event.ident)
                     if event.flags & select.KQ_FILTER_READ and event.data > 0:
                         self.process_socket(event.ident)
                     elif event.flags & select.KQ_EV_EOF:

+ 121 - 2
src/bin/msgq/tests/msgq_test.py

@@ -3,10 +3,14 @@ from msgq import SubscriptionManager, MsgQ
 import unittest
 import os
 import socket
+import signal
+import sys
+import time
+import isc.cc
 
 #
-# Currently only the subscription part is implemented...  I'd have to mock
-# out a socket, which, while not impossible, is not trivial.
+# Currently only the subscription part and some sending is implemented...
+# I'd have to mock out a socket, which, while not impossible, is not trivial.
 #
 
 class TestSubscriptionManager(unittest.TestCase):
@@ -108,5 +112,120 @@ class TestSubscriptionManager(unittest.TestCase):
         msgq = MsgQ("/does/not/exist")
         self.assertRaises(socket.error, msgq.setup)
 
+class SendNonblock(unittest.TestCase):
+    """
+    Tests that the whole thing will not get blocked if someone does not read.
+    """
+
+    def terminate_check(self, task, timeout = 1):
+        """
+        Runs task in separate process (task is a function) and checks
+        it terminates sooner than timeout.
+        """
+        task_pid = os.fork()
+        if task_pid == 0:
+            # Kill the forked process after timeout by SIGALRM
+            signal.alarm(timeout)
+            # Run the task
+            # If an exception happens or we run out of time, we terminate
+            # with non-zero
+            task()
+            # If we got here, then everything worked well and in time
+            # In that case, we terminate successfully
+            sys.exit()
+        else:
+            (pid, status) = os.waitpid(task_pid, 0)
+            self.assertEqual(0, status,
+                "The task did not complete successfully in time")
+
+    def infinite_sender(self, sender):
+        """
+        Sends data until an exception happens. socket.error is caught,
+        as it means the socket got closed. Sender is called to actually
+        send the data.
+        """
+        msgq = MsgQ()
+        # We do only partial setup, so we don't create the listening socket
+        msgq.setup_poller()
+        (read, write) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
+        msgq.register_socket(write)
+        # Keep sending while it is not closed by the msgq
+        try:
+            while True:
+                sender(msgq, write)
+        except socket.error:
+            pass
+
+    def test_infinite_sendmsg(self):
+        """
+        Tries sending messages (and not reading them) until it either times
+        out (in blocking call, wrong) or closes it (correct).
+        """
+        self.terminate_check(lambda: self.infinite_sender(
+            lambda msgq, socket: msgq.sendmsg(socket, {}, {"message" : "x"})))
+
+    def test_infinite_sendprepared(self):
+        """
+        Tries sending data (and not reading them) until it either times
+        out (in blocking call, wrong) or closes it (correct).
+        """
+        self.terminate_check(lambda: self.infinite_sender(
+            lambda msgq, socket: msgq.send_prepared_msg(socket, b"data")))
+
+    def send_many(self, data):
+        """
+        Tries that sending a command many times and getting an answer works.
+        """
+        msgq = MsgQ()
+        msgq.setup_poller()
+        # msgq.run needs to compare with the listen_socket, so we provide
+        # a replacement
+        class DummySocket:
+            def fileno():
+                return -1
+        msgq.listen_socket = DummySocket
+        (queue, out) = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
+        def run():
+            length = len(data)
+            queue_pid = os.fork()
+            if queue_pid == 0:
+                signal.alarm(10)
+                msgq.register_socket(queue)
+                msgq.run()
+            else:
+                try:
+                    def killall(signum, frame):
+                        os.kill(queue_pid, signal.SIGTERM)
+                        sys.exit(1)
+                    signal.signal(signal.SIGALRM, killall)
+                    msg = msgq.preparemsg({"type" : "ping"}, data)
+                    now = time.clock()
+                    while time.clock() - now < 0.2:
+                        out.sendall(msg)
+                        # Check the answer
+                        (routing, received) = msgq.read_packet(out.fileno(),
+                            out)
+                        self.assertEqual({"type" : "pong"},
+                            isc.cc.message.from_wire(routing))
+                        self.assertEqual(data, received)
+                finally:
+                    os.kill(queue_pid, signal.SIGTERM)
+        self.terminate_check(run)
+
+    def test_small_sends(self):
+        """
+        Tests sending small data many times.
+        """
+        self.send_many(b"data")
+
+    def test_large_sends(self):
+        """
+        Tests sending large data many times.
+        """
+        data = b"data"
+        for i in range(1, 20):
+            data = data + data
+        self.send_many(data)
+
 if __name__ == '__main__':
     unittest.main()