Browse Source

[1452] made the forwarder socket nonblocking and test the case we push sessions
too fast.

JINMEI Tatuya 13 years ago
parent
commit
8cd3a3f503
2 changed files with 86 additions and 15 deletions
  1. 50 15
      src/lib/util/io/socketsession.cc
  2. 36 0
      src/lib/util/tests/socketsession_unittest.cc

+ 50 - 15
src/lib/util/io/socketsession.cc

@@ -14,11 +14,13 @@
 
 
 #include <sys/types.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/socket.h>
+#include <sys/uio.h>
 #include <sys/un.h>
 #include <sys/un.h>
 
 
 #include <netinet/in.h>
 #include <netinet/in.h>
 
 
 #include <errno.h>
 #include <errno.h>
+#include <fcntl.h>
 #include <signal.h>
 #include <signal.h>
 #include <stdint.h>
 #include <stdint.h>
 #include <string.h>
 #include <string.h>
@@ -51,6 +53,17 @@ struct SocketSessionForwarder::ForwarderImpl {
     OutputBuffer buf_;
     OutputBuffer buf_;
 };
 };
 
 
+// The expected max size of the session header: 2-byte header length,
+// 6 32-bit fields, and 2 sockaddr structure.  sizeof sockaddr_storage
+// should be the possible max of any sockaddr structure.
+const size_t DEFAULT_HEADER_BUFLEN = 2 + sizeof(uint32_t) * 6 +
+    sizeof(struct sockaddr_storage) * 2;
+
+// The (default) socket buffer size for the forwarder.  This is chosen to
+// be sufficiently large to store two full-size DNS messages.  We may want to
+// customize this value in future.
+const int FORWARDER_BUFSIZE = (DEFAULT_HEADER_BUFLEN + 65536) * 2;
+
 SocketSessionForwarder::SocketSessionForwarder(const std::string& unix_file) :
 SocketSessionForwarder::SocketSessionForwarder(const std::string& unix_file) :
     impl_(NULL)
     impl_(NULL)
 {
 {
@@ -98,18 +111,30 @@ SocketSessionForwarder::connectToReceptor() {
         isc_throw(SocketSessionError, "Failed to create a UNIX domain socket: "
         isc_throw(SocketSessionError, "Failed to create a UNIX domain socket: "
                   << strerror(errno));
                   << strerror(errno));
     }
     }
+    // Make the socket non blocking
+    int fcntl_flags = fcntl(impl_->fd_, F_GETFL, 0);
+    if (fcntl_flags != -1) {
+        fcntl_flags |= O_NONBLOCK;
+        fcntl_flags = fcntl(impl_->fd_, F_SETFL, fcntl_flags);
+    }
+    if (fcntl_flags == -1) {
+        close();   // note: this is the internal method, not ::close()
+        isc_throw(SocketSessionError,
+                  "Failed to make UNIX domain socket non blocking: " <<
+                  strerror(errno));
+    }
+    if (setsockopt(impl_->fd_, SOL_SOCKET, SO_SNDBUF, &FORWARDER_BUFSIZE,
+                   sizeof(FORWARDER_BUFSIZE)) == -1) {
+        close();
+        isc_throw(SocketSessionError, "Failed to enlarge send buffer size");
+    }
     if (connect(impl_->fd_, convertSockAddr(&impl_->sock_un_),
     if (connect(impl_->fd_, convertSockAddr(&impl_->sock_un_),
                 impl_->sock_un_len_) == -1) {
                 impl_->sock_un_len_) == -1) {
-        close();   // note: this is the internal method, not ::close()
+        close();
         isc_throw(SocketSessionError, "Failed to connect to UNIX domain "
         isc_throw(SocketSessionError, "Failed to connect to UNIX domain "
                   "endpoint " << impl_->sock_un_.sun_path << ": " <<
                   "endpoint " << impl_->sock_un_.sun_path << ": " <<
                   strerror(errno));
                   strerror(errno));
     }
     }
-    int bufsize = 65536 * 2;
-    if (setsockopt(impl_->fd_, SOL_SOCKET, SO_SNDBUF, &bufsize,
-                   sizeof(bufsize)) == -1) {
-        isc_throw(SocketSessionError, "failed to enlarge receive buffer size");
-    }
 }
 }
 
 
 void
 void
@@ -144,6 +169,10 @@ SocketSessionForwarder::push(int sock, int family, int sock_type, int protocol,
                   << static_cast<int>(local_end.sa_family) << ", "
                   << static_cast<int>(local_end.sa_family) << ", "
                   << static_cast<int>(remote_end.sa_family) << " given");
                   << static_cast<int>(remote_end.sa_family) << " given");
     }
     }
