server.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #!/usr/bin/env python
  2. import asyncio
  3. import random
  4. import functools
  5. import netaddr
  6. from netaddr import IPAddress
  7. import peerfinder_pb2 as pf
  8. WORKER_BIND_ADDRESS = "::"
  9. WORKER_BIND_PORT = 9999
  10. COMMAND_BIND_ADDRESS = "127.0.0.1"
  11. COMMAND_BIND_PORT = 9998
  12. # How long (in seconds) to wait for a worker before disconnecting it
  13. WORKER_TIMEOUT = 15
  14. class Worker(object):
  15. def __init__(self, reader, writer):
  16. self.reader = reader
  17. self.writer = writer
  18. # Queue of measurements to send to the worker
  19. self.queue = asyncio.Queue()
  20. class Peerfinder(object):
  21. def __init__(self, loop, parallel_workers=3):
  22. self.loop = loop
  23. self.workers = []
  24. # Limit the maximum number of parallel workers handling jobs (to
  25. # avoid overloading the targets of the jobs). For now, we have a
  26. # global semaphore for all measurements, but we might change it to
  27. # a semaphore for each measurement.
  28. self.parallel_workers = parallel_workers
  29. self.workers_semaphore = asyncio.Semaphore(parallel_workers)
  30. def send_pong(self, writer):
  31. msg = pf.Message()
  32. msg.type = pf.Message.Pong
  33. writer.write(msg.SerializeToString())
  34. yield from writer.drain()
  35. def send_target(self, writer, target):
  36. msg = pf.Message()
  37. msg.type = pf.Message.Target
  38. msg.target.target_id = target.target_id
  39. msg.target.target.address = target.target.address
  40. msg.target.target.family = target.target.family
  41. writer.write(msg.SerializeToString())
  42. def status(self):
  43. busy_workers = len([None for w in self.workers if not w.queue.empty()])
  44. print("Currently having {} workers, {} have queued work.".format(
  45. len(self.workers),
  46. busy_workers))
  47. print("Queue lengths: {}".format([w.queue.qsize() for w in self.workers]))
  48. print("Semaphore status: waiting for {}/{} workers.".format(
  49. self.parallel_workers - self.workers_semaphore._value,
  50. self.parallel_workers))
  51. @asyncio.coroutine
  52. def generate_measurements(self):
  53. """Generate dummy measurement at regular intervals, for debug"""
  54. target = pf.Target()
  55. target.target_id = 1
  56. target.target.address = "2001:db8::1"
  57. target.target.family = pf.IPAddress.IPV6
  58. while True:
  59. self.status()
  60. print("Sending measurement")
  61. for worker in list(self.workers):
  62. worker.queue.put_nowait(target)
  63. yield from asyncio.sleep(10 * random.random())
  64. @asyncio.coroutine
  65. def handle_commands(self, reader, writer):
  66. print("Client connecting")
  67. target_id = 0
  68. family = {4: pf.IPAddress.IPV4, 6: pf.IPAddress.IPV6}
  69. while True:
  70. if reader.at_eof():
  71. print("Exiting commands handler")
  72. return
  73. try:
  74. data = yield from reader.read(1024)
  75. except ConnectionResetError:
  76. print("Exiting commands handler")
  77. return
  78. if len(data) == 0:
  79. continue
  80. try:
  81. data = data.strip().decode()
  82. addr = IPAddress(data)
  83. except netaddr.AddrFormatError:
  84. print("Invalid command, disconnecting client")
  85. writer.close()
  86. return
  87. target_id += 1
  88. target = pf.Target()
  89. target.target_id = target_id
  90. target.target.address = str(addr)
  91. target.target.family = family[addr.version]
  92. self.status()
  93. print("Queueing measurement for all workers...")
  94. for worker in list(self.workers):
  95. worker.queue.put_nowait(target)
  96. @asyncio.coroutine
  97. def handle_worker(self, reader, writer):
  98. print("Worker connecting")
  99. worker = Worker(reader, writer)
  100. self.workers.append(worker)
  101. while True:
  102. # Wait for a new target
  103. target = yield from worker.queue.get()
  104. if reader.at_eof():
  105. print("Worker disconnected, exiting")
  106. self.workers.remove(worker)
  107. return
  108. # Use a semaphore to avoid overloading targets
  109. with (yield from self.workers_semaphore):
  110. self.send_target(writer, target)
  111. try:
  112. yield from writer.drain()
  113. data = yield from asyncio.wait_for(reader.read(1024),
  114. WORKER_TIMEOUT)
  115. except ConnectionResetError:
  116. print("Worker disconnected, exiting")
  117. self.workers.remove(worker)
  118. return
  119. except asyncio.TimeoutError:
  120. print("Worker timeout, exiting")
  121. self.workers.remove(worker)
  122. writer.close()
  123. return
  124. msg = pf.Message()
  125. answer = pf.Message()
  126. answer.ParseFromString(data)
  127. print("Received answer {}".format(answer))
  128. if __name__ == '__main__':
  129. loop = asyncio.get_event_loop()
  130. p = Peerfinder(loop)
  131. worker_coro = asyncio.start_server(p.handle_worker, WORKER_BIND_ADDRESS, WORKER_BIND_PORT, loop=loop)
  132. worker_server = loop.run_until_complete(worker_coro)
  133. command_coro = asyncio.start_server(p.handle_commands, COMMAND_BIND_ADDRESS, COMMAND_BIND_PORT, loop=loop)
  134. command_server = loop.run_until_complete(command_coro)
  135. #asyncio.async(p.generate_measurements())
  136. # Serve requests until CTRL+c is pressed
  137. print('Serving workers on {}'.format(worker_server.sockets[0].getsockname()))
  138. print('Listen to commands on {}'.format(command_server.sockets[0].getsockname()))
  139. try:
  140. loop.run_forever()
  141. except KeyboardInterrupt:
  142. pass
  143. # Close the server
  144. command_server.close()
  145. loop.run_until_complete(command_server.wait_closed())
  146. worker_server.close()
  147. loop.run_until_complete(worker_server.wait_closed())
  148. loop.close()