Browse Source

[1452] complete data validation in SocketSessionReceptor::pop with detailed
tests.

JINMEI Tatuya 13 years ago
parent
commit
1e4d796212

+ 1 - 1
src/bin/xfrout/xfrout.py.in

@@ -716,7 +716,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn,
             # This may happen when one xfrout process try to connect to
             # xfrout unix socket server, to check whether there is another
             # xfrout running.
-            if sock_fd == FD_COMM_ERROR:
+            if sock_fd == FD_SYSTEM_ERROR:
                 logger.error(XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR)
             return
 

+ 8 - 6
src/lib/util/io/fd_share.cc

@@ -88,12 +88,16 @@ recv_fd(const int sock) {
     msghdr.msg_controllen = cmsg_space(sizeof(int));
     msghdr.msg_control = malloc(msghdr.msg_controllen);
     if (msghdr.msg_control == NULL) {
-        return (FD_OTHER_ERROR);
+        return (FD_SYSTEM_ERROR);
     }
 
-    if (recvmsg(sock, &msghdr, 0) < 0) {
+    const int cc = recvmsg(sock, &msghdr, 0);
+    if (cc <= 0) {
         free(msghdr.msg_control);
-        return (FD_COMM_ERROR);
+        if (cc == 0) {
+            errno = ECONNRESET;
+        }
+        return (FD_SYSTEM_ERROR);
     }
     const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msghdr);
     int fd = FD_OTHER_ERROR;
@@ -131,10 +135,8 @@ send_fd(const int sock, const int fd) {
     *(int*)CMSG_DATA(cmsg) = fd;
 
     const int ret = sendmsg(sock, &msghdr, 0);
-    const int e = errno;
     free(msghdr.msg_control);
-    errno = e;                  // recover errno in case free() changed it
-    return (ret >= 0 ? 0 : FD_COMM_ERROR);
+    return (ret >= 0 ? 0 : FD_SYSTEM_ERROR);
 }
 
 } // End for namespace io

+ 9 - 6
src/lib/util/io/fd_share.h

