#!/usr/bin/env python import asyncio import random import functools import netaddr from netaddr import IPAddress import peerfinder_pb2 as pf WORKER_BIND_ADDRESS = "::" WORKER_BIND_PORT = 9999 COMMAND_BIND_ADDRESS = "127.0.0.1" COMMAND_BIND_PORT = 9998 class Worker(object): def __init__(self, reader, writer): self.reader = reader self.writer = writer # Queue of measurements to send to the worker self.queue = asyncio.Queue() class Peerfinder(object): def __init__(self, loop, parallel_workers=3): self.loop = loop self.workers = [] # Limit the maximum number of parallel workers handling jobs (to # avoid overloading the targets of the jobs). For now, we have a # global semaphore for all measurements, but we might change it to # a semaphore for each measurement. self.workers_semaphore = asyncio.Semaphore(parallel_workers) def send_pong(self, writer): msg = pf.Message() msg.type = pf.Message.Pong writer.write(msg.SerializeToString()) yield from writer.drain() def send_target(self, writer, target): msg = pf.Message() msg.type = pf.Message.Target msg.target.target_id = target.target_id msg.target.target.address = target.target.address msg.target.target.family = target.target.family writer.write(msg.SerializeToString()) @asyncio.coroutine def generate_measurements(self): """Generate dummy measurement at regular intervals, for debug""" target = pf.Target() target.target_id = 1 target.target.address = "2001:db8::1" target.target.family = pf.IPAddress.IPV6 while True: print("Sending measurement") for worker in list(self.workers): worker.queue.put_nowait(target) yield from asyncio.sleep(5 * random.random()) @asyncio.coroutine def handle_commands(self, reader, writer): print("Client connecting") target_id = 0 family = {4: pf.IPAddress.IPV4, 6: pf.IPAddress.IPV6} while True: if reader.at_eof(): print("Exiting commands handler") return try: data = yield from reader.read(1024) except ConnectionResetError: print("Exiting commands handler") return if len(data) == 0: continue try: data = data.strip().decode() addr = IPAddress(data) except netaddr.AddrFormatError: print("Invalid command, disconnecting client") writer.close() return target_id += 1 target = pf.Target() target.target_id = target_id target.target.address = str(addr) target.target.family = family[addr.version] print("Queueing measurement for all workers...") for worker in list(self.workers): worker.queue.put_nowait(target) @asyncio.coroutine def handle_worker(self, reader, writer): print("Worker connecting") worker = Worker(reader, writer) self.workers.append(worker) while True: # Wait for a new target target = yield from worker.queue.get() if reader.at_eof(): print("Exiting worker handler") self.workers.remove(worker) return # Use a semaphore to avoid overloading targets with (yield from self.workers_semaphore): self.send_target(writer, target) try: yield from writer.drain() # TODO: timeout data = yield from reader.read(1024) except ConnectionResetError: print("Exiting worker handler") self.workers.remove(worker) return msg = pf.Message() answer = pf.Message() answer.ParseFromString(data) print("Received answer {}".format(answer)) if __name__ == '__main__': loop = asyncio.get_event_loop() p = Peerfinder(loop) worker_coro = asyncio.start_server(p.handle_worker, WORKER_BIND_ADDRESS, WORKER_BIND_PORT, loop=loop) worker_server = loop.run_until_complete(worker_coro) command_coro = asyncio.start_server(p.handle_commands, COMMAND_BIND_ADDRESS, COMMAND_BIND_PORT, loop=loop) command_server = loop.run_until_complete(command_coro) #asyncio.async(p.generate_measurements()) # Serve requests until CTRL+c is pressed print('Serving workers on {}'.format(worker_server.sockets[0].getsockname())) print('Listen to commands on {}'.format(command_server.sockets[0].getsockname())) try: loop.run_forever() except KeyboardInterrupt: pass # Close the server command_server.close() loop.run_until_complete(command_server.wait_closed()) worker_server.close() loop.run_until_complete(worker_server.wait_closed()) loop.close()