Parcourir la source

[1534] Close sockets on errors

Michal 'vorner' Vaner il y a 13 ans
Parent
commit
bc4fc18659

+ 18 - 6
src/bin/sockcreator/sockcreator.cc

@@ -126,7 +126,7 @@ handleRequest(const int input_fd, const int output_fd,
     }
 
     // Obtain the socket
-    const int result = get_sock(sock_type, addr, addr_len);
+    const int result = get_sock(sock_type, addr, addr_len, close_fun);
     if (result >= 0) {
         // Got the socket, send it to the client.
         writeMessage(output_fd, "S", 1);
@@ -198,6 +198,17 @@ mtu(int fd) {
     return (fd);
 }
 
+// This one closes the socket if result is negative. Used not to leak socket
+// on error.
+int maybeClose(const int result, const int socket, const close_t close_fun) {
+    if (result < 0) {
+        if (close_fun(socket) == -1) {
+            isc_throw(InternalError, "Error closing socket");
+        }
+    }
+    return (result);
+}
+
 } // Anonymous namespace
 
 namespace isc {
@@ -205,7 +216,8 @@ namespace socket_creator {
 
 // Get the socket and bind to it.
 int
-getSock(const int type, struct sockaddr* bind_addr, const socklen_t addr_len) {
+getSock(const int type, struct sockaddr* bind_addr, const socklen_t addr_len,
+        const close_t close_fun) {
     const int sock = socket(bind_addr->sa_family, type, 0);
     if (sock == -1) {
         return (-1);
@@ -213,19 +225,19 @@ getSock(const int type, struct sockaddr* bind_addr, const socklen_t addr_len) {
     const int on = 1;
     if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) {
         // This is part of the binding process, so it's a bind error
-        return (-2);
+        return (maybeClose(-2, sock, close_fun));
     }
     if (bind_addr->sa_family == AF_INET6 &&
         setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) == -1) {
         // This is part of the binding process, so it's a bind error
-        return (-2);
+        return (maybeClose(-2, sock, close_fun));
     }
     if (bind(sock, bind_addr, addr_len) == -1) {
-        return (-2);
+        return (maybeClose(-2, sock, close_fun));
     }
     if (type == SOCK_DGRAM && bind_addr->sa_family == AF_INET6) {
         // Set some MTU flags on IPv6 UDP sockets.
-        return (mtu(sock));
+        return (maybeClose(mtu(sock), sock, close_fun));
     }
     return (sock);
 }

+ 9 - 6
src/bin/sockcreator/sockcreator.h

@@ -73,6 +73,9 @@ public:
 };
 
 
+// Type of the close() function, so it can be passed as a parameter.
+// Argument is the same as that for close(2).
+typedef int (*close_t)(int);
 
 /// \short Create a socket and bind it.
 ///
@@ -82,13 +85,16 @@ public:
 /// \param type The type of socket to create (SOCK_STREAM, SOCK_DGRAM, etc).
 /// \param bind_addr The address to bind.
 /// \param addr_len The actual length of bind_addr.
+/// \param close_fun The furction used to close a socket if there's an error
+///     after the creation.
 ///
 /// \return The file descriptor of the newly created socket, if everything
 ///         goes well. A negative number is returned if an error occurs -
 ///         -1 if the socket() call fails or -2 if bind() fails. In case of
 ///         error, errno is set (or left intact from socket() or bind()).
 int
-getSock(const int type, struct sockaddr* bind_addr, const socklen_t addr_len);
+getSock(const int type, struct sockaddr* bind_addr, const socklen_t addr_len,
+        const close_t close_fun);
 
 // Define some types for functions used to perform socket-related operations.
 // These are typedefed so that alternatives can be passed through to the
@@ -96,16 +102,13 @@ getSock(const int type, struct sockaddr* bind_addr, const socklen_t addr_len);
 
 // Type of the function to get a socket and to pass it as parameter.
 // Arguments are those described above for getSock().
-typedef int (*get_sock_t)(const int, struct sockaddr *, const socklen_t);
+typedef int (*get_sock_t)(const int, struct sockaddr *, const socklen_t,
+                          const close_t close_fun);
 
 // Type of the send_fd() function, so it can be passed as a parameter.
 // Arguments are the same as those of the send_fd() function.
 typedef int (*send_fd_t)(const int, const int);
 
-// Type of the close() function, so it can be passed as a parameter.
-// Argument is the same as that for close(2).
-typedef int (*close_t)(int);
-
 
 /// \brief Infinite loop parsing commands and returning the sockets.
 ///

+ 36 - 9
src/bin/sockcreator/tests/sockcreator_tests.cc

@@ -22,6 +22,7 @@
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <netinet/in.h>
+#include <arpa/inet.h>
 #include <unistd.h>
 
 #include <iostream>
@@ -140,6 +141,11 @@ void addressFamilySpecificCheck(const sockaddr_in6*, const int socknum,
     }
 }
 
+// Just ignore the fd and pretend success. We close invalid fds in the tests.
+int
+closeIgnore(int) {
+    return (0);
+}
 
 // Generic version of the socket test.  It creates the socket and checks that
 // it is a valid descriptor.  The family-specific check functions are called
@@ -157,7 +163,8 @@ void testAnyCreate(int socket_type, socket_check_t socket_check) {
     memset(&addr, 0, sizeof(addr));
     setAddressFamilyFields(&addr);
     sockaddr* addr_ptr = reinterpret_cast<sockaddr*>(&addr);
-    const int socket = getSock(socket_type, addr_ptr, sizeof(addr));
+    const int socket = getSock(socket_type, addr_ptr, sizeof(addr),
+                               closeIgnore);
     ASSERT_GE(socket, 0) << "Couldn't create socket: failed with " <<
         "return code " << socket << " and error " << strerror(errno);
 
@@ -195,12 +202,37 @@ TEST(get_sock, tcp6_create) {
     testAnyCreate<sockaddr_in6>(SOCK_STREAM, tcpCheck);
 }
 
+bool close_called(false);
+
+int closeCall(int socket) {
+    close(socket);
+    close_called = true;
+    return (0);
+}
+
 // Ask the get_sock function for some nonsense and test if it is able to report
 // an error.
 TEST(get_sock, fail_with_nonsense) {
     sockaddr addr;
     memset(&addr, 0, sizeof(addr));
-    ASSERT_LT(getSock(0, &addr, sizeof addr), 0);
+    close_called = false;
+    ASSERT_EQ(-1, getSock(0, &addr, sizeof addr, closeCall));
+    ASSERT_FALSE(close_called); // The "socket" call should have failed already
+}
+
+// Bind should have failed here
+TEST(get_sock, fail_with_bind) {
+    sockaddr_in addr;
+    memset(&addr, 0, sizeof(addr));
+    addr.sin_family = AF_INET;
+    addr.sin_port = 1;
+    // No host should have this address on the interface, so it should not be
+    // possible to bind it.
+    addr.sin_addr.s_addr = inet_addr("192.0.2.1");
+    close_called = false;
+    ASSERT_EQ(-2, getSock(SOCK_STREAM, reinterpret_cast<sockaddr*>(&addr),
+                          sizeof addr, closeCall));
+    ASSERT_TRUE(close_called); // The "socket" call should have failed already
 }
 
 // The main run() function in the socket creator takes three functions to
@@ -222,7 +254,8 @@ TEST(get_sock, fail_with_nonsense) {
 // -1: The simulated bind() call has failed
 // -2: The simulated socket() call has failed
 int
-getSockDummy(const int type, struct sockaddr* addr, const socklen_t) {
+getSockDummy(const int type, struct sockaddr* addr, const socklen_t,
+             const close_t) {
     int result = 0;
     int port = 0;
 
@@ -277,12 +310,6 @@ send_FdDummy(const int destination, const int what) {
     return (status ? 0 : -1);
 }
 
-// Just ignore the fd and pretend success. We close invalid fds in the tests.
-int
-closeIgnore(int) {
-    return (0);
-}
-
 // Generic test that it works, with various inputs and outputs.
 // It uses different functions to create the socket and send it and pass
 // data to it and check it returns correct data back, to see if the run()