Browse Source

refactored the recv() code; fixed short reads on length bytes, removed the catch-all (currently there is a bit of special casing going on to keep the calling modules working, we may need to take a look at that), and removed the class variables, where data was stored but in the end that data was only used locally

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac312@2906 e5f2f494-b856-4b98-b285-d166d9295462
Jelte Jansen 14 years ago
parent
commit
a951cdc5ee
2 changed files with 47 additions and 45 deletions
  1. 34 34
      src/lib/python/isc/cc/session.py
  2. 13 11
      src/lib/python/isc/cc/tests/session_test.py

+ 34 - 34
src/lib/python/isc/cc/session.py

@@ -16,6 +16,7 @@
 import sys
 import socket
 import struct
+import errno
 import os
 import threading
 import bind10_config
@@ -33,8 +34,6 @@ class Session:
     def __init__(self, socket_file=None):
         self._socket = None
         self._lname = None
-        self._recvbuffer = bytearray()
-        self._recvlength = 0
         self._sequence = 1
         self._closed = False
         self._queue = []
@@ -121,6 +120,27 @@ class Session:
                     return isc.cc.message.from_wire(data[2:header_length+2]), None
             return None, None
 
+    def _receive_bytes(self, length, nonblock):
+        """Returns a bytearray of length bytes as read from the socket.
+           Raises a ProtocolError if it reads 0 bytes, unless nonblock
+           is True.
+           Re-raises errors raised by recv().
+           Returns either a bytearray of length bytes, or None (if
+           nonblock is True, and less than length bytes of data is
+           available)
+        """
+        data = bytearray()
+        while length > 0:
+            new_data = self._socket.recv(length)
+            if len(new_data) == 0: # server closed connection
+                if nonblock:
+                    return None
+                else:
+                    raise ProtocolError("Read of 0 bytes: connection closed")
+            data += new_data
+            length -= len(new_data)
+        return data
+
     def _receive_full_buffer(self, nonblock):
         if nonblock:
             self._socket.setblocking(0)
@@ -131,39 +151,19 @@ class Session:
             else:
                 self._socket.settimeout(self._socket_timeout)
 
-        if self._recvlength == 0:
-            length = 4
-            length -= len(self._recvbuffer)
-            try:
-                data = self._socket.recv(length)
-            except socket.timeout:
-                raise SessionTimeout("recv() on cc session timed out")
-            except:
-                return None
-            if data == "": # server closed connection
-                raise ProtocolError("Read of 0 bytes: connection closed")
-            self._recvbuffer += data
-            if len(self._recvbuffer) < 4:
-                return None
-            self._recvlength = struct.unpack('>I', self._recvbuffer)[0]
-            self._recvbuffer = bytearray()
-
-        length = self._recvlength - len(self._recvbuffer)
-        while (length > 0):
-            try:
-                data = self._socket.recv(length)
-            except socket.timeout:
-                raise SessionTimeout("recv() on cc session timed out")
-            except:
+        try:
+            data = self._receive_bytes(4, nonblock)
+            if data is not None:
+                data_length = struct.unpack('>I', data)[0]
+                data = self._receive_bytes(data_length, nonblock)
+            return (data)
+        except socket.timeout:
+            raise SessionTimeout("recv() on cc session timed out")
+        except socket.error as se:
+            if se.errno == errno.EINTR or \
+               (nonblock and se.errno) == errno.EAGAIN:
                 return None
-            if data == "": # server closed connection
-                raise ProtocolError("Read of 0 bytes: connection closed")
-            self._recvbuffer += data
-            length -= len(data)
-        data = self._recvbuffer
-        self._recvbuffer = bytearray()
-        self._recvlength = 0
-        return (data)
+            raise se
 
     def _next_sequence(self):
         self._sequence += 1

+ 13 - 11
src/lib/python/isc/cc/tests/session_test.py

@@ -67,6 +67,8 @@ class MySocket():
         return result
 
     def recv(self, length):
+        if len(self.recvqueue) == 0:
+            return bytes()
         if length > len(self.recvqueue):
             raise Exception("Buffer underrun in test, does the test provide the right data?")
         result = self.recvqueue[:length]
