|
@@ -15,15 +15,60 @@ import errno
|
|
|
import time
|
|
|
import select
|
|
|
import pprint
|
|
|
+import random
|
|
|
from optparse import OptionParser, OptionValueError
|
|
|
|
|
|
import ISC.CC
|
|
|
|
|
|
-class MsgQReceiveError(Exception): pass
|
|
|
-
|
|
|
# This is the version that gets displayed to the user.
|
|
|
__version__ = "v20091030 (Paving the DNS Parking Lot)"
|
|
|
|
|
|
+class MsgQReceiveError(Exception): pass
|
|
|
+
|
|
|
+class SubscriptionManager:
|
|
|
+ def __init__(self):
|
|
|
+ self.subscriptions = {}
|
|
|
+
|
|
|
+ def subscribe(self, group, instance, socket):
|
|
|
+ """Add a subscription."""
|
|
|
+ target = ( group, instance )
|
|
|
+ if target in self.subscriptions:
|
|
|
+ print("Appending to existing target")
|
|
|
+ self.subscriptions[target].append(socket)
|
|
|
+ else:
|
|
|
+ print("Creating new target")
|
|
|
+ self.subscriptions[target] = [ socket ]
|
|
|
+
|
|
|
+ def unsubscribe(self, group, instance, socket):
|
|
|
+ """Remove the socket from the one specific subscription."""
|
|
|
+ target = ( group, instance )
|
|
|
+ if target in self.subscriptions:
|
|
|
+ while socket in self.subscriptions[target]:
|
|
|
+ self.subscriptions[target].remove(socket)
|
|
|
+
|
|
|
+ def unsubscribe_all(self, socket):
|
|
|
+ """Remove the socket from all subscriptions."""
|
|
|
+ for socklist in self.subscriptions.values():
|
|
|
+ while socket in socklist:
|
|
|
+ socklist.remove(socket)
|
|
|
+
|
|
|
+ def find_sub(self, group, instance):
|
|
|
+ """Return an array of sockets which want this specific group,
|
|
|
+ instance."""
|
|
|
+ target = (group, instance)
|
|
|
+ if target in self.subscriptions:
|
|
|
+ return self.subscriptions[target]
|
|
|
+ else:
|
|
|
+ return []
|
|
|
+
|
|
|
+ def find(self, group, instance):
|
|
|
+ """Return an array of sockets who should get something sent to
|
|
|
+ this group, instance pair. This includes wildcard subscriptions."""
|
|
|
+ target = (group, instance)
|
|
|
+ partone = self.find_sub(group, instance)
|
|
|
+ parttwo = self.find_sub(group, "*")
|
|
|
+ return list(set(partone + parttwo))
|
|
|
+
|
|
|
class MsgQ:
|
|
|
"""Message Queue class."""
|
|
|
def __init__(self, c_channel_port=9912, verbose=False):
|
|
@@ -39,6 +84,9 @@ class MsgQ:
|
|
|
self.runnable = False
|
|
|
self.listen_socket = False
|
|
|
self.sockets = {}
|
|
|
+ self.connection_counter = random.random()
|
|
|
+ self.hostname = socket.gethostname()
|
|
|
+ self.subs = SubscriptionManager()
|
|
|
|
|
|
def setup_poller(self):
|
|
|
"""Set up the poll thing. Internal function."""
|
|
@@ -77,12 +125,13 @@ class MsgQ:
|
|
|
if sock == None:
|
|
|
sys.stderr.write("Got read on Strange Socket fd %d\n" % fd)
|
|
|
return
|
|
|
- sys.stderr.write("Got read on fd %d\n" %fd)
|
|
|
+# sys.stderr.write("Got read on fd %d\n" %fd)
|
|
|
self.process_packet(fd, sock)
|
|
|
|
|
|
def kill_socket(self, fd, sock):
|
|
|
"""Fully close down the socket."""
|
|
|
self.poller.unregister(sock)
|
|
|
+ self.subs.unsubscribe_all(sock)
|
|
|
sock.close()
|
|
|
self.sockets[fd] = None
|
|
|
sys.stderr.write("Closing socket fd %d\n" % fd)
|
|
@@ -106,8 +155,6 @@ class MsgQ:
|
|
|
if overall_length < 2:
|
|
|
raise MsgQReceiveError("overall_length < 2")
|
|
|
overall_length -= 2
|
|
|
- sys.stderr.write("overall length: %d, routing_length %d\n"
|
|
|
- % (overall_length, routing_length))
|
|
|
if routing_length > overall_length:
|
|
|
raise MsgQReceiveError("routing_length > overall_length")
|
|
|
if routing_length == 0:
|
|
@@ -137,8 +184,8 @@ class MsgQ:
|
|
|
sys.stderr.write("Routing decode error: %s\n" % err)
|
|
|
return
|
|
|
|
|
|
- sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
|
|
|
- sys.stdout.write("\t" + pprint.pformat(data) + "\n")
|
|
|
+# sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
|
|
|
+# sys.stdout.write("\t" + pprint.pformat(data) + "\n")
|
|
|
|
|
|
self.process_command(fd, sock, routingmsg, data)
|
|
|
|
|
@@ -146,29 +193,77 @@ class MsgQ:
|
|
|
"""Process a single command. This will split out into one of the
|
|
|
other functions, above."""
|
|
|
cmd = routing["type"]
|
|
|
- if cmd == 'getlname':
|
|
|
- self.process_command_getlname(sock, routing, data)
|
|
|
- elif cmd == 'send':
|
|
|
+ if cmd == 'send':
|
|
|
self.process_command_send(sock, routing, data)
|
|
|
+ elif cmd == 'subscribe':
|
|
|
+ self.process_command_subscribe(sock, routing, data)
|
|
|
+ elif cmd == 'unsubscribe':
|
|
|
+ self.process_command_unsubscribe(sock, routing, data)
|
|
|
+ elif cmd == 'getlname':
|
|
|
+ self.process_command_getlname(sock, routing, data)
|
|
|
else:
|
|
|
sys.stderr.write("Invalid command: %s\n" % cmd)
|
|
|
|
|
|
- def sendmsg(self, sock, env, msg = None):
|
|
|
+ def preparemsg(self, env, msg = None):
|
|
|
if type(env) == dict:
|
|
|
env = ISC.CC.Message.to_wire(env)
|
|
|
if type(msg) == dict:
|
|
|
msg = ISC.CC.Message.to_wire(msg)
|
|
|
- sock.setblocking(1)
|
|
|
length = 2 + len(env);
|
|
|
if msg:
|
|
|
length += len(msg)
|
|
|
- sock.send(struct.pack("!IH", length, len(env)))
|
|
|
- sock.send(env)
|
|
|
+ ret = struct.pack("!IH", length, len(env))
|
|
|
+ ret += env
|
|
|
if msg:
|
|
|
- sock.send(msg)
|
|
|
+ ret += msg
|
|
|
+ return ret
|
|
|
+
|
|
|
+ def sendmsg(self, sock, env, msg = None):
|
|
|
+ sock.send(self.preparemsg(env, msg))
|
|
|
+
|
|
|
+ def send_prepared_msg(self, sock, msg):
|
|
|
+ sock.send(msg)
|
|
|
+
|
|
|
+ def newlname(self):
|
|
|
+ """Generate a unique conenction identifier for this socket.
|
|
|
+ This is done by using an increasing counter and the current
|
|
|
+ time."""
|
|
|
+ self.connection_counter += 1
|
|
|
+ return "%x_%x@%s" % (time.time(), self.connection_counter, self.hostname)
|
|
|
|
|
|
def process_command_getlname(self, sock, routing, data):
|
|
|
- self.sendmsg(sock, { "type" : "getlname" }, { "lname" : "staticlname" })
|
|
|
+ env = { "type" : "getlname" }
|
|
|
+ reply = { "lname" : self.newlname() }
|
|
|
+ self.sendmsg(sock, env, reply)
|
|
|
+
|
|
|
+ def process_command_send(self, sock, routing, data):
|
|
|
+ group = routing["group"]
|
|
|
+ instance = routing["instance"]
|
|
|
+ if group == None or instance == None:
|
|
|
+ return # ignore invalid packets entirely
|
|
|
+ sockets = self.subs.find(group, instance)
|
|
|
+
|
|
|
+ msg = self.preparemsg(routing, data)
|
|
|
+
|
|
|
+ if sock in sockets:
|
|
|
+ sockets.remove(sock)
|
|
|
+ for socket in sockets:
|
|
|
+ self.send_prepared_msg(socket, msg)
|
|
|
+
|
|
|
+ def process_command_subscribe(self, sock, routing, data):
|
|
|
+ group = routing["group"]
|
|
|
+ instance = routing["instance"]
|
|
|
+ subtype = routing["subtype"]
|
|
|
+ if group == None or instance == None or subtype == None:
|
|
|
+ return # ignore invalid packets entirely
|
|
|
+ self.subs.subscribe(group, instance, sock)
|
|
|
+
|
|
|
+ def process_command_unsubscribe(self, sock, routing, data):
|
|
|
+ group = routing["group"]
|
|
|
+ instance = routing["instance"]
|
|
|
+ if group == None or instance == None:
|
|
|
+ return # ignore invalid packets entirely
|
|
|
+ self.subs.unsubscribe(group, instance, sock)
|
|
|
|
|
|
def run(self):
|
|
|
"""Process messages. Forever. Mostly."""
|