Browse Source

timeouts in python cc session too + test

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac296@2653 e5f2f494-b856-4b98-b285-d166d9295462
Jelte Jansen 14 years ago
parent
commit
e8acfc19af

+ 18 - 1
src/lib/python/isc/cc/session.py

@@ -29,6 +29,10 @@ class SessionError(Exception): pass
 class Session:
     def __init__(self, socket_file=None):
         self._socket = None
+        # store the current timeout value in seconds (the way
+        # settimeout() wants them, our API takes milliseconds
+        # so that it is consistent with the C++ version)
+        self._socket_timeout = 4;
         self._lname = None
         self._recvbuffer = bytearray()
         self._recvlength = 0
@@ -36,7 +40,7 @@ class Session:
         self._closed = False
         self._queue = []
         self._lock = threading.RLock()
-
+        
         if socket_file is None:
             if "BIND10_MSGQ_SOCKET_FILE" in os.environ:
                 self.socket_file = os.environ["BIND10_MSGQ_SOCKET_FILE"]
@@ -123,6 +127,10 @@ class Session:
             self._socket.setblocking(0)
         else:
             self._socket.setblocking(1)
+            if self._socket_timeout == 0.0:
+                self._socket.settimeout(None)
+            else:
+                self._socket.settimeout(self._socket_timeout)
 
         if self._recvlength == 0:
             length = 4
@@ -208,6 +216,15 @@ class Session:
         }, isc.cc.message.to_wire(msg))
         return seq
 
+    def set_timeout(self, milliseconds):
+        """Sets the socket timeout for blocking reads to the given
+           number of milliseconds"""
+        self._socket_timeout = milliseconds / 1000.0
+
+    def get_timeout(self):
+        """Returns the current timeout for blocking reads (in milliseconds)"""
+        return self._socket_timeout * 1000.0
+
 if __name__ == "__main__":
     import doctest
     doctest.testmod()

+ 1 - 0
src/lib/python/isc/cc/tests/Makefile.am

@@ -11,5 +11,6 @@ check-local:
 	for pytest in $(PYTESTS) ; do \
 	echo Running test: $$pytest ; \
 	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python \
+	BIND10_TEST_SOCKET_FILE=$(builddir)/test_socket.sock \
 	$(PYCOVERAGE) $(abs_srcdir)/$$pytest ; \
 	done

+ 39 - 9
src/lib/python/isc/cc/tests/session_test.py

@@ -89,13 +89,20 @@ class MySocket():
         if msg:
             self.recvqueue.extend(msg)
 
+    def settimeout(self, val):
+        pass
+
+    def gettimeout(self):
+        return 0
+
 #
 # We subclass the Session class we're testing here, only
 # to override the __init__() method, which wants a socket,
 # and we need to use our fake socket
 class MySession(Session):
-    def __init__(self, port=9912):
+    def __init__(self, port=9912, s = None):
         self._socket = None
+        self._socket_timeout = 1
         self._lname = None
         self._recvbuffer = bytearray()
         self._recvlength = 0
@@ -104,13 +111,16 @@ class MySession(Session):
         self._queue = []
         self._lock = threading.RLock()
 
-        try:
-            self._socket = MySocket(socket.AF_INET, socket.SOCK_STREAM)
-            self._socket.connect(tuple(['127.0.0.1', port]))
-            self._lname = "test_name"
-            # testing getlname here isn't useful, code removed
-        except socket.error as se:
-                raise SessionError(se)
+        if s is not None:
+            self._socket = s
+        else:
+            try:
+                self._socket = MySocket(socket.AF_INET, socket.SOCK_STREAM)
+                self._socket.connect(tuple(['127.0.0.1', port]))
+                self._lname = "test_name"
+                # testing getlname here isn't useful, code removed
+            except socket.error as se:
+                    raise SessionError(se)
 
 class testSession(unittest.TestCase):
 
@@ -323,7 +333,27 @@ class testSession(unittest.TestCase):
         sess.group_reply({ 'from': 'me', 'group': 'our_group', 'instance': 'other_instance', 'seq': 9}, {"hello": "a"})
         sent = sess._socket.readsentmsg();
         self.assertEqual(sent, b'\x00\x00\x00\x8b\x00{{"from": "test_name", "seq": 3, "to": "me", "instance": "other_instance", "reply": 9, "group": "our_group", "type": "send"}{"hello": "a"}')
-        
+
+    def test_timeout(self):
+        if "BIND10_TEST_SOCKET_FILE" not in os.environ:
+            self.assertEqual("", "This test can only run if the value BIND10_TEST_SOCKET_FILE is set in the environment")
+        TEST_SOCKET_FILE = os.environ["BIND10_TEST_SOCKET_FILE"]
+
+        # create a read domain socket to pass into the session
+        s1 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        if os.path.exists(TEST_SOCKET_FILE):
+            os.remove(TEST_SOCKET_FILE)
+        s1.bind(TEST_SOCKET_FILE)
+        s1.listen(1)
+
+        s2 = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        s2.connect(TEST_SOCKET_FILE)
+        sess = MySession(1, s2)
+        # set timeout to 100 msec, so test does not take too long
+        sess.set_timeout(100)
+        env, msg = sess.group_recvmsg(False)
+        self.assertEqual(None, env)
+        self.assertEqual(None, msg)
         
 if __name__ == "__main__":
     unittest.main()