Browse Source

[trac800] WrappedSocket

Michal 'vorner' Vaner 13 years ago
parent
commit
3da7e8747d

+ 26 - 1
src/bin/bind10/sockcreator.py

@@ -17,6 +17,7 @@ import socket
 import struct
 import os
 from bind10_messages import *
+from libutil_io_python import recv_fd
 
 logger = isc.log.Logger("boss")
 
@@ -60,7 +61,10 @@ class Parser:
         Creates the parser. The creator_socket is socket to the socket creator
         process that will be used for communication. However, the object must
         have a read_fd() method to read the file descriptor. This slightly
-        unusual modification of socket object is used to easy up testing.
+        unusual trick with modifying an object is used to easy up testing.
+
+        You can use WrappedSocket in production code to add the method to any
+        ordinary socket.
         """
         self.__socket = creator_socket
         logger.info(BIND10_SOCKCREATOR_INIT)
@@ -166,3 +170,24 @@ class Parser:
                 raise CreatorError('Unexpected EOF', True)
             result += data
         return result
+
+class WrappedSocket:
+    """
+    This class wraps a socket and adds a read_fd method, so it can be used
+    for the Parser class conveniently. It simply copies all it's guts into
+    itself and implements the method.
+    """
+    def __init__(self, socket):
+        # Copy whatever can be copied from the socket
+        for name in dir(socket):
+            if name not in ['__class__', '__weakref__']:
+                setattr(self, name, getattr(socket, name))
+        # Keep the socket, so we can prevent it from being garbage-collected
+        # and closed before we are removed ourself
+        self.__orig_socket = socket
+
+    def read_fd(self):
+        """
+        Read the file descriptor from the socket.
+        """
+        return recv_fd(self.fileno())

+ 1 - 1
src/bin/bind10/tests/Makefile.am

@@ -21,7 +21,7 @@ endif
 	for pytest in $(PYTESTS) ; do \
 	echo Running test: $$pytest ; \
 	$(LIBRARY_PATH_PLACEHOLDER) \
-	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_srcdir)/src/bin:$(abs_top_builddir)/src/bin/bind10 \
+	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_srcdir)/src/bin:$(abs_top_builddir)/src/bin/bind10:$(abs_top_builddir)/src/lib/util/io/.libs \
 	BIND10_MSGQ_SOCKET_FILE=$(abs_top_builddir)/msgq_socket \
 		$(PYCOVERAGE_RUN) $(abs_builddir)/$$pytest || exit ; \
 	done

+ 35 - 1
src/bin/bind10/tests/sockcreator_test.py

@@ -20,9 +20,10 @@ Tests for the bind10.sockcreator module.
 import unittest
 import struct
 import socket
-from bind10.sockcreator import Parser, CreatorError
 from isc.net.addr import IPAddr
 import isc.log
+from libutil_io_python import send_fd
+from bind10.sockcreator import Parser, CreatorError, WrappedSocket
 
 class FakeCreator:
     """
@@ -272,6 +273,39 @@ class ParserTests(unittest.TestCase):
         self.assertRaises(ValueError, Parser(FakeCreator([])).get_socket,
                           addr, 42, socket.SOCK_DGRAM)
 
+class WrapTests(unittest.TestCase):
+    """
+    Tests for the wrap_socket function.
+    """
+    def test_wrap(self):
+        # We construct two pairs of socket. The receiving side of one pair will
+        # be wrapped. Then we send one of the other pair through this pair and
+        # check the received one can be used as a socket
+
+        # The transport socket
+        (t1, t2) = socket.socketpair()
+        # The payload socket
+        (p1, p2) = socket.socketpair()
+
+        t2 = WrappedSocket(t2)
+
+        # Transfer the descriptor
+        send_fd(t1.fileno(), p1.fileno())
+        p1 = socket.fromfd(t2.read_fd(), socket.AF_UNIX, socket.SOCK_STREAM)
+
+        # Now, pass some data trough the socket
+        p1.send(b'A')
+        data = p2.recv(1)
+        self.assertEqual(b'A', data)
+
+        # Test the wrapping didn't hurt the socket's usual methods
+        t1.send(b'B')
+        data = t2.recv(1)
+        self.assertEqual(b'B', data)
+        t2.send(b'C')
+        data = t1.recv(1)
+        self.assertEqual(b'C', data)
+
 if __name__ == '__main__':
     isc.log.init("bind10") # FIXME Should this be needed?
     isc.log.resetUnitTestRootLogger()