Browse Source

update

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac312@2916 e5f2f494-b856-4b98-b285-d166d9295462
Jelte Jansen 14 years ago
parent
commit
d6585dcead
2 changed files with 75 additions and 28 deletions
  1. 67 25
      src/lib/python/isc/cc/session.py
  2. 8 3
      src/lib/python/isc/cc/tests/session_test.py

+ 67 - 25
src/lib/python/isc/cc/session.py

@@ -39,6 +39,8 @@ class Session:
         self._queue = []
         self._lock = threading.RLock()
         self.set_timeout(self.MSGQ_DEFAULT_TIMEOUT);
+        self._recv_len_size = 0
+        self._recv_size = 0
 
         if socket_file is None:
             if "BIND10_MSGQ_SOCKET_FILE" in os.environ:
@@ -120,26 +122,42 @@ class Session:
                     return isc.cc.message.from_wire(data[2:header_length+2]), None
             return None, None
 
-    def _receive_bytes(self, length, nonblock):
-        """Returns a bytearray of length bytes as read from the socket.
-           Raises a ProtocolError if it reads 0 bytes, unless nonblock
-           is True.
-           Re-raises errors raised by recv().
-           Returns either a bytearray of length bytes, or None (if
-           nonblock is True, and less than length bytes of data is
-           available)
-        """
-        data = bytearray()
-        while length > 0:
-            new_data = self._socket.recv(length)
-            if len(new_data) == 0: # server closed connection
-                if nonblock:
-                    return None
-                else:
-                    raise ProtocolError("Read of 0 bytes: connection closed")
-            data += new_data
-            length -= len(new_data)
+    def _receive_bytes(self, size):
+        """Try to get size bytes of data from the socket.
+           Raises a ProtocolError if the size is 0.
+           Raises any error from recv().
+           Returns whatever data was available (if >0 bytes).
+           """
+        data = self._socket.recv(size)
+        if len(data) == 0: # server closed connection
+            raise ProtocolError("Read of 0 bytes: connection closed")
         return data
+        
+    def _receive_len_data(self):
+        """Reads self._recv_len_size bytes of data from the socket into
+           self._recv_len_data
+           This is done through class variables so in the case of
+           an EAGAIN we can continue on a subsequent call.
+           Raises a ProtocolError, a socket.error (which may be
+           timeout or eagain), or reads until we have all data we need.
+           """
+        while self._recv_len_size > 0:
+            new_data = self._receive_bytes(self._recv_len_size)
+            self._recv_len_data += new_data
+            self._recv_len_size -= len(new_data)
+
+    def _receive_data(self):
+        """Reads self._recv_size bytes of data from the socket into
+           self._recv_data.
+           This is done through class variables so in the case of
+           an EAGAIN we can continue on a subsequent call.
+           Raises a ProtocolError, a socket.error (which may be
+           timeout or eagain), or reads until we have all data we need.
+        """
+        while self._recv_size > 0:
+            new_data = self._receive_bytes(self._recv_size)
+            self._recv_data += new_data
+            self._recv_size -= len(new_data)
 
     def _receive_full_buffer(self, nonblock):
         if nonblock:
@@ -152,16 +170,40 @@ class Session:
                 self._socket.settimeout(self._socket_timeout)
 
         try:
-            data = self._receive_bytes(4, nonblock)
-            if data is not None:
-                data_length = struct.unpack('>I', data)[0]
-                data = self._receive_bytes(data_length, nonblock)
+            # we might be in a call following an EAGAIN, in which case
+            # we simply continue. In the first case, either
+            # recv_size or recv_len size are not zero
+            if self._recv_size == 0:
+                if self._recv_len_size == 0:
+                    # both zero, start a new full read
+                    self._recv_len_size = 4
+                    self._recv_len_data = bytearray()
+                self._receive_len_data()
+
+                self._recv_size = struct.unpack('>I', self._recv_len_data)[0]
+                self._recv_data = bytearray()
+            self._receive_data()
+
+            # no EAGAIN, so copy data and reset internal counters
+            data = self._recv_data
+
+            self._recv_len_size = 0
+            self._recv_size = 0
+
             return (data)
+
         except socket.timeout:
             raise SessionTimeout("recv() on cc session timed out")
         except socket.error as se:
-            if se.errno == errno.EINTR or \
-               (nonblock and se.errno) == errno.EAGAIN:
+            # Only keep data in case of EAGAIN
+            if se.errno == errno.EAGAIN:
+                return None
+            # unknown state otherwise, best to drop data
+            self._recv_len_size = 0
+            self._recv_size = 0
+            # ctrl-c can result in EINTR, return None to prevent
+            # stacktrace output
+            if se.errno == errno.EINTR:
                 return None
             raise se
 

+ 8 - 3
src/lib/python/isc/cc/tests/session_test.py

@@ -28,6 +28,7 @@ class MySocket():
         self.type = type
         self.recvqueue = bytearray()
         self.sendqueue = bytearray()
+        self._blocking = True
 
     def connect(self, to):
         pass
@@ -36,7 +37,7 @@ class MySocket():
         pass
 
     def setblocking(self, val):
-        pass
+        self._blocking = val
 
     def send(self, data):
         self.sendqueue.extend(data);
@@ -68,7 +69,10 @@ class MySocket():
 
     def recv(self, length):
         if len(self.recvqueue) == 0:
-            return bytes()
+            if self._blocking:
+                return bytes()
+            else:
+                raise socket.error(errno.EAGAIN, "Resource temporarily unavailable")
         if length > len(self.recvqueue):
             raise Exception("Buffer underrun in test, does the test provide the right data?")
         result = self.recvqueue[:length]
@@ -107,7 +111,8 @@ class MySession(Session):
         self._socket_timeout = 1
         self._lname = None
         self._recvbuffer = bytearray()
-        self._recvlength = 0
+        self._recv_len_size = 0
+        self._recv_size = 0
         self._sequence = 1
         self._closed = False
         self._queue = []