session.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright (C) 2009 Internet Systems Consortium.
  2. #
  3. # Permission to use, copy, modify, and distribute this software for any
  4. # purpose with or without fee is hereby granted, provided that the above
  5. # copyright notice and this permission notice appear in all copies.
  6. #
  7. # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
  8. # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
  9. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
  10. # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
  11. # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
  12. # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
  13. # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
  14. # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. import sys
  16. import socket
  17. import struct
  18. from ISC.CC import Message
  19. class ProtocolError(Exception): pass
  20. class NetworkError(Exception): pass
  21. class SessionError(Exception): pass
  22. class Session:
  23. def __init__(self, port=9912):
  24. self._socket = None
  25. self._lname = None
  26. self._recvbuffer = bytearray()
  27. self._recvlength = None
  28. self._sendbuffer = bytearray()
  29. self._sequence = 1
  30. self._closed = False
  31. try:
  32. self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  33. self._socket.connect(tuple(['127.0.0.1', port]))
  34. self.sendmsg({ "type": "getlname" })
  35. env, msg = self.recvmsg(False)
  36. self._lname = msg["lname"]
  37. if not self._lname:
  38. raise ProtocolError("Could not get local name")
  39. except socket.error as se:
  40. raise SessionError(se)
  41. @property
  42. def lname(self):
  43. return self._lname
  44. def close(self):
  45. self._socket.close()
  46. self._lname = None
  47. self._closed = True
  48. def sendmsg(self, env, msg = None):
  49. if self._closed:
  50. raise SessionError("Session has been closed.")
  51. if type(env) == dict:
  52. env = Message.to_wire(env)
  53. if type(msg) == dict:
  54. msg = Message.to_wire(msg)
  55. self._socket.setblocking(1)
  56. length = 2 + len(env);
  57. if msg:
  58. length += len(msg)
  59. self._socket.send(struct.pack("!I", length))
  60. self._socket.send(struct.pack("!H", len(env)))
  61. self._socket.send(env)
  62. if msg:
  63. self._socket.send(msg)
  64. def recvmsg(self, nonblock = True):
  65. if self._closed:
  66. raise SessionError("Session has been closed.")
  67. data = self._receive_full_buffer(nonblock)
  68. if data and len(data) > 2:
  69. header_length = struct.unpack('>H', data[0:2])[0]
  70. data_length = len(data) - 2 - header_length
  71. if data_length > 0:
  72. return Message.from_wire(data[2:header_length+2]), Message.from_wire(data[header_length + 2:])
  73. else:
  74. return Message.from_wire(data[2:header_length+2]), None
  75. return None, None
  76. def _receive_full_buffer(self, nonblock):
  77. if nonblock:
  78. self._socket.setblocking(0)
  79. else:
  80. self._socket.setblocking(1)
  81. if self._recvlength == None:
  82. length = 4
  83. length -= len(self._recvbuffer)
  84. try:
  85. data = self._socket.recv(length)
  86. except:
  87. return None
  88. if data == "": # server closed connection
  89. raise ProtocolError("Read of 0 bytes: connection closed")
  90. self._recvbuffer += data
  91. if len(self._recvbuffer) < 4:
  92. return None
  93. self._recvlength = struct.unpack('>I', self._recvbuffer)[0]
  94. self._recvbuffer = bytearray()
  95. length = self._recvlength - len(self._recvbuffer)
  96. while (length > 0):
  97. try:
  98. data = self._socket.recv(length)
  99. except:
  100. return None
  101. if data == "": # server closed connection
  102. raise ProtocolError("Read of 0 bytes: connection closed")
  103. self._recvbuffer += data
  104. length -= len(data)
  105. data = self._recvbuffer
  106. self._recvbuffer = bytearray()
  107. self._recvlength = None
  108. return (data)
  109. def _next_sequence(self):
  110. self._sequence += 1
  111. return self._sequence
  112. def group_subscribe(self, group, instance = "*"):
  113. self.sendmsg({
  114. "type": "subscribe",
  115. "group": group,
  116. "instance": instance,
  117. })
  118. def group_unsubscribe(self, group, instance = "*"):
  119. self.sendmsg({
  120. "type": "unsubscribe",
  121. "group": group,
  122. "instance": instance,
  123. })
  124. def group_sendmsg(self, msg, group, instance = "*", to = "*"):
  125. seq = self._next_sequence()
  126. self.sendmsg({
  127. "type": "send",
  128. "from": self._lname,
  129. "to": to,
  130. "group": group,
  131. "instance": instance,
  132. "seq": seq,
  133. }, Message.to_wire(msg))
  134. return seq
  135. def group_recvmsg(self, nonblock = True):
  136. env, msg = self.recvmsg(nonblock)
  137. if env == None:
  138. # return none twice to match normal return value
  139. # (so caller won't get a type error on no data)
  140. return (None, None)
  141. return (msg, env)
  142. def group_reply(self, routing, msg):
  143. seq = self._next_sequence()
  144. self.sendmsg({
  145. "type": "send",
  146. "from": self._lname,
  147. "to": routing["from"],
  148. "group": routing["group"],
  149. "instance": routing["instance"],
  150. "seq": seq,
  151. "reply": routing["seq"],
  152. }, Message.to_wire(msg))
  153. return seq
  154. if __name__ == "__main__":
  155. import doctest
  156. doctest.testmod()