Browse Source

checkpoint work; Python-based msgq mostly works. Bad input will crash it, which should be fixed, probably by wrapping the entire message processing in a try loop. Gross, but...

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/parkinglot@333 e5f2f494-b856-4b98-b285-d166d9295462
Michael Graff 15 years ago
parent
commit
a6823843ad
3 changed files with 217 additions and 16 deletions
  1. 111 16
      src/bin/pymsgq/msgq.py
  2. 62 0
      src/bin/pymsgq/msgq_test.py
  3. 44 0
      src/lib/cc/python/test_session.py

+ 111 - 16
src/bin/pymsgq/msgq.py

@@ -15,15 +15,60 @@ import errno
 import time
 import select
 import pprint
+import random
 from optparse import OptionParser, OptionValueError
 
 import ISC.CC
 
-class MsgQReceiveError(Exception): pass
-
 # This is the version that gets displayed to the user.
 __version__ = "v20091030 (Paving the DNS Parking Lot)"
 
+class MsgQReceiveError(Exception): pass
+
+class SubscriptionManager:
+    def __init__(self):
+        self.subscriptions = {}
+
+    def subscribe(self, group, instance, socket):
+        """Add a subscription."""
+        target = ( group, instance )
+        if target in self.subscriptions:
+            print("Appending to existing target")
+            self.subscriptions[target].append(socket)
+        else:
+            print("Creating new target")
+            self.subscriptions[target] = [ socket ]
+
+    def unsubscribe(self, group, instance, socket):
+        """Remove the socket from the one specific subscription."""
+        target = ( group, instance )
+        if target in self.subscriptions:
+            while socket in self.subscriptions[target]:
+                self.subscriptions[target].remove(socket)
+
+    def unsubscribe_all(self, socket):
+        """Remove the socket from all subscriptions."""
+        for socklist in self.subscriptions.values():
+            while socket in socklist:
+                socklist.remove(socket)
+
+    def find_sub(self, group, instance):
+        """Return an array of sockets which want this specific group,
+        instance."""
+        target = (group, instance)
+        if target in self.subscriptions:
+            return self.subscriptions[target]
+        else:
+            return []
+
+    def find(self, group, instance):
+        """Return an array of sockets who should get something sent to
+        this group, instance pair.  This includes wildcard subscriptions."""
+        target = (group, instance)
+        partone = self.find_sub(group, instance)
+        parttwo = self.find_sub(group, "*")
+        return list(set(partone + parttwo))
+
 class MsgQ:
     """Message Queue class."""
     def __init__(self, c_channel_port=9912, verbose=False):
@@ -39,6 +84,9 @@ class MsgQ:
         self.runnable = False
         self.listen_socket = False
         self.sockets = {}
+        self.connection_counter = random.random()
+        self.hostname = socket.gethostname()
+        self.subs = SubscriptionManager()
 
     def setup_poller(self):
         """Set up the poll thing.  Internal function."""
@@ -77,12 +125,13 @@ class MsgQ:
         if sock == None:
             sys.stderr.write("Got read on Strange Socket fd %d\n" % fd)
             return
-        sys.stderr.write("Got read on fd %d\n" %fd)
+#        sys.stderr.write("Got read on fd %d\n" %fd)
         self.process_packet(fd, sock)
 
     def kill_socket(self, fd, sock):
         """Fully close down the socket."""
         self.poller.unregister(sock)
+        self.subs.unsubscribe_all(sock)
         sock.close()
         self.sockets[fd] = None
         sys.stderr.write("Closing socket fd %d\n" % fd)
@@ -106,8 +155,6 @@ class MsgQ:
         if overall_length < 2:
             raise MsgQReceiveError("overall_length < 2")
         overall_length -= 2
-        sys.stderr.write("overall length: %d, routing_length %d\n"
-                         % (overall_length, routing_length))
         if routing_length > overall_length:
             raise MsgQReceiveError("routing_length > overall_length")
         if routing_length == 0:
@@ -137,8 +184,8 @@ class MsgQ:
             sys.stderr.write("Routing decode error: %s\n" % err)
             return
 
