session_test.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. # Copyright (C) 2010 Internet Systems Consortium.
  2. #
  3. # Permission to use, copy, modify, and distribute this software for any
  4. # purpose with or without fee is hereby granted, provided that the above
  5. # copyright notice and this permission notice appear in all copies.
  6. #
  7. # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
  8. # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
  9. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
  10. # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
  11. # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
  12. # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
  13. # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
  14. # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. #
  16. # Tests for the ConfigData and MultiConfigData classes
  17. #
  18. import unittest
  19. import os
  20. from isc.cc.session import *
  21. # our fake socket, where we can read and insert messages
  22. class MySocket():
  23. def __init__(self, family, type):
  24. self.family = family
  25. self.type = type
  26. self.recvqueue = bytearray()
  27. self.sendqueue = bytearray()
  28. self._blocking = True
  29. def connect(self, to):
  30. pass
  31. def close(self):
  32. pass
  33. def setblocking(self, val):
  34. self._blocking = val
  35. def send(self, data):
  36. self.sendqueue.extend(data);
  37. def readsent(self, length):
  38. if length > len(self.sendqueue):
  39. raise Exception("readsent(" + str(length) + ") called, but only " + str(len(self.sendqueue)) + " in queue")
  40. result = self.sendqueue[:length]
  41. del self.sendqueue[:length]
  42. return result
  43. def readsentmsg(self):
  44. """return bytearray of the full message include length specifiers"""
  45. result = bytearray()
  46. length_buf = self.readsent(4)
  47. result.extend(length_buf)
  48. length = struct.unpack('>I', length_buf)[0]
  49. header_length_buf = self.readsent(2)
  50. header_length = struct.unpack('>H', header_length_buf)[0]
  51. result.extend(header_length_buf)
  52. data_length = length - 2 - header_length
  53. result.extend(self.readsent(header_length))
  54. result.extend(self.readsent(data_length))
  55. return result
  56. def recv(self, length):
  57. if len(self.recvqueue) == 0:
  58. if self._blocking:
  59. return bytes()
  60. else:
  61. raise socket.error(errno.EAGAIN, "Resource temporarily unavailable")
  62. if length > len(self.recvqueue):
  63. raise Exception("Buffer underrun in test, does the test provide the right data?")
  64. result = self.recvqueue[:length]
  65. del self.recvqueue[:length]
  66. #print("[XX] returning: " + str(result))
  67. #print("[XX] queue now: " + str(self.recvqueue))
  68. return result
  69. def addrecv(self, env, msg = None):
  70. if type(env) == dict:
  71. env = isc.cc.message.to_wire(env)
  72. if type(msg) == dict:
  73. msg = isc.cc.message.to_wire(msg)
  74. length = 2 + len(env);
  75. if msg:
  76. length += len(msg)
  77. self.recvqueue.extend(struct.pack("!I", length))
  78. self.recvqueue.extend(struct.pack("!H", len(env)))
  79. self.recvqueue.extend(env)
  80. if msg:
  81. self.recvqueue.extend(msg)
  82. def settimeout(self, val):
  83. pass
  84. def gettimeout(self):
  85. return 0
  86. #
  87. # We subclass the Session class we're testing here, only
  88. # to override the __init__() method, which wants a socket,
  89. # and we need to use our fake socket
  90. class MySession(Session):
  91. def __init__(self, port=9912, s=None):
  92. self._socket = None
  93. self._socket_timeout = 1
  94. self._lname = None
  95. self._recvbuffer = bytearray()
  96. self._recv_len_size = 0
  97. self._recv_size = 0
  98. self._sequence = 1
  99. self._closed = False
  100. self._queue = []
  101. self._lock = threading.RLock()
  102. if s is not None:
  103. self._socket = s
  104. else:
  105. try:
  106. self._socket = MySocket(socket.AF_INET, socket.SOCK_STREAM)
  107. self._socket.connect(tuple(['127.0.0.1', port]))
  108. self._lname = "test_name"
  109. # testing getlname here isn't useful, code removed
  110. except socket.error as se:
  111. raise SessionError(se)
  112. class testSession(unittest.TestCase):
  113. def test_session_close(self):
  114. sess = MySession()
  115. self.assertEqual("test_name", sess.lname)
  116. sess.close()
  117. self.assertRaises(SessionError, sess.sendmsg, {}, {"hello": "a"})
  118. def test_env_too_large(self):
  119. sess = MySession()
  120. largeenv = { "a": "b"*65535 }
  121. self.assertRaises(ProtocolError, sess.sendmsg, largeenv, {"hello": "a"})
  122. def test_session_sendmsg(self):
  123. sess = MySession()
  124. sess.sendmsg({}, {"hello": "a"})
  125. sent = sess._socket.readsentmsg();
  126. self.assertEqual(sent, b'\x00\x00\x00\x12\x00\x02{}{"hello": "a"}')
  127. sess.close()
  128. self.assertRaises(SessionError, sess.sendmsg, {}, {"hello": "a"})
  129. def test_session_sendmsg2(self):
  130. sess = MySession()
  131. sess.sendmsg({'to': 'someone', 'reply': 1}, {"hello": "a"})
  132. sent = sess._socket.readsentmsg();
  133. #print(sent)
  134. #self.assertRaises(SessionError, sess.sendmsg, {}, {"hello": "a"})
  135. def recv_and_compare(self, session, bytes, env, msg):
  136. """Adds bytes to the recvqueue (which will be read by the
  137. session object, and compare the resultinv env and msg to
  138. the ones given."""
  139. session._socket.addrecv(bytes)
  140. s_env, s_msg = session.recvmsg(False)
  141. self.assertEqual(env, s_env)
  142. self.assertEqual(msg, s_msg)
  143. # clear the recv buffer in case a malformed message left garbage
  144. # (actually, shouldn't that case provide some error instead of
  145. # None?)
  146. session._socket.recvqueue = bytearray()
  147. def test_session_recvmsg(self):
  148. sess = MySession()
  149. # {'to': "someone"}, {"hello": "a"}
  150. #self.recv_and_compare(sess,
  151. # b'\x00\x00\x00\x1f\x00\x10Skan\x02to(\x07someoneSkan\x05hello(\x01a',
  152. # {'to': "someone"}, {"hello": "a"})
  153. # 'malformed' messages
  154. # shouldn't some of these raise exceptions?
  155. #self.recv_and_compare(sess,
  156. # b'\x00',
  157. # None, None)
  158. #self.recv_and_compare(sess,
  159. # b'\x00\x00\x00\x10',
  160. # None, None)
  161. #self.recv_and_compare(sess,
  162. # b'\x00\x00\x00\x02\x00\x00',
  163. # None, None)
  164. #self.recv_and_compare(sess,
  165. # b'\x00\x00\x00\x02\x00\x02',
  166. # None, None)
  167. #self.recv_and_compare(sess,
  168. # b'',
  169. # None, None)
  170. # need to clear
  171. sess._socket.recvqueue = bytearray()
  172. # 'queueing' system
  173. # sending message {'to': 'someone', 'reply': 1}, {"hello": "a"}
  174. #print("sending message {'to': 'someone', 'reply': 1}, {'hello': 'a'}")
  175. # get no message without asking for a specific sequence number reply
  176. self.assertFalse(sess.has_queued_msgs())
  177. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
  178. env, msg = sess.recvmsg(True)
  179. self.assertEqual(None, env)
  180. self.assertTrue(sess.has_queued_msgs())
  181. env, msg = sess.recvmsg(True, 1)
  182. self.assertEqual({'to': 'someone', 'reply': 1}, env)
  183. self.assertEqual({"hello": "a"}, msg)
  184. self.assertFalse(sess.has_queued_msgs())
  185. # ask for a differe sequence number reply (that doesn't exist)
  186. # then ask for the one that is there
  187. self.assertFalse(sess.has_queued_msgs())
  188. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
  189. env, msg = sess.recvmsg(True, 2)
  190. self.assertEqual(None, env)
  191. self.assertEqual(None, msg)
  192. self.assertTrue(sess.has_queued_msgs())
  193. env, msg = sess.recvmsg(True, 1)
  194. self.assertEqual({'to': 'someone', 'reply': 1}, env)
  195. self.assertEqual({"hello": "a"}, msg)
  196. self.assertFalse(sess.has_queued_msgs())
  197. # ask for a differe sequence number reply (that doesn't exist)
  198. # then ask for any message
  199. self.assertFalse(sess.has_queued_msgs())
  200. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
  201. env, msg = sess.recvmsg(True, 2)
  202. self.assertEqual(None, env)
  203. self.assertEqual(None, msg)
  204. self.assertTrue(sess.has_queued_msgs())
  205. env, msg = sess.recvmsg(True, 1)
  206. self.assertEqual({'to': 'someone', 'reply': 1}, env)
  207. self.assertEqual({"hello": "a"}, msg)
  208. self.assertFalse(sess.has_queued_msgs())
  209. #print("sending message {'to': 'someone', 'reply': 1}, {'hello': 'a'}")
  210. # ask for a differe sequence number reply (that doesn't exist)
  211. # send a new message, ask for specific message (get the first)
  212. # then ask for any message (get the second)
  213. self.assertFalse(sess.has_queued_msgs())
  214. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {'hello': 'a'})
  215. env, msg = sess.recvmsg(True, 2)
  216. self.assertEqual(None, env)
  217. self.assertEqual(None, msg)
  218. self.assertTrue(sess.has_queued_msgs())
  219. sess._socket.addrecv({'to': 'someone' }, {'hello': 'b'})
  220. env, msg = sess.recvmsg(True, 1)
  221. self.assertEqual({'to': 'someone', 'reply': 1 }, env)
  222. self.assertEqual({"hello": "a"}, msg)
  223. self.assertFalse(sess.has_queued_msgs())
  224. env, msg = sess.recvmsg(True)
  225. self.assertEqual({'to': 'someone'}, env)
  226. self.assertEqual({"hello": "b"}, msg)
  227. self.assertFalse(sess.has_queued_msgs())
  228. # send a message, then one with specific reply value
  229. # ask for that specific message (get the second)
  230. # then ask for any message (get the first)
  231. self.assertFalse(sess.has_queued_msgs())
  232. sess._socket.addrecv({'to': 'someone' }, {'hello': 'b'})
  233. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {'hello': 'a'})
  234. env, msg = sess.recvmsg(True, 1)
  235. self.assertEqual({'to': 'someone', 'reply': 1}, env)
  236. self.assertEqual({"hello": "a"}, msg)
  237. self.assertTrue(sess.has_queued_msgs())
  238. env, msg = sess.recvmsg(True)
  239. self.assertEqual({'to': 'someone'}, env)
  240. self.assertEqual({"hello": "b"}, msg)
  241. self.assertFalse(sess.has_queued_msgs())
  242. def test_recv_bad_msg(self):
  243. sess = MySession()
  244. self.assertFalse(sess.has_queued_msgs())
  245. sess._socket.addrecv({'to': 'someone' }, {'hello': 'b'})
  246. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {'hello': 'a'})
  247. # mangle the bytes a bit
  248. sess._socket.recvqueue[5] = sess._socket.recvqueue[5] - 2
  249. sess._socket.recvqueue = sess._socket.recvqueue[:-2]
  250. self.assertRaises(SessionError, sess.recvmsg, True, 1)
  251. def test_next_sequence(self):
  252. sess = MySession()
  253. self.assertEqual(sess._sequence, 1)
  254. self.assertEqual(sess._next_sequence(), 2)
  255. self.assertEqual(sess._sequence, 2)
  256. sess._sequence = 47805
  257. self.assertEqual(sess._sequence, 47805)
  258. self.assertEqual(sess._next_sequence(), 47806)
  259. self.assertEqual(sess._sequence, 47806)
  260. def test_group_subscribe(self):
  261. sess = MySession()
  262. sess.group_subscribe("mygroup")
  263. sent = sess._socket.readsentmsg()
  264. self.assertEqual(sent, b'\x00\x00\x00<\x00:{"group": "mygroup", "type": "subscribe", "instance": "*"}')
  265. sess.group_subscribe("mygroup")
  266. sent = sess._socket.readsentmsg()
  267. self.assertEqual(sent, b'\x00\x00\x00<\x00:{"group": "mygroup", "type": "subscribe", "instance": "*"}')
  268. sess.group_subscribe("mygroup", "my_instance")
  269. sent = sess._socket.readsentmsg()
  270. self.assertEqual(sent, b'\x00\x00\x00F\x00D{"group": "mygroup", "type": "subscribe", "instance": "my_instance"}')
  271. def test_group_unsubscribe(self):
  272. sess = MySession()
  273. sess.group_unsubscribe("mygroup")
  274. sent = sess._socket.readsentmsg()
  275. self.assertEqual(sent, b'\x00\x00\x00>\x00<{"group": "mygroup", "type": "unsubscribe", "instance": "*"}')
  276. sess.group_unsubscribe("mygroup")
  277. sent = sess._socket.readsentmsg()
  278. self.assertEqual(sent, b'\x00\x00\x00>\x00<{"group": "mygroup", "type": "unsubscribe", "instance": "*"}')
  279. sess.group_unsubscribe("mygroup", "my_instance")
  280. sent = sess._socket.readsentmsg()
  281. self.assertEqual(sent, b'\x00\x00\x00H\x00F{"group": "mygroup", "type": "unsubscribe", "instance": "my_instance"}')
  282. def test_group_sendmsg(self):
  283. sess = MySession()
  284. self.assertEqual(sess._sequence, 1)
  285. sess.group_sendmsg({ 'hello': 'a' }, "my_group")
  286. sent = sess._socket.readsentmsg()
  287. self.assertEqual(sent, b'\x00\x00\x00p\x00`{"from": "test_name", "seq": 2, "to": "*", "instance": "*", "group": "my_group", "type": "send"}{"hello": "a"}')
  288. self.assertEqual(sess._sequence, 2)
  289. sess.group_sendmsg({ 'hello': 'a' }, "my_group", "my_instance")
  290. sent = sess._socket.readsentmsg()
  291. self.assertEqual(sent, b'\x00\x00\x00z\x00j{"from": "test_name", "seq": 3, "to": "*", "instance": "my_instance", "group": "my_group", "type": "send"}{"hello": "a"}')
  292. self.assertEqual(sess._sequence, 3)
  293. sess.group_sendmsg({ 'hello': 'a' }, "your_group", "your_instance")
  294. sent = sess._socket.readsentmsg()
  295. self.assertEqual(sent, b'\x00\x00\x00~\x00n{"from": "test_name", "seq": 4, "to": "*", "instance": "your_instance", "group": "your_group", "type": "send"}{"hello": "a"}')
  296. self.assertEqual(sess._sequence, 4)
  297. def test_group_recvmsg(self):
  298. # must this one do anything except not return messages with
  299. # no header?
  300. pass
  301. def test_group_reply(self):
  302. sess = MySession()
  303. sess.group_reply({ 'from': 'me', 'group': 'our_group', 'instance': 'other_instance', 'seq': 4}, {"hello": "a"})
  304. sent = sess._socket.readsentmsg();
  305. self.assertEqual(sent, b'\x00\x00\x00\x8b\x00{{"from": "test_name", "seq": 2, "to": "me", "instance": "other_instance", "reply": 4, "group": "our_group", "type": "send"}{"hello": "a"}')
  306. sess.group_reply({ 'from': 'me', 'group': 'our_group', 'instance': 'other_instance', 'seq': 9}, {"hello": "a"})
  307. sent = sess._socket.readsentmsg();
  308. self.assertEqual(sent, b'\x00\x00\x00\x8b\x00{{"from": "test_name", "seq": 3, "to": "me", "instance": "other_instance", "reply": 9, "group": "our_group", "type": "send"}{"hello": "a"}')
  309. def test_timeout(self):
  310. if "BIND10_TEST_SOCKET_FILE" not in os.environ:
  311. self.assertEqual("", "This test can only run if the value BIND10_TEST_SOCKET_FILE is set in the environment")
  312. TEST_SOCKET_FILE = os.environ["BIND10_TEST_SOCKET_FILE"]
  313. # create a read domain socket to pass into the session
  314. s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  315. if os.path.exists(TEST_SOCKET_FILE):
  316. os.remove(TEST_SOCKET_FILE)
  317. s1.bind(TEST_SOCKET_FILE)
  318. try:
  319. s1.listen(1)
  320. s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  321. s2.connect(TEST_SOCKET_FILE)
  322. sess = MySession(1, s2)
  323. # set timeout to 100 msec, so test does not take too long
  324. sess.set_timeout(100)
  325. self.assertRaises(SessionTimeout, sess.group_recvmsg, False)
  326. finally:
  327. os.remove(TEST_SOCKET_FILE)
  328. if __name__ == "__main__":
  329. unittest.main()