Browse Source

[1452] fixed another error handling issue: make sure the passed FD (if
it succeeds) is closed if an exception is thrown in the middle of
SocketSessionReceiver::pop(). Also clarified the responsibility of
the caller of pop on which resource it should release which it shouldn't.

JINMEI Tatuya 13 years ago
parent
commit
f709af9e07

+ 24 - 4
src/lib/util/io/socketsession.cc

@@ -33,6 +33,8 @@
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
 
 
+#include <boost/noncopyable.hpp>
+
 #include <exceptions/exceptions.h>
 #include <exceptions/exceptions.h>
 
 
 #include <util/buffer.h>
 #include <util/buffer.h>
@@ -320,15 +322,33 @@ readFail(int actual_len, int expected_len) {
               "SocketSessionForwarder: " << actual_len << "/" <<
               "SocketSessionForwarder: " << actual_len << "/" <<
               expected_len);
               expected_len);
 }
 }
+
+// A helper container for a (socket) file descriptor used in
+// SocketSessionReceiver::pop that ensures the socket is closed unless it
+// can be safely passed to the caller via release().
+struct ScopedSocket : boost::noncopyable {
+    ScopedSocket(int fd) : fd_(fd) {}
+    ~ScopedSocket() {
+        if (fd_ >= 0) {
+            close(fd_);
+        }
+    }
+    int release() {
+        const int fd = fd_;
+        fd_ = -1;
+        return (fd);
+    }
+    int fd_;
+};
 }
 }
 
 
 SocketSession
 SocketSession
 SocketSessionReceiver::pop() {
 SocketSessionReceiver::pop() {
-    const int passed_fd = recv_fd(impl_->fd_);
-    if (passed_fd == FD_SYSTEM_ERROR) {
+    ScopedSocket passed_sock(recv_fd(impl_->fd_));
+    if (passed_sock.fd_ == FD_SYSTEM_ERROR) {
         isc_throw(SocketSessionError, "Receiving a forwarded FD failed: " <<
         isc_throw(SocketSessionError, "Receiving a forwarded FD failed: " <<
                   strerror(errno));
                   strerror(errno));
-    } else if (passed_fd < 0) {
+    } else if (passed_sock.fd_ < 0) {
         isc_throw(SocketSessionError, "No FD forwarded");
         isc_throw(SocketSessionError, "No FD forwarded");
     }
     }
 
 
@@ -398,7 +418,7 @@ SocketSessionReceiver::pop() {
             readFail(cc_data, data_len);
             readFail(cc_data, data_len);
         }
         }
 
 
-        return (SocketSession(passed_fd, family, type, protocol,
+        return (SocketSession(passed_sock.release(), family, type, protocol,
                               impl_->sa_local_, impl_->sa_remote_,
                               impl_->sa_local_, impl_->sa_remote_,
                               &impl_->data_buf_[0], data_len));
                               &impl_->data_buf_[0], data_len));
     } catch (const InvalidBufferPosition& ex) {
     } catch (const InvalidBufferPosition& ex) {

+ 8 - 0
src/lib/util/io/socketsession.h

@@ -421,6 +421,14 @@ public:
     /// this method is called or until the \c SocketSessionReceiver object is
     /// this method is called or until the \c SocketSessionReceiver object is
     /// destructed.
     /// destructed.
     ///
     ///
+    /// The caller is responsible for closing the received socket (whose
+    /// file descriptor is accessible via \c SocketSession::getSocket()).
+    /// If the caller copies the returned \c SocketSession object, it's also
+    /// responsible for making sure the descriptor is closed at most once.
+    /// On the other hand, the caller is not responsible for freeing the
+    /// socket session data (accessible via \c SocketSession::getData());
+    /// the \c SocketSessionReceiver object will clean it up automatically.
+    ///
     /// It ensures the following:
     /// It ensures the following:
     /// - The address family is either \c AF_INET or \c AF_INET6
     /// - The address family is either \c AF_INET or \c AF_INET6
     /// - The address family (\c sa_family) member of the local and remote
     /// - The address family (\c sa_family) member of the local and remote

+ 18 - 2
src/lib/util/tests/socketsession_unittest.cc

@@ -280,13 +280,14 @@ protected:
     //                 but can be false for testing.
     //                 but can be false for testing.
     void pushSessionHeader(uint16_t hdrlen,
     void pushSessionHeader(uint16_t hdrlen,
                            size_t hdrlen_len = sizeof(uint16_t),
                            size_t hdrlen_len = sizeof(uint16_t),
-                           bool push_fd = true)
+                           bool push_fd = true,
+                           int fd = 0)
     {
     {
         isc::util::OutputBuffer obuffer(0);
         isc::util::OutputBuffer obuffer(0);
         obuffer.clear();
         obuffer.clear();
 
 
         dummy_forwarder_.reset(dummyConnect());
         dummy_forwarder_.reset(dummyConnect());
-        if (push_fd && send_fd(dummy_forwarder_.fd, 0) != 0) {
+        if (push_fd && send_fd(dummy_forwarder_.fd, fd) != 0) {
             isc_throw(isc::Unexpected, "Failed to pass FD");
             isc_throw(isc::Unexpected, "Failed to pass FD");
         }
         }
         obuffer.writeUint16(hdrlen);
         obuffer.writeUint16(hdrlen);
@@ -802,6 +803,21 @@ TEST_F(ForwardTest, badPop) {
                 *sai_remote.first, sizeof(TEST_DATA) + 1);
                 *sai_remote.first, sizeof(TEST_DATA) + 1);
     dummy_forwarder_.reset(-1);
     dummy_forwarder_.reset(-1);
     EXPECT_THROW(receiver_->pop(), SocketSessionError);
     EXPECT_THROW(receiver_->pop(), SocketSessionError);
+
+    // Check the forwarded FD is closed on failure
+    ScopedSocket sock(createSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP,
+                                   getSockAddr("127.0.0.1", TEST_PORT),
+                                   false));
+    pushSessionHeader(0, 1, true, sock.fd);
+    dummy_forwarder_.reset(-1);
+    EXPECT_THROW(receiver_->pop(), SocketSessionError);
+    // Close the original socket
+    sock.reset(-1);
+    // The passed one should have been closed, too, so we should be able
+    // to bind a new socket to the same port.
+    ScopedSocket(createSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP,
+                              getSockAddr("127.0.0.1", TEST_PORT),
+                              false));
 }
 }
 
 
 TEST(SocketSessionTest, badValue) {
 TEST(SocketSessionTest, badValue) {