Browse Source

[1427] Finish get_token

Michal 'vorner' Vaner 13 years ago
parent
commit
b586771730

+ 20 - 3
src/lib/python/isc/bind10/socket_cache.py

@@ -19,6 +19,7 @@ Here's the cache for sockets from socket creator.
 
 
 import os
 import os
 import random
 import random
+import isc.bind10.sockcreator
 
 
 class SocketError(Exception):
 class SocketError(Exception):
     """
     """
@@ -26,8 +27,12 @@ class SocketError(Exception):
     socket. Possible reasons might be the address it should be bound to
     socket. Possible reasons might be the address it should be bound to
     is already taken, the permissions are insufficient, the address family
     is already taken, the permissions are insufficient, the address family
     is not supported on this computer and many more.
     is not supported on this computer and many more.
+
+    The errno, if not None, is passed from the socket creator.
     """
     """
-    pass
+    def __init__(self, message, errno):
+        Exception.__init__(self, message)
+        self.errno = errno
 
 
 class ShareError(Exception):
 class ShareError(Exception):
     """
     """
@@ -180,8 +185,20 @@ class Cache:
         except KeyError:
         except KeyError:
             # Something in the dicts is not there, so socket is to be
             # Something in the dicts is not there, so socket is to be
             # created
             # created
