123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- #!/usr/bin/env python
- import asyncio
- import random
- import functools
- import netaddr
- from netaddr import IPAddress
- import peerfinder_pb2 as pf
- WORKER_BIND_ADDRESS = "::"
- WORKER_BIND_PORT = 9999
- COMMAND_BIND_ADDRESS = "127.0.0.1"
- COMMAND_BIND_PORT = 9998
- # How long (in seconds) to wait for a worker before disconnecting it
- WORKER_TIMEOUT = 15
- class Worker(object):
- def __init__(self, reader, writer):
- self.reader = reader
- self.writer = writer
- # Queue of measurements to send to the worker
- self.queue = asyncio.Queue()
- class Peerfinder(object):
- def __init__(self, loop, parallel_workers=3):
- self.loop = loop
- self.workers = []
- # Limit the maximum number of parallel workers handling jobs (to
- # avoid overloading the targets of the jobs). For now, we have a
- # global semaphore for all measurements, but we might change it to
- # a semaphore for each measurement.
- self.parallel_workers = parallel_workers
- self.workers_semaphore = asyncio.Semaphore(parallel_workers)
- def send_pong(self, writer):
- msg = pf.Message()
- msg.type = pf.Message.Pong
- writer.write(msg.SerializeToString())
- yield from writer.drain()
- def send_target(self, writer, target):
- msg = pf.Message()
- msg.type = pf.Message.Target
- msg.target.target_id = target.target_id
- msg.target.target.address = target.target.address
- msg.target.target.family = target.target.family
- writer.write(msg.SerializeToString())
- def status(self):
- busy_workers = len([None for w in self.workers if not w.queue.empty()])
- print("Currently having {} workers, {} have queued work.".format(
- len(self.workers),
- busy_workers))
- print("Queue lengths: {}".format([w.queue.qsize() for w in self.workers]))
- print("Semaphore status: waiting for {}/{} workers.".format(
- self.parallel_workers - self.workers_semaphore._value,
- self.parallel_workers))
- @asyncio.coroutine
- def generate_measurements(self):
- """Generate dummy measurement at regular intervals, for debug"""
- target = pf.Target()
- target.target_id = 1
- target.target.address = "2001:db8::1"
- target.target.family = pf.IPAddress.IPV6
- while True:
- self.status()
- print("Sending measurement")
- for worker in list(self.workers):
- worker.queue.put_nowait(target)
- yield from asyncio.sleep(10 * random.random())
- @asyncio.coroutine
- def handle_commands(self, reader, writer):
- print("Client connecting")
- target_id = 0
- family = {4: pf.IPAddress.IPV4, 6: pf.IPAddress.IPV6}
- while True:
- if reader.at_eof():
- print("Exiting commands handler")
- return
- try:
- data = yield from reader.read(1024)
- except ConnectionResetError:
- print("Exiting commands handler")
- return
- if len(data) == 0:
- continue
- try:
- data = data.strip().decode()
- addr = IPAddress(data)
- except netaddr.AddrFormatError:
- print("Invalid command, disconnecting client")
- writer.close()
- return
- target_id += 1
- target = pf.Target()
- target.target_id = target_id
- target.target.address = str(addr)
- target.target.family = family[addr.version]
- self.status()
- print("Queueing measurement for all workers...")
- for worker in list(self.workers):
- worker.queue.put_nowait(target)
- @asyncio.coroutine
- def handle_worker(self, reader, writer):
- print("Worker connecting")
- worker = Worker(reader, writer)
- self.workers.append(worker)
- while True:
- # Wait for a new target
- target = yield from worker.queue.get()
- if reader.at_eof():
- print("Worker disconnected, exiting")
- self.workers.remove(worker)
- return
- # Use a semaphore to avoid overloading targets
- with (yield from self.workers_semaphore):
- self.send_target(writer, target)
- try:
- yield from writer.drain()
- data = yield from asyncio.wait_for(reader.read(1024),
- WORKER_TIMEOUT)
- except ConnectionResetError:
- print("Worker disconnected, exiting")
- self.workers.remove(worker)
- return
- except asyncio.TimeoutError:
- print("Worker timeout, exiting")
- self.workers.remove(worker)
- writer.close()
- return
- msg = pf.Message()
- answer = pf.Message()
- answer.ParseFromString(data)
- print("Received answer {}".format(answer))
- if __name__ == '__main__':
- loop = asyncio.get_event_loop()
- p = Peerfinder(loop)
- worker_coro = asyncio.start_server(p.handle_worker, WORKER_BIND_ADDRESS, WORKER_BIND_PORT, loop=loop)
- worker_server = loop.run_until_complete(worker_coro)
- command_coro = asyncio.start_server(p.handle_commands, COMMAND_BIND_ADDRESS, COMMAND_BIND_PORT, loop=loop)
- command_server = loop.run_until_complete(command_coro)
- #asyncio.async(p.generate_measurements())
- # Serve requests until CTRL+c is pressed
- print('Serving workers on {}'.format(worker_server.sockets[0].getsockname()))
- print('Listen to commands on {}'.format(command_server.sockets[0].getsockname()))
- try:
- loop.run_forever()
- except KeyboardInterrupt:
- pass
- # Close the server
- command_server.close()
- loop.run_until_complete(command_server.wait_closed())
- worker_server.close()
- loop.run_until_complete(worker_server.wait_closed())
- loop.close()
|