Parcourir la source

[1820] use the ScopedSocket wrapper throughout the code for more safety.

JINMEI Tatuya il y a 13 ans
Parent
commit
9e77d650e3
1 fichiers modifiés avec 43 ajouts et 33 suppressions
  1. 43 33
      src/lib/resolve/tests/recursive_query_unittest.cc

+ 43 - 33
src/lib/resolve/tests/recursive_query_unittest.cc

@@ -20,6 +20,7 @@
 
 #include <cstring>
 
+#include <boost/noncopyable.hpp>
 #include <boost/lexical_cast.hpp>
 #include <boost/bind.hpp>
 #include <boost/scoped_ptr.hpp>
@@ -129,13 +130,22 @@ struct ScopedAddrInfo {
 
 // Similar to ScopedAddrInfo but for socket FD.  It also supports the "release"
 // operation so it can release the ownership of the FD.
-struct ScopedSocket {
+// This is made non copyable to avoid making an accidental copy, which could
+// result in duplicate close.
+struct ScopedSocket : private boost::noncopyable {
+    ScopedSocket() : s_(-1) {}
     ScopedSocket(int s) : s_(s) {}
     ~ScopedSocket() {
         if (s_ >= 0) {
             close(s_);
         }
     }
+    void reset(int new_s) {
+        if (s_ >= 0) {
+            close(s_);
+        }
+        s_ = new_s;
+    }
     int release() {
         int s = s_;
         s_ = -1;
@@ -163,9 +173,6 @@ protected:
         // It would delete itself, but after the io_service_, which could
         // segfailt in case there were unhandled requests
         resolver_.reset();
-        if (sock_ != -1) {
-            close(sock_);
-        }
     }
 
     // Send a test UDP packet to a mock server
@@ -173,11 +180,12 @@ protected:
         ScopedAddrInfo sai(resolveAddress(family, IPPROTO_UDP, false));
         struct addrinfo* res = sai.res_;
 
-        sock_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
-        if (sock_ < 0) {
+        sock_.reset(socket(res->ai_family, res->ai_socktype,
+                           res->ai_protocol));
+        if (sock_.s_ < 0) {
             isc_throw(IOError, "failed to open test socket");
         }
-        const int cc = sendto(sock_, test_data, sizeof(test_data), 0,
+        const int cc = sendto(sock_.s_, test_data, sizeof(test_data), 0,
                               res->ai_addr, res->ai_addrlen);
         if (cc != sizeof(test_data)) {
             isc_throw(IOError, "unexpected sendto result: " << cc);
@@ -190,14 +198,15 @@ protected:
         ScopedAddrInfo sai(resolveAddress(family, IPPROTO_TCP, false));
         struct addrinfo* res = sai.res_;
 
-        sock_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
-        if (sock_ < 0) {
+        sock_.reset(socket(res->ai_family, res->ai_socktype,
+                           res->ai_protocol));
+        if (sock_.s_ < 0) {
             isc_throw(IOError, "failed to open test socket");
         }
-        if (connect(sock_, res->ai_addr, res->ai_addrlen) < 0) {
+        if (connect(sock_.s_, res->ai_addr, res->ai_addrlen) < 0) {
             isc_throw(IOError, "failed to connect to the test server");
         }
-        const int cc = send(sock_, test_data, sizeof(test_data), 0);
+        const int cc = send(sock_.s_, test_data, sizeof(test_data), 0);
         if (cc != sizeof(test_data)) {
             isc_throw(IOError, "unexpected send result: " << cc);
         }
@@ -211,12 +220,13 @@ protected:
         ScopedAddrInfo sai(resolveAddress(family, IPPROTO_UDP, true));
         struct addrinfo* res = sai.res_;
 
-        sock_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
-        if (sock_ < 0) {
+        sock_.reset(socket(res->ai_family, res->ai_socktype,
+                           res->ai_protocol));
+        if (sock_.s_ < 0) {
             isc_throw(IOError, "failed to open test socket");
         }
 
-        if (bind(sock_, res->ai_addr, res->ai_addrlen) < 0) {
+        if (bind(sock_.s_, res->ai_addr, res->ai_addrlen) < 0) {
             isc_throw(IOError, "bind failed: " << strerror(errno));
         }
 
@@ -236,7 +246,7 @@ protected:
         // we add an ad hoc timeout.
         const struct timeval timeo = { 10, 0 };
         int recv_options = 0;
-        if (setsockopt(sock_, SOL_SOCKET, SO_RCVTIMEO, &timeo,
+        if (setsockopt(sock_.s_, SOL_SOCKET, SO_RCVTIMEO, &timeo,
                        sizeof(timeo))) {
             if (errno == ENOPROTOOPT) {
                 // Workaround for Solaris: it doesn't accept SO_RCVTIMEO
@@ -249,7 +259,7 @@ protected:
                 isc_throw(IOError, "set RCVTIMEO failed: " << strerror(errno));
             }
         }
-        const int ret = recv(sock_, buffer, size, recv_options);
+        const int ret = recv(sock_.s_, buffer, size, recv_options);
         if (ret < 0) {
             isc_throw(IOError, "recvfrom failed: " << strerror(errno));
         }
@@ -338,7 +348,7 @@ protected:
         // There doesn't seem to be an effective test for the validity of
         // 'native'.
         // One thing we are sure is it must be different from our local socket.
-        EXPECT_NE(sock_, callback_native_);
+        EXPECT_NE(sock_.s_, callback_native_);
         EXPECT_EQ(protocol, callback_protocol_);
         EXPECT_EQ(family == AF_INET6 ? TEST_IPV6_ADDR : TEST_IPV4_ADDR,
                   callback_address_);
@@ -495,14 +505,13 @@ protected:
     int callback_native_;
     string callback_address_;
     vector<uint8_t> callback_data_;
-    int sock_;
+    ScopedSocket sock_;
     boost::shared_ptr<isc::util::unittests::TestResolver> resolver_;
 };
 
 RecursiveQueryTest::RecursiveQueryTest() :
     dns_service_(NULL), callback_(NULL), callback_protocol_(0),
-    callback_native_(-1), sock_(-1),
-    resolver_(new isc::util::unittests::TestResolver())
+    callback_native_(-1), resolver_(new isc::util::unittests::TestResolver())
 {
     io_service_.reset(new IOService());
     setDNSService(true, true);
@@ -671,38 +680,39 @@ createTestSocket() {
     ScopedAddrInfo sai(resolveAddress(AF_INET, IPPROTO_UDP, true));
     struct addrinfo* res = sai.res_;
 
-    int sock_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
-    if (sock_ < 0) {
+    ScopedSocket sock(socket(res->ai_family, res->ai_socktype,
+                             res->ai_protocol));
+    if (sock.s_ < 0) {
         isc_throw(IOError, "failed to open test socket");
     }
-    if (bind(sock_, res->ai_addr, res->ai_addrlen) < 0) {
+    if (bind(sock.s_, res->ai_addr, res->ai_addrlen) < 0) {
         isc_throw(IOError, "failed to bind test socket");
     }
-    return (sock_);
+    return (sock.release());
 }
 
 int
-setSocketTimeout(int sock_, size_t tv_sec, size_t tv_usec) {
+setSocketTimeout(int sock, size_t tv_sec, size_t tv_usec) {
     const struct timeval timeo = { tv_sec, tv_usec };
     int recv_options = 0;
-    if (setsockopt(sock_, SOL_SOCKET, SO_RCVTIMEO, &timeo, sizeof(timeo))) {
+    if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &timeo, sizeof(timeo))) {
         if (errno == ENOPROTOOPT) { // see RecursiveQueryTest::recvUDP()
             recv_options = MSG_DONTWAIT;
         } else {
             isc_throw(IOError, "set RCVTIMEO failed: " << strerror(errno));
         }
     }
-    return recv_options;
+    return (recv_options);
 }
 
 // try to read from the socket max time
 // *num is incremented for every succesfull read
 // returns true if it can read max times, false otherwise
-bool tryRead(int sock_, int recv_options, size_t max, int* num) {
+bool tryRead(int sock, int recv_options, size_t max, int* num) {
     size_t i = 0;
     do {
         char inbuff[512];
-        if (recv(sock_, inbuff, sizeof(inbuff), recv_options) < 0) {
+        if (recv(sock, inbuff, sizeof(inbuff), recv_options) < 0) {
             return false;
         } else {
             ++i;
@@ -752,7 +762,7 @@ TEST_F(RecursiveQueryTest, forwardQueryTimeout) {
     setDNSService();
 
     // Prepare the socket
-    sock_ = createTestSocket();
+    sock_.reset(createTestSocket());
 
     // Prepare the server
     bool done(true);
@@ -786,7 +796,7 @@ TEST_F(RecursiveQueryTest, forwardClientTimeout) {
     // Prepare the service (we do not use the common setup, we do not answer
     setDNSService();
 
-    sock_ = createTestSocket();
+    sock_.reset(createTestSocket());
 
     // Prepare the server
     bool done1(true);
@@ -820,7 +830,7 @@ TEST_F(RecursiveQueryTest, forwardLookupTimeout) {
     setDNSService();
 
     // Prepare the socket
-    sock_ = createTestSocket();
+    sock_.reset(createTestSocket());
 
     // Prepare the server
     bool done(true);
@@ -855,7 +865,7 @@ TEST_F(RecursiveQueryTest, lowtimeouts) {
     setDNSService();
 
     // Prepare the socket
-    sock_ = createTestSocket();
+    sock_.reset(createTestSocket());
 
     // Prepare the server
     bool done(true);