Browse Source

New wire format, which makes things more sane for processing envelope apart from messages. No API changes. The current msgq does not support this, but the pymsgq I'm hoping to finish up tomorrow will.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/parkinglot@332 e5f2f494-b856-4b98-b285-d166d9295462
Michael Graff 15 years ago
parent
commit
6383f1fda5
4 changed files with 208 additions and 34 deletions
  1. 92 9
      src/bin/pymsgq/msgq.py
  2. 89 9
      src/lib/cc/cpp/session.cc
  3. 4 0
      src/lib/cc/cpp/session.h
  4. 23 16
      src/lib/cc/python/ISC/CC/session.py

+ 92 - 9
src/bin/pymsgq/msgq.py

@@ -10,7 +10,7 @@ import signal
 import os
 import socket
 import sys
-import re
+import struct
 import errno
 import time
 import select
@@ -19,6 +19,8 @@ from optparse import OptionParser, OptionValueError
 
 import ISC.CC
 
+class MsgQReceiveError(Exception): pass
+
 # This is the version that gets displayed to the user.
 __version__ = "v20091030 (Paving the DNS Parking Lot)"
 
@@ -63,12 +65,14 @@ class MsgQ:
         self.runnable = True
 
     def process_accept(self):
+        """Process an accept on the listening socket."""
         newsocket, ipaddr = self.listen_socket.accept()
         sys.stderr.write("Connection\n")
         self.sockets[newsocket.fileno()] = newsocket
         self.poller.register(newsocket, select.POLLIN)
 
     def process_socket(self, fd):
+        """Process a read on a socket."""
         sock = self.sockets[fd]
         if sock == None:
             sys.stderr.write("Got read on Strange Socket fd %d\n" % fd)
@@ -76,19 +80,98 @@ class MsgQ:
         sys.stderr.write("Got read on fd %d\n" %fd)
         self.process_packet(fd, sock)
 
+    def kill_socket(self, fd, sock):
+        """Fully close down the socket."""
+        self.poller.unregister(sock)
+        sock.close()
+        self.sockets[fd] = None
+        sys.stderr.write("Closing socket fd %d\n" % fd)
+
+    def getbytes(self, fd, sock, length):
+        """Get exactly the requested bytes, or raise an exception if
+           EOF."""
+        received = b''
+        while len(received) < length:
+            data = sock.recv(length - len(received))
+            if len(data) == 0:
+                raise MsgQReceiveError("EOF")
+            received += data
+        return received
+
+    def read_packet(self, fd, sock):
+        """Read a correctly formatted packet.  Will raise exceptions if
+           something fails."""
+        lengths = self.getbytes(fd, sock, 6)
+        overall_length, routing_length = struct.unpack(">IH", lengths)
+        if overall_length < 2:
+            raise MsgQReceiveError("overall_length < 2")
+        overall_length -= 2
+        sys.stderr.write("overall length: %d, routing_length %d\n"
+                         % (overall_length, routing_length))
+        if routing_length > overall_length:
+            raise MsgQReceiveError("routing_length > overall_length")
+        if routing_length == 0:
+            raise MsgQReceiveError("routing_length == 0")
+        data_length = overall_length - routing_length
+        # probably need to sanity check lengths here...
+        routing = self.getbytes(fd, sock, routing_length)
+        if data_length > 0:
+            data = self.getbytes(fd, sock, data_length)
+        else:
+            data = None
+        return (routing, data)
+
     def process_packet(self, fd, sock):
-        data = sock.recv(4)
-        if len(data) == 0:
-            self.poller.unregister(sock)
-            sock.close()
-            self.sockets[fd] = None
-            sys.stderr.write("Closing socket fd %d\n" % fd)
+        """Process one packet."""
+        try:
+            routing, data = self.read_packet(fd, sock)
+        except MsgQReceiveError as err:
+            self.kill_socket(fd, sock)
+            sys.stderr.write("Receive error: %s\n" % err)
+            return
+
+        try:
+            routingmsg = ISC.CC.Message.from_wire(routing)
+        except DecodeError as err:
+            self.kill_socket(fd, sock)
+            sys.stderr.write("Routing decode error: %s\n" % err)
             return
-        sys.stderr.write("Got data: %s\n" % data)
+
+        sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
+        sys.stdout.write("\t" + pprint.pformat(data) + "\n")
+
+        self.process_command(fd, sock, routingmsg, data)
+
+    def process_command(self, fd, sock, routing, data):
+        """Process a single command.  This will split out into one of the
+           other functions, above."""
+        cmd = routing["type"]
+        if cmd == 'getlname':
+            self.process_command_getlname(sock, routing, data)
+        elif cmd == 'send':
+            self.process_command_send(sock, routing, data)
+        else:
+            sys.stderr.write("Invalid command: %s\n" % cmd)
+
+    def sendmsg(self, sock, env, msg = None):
+        if type(env) == dict:
+            env = ISC.CC.Message.to_wire(env)
+        if type(msg) == dict:
+            msg = ISC.CC.Message.to_wire(msg)
+        sock.setblocking(1)
+        length = 2 + len(env);
+        if msg:
+            length += len(msg)
+        sock.send(struct.pack("!IH", length, len(env)))
+        sock.send(env)
+        if msg:
+            sock.send(msg)
+
+    def process_command_getlname(self, sock, routing, data):
+        self.sendmsg(sock, { "type" : "getlname" }, { "lname" : "staticlname" })
 
     def run(self):
         """Process messages.  Forever.  Mostly."""