+    if (data_len == 0 || data == NULL) {
+        isc_throw(SocketSessionError,
+                  "Data for a socket session must not be empty");
+    }
 
 
     if (send_fd(impl_->fd_, sock) != 0) {
     if (send_fd(impl_->fd_, sock) != 0) {
         isc_throw(SocketSessionError, "FD passing failed: " <<
         isc_throw(SocketSessionError, "FD passing failed: " <<
@@ -168,12 +197,21 @@ SocketSessionForwarder::push(int sock, int family, int sock_type, int protocol,
     // Write the resulting header length at the beginning of the buffer
     // Write the resulting header length at the beginning of the buffer
     impl_->buf_.writeUint16At(impl_->buf_.getLength() - sizeof(uint16_t), 0);
     impl_->buf_.writeUint16At(impl_->buf_.getLength() - sizeof(uint16_t), 0);
 
 
-    const int cc = write(impl_->fd_, impl_->buf_.getData(),
-                         impl_->buf_.getLength());
-    assert(cc == impl_->buf_.getLength());
-
-    const int cc_data = write(impl_->fd_, data, data_len);
-    assert(cc_data == data_len);
+    const struct iovec iov[2] = {
+        { const_cast<void*>(impl_->buf_.getData()), impl_->buf_.getLength() },
+        { const_cast<void*>(data), data_len }
+    };
+    const int cc = writev(impl_->fd_, iov, 2);
+    if (cc != impl_->buf_.getLength() + data_len) {
+        if (cc < 0) {
+            isc_throw(SocketSessionError,
+                      "Write failed in forwarding a socket session: " <<
+                      strerror(errno));
+        }
+        isc_throw(SocketSessionError,
+                  "Incomplete write in forwarding a socket session: " << cc <<
+                  "/" << (impl_->buf_.getLength() + data_len));
+    }
 }
 }
 
 
 SocketSession::SocketSession(int sock, int family, int type, int protocol,
 SocketSession::SocketSession(int sock, int family, int type, int protocol,
@@ -195,9 +233,6 @@ SocketSession::SocketSession(int sock, int family, int type, int protocol,
     }
     }
 }
 }
 
 
-const size_t DEFAULT_HEADER_BUFLEN = sizeof(struct sockaddr_storage) * 2 +
-    sizeof(uint32_t) * 6;
-
 struct SocketSessionReceptor::ReceptorImpl {
 struct SocketSessionReceptor::ReceptorImpl {
     ReceptorImpl(int fd) : fd_(fd),
     ReceptorImpl(int fd) : fd_(fd),
                            sa_local_(convertSockAddr(&ss_local_)),
                            sa_local_(convertSockAddr(&ss_local_)),

+ 36 - 0
src/lib/util/tests/socketsession_unittest.cc

@@ -532,6 +532,18 @@ TEST_F(ForwarderTest, badPush) {
                                  TEST_DATA, sizeof(TEST_DATA)),
                                  TEST_DATA, sizeof(TEST_DATA)),
                  SocketSessionError);
                  SocketSessionError);
 
 
+    // Empty data: we reject them at least for now
+    EXPECT_THROW(forwarder_.push(1, AF_INET, SOCK_DGRAM, IPPROTO_UDP,
+                                 *getSockAddr("192.0.2.1", "53").first,
+                                 *getSockAddr("192.0.2.2", "53").first,
+                                 TEST_DATA, 0),
+                 SocketSessionError);
+    EXPECT_THROW(forwarder_.push(1, AF_INET, SOCK_DGRAM, IPPROTO_UDP,
+                                 *getSockAddr("192.0.2.1", "53").first,
+                                 *getSockAddr("192.0.2.2", "53").first,
+                                 NULL, sizeof(TEST_DATA)),
+                 SocketSessionError);
+
     // Close the acceptor before push.  It will result in SIGPIPE (should be
     // Close the acceptor before push.  It will result in SIGPIPE (should be
     // ignored) and EPIPE, which will be converted to SocketSessionError.
     // ignored) and EPIPE, which will be converted to SocketSessionError.
     const int receptor_fd = acceptForwarder();
     const int receptor_fd = acceptForwarder();
@@ -543,6 +555,30 @@ TEST_F(ForwarderTest, badPush) {
                  SocketSessionError);
                  SocketSessionError);
 }
 }
 
 
+// A subroutine for pushTooFast.  Due to the fixed configuration of the
+// send buffer size, we shouldn't be able to forward 3 full-size DNS messages
+// without receiving them.  Exactly how many we can forward depends on the
+// internal system implementation, so we'll at least confirm we can't do for 3.
+void
+multiPush(SocketSessionForwarder& forwarder, const struct sockaddr& sa,
+          const void* data, size_t data_len)
+{
+    for (int i = 0; i < 3; ++i) {
+        forwarder.push(1, AF_INET, SOCK_DGRAM, IPPROTO_UDP, sa, sa,
+                       data, data_len);
+    }
+}
+
+TEST_F(ForwarderTest, pushTooFast) {
+    // Emulate the situation where the forwarder is pushing sessions too fast.
+    // It should eventually fail without blocking.
+    startListen();
+    forwarder_.connectToReceptor();
+    EXPECT_THROW(multiPush(forwarder_, *getSockAddr("192.0.2.1", "53").first,
+                           large_text_.c_str(), large_text_.length()),
+                 SocketSessionError);
+}
+
 TEST(SocketSession, badValue) {
 TEST(SocketSession, badValue) {
     // normal cases are confirmed in ForwarderTest.  We only check some
     // normal cases are confirmed in ForwarderTest.  We only check some
     // abnormal cases here.
     // abnormal cases here.