|
@@ -17,11 +17,25 @@ COMMAND_BIND_ADDRESS = "127.0.0.1"
|
|
|
COMMAND_BIND_PORT = 9998
|
|
|
|
|
|
|
|
|
+class Worker(object):
|
|
|
+
|
|
|
+ def __init__(self, reader, writer):
|
|
|
+ self.reader = reader
|
|
|
+ self.writer = writer
|
|
|
+ # Queue of measurements to send to the worker
|
|
|
+ self.queue = asyncio.Queue()
|
|
|
+
|
|
|
+
|
|
|
class Peerfinder(object):
|
|
|
|
|
|
- def __init__(self, loop):
|
|
|
+ def __init__(self, loop, parallel_workers=3):
|
|
|
self.loop = loop
|
|
|
- self.clients = []
|
|
|
+ self.workers = []
|
|
|
+ # Limit the maximum number of parallel workers handling jobs (to
|
|
|
+ # avoid overloading the targets of the jobs). For now, we have a
|
|
|
+ # global semaphore for all measurements, but we might change it to
|
|
|
+ # a semaphore for each measurement.
|
|
|
+ self.workers_semaphore = asyncio.Semaphore(parallel_workers)
|
|
|
|
|
|
def send_pong(self, writer):
|
|
|
msg = pf.Message()
|
|
@@ -32,46 +46,40 @@ class Peerfinder(object):
|
|
|
def send_target(self, writer, target):
|
|
|
msg = pf.Message()
|
|
|
msg.type = pf.Message.Target
|
|
|
- #msg.target = target.SerializeToString()
|
|
|
msg.target.target_id = target.target_id
|
|
|
msg.target.target.address = target.target.address
|
|
|
msg.target.target.family = target.target.family
|
|
|
writer.write(msg.SerializeToString())
|
|
|
- #yield from writer.drain()
|
|
|
|
|
|
- def generate_data(self):
|
|
|
- print("Generate_data")
|
|
|
+ @asyncio.coroutine
|
|
|
+ def generate_measurements(self):
|
|
|
+ """Generate dummy measurement at regular intervals, for debug"""
|
|
|
target = pf.Target()
|
|
|
target.target_id = 1
|
|
|
target.target.address = "2001:db8::1"
|
|
|
target.target.family = pf.IPAddress.IPV6
|
|
|
while True:
|
|
|
- print("Entering write loop")
|
|
|
- for writer in list(self.clients):
|
|
|
- self.sent_target(writer, target)
|
|
|
- target.target_id += 1
|
|
|
- yield from asyncio.sleep(2)
|
|
|
+ print("Sending measurement")
|
|
|
+ for worker in list(self.workers):
|
|
|
+ worker.queue.put_nowait(target)
|
|
|
+ yield from asyncio.sleep(5 * random.random())
|
|
|
|
|
|
- def generate_data_cb(self):
|
|
|
- target = pf.Target()
|
|
|
- target.target_id = 1
|
|
|
- target.target.address = "2001:db8::1"
|
|
|
- target.target.family = pf.IPAddress.IPV6
|
|
|
- print("Sending data")
|
|
|
- for writer in list(self.clients):
|
|
|
- # writer.write(target.SerializeToString())
|
|
|
- self.send_target(writer, target)
|
|
|
- self.loop.call_later(random.randint(1, 10), self.generate_data_cb)
|
|
|
-
|
|
|
+ @asyncio.coroutine
|
|
|
def handle_commands(self, reader, writer):
|
|
|
- target = pf.Target()
|
|
|
- target.target_id = 1
|
|
|
+ print("Client connecting")
|
|
|
+ target_id = 0
|
|
|
family = {4: pf.IPAddress.IPV4, 6: pf.IPAddress.IPV6}
|
|
|
while True:
|
|
|
if reader.at_eof():
|
|
|
print("Exiting commands handler")
|
|
|
return
|
|
|
- data = yield from reader.read(1024)
|
|
|
+ try:
|
|
|
+ data = yield from reader.read(1024)
|
|
|
+ except ConnectionResetError:
|
|
|
+ print("Exiting commands handler")
|
|
|
+ return
|
|
|
+ if len(data) == 0:
|
|
|
+ continue
|
|
|
try:
|
|
|
data = data.strip().decode()
|
|
|
addr = IPAddress(data)
|
|
@@ -79,40 +87,42 @@ class Peerfinder(object):
|
|
|
print("Invalid command, disconnecting client")
|
|
|
writer.close()
|
|
|
return
|
|
|
+ target_id += 1
|
|
|
+ target = pf.Target()
|
|
|
+ target.target_id = target_id
|
|
|
target.target.address = str(addr)
|
|
|
target.target.family = family[addr.version]
|
|
|
- target.target_id += 1
|
|
|
- print("Sending data to all workers...")
|
|
|
- for writer in list(self.clients):
|
|
|
- #writer.write(target.SerializeToString())
|
|
|
- self.send_target(writer, target)
|
|
|
- yield from writer.drain()
|
|
|
+ print("Queueing measurement for all workers...")
|
|
|
+ for worker in list(self.workers):
|
|
|
+ worker.queue.put_nowait(target)
|
|
|
|
|
|
+ @asyncio.coroutine
|
|
|
def handle_worker(self, reader, writer):
|
|
|
- self.clients.append(writer)
|
|
|
+ print("Worker connecting")
|
|
|
+ worker = Worker(reader, writer)
|
|
|
+ self.workers.append(worker)
|
|
|
while True:
|
|
|
+ # Wait for a new target
|
|
|
+ target = yield from worker.queue.get()
|
|
|
if reader.at_eof():
|
|
|
print("Exiting worker handler")
|
|
|
- self.clients.remove(writer)
|
|
|
+ self.workers.remove(worker)
|
|
|
return
|
|
|
- try:
|
|
|
- data = yield from reader.read(1024)
|
|
|
- except ConnectionResetError:
|
|
|
- print("Exiting worker handler")
|
|
|
- self.clients.remove(writer)
|
|
|
- return
|
|
|
- msg = pf.Message()
|
|
|
- msg.ParseFromString(data)
|
|
|
- print("Receiving {!r}".format(msg))
|
|
|
- print("Currently having {} clients".format(len(self.clients)))
|
|
|
- if msg.type == pf.Message.Ping:
|
|
|
- yield from self.send_pong(writer)
|
|
|
-
|
|
|
-#@asyncio.coroutine
|
|
|
-def hello(loop):
|
|
|
- print("hello")
|
|
|
- #yield from asyncio.sleep(1)
|
|
|
- loop.call_later(2, hello, loop)
|
|
|
+ # Use a semaphore to avoid overloading targets
|
|
|
+ with (yield from self.workers_semaphore):
|
|
|
+ self.send_target(writer, target)
|
|
|
+ try:
|
|
|
+ yield from writer.drain()
|
|
|
+ # TODO: timeout
|
|
|
+ data = yield from reader.read(1024)
|
|
|
+ except ConnectionResetError:
|
|
|
+ print("Exiting worker handler")
|
|
|
+ self.workers.remove(worker)
|
|
|
+ return
|
|
|
+ msg = pf.Message()
|
|
|
+ answer = pf.Message()
|
|
|
+ answer.ParseFromString(data)
|
|
|
+ print("Received answer {}".format(answer))
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
@@ -123,11 +133,7 @@ if __name__ == '__main__':
|
|
|
command_coro = asyncio.start_server(p.handle_commands, COMMAND_BIND_ADDRESS, COMMAND_BIND_PORT, loop=loop)
|
|
|
command_server = loop.run_until_complete(command_coro)
|
|
|
|
|
|
- # Generate data (test)
|
|
|
- #loop.call_soon(hello, loop)
|
|
|
- loop.call_soon(p.generate_data_cb)
|
|
|
- #asyncio.async(hello)
|
|
|
- #loop.create_task(hello)
|
|
|
+ #asyncio.async(p.generate_measurements())
|
|
|
|
|
|
# Serve requests until CTRL+c is pressed
|
|
|
print('Serving workers on {}'.format(worker_server.sockets[0].getsockname()))
|