Browse Source

Implement directed messages, and a test for it

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/parkinglot@334 e5f2f494-b856-4b98-b285-d166d9295462
Michael Graff 15 years ago
parent
commit
f222c66705
3 changed files with 40 additions and 6 deletions
  1. 14 6
      src/bin/pymsgq/msgq.py
  2. 10 0
      src/lib/cc/python/ISC/CC/session.py
  3. 16 0
      src/lib/cc/python/test_session.py

+ 14 - 6
src/bin/pymsgq/msgq.py

@@ -87,6 +87,7 @@ class MsgQ:
         self.connection_counter = random.random()
         self.hostname = socket.gethostname()
         self.subs = SubscriptionManager()
+        self.lnames = {}
 
     def setup_poller(self):
         """Set up the poll thing.  Internal function."""
@@ -117,6 +118,8 @@ class MsgQ:
         newsocket, ipaddr = self.listen_socket.accept()
         sys.stderr.write("Connection\n")
         self.sockets[newsocket.fileno()] = newsocket
+        lname = self.newlname()
+        self.lnames[lname] = newsocket
         self.poller.register(newsocket, select.POLLIN)
 
     def process_socket(self, fd):
@@ -132,6 +135,8 @@ class MsgQ:
         """Fully close down the socket."""
         self.poller.unregister(sock)
         self.subs.unsubscribe_all(sock)
+        lname = [ k for k, v in self.lnames.items() if v == sock ][0]
+        del self.lnames[lname]
         sock.close()
         self.sockets[fd] = None
         sys.stderr.write("Closing socket fd %d\n" % fd)
@@ -232,16 +237,20 @@ class MsgQ:
         return "%x_%x@%s" % (time.time(), self.connection_counter, self.hostname)
 
     def process_command_getlname(self, sock, routing, data):
-        env = { "type" : "getlname" }
-        reply = { "lname" : self.newlname() }
-        self.sendmsg(sock, env, reply)
+        lname = [ k for k, v in self.lnames.items() if v == sock ][0]
+        self.sendmsg(sock, { "type" : "getlname" }, { "lname" : lname })
 
     def process_command_send(self, sock, routing, data):
         group = routing["group"]
         instance = routing["instance"]
+        to = routing["to"]
         if group == None or instance == None:
             return  # ignore invalid packets entirely
-        sockets = self.subs.find(group, instance)
+
+        if to == "*":
+            sockets = self.subs.find(group, instance)
+        else:
+            sockets = [ self.lnames[to] ]
 
         msg = self.preparemsg(routing, data)
 
@@ -253,8 +262,7 @@ class MsgQ:
     def process_command_subscribe(self, sock, routing, data):
         group = routing["group"]
         instance = routing["instance"]
-        subtype = routing["subtype"]
-        if group == None or instance == None or subtype == None:
+        if group == None or instance == None:
             return  # ignore invalid packets entirely
         self.subs.subscribe(group, instance, sock)
 

+ 10 - 0
src/lib/cc/python/ISC/CC/session.py

@@ -31,6 +31,7 @@ class Session:
         self._recvlength = None
         self._sendbuffer = bytearray()
         self._sequence = 1
+        self._closed = False
 
         try:
             self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -48,7 +49,14 @@ class Session:
     def lname(self):
         return self._lname
 
+    def close(self):
+        self._socket.close()
+        self._lname = None
+        self._closed = True
+
     def sendmsg(self, env, msg = None):
+        if self._closed:
+            raise SessionError("Session has been closed.")
         if type(env) == dict:
             env = Message.to_wire(env)
         if type(msg) == dict:
@@ -64,6 +72,8 @@ class Session:
             self._socket.send(msg)
 
     def recvmsg(self, nonblock = True):
+        if self._closed:
+            raise SessionError("Session has been closed.")
         data = self._receive_full_buffer(nonblock)
         if data and len(data) > 2:
             header_length = struct.unpack('>H', data[0:2])[0]

+ 16 - 0
src/lib/cc/python/test_session.py

@@ -15,6 +15,10 @@ class TestCCWireEncoding(unittest.TestCase):
         self.s1 = ISC.CC.Session()
         self.s2 = ISC.CC.Session()
 
+    def tearDown(self):
+        self.s1.close()
+        self.s2.close()
+
     def test_lname(self):
         self.assertTrue(self.s1.lname)
         self.assertTrue(self.s2.lname)
@@ -40,5 +44,17 @@ class TestCCWireEncoding(unittest.TestCase):
         msg, env = self.s2.group_recvmsg()
         self.assertFalse(env)
 
+    def test_directed_recipient(self):
+        self.s1.group_subscribe("g1", "i1")
+        time.sleep(0.5)
+        outmsg = { "data" : "foo" }
+        self.s1.group_sendmsg(outmsg, "g4", "i4", self.s2.lname)
+        time.sleep(0.5)
+        msg, env = self.s2.group_recvmsg()
+        self.assertEqual(env["from"], self.s1.lname)
+        self.assertEqual(env["to"], self.s2.lname)
+        self.assertEqual(env["group"], "g4")
+        self.assertEqual(env["instance"], "i4")
+
 if __name__ == '__main__':
     unittest.main()