Parcourir la source

overall cleanup and minor bug fixes for xfrout:
- pass xfr request data length in network byte order, and receive it as such. receiving this as native data (unpack('H')) would be a bit naive because it may assume some padding.
- avoid cast as a bonus side effect of the first fix
- make cmsg space handling more portable. use CMSG_SPACE when possible; otherwise calculate the space using dummy data
- overall simplied fd_share code: trying to handle multiple FDs is not correct in terms of API, and in any case we don't need that; verify cmsg_level and cmsg_type on reception; avoid unnecessary data initialization; stop naively assume the exact memory allocation of cmsg_data (assuming no padding) - use cmsg_space() so that padding, if any, will be taken intou account
- catch the error case of recv_fd failure in XfroutSession.handle() explicitly, and raise an exception


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1694 e5f2f494-b856-4b98-b285-d166d9295462

JINMEI Tatuya il y a 15 ans
Parent
commit
ce604aaf99
4 fichiers modifiés avec 87 ajouts et 92 suppressions
  1. 4 3
      src/bin/xfrout/xfrout.py.in
  2. 73 69
      src/lib/xfr/fd_share.cc
  3. 10 17
      src/lib/xfr/xfrout_client.cc
  4. 0 3
      src/lib/xfr/xfrout_client.h

+ 4 - 3
src/bin/xfrout/xfrout.py.in

@@ -53,8 +53,10 @@ class XfroutException(Exception): pass
 class XfroutSession(BaseRequestHandler):
     def handle(self):
         fd = recv_fd(self.request.fileno())
+        if fd < 0:
+            raise XfroutException("failed to receive the FD for XFR connection")
         data_len = self.request.recv(2)
-        msg_len = struct.unpack('H', data_len)[0]
+        msg_len = struct.unpack('!H', data_len)[0]
         msgdata = self.request.recv(msg_len)
         sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
         try:
@@ -64,7 +66,7 @@ class XfroutSession(BaseRequestHandler):
                 self.log_msg(str(e))
 
         sock.close()
-               
+
     def _parse_query_message(self, mdata):
         ''' parse query message to [socket,message]'''
         #TODO, need to add parseHeader() in case the message header is invalid 
@@ -78,7 +80,6 @@ class XfroutSession(BaseRequestHandler):
 
         return rcode.NOERROR(), msg
 
-
     def _get_query_zone_name(self, msg):
         q_iter = question_iter(msg)
         question = q_iter.get_question()

+ 73 - 69
src/lib/xfr/fd_share.cc

@@ -14,7 +14,9 @@
 
 // $Id$
 
