Browse Source

New wire format, which makes things more sane for processing envelope apart from messages. No API changes. The current msgq does not support this, but the pymsgq I'm hoping to finish up tomorrow will.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/parkinglot@332 e5f2f494-b856-4b98-b285-d166d9295462
Michael Graff 15 years ago
parent
commit
6383f1fda5
4 changed files with 208 additions and 34 deletions
  1. 92 9
      src/bin/pymsgq/msgq.py
  2. 89 9
      src/lib/cc/cpp/session.cc
  3. 4 0
      src/lib/cc/cpp/session.h
  4. 23 16
      src/lib/cc/python/ISC/CC/session.py

+ 92 - 9
src/bin/pymsgq/msgq.py

@@ -10,7 +10,7 @@ import signal
 import os
 import os
 import socket
 import socket
 import sys
 import sys
-import re
+import struct
 import errno
 import errno
 import time
 import time
 import select
 import select
@@ -19,6 +19,8 @@ from optparse import OptionParser, OptionValueError
 
 
 import ISC.CC
 import ISC.CC
 
 
+class MsgQReceiveError(Exception): pass
+
 # This is the version that gets displayed to the user.
 # This is the version that gets displayed to the user.
 __version__ = "v20091030 (Paving the DNS Parking Lot)"
 __version__ = "v20091030 (Paving the DNS Parking Lot)"
 
 
@@ -63,12 +65,14 @@ class MsgQ:
         self.runnable = True
         self.runnable = True
 
 
     def process_accept(self):
     def process_accept(self):
+        """Process an accept on the listening socket."""
         newsocket, ipaddr = self.listen_socket.accept()
         newsocket, ipaddr = self.listen_socket.accept()
         sys.stderr.write("Connection\n")
         sys.stderr.write("Connection\n")
         self.sockets[newsocket.fileno()] = newsocket
         self.sockets[newsocket.fileno()] = newsocket
         self.poller.register(newsocket, select.POLLIN)
         self.poller.register(newsocket, select.POLLIN)
 
 
     def process_socket(self, fd):
     def process_socket(self, fd):
+        """Process a read on a socket."""
         sock = self.sockets[fd]
         sock = self.sockets[fd]
         if sock == None:
         if sock == None:
             sys.stderr.write("Got read on Strange Socket fd %d\n" % fd)
             sys.stderr.write("Got read on Strange Socket fd %d\n" % fd)
@@ -76,19 +80,98 @@ class MsgQ:
         sys.stderr.write("Got read on fd %d\n" %fd)
         sys.stderr.write("Got read on fd %d\n" %fd)
         self.process_packet(fd, sock)
         self.process_packet(fd, sock)
 
 
+    def kill_socket(self, fd, sock):
+        """Fully close down the socket."""
+        self.poller.unregister(sock)
+        sock.close()
+        self.sockets[fd] = None
+        sys.stderr.write("Closing socket fd %d\n" % fd)
+
+    def getbytes(self, fd, sock, length):
+        """Get exactly the requested bytes, or raise an exception if
+           EOF."""
+        received = b''
+        while len(received) < length:
+            data = sock.recv(length - len(received))
+            if len(data) == 0:
+                raise MsgQReceiveError("EOF")
+            received += data
+        return received
+
+    def read_packet(self, fd, sock):
+        """Read a correctly formatted packet.  Will raise exceptions if
+           something fails."""
+        lengths = self.getbytes(fd, sock, 6)
+        overall_length, routing_length = struct.unpack(">IH", lengths)
+        if overall_length < 2:
+            raise MsgQReceiveError("overall_length < 2")
+        overall_length -= 2
+        sys.stderr.write("overall length: %d, routing_length %d\n"
+                         % (overall_length, routing_length))
+        if routing_length > overall_length:
+            raise MsgQReceiveError("routing_length > overall_length")
+        if routing_length == 0:
+            raise MsgQReceiveError("routing_length == 0")
+        data_length = overall_length - routing_length
+        # probably need to sanity check lengths here...
+        routing = self.getbytes(fd, sock, routing_length)
+        if data_length > 0:
+            data = self.getbytes(fd, sock, data_length)
+        else:
+            data = None
+        return (routing, data)
+
     def process_packet(self, fd, sock):
     def process_packet(self, fd, sock):
