Browse Source

completed IOSocket and IOMessage with tests.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac221@2175 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya 15 years ago
parent
commit
92fbe8e211
3 changed files with 312 additions and 26 deletions
  1. 117 16
      src/bin/auth/asio_link.cc
  2. 30 10
      src/bin/auth/asio_link.h
  3. 165 0
      src/bin/auth/tests/asio_link_unittest.cc

+ 117 - 16
src/bin/auth/asio_link.cc

@@ -16,6 +16,9 @@
 
 #include <config.h>
 
+#include <sys/socket.h>
+#include <netinet/in.h>
+
 #include <asio.hpp>
 #include <boost/bind.hpp>
 
@@ -34,8 +37,8 @@
 #include "auth_srv.h"
 
 using namespace asio;
-using ip::udp;
-using ip::tcp;
+using asio::ip::udp;
+using asio::ip::tcp;
 
 using namespace std;
 using namespace isc::dns;
@@ -120,6 +123,36 @@ IOAddress::toText() const {
     return (asio_address_.to_string());
 }
 
+class TCPSocket : public IOSocket {
+private:
+    TCPSocket(const TCPSocket& source);
+    TCPSocket& operator=(const TCPSocket& source);
+public:
+    TCPSocket(tcp::socket& socket) : socket_(socket) {}
+    virtual int getNative() const { return (socket_.native()); }
+    virtual int getProtocol() const { return (IPPROTO_TCP); }
+private:
+    tcp::socket& socket_;
+};
+
+class UDPSocket : public IOSocket {
+private:
+    UDPSocket(const UDPSocket& source);
+    UDPSocket& operator=(const UDPSocket& source);
+public:
+    UDPSocket(udp::socket& socket) : socket_(socket) {}
+    virtual int getNative() const { return (socket_.native()); }
+    virtual int getProtocol() const { return (IPPROTO_UDP); }
+private:
+    udp::socket& socket_;
+};
+
+IOMessage::IOMessage(const void* data, const size_t data_size,
+                     IOSocket& io_socket, const ip::address& remote_address) :
+    data_(data), data_size_(data_size), io_socket_(io_socket),
+    remote_io_address_(remote_address)
+{}
+
 //
 // Helper classes for asynchronous I/O using asio
 //
@@ -128,15 +161,18 @@ public:
     TCPClient(AuthSrv* auth_server, io_service& io_service) :
         auth_server_(auth_server),
         socket_(io_service),
+        io_socket_(socket_),
         response_buffer_(0),
         responselen_buffer_(TCP_MESSAGE_LENGTHSIZE),
         response_renderer_(response_buffer_),
-        dns_message_(Message::PARSE)
+        dns_message_(Message::PARSE),
+        custom_callback_(NULL)
     {}
 
     void start() {
         // Check for queued configuration commands
-        if (auth_server_->configSession()->hasQueuedMsgs()) {
+        if (auth_server_ != NULL &&
+            auth_server_->configSession()->hasQueuedMsgs()) {
             auth_server_->configSession()->checkCommand();
         }
         async_read(socket_, asio::buffer(data_, TCP_MESSAGE_LENGTHSIZE),
@@ -145,7 +181,7 @@ public:
                                placeholders::bytes_transferred));
     }
 
-    tcp::socket& getSocket() { return (socket_); }
+    ip::tcp::socket& getSocket() { return (socket_); }
 
     void headerRead(const asio::error_code& error,
                     size_t bytes_transferred)