@@ -25,7 +25,7 @@ namespace isc {
 namespace util {
 namespace io {
 
-const int FD_COMM_ERROR = -2;
+const int FD_SYSTEM_ERROR = -2;
 const int FD_OTHER_ERROR = -1;
 
 /**
@@ -33,8 +33,11 @@ const int FD_OTHER_ERROR = -1;
  * This receives a file descriptor sent over an unix domain socket. This
  * is the counterpart of send_fd().
  *
- * \return FD_COMM_ERROR when there's error receiving the socket, FD_OTHER_ERROR
- *     when there's a different error.
+ * \return FD_SYSTEM_ERROR when there's an error at the operating system
+ * level (such as a system call failure).  The global 'errno' variable
+ * indicates the specific error.  FD_OTHER_ERROR when there's a different
+ * error.
+ *
  * \param sock The unix domain socket to read from. Tested and it does
  *     not work with a pipe.
  */
@@ -45,9 +48,9 @@ int recv_fd(const int sock);
  * This sends a file descriptor over an unix domain socket. This is the
  * counterpart of recv_fd().
  *
- * \return FD_COMM_ERROR when there's error sending the socket, FD_OTHER_ERROR
- *     for all other possible errors.  The global 'errno' variable indicates
- *     the corresponding system error.
+ * \return FD_SYSTEM_ERROR when there's an error at the operating system
+ * level (such as a system call failure).  The global 'errno' variable
+ * indicates the specific error.
  * \param sock The unix domain socket to send to. Tested and it does not
  *     work with a pipe.
  * \param fd The file descriptor to send. It should work with any valid

+ 6 - 5
src/lib/util/io/fdshare_python.cc

@@ -67,14 +67,15 @@ PyInit_libutil_io_python(void) {
         return (NULL);
     }
 
-    PyObject* FD_COMM_ERROR = Py_BuildValue("i", isc::util::io::FD_COMM_ERROR);
-    if (FD_COMM_ERROR == NULL) {
+    PyObject* FD_SYSTEM_ERROR = Py_BuildValue("i",
+                                              isc::util::io::FD_SYSTEM_ERROR);
+    if (FD_SYSTEM_ERROR == NULL) {
         Py_XDECREF(mod);
         return (NULL);
     }
-    int ret = PyModule_AddObject(mod, "FD_COMM_ERROR", FD_COMM_ERROR);
-    if (-1 == ret) {
-        Py_XDECREF(FD_COMM_ERROR);
+    int ret = PyModule_AddObject(mod, "FD_SYSTEM_ERROR", FD_SYSTEM_ERROR);
+    if (ret == -1) {
+        Py_XDECREF(FD_SYSTEM_ERROR);
         Py_XDECREF(mod);
         return (NULL);
     }

+ 98 - 31
src/lib/util/io/socketsession.cc

@@ -45,24 +45,30 @@ namespace io {
 
 using namespace internal;
 
-struct SocketSessionForwarder::ForwarderImpl {
-    ForwarderImpl() : buf_(512) {}
-    struct sockaddr_un sock_un_;
-    socklen_t sock_un_len_;
-    int fd_;
-    OutputBuffer buf_;
-};
-
 // The expected max size of the session header: 2-byte header length,
 // 6 32-bit fields, and 2 sockaddr structure.  sizeof sockaddr_storage
 // should be the possible max of any sockaddr structure.
 const size_t DEFAULT_HEADER_BUFLEN = 2 + sizeof(uint32_t) * 6 +
     sizeof(struct sockaddr_storage) * 2;
 
+// The allowable maximum size of data passed with the socket FD.  For now
+// we use a fixed value of 65535, the largest possible size of valid DNS
+// messages.  We may enlarge it or make it configurable as we see the need
+// for more flexibility.
+const int MAX_DATASIZE = 65535;
+
 // The (default) socket buffer size for the forwarder and receptor.  This is
 // chosen to be sufficiently large to store two full-size DNS messages.  We
 // may want to customize this value in future.
-const int SOCKSESSION_BUFSIZE = (DEFAULT_HEADER_BUFLEN + 65536) * 2;
+const int SOCKSESSION_BUFSIZE = (DEFAULT_HEADER_BUFLEN + MAX_DATASIZE) * 2;
+
+struct SocketSessionForwarder::ForwarderImpl {
+    ForwarderImpl() : buf_(512) {}
+    struct sockaddr_un sock_un_;
+    socklen_t sock_un_len_;
+    int fd_;
+    OutputBuffer buf_;
+};
 
 SocketSessionForwarder::SocketSessionForwarder(const std::string& unix_file) :
     impl_(NULL)
@@ -265,39 +271,100 @@ SocketSessionReceptor::~SocketSessionReceptor() {
     delete impl_;
 }
 
+namespace {
+// A shortcut to throw common exception on failure of recv(2)
+void
+readFail(int actual_len, int expected_len) {
+    if (expected_len < 0) {
+        isc_throw(SocketSessionError, "Failed to receive data from "
+                  "SocketSessionForwarder: " << strerror(errno));
+    }
+    isc_throw(SocketSessionError, "Incomplete data from "
+              "SocketSessionForwarder: " << actual_len << "/" <<
+              expected_len);
+}
+}
+
 SocketSession
 SocketSessionReceptor::pop() {
     const int passed_fd = recv_fd(impl_->fd_);
-    // TODO: error check
+    if (passed_fd == FD_SYSTEM_ERROR) {
+        isc_throw(SocketSessionError, "Receiving a forwarded FD failed: " <<
+                  strerror(errno));
+    } else if (passed_fd < 0) {
+        isc_throw(SocketSessionError, "No FD forwarded");
+    }
 
     uint16_t header_len;
-    const int cc = recv(impl_->fd_, &header_len, sizeof(header_len),
+    const int cc_hlen = recv(impl_->fd_, &header_len, sizeof(header_len),
                         MSG_WAITALL);
-    assert(cc == sizeof(header_len)); // XXX
+    if (cc_hlen < sizeof(header_len)) {
+        readFail(cc_hlen, sizeof(header_len));
+    }
     header_len = InputBuffer(&header_len, sizeof(header_len)).readUint16();
+    if (header_len > DEFAULT_HEADER_BUFLEN) {
+        isc_throw(SocketSessionError, "Too large header length: " <<
+                  header_len);
+    }
     impl_->header_buf_.clear();
     impl_->header_buf_.resize(header_len);
-    recv(impl_->fd_, &impl_->header_buf_[0], header_len, MSG_WAITALL);
+    const int cc_hdr = recv(impl_->fd_, &impl_->header_buf_[0], header_len,
+                            MSG_WAITALL);
+    if (cc_hdr < header_len) {
+        readFail(cc_hdr, header_len);
+    }
 
     InputBuffer ibuffer(&impl_->header_buf_[0], header_len);
-    const int family = static_cast<int>(ibuffer.readUint32());
-    const int type = static_cast<int>(ibuffer.readUint32());
-    const int protocol = static_cast<int>(ibuffer.readUint32());
-    const socklen_t local_end_len = ibuffer.readUint32();
-    assert(local_end_len <= sizeof(impl_->ss_local_)); // XXX
-    ibuffer.readData(&impl_->ss_local_, local_end_len);
-    const socklen_t remote_end_len = ibuffer.readUint32();
-    assert(remote_end_len <= sizeof(impl_->ss_remote_)); // XXX
-    ibuffer.readData(&impl_->ss_remote_, remote_end_len);
-    const size_t data_len = ibuffer.readUint32();
-
-    impl_->data_buf_.clear();
-    impl_->data_buf_.resize(data_len);
-    recv(impl_->fd_, &impl_->data_buf_[0], data_len, MSG_WAITALL);
-
-    return (SocketSession(passed_fd, family, type, protocol,
-                          impl_->sa_local_, impl_->sa_remote_, data_len,
-                          &impl_->data_buf_[0]));
+    try {
+        const int family = static_cast<int>(ibuffer.readUint32());
+        if (family != AF_INET && family != AF_INET6) {
+            isc_throw(SocketSessionError,
+                      "Unsupported address family is passed: " << family);
+        }
+        const int type = static_cast<int>(ibuffer.readUint32());
+        const int protocol = static_cast<int>(ibuffer.readUint32());
+        const socklen_t local_end_len = ibuffer.readUint32();
+        if (local_end_len > sizeof(impl_->ss_local_)) {
+            isc_throw(SocketSessionError, "Local SA length too large: " <<
+                      local_end_len);
+        }
+        ibuffer.readData(&impl_->ss_local_, local_end_len);
+        const socklen_t remote_end_len = ibuffer.readUint32();
+        if (remote_end_len > sizeof(impl_->ss_remote_)) {
+            isc_throw(SocketSessionError, "Remote SA length too large: " <<
+                      remote_end_len);
+        }
+        ibuffer.readData(&impl_->ss_remote_, remote_end_len);
+        if (family != impl_->sa_local_->sa_family) {
+            isc_throw(SocketSessionError, "SA family inconsistent: " <<
+                      static_cast<int>(impl_->sa_local_->sa_family) << ", " <<
+                      static_cast<int>(impl_->sa_remote_->sa_family) <<
+                      " given, must be " << family);
+        }
+        const size_t data_len = ibuffer.readUint32();
+        if (data_len == 0 || data_len > MAX_DATASIZE) {
+            isc_throw(SocketSessionError,
+                      "Invalid socket session data size: " << data_len <<
+                      ", must be > 0 and <= " << MAX_DATASIZE);
+        }
+
+        impl_->data_buf_.clear();
+        impl_->data_buf_.resize(data_len);
+        const int cc_data = recv(impl_->fd_, &impl_->data_buf_[0], data_len,
+                                 MSG_WAITALL);
+        if (cc_data < data_len) {
+            readFail(cc_data, data_len);
+        }
+
+        return (SocketSession(passed_fd, family, type, protocol,
+                              impl_->sa_local_, impl_->sa_remote_, data_len,
+                              &impl_->data_buf_[0]));
+    } catch (const InvalidBufferPosition& ex) {
+        // We catch the case where the given header is too short and convert
+        // the exception to SocketSessionError.
+        isc_throw(SocketSessionError, "bogus socket session header: " <<
+                  ex.what());
+    }
 }
 
 }

+ 171 - 5
src/lib/util/tests/socketsession_unittest.cc

@@ -27,16 +27,20 @@
 #include <vector>
 
 #include <boost/noncopyable.hpp>
+#include <boost/scoped_ptr.hpp>
 
 #include <gtest/gtest.h>
 
 #include <exceptions/exceptions.h>
 
+#include <util/buffer.h>
+#include <util/io/fd_share.h>
 #include <util/io/socketsession.h>
 #include <util/io/sockaddr_util.h>
 
 using namespace std;
 using namespace isc;
+using boost::scoped_ptr;
 using namespace isc::util::io;
 using namespace isc::util::io::internal;
 
@@ -185,6 +189,20 @@ protected:
         }
     }
 
+    int dummyConnect() const {
+        const int s = socket(AF_UNIX, SOCK_STREAM, 0);
+        if (s == -1) {
+            isc_throw(isc::Unexpected,
+                      "failed to create a test UNIX domain socket");
+        }
+        setNonBlock(s, true);
+        if (connect(s, convertSockAddr(&test_un_), sizeof(test_un_)) == -1) {
+            isc_throw(isc::Unexpected,
+                      "failed to connect to the test SocketSessionForwarder");
+        }
+        return (s);
+    }
+
     // Accept a new connection from a SocketSessionForwarder and return
     // the socket FD of the new connection.  This assumes startListen()
     // has been called.
@@ -196,6 +214,10 @@ protected:
         if (s == -1) {
             isc_throw(isc::Unexpected, "accept failed: " << strerror(errno));
         }
+        // Make sure the socket is *blocking*.  We may pass large data, through
+        // it, and apparently non blocking read could cause some unexpected
+        // partial read on some systems.
+        setNonBlock(s, false);
         return (s);
     }
 
@@ -236,6 +258,60 @@ protected:
         return (s);
     }
 
+    // A helper method to push some (normally bogus) socket session header
+    // via a Unix domain socket that pretends to be a valid
+    // SocketSessionForwarder.  It first opens the Unix domain socket,
+    // and connect to the test receptor server (startListen() is expected to
+    // be called beforehand), forwards a valid file descriptor ("stdin" is
+    // used for simplicity), the pushed a 2-byte header length field of the
+    // session header.  The internal receptor_ pointer will be set to a
+    // newly created receptor object for the connection.
+    //
+    // \param hdrlen: The header length to be pushed.  It may or may not be
+    //                valid.
+    // \param hdrlen_len: The length of the actually pushed data as "header
+    //                    length".  Normally it should be 2 (the default), but
+    //                    could be a bogus value for testing.
+    // \param push_fd: Whether to forward the FD.  Normally it should be true,
+    //                 but can be false for testing.
+    void pushSessionHeader(uint16_t hdrlen,
+                           size_t hdrlen_len = sizeof(uint16_t),
+                           bool push_fd = true)
+    {
+        isc::util::OutputBuffer obuffer(0);
+        obuffer.clear();
+
+        dummy_forwarder_.reset(dummyConnect());
+        if (push_fd && send_fd(dummy_forwarder_.fd, 0) != 0) {
+            isc_throw(isc::Unexpected, "Failed to pass FD");
+        }
+        obuffer.writeUint16(hdrlen);
+        if (hdrlen_len > 0) {
+            send(dummy_forwarder_.fd, obuffer.getData(), hdrlen_len, 0);
+        }
+        accept_sock_.reset(acceptForwarder());
+        receptor_.reset(new SocketSessionReceptor(accept_sock_.fd));
+    }
+
+    void pushSession(int family, int type, int protocol, socklen_t local_len,
+                     const sockaddr& local, socklen_t remote_len,
+                     const sockaddr& remote,
+                     size_t data_len = sizeof(TEST_DATA))
+    {
+        isc::util::OutputBuffer obuffer(0);
+        obuffer.writeUint32(static_cast<uint32_t>(family));
+        obuffer.writeUint32(static_cast<uint32_t>(type));
+        obuffer.writeUint32(static_cast<uint32_t>(protocol));
+        obuffer.writeUint32(static_cast<uint32_t>(local_len));
+        obuffer.writeData(&local, getSALength(local));
+        obuffer.writeUint32(static_cast<uint32_t>(remote_len));
+        obuffer.writeData(&remote, getSALength(remote));
+        obuffer.writeUint32(static_cast<uint32_t>(data_len));
+        pushSessionHeader(obuffer.getLength());
+        send(dummy_forwarder_.fd, obuffer.getData(), obuffer.getLength(), 0);
+        send(dummy_forwarder_.fd, TEST_DATA, sizeof(TEST_DATA), 0);
+    }
+
     // See below
     void checkPushAndPop(int family, int type, int protocoal,
                          const SockAddrInfo& local,
@@ -245,6 +321,8 @@ protected:
 protected:
     int listen_fd_;
     SocketSessionForwarder forwarder_;
+    ScopedSocket dummy_forwarder_; // forwarder "like" socket to pass bad data
+    scoped_ptr<SocketSessionReceptor> receptor_;
     ScopedSocket accept_sock_;
     const string large_text_;
 
@@ -380,11 +458,6 @@ ForwarderTest::checkPushAndPop(int family, int type, int protocol,
         startListen();
         forwarder_.connectToReceptor();
         accept_sock_.reset(acceptForwarder());
-
-        // Make sure the socket is *blocking*.  We may pass large data, through
-        // it, and apparently non blocking read could cause some unexpected
-        // partial read on some systems.
-        setNonBlock(accept_sock_.fd, false);
     }
 
     // Then push one socket session via the forwarder.
@@ -588,6 +661,99 @@ TEST_F(ForwarderTest, pushTooFast) {
                  SocketSessionError);
 }
 
+TEST_F(ForwarderTest, badPop) {
+    startListen();
+
+    // Close the forwarder socket before pop() without sending anything.
+    pushSessionHeader(0, 0, false);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Pretending to be a forwarder but don't actually pass FD.
+    pushSessionHeader(0, 1, false);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Pass a valid FD (stdin), but provide short data for the hdrlen
+    pushSessionHeader(0, 1);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Pass a valid FD, but provides too large hdrlen
+    pushSessionHeader(0xffff);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Don't provide full header
+    pushSessionHeader(sizeof(uint32_t));
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Pushed header is too short
+    const uint8_t dummy_data = 0;
+    pushSessionHeader(1);
+    send(dummy_forwarder_.fd, &dummy_data, 1, 0);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // socket addresses commonly used below (the values don't matter).
+    const SockAddrInfo sai_local(getSockAddr("192.0.2.1", "53535"));
+    const SockAddrInfo sai_remote(getSockAddr("192.0.2.2", "53536"));
+    const SockAddrInfo sai6(getSockAddr("2001:db8::1", "53537"));
+
+    // Pass invalid address family (AF_UNSPEC)
+    pushSession(AF_UNSPEC, SOCK_DGRAM, IPPROTO_UDP, sai_local.second,
+                *sai_local.first, sai_remote.second, *sai_remote.first);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Pass inconsistent address family for local
+    pushSession(AF_INET, SOCK_DGRAM, IPPROTO_UDP, sai6.second,
+                *sai6.first, sai_remote.second, *sai_remote.first);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Same for remote
+    pushSession(AF_INET, SOCK_DGRAM, IPPROTO_UDP, sai_local.second,
+                *sai_local.first, sai6.second, *sai6.first);
+    dummy_forwarder_.reset(-1);
+
+    // Pass too big sa length for local
+    pushSession(AF_INET, SOCK_DGRAM, IPPROTO_UDP,
+                sizeof(struct sockaddr_storage) + 1, *sai_local.first,
+                sai_remote.second, *sai_remote.first);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Same for remote
+    pushSession(AF_INET, SOCK_DGRAM, IPPROTO_UDP, sai_local.second,
+                *sai_local.first, sizeof(struct sockaddr_storage) + 1,
+                *sai_remote.first);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Data length is too large
+    pushSession(AF_INET, SOCK_DGRAM, IPPROTO_UDP, sai_local.second,
+                *sai_local.first, sai_remote.second,
+                *sai_remote.first, 65536);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Empty data
+    pushSession(AF_INET, SOCK_DGRAM, IPPROTO_UDP, sai_local.second,
+                *sai_local.first, sai_remote.second,
+                *sai_remote.first, 0);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+
+    // Not full data are passed
+    pushSession(AF_INET, SOCK_DGRAM, IPPROTO_UDP, sai_local.second,
+                *sai_local.first, sai_remote.second,
+                *sai_remote.first, sizeof(TEST_DATA) + 1);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receptor_->pop(), SocketSessionError);
+}
+
 TEST(SocketSession, badValue) {
     // normal cases are confirmed in ForwarderTest.  We only check some
     // abnormal cases here.