-            # TODO
-            pass
+            try:
+                fileno = self._creator.get_socket(address, port, protocol)
+            except isc.bind10.sockcreator.CreatorError as ce:
+                if ce.fatal:
+                    raise
+                else:
+                    raise SocketError(str(ce), ce.errno)
+            socket = Socket(protocol, address, port, fileno)
+            # And cache it
+            if protocol not in self._sockets:
+                self._sockets[protocol] = {}
+            if addr_str not in self._sockets[protocol]:
+                self._sockets[protocol][addr_str] = {}
+            self._sockets[protocol][addr_str][port] = socket
         # Now we get the token, check it is compatible
         # Now we get the token, check it is compatible
         if not socket.shareCompatible(share_mode, share_name):
         if not socket.shareCompatible(share_mode, share_name):
             raise ShareError("Cached socket not compatible with mode " +
             raise ShareError("Cached socket not compatible with mode " +

+ 67 - 3
src/lib/python/isc/bind10/tests/socket_cache_test.py

@@ -16,6 +16,7 @@
 import unittest
 import unittest
 import isc.log
 import isc.log
 import isc.bind10.socket_cache
 import isc.bind10.socket_cache
+import isc.bind10.sockcreator
 from isc.net.addr import IPAddr
 from isc.net.addr import IPAddr
 import os
 import os
 
 
@@ -142,6 +143,7 @@ class SocketCacheTest(Test):
         self.__address = IPAddr("192.0.2.1")
         self.__address = IPAddr("192.0.2.1")
         self.__socket = isc.bind10.socket_cache.Socket('Test', self.__address,
         self.__socket = isc.bind10.socket_cache.Socket('Test', self.__address,
                                                        1024, 42)
                                                        1024, 42)
+        self.__get_socket_called = False
 
 
     def test_init(self):
     def test_init(self):
         """
         """
@@ -154,16 +156,32 @@ class SocketCacheTest(Test):
         self.assertEqual({}, self.__cache._sockets)
         self.assertEqual({}, self.__cache._sockets)
         self.assertEqual(set(), self.__cache._live_tokens)
         self.assertEqual(set(), self.__cache._live_tokens)
 
 
+    def get_socket(self, address, port, socktype):
+        """
+        Pretend to be a socket creator.
+
+        This expects to be called with the _address, port 1024 and 'UDP'.
+
+        Returns 42 and notes down it was called.
+        """
+        self.assertEqual(self.__address, address)
+        self.assertEqual(1024, port)
+        self.assertEqual('UDP', socktype)
+        self.__get_socket_called = True
+        return 42
+
     def test_get_token_cached(self):
     def test_get_token_cached(self):
         """
         """
         Check the behaviour of get_token when the requested socket is already
         Check the behaviour of get_token when the requested socket is already
         cached inside.
         cached inside.
         """
         """
         self.__cache._sockets = {
         self.__cache._sockets = {
-            'UDP': {'192.0.2.1': {42: self.__socket}}
+            'UDP': {'192.0.2.1': {1024: self.__socket}}
         }
         }
-        token = self.__cache.get_token('UDP', self.__address, 42, 'ANY',
+        token = self.__cache.get_token('UDP', self.__address, 1024, 'ANY',
                                        'test')
                                        'test')
+        # It didn't call get_socket
+        self.assertFalse(self.__get_socket_called)
         # It returned something
         # It returned something
         self.assertIsNotNone(token)
         self.assertIsNotNone(token)
         # The token is both in the waiting sockets and the live tokens
         # The token is both in the waiting sockets and the live tokens
@@ -176,7 +194,7 @@ class SocketCacheTest(Test):
 
 
         # If we request one more, with incompatible share, it is rejected
         # If we request one more, with incompatible share, it is rejected
         self.assertRaises(isc.bind10.socket_cache.ShareError,
         self.assertRaises(isc.bind10.socket_cache.ShareError,
-                          self.__cache.get_token, 'UDP', self.__address, 42,
+                          self.__cache.get_token, 'UDP', self.__address, 1024,
                           'NO', 'test')
                           'NO', 'test')
         # The internals are not changed, so the same checks
         # The internals are not changed, so the same checks
         self.assertEqual({token: self.__socket}, self.__cache._waiting_tokens)
         self.assertEqual({token: self.__socket}, self.__cache._waiting_tokens)
@@ -184,6 +202,52 @@ class SocketCacheTest(Test):
         self.assertEqual({token: ('ANY', 'test')}, self.__socket.shares)
         self.assertEqual({token: ('ANY', 'test')}, self.__socket.shares)
         self.assertEqual(set([token]), self.__socket.waiting_tokens)
         self.assertEqual(set([token]), self.__socket.waiting_tokens)
 
 
+    def test_get_token_uncached(self):
+        """
+        Check a new socket is created when a corresponding one is missing.
+        """
+        token = self.__cache.get_token('UDP', self.__address, 1024, 'ANY',
+                                       'test')
+        # The get_socket was called
+        self.assertTrue(self.__get_socket_called)
+        # It returned something
+        self.assertIsNotNone(token)
+        # Get the socket and check it looks OK
+        socket = self.__cache._waiting_tokens[token]
+        self.assertEqual(self.__address, socket.address)
+        self.assertEqual(1024, socket.port)
+        self.assertEqual(42, socket.fileno)
+        self.assertEqual('UDP', socket.protocol)
+        # The socket is properly cached
+        self.assertEqual({
+            'UDP': {'192.0.2.1': {1024: socket}}
+        }, self.__cache._sockets)
+        # The token is both in the waiting sockets and the live tokens
+        self.assertEqual({token: socket}, self.__cache._waiting_tokens)
+        self.assertEqual(set([token]), self.__cache._live_tokens)
+        # The token got the new share to block any relevant queries
+        self.assertEqual({token: ('ANY', 'test')}, socket.shares)
+        # The socket knows the token is waiting in it
+        self.assertEqual(set([token]), socket.waiting_tokens)
+
+    def test_get_token_excs(self):
+        """
+        Test that it is handled properly if the socket creator raises
+        some exceptions.
+        """
+        def raiseCreatorError(fatal):
+            raise isc.bind10.sockcreator.CreatorError('test error', fatal)
+        # First, fatal socket creator errors are passed through
+        self.get_socket = lambda addr, port, proto: raiseCreatorError(True)
+        self.assertRaises(isc.bind10.sockcreator.CreatorError,
+                          self.__cache.get_token, 'UDP', self.__address, 1024,
+                          'NO', 'test')
+        # And nonfatal are converted to SocketError
+        self.get_socket = lambda addr, port, proto: raiseCreatorError(False)
+        self.assertRaises(isc.bind10.socket_cache.SocketError,
+                          self.__cache.get_token, 'UDP', self.__address, 1024,
+                          'NO', 'test')
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     isc.log.init("bind10")
     isc.log.init("bind10")
     isc.log.resetUnitTestRootLogger()
     isc.log.resetUnitTestRootLogger()