@@ -192,10 +194,10 @@ class testSession(unittest.TestCase):
         # get no message without asking for a specific sequence number reply
         self.assertFalse(sess.has_queued_msgs())
         sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
-        env, msg = sess.recvmsg(False)
+        env, msg = sess.recvmsg(True)
         self.assertEqual(None, env)
         self.assertTrue(sess.has_queued_msgs())
-        env, msg = sess.recvmsg(False, 1)
+        env, msg = sess.recvmsg(True, 1)
         self.assertEqual({'to': 'someone', 'reply': 1}, env)
         self.assertEqual({"hello": "a"}, msg)
         self.assertFalse(sess.has_queued_msgs())
@@ -204,11 +206,11 @@ class testSession(unittest.TestCase):
         # then ask for the one that is there
         self.assertFalse(sess.has_queued_msgs())
         sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
-        env, msg = sess.recvmsg(False, 2)
+        env, msg = sess.recvmsg(True, 2)
         self.assertEqual(None, env)
         self.assertEqual(None, msg)
         self.assertTrue(sess.has_queued_msgs())
-        env, msg = sess.recvmsg(False, 1)
+        env, msg = sess.recvmsg(True, 1)
         self.assertEqual({'to': 'someone', 'reply': 1}, env)
         self.assertEqual({"hello": "a"}, msg)
         self.assertFalse(sess.has_queued_msgs())
@@ -217,11 +219,11 @@ class testSession(unittest.TestCase):
         # then ask for any message
         self.assertFalse(sess.has_queued_msgs())
         sess._socket.addrecv({'to': 'someone', 'reply': 1}, {"hello": "a"})
-        env, msg = sess.recvmsg(False, 2)
+        env, msg = sess.recvmsg(True, 2)
         self.assertEqual(None, env)
         self.assertEqual(None, msg)
         self.assertTrue(sess.has_queued_msgs())
-        env, msg = sess.recvmsg(False, 1)
+        env, msg = sess.recvmsg(True, 1)
         self.assertEqual({'to': 'someone', 'reply': 1}, env)
         self.assertEqual({"hello": "a"}, msg)
         self.assertFalse(sess.has_queued_msgs())
@@ -233,16 +235,16 @@ class testSession(unittest.TestCase):
         # then ask for any message (get the second)
         self.assertFalse(sess.has_queued_msgs())
         sess._socket.addrecv({'to': 'someone', 'reply': 1}, {'hello': 'a'})
-        env, msg = sess.recvmsg(False, 2)
+        env, msg = sess.recvmsg(True, 2)
         self.assertEqual(None, env)
         self.assertEqual(None, msg)
         self.assertTrue(sess.has_queued_msgs())
         sess._socket.addrecv({'to': 'someone' }, {'hello': 'b'})
-        env, msg = sess.recvmsg(False, 1)
+        env, msg = sess.recvmsg(True, 1)
         self.assertEqual({'to': 'someone', 'reply': 1 }, env)
         self.assertEqual({"hello": "a"}, msg)
         self.assertFalse(sess.has_queued_msgs())
-        env, msg = sess.recvmsg(False)
+        env, msg = sess.recvmsg(True)
         self.assertEqual({'to': 'someone'}, env)
         self.assertEqual({"hello": "b"}, msg)
         self.assertFalse(sess.has_queued_msgs())
@@ -253,11 +255,11 @@ class testSession(unittest.TestCase):
         self.assertFalse(sess.has_queued_msgs())
         sess._socket.addrecv({'to': 'someone' }, {'hello': 'b'})
         sess._socket.addrecv({'to': 'someone', 'reply': 1}, {'hello': 'a'})
-        env, msg = sess.recvmsg(False, 1)
+        env, msg = sess.recvmsg(True, 1)
         self.assertEqual({'to': 'someone', 'reply': 1}, env)
         self.assertEqual({"hello": "a"}, msg)
         self.assertTrue(sess.has_queued_msgs())
-        env, msg = sess.recvmsg(False)
+        env, msg = sess.recvmsg(True)
         self.assertEqual({'to': 'someone'}, env)
         self.assertEqual({"hello": "b"}, msg)
         self.assertFalse(sess.has_queued_msgs())