server.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #!/usr/bin/env python
  2. import asyncio
  3. import random
  4. import functools
  5. import netaddr
  6. from netaddr import IPAddress
  7. import peerfinder_pb2 as pf
  8. WORKER_BIND_ADDRESS = "::"
  9. WORKER_BIND_PORT = 9999
  10. COMMAND_BIND_ADDRESS = "127.0.0.1"
  11. COMMAND_BIND_PORT = 9998
  12. class Worker(object):
  13. def __init__(self, reader, writer):
  14. self.reader = reader
  15. self.writer = writer
  16. # Queue of measurements to send to the worker
  17. self.queue = asyncio.Queue()
  18. class Peerfinder(object):
  19. def __init__(self, loop, parallel_workers=3):
  20. self.loop = loop
  21. self.workers = []
  22. # Limit the maximum number of parallel workers handling jobs (to
  23. # avoid overloading the targets of the jobs). For now, we have a
  24. # global semaphore for all measurements, but we might change it to
  25. # a semaphore for each measurement.
  26. self.parallel_workers = parallel_workers
  27. self.workers_semaphore = asyncio.Semaphore(parallel_workers)
  28. def send_pong(self, writer):
  29. msg = pf.Message()
  30. msg.type = pf.Message.Pong
  31. writer.write(msg.SerializeToString())
  32. yield from writer.drain()
  33. def send_target(self, writer, target):
  34. msg = pf.Message()
  35. msg.type = pf.Message.Target
  36. msg.target.target_id = target.target_id
  37. msg.target.target.address = target.target.address
  38. msg.target.target.family = target.target.family
  39. writer.write(msg.SerializeToString())
  40. def status(self):
  41. busy_workers = len([None for w in self.workers if not w.queue.empty()])
  42. print("Currently having {} workers, {} have queued work.".format(
  43. len(self.workers),
  44. busy_workers))
  45. print("Queue lengths: {}".format([w.queue.qsize() for w in self.workers]))
  46. print("Semaphore status: waiting for {}/{} workers.".format(
  47. self.parallel_workers - self.workers_semaphore._value,
  48. self.parallel_workers))
  49. @asyncio.coroutine
  50. def generate_measurements(self):
  51. """Generate dummy measurement at regular intervals, for debug"""
  52. target = pf.Target()
  53. target.target_id = 1
  54. target.target.address = "2001:db8::1"
  55. target.target.family = pf.IPAddress.IPV6
  56. while True:
  57. self.status()
  58. print("Sending measurement")
  59. for worker in list(self.workers):
  60. worker.queue.put_nowait(target)
  61. yield from asyncio.sleep(10 * random.random())
  62. @asyncio.coroutine
  63. def handle_commands(self, reader, writer):
  64. print("Client connecting")
  65. target_id = 0
  66. family = {4: pf.IPAddress.IPV4, 6: pf.IPAddress.IPV6}
  67. while True:
  68. if reader.at_eof():
  69. print("Exiting commands handler")
  70. return
  71. try:
  72. data = yield from reader.read(1024)
  73. except ConnectionResetError:
  74. print("Exiting commands handler")
  75. return
  76. if len(data) == 0:
  77. continue
  78. try:
  79. data = data.strip().decode()
  80. addr = IPAddress(data)
  81. except netaddr.AddrFormatError:
  82. print("Invalid command, disconnecting client")
  83. writer.close()
  84. return
  85. target_id += 1
  86. target = pf.Target()
  87. target.target_id = target_id
  88. target.target.address = str(addr)
  89. target.target.family = family[addr.version]
  90. self.status()
  91. print("Queueing measurement for all workers...")
  92. for worker in list(self.workers):
  93. worker.queue.put_nowait(target)
  94. @asyncio.coroutine
  95. def handle_worker(self, reader, writer):
  96. print("Worker connecting")
  97. worker = Worker(reader, writer)
  98. self.workers.append(worker)
  99. while True:
  100. # Wait for a new target
  101. target = yield from worker.queue.get()
  102. if reader.at_eof():
  103. print("Exiting worker handler")
  104. self.workers.remove(worker)
  105. return
  106. # Use a semaphore to avoid overloading targets
  107. with (yield from self.workers_semaphore):
  108. self.send_target(writer, target)
  109. try:
  110. yield from writer.drain()
  111. # TODO: timeout
  112. data = yield from reader.read(1024)
  113. except ConnectionResetError:
  114. print("Exiting worker handler")
  115. self.workers.remove(worker)
  116. return
  117. msg = pf.Message()
  118. answer = pf.Message()
  119. answer.ParseFromString(data)
  120. print("Received answer {}".format(answer))
  121. if __name__ == '__main__':
  122. loop = asyncio.get_event_loop()
  123. p = Peerfinder(loop)
  124. worker_coro = asyncio.start_server(p.handle_worker, WORKER_BIND_ADDRESS, WORKER_BIND_PORT, loop=loop)
  125. worker_server = loop.run_until_complete(worker_coro)
  126. command_coro = asyncio.start_server(p.handle_commands, COMMAND_BIND_ADDRESS, COMMAND_BIND_PORT, loop=loop)
  127. command_server = loop.run_until_complete(command_coro)
  128. #asyncio.async(p.generate_measurements())
  129. # Serve requests until CTRL+c is pressed
  130. print('Serving workers on {}'.format(worker_server.sockets[0].getsockname()))
  131. print('Listen to commands on {}'.format(command_server.sockets[0].getsockname()))
  132. try:
  133. loop.run_forever()
  134. except KeyboardInterrupt:
  135. pass
  136. # Close the server
  137. command_server.close()
  138. loop.run_until_complete(command_server.wait_closed())
  139. worker_server.close()
  140. loop.run_until_complete(worker_server.wait_closed())
  141. loop.close()