-
         while True:
             try:
                 events = self.poller.poll()

+ 89 - 9
src/lib/cc/cpp/session.cc

@@ -73,28 +73,70 @@ Session::sendmsg(ElementPtr& msg)
     std::string wire = msg->to_wire();
     unsigned int length = wire.length();
     unsigned int length_net = htonl(length);
+    unsigned short header_length_net = htons(length);
     unsigned int ret;
 
     ret = write(sock, &length_net, 4);
     if (ret != 4)
         throw SessionError("Short write");
 
+    ret = write(sock, &header_length_net, 2);
+    if (ret != 2)
+        throw SessionError("Short write");
+
     ret = write(sock, wire.c_str(), length);
     if (ret != length)
         throw SessionError("Short write");
 }
 
+void
+Session::sendmsg(ElementPtr& env, ElementPtr& msg)
+{
+    std::string header_wire = env->to_wire();
+    std::string body_wire = msg->to_wire();
+    unsigned int length = 2 + header_wire.length() + body_wire.length();
+    unsigned int length_net = htonl(length);
+    unsigned short header_length = header_wire.length();
+    unsigned short header_length_net = htons(header_length);
+    unsigned int ret;
+
+    ret = write(sock, &length_net, 4);
+    if (ret != 4)
+        throw SessionError("Short write");
+
+    ret = write(sock, &header_length_net, 2);
+    if (ret != 2)
+        throw SessionError("Short write");
+
+    std::cout << "[XX] Header length sending: " << header_length << std::endl;
+
+    ret = write(sock, header_wire.c_str(), header_length);
+    ret = write(sock, body_wire.c_str(), body_wire.length());
+    if (ret != length)
+        throw SessionError("Short write");
+}
+
 bool
 Session::recvmsg(ElementPtr& msg, bool nonblock)
 {
     unsigned int length_net;
+    unsigned short header_length_net;
     unsigned int ret;
 
     ret = read(sock, &length_net, 4);
     if (ret != 4)
         throw SessionError("Short read");
 
-    unsigned int length = ntohl(length_net);
+    ret = read(sock, &header_length_net, 2);
+    if (ret != 2)
+        throw SessionError("Short read");
+
+    unsigned int length = ntohl(length_net) - 2;
+    unsigned short header_length = ntohs(header_length_net);
+    if (header_length != length) {
+        throw SessionError("Received non-empty body where only a header expected");
+    }
+    
     char *buffer = new char[length];
     ret = read(sock, buffer, length);
     if (ret != length)
@@ -112,6 +154,48 @@ Session::recvmsg(ElementPtr& msg, bool nonblock)
     // XXXMLG handle non-block here, and return false for short reads
 }
 