-        sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
-        sys.stdout.write("\t" + pprint.pformat(data) + "\n")
+#        sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
+#        sys.stdout.write("\t" + pprint.pformat(data) + "\n")
 
         self.process_command(fd, sock, routingmsg, data)
 
@@ -146,29 +193,77 @@ class MsgQ:
         """Process a single command.  This will split out into one of the
            other functions, above."""
         cmd = routing["type"]
-        if cmd == 'getlname':
-            self.process_command_getlname(sock, routing, data)
-        elif cmd == 'send':
+        if cmd == 'send':
             self.process_command_send(sock, routing, data)
+        elif cmd == 'subscribe':
+            self.process_command_subscribe(sock, routing, data)
+        elif cmd == 'unsubscribe':
+            self.process_command_unsubscribe(sock, routing, data)
+        elif cmd == 'getlname':
+            self.process_command_getlname(sock, routing, data)
         else:
             sys.stderr.write("Invalid command: %s\n" % cmd)
 
-    def sendmsg(self, sock, env, msg = None):
+    def preparemsg(self, env, msg = None):
         if type(env) == dict:
             env = ISC.CC.Message.to_wire(env)
         if type(msg) == dict:
             msg = ISC.CC.Message.to_wire(msg)
-        sock.setblocking(1)
         length = 2 + len(env);
         if msg:
             length += len(msg)
-        sock.send(struct.pack("!IH", length, len(env)))
-        sock.send(env)
+        ret = struct.pack("!IH", length, len(env))
+        ret += env
         if msg:
-            sock.send(msg)
+            ret += msg
+        return ret
+
+    def sendmsg(self, sock, env, msg = None):
+        sock.send(self.preparemsg(env, msg))
+
+    def send_prepared_msg(self, sock, msg):
+        sock.send(msg)
+
+    def newlname(self):
+        """Generate a unique conenction identifier for this socket.
+        This is done by using an increasing counter and the current
+        time."""
+        self.connection_counter += 1
+        return "%x_%x@%s" % (time.time(), self.connection_counter, self.hostname)
 
     def process_command_getlname(self, sock, routing, data):
-        self.sendmsg(sock, { "type" : "getlname" }, { "lname" : "staticlname" })
+        env = { "type" : "getlname" }
+        reply = { "lname" : self.newlname() }
+        self.sendmsg(sock, env, reply)
+
+    def process_command_send(self, sock, routing, data):
+        group = routing["group"]
+        instance = routing["instance"]
+        if group == None or instance == None:
+            return  # ignore invalid packets entirely
+        sockets = self.subs.find(group, instance)
+
+        msg = self.preparemsg(routing, data)
+
+        if sock in sockets:
+            sockets.remove(sock)
+        for socket in sockets:
+            self.send_prepared_msg(socket, msg)
+
+    def process_command_subscribe(self, sock, routing, data):
+        group = routing["group"]
+        instance = routing["instance"]
+        subtype = routing["subtype"]
+        if group == None or instance == None or subtype == None:
+            return  # ignore invalid packets entirely
+        self.subs.subscribe(group, instance, sock)
+
+    def process_command_unsubscribe(self, sock, routing, data):
+        group = routing["group"]
+        instance = routing["instance"]
+        if group == None or instance == None:
+            return  # ignore invalid packets entirely
+        self.subs.unsubscribe(group, instance, sock)
 
     def run(self):
         """Process messages.  Forever.  Mostly."""

+ 62 - 0
src/bin/pymsgq/msgq_test.py

