Browse Source

Substantial further (but still very incomplete) work on ASIO structure.
DNS lookup calls can now be asynchronous, calling back into the UDPServer
or TCPServer coroutine that originated them via io_service::post().

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac327@3084 e5f2f494-b856-4b98-b285-d166d9295462

Evan Hunt 14 years ago
parent
commit
bbb9277960

+ 84 - 37
src/bin/auth/auth_srv.cc

@@ -14,6 +14,8 @@
 
 // $Id$
 
+#include <config.h>
+
 #include <netinet/in.h>
 
 #include <algorithm>
@@ -123,28 +125,58 @@ AuthSrvImpl::~AuthSrvImpl() {
     }
 }
 
-// This is a derived class of \c DNSProvider, to serve as a
+// This is a derived class of \c DNSLookup, to serve as a
 // callback in the asiolink module.  It calls
 // AuthSrv::processMessage() on a single DNS message.
-class MessageProcessor : public DNSProvider {
+class MessageLookup : public DNSLookup {
 public:
-    MessageProcessor(AuthSrv* srv) : server_(srv) {}
-    virtual bool operator()(const IOMessage& io_message,
+    MessageLookup(AuthSrv* srv) : server_(srv) {}
+    virtual void operator()(const IOMessage& io_message,
                             isc::dns::Message& dns_message,
-                            isc::dns::MessageRenderer& renderer) const {
-        return (server_->processMessage(io_message, dns_message, renderer));
+                            isc::dns::MessageRenderer& renderer,
+                            BasicServer* server, bool& complete) const
+    {
+        server_->processMessage(io_message, dns_message, renderer,
+                                server, complete);
+    }
+private:
+    AuthSrv* server_;
+};
+
+// This is a derived class of \c DNSAnswer, to serve as a
+// callback in the asiolink module.  It takes a completed
+// set of answer data from the DNS lookup and assembles it
+// into a wire-format response.
+class MessageAnswer : public DNSAnswer {
+public:
+    MessageAnswer(AuthSrv* srv) : server_(srv) {}
+    virtual void operator()(const IOMessage& io_message,
+                            isc::dns::Message& message,
+                            isc::dns::MessageRenderer& renderer) const
+    {
+        if (io_message.getSocket().getProtocol() == IPPROTO_UDP) {
+            renderer.setLengthLimit(message.getUDPSize());
+        } else {
+            renderer.setLengthLimit(65535);
+        }
+        message.toWire(renderer);
+        if (server_->getVerbose()) {
+            cerr << "[b10-recurse] sending a response (" << renderer.getLength()
+                 << " bytes):\n" << message.toText() << endl;
+        }
     }
+
 private:
     AuthSrv* server_;
 };
 
-// This is a derived class of \c CheckinProvider, to serve
+// This is a derived class of \c IOCallback, to serve
 // as a callback in the asiolink module.  It checks for queued
 // configuration messages, and executes them if found.
-class ConfigChecker : public CheckinProvider {
+class ConfigChecker : public IOCallback {
 public:
     ConfigChecker(AuthSrv* srv) : server_(srv) {}
-    virtual void operator()(void) const {
+    virtual void operator()(const IOMessage& io_message UNUSED_PARAM) const {
         if (server_->configSession()->hasQueuedMsgs()) {
             server_->configSession()->checkCommand();
         }
@@ -156,13 +188,15 @@ private:
 AuthSrv::AuthSrv(const bool use_cache, AbstractXfroutClient& xfrout_client) :
     impl_(new AuthSrvImpl(use_cache, xfrout_client)),
     checkin_provider_(new ConfigChecker(this)),
-    dns_provider_(new MessageProcessor(this))
+    dns_lookup_(new MessageLookup(this)),
+    dns_answer_(new MessageAnswer(this))
 {}
 
 AuthSrv::~AuthSrv() {
     delete impl_;
     delete checkin_provider_;
-    delete dns_provider_;
+    delete dns_lookup_;
+    delete dns_answer_;
 }
 
 namespace {
@@ -241,9 +275,10 @@ AuthSrv::configSession() const {
     return (impl_->config_session_);
 }
 
-bool
+void
 AuthSrv::processMessage(const IOMessage& io_message, Message& message,
-                        MessageRenderer& response_renderer)
+                        MessageRenderer& response_renderer,
+                        BasicServer* server, bool& complete)
 {
     InputBuffer request_buffer(io_message.getData(), io_message.getDataSize());
 
@@ -258,14 +293,21 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
                 cerr << "[b10-auth] received unexpected response, ignoring"
                      << endl;
             }
-            return (false);
+            complete = false;
+            server->resume();
+            return;
         }
     } catch (const Exception& ex) {
-        return (false);
+        if (impl_->verbose_mode_) {
+            cerr << "[b10-auth] DNS packet exception: " << ex.what() << endl;
+        }
+        complete = false;
+        server->resume();
+        return;
     }
 
-    // Parse the message.  On failure, return an appropriate error.
     try {
+        // Parse the message.
         message.fromWire(request_buffer);
     } catch (const DNSProtocolError& error) {
         if (impl_->verbose_mode_) {
@@ -274,14 +316,18 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
         }
         makeErrorMessage(message, response_renderer, error.getRcode(),
                          impl_->verbose_mode_);
-        return (true);
+        complete = true;
+        server->resume();
+        return;
     } catch (const Exception& ex) {
         if (impl_->verbose_mode_) {
             cerr << "[b10-auth] returning SERVFAIL: " << ex.what() << endl;
         }
         makeErrorMessage(message, response_renderer, Rcode::SERVFAIL(),
                          impl_->verbose_mode_);
-        return (true);
+        complete = true;
+        server->resume();
+        return;
     } // other exceptions will be handled at a higher layer.
 
     if (impl_->verbose_mode_) {
@@ -291,35 +337,36 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
     // Perform further protocol-level validation.
 
     if (message.getOpcode() == Opcode::NOTIFY()) {
-        return (impl_->processNotify(io_message, message, response_renderer));
+        complete = impl_->processNotify(io_message, message,
+                                         response_renderer);
     } else if (message.getOpcode() != Opcode::QUERY()) {
         if (impl_->verbose_mode_) {
             cerr << "[b10-auth] unsupported opcode" << endl;
         }
         makeErrorMessage(message, response_renderer, Rcode::NOTIMP(),
                          impl_->verbose_mode_);
-        return (true);
-    }
-
-    if (message.getRRCount(Section::QUESTION()) != 1) {
+        complete = true;
+    } else if (message.getRRCount(Section::QUESTION()) != 1) {
         makeErrorMessage(message, response_renderer, Rcode::FORMERR(),
                          impl_->verbose_mode_);
-        return (true);
-    }
-
-    ConstQuestionPtr question = *message.beginQuestion();
-    const RRType &qtype = question->getType();
-    if (qtype == RRType::AXFR()) {
-        return (impl_->processAxfrQuery(io_message, message,
-                                        response_renderer));
-    } else if (qtype == RRType::IXFR()) {
-        makeErrorMessage(message, response_renderer, Rcode::NOTIMP(),
-                         impl_->verbose_mode_);
-        return (true);
+        complete = true;
     } else {
-        return (impl_->processNormalQuery(io_message, message,
-                                          response_renderer));
+        ConstQuestionPtr question = *message.beginQuestion();
+        const RRType &qtype = question->getType();
+        if (qtype == RRType::AXFR()) {
+            complete = impl_->processAxfrQuery(io_message, message,
+                                                response_renderer);
+        } else if (qtype == RRType::IXFR()) {
+            makeErrorMessage(message, response_renderer, Rcode::NOTIMP(),
+                         impl_->verbose_mode_);
+            complete = true;
+        } else {
+            complete = impl_->processNormalQuery(io_message, message,
+                                               response_renderer);
+        }
     }
+
+    server->resume();
 }
 
 bool

+ 16 - 7
src/bin/auth/auth_srv.h

@@ -66,19 +66,26 @@ public:
     //@}
     /// \return \c true if the \message contains a response to be returned;
     /// otherwise \c false.
-    bool processMessage(const asiolink::IOMessage& io_message,
+    void processMessage(const asiolink::IOMessage& io_message,
                         isc::dns::Message& message,
-                        isc::dns::MessageRenderer& response_renderer);
+                        isc::dns::MessageRenderer& response_renderer,
+                        asiolink::BasicServer* server, bool& complete);
     void setVerbose(bool on);
     bool getVerbose() const;
     isc::data::ConstElementPtr updateConfig(isc::data::ConstElementPtr config);
     isc::config::ModuleCCSession* configSession() const;
     void setConfigSession(isc::config::ModuleCCSession* config_session);
 
-    asiolink::DNSProvider* getDNSProvider() {
-        return (dns_provider_);
+    void setIOService(asiolink::IOService& ios) { io_service_ = &ios; }
+    asiolink::IOService& getIOService() const { return (*io_service_); }
+
+    asiolink::DNSLookup* getDNSLookupProvider() const {
+        return (dns_lookup_);
+    }
+    asiolink::DNSAnswer* getDNSAnswerProvider() const {
+        return (dns_answer_);
     }
-    asiolink::CheckinProvider* getCheckinProvider() {
+    asiolink::IOCallback* getCheckinProvider() const {
         return (checkin_provider_);
     }
 
@@ -98,8 +105,10 @@ public:
     void setXfrinSession(isc::cc::AbstractSession* xfrin_session);
 private:
     AuthSrvImpl* impl_;
-    asiolink::CheckinProvider* checkin_provider_;
-    asiolink::DNSProvider* dns_provider_;
+    asiolink::IOService* io_service_;
+    asiolink::IOCallback* checkin_provider_;
+    asiolink::DNSLookup* dns_lookup_;
+    asiolink::DNSAnswer* dns_answer_;
 };
 
 #endif // __AUTH_SRV_H

+ 6 - 4
src/bin/auth/main.cc

@@ -177,8 +177,9 @@ main(int argc, char* argv[]) {
         auth_server->setVerbose(verbose_mode);
         cout << "[b10-auth] Server created." << endl;
 
-        CheckinProvider* checkin = auth_server->getCheckinProvider();
-        DNSProvider* process = auth_server->getDNSProvider();
+        IOCallback* checkin = auth_server->getCheckinProvider();
+        DNSLookup* lookup = auth_server->getDNSLookupProvider();
+        DNSAnswer* answer = auth_server->getDNSAnswerProvider();
 
         if (address != NULL) {
             // XXX: we can only specify at most one explicit address.
@@ -188,11 +189,12 @@ main(int argc, char* argv[]) {
             // is a short term workaround until we support dynamic listening
             // port allocation.
             io_service = new IOService(*port, *address,
-                                                 checkin, process);
+                                       checkin, lookup, answer);
         } else {
             io_service = new IOService(*port, use_ipv4, use_ipv6,
-                                                 checkin, process);
+                                       checkin, lookup, answer);
         }
+        auth_server->setIOService(*io_service);
         cout << "[b10-auth] IOService created." << endl;
 
         cc_session = new Session(io_service->get_io_service());

+ 122 - 59
src/bin/auth/tests/auth_srv_unittest.cc

@@ -122,6 +122,13 @@ private:
         bool receive_ok_;
     };
 
+    // A nonoperative task object to be used in calls to processMessage()
+    class MockTask : public BasicServer {
+        void operator()(asio::error_code ec UNUSED_PARAM,
+                        size_t length UNUSED_PARAM)
+        {}
+    };
+
 protected:
     AuthSrvTest() : server(true, xfrout),
                     request_message(Message::RENDER),
@@ -140,6 +147,7 @@ protected:
     }
     MockSession notify_session;
     MockXfroutClient xfrout;
+    MockTask noOp;
     AuthSrv server;
     Message request_message;
     Message parse_message;
@@ -273,7 +281,6 @@ AuthSrvTest::createDataFromFile(const char* const datafile,
                                 const int protocol = IPPROTO_UDP)
 {
     delete io_message;
-    delete io_sock;
     data.clear();
 
     delete endpoint;
@@ -359,8 +366,10 @@ TEST_F(AuthSrvTest, unsupportedRequest) {
         data[2] = ((i << 3) & 0xff);
 
         parse_message.clear(Message::PARSE);
-        EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                              response_renderer));
+        bool done;
+        server.processMessage(*io_message, parse_message, response_renderer,
+                              &noOp, done);
+    EXPECT_TRUE(done);
         headerCheck(parse_message, default_qid, Rcode::NOTIMP(), i, QR_FLAG,
                     0, 0, 0, 0);
     }
@@ -378,8 +387,10 @@ TEST_F(AuthSrvTest, verbose) {
 // Multiple questions.  Should result in FORMERR.
 TEST_F(AuthSrvTest, multiQuestion) {
     createDataFromFile("multiquestion_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::FORMERR(), opcode.getCode(),
                 QR_FLAG, 2, 0, 0, 0);
 
@@ -399,8 +410,10 @@ TEST_F(AuthSrvTest, multiQuestion) {
 // dropped.
 TEST_F(AuthSrvTest, shortMessage) {
     createDataFromFile("shortmessage_fromWire");
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 // Response messages.  Must be silently dropped, whether it's a valid response
@@ -408,26 +421,32 @@ TEST_F(AuthSrvTest, shortMessage) {
 TEST_F(AuthSrvTest, response) {
     // A valid (although unusual) response
     createDataFromFile("simpleresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 
     // A response with a broken question section.  must be dropped rather than
     // returning FORMERR.
     createDataFromFile("shortresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 
     // A response to iquery.  must be dropped rather than returning NOTIMP.
     createDataFromFile("iqueryresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 // Query with a broken question
 TEST_F(AuthSrvTest, shortQuestion) {
     createDataFromFile("shortquestion_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     // Since the query's question is broken, the question section of the
     // response should be empty.
     headerCheck(parse_message, default_qid, Rcode::FORMERR(), opcode.getCode(),
@@ -437,8 +456,10 @@ TEST_F(AuthSrvTest, shortQuestion) {
 // Query with a broken answer section
 TEST_F(AuthSrvTest, shortAnswer) {
     createDataFromFile("shortanswer_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
 
     // This is a bogus query, but question section is valid.  So the response
     // should copy the question section.
@@ -456,8 +477,10 @@ TEST_F(AuthSrvTest, shortAnswer) {
 // Query with unsupported version of EDNS.
 TEST_F(AuthSrvTest, ednsBadVers) {
     createDataFromFile("queryBadEDNS_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
 
     // The response must have an EDNS OPT RR in the additional section.
     // Note that the DNSSEC DO bit is cleared even if this bit in the query
@@ -472,8 +495,10 @@ TEST_F(AuthSrvTest, AXFROverUDP) {
     // AXFR over UDP is invalid and should result in FORMERR.
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::FORMERR(), opcode.getCode(),
                 QR_FLAG, 1, 0, 0, 0);
 }
@@ -484,8 +509,10 @@ TEST_F(AuthSrvTest, AXFRSuccess) {
                         RRType::AXFR(), IPPROTO_TCP);
     // On success, the AXFR query has been passed to a separate process,
     // so we shouldn't have to respond.
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
     EXPECT_FALSE(xfrout.isConnected());
 }
 
@@ -494,8 +521,10 @@ TEST_F(AuthSrvTest, AXFRConnectFail) {
     xfrout.disableConnect();
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_TCP);
-    EXPECT_TRUE(server.processMessage(*io_message, parse_message,
-                                      response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::SERVFAIL(),
                 opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
     // For a shot term workaround with xfrout we currently close the connection
@@ -508,7 +537,9 @@ TEST_F(AuthSrvTest, AXFRSendFail) {
     // open.
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_TCP);
-    server.processMessage(*io_message, parse_message, response_renderer);
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
     EXPECT_FALSE(xfrout.isConnected()); // see above
 
     xfrout.disableSend();
@@ -516,8 +547,9 @@ TEST_F(AuthSrvTest, AXFRSendFail) {
     response_renderer.clear();
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_TCP);
-    EXPECT_TRUE(server.processMessage(*io_message, parse_message,
-                                      response_renderer));
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::SERVFAIL(),
                 opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
 
@@ -532,8 +564,9 @@ TEST_F(AuthSrvTest, AXFRDisconnectFail) {
     xfrout.disableDisconnect();
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_TCP);
+    bool done;
     EXPECT_THROW(server.processMessage(*io_message, parse_message,
-                                       response_renderer),
+                                       response_renderer, &noOp, done),
                  XfroutError);
     EXPECT_TRUE(xfrout.isConnected());
     // XXX: we need to re-enable disconnect.  otherwise an exception would be
@@ -546,8 +579,10 @@ TEST_F(AuthSrvTest, notify) {
                         RRType::SOA());
     request_message.setHeaderFlag(MessageFlag::AA());
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
 
     // An internal command message should have been created and sent to an
     // external module.  Check them.
@@ -578,8 +613,10 @@ TEST_F(AuthSrvTest, notifyForCHClass) {
                         RRType::SOA());
     request_message.setHeaderFlag(MessageFlag::AA());
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
 
     // Other conditions should be the same, so simply confirm the RR class is
     // set correctly.
@@ -595,8 +632,10 @@ TEST_F(AuthSrvTest, notifyEmptyQuestion) {
     request_message.setQid(default_qid);
     request_message.toWire(request_renderer);
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::FORMERR(),
                 Opcode::NOTIFY().getCode(), QR_FLAG, 0, 0, 0, 0);
 }
@@ -609,8 +648,10 @@ TEST_F(AuthSrvTest, notifyMultiQuestions) {
                                          RRType::SOA()));
     request_message.setHeaderFlag(MessageFlag::AA());
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::FORMERR(),
                 Opcode::NOTIFY().getCode(), QR_FLAG, 2, 0, 0, 0);
 }
@@ -620,8 +661,10 @@ TEST_F(AuthSrvTest, notifyNonSOAQuestion) {
                         RRType::NS());
     request_message.setHeaderFlag(MessageFlag::AA());
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::FORMERR(),
                 Opcode::NOTIFY().getCode(), QR_FLAG, 1, 0, 0, 0);
 }
@@ -630,8 +673,10 @@ TEST_F(AuthSrvTest, notifyWithoutAA) {
     // implicitly leave the AA bit off.  our implementation will accept it.
     createRequestPacket(Opcode::NOTIFY(), Name("example.com"), RRClass::IN(),
                         RRType::SOA());
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::NOERROR(),
                 Opcode::NOTIFY().getCode(), QR_FLAG | AA_FLAG, 1, 0, 0, 0);
 }
@@ -642,8 +687,10 @@ TEST_F(AuthSrvTest, notifyWithErrorRcode) {
     request_message.setHeaderFlag(MessageFlag::AA());
     request_message.setRcode(Rcode::SERVFAIL());
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::NOERROR(),
                 Opcode::NOTIFY().getCode(), QR_FLAG | AA_FLAG, 1, 0, 0, 0);
 }
@@ -658,8 +705,10 @@ TEST_F(AuthSrvTest, notifyWithoutSession) {
 
     // we simply ignore the notify and let it be resent if an internal error
     // happens.
-    EXPECT_FALSE(server.processMessage(*io_message, parse_message,
-                                       response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 TEST_F(AuthSrvTest, notifySendFail) {
@@ -670,8 +719,10 @@ TEST_F(AuthSrvTest, notifySendFail) {
     request_message.setHeaderFlag(MessageFlag::AA());
     createRequestPacket(IPPROTO_UDP);
 
-    EXPECT_FALSE(server.processMessage(*io_message, parse_message,
-                                       response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 TEST_F(AuthSrvTest, notifyReceiveFail) {
@@ -681,8 +732,10 @@ TEST_F(AuthSrvTest, notifyReceiveFail) {
                         RRType::SOA());
     request_message.setHeaderFlag(MessageFlag::AA());
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_FALSE(server.processMessage(*io_message, parse_message,
-                                       response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 TEST_F(AuthSrvTest, notifyWithBogusSessionMessage) {
@@ -692,8 +745,10 @@ TEST_F(AuthSrvTest, notifyWithBogusSessionMessage) {
                         RRType::SOA());
     request_message.setHeaderFlag(MessageFlag::AA());
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_FALSE(server.processMessage(*io_message, parse_message,
-                                       response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 TEST_F(AuthSrvTest, notifyWithSessionMessageError) {
@@ -704,8 +759,10 @@ TEST_F(AuthSrvTest, notifyWithSessionMessageError) {
                         RRType::SOA());
     request_message.setHeaderFlag(MessageFlag::AA());
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_FALSE(server.processMessage(*io_message, parse_message,
-                                       response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 void
@@ -730,8 +787,10 @@ TEST_F(AuthSrvTest, updateConfig) {
     // response should have the AA flag on, and have an RR in each answer
     // and authority section.
     createDataFromFile("examplequery_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::NOERROR(), opcode.getCode(),
                 QR_FLAG | AA_FLAG, 1, 1, 1, 0);
 }
@@ -744,8 +803,10 @@ TEST_F(AuthSrvTest, datasourceFail) {
     // in a SERVFAIL response, and the answer and authority sections should
     // be empty.
     createDataFromFile("badExampleQuery_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::SERVFAIL(),
                 opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
 }
@@ -759,8 +820,10 @@ TEST_F(AuthSrvTest, updateConfigFail) {
 
     // The original data source should still exist.
     createDataFromFile("examplequery_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer,
+                          &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::NOERROR(), opcode.getCode(),
                 QR_FLAG | AA_FLAG, 1, 1, 1, 0);
 }

+ 10 - 9
src/bin/recurse/main.cc

@@ -61,8 +61,8 @@ bool verbose_mode = false;
 const string PROGRAM = "Recurse";
 const char* DNSPORT = "5300";
 
-Recursor *recursor;
 IOService* io_service;
+Recursor *recursor;
 
 ConstElementPtr
 my_config_handler(ConstElementPtr new_config) {
@@ -161,12 +161,9 @@ main(int argc, char* argv[]) {
             specfile = string(RECURSE_SPECFILE_LOCATION);
         }
 
-        recursor = new Recursor();
-        recursor ->setVerbose(verbose_mode);
-        cout << "[b10-recurse] Server created." << endl;
-
-        CheckinProvider* checkin = recursor->getCheckinProvider();
-        DNSProvider* process = recursor->getDNSProvider();
+        IOCallback* checkin = recursor->getCheckinProvider();
+        DNSLookup* lookup = recursor->getDNSLookupProvider();
+        DNSAnswer* answer = recursor->getDNSAnswerProvider();
 
         if (address != NULL) {
             // XXX: we can only specify at most one explicit address.
@@ -176,13 +173,17 @@ main(int argc, char* argv[]) {
             // is a short term workaround until we support dynamic listening
             // port allocation.
             io_service = new IOService(*port, *address,
-                                                 checkin, process);
+                                       checkin, lookup, answer);
         } else {
             io_service = new IOService(*port, use_ipv4, use_ipv6,
-                                                 checkin, process);
+                                       checkin, lookup, answer);
         }
         cout << "[b10-recurse] IOService created." << endl;
 
+        recursor = new Recursor(*io_service);
+        recursor ->setVerbose(verbose_mode);
+        cout << "[b10-recurse] Server created." << endl;
+
         cc_session = new Session(io_service->get_io_service());
         cout << "[b10-recurse] Configuration session channel created." << endl;
 

+ 110 - 80
src/bin/recurse/recursor.cc

@@ -61,41 +61,78 @@ private:
     RecursorImpl(const RecursorImpl& source);
     RecursorImpl& operator=(const RecursorImpl& source);
 public:
-    RecursorImpl();
-    bool processNormalQuery(const IOMessage& io_message, Message& message,
-                            MessageRenderer& response_renderer);
+    RecursorImpl(asiolink::IOService& io_service);
+    bool processNormalQuery(const IOMessage& io_message,
+                            const Question& question, Message& message,
+                            MessageRenderer& renderer,
+                            BasicServer* server);
     ModuleCCSession* config_session_;
 
     bool verbose_mode_;
 
+    /// Object to handle upstream queries
+    IOQuery ioquery_;
+
     /// Currently non-configurable, but will be.
     static const uint16_t DEFAULT_LOCAL_UDPSIZE = 4096;
 };
 
-RecursorImpl::RecursorImpl() : config_session_(NULL), verbose_mode_(false) {}
+RecursorImpl::RecursorImpl(asiolink::IOService& io_service) :
+    config_session_(NULL), verbose_mode_(false), ioquery_(io_service)
+{}
 
-// This is a derived class of \c DNSProvider, to serve as a
+// This is a derived class of \c DNSLookup, to serve as a
 // callback in the asiolink module.  It calls
 // Recursor::processMessage() on a single DNS message.
-class MessageProcessor : public DNSProvider {
+class MessageLookup : public DNSLookup {
 public:
-    MessageProcessor(Recursor* srv) : server_(srv) {}
-    virtual bool operator()(const IOMessage& io_message,
+    MessageLookup(Recursor* srv) : server_(srv) {}
+    virtual void operator()(const IOMessage& io_message,
                             isc::dns::Message& dns_message,
-                            isc::dns::MessageRenderer& renderer) const {
-        return (server_->processMessage(io_message, dns_message, renderer));
+                            isc::dns::MessageRenderer& renderer,
+                            BasicServer* server, bool& complete) const
+    {
+        server_->processMessage(io_message, dns_message, renderer,
+                                server, complete);
+    }
+private:
+    Recursor* server_;
+};
+
+// This is a derived class of \c DNSAnswer, to serve as a
+// callback in the asiolink module.  It takes a completed
+// set of answer data from the DNS lookup and assembles it
+// into a wire-format response.
+class MessageAnswer : public DNSAnswer {
+public:
+    MessageAnswer(Recursor* srv) : server_(srv) {}
+    virtual void operator()(const IOMessage& io_message,
+                            isc::dns::Message& message,
+                            isc::dns::MessageRenderer& renderer) const
+    {
+        if (io_message.getSocket().getProtocol() == IPPROTO_UDP) {
+            renderer.setLengthLimit(message.getUDPSize());
+        } else {
+            renderer.setLengthLimit(65535);
+        }
+        message.toWire(renderer);
+        if (server_->getVerbose()) {
+            cerr << "[b10-recurse] sending a response (" << renderer.getLength()
+                 << " bytes):\n" << message.toText() << endl;
+        }
     }
+
 private:
     Recursor* server_;
 };
 
-// This is a derived class of \c CheckinProvider, to serve
+// This is a derived class of \c IOCallback, to serve
 // as a callback in the asiolink module.  It checks for queued
 // configuration messages, and executes them if found.
-class ConfigChecker : public CheckinProvider {
+class ConfigChecker : public IOCallback {
 public:
     ConfigChecker(Recursor* srv) : server_(srv) {}
-    virtual void operator()(void) const {
+    virtual void operator()(const IOMessage& io_message UNUSED_PARAM) const {
         if (server_->configSession()->hasQueuedMsgs()) {
             server_->configSession()->checkCommand();
         }
@@ -104,16 +141,18 @@ private:
     Recursor* server_;
 };
 
-Recursor::Recursor() :
-    impl_(new RecursorImpl()),
-    checkin_provider_(new ConfigChecker(this)),
-    dns_provider_(new MessageProcessor(this))
+Recursor::Recursor(asiolink::IOService& io_service) :
+    impl_(new RecursorImpl(io_service)),
+    checkin_(new ConfigChecker(this)),
+    dns_lookup_(new MessageLookup(this)),
+    dns_answer_(new MessageAnswer(this))
 {}
 
 Recursor::~Recursor() {
     delete impl_;
-    delete checkin_provider_;
-    delete dns_provider_;
+    delete checkin_;
+    delete dns_lookup_;
+    delete dns_answer_;
 }
 
 namespace {
@@ -187,9 +226,10 @@ Recursor::configSession() const {
     return (impl_->config_session_);
 }
 
-bool
+void
 Recursor::processMessage(const IOMessage& io_message, Message& message,
-                        MessageRenderer& response_renderer)
+                        MessageRenderer& renderer,
+                        BasicServer* server, bool& complete)
 {
     InputBuffer request_buffer(io_message.getData(), io_message.getDataSize());
 
@@ -204,10 +244,17 @@ Recursor::processMessage(const IOMessage& io_message, Message& message,
                 cerr << "[b10-recurse] received unexpected response, ignoring"
                      << endl;
             }
-            return (false);
+            complete = false;
+            server->resume();
+            return;
         }
     } catch (const Exception& ex) {
-        return (false);
+        if (impl_->verbose_mode_) {
+            cerr << "[b10-recurse] DNS packet exception: " << ex.what() << endl;
+        }
+        complete = false;
+        server->resume();
+        return;
     }
 
     // Parse the message.  On failure, return an appropriate error.
@@ -218,16 +265,20 @@ Recursor::processMessage(const IOMessage& io_message, Message& message,
             cerr << "[b10-recurse] returning " <<  error.getRcode().toText()
                  << ": " << error.what() << endl;
         }
-        makeErrorMessage(message, response_renderer, error.getRcode(),
+        makeErrorMessage(message, renderer, error.getRcode(),
                          impl_->verbose_mode_);
-        return (true);
+        complete = true;
+        server->resume();
+        return;
     } catch (const Exception& ex) {
         if (impl_->verbose_mode_) {
             cerr << "[b10-recurse] returning SERVFAIL: " << ex.what() << endl;
         }
-        makeErrorMessage(message, response_renderer, Rcode::SERVFAIL(),
+        makeErrorMessage(message, renderer, Rcode::SERVFAIL(),
                          impl_->verbose_mode_);
-        return (true);
+        complete = true;
+        server->resume();
+        return;
     } // other exceptions will be handled at a higher layer.
 
     if (impl_->verbose_mode_) {
@@ -237,79 +288,58 @@ Recursor::processMessage(const IOMessage& io_message, Message& message,
 
     // Perform further protocol-level validation.
     if (message.getOpcode() == Opcode::NOTIFY()) {
-        makeErrorMessage(message, response_renderer, Rcode::NOTAUTH(),
+        makeErrorMessage(message, renderer, Rcode::NOTAUTH(),
                          impl_->verbose_mode_);
-        return (true);
+        complete = true;
     } else if (message.getOpcode() != Opcode::QUERY()) {
         if (impl_->verbose_mode_) {
             cerr << "[b10-recurse] unsupported opcode" << endl;
         }
-        makeErrorMessage(message, response_renderer, Rcode::NOTIMP(),
+        makeErrorMessage(message, renderer, Rcode::NOTIMP(),
                          impl_->verbose_mode_);
-        return (true);
-    }
-
-    if (message.getRRCount(Section::QUESTION()) != 1) {
-        makeErrorMessage(message, response_renderer, Rcode::FORMERR(),
+        complete = true;
+    } else if (message.getRRCount(Section::QUESTION()) != 1) {
+        makeErrorMessage(message, renderer, Rcode::FORMERR(),
                          impl_->verbose_mode_);
-        return (true);
-    }
-
-    ConstQuestionPtr question = *message.beginQuestion();
-    const RRType &qtype = question->getType();
-    if (qtype == RRType::AXFR()) {
-        if (io_message.getSocket().getProtocol() == IPPROTO_UDP) {
-            makeErrorMessage(message, response_renderer, Rcode::FORMERR(),
-                             impl_->verbose_mode_);
+        complete = true;
+    } else {
+        ConstQuestionPtr question = *message.beginQuestion();
+        const RRType &qtype = question->getType();
+        if (qtype == RRType::AXFR()) {
+            if (io_message.getSocket().getProtocol() == IPPROTO_UDP) {
+                makeErrorMessage(message, renderer, Rcode::FORMERR(),
+                                 impl_->verbose_mode_);
+            } else {
+                makeErrorMessage(message, renderer, Rcode::NOTIMP(),
+                                 impl_->verbose_mode_);
+            }
+            complete = true;
+        } else if (qtype == RRType::IXFR()) {
+            makeErrorMessage(message, renderer, Rcode::NOTIMP(),
+                         impl_->verbose_mode_);
+            complete = true;
         } else {
-            makeErrorMessage(message, response_renderer, Rcode::NOTIMP(),
-                             impl_->verbose_mode_);
+            complete = impl_->processNormalQuery(io_message, *question,
+                                                 message, renderer, server);
         }
-        return (true);
-    } else if (qtype == RRType::IXFR()) {
-        makeErrorMessage(message, response_renderer, Rcode::NOTIMP(),
-                         impl_->verbose_mode_);
-        return (true);
-    } else {
-        return (impl_->processNormalQuery(io_message, message,
-                                          response_renderer));
     }
+
+    server->resume();
 }
 
 bool
-RecursorImpl::processNormalQuery(const IOMessage& io_message, Message& message,
-                                MessageRenderer& response_renderer)
+RecursorImpl::processNormalQuery(const IOMessage& io_message,
+                                 const Question& question, Message& message,
+                                 MessageRenderer& renderer,
+                                 BasicServer* server)
 {
     const bool dnssec_ok = message.isDNSSECSupported();
-    const uint16_t remote_bufsize = message.getUDPSize();
 
     message.makeResponse();
     message.setRcode(Rcode::NOERROR());
     message.setDNSSECSupported(dnssec_ok);
     message.setUDPSize(RecursorImpl::DEFAULT_LOCAL_UDPSIZE);
-
-    try {
-        // HERE: initiate forward query, construct a reply
-    } catch (const Exception& ex) {
-        if (verbose_mode_) {
-            cerr << "[b10-recurse] Internal error, returning SERVFAIL: " <<
-                ex.what() << endl;
-        }
-        makeErrorMessage(message, response_renderer, Rcode::SERVFAIL(),
-                         verbose_mode_);
-        return (true);
-    }
-
-    const bool udp_buffer =
-        (io_message.getSocket().getProtocol() == IPPROTO_UDP);
-    response_renderer.setLengthLimit(udp_buffer ? remote_bufsize : 65535);
-    message.toWire(response_renderer);
-    if (verbose_mode_) {
-        cerr << "[b10-recurse] sending a response ("
-             << response_renderer.getLength()
-             << " bytes):\n" << message.toText() << endl;
-    }
-
+    ioquery_.sendQuery(io_message, question, renderer, server);
     return (true);
 }
 

+ 14 - 9
src/bin/recurse/recursor.h

@@ -56,31 +56,36 @@ public:
     /// process.  It's normally a reference to an xfr::XfroutClient object,
     /// but can refer to a local mock object for testing (or other
     /// experimental) purposes.
-    Recursor();
+    Recursor(asiolink::IOService& io_service);
     ~Recursor();
     //@}
     /// \return \c true if the \message contains a response to be returned;
     /// otherwise \c false.
-    bool processMessage(const asiolink::IOMessage& io_message,
+    void processMessage(const asiolink::IOMessage& io_message,
                         isc::dns::Message& message,
-                        isc::dns::MessageRenderer& response_renderer);
+                        isc::dns::MessageRenderer& response_renderer,
+                        asiolink::BasicServer* server, bool& complete);
     void setVerbose(bool on);
     bool getVerbose() const;
     isc::data::ConstElementPtr updateConfig(isc::data::ConstElementPtr config);
     isc::config::ModuleCCSession* configSession() const;
     void setConfigSession(isc::config::ModuleCCSession* config_session);
 
-    asiolink::DNSProvider* getDNSProvider() {
-        return (dns_provider_);
+    asiolink::DNSLookup* getDNSLookupProvider() {
+        return (dns_lookup_);
     }
-    asiolink::CheckinProvider* getCheckinProvider() {
-        return (checkin_provider_);
+    asiolink::DNSAnswer* getDNSAnswerProvider() {
+        return (dns_answer_);
+    }
+    asiolink::IOCallback* getCheckinProvider() {
+        return (checkin_);
     }
 
 private:
     RecursorImpl* impl_;
-    asiolink::CheckinProvider* checkin_provider_;
-    asiolink::DNSProvider* dns_provider_;
+    asiolink::IOCallback* checkin_;
+    asiolink::DNSLookup* dns_lookup_;
+    asiolink::DNSAnswer* dns_answer_;
 };
 
 #endif // __RECURSOR_H

+ 51 - 27
src/bin/recurse/tests/recursor_unittest.cc

@@ -44,6 +44,7 @@ using namespace asiolink;
 
 namespace {
 const char* const DEFAULT_REMOTE_ADDRESS = "192.0.2.1";
+const char* const TEST_PORT = "53535";
 
 class DummySocket : public IOSocket {
 private:
@@ -93,11 +94,20 @@ private:
         bool receive_ok_;
     };
 
+    // A nonoperative task object to be used in calls to processMessage()
+    class MockTask : public BasicServer {
+        void operator()(asio::error_code ec UNUSED_PARAM,
+                        size_t length UNUSED_PARAM)
+        {}
+    };
+
 protected:
-    RecursorTest() : server(),
+    RecursorTest() : ios(*TEST_PORT, true, false, NULL, NULL, NULL),
+                    server(ios),
                     request_message(Message::RENDER),
-                    parse_message(Message::PARSE), default_qid(0x1035),
-                    opcode(Opcode(Opcode::QUERY())), qname("www.example.com"),
+                    parse_message(Message::PARSE),
+                    default_qid(0x1035), opcode(Opcode(Opcode::QUERY())),
+                    qname("www.example.com"),
                     qclass(RRClass::IN()), qtype(RRType::A()),
                     io_message(NULL), endpoint(NULL), request_obuffer(0),
                     request_renderer(request_obuffer),
@@ -108,6 +118,8 @@ protected:
         delete endpoint;
     }
     MockSession notify_session;
+    MockTask noOp;
+    IOService ios;
     Recursor server;
     Message request_message;
     Message parse_message;
@@ -294,8 +306,11 @@ TEST_F(RecursorTest, unsupportedRequest) {
         data[2] = ((i << 3) & 0xff);
 
         parse_message.clear(Message::PARSE);
-        EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                              response_renderer));
+        bool done;
+        server.processMessage(*io_message, parse_message,
+                              response_renderer, &noOp,
+                              done);
+        EXPECT_TRUE(done);
         headerCheck(parse_message, default_qid, Rcode::NOTIMP(), i, QR_FLAG,
                     0, 0, 0, 0);
     }
@@ -313,8 +328,9 @@ TEST_F(RecursorTest, verbose) {
 // Multiple questions.  Should result in FORMERR.
 TEST_F(RecursorTest, multiQuestion) {
     createDataFromFile("multiquestion_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::FORMERR(), opcode.getCode(),
                 QR_FLAG, 2, 0, 0, 0);
 
@@ -334,8 +350,9 @@ TEST_F(RecursorTest, multiQuestion) {
 // dropped.
 TEST_F(RecursorTest, shortMessage) {
     createDataFromFile("shortmessage_fromWire");
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 // Response messages.  Must be silently dropped, whether it's a valid response
@@ -343,26 +360,28 @@ TEST_F(RecursorTest, shortMessage) {
 TEST_F(RecursorTest, response) {
     // A valid (although unusual) response
     createDataFromFile("simpleresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_FALSE(done);
 
     // A response with a broken question section.  must be dropped rather than
     // returning FORMERR.
     createDataFromFile("shortresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_FALSE(done);
 
     // A response to iquery.  must be dropped rather than returning NOTIMP.
     createDataFromFile("iqueryresponse_fromWire");
-    EXPECT_EQ(false, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_FALSE(done);
 }
 
 // Query with a broken question
 TEST_F(RecursorTest, shortQuestion) {
     createDataFromFile("shortquestion_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_TRUE(done);
     // Since the query's question is broken, the question section of the
     // response should be empty.
     headerCheck(parse_message, default_qid, Rcode::FORMERR(), opcode.getCode(),
@@ -372,8 +391,9 @@ TEST_F(RecursorTest, shortQuestion) {
 // Query with a broken answer section
 TEST_F(RecursorTest, shortAnswer) {
     createDataFromFile("shortanswer_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_TRUE(done);
 
     // This is a bogus query, but question section is valid.  So the response
     // should copy the question section.
@@ -391,8 +411,9 @@ TEST_F(RecursorTest, shortAnswer) {
 // Query with unsupported version of EDNS.
 TEST_F(RecursorTest, ednsBadVers) {
     createDataFromFile("queryBadEDNS_fromWire");
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_TRUE(done);
 
     // The response must have an EDNS OPT RR in the additional section.
     // Note that the DNSSEC DO bit is cleared even if this bit in the query
@@ -407,8 +428,9 @@ TEST_F(RecursorTest, AXFROverUDP) {
     // AXFR over UDP is invalid and should result in FORMERR.
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::FORMERR(), opcode.getCode(),
                 QR_FLAG, 1, 0, 0, 0);
 }
@@ -417,8 +439,9 @@ TEST_F(RecursorTest, AXFRFail) {
     createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
                         RRType::AXFR(), IPPROTO_TCP);
     // AXFR is not implemented and should always send NOTIMP.
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                           response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::NOTIMP(), opcode.getCode(),
                 QR_FLAG, 1, 0, 0, 0);
 }
@@ -431,8 +454,9 @@ TEST_F(RecursorTest, notifyFail) {
     request_message.setQid(default_qid);
     request_message.toWire(request_renderer);
     createRequestPacket(IPPROTO_UDP);
-    EXPECT_EQ(true, server.processMessage(*io_message, parse_message,
-                                          response_renderer));
+    bool done;
+    server.processMessage(*io_message, parse_message, response_renderer, &noOp, done);
+    EXPECT_TRUE(done);
     headerCheck(parse_message, default_qid, Rcode::NOTAUTH(),
                 Opcode::NOTIFY().getCode(), QR_FLAG, 0, 0, 0, 0);
 }

+ 37 - 12
src/lib/asiolink/asiolink.cc

@@ -94,11 +94,32 @@ IOMessage::IOMessage(const void* data, const size_t data_size,
     remote_endpoint_(remote_endpoint)
 {}
 
+IOQuery::IOQuery(IOService& io_service) : io_service_(io_service) {}
+
+void
+IOQuery::sendQuery(const IOMessage& io_message,
+                   const Question& question, MessageRenderer& renderer,
+                   BasicServer* completer)
+{
+    error_code err;
+    // XXX: hard-code the address for now:
+    const ip::address addr = ip::address::from_string("192.168.1.12", err);
+
+    // XXX: eventually we will need to be able to determine whether
+    // the message should be sent via TCP or UDP, or sent initially via
+    // UDP and then fall back to TCP on failure, but for the moment
+    // we're only going to handle UDP.
+    UDPQuery* query = new UDPQuery(io_service_.get_io_service(), io_message,
+                                   question, addr, renderer, completer);
+    (*query)();
+}
+
 class IOServiceImpl {
 public:
     IOServiceImpl(const char& port,
                   const ip::address* v4addr, const ip::address* v6addr,
-                  CheckinProvider* checkin, DNSProvider* process);
+                  IOCallback* checkin, DNSLookup* lookup,
+                  DNSAnswer* answer);
     asio::io_service io_service_;
 
     typedef boost::shared_ptr<UDPServer> UDPServerPtr;
@@ -112,12 +133,13 @@ public:
 IOServiceImpl::IOServiceImpl(const char& port,
                              const ip::address* const v4addr,
                              const ip::address* const v6addr,
-                             CheckinProvider* checkin, DNSProvider* process) :
+                             IOCallback* checkin,
+                             DNSLookup* lookup,
+                             DNSAnswer* answer) :
     udp4_server_(UDPServerPtr()), udp6_server_(UDPServerPtr()),
     tcp4_server_(TCPServerPtr()), tcp6_server_(TCPServerPtr())
 {
     uint16_t portnum;
-
     try {
         // XXX: SunStudio with stlport4 doesn't reject some invalid
         // representation such as "-1" by lexical_cast<uint16_t>, so
@@ -137,21 +159,21 @@ IOServiceImpl::IOServiceImpl(const char& port,
         if (v4addr != NULL) {
             udp4_server_ = UDPServerPtr(new UDPServer(io_service_,
                                                       *v4addr, portnum,
-                                                      checkin, process));
+                                                      checkin, lookup, answer));
             (*udp4_server_)();
             tcp4_server_ = TCPServerPtr(new TCPServer(io_service_,
                                                       *v4addr, portnum,
-                                                      checkin, process));
+                                                      checkin, lookup, answer));
             (*tcp4_server_)();
         }
         if (v6addr != NULL) {
             udp6_server_ = UDPServerPtr(new UDPServer(io_service_,
                                                       *v6addr, portnum,
-                                                      checkin, process));
+                                                      checkin, lookup, answer));
             (*udp6_server_)();
             tcp6_server_ = TCPServerPtr(new TCPServer(io_service_,
                                                       *v6addr, portnum,
-                                                      checkin, process));
+                                                      checkin, lookup, answer));
             (*tcp6_server_)();
         }
     } catch (const asio::system_error& err) {
@@ -164,7 +186,9 @@ IOServiceImpl::IOServiceImpl(const char& port,
 }
 
 IOService::IOService(const char& port, const char& address,
-                     CheckinProvider* checkin, DNSProvider* process) :
+                     IOCallback* checkin,
+                     DNSLookup* lookup,
+                     DNSAnswer* answer) :
     impl_(NULL)
 {
     error_code err;
@@ -177,20 +201,21 @@ IOService::IOService(const char& port, const char& address,
     impl_ = new IOServiceImpl(port,
                               addr.is_v4() ? &addr : NULL,
                               addr.is_v6() ? &addr : NULL,
-                              checkin, process);
+                              checkin, lookup, answer);
 }
 
 IOService::IOService(const char& port,
                      const bool use_ipv4, const bool use_ipv6,
-                     CheckinProvider* checkin, DNSProvider* process) :
+                     IOCallback* checkin,
+                     DNSLookup* lookup,
+                     DNSAnswer* answer) :
     impl_(NULL)
 {
     const ip::address v4addr_any = ip::address(ip::address_v4::any());
     const ip::address* const v4addrp = use_ipv4 ? &v4addr_any : NULL; 
     const ip::address v6addr_any = ip::address(ip::address_v6::any());
     const ip::address* const v6addrp = use_ipv6 ? &v6addr_any : NULL;
-    impl_ = new IOServiceImpl(port, v4addrp, v6addrp,
-                              checkin, process);
+    impl_ = new IOServiceImpl(port, v4addrp, v6addrp, checkin, lookup, answer);
 }
 
 IOService::~IOService() {

+ 117 - 17
src/lib/asiolink/asiolink.h

@@ -30,6 +30,7 @@
 
 #include <dns/message.h>
 #include <dns/messagerenderer.h>
+#include <dns/question.h>
 
 #include <exceptions/exceptions.h>
 
@@ -386,7 +387,34 @@ private:
     const IOEndpoint& remote_endpoint_;
 };
 
-/// \brief The \c DNSProvider class is an abstract base class for a DNS
+/// XXX: need to add doc
+class BasicServer {
+public:
+    BasicServer() : self(this) {}
+    virtual void operator()(asio::error_code ec = asio::error_code(),
+                            size_t length = 0)
+    {
+        (*self)(ec, length);
+    }
+
+    virtual void doLookup() {}
+    virtual void resume() {}
+private:
+    BasicServer* self;
+};
+
+template <typename T>
+class LookupHandler {
+public:
+    LookupHandler(T caller) : caller_(caller) {}
+    void operator()() {
+        caller_.doLookup();
+    }
+private:
+    T caller_;
+};
+
+/// \brief The \c DNSLookup class is an abstract base class for a DNS
 /// provider function.
 ///
 /// Specific derived class implementations are hidden within the
@@ -394,7 +422,7 @@ private:
 /// as functions via the operator() interface.  Pointers to these
 /// instances can then be provided to the \c IOService class
 /// via its constructor.
-class DNSProvider {
+class DNSLookup {
     ///
     /// \name Constructors and Destructor
     ///
@@ -402,32 +430,78 @@ class DNSProvider {
     /// intentionally defined as private, making this class non-copyable.
     //@{
 private:
-    DNSProvider(const DNSProvider& source);
-    DNSProvider& operator=(const DNSProvider& source);
+    DNSLookup(const DNSLookup& source);
+    DNSLookup& operator=(const DNSLookup& source);
 protected:
     /// \brief The default constructor.
     ///
     /// This is intentionally defined as \c protected as this base class
     /// should never be instantiated (except as part of a derived class).
-    DNSProvider() {}
+    DNSLookup() : self(this) {}
 public:
     /// \brief The destructor
-    virtual ~DNSProvider() {}
+    virtual ~DNSLookup() {}
+    /// \brief The function operator
+    ///
+    /// This makes its call indirectly via the "self" pointer, ensuring
+    /// that the function ultimately invoked will be the one in the derived
+    /// class.
+    virtual void operator()(const IOMessage& io_message,
+                            isc::dns::Message& dns_message,
+                            isc::dns::MessageRenderer& renderer,
+                            BasicServer* server, bool& success)
+                            const
+    {
+        (*self)(io_message, dns_message, renderer, server, success);
+    }
     //@}
-    virtual bool operator()(const IOMessage& io_message,
+private:
+    DNSLookup* self;
+};
+
+/// \brief The \c DNSAnswer class is an abstract base class for a DNS
+/// provider function.
+///
+/// Specific derived class implementations are hidden within the
+/// implementation.  Instances of the derived classes can be called
+/// as functions via the operator() interface.  Pointers to these
+/// instances can then be provided to the \c IOService class
+/// via its constructor.
+class DNSAnswer {
+    ///
+    /// \name Constructors and Destructor
+    ///
+    /// Note: The copy constructor and the assignment operator are
+    /// intentionally defined as private, making this class non-copyable.
+    //@{
+private:
+    DNSAnswer(const DNSAnswer& source);
+    DNSAnswer& operator=(const DNSAnswer& source);
+protected:
+    /// \brief The default constructor.
+    ///
+    /// This is intentionally defined as \c protected as this base class
+    /// should never be instantiated (except as part of a derived class).
+    DNSAnswer() {}
+public:
+    /// \brief The destructor
+    virtual ~DNSAnswer() {}
+    /// \brief The function operator
+    virtual void operator()(const IOMessage& io_message,
                             isc::dns::Message& dns_message,
                             isc::dns::MessageRenderer& renderer) const = 0;
+    //@}
 };
 
-/// \brief The \c CheckinProvider class is an abstract base class for a
-/// checkin function.
+/// \brief The \c IOCallback class is an abstract base class for a
+/// simple callback function with the signature:
 ///
 /// Specific derived class implementations are hidden within the
 /// implementation.  Instances of the derived classes can be called
 /// as functions via the operator() interface.  Pointers to these
 /// instances can then be provided to the \c IOService class
 /// via its constructor.
-class CheckinProvider {
+class IOCallback {
     ///
     /// \name Constructors and Destructor
     ///
@@ -435,19 +509,28 @@ class CheckinProvider {
     /// intentionally defined as private, making this class non-copyable.
     //@{
 private:
-    CheckinProvider(const CheckinProvider& source);
-    CheckinProvider& operator=(const CheckinProvider& source);
+    IOCallback(const IOCallback& source);
+    IOCallback& operator=(const IOCallback& source);
 protected:
     /// \brief The default constructor.
     ///
     /// This is intentionally defined as \c protected as this base class
     /// should never be instantiated (except as part of a derived class).
-    CheckinProvider() {}
+    IOCallback() : self(this) {}
 public:
     /// \brief The destructor
-    virtual ~CheckinProvider() {}
+    virtual ~IOCallback() {}
+    /// \brief The function operator
+    ///
+    /// This makes its call indirectly via the "self" pointer, ensuring
+    /// that the function ultimately invoked will be the one in the derived
+    /// class.
+    virtual void operator()(const IOMessage& io_message) const {
+        (*self)(io_message);
+    }
     //@}
-    virtual void operator()(void) const = 0;
+private:
+    IOCallback* self;
 };
 
 /// \brief The \c IOService class is a wrapper for the ASIO \c io_service
@@ -474,7 +557,9 @@ public:
     /// \brief The constructor with a specific IP address and port on which
     /// the services listen on.
     IOService(const char& port, const char& address,
-              CheckinProvider* checkin, DNSProvider* process);
+              IOCallback* checkin,
+              DNSLookup* lookup,
+              DNSAnswer* answer);
     /// \brief The constructor with a specific port on which the services
     /// listen on.
     ///
@@ -482,7 +567,9 @@ public:
     /// IPv4/IPv6 services will be available if and only if \c use_ipv4
     /// or \c use_ipv6 is \c true, respectively.
     IOService(const char& port, const bool use_ipv4, const bool use_ipv6,
-              CheckinProvider* checkin, DNSProvider* process);
+              IOCallback* checkin,
+              DNSLookup* lookup,
+              DNSAnswer* answer);
     /// \brief The destructor.
     ~IOService();
     //@}
@@ -509,6 +596,19 @@ private:
     IOServiceImpl* impl_;
 };
 
+/// \brief The \c IOQuery class provides a layer of abstraction around
+/// the ASIO code that carries out upstream queries.
+class IOQuery {
+public:
+    IOQuery(IOService& io_service);
+    void sendQuery(const IOMessage& io_message,
+                   const isc::dns::Question& question,
+                   isc::dns::MessageRenderer& renderer,
+                   BasicServer* caller);
+private:
+    IOService& io_service_;
+};
+
 }      // asiolink
 #endif // __ASIOLINK_H
 

+ 1 - 0
src/lib/asiolink/internal/coroutine.h

@@ -18,6 +18,7 @@ public:
   bool is_child() const { return value_ < 0; }
   bool is_parent() const { return !is_child(); }
   bool is_complete() const { return value_ == -1; }
+  int get_value() const { return value_; }
 private:
   friend class coroutine_ref;
   int value_;

+ 19 - 7
src/lib/asiolink/internal/tcpdns.h

@@ -73,20 +73,26 @@ private:
 //
 // Asynchronous TCP server coroutine
 //
-class TCPServer : public coroutine {
+class TCPServer : public virtual BasicServer, public virtual coroutine {
 public:
     explicit TCPServer(asio::io_service& io_service,
                        const asio::ip::address& addr, const uint16_t port, 
-                       CheckinProvider* checkin = NULL,
-                       DNSProvider* process = NULL);
+                       const IOCallback* checkin = NULL,
+                       const DNSLookup* lookup = NULL,
+                       const DNSAnswer* answer = NULL);
 
     void operator()(asio::error_code ec = asio::error_code(),
                     size_t length = 0);
 
+    void doLookup();
+    void resume();
+
 private:
     enum { MAX_LENGTH = 65535 };
     static const size_t TCP_MESSAGE_LENGTHSIZE = 2;
 
+    asio::io_service& io_;
+
     // Class member variables which are dynamic, and changes to which
     // need to accessible from both sides of a coroutine fork or from
     // outside of the coroutine (i.e., from an asynchronous I/O call),
@@ -95,16 +101,22 @@ private:
     boost::shared_ptr<asio::ip::tcp::acceptor> acceptor_;
     boost::shared_ptr<asio::ip::tcp::socket> socket_;
     boost::shared_ptr<isc::dns::MessageRenderer> renderer_;
+    boost::shared_ptr<isc::dns::OutputBuffer> lenbuf_;
+    boost::shared_ptr<isc::dns::OutputBuffer> respbuf_;
+    boost::shared_ptr<asiolink::IOEndpoint> peer_;
+    boost::shared_ptr<asiolink::IOSocket> iosock_;
+    boost::shared_ptr<asiolink::IOMessage> io_message_;
     boost::shared_ptr<char> data_;
 
     // State information that is entirely internal to a given instance
     // of the coroutine can be declared here.
-    isc::dns::OutputBuffer respbuf_;
-    isc::dns::OutputBuffer lenbuf_;
+    size_t bytes_;
+    bool done_;
 
     // Callbacks
-    const CheckinProvider* checkin_callback_;
-    const DNSProvider* dns_callback_;
+    const IOCallback* checkin_callback_;
+    const DNSLookup* lookup_callback_;
+    const DNSAnswer* answer_callback_;
 };
 
 }

+ 44 - 7
src/lib/asiolink/internal/udpdns.h

@@ -57,7 +57,6 @@ private:
     const asio::ip::udp::endpoint& asio_endpoint_;
 };
 
-class UDPBuffers;
 class UDPSocket : public IOSocket {
 private:
     UDPSocket(const UDPSocket& source);
@@ -69,20 +68,30 @@ public:
 private:
     asio::ip::udp::socket& socket_;
 };
+
 //
 // Asynchronous UDP server coroutine
 //
-class UDPServer : public coroutine {
+class UDPServer : public virtual BasicServer, public virtual coroutine {
 public:
     explicit UDPServer(asio::io_service& io_service,
                        const asio::ip::address& addr, const uint16_t port,
-                       CheckinProvider* checkin = NULL,
-                       DNSProvider* process = NULL);
+                       IOCallback* checkin = NULL,
+                       DNSLookup* lookup = NULL,
+                       DNSAnswer* answer = NULL);
+
     void operator()(asio::error_code ec = asio::error_code(),
                     size_t length = 0);
 
-private:
     enum { MAX_LENGTH = 4096 };
+    char answer[MAX_LENGTH];
+    asio::ip::udp::endpoint peer;
+
+    void doLookup();
+    void resume();
+
+private:
+    asio::io_service& io_;
 
     // Class member variables which are dynamic, and changes to which
     // need to accessible from both sides of a coroutine fork or from
@@ -93,17 +102,45 @@ private:
     boost::shared_ptr<char> data_;
     boost::shared_ptr<asio::ip::udp::endpoint> sender_;
     boost::shared_ptr<isc::dns::MessageRenderer> renderer_;
+    boost::shared_ptr<isc::dns::Message> message_;
+    boost::shared_ptr<asiolink::IOEndpoint> peer_;
+    boost::shared_ptr<asiolink::IOSocket> iosock_;
+    boost::shared_ptr<asiolink::IOMessage> io_message_;
 
     // State information that is entirely internal to a given instance
     // of the coroutine can be declared here.
     isc::dns::OutputBuffer respbuf_;
     size_t bytes_;
+    bool done_;
 
     // Callbacks
-    const CheckinProvider* checkin_callback_;
-    const DNSProvider* dns_callback_;
+    const IOCallback* checkin_callback_;
+    const DNSLookup* lookup_callback_;
+    const DNSAnswer* answer_callback_;
 };
 
+//
+// Asynchronous UDP coroutine for upstream queries
+//
+class UDPQuery : public coroutine {
+public:
+    explicit UDPQuery(asio::io_service& io_service,
+                      const IOMessage& io_message,
+                      const isc::dns::Question& q,
+                      const asio::ip::address& addr,
+                      isc::dns::MessageRenderer& renderer,
+                      BasicServer* caller);
+    void operator()(asio::error_code ec = asio::error_code(),
+                    size_t length = 0); 
+private:
+    boost::shared_ptr<asio::ip::udp::socket> socket_;
+    asio::ip::udp::endpoint server_;
+    isc::dns::Question question_;
+    char* data_;
+    size_t datalen_;
+    isc::dns::OutputBuffer msgbuf_;
+    BasicServer* caller_;
+};
 }
 
 #endif // __UDPDNS_H

+ 61 - 44
src/lib/asiolink/tcpdns.cc

@@ -76,9 +76,12 @@ TCPSocket::getProtocol() const {
 
 TCPServer::TCPServer(io_service& io_service,
                      const ip::address& addr, const uint16_t port, 
-                     CheckinProvider* checkin, DNSProvider* process) :
-    respbuf_(0), lenbuf_(TCP_MESSAGE_LENGTHSIZE),
-    checkin_callback_(checkin), dns_callback_(process)
+                     const IOCallback* checkin,
+                     const DNSLookup* lookup,
+                     const DNSAnswer* answer) :
+    io_(io_service), done_(false),
+    checkin_callback_(checkin), lookup_callback_(lookup),
+    answer_callback_(answer)
 {
     tcp::endpoint endpoint(addr, port);
     acceptor_.reset(new tcp::acceptor(io_service));
@@ -91,6 +94,8 @@ TCPServer::TCPServer(io_service& io_service,
     acceptor_->set_option(tcp::acceptor::reuse_address(true));
     acceptor_->bind(endpoint);
     acceptor_->listen();
+    lenbuf_.reset(new OutputBuffer(TCP_MESSAGE_LENGTHSIZE));
+    respbuf_.reset(new OutputBuffer(0));
 }
 
 void
@@ -99,71 +104,83 @@ TCPServer::operator()(error_code ec, size_t length) {
         return;
     }
 
-    bool done = false;
+    boost::array<const_buffer,2> bufs;
     CORO_REENTER (this) {
         do {
             socket_.reset(new tcp::socket(acceptor_->get_io_service()));
             CORO_YIELD acceptor_->async_accept(*socket_, *this);
-            CORO_FORK TCPServer(*this)();
-        } while (is_child());
-
-        // Perform any necessary operations prior to processing an incoming
-        // packet (e.g., checking for queued configuration messages).
-        //
-        // (XXX: it may be a performance issue to have this called for
-        // every single incoming packet; we may wish to throttle it somehow
-        // in the future.)
-        if (checkin_callback_ != NULL) {
-            (*checkin_callback_)();
-        }
+            CORO_FORK io_.post(TCPServer(*this));
+        } while (is_parent());
 
         // Instantiate the data buffer that will be used by the
         // asynchronous read call.
         data_ = boost::shared_ptr<char>(new char[MAX_LENGTH]);
+
+        // Read the message length.
         CORO_YIELD async_read(*socket_, asio::buffer(data_.get(),
-                                                     TCP_MESSAGE_LENGTHSIZE),
-                              *this);
+                              TCP_MESSAGE_LENGTHSIZE), *this);
 
+        // Now read the message itself. (This is done in a different scope
+        // because CORO_REENTER is implemented as a switch statement; the
+        // inline variable declaration of "msglen" and "dnsbuffer" are
+        // therefore not permitted in this scope.)
         CORO_YIELD {
             InputBuffer dnsbuffer((const void *) data_.get(), length);
             uint16_t msglen = dnsbuffer.readUint16();
             async_read(*socket_, asio::buffer(data_.get(), msglen), *this);
         }
 
-        // Stop here if we don't have a DNS callback function
-        if (dns_callback_ == NULL) {
-            CORO_YIELD return;
-        }
+        // Store the io_message data.
+        peer_.reset(new TCPEndpoint(socket_->remote_endpoint()));
+        iosock_.reset(new TCPSocket(*socket_));
+        io_message_.reset(new IOMessage(data_.get(), length, *iosock_, *peer_));
 
-        // Instantiate the objects that will be needed by the
-        // DNS callback and the asynchronous write calls.
-        respbuf_.clear();
-        renderer_.reset(new MessageRenderer(respbuf_));
-
-        // Process the DNS message.  (Must be done in a separate scope 
-        // because CORO_REENTER is implemented with a switch statement
-        // and inline variable declaration isn't allowed.)
-        {
-            TCPEndpoint peer(socket_->remote_endpoint());
-            TCPSocket iosock(*socket_);
-            IOMessage io_message(data_.get(), length, iosock, peer);
-            Message message(Message::PARSE);
-            done = (*dns_callback_)(io_message, message, *renderer_);
+        // Perform any necessary operations prior to processing the incoming
+        // packet (e.g., checking for queued configuration messages).
+        //
+        // (XXX: it may be a performance issue to have this called for
+        // every single incoming packet; we may wish to throttle it somehow
+        // in the future.)
+        if (checkin_callback_ != NULL) {
+            (*checkin_callback_)(*io_message_);
         }
 
-        if (!done) {
+        // Just stop here if we don't have a DNS callback function.
+        if (lookup_callback_ == NULL) {
             CORO_YIELD return;
         }
 
-        CORO_YIELD {
-            lenbuf_.clear();
-            lenbuf_.writeUint16(respbuf_.getLength());
-            boost::array<const_buffer,2> bufs;
-            bufs[0] = buffer(lenbuf_.getData(), lenbuf_.getLength());
-            bufs[1] = buffer(respbuf_.getData(), respbuf_.getLength());
-            async_write(*socket_, bufs, *this);
+        // Reset or instantiate objects that will be needed by the
+        // DNS lookup and the write call.
+        respbuf_->clear();
+        renderer_.reset(new MessageRenderer(*respbuf_));
+
+        // Process the DNS message.
+        bytes_ = length;
+        CORO_YIELD io_.post(LookupHandler<TCPServer>(*this));
+
+        if (!done_) {
+            CORO_YIELD return;
         }
+
+        // Send the response.
+        lenbuf_->clear();
+        lenbuf_->writeUint16(respbuf_->getLength());
+        bufs[0] = buffer(lenbuf_->getData(), lenbuf_->getLength());
+        bufs[1] = buffer(respbuf_->getData(), respbuf_->getLength());
+        CORO_YIELD async_write(*socket_, bufs, *this);
     }
 }
 
+void
+TCPServer::doLookup() {
+    Message message(Message::PARSE);
+    (*lookup_callback_)(*io_message_, message, *renderer_, this, done_);
+}
+
+void
+TCPServer::resume() {
+    io_.post(*this);
+}
+
 }

+ 32 - 29
src/lib/asiolink/tests/asio_link_unittest.cc

@@ -120,52 +120,52 @@ TEST(IOSocketTest, dummySockets) {
 }
 
 TEST(IOServiceTest, badPort) {
-    EXPECT_THROW(IOService(*"65536", true, false, NULL, NULL), IOError);
-    EXPECT_THROW(IOService(*"5300.0", true, false, NULL, NULL), IOError);
-    EXPECT_THROW(IOService(*"-1", true, false, NULL, NULL), IOError);
-    EXPECT_THROW(IOService(*"domain", true, false, NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*"65536", true, false, NULL, NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*"5300.0", true, false, NULL, NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*"-1", true, false, NULL, NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*"domain", true, false, NULL, NULL, NULL), IOError);
 }
 
 TEST(IOServiceTest, badAddress) {
-    EXPECT_THROW(IOService(*TEST_PORT, *"192.0.2.1.1", NULL, NULL), IOError);
-    EXPECT_THROW(IOService(*TEST_PORT, *"2001:db8:::1", NULL, NULL), IOError);
-    EXPECT_THROW(IOService(*TEST_PORT, *"localhost", NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*TEST_PORT, *"192.0.2.1.1", NULL, NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*TEST_PORT, *"2001:db8:::1", NULL, NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*TEST_PORT, *"localhost", NULL, NULL, NULL), IOError);
 }
 
 TEST(IOServiceTest, unavailableAddress) {
     // These addresses should generally be unavailable as a valid local
     // address, although there's no guarantee in theory.
-    EXPECT_THROW(IOService(*TEST_PORT, *"255.255.0.0", NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*TEST_PORT, *"255.255.0.0", NULL, NULL, NULL), IOError);
 
     // Some OSes would simply reject binding attempt for an AF_INET6 socket
     // to an IPv4-mapped IPv6 address.  Even if those that allow it, since
     // the corresponding IPv4 address is the same as the one used in the
     // AF_INET socket case above, it should at least show the same result
     // as the previous one.
-    EXPECT_THROW(IOService(*TEST_PORT, *"::ffff:255.255.0.0", NULL, NULL), IOError);
+    EXPECT_THROW(IOService(*TEST_PORT, *"::ffff:255.255.0.0", NULL, NULL, NULL), IOError);
 }
 
 TEST(IOServiceTest, duplicateBind) {
     // In each sub test case, second attempt should fail due to duplicate bind
 
     // IPv6, "any" address
-    IOService* io_service = new IOService(*TEST_PORT, false, true, NULL, NULL);
-    EXPECT_THROW(IOService(*TEST_PORT, false, true, NULL, NULL), IOError);
+    IOService* io_service = new IOService(*TEST_PORT, false, true, NULL, NULL, NULL);
+    EXPECT_THROW(IOService(*TEST_PORT, false, true, NULL, NULL, NULL), IOError);
     delete io_service;
 
     // IPv6, specific address
-    io_service = new IOService(*TEST_PORT, *TEST_IPV6_ADDR, NULL, NULL);
-    EXPECT_THROW(IOService(*TEST_PORT, *TEST_IPV6_ADDR, NULL, NULL), IOError);
+    io_service = new IOService(*TEST_PORT, *TEST_IPV6_ADDR, NULL, NULL, NULL);
+    EXPECT_THROW(IOService(*TEST_PORT, *TEST_IPV6_ADDR, NULL, NULL, NULL), IOError);
     delete io_service;
 
     // IPv4, "any" address
-    io_service = new IOService(*TEST_PORT, true, false, NULL, NULL);
-    EXPECT_THROW(IOService(*TEST_PORT, true, false, NULL, NULL), IOError);
+    io_service = new IOService(*TEST_PORT, true, false, NULL, NULL, NULL);
+    EXPECT_THROW(IOService(*TEST_PORT, true, false, NULL, NULL, NULL), IOError);
     delete io_service;
 
     // IPv4, specific address
-    io_service = new IOService(*TEST_PORT, *TEST_IPV4_ADDR, NULL, NULL);
-    EXPECT_THROW(IOService(*TEST_PORT, *TEST_IPV4_ADDR, NULL, NULL), IOError);
+    io_service = new IOService(*TEST_PORT, *TEST_IPV4_ADDR, NULL, NULL, NULL);
+    EXPECT_THROW(IOService(*TEST_PORT, *TEST_IPV4_ADDR, NULL, NULL, NULL), IOError);
     delete io_service;
 }
 
@@ -211,7 +211,12 @@ protected:
         if (sock_ != -1) {
             close(sock_);
         }
-        delete io_service_;
+        if (io_service_ != NULL) {
+            delete io_service_;
+        }
+        if (callback_ != NULL) {
+            delete callback_;
+        }
     }
     void sendUDP(const int family) {
         res_ = resolveAddress(family, SOCK_DGRAM, IPPROTO_UDP);
@@ -246,14 +251,15 @@ protected:
     void setIOService(const char& address) {
         delete io_service_;
         io_service_ = NULL;
-        ASIOCallBack* cb = new ASIOCallBack(this);
-        io_service_ = new IOService(*TEST_PORT, address, NULL, cb);
+        callback_ = new ASIOCallBack(this);
+        io_service_ = new IOService(*TEST_PORT, address, callback_, NULL, NULL);
     }
     void setIOService(const bool use_ipv4, const bool use_ipv6) {
         delete io_service_;
         io_service_ = NULL;
-        ASIOCallBack* cb = new ASIOCallBack(this);
-        io_service_ = new IOService(*TEST_PORT, use_ipv4, use_ipv6, NULL, cb);
+        callback_ = new ASIOCallBack(this);
+        io_service_ = new IOService(*TEST_PORT, use_ipv4, use_ipv6, callback_,
+                                    NULL, NULL);
     }
     void doTest(const int family, const int protocol) {
         if (protocol == IPPROTO_UDP) {
@@ -280,15 +286,11 @@ protected:
                             expected_data, expected_datasize);
     }
 private:
-    class ASIOCallBack : public DNSProvider {
+    class ASIOCallBack : public IOCallback {
     public:
         ASIOCallBack(ASIOLinkTest* test_obj) : test_obj_(test_obj) {}
-        bool operator()(const IOMessage& io_message,
-                        isc::dns::Message& dns_message UNUSED_PARAM,
-                        isc::dns::MessageRenderer& renderer UNUSED_PARAM) const
-        {
+        void operator()(const IOMessage& io_message) const {
             test_obj_->callBack(io_message);
-            return (true);
         }
     private:
         ASIOLinkTest* test_obj_;
@@ -306,6 +308,7 @@ private:
     }
 protected:
     IOService* io_service_;
+    ASIOCallBack* callback_;
     int callback_protocol_;
     int callback_native_;
     string callback_address_;
@@ -316,7 +319,7 @@ private:
 };
 
 ASIOLinkTest::ASIOLinkTest() :
-    io_service_(NULL), sock_(-1), res_(NULL)
+    io_service_(NULL), callback_(NULL), sock_(-1), res_(NULL)
 {
     setIOService(true, true);
 }

+ 72 - 18
src/lib/asiolink/udpdns.cc

@@ -22,6 +22,8 @@
 #include <sys/socket.h>
 #include <netinet/in.h>
 
+#include <boost/bind.hpp>
+
 #include <asio.hpp>
 #include <boost/lexical_cast.hpp>
 
@@ -75,8 +77,12 @@ UDPSocket::getProtocol() const {
 
 UDPServer::UDPServer(io_service& io_service,
                      const ip::address& addr, const uint16_t port,
-                     CheckinProvider* checkin, DNSProvider* process) :
-    respbuf_(0), checkin_callback_(checkin), dns_callback_(process)
+                     IOCallback* checkin,
+                     DNSLookup* lookup,
+                     DNSAnswer* answer) :
+    io_(io_service), respbuf_(0), done_(false),
+    checkin_callback_(checkin), lookup_callback_(lookup),
+    answer_callback_(answer)
 {
     // Wwe use a different instantiation for v4,
     // otherwise asio will bind to both v4 and v6
@@ -91,7 +97,6 @@ UDPServer::UDPServer(io_service& io_service,
 
 void
 UDPServer::operator()(error_code ec, size_t length) {
-    bool done = false;
     CORO_REENTER (this) {
         do {
             // Instantiate the data buffer and endpoint that will
@@ -105,8 +110,13 @@ UDPServer::operator()(error_code ec, size_t length) {
             } while (ec || length == 0);
 
             bytes_ = length;
-            CORO_FORK UDPServer(*this)();
-        } while (is_child());
+            CORO_FORK io_.post(UDPServer(*this));
+        } while (is_parent());
+
+        // Store the io_message data.
+        peer_.reset(new UDPEndpoint(*sender_));
+        iosock_.reset(new UDPSocket(*socket_));
+        io_message_.reset(new IOMessage(data_.get(), bytes_, *iosock_, *peer_));
 
         // Perform any necessary operations prior to processing an incoming
         // packet (e.g., checking for queued configuration messages).
@@ -115,11 +125,11 @@ UDPServer::operator()(error_code ec, size_t length) {
         // every single incoming packet; we may wish to throttle it somehow
         // in the future.)
         if (checkin_callback_ != NULL) {
-            (*checkin_callback_)();
+            (*checkin_callback_)(*io_message_);
         }
 
         // Stop here if we don't have a DNS callback function
-        if (dns_callback_ == NULL) {
+        if (lookup_callback_ == NULL) {
             CORO_YIELD return;
         }
 
@@ -127,26 +137,70 @@ UDPServer::operator()(error_code ec, size_t length) {
         // asynchronous send call.
         respbuf_.clear();
         renderer_.reset(new MessageRenderer(respbuf_));
+        message_.reset(new Message(Message::PARSE));
 
-        // Process the DNS message.  (Must be done in a separate scope 
-        // because CORO_REENTER is implemented with a switch statement,
-        // and thus normal inline variable declaration isn't allowed.)
-        {
-            UDPEndpoint peer(*sender_);
-            UDPSocket iosock(*socket_);
-            IOMessage io_message(data_.get(), bytes_, iosock, peer);
-            Message message(Message::PARSE);
-            done = (*dns_callback_)(io_message, message, *renderer_);
-        }
+        CORO_YIELD io_.post(LookupHandler<UDPServer>(*this));
 
-        if (!done) {
+        if (!done_) {
             CORO_YIELD return;
         }
 
+        (*answer_callback_)(*io_message_, *message_, *renderer_);
         CORO_YIELD socket_->async_send_to(buffer(respbuf_.getData(),
                                                  respbuf_.getLength()),
                                      *sender_, *this);
     }
 }
 
+void
+UDPServer::doLookup() {
+    (*lookup_callback_)(*io_message_, *message_, *renderer_, this, done_);
+}
+
+void
+UDPServer::resume() {
+    io_.post(*this);
+}
+
+UDPQuery::UDPQuery(io_service& io_service, const IOMessage& io_message,
+                   const Question& q, const ip::address& addr,
+                   MessageRenderer& renderer, BasicServer* caller) :
+    question_(q),
+    data_((char*) renderer.getData()), datalen_(renderer.getLength()),
+    msgbuf_(512), caller_(caller)
+{
+    udp proto = addr.is_v4() ? udp::v4() : udp::v6();
+    socket_.reset(new udp::socket(io_service, proto));
+    server_ = udp::endpoint(addr, 53);
+}
+
+void
+UDPQuery::operator()(error_code ec, size_t length) {
+    if (ec) {
+        return;
+    }
+
+    CORO_REENTER (this) {
+        {
+            Message msg(Message::RENDER);
+            msg.setQid(0);
+            msg.setOpcode(Opcode::QUERY());
+            msg.setRcode(Rcode::NOERROR());
+            msg.setHeaderFlag(MessageFlag::RD());
+            msg.addQuestion(question_);
+            MessageRenderer renderer(msgbuf_);
+            msg.toWire(renderer);
+        }
+
+        CORO_YIELD socket_->async_send_to(buffer(msgbuf_.getData(),
+                                                 msgbuf_.getLength()),
+                                           server_, *this);
+
+        CORO_YIELD socket_->async_receive_from(buffer(data_, datalen_),
+                                               server_, *this);
+    }
+
+    caller_->resume();
+}
+
 }