+bool
+Session::recvmsg(ElementPtr& env, ElementPtr& msg, bool nonblock)
+{
+    unsigned int length_net;
+    unsigned short header_length_net;
+    unsigned int ret;
+
+    ret = read(sock, &length_net, 4);
+    if (ret != 4)
+        throw SessionError("Short read");
+
+    ret = read(sock, &header_length_net, 2);
+    if (ret != 2)
+        throw SessionError("Short read");
+
+    unsigned int length = ntohl(length_net);
+    unsigned short header_length = ntohs(header_length_net);
+
+    if (header_length > length)
+        throw SessionError("Bad header length");
+    
+    char *buffer = new char[length];
+    ret = read(sock, buffer, length);
+    if (ret != length)
+        throw SessionError("Short read");
+
+    std::string header_wire = std::string(buffer, header_length);
+    std::string body_wire = std::string(buffer, length - header_length);
+    delete [] buffer;
+
+    std::stringstream header_wire_stream;
+    header_wire_stream << header_wire;
+    env = Element::from_wire(header_wire_stream, length);
+
+    std::stringstream body_wire_stream;
+    body_wire_stream << body_wire;
+    msg = Element::from_wire(body_wire_stream, length - header_length);
+
+    return (true);
+    // XXXMLG handle non-block here, and return false for short reads
+}
+
 void
 Session::subscribe(std::string group, std::string instance, std::string subtype)
 {
@@ -148,9 +232,9 @@ Session::group_sendmsg(ElementPtr& msg, std::string group, std::string instance,
     env->set("group", Element::create(group));
     env->set("instance", Element::create(instance));
     env->set("seq", Element::create(sequence));
-    env->set("msg", Element::create(msg->to_wire()));
+    //env->set("msg", Element::create(msg->to_wire()));
 
-    sendmsg(env);
+    sendmsg(env, msg);
 
     return (sequence++);
 }
@@ -158,14 +242,11 @@ Session::group_sendmsg(ElementPtr& msg, std::string group, std::string instance,
 bool
 Session::group_recvmsg(ElementPtr& envelope, ElementPtr& msg, bool nonblock)
 {
-    bool got_message = recvmsg(envelope, nonblock);
+    bool got_message = recvmsg(envelope, msg, nonblock);
     if (!got_message) {
         return false;
     }
 
-    msg = Element::from_wire(envelope->get("msg")->string_value());
-    envelope->remove("msg");
-
     return (true);
 }
 
@@ -180,10 +261,9 @@ Session::reply(ElementPtr& envelope, ElementPtr& newmsg)
     env->set("group", Element::create(envelope->get("group")->string_value()));
     env->set("instance", Element::create(envelope->get("instance")->string_value()));
     env->set("seq", Element::create(sequence));
-    env->set("msg", Element::create(newmsg->to_wire()));
     env->set("reply", Element::create(envelope->get("seq")->string_value()));
 
-    sendmsg(env);
+    sendmsg(env, newmsg);
 
     return (sequence++);
 }

+ 4 - 0
src/lib/cc/cpp/session.h

@@ -36,8 +36,12 @@ namespace ISC {
             void establish();
             void disconnect();
             void sendmsg(ISC::Data::ElementPtr& msg);
+            void sendmsg(ISC::Data::ElementPtr& env, ISC::Data::ElementPtr& msg);
             bool recvmsg(ISC::Data::ElementPtr& msg,
                          bool nonblock = true);
+            bool recvmsg(ISC::Data::ElementPtr& env,
+                         ISC::Data::ElementPtr& msg,
+                         bool nonblock = true);
             void subscribe(std::string group,
                            std::string instance = "*",
                            std::string subtype = "normal");

+ 23 - 16
src/lib/cc/python/ISC/CC/session.py

@@ -37,7 +37,7 @@ class Session:
             self._socket.connect(tuple(['127.0.0.1', port]))
 
             self.sendmsg({ "type": "getlname" })
-            msg = self.recvmsg(False)
+            env, msg = self.recvmsg(False)
             self._lname = msg["lname"]
             if not self._lname:
                 raise ProtocolError("Could not get local name")
@@ -48,18 +48,31 @@ class Session:
     def lname(self):
         return self._lname
 
-    def sendmsg(self, msg):
+    def sendmsg(self, env, msg = None):
+        if type(env) == dict:
+            env = Message.to_wire(env)
         if type(msg) == dict:
             msg = Message.to_wire(msg)
         self._socket.setblocking(1)
-        self._socket.send(struct.pack("!I", len(msg)))
-        self._socket.send(msg)
+        length = 2 + len(env);
+        if msg:
+            length += len(msg)
+        self._socket.send(struct.pack("!I", length))
+        self._socket.send(struct.pack("!H", len(env)))
+        self._socket.send(env)
+        if msg:
+            self._socket.send(msg)
 
     def recvmsg(self, nonblock = True):
         data = self._receive_full_buffer(nonblock)
-        if data:
-            return Message.from_wire(data)
-        return None
+        if data and len(data) > 2:
+            header_length = struct.unpack('>H', data[0:2])[0]
+            data_length = len(data) - 2 - header_length
+            if data_length > 0:
+                return Message.from_wire(data[2:header_length+2]), Message.from_wire(data[header_length + 2:])
+            else:
+                return Message.from_wire(data[2:header_length+2]), None
+        return None, None
 
     def _receive_full_buffer(self, nonblock):
         if nonblock:
@@ -127,20 +140,15 @@ class Session:
             "group": group,
             "instance": instance,
             "seq": seq,
-            "msg": Message.to_wire(msg),
-        })
+        }, Message.to_wire(msg))
         return seq
 
     def group_recvmsg(self, nonblock = True):
-        env = self.recvmsg(nonblock)
+        env, msg  = self.recvmsg(nonblock)
         if env == None:
             # return none twice to match normal return value
             # (so caller won't get a type error on no data)
             return (None, None)
-        if type(env["msg"]) != bytearray:
-            msg = Message.from_wire(env["msg"].encode('ascii'))
-        else:
-            msg = Message.from_wire(env["msg"])
         return (msg, env)
 
     def group_reply(self, routing, msg):
@@ -153,8 +161,7 @@ class Session:
             "instance": routing["instance"],
             "seq": seq,
             "reply": routing["seq"],
-            "msg": Message.to_wire(msg),
-        })
+        }, Message.to_wire(msg))
         return seq
 
 if __name__ == "__main__":