server.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #!/usr/bin/env python
  2. import asyncio
  3. import random
  4. import functools
  5. import netaddr
  6. from netaddr import IPAddress
  7. import peerfinder_pb2 as pf
  8. WORKER_BIND_ADDRESS = "::"
  9. WORKER_BIND_PORT = 9999
  10. COMMAND_BIND_ADDRESS = "127.0.0.1"
  11. COMMAND_BIND_PORT = 9998
  12. class Worker(object):
  13. def __init__(self, reader, writer):
  14. self.reader = reader
  15. self.writer = writer
  16. # Queue of measurements to send to the worker
  17. self.queue = asyncio.Queue()
  18. class Peerfinder(object):
  19. def __init__(self, loop, parallel_workers=3):
  20. self.loop = loop
  21. self.workers = []
  22. # Limit the maximum number of parallel workers handling jobs (to
  23. # avoid overloading the targets of the jobs). For now, we have a
  24. # global semaphore for all measurements, but we might change it to
  25. # a semaphore for each measurement.
  26. self.workers_semaphore = asyncio.Semaphore(parallel_workers)
  27. def send_pong(self, writer):
  28. msg = pf.Message()
  29. msg.type = pf.Message.Pong
  30. writer.write(msg.SerializeToString())
  31. yield from writer.drain()
  32. def send_target(self, writer, target):
  33. msg = pf.Message()
  34. msg.type = pf.Message.Target
  35. msg.target.target_id = target.target_id
  36. msg.target.target.address = target.target.address
  37. msg.target.target.family = target.target.family
  38. writer.write(msg.SerializeToString())
  39. @asyncio.coroutine
  40. def generate_measurements(self):
  41. """Generate dummy measurement at regular intervals, for debug"""
  42. target = pf.Target()
  43. target.target_id = 1
  44. target.target.address = "2001:db8::1"
  45. target.target.family = pf.IPAddress.IPV6
  46. while True:
  47. print("Sending measurement")
  48. for worker in list(self.workers):
  49. worker.queue.put_nowait(target)
  50. yield from asyncio.sleep(5 * random.random())
  51. @asyncio.coroutine
  52. def handle_commands(self, reader, writer):
  53. print("Client connecting")
  54. target_id = 0
  55. family = {4: pf.IPAddress.IPV4, 6: pf.IPAddress.IPV6}
  56. while True:
  57. if reader.at_eof():
  58. print("Exiting commands handler")
  59. return
  60. try:
  61. data = yield from reader.read(1024)
  62. except ConnectionResetError:
  63. print("Exiting commands handler")
  64. return
  65. if len(data) == 0:
  66. continue
  67. try:
  68. data = data.strip().decode()
  69. addr = IPAddress(data)
  70. except netaddr.AddrFormatError:
  71. print("Invalid command, disconnecting client")
  72. writer.close()
  73. return
  74. target_id += 1
  75. target = pf.Target()
  76. target.target_id = target_id
  77. target.target.address = str(addr)
  78. target.target.family = family[addr.version]
  79. print("Queueing measurement for all workers...")
  80. for worker in list(self.workers):
  81. worker.queue.put_nowait(target)
  82. @asyncio.coroutine
  83. def handle_worker(self, reader, writer):
  84. print("Worker connecting")
  85. worker = Worker(reader, writer)
  86. self.workers.append(worker)
  87. while True:
  88. # Wait for a new target
  89. target = yield from worker.queue.get()
  90. if reader.at_eof():
  91. print("Exiting worker handler")
  92. self.workers.remove(worker)
  93. return
  94. # Use a semaphore to avoid overloading targets
  95. with (yield from self.workers_semaphore):
  96. self.send_target(writer, target)
  97. try:
  98. yield from writer.drain()
  99. # TODO: timeout
  100. data = yield from reader.read(1024)
  101. except ConnectionResetError:
  102. print("Exiting worker handler")
  103. self.workers.remove(worker)
  104. return
  105. msg = pf.Message()
  106. answer = pf.Message()
  107. answer.ParseFromString(data)
  108. print("Received answer {}".format(answer))
  109. if __name__ == '__main__':
  110. loop = asyncio.get_event_loop()
  111. p = Peerfinder(loop)
  112. worker_coro = asyncio.start_server(p.handle_worker, WORKER_BIND_ADDRESS, WORKER_BIND_PORT, loop=loop)
  113. worker_server = loop.run_until_complete(worker_coro)
  114. command_coro = asyncio.start_server(p.handle_commands, COMMAND_BIND_ADDRESS, COMMAND_BIND_PORT, loop=loop)
  115. command_server = loop.run_until_complete(command_coro)
  116. #asyncio.async(p.generate_measurements())
  117. # Serve requests until CTRL+c is pressed
  118. print('Serving workers on {}'.format(worker_server.sockets[0].getsockname()))
  119. print('Listen to commands on {}'.format(command_server.sockets[0].getsockname()))
  120. try:
  121. loop.run_forever()
  122. except KeyboardInterrupt:
  123. pass
  124. # Close the server
  125. command_server.close()
  126. loop.run_until_complete(command_server.wait_closed())
  127. worker_server.close()
  128. loop.run_until_complete(worker_server.wait_closed())
  129. loop.close()