Parcourir la source

[1593] Restructure code into a more logical order

Stephen Morris il y a 13 ans
Parent
commit
c5a39e1312
1 fichiers modifiés avec 103 ajouts et 86 suppressions
  1. 103 86
      src/bin/sockcreator/sockcreator.cc

+ 103 - 86
src/bin/sockcreator/sockcreator.cc

@@ -16,38 +16,16 @@
 
 #include <util/io/fd.h>
 
-#include <unistd.h>
 #include <cerrno>
-#include <string.h>
+#include <cstring>
+
+#include <unistd.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <netinet/in.h>
 
 using namespace isc::util::io;
-
-namespace isc {
-namespace socket_creator {
-
-int
-get_sock(const int type, struct sockaddr *bind_addr, const socklen_t addr_len)
-{
-    int sock(socket(bind_addr->sa_family, type, 0));
-    if (sock == -1) {
-        return -1;
-    }
-    const int on(1);
-    if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) {
-        return -2; // This is part of the binding process, so it's a bind error
-    }
-    if (bind_addr->sa_family == AF_INET6 &&
-        setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) == -1) {
-        return -2; // This is part of the binding process, so it's a bind error
-    }
-    if (bind(sock, bind_addr, addr_len) == -1) {
-        return -2;
-    }
-    return sock;
-}
+using namespace isc::socket_creator;
 
 namespace {
 
@@ -69,6 +47,7 @@ write_message(int fd, const void* what, const size_t length) {
 // Exit on a protocol error after informing the client of the problem.
 void
 protocol_error(int fd, const char reason = 'I') {
+
     // Tell client we have a problem
     char message[2];
     message[0] = 'F';
@@ -79,68 +58,78 @@ protocol_error(int fd, const char reason = 'I') {
     isc_throw(ProtocolError, "Fatal error, reason: " << reason);
 }
 
-// Handle the request to create a socket
-
+// Handle the request from the client.
+//
+// Reads the type and family of socket required, creates the socket, then
+// returns it to the client.
+//
+// The arguments are the same as those passed to run().
 void
-create_socket(const int input_fd, const int output_fd,
-              const get_sock_t get_sock, const send_fd_t send_fd_fun,
-              const close_t close_fun)
+handle_request(const int input_fd, const int output_fd,
+               const get_sock_t get_sock, const send_fd_t send_fd_fun,
+               const close_t close_fun)
 {
-    // Read what type of socket they want
+    // Read the message from the client
     char type[2];
     read_message(input_fd, type, sizeof(type));
-    // Read the address they ask for
-    struct sockaddr *addr(NULL);
-    size_t addr_len(0);
-    struct sockaddr_in addr_in;
-    struct sockaddr_in6 addr_in6;
+
+    // Decide what type of socket is being asked for
+    int sock_type = 0;
+    switch (type[0]) {
+        case 'T':
+            sock_type = SOCK_STREAM;
+            break;
+
+        case 'U':
+            sock_type = SOCK_DGRAM;
+            break;
+
+        default:
+            protocol_error(output_fd);
+    }
+
+    // Read the address they ask for depending on what address family was
+    // specified.
+    sockaddr* addr = NULL;
+    size_t addr_len = 0;
+    sockaddr_in addr_in;
+    sockaddr_in6 addr_in6;
     switch (type[1]) { // The address family
-        /*
-         * Here are some casts. They are required by C++ and
-         * the low-level interface (they are implicit in C).
-         */
+
+        // The casting to apparently incompatible types by reinterpret_cast
+        // is required by the C low-level interface. Unions are not used
+        // because of the possibility of alignment issues.
+
         case '4':
-            addr = static_cast<struct sockaddr *>(
-                static_cast<void *>(&addr_in));
-            addr_len = sizeof addr_in;
+            addr = reinterpret_cast<sockaddr*>(&addr_in);
+            addr_len = sizeof(addr_in);
             memset(&addr_in, 0, sizeof addr_in);
             addr_in.sin_family = AF_INET;
-            read_message(input_fd,
-                static_cast<char *>(static_cast<void *>(
-                &addr_in.sin_port)), 2);
-            read_message(input_fd,
-                static_cast<char *>(static_cast<void *>(
-                &addr_in.sin_addr.s_addr)), 4);
+            read_message(input_fd, static_cast<void *>(&addr_in.sin_port),
+                         sizeof(addr_in.sin_port));
+            read_message(input_fd, static_cast<void *>(&addr_in.sin_addr.s_addr),
+                         sizeof(addr_in.sin_addr.s_addr));
             break;
+
         case '6':
-            addr = static_cast<struct sockaddr *>(
-                static_cast<void *>(&addr_in6));
+            addr = reinterpret_cast<sockaddr*>(&addr_in6);
             addr_len = sizeof addr_in6;
             memset(&addr_in6, 0, sizeof addr_in6);
             addr_in6.sin6_family = AF_INET6;
-            read_message(input_fd,
-                static_cast<char *>(static_cast<void *>(
-                &addr_in6.sin6_port)), 2);
-            read_message(input_fd,
-                static_cast<char *>(static_cast<void *>(
-                &addr_in6.sin6_addr.s6_addr)), 16);
-            break;
-        default:
-            protocol_error(output_fd);
-    }
-    int sock_type = 0;
-    switch (type[0]) { // Translate the type
-        case 'T':
-            sock_type = SOCK_STREAM;
-            break;
-        case 'U':
-            sock_type = SOCK_DGRAM;
+            read_message(input_fd, static_cast<void *>(&addr_in6.sin6_port),
+                         sizeof(addr_in6.sin6_port));
+            read_message(input_fd, static_cast<void *>(&addr_in6.sin6_addr.s6_addr),
+                         sizeof(addr_in6.sin6_addr.s6_addr));
             break;
+
         default:
             protocol_error(output_fd);
     }
-    int result(get_sock(sock_type, addr, addr_len));
-    if (result >= 0) { // We got the socket
+
+    // Obtain the socket
+    int result = get_sock(sock_type, addr, addr_len);
+    if (result >= 0) {
+        // Got the socket, send it to the client.
         write_message(output_fd, "S", 1);
         if (send_fd_fun(output_fd, result) != 0) {
             // We'll soon abort ourselves, but make sure we still
@@ -149,51 +138,79 @@ create_socket(const int input_fd, const int output_fd,
             close_fun(result);
             isc_throw(InternalError, "Error sending descriptor");
         }
-        // Don't leak the socket
+
+        // Don't leak the socket used to send the acquired socket back to the
+        // client.
         if (close_fun(result) == -1) {
             isc_throw(InternalError, "Error closing socket");
         }
     } else {
+        // Error.  Tell the client.
         write_message(output_fd, "E", 1);
         switch (result) {
             case -1:
                 write_message(output_fd, "S", 1);
                 break;
+
             case -2:
                 write_message(output_fd, "B", 1);
                 break;
+
             default:
                 isc_throw(InternalError, "Error creating socket");
-    }
-    int error(errno);
-    write_message(output_fd,
-        static_cast<char *>(static_cast<void *>(&error)),
-        sizeof error);
+        }
+
+        // Error reason code.
+        int error = errno;
+        write_message(output_fd, static_cast<void *>(&error), sizeof error);
     }
 }
 
 } // Anonymous namespace
 
-// 
+namespace isc {
+namespace socket_creator {
 
+// Get the socket and bind to it.
+int
+get_sock(const int type, struct sockaddr *bind_addr, const socklen_t addr_len)
+{
+    int sock(socket(bind_addr->sa_family, type, 0));
+    if (sock == -1) {
+        return -1;
+    }
+    const int on(1);
+    if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) == -1) {
+        return -2; // This is part of the binding process, so it's a bind error
+    }
+    if (bind_addr->sa_family == AF_INET6 &&
+        setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) == -1) {
+        return -2; // This is part of the binding process, so it's a bind error
+    }
+    if (bind(sock, bind_addr, addr_len) == -1) {
+        return -2;
+    }
+    return sock;
+}
+
+// Main run loop.
 void
 run(const int input_fd, const int output_fd, const get_sock_t get_sock,
     const send_fd_t send_fd_fun, const close_t close_fun)
 {
     for (;;) {
-        // Read the command
         char command;
         read_message(input_fd, &command, sizeof(command));
         switch (command) {
-            case 'T': // The "terminate" command
-                return;
-
-            case 'S':
-                create_socket(input_fd, output_fd, get_sock,
-                              send_fd_fun, close_fun);
+            case 'S':   // The "get socket" command
+                handle_request(input_fd, output_fd, get_sock,
+                               send_fd_fun, close_fun);
                 break;
 
-            default:
+            case 'T':   // The "terminate" command
+                return;
+
+            default:    // Don't recognise anything else
                 protocol_error(output_fd);
         }
     }