Parcourir la source

refactoring: pass IOMessage to AuthSrv instead of InputBuffer.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac221@2176 e5f2f494-b856-4b98-b285-d166d9295462
JINMEI Tatuya il y a 15 ans
Parent
commit
9a7ab5ecf1

+ 1 - 0
src/bin/auth/Makefile.am

@@ -1,6 +1,7 @@
 SUBDIRS = . tests
 
 AM_CPPFLAGS = -I$(top_srcdir)/src/lib -I$(top_builddir)/src/lib
+AM_CPPFLAGS += -I$(top_srcdir)/src/bin -I$(top_builddir)/src/bin
 AM_CPPFLAGS += -I$(top_srcdir)/src/lib/dns -I$(top_builddir)/src/lib/dns
 AM_CPPFLAGS += -I$(top_builddir)/src/lib/cc
 

+ 34 - 7
src/bin/auth/asio_link.cc

@@ -147,12 +147,42 @@ private:
     udp::socket& socket_;
 };
 
+class DummySocket : public IOSocket {
+private:
+    DummySocket(const DummySocket& source);
+    DummySocket& operator=(const DummySocket& source);
+public:
+    DummySocket(const int protocol) : protocol_(protocol) {}
+    virtual int getNative() const { return (-1); }
+    virtual int getProtocol() const { return (protocol_); }
+private:
+    const int protocol_;
+};
+
+IOSocket&
+IOSocket::getDummyUDPSocket() {
+    static DummySocket socket(IPPROTO_UDP);
+    return (socket);
+}
+
+IOSocket&
+IOSocket::getDummyTCPSocket() {
+    static DummySocket socket(IPPROTO_TCP);
+    return (socket);
+}
+
 IOMessage::IOMessage(const void* data, const size_t data_size,
                      IOSocket& io_socket, const ip::address& remote_address) :
     data_(data), data_size_(data_size), io_socket_(io_socket),
     remote_io_address_(remote_address)
 {}
 
+IOMessage::IOMessage(const void* data, const size_t data_size,
+                     IOSocket& io_socket, const string& remote_address) :
+    data_(data), data_size_(data_size), io_socket_(io_socket),
+    remote_io_address_(remote_address)
+{}
+
 //
 // Helper classes for asynchronous I/O using asio
 //
@@ -213,7 +243,6 @@ public:
                 return;
             }
 
