session_test.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. # Copyright (C) 2010 Internet Systems Consortium.
  2. #
  3. # Permission to use, copy, modify, and distribute this software for any
  4. # purpose with or without fee is hereby granted, provided that the above
  5. # copyright notice and this permission notice appear in all copies.
  6. #
  7. # THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
  8. # DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
  9. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
  10. # INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
  11. # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
  12. # FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
  13. # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
  14. # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. #
  16. # Tests for the ConfigData and MultiConfigData classes
  17. #
  18. import unittest
  19. import os
  20. import json
  21. from isc.cc.session import *
  22. # our fake socket, where we can read and insert messages
  23. class MySocket():
  24. def __init__(self, family, type):
  25. self.family = family
  26. self.type = type
  27. self.recvqueue = bytearray()
  28. self.sendqueue = bytearray()
  29. self._blocking = True
  30. self.send_limit = None
  31. def connect(self, to):
  32. pass
  33. def close(self):
  34. pass
  35. def setblocking(self, val):
  36. self._blocking = val
  37. def send(self, data):
  38. # If the upper limit is specified, only "send" up to the specified
  39. # limit
  40. if self.send_limit is not None and len(data) > self.send_limit:
  41. self.sendqueue.extend(data[0:self.send_limit])
  42. return self.send_limit
  43. else:
  44. self.sendqueue.extend(data)
  45. return len(data)
  46. def readsent(self, length):
  47. if length > len(self.sendqueue):
  48. raise Exception("readsent(" + str(length) + ") called, but only " + str(len(self.sendqueue)) + " in queue")
  49. result = self.sendqueue[:length]
  50. del self.sendqueue[:length]
  51. return result
  52. def readsentmsg(self):
  53. """return bytearray of the full message include length specifiers"""
  54. result = bytearray()
  55. length_buf = self.readsent(4)
  56. result.extend(length_buf)
  57. length = struct.unpack('>I', length_buf)[0]
  58. header_length_buf = self.readsent(2)
  59. header_length = struct.unpack('>H', header_length_buf)[0]
  60. result.extend(header_length_buf)
  61. data_length = length - 2 - header_length
  62. result.extend(self.readsent(header_length))
  63. result.extend(self.readsent(data_length))
  64. return result
  65. def readsentmsg_parsed(self):
  66. length_buf = self.readsent(4)
  67. length = struct.unpack('>I', length_buf)[0]
  68. header_length_buf = self.readsent(2)
  69. header_length = struct.unpack('>H', header_length_buf)[0]
  70. data_length = length - 2 - header_length
  71. env = json.loads(self.readsent(header_length).decode('utf-8'), strict=False)
  72. if (data_length > 0):
  73. msg = json.loads(self.readsent(data_length).decode('utf-8'), strict=False)
  74. else:
  75. msg = {}
  76. return (env, msg)
  77. def recv(self, length):
  78. if len(self.recvqueue) == 0:
  79. if self._blocking:
  80. return bytes()
  81. else:
  82. raise socket.error(errno.EAGAIN, "Resource temporarily unavailable")
  83. if length > len(self.recvqueue):
  84. raise Exception("Buffer underrun in test, does the test provide the right data?")
  85. result = self.recvqueue[:length]
  86. del self.recvqueue[:length]
  87. #print("[XX] returning: " + str(result))
  88. #print("[XX] queue now: " + str(self.recvqueue))
  89. return result
  90. def addrecv(self, env, msg = None):
  91. if type(env) == dict:
  92. env = isc.cc.message.to_wire(env)
  93. if type(msg) == dict:
  94. msg = isc.cc.message.to_wire(msg)
  95. length = 2 + len(env);
  96. if msg:
  97. length += len(msg)
  98. self.recvqueue.extend(struct.pack("!I", length))
  99. self.recvqueue.extend(struct.pack("!H", len(env)))
  100. self.recvqueue.extend(env)
  101. if msg:
  102. self.recvqueue.extend(msg)
  103. def settimeout(self, val):
  104. pass
  105. def gettimeout(self):
  106. return 0
  107. def set_send_limit(self, limit):
  108. '''Specify the upper limit of the transmittable data at once.
  109. By default, the send() method of this class "sends" all given data.
  110. If this method is called and the its parameter is not None,
  111. subsequent calls to send() will only transmit the specified amount
  112. of data. This can be used to emulate the situation where send()
  113. on a real socket object results in partial write.
  114. '''
  115. self.send_limit = limit
  116. #
  117. # We subclass the Session class we're testing here, only
  118. # to override the __init__() method, which wants a socket,
  119. # and we need to use our fake socket
  120. class MySession(Session):
  121. def __init__(self, port=9912, s=None):
  122. self._socket = None
  123. self._socket_timeout = 1
  124. self._lname = None
  125. self._recvbuffer = bytearray()
  126. self._recv_len_size = 0
  127. self._recv_size = 0
  128. self._sequence = 1
  129. self._closed = False
  130. self._queue = []
  131. self._lock = threading.RLock()
  132. if s is not None:
  133. self._socket = s
  134. else:
  135. try:
  136. self._socket = MySocket(socket.AF_INET, socket.SOCK_STREAM)
  137. self._socket.connect(tuple(['127.0.0.1', port]))
  138. self._lname = "test_name"
  139. # testing getlname here isn't useful, code removed
  140. except socket.error as se:
  141. raise SessionError(se)
  142. class testSession(unittest.TestCase):
  143. def test_session_close(self):
  144. sess = MySession()
  145. self.assertEqual("test_name", sess.lname)
  146. sess.close()
  147. self.assertRaises(SessionError, sess.sendmsg, {}, {"hello": "a"})
  148. def test_env_too_large(self):
  149. sess = MySession()
  150. largeenv = { "a": "b"*65535 }
  151. self.assertRaises(ProtocolError, sess.sendmsg, largeenv, {"hello": "a"})
  152. def test_session_sendmsg(self):
  153. sess = MySession()
  154. sess.sendmsg({}, {"hello": "a"})
  155. sent = sess._socket.readsentmsg();
  156. self.assertEqual(sent, b'\x00\x00\x00\x12\x00\x02{}{"hello": "a"}')
  157. sess.close()
  158. self.assertRaises(SessionError, sess.sendmsg, {}, {"hello": "a"})
  159. def test_session_sendmsg2(self):
  160. sess = MySession()
  161. sess.sendmsg({'to': 'someone', 'reply': 1}, {"hello": "a"})
  162. sent = sess._socket.readsentmsg();
  163. #print(sent)
  164. #self.assertRaises(SessionError, sess.sendmsg, {}, {"hello": "a"})
  165. def test_session_sendmsg_shortwrite(self):
  166. sess = MySession()
  167. # Specify the upper limit of the size that can be transmitted at
  168. # a single send() call on the faked socket (10 is an arbitrary choice,
  169. # just reasonably small).
  170. sess._socket.set_send_limit(10)
  171. sess.sendmsg({'to': 'someone', 'reply': 1}, {"hello": "a"})
  172. # The complete message should still have been transmitted in the end.
  173. sent = sess._socket.readsentmsg();
  174. def recv_and_compare(self, session, bytes, env, msg):
  175. """Adds bytes to the recvqueue (which will be read by the
  176. session object, and compare the resultinv env and msg to
  177. the ones given."""
  178. session._socket.addrecv(bytes)
  179. s_env, s_msg = session.recvmsg(False)
  180. self.assertEqual(env, s_env)
  181. self.assertEqual(msg, s_msg)
  182. # clear the recv buffer in case a malformed message left garbage
  183. # (actually, shouldn't that case provide some error instead of
  184. # None?)
  185. session._socket.recvqueue = bytearray()
  186. def test_session_recvmsg(self):
  187. sess = MySession()
  188. # {'to': "someone"}, {"hello": "a"}
  189. #self.recv_and_compare(sess,
  190. # b'\x00\x00\x00\x1f\x00\x10Skan\x02to(\x07someoneSkan\x05hello(\x01a',
  191. # {'to': "someone"}, {"hello": "a"})
  192. # 'malformed' messages
  193. # shouldn't some of these raise exceptions?
  194. #self.recv_and_compare(sess,
  195. # b'\x00',
  196. # None, None)
  197. #self.recv_and_compare(sess,
  198. # b'\x00\x00\x00\x10',
  199. # None, None)
  200. #self.recv_and_compare(sess,
  201. # b'\x00\x00\x00\x02\x00\x00',
  202. # None, None)
  203. #self.recv_and_compare(sess,
  204. # b'\x00\x00\x00\x02\x00\x02',
  205. # None, None)
  206. #self.recv_and_compare(sess,
  207. # b'',
  208. # None, None)
  209. # need to clear
  210. sess._socket.recvqueue = bytearray()
  211. # 'queueing' system
  212. # sending message {'to': 'someone', 'reply': 1}, {"hello": "a"}
  213. #print("sending message {'to': 'someone', 'reply': 1}, {'hello': 'a'}")
  214. # get no message without asking for a specific sequence number reply
  215. self.assertFalse(sess.has_queued_msgs())
  216. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
  217. env, msg = sess.recvmsg(True)
  218. self.assertEqual(None, env)
  219. self.assertTrue(sess.has_queued_msgs())
  220. env, msg = sess.recvmsg(True, 1)
  221. self.assertEqual({'to': 'someone', 'reply': 1}, env)
  222. self.assertEqual({"hello": "a"}, msg)
  223. self.assertFalse(sess.has_queued_msgs())
  224. # ask for a differe sequence number reply (that doesn't exist)
  225. # then ask for the one that is there
  226. self.assertFalse(sess.has_queued_msgs())
  227. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
  228. env, msg = sess.recvmsg(True, 2)
  229. self.assertEqual(None, env)
  230. self.assertEqual(None, msg)
  231. self.assertTrue(sess.has_queued_msgs())
  232. env, msg = sess.recvmsg(True, 1)
  233. self.assertEqual({'to': 'someone', 'reply': 1}, env)
  234. self.assertEqual({"hello": "a"}, msg)
  235. self.assertFalse(sess.has_queued_msgs())
  236. # ask for a differe sequence number reply (that doesn't exist)
  237. # then ask for any message
  238. self.assertFalse(sess.has_queued_msgs())
  239. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
  240. env, msg = sess.recvmsg(True, 2)
  241. self.assertEqual(None, env)
  242. self.assertEqual(None, msg)
  243. self.assertTrue(sess.has_queued_msgs())
  244. env, msg = sess.recvmsg(True, 1)
  245. self.assertEqual({'to': 'someone', 'reply': 1}, env)
  246. self.assertEqual({"hello": "a"}, msg)
  247. self.assertFalse(sess.has_queued_msgs())
  248. #print("sending message {'to': 'someone', 'reply': 1}, {'hello': 'a'}")
  249. # ask for a differe sequence number reply (that doesn't exist)
  250. # send a new message, ask for specific message (get the first)
  251. # then ask for any message (get the second)
  252. self.assertFalse(sess.has_queued_msgs())
  253. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {'hello': 'a'})
  254. env, msg = sess.recvmsg(True, 2)
  255. self.assertEqual(None, env)
  256. self.assertEqual(None, msg)
  257. self.assertTrue(sess.has_queued_msgs())
  258. sess._socket.addrecv({'to': 'someone' }, {'hello': 'b'})
  259. env, msg = sess.recvmsg(True, 1)
  260. self.assertEqual({'to': 'someone', 'reply': 1 }, env)
  261. self.assertEqual({"hello": "a"}, msg)
  262. self.assertFalse(sess.has_queued_msgs())
  263. env, msg = sess.recvmsg(True)
  264. self.assertEqual({'to': 'someone'}, env)
  265. self.assertEqual({"hello": "b"}, msg)
  266. self.assertFalse(sess.has_queued_msgs())
  267. # send a message, then one with specific reply value
  268. # ask for that specific message (get the second)
  269. # then ask for any message (get the first)
  270. self.assertFalse(sess.has_queued_msgs())
  271. sess._socket.addrecv({'to': 'someone' }, {'hello': 'b'})
  272. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {'hello': 'a'})
  273. env, msg = sess.recvmsg(True, 1)
  274. self.assertEqual({'to': 'someone', 'reply': 1}, env)
  275. self.assertEqual({"hello": "a"}, msg)
  276. self.assertTrue(sess.has_queued_msgs())
  277. env, msg = sess.recvmsg(True)
  278. self.assertEqual({'to': 'someone'}, env)
  279. self.assertEqual({"hello": "b"}, msg)
  280. self.assertFalse(sess.has_queued_msgs())
  281. def test_recv_bad_msg(self):
  282. sess = MySession()
  283. self.assertFalse(sess.has_queued_msgs())
  284. sess._socket.addrecv({'to': 'someone' }, {'hello': 'b'})
  285. sess._socket.addrecv({'to': 'someone', 'reply': 1}, {'hello': 'a'})
  286. # mangle the bytes a bit
  287. sess._socket.recvqueue[5] = sess._socket.recvqueue[5] - 2
  288. sess._socket.recvqueue = sess._socket.recvqueue[:-2]
  289. self.assertRaises(SessionError, sess.recvmsg, True, 1)
  290. def test_next_sequence(self):
  291. sess = MySession()
  292. self.assertEqual(sess._sequence, 1)
  293. self.assertEqual(sess._next_sequence(), 2)
  294. self.assertEqual(sess._sequence, 2)
  295. sess._sequence = 47805
  296. self.assertEqual(sess._sequence, 47805)
  297. self.assertEqual(sess._next_sequence(), 47806)
  298. self.assertEqual(sess._sequence, 47806)
  299. def test_group_subscribe(self):
  300. sess = MySession()
  301. sess.group_subscribe("mygroup")
  302. sent = sess._socket.readsentmsg_parsed()
  303. self.assertEqual(sent, ({"group": "mygroup", "type": "subscribe",
  304. "instance": "*"}, {}))
  305. sess.group_subscribe("mygroup")
  306. sent = sess._socket.readsentmsg_parsed()
  307. self.assertEqual(sent, ({"group": "mygroup", "type": "subscribe",
  308. "instance": "*"}, {}))
  309. sess.group_subscribe("mygroup", "my_instance")
  310. sent = sess._socket.readsentmsg_parsed()
  311. self.assertEqual(sent, ({"group": "mygroup", "type": "subscribe",
  312. "instance": "my_instance"}, {}))
  313. def test_group_unsubscribe(self):
  314. sess = MySession()
  315. sess.group_unsubscribe("mygroup")
  316. sent = sess._socket.readsentmsg_parsed()
  317. self.assertEqual(sent, ({"group": "mygroup", "type": "unsubscribe",
  318. "instance": "*"}, {}))
  319. sess.group_unsubscribe("mygroup")
  320. sent = sess._socket.readsentmsg_parsed()
  321. self.assertEqual(sent, ({"group": "mygroup", "type": "unsubscribe",
  322. "instance": "*"}, {}))
  323. sess.group_unsubscribe("mygroup", "my_instance")
  324. sent = sess._socket.readsentmsg_parsed()
  325. self.assertEqual(sent, ({"group": "mygroup", "type": "unsubscribe",
  326. "instance": "my_instance"}, {}))
  327. def test_group_sendmsg(self):
  328. sess = MySession()
  329. self.assertEqual(sess._sequence, 1)
  330. sess.group_sendmsg({ 'hello': 'a' }, "my_group")
  331. sent = sess._socket.readsentmsg_parsed()
  332. self.assertEqual(sent, ({"from": "test_name", "seq": 2, "to": "*",
  333. "instance": "*", "group": "my_group",
  334. "type": "send"}, {"hello": "a"}))
  335. self.assertEqual(sess._sequence, 2)
  336. sess.group_sendmsg({ 'hello': 'a' }, "my_group", "my_instance")
  337. sent = sess._socket.readsentmsg_parsed()
  338. self.assertEqual(sent, ({"from": "test_name", "seq": 3, "to": "*", "instance": "my_instance", "group": "my_group", "type": "send"}, {"hello": "a"}))
  339. self.assertEqual(sess._sequence, 3)
  340. sess.group_sendmsg({ 'hello': 'a' }, "your_group", "your_instance")
  341. sent = sess._socket.readsentmsg_parsed()
  342. self.assertEqual(sent, ({"from": "test_name", "seq": 4, "to": "*", "instance": "your_instance", "group": "your_group", "type": "send"}, {"hello": "a"}))
  343. self.assertEqual(sess._sequence, 4)
  344. def test_group_recvmsg(self):
  345. # must this one do anything except not return messages with
  346. # no header?
  347. pass
  348. def test_group_reply(self):
  349. sess = MySession()
  350. sess.group_reply({ 'from': 'me', 'group': 'our_group',
  351. 'instance': 'other_instance', 'seq': 4},
  352. {"hello": "a"})
  353. sent = sess._socket.readsentmsg_parsed();
  354. self.assertEqual(sent, ({"from": "test_name", "seq": 2,
  355. "to": "me", "instance": "other_instance",
  356. "reply": 4, "group": "our_group",
  357. "type": "send"},
  358. {"hello": "a"}))
  359. sess.group_reply({ 'from': 'me', 'group': 'our_group',
  360. 'instance': 'other_instance', 'seq': 9},
  361. {"hello": "a"})
  362. sent = sess._socket.readsentmsg_parsed();
  363. self.assertEqual(sent, ({"from": "test_name", "seq": 3,
  364. "to": "me", "instance": "other_instance",
  365. "reply": 9, "group": "our_group",
  366. "type": "send"},
  367. {"hello": "a"}))
  368. def test_timeout(self):
  369. if "BIND10_TEST_SOCKET_FILE" not in os.environ:
  370. self.assertEqual("", "This test can only run if the value BIND10_TEST_SOCKET_FILE is set in the environment")
  371. TEST_SOCKET_FILE = os.environ["BIND10_TEST_SOCKET_FILE"]
  372. # create a read domain socket to pass into the session
  373. s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  374. if os.path.exists(TEST_SOCKET_FILE):
  375. os.remove(TEST_SOCKET_FILE)
  376. s1.bind(TEST_SOCKET_FILE)
  377. try:
  378. s1.listen(1)
  379. s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
  380. s2.connect(TEST_SOCKET_FILE)
  381. sess = MySession(1, s2)
  382. # set timeout to 100 msec, so test does not take too long
  383. sess.set_timeout(100)
  384. self.assertRaises(SessionTimeout, sess.group_recvmsg, False)
  385. finally:
  386. os.remove(TEST_SOCKET_FILE)
  387. if __name__ == "__main__":
  388. unittest.main()