@@ -168,6 +204,15 @@ public:
                      size_t bytes_transferred)
     {
         if (!error) {
+            const IOMessage io_message(data_, bytes_transferred, io_socket_,
+                                       socket_.remote_endpoint().address());
+            // currently, for testing purpose only
+            if (custom_callback_ != NULL) {
+                (*custom_callback_)(io_message);
+                start();
+                return;
+            }
+
             InputBuffer dnsbuffer(data_, bytes_transferred);
 #ifdef USE_XFROUT
             if (check_axfr_query(data_, bytes_transferred)) {
@@ -176,14 +221,6 @@ public:
                 start();
             } else {
 #endif
-#ifdef notyet
-                IOMessage io_message(data_, bytes_transferred,
-                                     remote_endpoint, socket);
-                if (auth_server_->processMessage(IOMessage(message), ..)) {
-                    //...
-                    message.getIOService().
-                }
-#endif
                 if (auth_server_->processMessage(dnsbuffer, dns_message_,
                                                 response_renderer_, false)) {
                     responselen_buffer_.writeUint16(
@@ -225,9 +262,15 @@ public:
       }
     }
 
+    // Currently this is for tests only
+    void setCallBack(const IOService::IOCallBack* callback) {
+        custom_callback_ = callback;
+    }
+
 private:
     AuthSrv* auth_server_;
     tcp::socket socket_;
+    TCPSocket io_socket_;
     OutputBuffer response_buffer_;
     OutputBuffer responselen_buffer_;
     MessageRenderer response_renderer_;
@@ -235,6 +278,9 @@ private:
     enum { MAX_LENGTH = 65535 };
     static const size_t TCP_MESSAGE_LENGTHSIZE = 2;
     char data_[MAX_LENGTH];
+
+    // currently, for testing purpose only.
+    const IOService::IOCallBack* custom_callback_;
 };
 
 class TCPServer {
@@ -243,7 +289,8 @@ public:
               int af, short port) :
         auth_server_(auth_server), io_service_(io_service),
         acceptor_(io_service_), listening_(new TCPClient(auth_server_,
-                                                         io_service_))
+                                                         io_service_)),
+        custom_callback_(NULL)
     {
         tcp::endpoint endpoint(af == AF_INET6 ? tcp::v6() : tcp::v4(), port);
         acceptor_.open(endpoint.protocol());
@@ -267,6 +314,7 @@ public:
     {
         if (!error) {
             assert(new_client == listening_);
+            new_client->setCallBack(custom_callback_);
             new_client->start();
             listening_ = new TCPClient(auth_server_, io_service_);
             acceptor_.async_accept(listening_->getSocket(),
@@ -278,11 +326,19 @@ public:
         }
     }
 
+    // Currently this is for tests only
+    void setCallBack(const IOService::IOCallBack* callback) {
+        custom_callback_ = callback;
+    }
+
 private:
     AuthSrv* auth_server_;
     io_service& io_service_;
     tcp::acceptor acceptor_;
     TCPClient* listening_;
+
+    // currently, for testing purpose only.
+    const IOService::IOCallBack* custom_callback_;
 };
 
 class UDPServer {
@@ -292,9 +348,11 @@ public:
         auth_server_(auth_server),
         io_service_(io_service),
         socket_(io_service, af == AF_INET6 ? udp::v6() : udp::v4()),
+        io_socket_(socket_),
         response_buffer_(0),
         response_renderer_(response_buffer_),
-        dns_message_(Message::PARSE)
+        dns_message_(Message::PARSE),
+        custom_callback_(NULL)
     {
         // Set v6-only (we use a different instantiation for v4,
         // otherwise asio will bind to both v4 and v6
@@ -311,10 +369,20 @@ public:
                        size_t bytes_recvd)
     {
         // Check for queued configuration commands
-        if (auth_server_->configSession()->hasQueuedMsgs()) {
+        if (auth_server_ != NULL &&
+            auth_server_->configSession()->hasQueuedMsgs()) {
             auth_server_->configSession()->checkCommand();
         }
         if (!error && bytes_recvd > 0) {
+            const IOMessage io_message(data_, bytes_recvd, io_socket_,
+                                       sender_endpoint_.address());
+            // currently, for testing purpose only
+            if (custom_callback_ != NULL) {
+                (*custom_callback_)(io_message);
+                startReceive();
+                return;
+            }
+
             InputBuffer request_buffer(data_, bytes_recvd);
 
             dns_message_.clear(Message::PARSE);
@@ -344,6 +412,11 @@ public:
         // the next request.
         startReceive();
     }
+
+    // Currently this is for tests only
+    void setCallBack(const IOService::IOCallBack* callback) {
+        custom_callback_ = callback;
+    }
 private:
     void startReceive() {
         socket_.async_receive_from(
@@ -357,12 +430,16 @@ private:
     AuthSrv* auth_server_;
     io_service& io_service_;
     udp::socket socket_;
+    UDPSocket io_socket_;
     OutputBuffer response_buffer_;
     MessageRenderer response_renderer_;
     Message dns_message_;
     udp::endpoint sender_endpoint_;
     enum { MAX_LENGTH = 4096 };
     char data_[MAX_LENGTH];
+
+    // currently, for testing purpose only.
+    const IOService::IOCallBack* custom_callback_;
 };
 
 // This is a helper structure just to make the construction of IOServiceImpl
@@ -396,6 +473,9 @@ public:
     UDPServer* udp6_server_;
     TCPServer* tcp4_server_;
     TCPServer* tcp6_server_;
+
+    // This member is used only for testing at the moment.
+    IOService::IOCallBack callback_;
 };
 
 IOServiceImpl::IOServiceImpl(AuthSrv* auth_server, const char* const port,
@@ -409,14 +489,18 @@ IOServiceImpl::IOServiceImpl(AuthSrv* auth_server, const char* const port,
     if (use_ipv4) {
         servers.udp4_server = new UDPServer(auth_server, io_service_,
                                             AF_INET, portnum);
+        udp4_server_ = servers.udp4_server;
         servers.tcp4_server = new TCPServer(auth_server, io_service_,
                                             AF_INET, portnum);
+        tcp4_server_ = servers.tcp4_server;
     }
     if (use_ipv6) {
         servers.udp6_server = new UDPServer(auth_server, io_service_,
                                             AF_INET6, portnum);
+        udp6_server_ = servers.udp6_server;
         servers.tcp6_server = new TCPServer(auth_server, io_service_,
                                             AF_INET6, portnum);
+        tcp6_server_ = servers.tcp6_server;
     }
 
     // Now we don't have to worry about exception, and need to make sure that
@@ -457,4 +541,21 @@ asio::io_service&
 IOService::get_io_service() {
     return impl_->io_service_;
 }
+
+void
+IOService::setCallBack(const IOCallBack callback) {
+    impl_->callback_ = callback;
+    if (impl_->udp4_server_ != NULL) {
+        impl_->udp4_server_->setCallBack(&impl_->callback_);
+    }
+    if (impl_->udp6_server_ != NULL) {
+        impl_->udp6_server_->setCallBack(&impl_->callback_);
+    }
+    if (impl_->tcp4_server_ != NULL) {
+        impl_->tcp4_server_->setCallBack(&impl_->callback_);
+    }
+    if (impl_->tcp6_server_ != NULL) {
+        impl_->tcp6_server_->setCallBack(&impl_->callback_);
+    }
+}
 }

+ 30 - 10
src/bin/auth/asio_link.h

@@ -17,13 +17,15 @@
 #ifndef __ASIO_LINK_H
 #define __ASIO_LINK_H 1
 
+#include <functional>
 #include <string>
 
+#include <boost/function.hpp>
+
 #include <exceptions/exceptions.h>
 
 namespace asio {
 class io_service;
-
 namespace ip {
 class address;
 }
@@ -34,7 +36,7 @@ class AuthSrv;
 namespace asio_link {
 struct IOServiceImpl;
 
-/// \brief An exception that is thrown if an error occurs with in the IO
+/// \brief An exception that is thrown if an error occurs within the IO
 /// module.  This is mainly intended to be a wrapper exception class for
 /// ASIO specific exceptions.
 class IOError : public isc::Exception {
@@ -57,20 +59,34 @@ private:
     const asio::ip::address& asio_address_;
 };
 
+class IOSocket {
+private:
+    IOSocket(const IOSocket& source);
+    IOSocket& operator=(const IOSocket& source);
+protected:
+    IOSocket() {}
+public:
+    virtual ~IOSocket() {}
+    virtual int getNative() const = 0;
+    virtual int getProtocol() const = 0;
+};
+
 class IOMessage {
 private:
     IOMessage(const IOMessage& source);
     IOMessage& operator=(const IOMessage& source);
 public:
-    IOMessage();
-    ~IOMessage();
-    const void* getData() const;
-    size_t getDataSize() const;
-    int getNative() const;
-    const IOAddress& getRemoteAddress() const;
+    IOMessage(const void* data, size_t data_size, IOSocket& io_socket,
+              const asio::ip::address& remote_address);
+    const void* getData() const { return (data_); }
+    size_t getDataSize() const { return (data_size_); }
+    const IOSocket& getSocket() const { return (io_socket_); }
+    const IOAddress& getRemoteAddress() const { return (remote_io_address_); }
 private:
-    class IOMessageImpl;
-    IOMessageImpl* impl_;
+    const void* data_;
+    const size_t data_size_;
+    IOSocket& io_socket_;
+    IOAddress remote_io_address_;
 };
 
 class IOService {
@@ -81,6 +97,10 @@ public:
     void run();
     void stop();
     asio::io_service& get_io_service();
+    /// Right now this method is only for testing, but will eventually be
+    /// generalized.
+    typedef boost::function<void(const IOMessage& io_message)> IOCallBack;
+    void setCallBack(IOCallBack callback);
 private:
     IOServiceImpl* impl_;
 };

+ 165 - 0
src/bin/auth/tests/asio_link_unittest.cc

@@ -14,13 +14,37 @@
 
 // $Id$
 
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netdb.h>
+
+#include <stdint.h>
+
+#include <functional>
+#include <string>
+#include <vector>
+
 #include <gtest/gtest.h>
 
+#include <exceptions/exceptions.h>
+
+#include <dns/tests/unittest_util.h>
+
 #include <auth/asio_link.h>
 
+using isc::UnitTestUtil;
+using namespace std;
 using namespace asio_link;
 
 namespace {
+const char* const TEST_PORT = "53535";
+const char* const TEST_IPV6_ADDR = "::1";
+const char* const TEST_IPV4_ADDR = "127.0.0.1";
+// This data is intended to be valid as a DNS/TCP-like message: the first
+// two octets encode the length of the rest of the data.  This is crucial
+// for the tests below.
+const uint8_t test_data[] = {0, 4, 1, 2, 3, 4};
+
 TEST(IOAddressTest, fromText) {
     IOAddress io_address_v4("192.0.2.1");
     EXPECT_EQ("192.0.2.1", io_address_v4.toText());
@@ -34,4 +58,145 @@ TEST(IOAddressTest, fromText) {
     // bogus IPv6 address-like input
     EXPECT_THROW(IOAddress("2001:db8:::1234"), IOError);
 }
+
+struct addrinfo*
+resolveAddress(const int family, const int sock_type, const int protocol) {
+    const char* const addr = (family == AF_INET6) ?
+        TEST_IPV6_ADDR : TEST_IPV4_ADDR;
+
+    struct addrinfo hints;
+    memset(&hints, 0, sizeof(hints));
+    hints.ai_family = family;
+    hints.ai_socktype = sock_type;
+    hints.ai_protocol = protocol;
+
+    struct addrinfo* res;
+    const int error = getaddrinfo(addr, TEST_PORT, &hints, &res);
+    if (error != 0) {
+        isc_throw(IOError, "getaddrinfo failed: " << gai_strerror(error));
+    }
+
+    return (res);
+}
+
+class ASIOLinkTest : public ::testing::Test {
+protected:
+    ASIOLinkTest();
+    ~ASIOLinkTest() {
+        if (res_ != NULL) {
+            freeaddrinfo(res_);
+        }
+        if (sock_ != -1) {
+            close(sock_);
+        }
+    }
+    void sendUDP(const int family) {
+        res_ = resolveAddress(family, SOCK_DGRAM, IPPROTO_UDP);
+
+        sock_ = socket(res_->ai_family, res_->ai_socktype, res_->ai_protocol);
+        if (sock_ < 0) {
+            isc_throw(IOError, "failed to open test socket");
+        }
+        const int cc = sendto(sock_, test_data, sizeof(test_data), 0,
+                              res_->ai_addr, res_->ai_addrlen);
+        if (cc != sizeof(test_data)) {
+            isc_throw(IOError, "unexpected sendto result: " << cc);
+        }
+        io_service_.run();
+    }
+    void sendTCP(const int family) {
+        res_ = resolveAddress(family, SOCK_STREAM, IPPROTO_TCP);
+
+        sock_ = socket(res_->ai_family, res_->ai_socktype, res_->ai_protocol);
+        if (sock_ < 0) {
+            isc_throw(IOError, "failed to open test socket");
+        }
+        if (connect(sock_, res_->ai_addr, res_->ai_addrlen) < 0) {
+            isc_throw(IOError, "failed to connect to the test server");
+        }
+        const int cc = send(sock_, test_data, sizeof(test_data), 0);
+        if (cc != sizeof(test_data)) {
+            isc_throw(IOError, "unexpected sendto result: " << cc);
+        }
+        io_service_.run();
+    }
+public:
+    void callBack(const IOMessage& io_message) {
+        callback_protocoal_ = io_message.getSocket().getProtocol();
+        callback_native_ = io_message.getSocket().getNative();
+        callback_address_ = io_message.getRemoteAddress().toText();
+        callback_data_.assign(
+            static_cast<const uint8_t*>(io_message.getData()),
+            static_cast<const uint8_t*>(io_message.getData()) +
+            io_message.getDataSize());
+        io_service_.stop();
+    }
+protected:
+    IOService io_service_;
+    int callback_protocoal_;
+    int callback_native_;
+    string callback_address_;
+    vector<uint8_t> callback_data_;
+    int sock_;
+private:
+    struct addrinfo* res_;
+};
+
+class ASIOCallBack : public std::unary_function<IOMessage, void> {
+public:
+    ASIOCallBack(ASIOLinkTest* test_obj) : test_obj_(test_obj) {}
+    void operator()(const IOMessage& io_message) const {
+        test_obj_->callBack(io_message);
+    }
+private:
+    ASIOLinkTest* test_obj_;
+};
+
+ASIOLinkTest::ASIOLinkTest() :
+    io_service_(NULL, TEST_PORT, true, true),
+    sock_(-1), res_(NULL)
+{
+    io_service_.setCallBack(ASIOCallBack(this));
+}
+
+TEST_F(ASIOLinkTest, v6UDPSend) {
+    sendUDP(AF_INET6);
+    // There doesn't seem to be an effective test for the validity of 'native'.
+    // One thing we are sure is it must be different from our local socket.
+    EXPECT_NE(callback_native_, sock_);
+    EXPECT_EQ(IPPROTO_UDP, callback_protocoal_);
+    EXPECT_EQ(TEST_IPV6_ADDR, callback_address_);
+    EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData, &callback_data_[0],
+                        callback_data_.size(), test_data, sizeof(test_data));
+}
+
+TEST_F(ASIOLinkTest, v6TCPSend) {
+    sendTCP(AF_INET6);
+    EXPECT_NE(callback_native_, sock_);
+    EXPECT_EQ(IPPROTO_TCP, callback_protocoal_);
+    EXPECT_EQ(TEST_IPV6_ADDR, callback_address_);
+    EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData, &callback_data_[0],
+                        callback_data_.size(),
+                        test_data + 2, sizeof(test_data) - 2);
+}
+
+TEST_F(ASIOLinkTest, v4UDPSend) {
+    sendUDP(AF_INET);
+    EXPECT_NE(callback_native_, sock_);
+    EXPECT_EQ(IPPROTO_UDP, callback_protocoal_);
+    EXPECT_EQ(TEST_IPV4_ADDR, callback_address_);
+    EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData, &callback_data_[0],
+                        callback_data_.size(), test_data, sizeof(test_data));
+}
+
+TEST_F(ASIOLinkTest, v4TCPSend) {
+    sendTCP(AF_INET);
+    EXPECT_NE(callback_native_, sock_);
+    EXPECT_EQ(IPPROTO_TCP, callback_protocoal_);
+    EXPECT_EQ(TEST_IPV4_ADDR, callback_address_);
+    EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData, &callback_data_[0],
+                        callback_data_.size(),
+                        test_data + 2, sizeof(test_data) - 2);
+}
+
 }