-            InputBuffer dnsbuffer(data_, bytes_transferred);
 #ifdef USE_XFROUT
             if (check_axfr_query(data_, bytes_transferred)) {
                 dispatch_axfr_query(socket_.native(), data_, bytes_transferred); 
@@ -221,8 +250,8 @@ public:
                 start();
             } else {
 #endif
-                if (auth_server_->processMessage(dnsbuffer, dns_message_,
-                                                response_renderer_, false)) {
+                if (auth_server_->processMessage(io_message, dns_message_,
+                                                response_renderer_)) {
                     responselen_buffer_.writeUint16(
                         response_buffer_.getLength());
                     async_write(socket_,
@@ -383,12 +412,10 @@ public:
                 return;
             }
 
-            InputBuffer request_buffer(data_, bytes_recvd);
-
             dns_message_.clear(Message::PARSE);
             response_renderer_.clear();
-            if (auth_server_->processMessage(request_buffer, dns_message_,
-                                            response_renderer_, true)) {
+            if (auth_server_->processMessage(io_message, dns_message_,
+                                             response_renderer_)) {
                 socket_.async_send_to(
                     asio::buffer(response_buffer_.getData(),
                                         response_buffer_.getLength()),

+ 4 - 0
src/bin/auth/asio_link.h

@@ -69,6 +69,8 @@ public:
     virtual ~IOSocket() {}
     virtual int getNative() const = 0;
     virtual int getProtocol() const = 0;
+    static IOSocket& getDummyUDPSocket();
+    static IOSocket& getDummyTCPSocket();
 };
 
 class IOMessage {
@@ -78,6 +80,8 @@ private:
 public:
     IOMessage(const void* data, size_t data_size, IOSocket& io_socket,
               const asio::ip::address& remote_address);
+    IOMessage(const void* data, size_t data_size, IOSocket& io_socket,
+              const std::string& remote_address);
     const void* getData() const { return (data_); }
     size_t getDataSize() const { return (data_size_); }
     const IOSocket& getSocket() const { return (io_socket_); }

+ 12 - 5
src/bin/auth/auth_srv.cc

@@ -14,6 +14,8 @@
 
 // $Id$
 
+#include <netinet/in.h>
+
 #include <algorithm>
 #include <cassert>
 #include <iostream>
@@ -40,8 +42,9 @@
 
 #include <cc/data.h>
 
-#include "common.h"
-#include "auth_srv.h"
+#include <auth/common.h>
+#include <auth/auth_srv.h>
+#include <auth/asio_link.h>
 
 #include <boost/lexical_cast.hpp>
 
@@ -53,6 +56,7 @@ using namespace isc::dns;
 using namespace isc::dns::rdata;
 using namespace isc::data;
 using namespace isc::config;
+using namespace asio_link;
 
 class AuthSrvImpl {
 private:
@@ -167,10 +171,11 @@ AuthSrv::configSession() const {
 }
 
 bool
-AuthSrv::processMessage(InputBuffer& request_buffer, Message& message,
-                        MessageRenderer& response_renderer,
-                        const bool udp_buffer)
+AuthSrv::processMessage(const IOMessage& io_message, Message& message,
+                        MessageRenderer& response_renderer)
 {
+    InputBuffer request_buffer(io_message.getData(), io_message.getDataSize());
+
     // First, check the header part.  If we fail even for the base header,
     // just drop the message.
     try {
@@ -250,6 +255,8 @@ AuthSrv::processMessage(InputBuffer& request_buffer, Message& message,
         return (true);
     }
 
+    const bool udp_buffer =
+        (io_message.getSocket().getProtocol() == IPPROTO_UDP);
     response_renderer.setLengthLimit(udp_buffer ? remote_bufsize : 65535);
     message.toWire(response_renderer);
     if (impl_->verbose_mode_) {

+ 6 - 3
src/bin/auth/auth_srv.h

@@ -30,6 +30,10 @@ class MessageRenderer;
 }
 }
 
+namespace asio_link {
+class IOMessage;
+}
+
 class AuthSrvImpl;
 
 class AuthSrv {
@@ -48,10 +52,9 @@ public:
     //@}
     /// \return \c true if the \message contains a response to be returned;
     /// otherwise \c false.
-    bool processMessage(isc::dns::InputBuffer& request_buffer,
+    bool processMessage(const asio_link::IOMessage& io_message,
                         isc::dns::Message& message,
-                        isc::dns::MessageRenderer& response_renderer,
-                        bool udp_buffer);
+                        isc::dns::MessageRenderer& response_renderer);
     void setVerbose(bool on);
     bool getVerbose() const;
     void serve(std::string zone_name);

+ 7 - 0
src/bin/auth/tests/asio_link_unittest.cc

@@ -59,6 +59,13 @@ TEST(IOAddressTest, fromText) {
     EXPECT_THROW(IOAddress("2001:db8:::1234"), IOError);
 }
 
+TEST(IOSocketTest, dummySockets) {
+    EXPECT_EQ(IPPROTO_UDP, IOSocket::getDummyUDPSocket().getProtocol());
+    EXPECT_EQ(IPPROTO_TCP, IOSocket::getDummyTCPSocket().getProtocol());
+    EXPECT_EQ(-1, IOSocket::getDummyUDPSocket().getNative());
+    EXPECT_EQ(-1, IOSocket::getDummyTCPSocket().getNative());
+}
+
 struct addrinfo*
 resolveAddress(const int family, const int sock_type, const int protocol) {
     const char* const addr = (family == AF_INET6) ?

+ 36 - 29
src/bin/auth/tests/auth_srv_unittest.cc

@@ -26,6 +26,7 @@
 #include <cc/data.h>
 
 #include <auth/auth_srv.h>
+#include <auth/asio_link.h>
 
 #include <dns/tests/unittest_util.h>
 
@@ -33,6 +34,7 @@ using isc::UnitTestUtil;
 using namespace std;
 using namespace isc::dns;
 using namespace isc::data;
+using namespace asio_link;
 
 namespace {
 const char* CONFIG_TESTDB =
@@ -47,10 +49,14 @@ protected:
     AuthSrvTest() : request_message(Message::RENDER),
                     parse_message(Message::PARSE), default_qid(0x1035),
                     opcode(Opcode(Opcode::QUERY())), qname("www.example.com"),
-                    qclass(RRClass::IN()), qtype(RRType::A()), ibuffer(NULL),
-                    request_obuffer(0), request_renderer(request_obuffer),
+                    qclass(RRClass::IN()), qtype(RRType::A()),
+                    io_message(NULL), request_obuffer(0),
+                    request_renderer(request_obuffer),
                     response_obuffer(0), response_renderer(response_obuffer)
     {}
+    ~AuthSrvTest() {
+        delete io_message;
+    }
     AuthSrv server;
     Message request_message;
     Message parse_message;
@@ -59,7 +65,7 @@ protected:
     const Name qname;
     const RRClass qclass;
     const RRType qtype;
-    InputBuffer* ibuffer;
+    IOMessage *io_message;
     OutputBuffer request_obuffer;
     MessageRenderer request_renderer;
     OutputBuffer response_obuffer;
@@ -82,11 +88,12 @@ const unsigned int CD_FLAG = 0x40;
 
 void
 AuthSrvTest::createDataFromFile(const char* const datafile) {
-    delete ibuffer;
+    delete io_message;
     data.clear();
 
     UnitTestUtil::readWireData(datafile, data);
-    ibuffer = new InputBuffer(&data[0], data.size());
+    io_message = new IOMessage(&data[0], data.size(),
+                               IOSocket::getDummyUDPSocket(), "192.0.2.1");
 }
 
 void
@@ -122,8 +129,8 @@ TEST_F(AuthSrvTest, unsupportedRequest) {
         data[2] = ((i << 3) & 0xff);
 
         parse_message.clear(Message::PARSE);
-        EXPECT_EQ(true, server.processMessage(*ibuffer, parse_message,
-                                              response_renderer, true));
+        EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
+                                              response_renderer));
         headerCheck(parse_message, default_qid, Rcode::NOTIMP(), i, QR_FLAG,
                     0, 0, 0, 0);
     }
@@ -141,8 +148,8 @@ TEST_F(AuthSrvTest, verbose) {
 // Multiple questions.  Should result in FORMERR.
 TEST_F(AuthSrvTest, multiQuestion) {
     createDataFromFile("multiquestion_fromWire");
-    EXPECT_EQ(true, server.processMessage(*ibuffer, parse_message,
-                                          response_renderer, true));
+    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
+                                          response_renderer));
     headerCheck(parse_message, default_qid, Rcode::FORMERR(), opcode.getCode(),
                 QR_FLAG, 2, 0, 0, 0);
 
@@ -162,8 +169,8 @@ TEST_F(AuthSrvTest, multiQuestion) {
 // dropped.
 TEST_F(AuthSrvTest, shortMessage) {
     createDataFromFile("shortmessage_fromWire");
-    EXPECT_EQ(false, server.processMessage(*ibuffer, parse_message,
-                                           response_renderer, true));
+    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
+                                           response_renderer));
 }
 
 // Response messages.  Must be silently dropped, whether it's a valid response
@@ -171,26 +178,26 @@ TEST_F(AuthSrvTest, shortMessage) {
 TEST_F(AuthSrvTest, response) {
     // A valid (although unusual) response
     createDataFromFile("simpleresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*ibuffer, parse_message,
-                                           response_renderer, true));
+    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
+                                           response_renderer));
 
     // A response with a broken question section.  must be dropped rather than
     // returning FORMERR.
     createDataFromFile("shortresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*ibuffer, parse_message,
-                                           response_renderer, true));
+    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
+                                           response_renderer));
 
     // A response to iquery.  must be dropped rather than returning NOTIMP.
     createDataFromFile("iqueryresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*ibuffer, parse_message,
-                                           response_renderer, true));
+    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
+                                           response_renderer));
 }
 
 // Query with a broken question
 TEST_F(AuthSrvTest, shortQuestion) {
     createDataFromFile("shortquestion_fromWire");
-    EXPECT_EQ(true, server.processMessage(*ibuffer, parse_message,
-                                          response_renderer, true));
+    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
+                                          response_renderer));
     // Since the query's question is broken, the question section of the
     // response should be empty.
     headerCheck(parse_message, default_qid, Rcode::FORMERR(), opcode.getCode(),
@@ -200,8 +207,8 @@ TEST_F(AuthSrvTest, shortQuestion) {
 // Query with a broken answer section
 TEST_F(AuthSrvTest, shortAnswer) {
     createDataFromFile("shortanswer_fromWire");
-    EXPECT_EQ(true, server.processMessage(*ibuffer, parse_message,
-                                          response_renderer, true));
+    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
+                                          response_renderer));
 
     // This is a bogus query, but question section is valid.  So the response
     // should copy the question section.
@@ -219,8 +226,8 @@ TEST_F(AuthSrvTest, shortAnswer) {
 // Query with unsupported version of EDNS.
 TEST_F(AuthSrvTest, ednsBadVers) {
     createDataFromFile("queryBadEDNS_fromWire");
-    EXPECT_EQ(true, server.processMessage(*ibuffer, parse_message,
-                                          response_renderer, true));
+    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
+                                          response_renderer));
 
     // The response must have an EDNS OPT RR in the additional section.
     // Note that the DNSSEC DO bit is cleared even if this bit in the query
@@ -253,8 +260,8 @@ TEST_F(AuthSrvTest, updateConfig) {
     // response should have the AA flag on, and have an RR in each answer
     // and authority section.
     createDataFromFile("examplequery_fromWire");
-    EXPECT_EQ(true, server.processMessage(*ibuffer, parse_message,
-                                          response_renderer, true));
+    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
+                                          response_renderer));
     headerCheck(parse_message, default_qid, Rcode::NOERROR(), opcode.getCode(),
                 QR_FLAG | AA_FLAG, 1, 1, 1, 0);
 }
@@ -267,8 +274,8 @@ TEST_F(AuthSrvTest, datasourceFail) {
     // in a SERVFAIL response, and the answer and authority sections should
     // be empty.
     createDataFromFile("badExampleQuery_fromWire");
-    EXPECT_EQ(true, server.processMessage(*ibuffer, parse_message,
-                                          response_renderer, true));
+    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
+                                          response_renderer));
     headerCheck(parse_message, default_qid, Rcode::SERVFAIL(), opcode.getCode(),
                 QR_FLAG, 1, 0, 0, 0);
 }
@@ -282,8 +289,8 @@ TEST_F(AuthSrvTest, updateConfigFail) {
 
     // The original data source should still exist.
     createDataFromFile("examplequery_fromWire");
-    EXPECT_EQ(true, server.processMessage(*ibuffer, parse_message,
-                                          response_renderer, true));
+    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
+                                          response_renderer));
     headerCheck(parse_message, default_qid, Rcode::NOERROR(), opcode.getCode(),
                 QR_FLAG | AA_FLAG, 1, 1, 1, 0);
 }