@@ -0,0 +1,62 @@
+from msgq import SubscriptionManager, MsgQ
+
+import unittest
+
+#
+# Currently only the subscription part is implemented...  I'd have to mock
+# out a socket, which, while not impossible, is not trivial.
+#
+
+class TestSubscriptionManager(unittest.TestCase):
+    def setUp(self):
+        self.sm = SubscriptionManager()
+
+    def test_subscription_add_delete_manager(self):
+        self.sm.subscribe("a", "*", 'sock1')
+        self.assertEqual(self.sm.find_sub("a", "*"), [ 'sock1' ])
+
+    def test_subscription_add_delete_other(self):
+        self.sm.subscribe("a", "*", 'sock1')
+        self.sm.unsubscribe("a", "*", 'sock2')
+        self.assertEqual(self.sm.find_sub("a", "*"), [ 'sock1' ])
+
+    def test_subscription_add_several_sockets(self):
+        socks = [ 's1', 's2', 's3', 's4', 's5' ]
+        for s in socks:
+            self.sm.subscribe("a", "*", s)
+        self.assertEqual(self.sm.find_sub("a", "*"), socks)
+
+    def test_unsubscribe(self):
+        socks = [ 's1', 's2', 's3', 's4', 's5' ]
+        for s in socks:
+            self.sm.subscribe("a", "*", s)
+        self.sm.unsubscribe("a", "*", 's3')
+        self.assertEqual(self.sm.find_sub("a", "*"), [ 's1', 's2', 's4', 's5' ])
+
+    def test_unsubscribe_all(self):
+        self.sm.subscribe('g1', 'i1', 's1')
+        self.sm.subscribe('g1', 'i1', 's2')
+        self.sm.subscribe('g1', 'i2', 's1')
+        self.sm.subscribe('g1', 'i2', 's2')
+        self.sm.subscribe('g2', 'i1', 's1')
+        self.sm.subscribe('g2', 'i1', 's2')
+        self.sm.subscribe('g2', 'i2', 's1')
+        self.sm.subscribe('g2', 'i2', 's2')
+        self.sm.unsubscribe_all('s1')
+        self.assertEqual(self.sm.find_sub("g1", "i1"), [ 's2' ])
+        self.assertEqual(self.sm.find_sub("g1", "i2"), [ 's2' ])
+        self.assertEqual(self.sm.find_sub("g2", "i1"), [ 's2' ])
+        self.assertEqual(self.sm.find_sub("g2", "i2"), [ 's2' ])
+
+    def test_find(self):
+        self.sm.subscribe('g1', 'i1', 's1')
+        self.sm.subscribe('g1', '*', 's2')
+        self.assertEqual(set(self.sm.find("g1", "i1")), set([ 's1', 's2' ]))
+
+    def test_find_sub(self):
+        self.sm.subscribe('g1', 'i1', 's1')
+        self.sm.subscribe('g1', '*', 's2')
+        self.assertEqual(self.sm.find_sub("g1", "i1"), [ 's1' ])
+
+if __name__ == '__main__':
+    unittest.main()

+ 44 - 0
src/lib/cc/python/test_session.py

@@ -0,0 +1,44 @@
+import ISC
+
+import time
+import pprint
+import unittest
+
+#
+# This test requires the MsgQ daemon to be running.  We are doing nasty
+# tricks here, and so insert sleeps to give things time to migrate from
+# this process, to the MsgQ, and back to this process.
+#
+
+class TestCCWireEncoding(unittest.TestCase):
+    def setUp(self):
+        self.s1 = ISC.CC.Session()
+        self.s2 = ISC.CC.Session()
+
+    def test_lname(self):
+        self.assertTrue(self.s1.lname)
+        self.assertTrue(self.s2.lname)
+
+    def test_subscribe(self):
+        self.s1.group_subscribe("g1", "i1")
+        self.s2.group_subscribe("g1", "i1")
+        time.sleep(0.5)
+        outmsg = { "data" : "foo" }
+        self.s1.group_sendmsg(outmsg, "g1", "i1")
+        time.sleep(0.5)
+        msg, env = self.s2.group_recvmsg()
+        self.assertEqual(env["from"], self.s1.lname)
+
+    def test_unsubscribe(self):
+        self.s1.group_subscribe("g1", "i1")
+        self.s2.group_subscribe("g1", "i1")
+        time.sleep(0.5)
+        self.s2.group_unsubscribe("g1", "i1")
+        outmsg = { "data" : "foo" }
+        self.s1.group_sendmsg(outmsg, "g1", "i1")
+        time.sleep(0.5)
+        msg, env = self.s2.group_recvmsg()
+        self.assertFalse(env)
+
+if __name__ == '__main__':
+    unittest.main()