-        data = sock.recv(4)
+        """Process one packet."""
-        if len(data) == 0:
+        try:
-            self.poller.unregister(sock)
+            routing, data = self.read_packet(fd, sock)
-            sock.close()
+        except MsgQReceiveError as err:
-            self.sockets[fd] = None
+            self.kill_socket(fd, sock)
-            sys.stderr.write("Closing socket fd %d\n" % fd)
+            sys.stderr.write("Receive error: %s\n" % err)
+            return
+
+        try:
+            routingmsg = ISC.CC.Message.from_wire(routing)
+        except DecodeError as err:
+            self.kill_socket(fd, sock)
+            sys.stderr.write("Routing decode error: %s\n" % err)
             return
             return
-        sys.stderr.write("Got data: %s\n" % data)
+
+        sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
+        sys.stdout.write("\t" + pprint.pformat(data) + "\n")
+
+        self.process_command(fd, sock, routingmsg, data)
+
+    def process_command(self, fd, sock, routing, data):
+        """Process a single command.  This will split out into one of the
+           other functions, above."""
+        cmd = routing["type"]
+        if cmd == 'getlname':
+            self.process_command_getlname(sock, routing, data)
+        elif cmd == 'send':
+            self.process_command_send(sock, routing, data)
+        else:
+            sys.stderr.write("Invalid command: %s\n" % cmd)
+
+    def sendmsg(self, sock, env, msg = None):
+        if type(env) == dict:
+            env = ISC.CC.Message.to_wire(env)
+        if type(msg) == dict:
+            msg = ISC.CC.Message.to_wire(msg)
+        sock.setblocking(1)
+        length = 2 + len(env);
+        if msg:
+            length += len(msg)
+        sock.send(struct.pack("!IH", length, len(env)))
+        sock.send(env)
+        if msg:
+            sock.send(msg)
+
+    def process_command_getlname(self, sock, routing, data):
+        self.sendmsg(sock, { "type" : "getlname" }, { "lname" : "staticlname" })
 
 
     def run(self):
     def run(self):
         """Process messages.  Forever.  Mostly."""
         """Process messages.  Forever.  Mostly."""
-
         while True:
         while True:
             try:
             try:
                 events = self.poller.poll()
                 events = self.poller.poll()

+ 89 - 9
src/lib/cc/cpp/session.cc

@@ -73,28 +73,70 @@ Session::sendmsg(ElementPtr& msg)
     std::string wire = msg->to_wire();
     std::string wire = msg->to_wire();
     unsigned int length = wire.length();
     unsigned int length = wire.length();
     unsigned int length_net = htonl(length);
     unsigned int length_net = htonl(length);
+    unsigned short header_length_net = htons(length);
     unsigned int ret;
     unsigned int ret;
 
 
     ret = write(sock, &length_net, 4);
     ret = write(sock, &length_net, 4);
     if (ret != 4)
     if (ret != 4)
         throw SessionError("Short write");
         throw SessionError("Short write");
 
 
+    ret = write(sock, &header_length_net, 2);
+    if (ret != 2)
+        throw SessionError("Short write");
+
     ret = write(sock, wire.c_str(), length);
     ret = write(sock, wire.c_str(), length);
     if (ret != length)
     if (ret != length)
         throw SessionError("Short write");
         throw SessionError("Short write");
 }
 }
 
 
+void
+Session::sendmsg(ElementPtr& env, ElementPtr& msg)
+{
+    std::string header_wire = env->to_wire();
+    std::string body_wire = msg->to_wire();
+    unsigned int length = 2 + header_wire.length() + body_wire.length();
+    unsigned int length_net = htonl(length);
+    unsigned short header_length = header_wire.length();
+    unsigned short header_length_net = htons(header_length);
+    unsigned int ret;
+
+    ret = write(sock, &length_net, 4);
+    if (ret != 4)
+        throw SessionError("Short write");
+
+    ret = write(sock, &header_length_net, 2);
+    if (ret != 2)
+        throw SessionError("Short write");
+
+    std::cout << "[XX] Header length sending: " << header_length << std::endl;
+
+    ret = write(sock, header_wire.c_str(), header_length);
+    ret = write(sock, body_wire.c_str(), body_wire.length());
+    if (ret != length)
+        throw SessionError("Short write");
+}
+
 bool
 bool
 Session::recvmsg(ElementPtr& msg, bool nonblock)
 Session::recvmsg(ElementPtr& msg, bool nonblock)
 {
 {
     unsigned int length_net;
     unsigned int length_net;
+    unsigned short header_length_net;
     unsigned int ret;
     unsigned int ret;
 
 
     ret = read(sock, &length_net, 4);
     ret = read(sock, &length_net, 4);
     if (ret != 4)
     if (ret != 4)
         throw SessionError("Short read");
         throw SessionError("Short read");
 
 
-    unsigned int length = ntohl(length_net);
+    ret = read(sock, &header_length_net, 2);
+    if (ret != 2)
+        throw SessionError("Short read");
+
+    unsigned int length = ntohl(length_net) - 2;
+    unsigned short header_length = ntohs(header_length_net);
+    if (header_length != length) {
+        throw SessionError("Received non-empty body where only a header expected");
+    }
+    
     char *buffer = new char[length];
     char *buffer = new char[length];
     ret = read(sock, buffer, length);
     ret = read(sock, buffer, length);
     if (ret != length)
     if (ret != length)
@@ -112,6 +154,48 @@ Session::recvmsg(ElementPtr& msg, bool nonblock)
     // XXXMLG handle non-block here, and return false for short reads
     // XXXMLG handle non-block here, and return false for short reads
 }
 }
 
 
