session.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (C) 2009 Internet Systems Consortium.
  2. #
  3. # Permission to use, copy, modify, and distribute this software for any
  4. # purpose with or without fee is hereby granted, provided that the above
  5. # copyright notice and this permission notice appear in all copies.
  6. #
  7. # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
  8. # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
  9. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
  10. # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
  11. # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
  12. # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
  13. # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
  14. # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. import sys
  16. import socket
  17. import struct
  18. import Message
  19. class ProtocolError(Exception): pass
  20. class Session:
  21. def __init__(self):
  22. self._socket = None
  23. self._lname = None
  24. self._recvbuffer = ""
  25. self._recvlength = None
  26. self._sendbuffer = ""
  27. self._sequence = 1
  28. self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  29. self._socket.connect(tuple(['127.0.0.1', 9912]))
  30. self.sendmsg({ "type": "getlname" })
  31. msg = self.recvmsg(False)
  32. self._lname = msg["lname"]
  33. if not self._lname:
  34. raise ProtocolError("Could not get local name")
  35. @property
  36. def lname(self):
  37. return self._lname
  38. def sendmsg(self, msg):
  39. if type(msg) == dict:
  40. msg = Message.to_wire(msg)
  41. self._socket.setblocking(1)
  42. self._socket.send(struct.pack("!I", len(msg)))
  43. self._socket.send(msg)
  44. def recvmsg(self, nonblock = True):
  45. data = self._receive_full_buffer(nonblock)
  46. if data:
  47. return Message.from_wire(data)
  48. return None
  49. def _receive_full_buffer(self, nonblock):
  50. if nonblock:
  51. self._socket.setblocking(0)
  52. else:
  53. self._socket.setblocking(1)
  54. if self._recvlength == None:
  55. length = 4
  56. length -= len(self._recvbuffer)
  57. try:
  58. data = self._socket.recv(length)
  59. except:
  60. return None
  61. if not data: # server closed connection
  62. return None
  63. self._recvbuffer += data
  64. if len(self._recvbuffer) < 4:
  65. return None
  66. self._recvlength = struct.unpack('>I', self._recvbuffer)[0]
  67. self._recvbuffer = ""
  68. length = self._recvlength - len(self._recvbuffer)
  69. while (length > 0):
  70. data = self._socket.recv(length)
  71. self._recvbuffer += data
  72. length -= len(data)
  73. data = self._recvbuffer
  74. self._recvbuffer = ""
  75. self._recvlength = None
  76. return (data)
  77. def _next_sequence(self):
  78. self._sequence += 1
  79. return self._sequence
  80. def group_subscribe(self, group, instance = "*", subtype = "normal"):
  81. self.sendmsg({
  82. "type": "subscribe",
  83. "group": group,
  84. "instance": instance,
  85. "subtype": subtype,
  86. })
  87. def group_unsubscribe(self, group, instance = "*"):
  88. self.sendmsg({
  89. "type": "unsubscribe",
  90. "group": group,
  91. "instance": instance,
  92. })
  93. def group_sendmsg(self, msg, group, instance = "*", to = "*"):
  94. self.sendmsg({
  95. "type": "send",
  96. "from": self._lname,
  97. "to": to,
  98. "group": group,
  99. "instance": instance,
  100. "seq": self._next_sequence(),
  101. "msg": Message.to_wire(msg),
  102. })
  103. def group_recvmsg(self, nonblock = True):
  104. msg = self.recvmsg(nonblock)
  105. if msg == None:
  106. return None
  107. data = Message.from_wire(msg["msg"])
  108. return (data, msg)
  109. if __name__ == "__main__":
  110. import doctest
  111. doctest.testmod()