Browse Source

[trac772] Propagate the remote endpoint to XfrOut session

Michal 'vorner' Vaner 13 years ago
parent
commit
ed9c17ed16
2 changed files with 38 additions and 4 deletions
  1. 17 2
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 21 2
      src/bin/xfrout/xfrout.py.in

+ 17 - 2
src/bin/xfrout/tests/xfrout_test.py.in

@@ -117,8 +117,8 @@ class TestXfroutSession(unittest.TestCase):
 
     def setUp(self):
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
-        #self.log = isc.log.NSLogger('xfrout', '',  severity = 'critical', log_to_console = False )
-        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(), TSIGKeyRing())
+        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
+                                       TSIGKeyRing(), ('127.0.0.1', 12345))
         self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
@@ -527,6 +527,21 @@ class TestUnixSockServer(unittest.TestCase):
         self.write_sock, self.read_sock = socket.socketpair()
         self.unix = MyUnixSockServer()
 
+    def test_guess_remote(self):
+        """Test we can guess the remote endpoint when we have only the
+           file descriptor. This is needed, because we get only that one
+           from auth."""
+        # We test with UDP, as it can be "connected" without other
+        # endpoint
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        sock.connect(('127.0.0.1', 12345))
+        self.assertEqual(('127.0.0.1', 12345),
+                         self.unix._guess_remote(sock.fileno()))
+        sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+        sock.connect(('::1', 12345))
+        self.assertEqual(('::1', 12345, 0, 0),
+                         self.unix._guess_remote(sock.fileno()))
+
     def test_receive_query_message(self):
         send_msg = b"\xd6=\x00\x00\x00\x01\x00"
         msg_len = struct.pack('H', socket.htons(len(send_msg)))

+ 21 - 2
src/bin/xfrout/xfrout.py.in

@@ -95,13 +95,14 @@ def get_rrset_len(rrset):
 
 
 class XfroutSession():
-    def __init__(self, sock_fd, request_data, server, tsig_key_ring):
+    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote):
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._server = server
         self._tsig_key_ring = tsig_key_ring
         self._tsig_ctx = None
         self._tsig_len = 0
+        self._remote = remote
         self.handle()
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -471,10 +472,28 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
             t.daemon = True
         t.start()
 
+    def _guess_remote(self, sock_fd):
+        """
+           Guess remote address and port of the socket. The socket must be a
+           socket
+        """
+        # This uses a trick. If the socket is IPv4 in reality and we pretend
+        # it to to be IPv6, it returns IPv4 address anyway. This doesn't seem
+        # to care about the SOCK_STREAM parameter at all (which it really is,
+        # except for testing)
+        if socket.has_ipv6:
+            sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM)
+        else:
+            # To make it work even on hosts without IPv6 support
+            # (Any idea how to simulate this in test?)
+            sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM)
+        return sock.getpeername()
 
     def finish_request(self, sock_fd, request_data):
         '''Finish one request by instantiating RequestHandlerClass.'''
-        self.RequestHandlerClass(sock_fd, request_data, self, self.tsig_key_ring)
+        self.RequestHandlerClass(sock_fd, request_data, self,
+                                 self.tsig_key_ring,
+                                 self._guess_remote(sock_fd))
 
     def _remove_unused_sock_file(self, sock_file):
         '''Try to remove the socket file. If the file is being used