-#include <stdlib.h>
+#include <cstring>
+#include <cstdlib>
+
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/uio.h>
@@ -23,98 +25,100 @@
 namespace isc {
 namespace xfr {
 
-#define FD_BUFFER_CREATE(n) \
-    struct { \
-        struct cmsghdr h; \
-        int fd[n]; \
+namespace {
+// Not all OSes support advanced CMSG macros: CMSG_LEN and CMSG_SPACE.
+// In order to ensure as much portability as possible, we provide wrapper
+// functions of these macros.
+// Note that cmsg_space() could run slow on OSes that do not have
+// CMSG_SPACE.
+inline socklen_t
+cmsg_space(socklen_t len) {
+#ifdef CMSG_SPACE
+    return (CMSG_SPACE(len));
+#else
+    struct msghdr msg;
+    struct cmsghdr* cmsgp;
+    // XXX: The buffer length is an ad hoc value, but should be enough
+    // in a practical sense.
+    char dummybuf[sizeof(struct cmsghdr) + 1024];
+
+    memset(&msg, 0, sizeof(msg));
+    msg.msg_control = dummybuf;
+    msg.msg_controllen = sizeof(dummybuf);
+
+    cmsgp = (struct cmsghdr*)dummybuf;
+    cmsgp->cmsg_len = cmsg_len(len);
+
+    cmsgp = CMSG_NXTHDR(&msg, cmsgp);
+    if (cmsgp != NULL) {
+        return ((char*)cmsgp - (char*)msg.msg_control);
+    } else {
+        return (0);
     }
+#endif  // CMSG_SPACE
+}
+}
 
-namespace {
 int
-send_fds_with_buffer(const int sock, const int* fds, const unsigned n_fds,
-                     void* buffer)
-{
+recv_fd(const int sock) {
     struct msghdr msghdr;
-    char nothing = '!';
-    struct iovec nothing_ptr;
-    struct cmsghdr* cmsg;
+    struct iovec iov_dummy;
+    unsigned char dummy_data;
 
-    nothing_ptr.iov_base = &nothing;
-    nothing_ptr.iov_len = 1;
+    iov_dummy.iov_base = &dummy_data;
+    iov_dummy.iov_len = sizeof(dummy_data);
     msghdr.msg_name = NULL;
     msghdr.msg_namelen = 0;
-    msghdr.msg_iov = &nothing_ptr;
+    msghdr.msg_iov = &iov_dummy;
     msghdr.msg_iovlen = 1;
     msghdr.msg_flags = 0;
-    msghdr.msg_control = buffer;
-    msghdr.msg_controllen = sizeof(struct cmsghdr) + sizeof(int) * n_fds;
-    cmsg = CMSG_FIRSTHDR(&msghdr);
-    cmsg->cmsg_len = msghdr.msg_controllen;
-    cmsg->cmsg_level = SOL_SOCKET;
-    cmsg->cmsg_type = SCM_RIGHTS;
-    for (int i = 0; i < n_fds; ++i) {
-        ((int *)CMSG_DATA(cmsg))[i] = fds[i];
+    msghdr.msg_controllen = cmsg_space(sizeof(int));
+    msghdr.msg_control = malloc(msghdr.msg_controllen);
+    if (msghdr.msg_control == NULL) {
+        return (-1);
     }
 
-    const int ret =  sendmsg(sock, &msghdr, 0);
-    return (ret >= 0 ? 0 : -1);
+    if (recvmsg(sock, &msghdr, 0) < 0) {
+        free(msghdr.msg_control);
+        return (-1);
+    }
+    const struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msghdr);
+    int fd = -1;
+    if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
+        fd = *(const int *)CMSG_DATA(cmsg);
+    }
+    free(msghdr.msg_control);
+    return (fd);
 }
 
 int
-recv_fds_with_buffer(const int sock, int* fds, const unsigned n_fds,
-                     void* buffer)
-{
+send_fd(const int sock, const int fd) {
     struct msghdr msghdr;
-    char nothing;
-    struct iovec nothing_ptr;
-    struct cmsghdr *cmsg;
-    int i;
+    struct iovec iov_dummy;
+    unsigned char dummy_data = 0;
 
-    nothing_ptr.iov_base = &nothing;
-    nothing_ptr.iov_len = 1;
+    iov_dummy.iov_base = &dummy_data;
+    iov_dummy.iov_len = sizeof(dummy_data);
     msghdr.msg_name = NULL;
     msghdr.msg_namelen = 0;
-    msghdr.msg_iov = &nothing_ptr;
+    msghdr.msg_iov = &iov_dummy;
     msghdr.msg_iovlen = 1;
     msghdr.msg_flags = 0;
-    msghdr.msg_control = buffer;
-    msghdr.msg_controllen = sizeof(struct cmsghdr) + sizeof(int) * n_fds;
-    cmsg = CMSG_FIRSTHDR(&msghdr);
-    cmsg->cmsg_len = msghdr.msg_controllen;
-    cmsg->cmsg_level = SOL_SOCKET;
-    cmsg->cmsg_type = SCM_RIGHTS;
-    for (i = 0; i < n_fds; i++) {
-        ((int *)CMSG_DATA(cmsg))[i] = -1;
-    }
-
-    if (recvmsg(sock, &msghdr, 0) < 0) {
+    msghdr.msg_controllen = cmsg_space(sizeof(int));
+    msghdr.msg_control = malloc(msghdr.msg_controllen);
+    if (msghdr.msg_control == NULL) {
         return (-1);
     }
 
-    for (i = 0; i < n_fds; i++) {
-        fds[i] = ((int *)CMSG_DATA(cmsg))[i];
-    }
-
-    return ((msghdr.msg_controllen - sizeof(struct cmsghdr)) / sizeof(int));
-}
-}
-
-int
-recv_fd(const int sock) {
-    FD_BUFFER_CREATE(1) buffer;
-    int fd = 0;
-    if (recv_fds_with_buffer(sock, &fd, 1, &buffer) == -1) {
-        return -1;
-    }
-
-    return fd;
-}
+    struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msghdr);
+    cmsg->cmsg_len = msghdr.msg_controllen;
+    cmsg->cmsg_level = SOL_SOCKET;
+    cmsg->cmsg_type = SCM_RIGHTS;
+    *(int *)CMSG_DATA(cmsg) = fd;
 
-int
-send_fd(const int sock, const int fd) {
-    FD_BUFFER_CREATE(1) buffer;
-    int ret = send_fds_with_buffer(sock, &fd, 1, &buffer);
-    return ((ret < 0) ? -1 : ret);
+    const int ret = sendmsg(sock, &msghdr, 0);
+    free(msghdr.msg_control);
+    return (ret >= 0 ? 0 : -1);
 }
 
 } // End for namespace xfr

+ 10 - 17
src/lib/xfr/xfrout_client.cc

@@ -35,21 +35,6 @@ XfroutClient::disconnect() {
     socket_.close();
 }
 
-void
-XfroutClient::sendData(const uint8_t* msg_data, const uint16_t msg_len) {
-    int count = 0;
-    while (count < msg_len) {
-        const int size = send(socket_.native(), msg_data + count,
-                              msg_len - count, 0);
-        if (size == -1) {
-            isc_throw(XfroutError, "auth failed to send data to xfrout module");
-        }
-        count += size;
-    }
-
-    return;
-}
-
 int 
 XfroutClient::sendXfroutRequestInfo(const int tcp_sock, uint8_t* msg_data,
                                     const uint16_t msg_len)
@@ -59,8 +44,16 @@ XfroutClient::sendXfroutRequestInfo(const int tcp_sock, uint8_t* msg_data,
                   "Fail to send socket descriptor to xfrout module");
     }
 
-    sendData((uint8_t*)&msg_len, 2);
-    sendData(msg_data, msg_len);
+    // XXX: this shouldn't be blocking send, even though it's unlikely to block.
+    const uint8_t lenbuf[2] = { msg_len >> 8, msg_len & 0xff };
+    if (send(socket_.native(), lenbuf, sizeof(lenbuf), 0) != sizeof(lenbuf)) {
+        isc_throw(XfroutError,
+                  "failed to send XFR request length to xfrout module");
+    }
+    if (send(socket_.native(), msg_data, msg_len, 0) != msg_len) {
+        isc_throw(XfroutError,
+                  "failed to send XFR request data to xfrout module");
+    }
     
     int databuf = 0;
     if (recv(socket_.native(), &databuf, sizeof(int), 0) != 0) {

+ 0 - 3
src/lib/xfr/xfrout_client.h

@@ -43,9 +43,6 @@ public:
                               uint16_t msg_len);
 
 private:
-    void sendData(const uint8_t *msg_data, uint16_t msg_len);
-
-private:
     boost::asio::io_service io_service_;
     // The socket used to communicate with the xfrout server.
     stream_protocol::socket socket_;