+bool
+Session::recvmsg(ElementPtr& env, ElementPtr& msg, bool nonblock)
+{
+    unsigned int length_net;
+    unsigned short header_length_net;
+    unsigned int ret;
+
+    ret = read(sock, &length_net, 4);
+    if (ret != 4)
+        throw SessionError("Short read");
+
+    ret = read(sock, &header_length_net, 2);
+    if (ret != 2)
+        throw SessionError("Short read");
+
+    unsigned int length = ntohl(length_net);
+    unsigned short header_length = ntohs(header_length_net);
+
+    if (header_length > length)
+        throw SessionError("Bad header length");
+    
+    char *buffer = new char[length];
+    ret = read(sock, buffer, length);
+    if (ret != length)
+        throw SessionError("Short read");
+
+    std::string header_wire = std::string(buffer, header_length);
+    std::string body_wire = std::string(buffer, length - header_length);
+    delete [] buffer;
+
+    std::stringstream header_wire_stream;
+    header_wire_stream << header_wire;
+    env = Element::from_wire(header_wire_stream, length);
+
+    std::stringstream body_wire_stream;
+    body_wire_stream << body_wire;
+    msg = Element::from_wire(body_wire_stream, length - header_length);
+
+    return (true);
+    // XXXMLG handle non-block here, and return false for short reads
+}
+
 void
 void
 Session::subscribe(std::string group, std::string instance, std::string subtype)
 Session::subscribe(std::string group, std::string instance, std::string subtype)
 {
 {
@@ -148,9 +232,9 @@ Session::group_sendmsg(ElementPtr& msg, std::string group, std::string instance,
     env->set("group", Element::create(group));
     env->set("group", Element::create(group));
     env->set("instance", Element::create(instance));
     env->set("instance", Element::create(instance));
     env->set("seq", Element::create(sequence));
     env->set("seq", Element::create(sequence));
-    env->set("msg", Element::create(msg->to_wire()));
+    //env->set("msg", Element::create(msg->to_wire()));
 
 
-    sendmsg(env);
+    sendmsg(env, msg);
 
 
     return (sequence++);
     return (sequence++);
 }
 }
@@ -158,14 +242,11 @@ Session::group_sendmsg(ElementPtr& msg, std::string group, std::string instance,
 bool
 bool
 Session::group_recvmsg(ElementPtr& envelope, ElementPtr& msg, bool nonblock)
 Session::group_recvmsg(ElementPtr& envelope, ElementPtr& msg, bool nonblock)
 {
 {
-    bool got_message = recvmsg(envelope, nonblock);
+    bool got_message = recvmsg(envelope, msg, nonblock);
     if (!got_message) {
     if (!got_message) {
         return false;
         return false;
     }
     }
 
 
-    msg = Element::from_wire(envelope->get("msg")->string_value());
-    envelope->remove("msg");
-
     return (true);
     return (true);
 }
 }
 
 
@@ -180,10 +261,9 @@ Session::reply(ElementPtr& envelope, ElementPtr& newmsg)
     env->set("group", Element::create(envelope->get("group")->string_value()));
     env->set("group", Element::create(envelope->get("group")->string_value()));
     env->set("instance", Element::create(envelope->get("instance")->string_value()));
     env->set("instance", Element::create(envelope->get("instance")->string_value()));
     env->set("seq", Element::create(sequence));
     env->set("seq", Element::create(sequence));
-    env->set("msg", Element::create(newmsg->to_wire()));
     env->set("reply", Element::create(envelope->get("seq")->string_value()));
     env->set("reply", Element::create(envelope->get("seq")->string_value()));
 
 
-    sendmsg(env);
+    sendmsg(env, newmsg);
 
 
     return (sequence++);
     return (sequence++);
 }
 }

+ 4 - 0
src/lib/cc/cpp/session.h

@@ -36,8 +36,12 @@ namespace ISC {
             void establish();
             void establish();
             void disconnect();
             void disconnect();
             void sendmsg(ISC::Data::ElementPtr& msg);
             void sendmsg(ISC::Data::ElementPtr& msg);
+            void sendmsg(ISC::Data::ElementPtr& env, ISC::Data::ElementPtr& msg);
             bool recvmsg(ISC::Data::ElementPtr& msg,
             bool recvmsg(ISC::Data::ElementPtr& msg,
                          bool nonblock = true);
                          bool nonblock = true);
+            bool recvmsg(ISC::Data::ElementPtr& env,
+                         ISC::Data::ElementPtr& msg,
+                         bool nonblock = true);
             void subscribe(std::string group,
             void subscribe(std::string group,
                            std::string instance = "*",
                            std::string instance = "*",
                            std::string subtype = "normal");
                            std::string subtype = "normal");

+ 23 - 16
src/lib/cc/python/ISC/CC/session.py

@@ -37,7 +37,7 @@ class Session:
             self._socket.connect(tuple(['127.0.0.1', port]))
             self._socket.connect(tuple(['127.0.0.1', port]))
 
 
             self.sendmsg({ "type": "getlname" })
             self.sendmsg({ "type": "getlname" })
-            msg = self.recvmsg(False)
+            env, msg = self.recvmsg(False)
             self._lname = msg["lname"]
             self._lname = msg["lname"]
             if not self._lname:
             if not self._lname:
                 raise ProtocolError("Could not get local name")
                 raise ProtocolError("Could not get local name")
@@ -48,18 +48,31 @@ class Session:
     def lname(self):
     def lname(self):
         return self._lname
         return self._lname
 
 
-    def sendmsg(self, msg):
+    def sendmsg(self, env, msg = None):
+        if type(env) == dict:
+            env = Message.to_wire(env)
         if type(msg) == dict:
         if type(msg) == dict:
             msg = Message.to_wire(msg)
             msg = Message.to_wire(msg)
         self._socket.setblocking(1)
         self._socket.setblocking(1)
-        self._socket.send(struct.pack("!I", len(msg)))
+        length = 2 + len(env);
-        self._socket.send(msg)
+        if msg:
+            length += len(msg)
+        self._socket.send(struct.pack("!I", length))
+        self._socket.send(struct.pack("!H", len(env)))
+        self._socket.send(env)
+        if msg:
+            self._socket.send(msg)
 
 
     def recvmsg(self, nonblock = True):
     def recvmsg(self, nonblock = True):
         data = self._receive_full_buffer(nonblock)
         data = self._receive_full_buffer(nonblock)
-        if data:
+        if data and len(data) > 2:
-            return Message.from_wire(data)
+            header_length = struct.unpack('>H', data[0:2])[0]
-        return None
+            data_length = len(data) - 2 - header_length
+            if data_length > 0:
+                return Message.from_wire(data[2:header_length+2]), Message.from_wire(data[header_length + 2:])
+            else:
+                return Message.from_wire(data[2:header_length+2]), None
+        return None, None
 
 
     def _receive_full_buffer(self, nonblock):
     def _receive_full_buffer(self, nonblock):
         if nonblock:
         if nonblock:
@@ -127,20 +140,15 @@ class Session:
             "group": group,
             "group": group,
             "instance": instance,
             "instance": instance,
             "seq": seq,
             "seq": seq,
-            "msg": Message.to_wire(msg),
+        }, Message.to_wire(msg))
-        })
         return seq
         return seq
 
 
     def group_recvmsg(self, nonblock = True):
     def group_recvmsg(self, nonblock = True):
-        env = self.recvmsg(nonblock)
+        env, msg  = self.recvmsg(nonblock)
         if env == None:
         if env == None:
             # return none twice to match normal return value
             # return none twice to match normal return value
             # (so caller won't get a type error on no data)
             # (so caller won't get a type error on no data)
             return (None, None)
             return (None, None)
-        if type(env["msg"]) != bytearray:
-            msg = Message.from_wire(env["msg"].encode('ascii'))
-        else:
-            msg = Message.from_wire(env["msg"])
         return (msg, env)
         return (msg, env)
 
 
     def group_reply(self, routing, msg):
     def group_reply(self, routing, msg):
@@ -153,8 +161,7 @@ class Session:
             "instance": routing["instance"],
             "instance": routing["instance"],
             "seq": seq,
             "seq": seq,
             "reply": routing["seq"],
             "reply": routing["seq"],
-            "msg": Message.to_wire(msg),
+        }, Message.to_wire(msg))
-        })
         return seq
         return seq
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":