Browse Source

Merge #401 into #327

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac327@3660 e5f2f494-b856-4b98-b285-d166d9295462
Michal Vaner 14 years ago
parent
commit
aa71aa7ecf

+ 12 - 0
src/bin/recurse/recurse.spec.pre.in

@@ -4,6 +4,18 @@
     "module_description": "Recursive service",
     "config_data": [
       {
+        "item_name": "timeout",
+        "item_type": "integer",
+        "item_optional": False,
+        "item_default": 2000
+      },
+      {
+        "item_name": "retries",
+        "item_type": "integer",
+        "item_optional": False,
+        "item_default": 0
+      },
+      {
         "item_name": "forward_addresses",
         "item_type": "list",
         "item_optional": True,

+ 57 - 16
src/bin/recurse/recursor.cc

@@ -20,6 +20,7 @@
 
 #include <algorithm>
 #include <vector>
+#include <cassert>
 
 #include <asiolink/asiolink.h>
 #include <asiolink/ioaddress.h>
@@ -65,7 +66,7 @@ private:
 public:
     RecursorImpl() :
         config_session_(NULL),
-        rec_query_()
+        rec_query_(NULL)
     {}
 
     ~RecursorImpl() {
@@ -73,6 +74,7 @@ public:
     }
 
     void querySetup(DNSService& dnss) {
+        assert(!rec_query_); // queryShutdown must be called first
         dlog("Query setup");
         rec_query_ = new RecursiveQuery(dnss, upstream_);
     }
@@ -117,12 +119,22 @@ public:
     /// Addresses we listen on
     vector<addr_t> listen_;
 
+    /// Time in milliseconds, to timeout
+    int timeout_;
+    /// Number of retries after timeout
+    unsigned retries_;
+
 private:
 
     /// Object to handle upstream queries
     RecursiveQuery* rec_query_;
 };
 
+/*
+ * std::for_each has a broken interface. It makes no sense in a language
+ * without lambda functions/closures. These two classes emulate the lambda
+ * functions so for_each can be used.
+ */
 class QuestionInserter {
 public:
     QuestionInserter(MessagePtr message) : message_(message) {}
@@ -136,9 +148,8 @@ public:
 
 class SectionInserter {
 public:
-    SectionInserter(MessagePtr message, const Message::Section sect,
-        bool sign) :
-        message_(message), section_(sect), sign_(sign)
+    SectionInserter(MessagePtr message, const Message::Section sect) :
+        message_(message), section_(sect)
     {}
     void operator()(const RRsetPtr rrset) {
         dlog("Adding RRSet to message section " +
@@ -147,7 +158,6 @@ public:
     }
     MessagePtr message_;
     const Message::Section section_;
-    bool sign_;
 };
 
 void
@@ -212,7 +222,6 @@ private:
 // into a wire-format response.
 class MessageAnswer : public DNSAnswer {
 public:
-    MessageAnswer(Recursor* srv) : server_(srv) {}
     virtual void operator()(const IOMessage& io_message,
                             MessagePtr message,
                             OutputBufferPtr buffer) const
@@ -253,16 +262,13 @@ public:
                 incoming.fromWire(ibuf);
                 for_each(incoming.beginSection(Message::SECTION_ANSWER),
                          incoming.endSection(Message::SECTION_ANSWER),
-                         SectionInserter(message, Message::SECTION_ANSWER,
-                         true));
+                         SectionInserter(message, Message::SECTION_ANSWER));
                 for_each(incoming.beginSection(Message::SECTION_ADDITIONAL),
                          incoming.endSection(Message::SECTION_ADDITIONAL),
-                         SectionInserter(message, Message::SECTION_ADDITIONAL,
-                         true));
+                         SectionInserter(message, Message::SECTION_ADDITIONAL));
                 for_each(incoming.beginSection(Message::SECTION_AUTHORITY),
                          incoming.endSection(Message::SECTION_ADDITIONAL),
-                         SectionInserter(message, Message::SECTION_AUTHORITY,
-                         true));
+                         SectionInserter(message, Message::SECTION_AUTHORITY));
             } catch (const Exception& ex) {
                 // Incoming message couldn't be read, we just SERVFAIL
                 message->setRcode(Rcode::SERVFAIL());
@@ -288,9 +294,6 @@ public:
             boost::lexical_cast<string>(renderer.getLength()) + "bytes): \n" +
             message->toText());
     }
-
-private:
-    Recursor* server_;
 };
 
 // This is a derived class of \c SimpleCallback, to serve
@@ -312,7 +315,7 @@ Recursor::Recursor() :
     impl_(new RecursorImpl()),
     checkin_(new ConfigCheck(this)),
     dns_lookup_(new MessageLookup(this)),
-    dns_answer_(new MessageAnswer(this))
+    dns_answer_(new MessageAnswer)
 {}
 
 Recursor::~Recursor() {
@@ -489,6 +492,27 @@ Recursor::updateConfig(ConstElementPtr config) {
         vector<addr_t> forwardAddresses(parseAddresses(forwardAddressesE));
         ConstElementPtr listenAddressesE(config->get("listen_on"));
         vector<addr_t> listenAddresses(parseAddresses(listenAddressesE));
+        bool set_timeouts(false);
+        int timeout = impl_->timeout_;
+        unsigned retries = impl_->retries_;
+        ConstElementPtr timeoutE(config->get("timeout")),
+            retriesE(config->get("retries"));
+        if (timeoutE) {
+            // It should be safe to just get it, the config manager should
+            // check for us
+            timeout = timeoutE->intValue();
+            if (timeout < -1) {
+                isc_throw(BadValue, "Timeout too small");
+            }
+            set_timeouts = true;
+        }
+        if (retriesE) {
+            if (retriesE->intValue() < 0) {
+                isc_throw(BadValue, "Negative number of retries");
+            }
+            retries = retriesE->intValue();
+            set_timeouts = true;
+        }
         // Everything OK, so commit the changes
         // listenAddresses can fail to bind, so try them first
         if (listenAddressesE) {
@@ -497,6 +521,9 @@ Recursor::updateConfig(ConstElementPtr config) {
         if (forwardAddressesE) {
             setForwardAddresses(forwardAddresses);
         }
+        if (set_timeouts) {
+            setTimeouts(timeout, retries);
+        }
         return (isc::config::createAnswer());
     } catch (const isc::Exception& error) {
         dlog(string("error in config: ") + error.what());
@@ -562,6 +589,20 @@ Recursor::setListenAddresses(const vector<addr_t>& addresses) {
     }
 }
 
+void
+Recursor::setTimeouts(int timeout, unsigned retries) {
+    dlog("Setting timeout to " + boost::lexical_cast<string>(timeout) +
+        " and retry count to " + boost::lexical_cast<string>(retries));
+    impl_->timeout_ = timeout;
+    impl_->retries_ = retries;
+    impl_->queryShutdown();
+    impl_->querySetup(*dnss_);
+}
+pair<int, unsigned>
+Recursor::getTimeouts() const {
+    return (pair<int, unsigned>(impl_->timeout_, impl_->retries_));
+}
+
 vector<addr_t>
 Recursor::getListenAddresses() const {
     return (impl_->listen_);

+ 26 - 0
src/bin/recurse/recursor.h

@@ -28,6 +28,15 @@
 
 class RecursorImpl;
 
+/**
+ * \short The recursive nameserver.
+ *
+ * It is a concreate class implementing recursive DNS server protocol
+ * processing. It is responsible for handling incoming DNS requests. It parses
+ * them, passes them deeper into the resolving machinery and then creates the
+ * answer. It doesn't really know about chasing referrals and similar, it
+ * simply plugs the parts that know into the network handling code.
+ */
 class Recursor {
     ///
     /// \name Constructors, Assignment Operator and Destructor.
@@ -110,6 +119,23 @@ public:
         uint16_t> >& addresses);
     std::vector<std::pair<std::string, uint16_t> > getListenAddresses() const;
 
+    /**
+     * \short Set options related to timeouts.
+     *
+     * This sets the time of timeout and number of retries.
+     * \param timeout The time in milliseconds. The value -1 disables timeouts.
+     * \param retries The number of retries (0 means try the first time only,
+     *     do not retry).
+     */
+    void setTimeouts(int timeout = -1, unsigned retries = 0);
+
+    /**
+     * \short Get info about timeouts.
+     *
+     * \returns Timeout and retries (as described in setTimeouts).
+     */
+    std::pair<int, unsigned> getTimeouts() const;
+
 private:
     RecursorImpl* impl_;
     asiolink::DNSService* dnss_;

+ 38 - 1
src/bin/recurse/tests/recursor_unittest.cc

@@ -165,7 +165,8 @@ TEST_F(RecursorConfig, forwardAddressConfig) {
     EXPECT_EQ(0, server.getForwardAddresses().size());
 }
 
-void RecursorConfig::invalidTest(const string &JOSN) {
+void
+RecursorConfig::invalidTest(const string &JOSN) {
     ElementPtr config(Element::fromJSON(JOSN));
     EXPECT_FALSE(server.updateConfig(config)->equals(
         *isc::config::createAnswer())) << "Accepted config " << JOSN << endl;
@@ -278,4 +279,40 @@ TEST_F(RecursorConfig, invalidListenAddresses) {
         "}]}");
 }
 
+// Just test it sets and gets the values correctly
+TEST_F(RecursorConfig, timeouts) {
+    server.setTimeouts(0, 1);
+    EXPECT_EQ(0, server.getTimeouts().first);
+    EXPECT_EQ(1, server.getTimeouts().second);
+    server.setTimeouts();
+    EXPECT_EQ(-1, server.getTimeouts().first);
+    EXPECT_EQ(0, server.getTimeouts().second);
+}
+
+TEST_F(RecursorConfig, timeoutsConfig) {
+    ElementPtr config = Element::fromJSON("{"
+            "\"timeout\": 1000,"
+            "\"retries\": 3"
+            "}");
+    ConstElementPtr result(server.updateConfig(config));
+    EXPECT_EQ(result->toWire(), isc::config::createAnswer()->toWire());
+    EXPECT_EQ(1000, server.getTimeouts().first);
+    EXPECT_EQ(3, server.getTimeouts().second);
+}
+
+TEST_F(RecursorConfig, invalidTimeoutsConfig) {
+    invalidTest("{"
+        "\"timeout\": \"error\""
+        "}");
+    invalidTest("{"
+        "\"timeout\": -2"
+        "}");
+    invalidTest("{"
+        "\"retries\": \"error\""
+        "}");
+    invalidTest("{"
+        "\"retries\": -1"
+        "}");
+}
+
 }

+ 88 - 12
src/lib/asiolink/asiolink.cc

@@ -45,6 +45,7 @@ using asio::ip::tcp;
 using namespace std;
 using namespace isc::dns;
 using isc::log::dlog;
+using namespace boost;
 
 namespace asiolink {
 
@@ -237,6 +238,18 @@ DNSService::~DNSService() {
 
 namespace {
 
+typedef std::vector<std::pair<std::string, uint16_t> > AddressVector;
+
+}
+
+RecursiveQuery::RecursiveQuery(DNSService& dns_service,
+    const AddressVector& upstream, int timeout, unsigned retries) :
+    dns_service_(dns_service), upstream_(new AddressVector(upstream)),
+    timeout_(timeout), retries_(retries)
+{}
+
+namespace {
+
 ip::address
 convertAddr(const string& address) {
     error_code err;
@@ -256,7 +269,7 @@ DNSService::addServer(const char& port, const string& address) {
 }
 
 void
-DNSService::addServer(uint16_t port, const string &address) {
+DNSService::addServer(uint16_t port, const string& address) {
     impl_->addServer(port, convertAddr(address));
 }
 
@@ -267,27 +280,90 @@ DNSService::clearServers() {
     impl_->servers_.clear();
 }
 
-RecursiveQuery::RecursiveQuery(DNSService& dns_service,
-        const std::vector<std::pair<std::string, uint16_t> >& upstream) :
-    dns_service_(dns_service), upstream_(upstream)
-{}
+namespace {
+
+/*
+ * This is a query in progress. When a new query is made, this one holds
+ * the context information about it, like how many times we are allowed
+ * to retry on failure, what to do when we succeed, etc.
+ *
+ * Used by RecursiveQuery::sendQuery.
+ */
+class RunningQuery : public UDPQuery::Callback {
+        private:
+            // The io service to handle async calls
+            asio::io_service& io_;
+            // Info for (re)sending the query (the question and destination)
+            Question question_;
+            shared_ptr<AddressVector> upstream_;
+            // Buffer to store the result.
+            OutputBufferPtr buffer_;
+            /*
+             * FIXME This is said it does problems when it is shared pointer, as
+             *     it is destroyed too soon. But who deletes it now?
+             */
+            // Server to notify when we succeed or fail
+            shared_ptr<DNSServer> server_;
+            /*
+             * TODO Do something more clever with timeouts. In the long term, some
+             *     computation of average RTT, increase with each retry, etc.
+             */
+            // Timeout information
+            int timeout_;
+            unsigned retries_;
+            // (re)send the query to the server.
+            void send() {
+                int serverIndex(random() % upstream_->size());
+                dlog("Sending upstream query (" + question_.toText() +
+                    ") to " + upstream_->at(serverIndex).first);
+                UDPQuery query(io_, question_,
+                    upstream_->at(serverIndex).first,
+                    upstream_->at(serverIndex).second, buffer_, this,
+                    timeout_);
+                io_.post(query);
+            }
+        public:
+            RunningQuery(asio::io_service& io, const Question &question,
+                shared_ptr<AddressVector> upstream,
+                OutputBufferPtr buffer, DNSServer* server, int timeout,
+                unsigned retries) :
+                io_(io),
+                question_(question),
+                upstream_(upstream),
+                buffer_(buffer),
+                server_(server->clone()),
+                timeout_(timeout),
+                retries_(retries)
+            {
+                send();
+            }
+            // This function is used as callback from DNSQuery.
+            virtual void operator()(UDPQuery::Result result) {
+                if (result == UDPQuery::TIME_OUT && retries_ --) {
+                    dlog("Resending query");
+                    // We timed out, but we have some retries, so send again
+                    send();
+                } else {
+                    server_->resume(result == UDPQuery::SUCCESS);
+                    delete this;
+                }
+            }
+};
+
+}
 
 void
 RecursiveQuery::sendQuery(const Question& question, OutputBufferPtr buffer,
                           DNSServer* server)
 {
-    int serverIndex(random() % upstream_.size());
-    dlog("Sending upstream query (" + question.toText() + ") to " +
-        upstream_[serverIndex].first);
     // XXX: eventually we will need to be able to determine whether
     // the message should be sent via TCP or UDP, or sent initially via
     // UDP and then fall back to TCP on failure, but for the moment
     // we're only going to handle UDP.
     asio::io_service& io = dns_service_.get_io_service();
-    // TODO: Better way to choose the server
-    UDPQuery q(io, question, upstream_[serverIndex].first,
-        upstream_[serverIndex].second, buffer, server);
-    io.post(q);
+    // It will delete itself when it is done
+    new RunningQuery(io, question, upstream_, buffer, server->clone(),
+         timeout_, retries_);
 }
 
 }

+ 11 - 3
src/lib/asiolink/asiolink.h

@@ -22,6 +22,7 @@
 // See the description of the namespace below.
 #include <unistd.h>             // for some network system calls
 #include <asio/ip/address.hpp>
+#include <boost/shared_ptr.hpp>
 
 #include <functional>
 #include <string>
@@ -38,7 +39,6 @@
 #include <asiolink/ioendpoint.h>
 #include <asiolink/iomessage.h>
 #include <asiolink/iosocket.h>
-//#include <asio/io_service.hpp>
 
 namespace asio {
 // forward declaration for IOService::get_io_service() below
@@ -529,9 +529,14 @@ public:
     ///        query on.
     /// \param upstream Addresses and ports of the upstream servers
     ///        to forward queries to.
+    /// \param timeout How long to timeout the query, in ms
+    ///     -1 means never timeout (but do not use that).
+    ///     TODO: This should be computed somehow dynamically in future
+    /// \param retries how many times we try again (0 means just send and
+    ///     and return if it returs).
     RecursiveQuery(DNSService& dns_service,
                    const std::vector<std::pair<std::string, uint16_t> >&
-                   upstream);
+                   upstream, int timeout = -1, unsigned retries = 0);
     //@}
 
     /// \brief Initiates an upstream query in the \c RecursiveQuery object.
@@ -549,7 +554,10 @@ public:
                    DNSServer* server);
 private:
     DNSService& dns_service_;
-    std::vector<std::pair<std::string, uint16_t> > upstream_;
+    boost::shared_ptr<std::vector<std::pair<std::string, uint16_t> > >
+        upstream_;
+    int timeout_;
+    unsigned retries_;
 };
 
 }      // asiolink

+ 41 - 32
src/lib/asiolink/internal/udpdns.h

@@ -207,46 +207,55 @@ private:
 //
 class UDPQuery : public coroutine {
 public:
+    // TODO Maybe this should be more generic than just for UDPQuery?
+    /**
+     * \short Result of the query
+     *
+     * This is related only to contacting the remote server. If the answer
+     * indicates error, it is still counted as SUCCESS here, if it comes back.
+     */
+    enum Result {
+        SUCCESS,
+        TIME_OUT,
+        STOPPED
+    };
+    /// Abstract callback for the UDPQuery.
+    class Callback {
+        public:
+            /// This will be called when the UDPQuery is completed
+            virtual void operator()(Result result) = 0;
+    };
+    /**
+     * \short Constructor.
+     *
+     * It creates the query.
+     * @param callback will be called when we terminate. It is your task to
+     *     delete it if allocated on heap.
+     * @param timeout in ms.
+     */
     explicit UDPQuery(asio::io_service& io_service,
                       const isc::dns::Question& q,
                       const IOAddress& addr, uint16_t port,
                       isc::dns::OutputBufferPtr buffer,
-                      DNSServer* server);
+                      Callback* callback, int timeout = -1);
     void operator()(asio::error_code ec = asio::error_code(),
-                    size_t length = 0); 
+                    size_t length = 0);
+    /// Terminate the query.
+    void stop(Result reason = STOPPED);
 private:
     enum { MAX_LENGTH = 4096 };
 
-    // The \c UDPQuery coroutine never forks, but it is copied whenever
-    // it calls an async_*() function, so it's best to keep copy overhead
-    // small by using pointers or references when possible.  However, this
-    // is not always possible.
-    //
-    // Socket used to for upstream queries. Created in the
-    // constructor and stored in a shared_ptr because socket objects
-    // are not copyable.
-    boost::shared_ptr<asio::ip::udp::socket> socket_;
-
-    // The remote endpoint.  Instantiated in the constructor.  Not
-    // stored as a shared_ptr because copy overhead of an endpoint
-    // object is no larger than that of a shared_ptr.
-    asio::ip::udp::endpoint remote_;
-
-    // The question being answered.  Copied rather than referenced
-    // because the object that created it is not guaranteed to persist.
-    isc::dns::Question question_;
-
-    // The output buffer supplied by the caller.  The resposne frmo
-    // the upstream server will be copied here.
-    isc::dns::OutputBufferPtr buffer_;;
-
-    // These are allocated for each new query and are stored as
-    // shared pointers to minimize copy overhead.
-    isc::dns::OutputBufferPtr msgbuf_;
-    boost::shared_array<char> data_;
-
-    // The UDP or TCP Server object from which the query originated.
-    boost::shared_ptr<DNSServer> server_;
+    /**
+     * \short Private data
+     *
+     * They are not private because of stability of the
+     * interface (this is private class anyway), but because this class
+     * will be copyed often (it is used as a coroutine and passed as callback
+     * to many async_*() functions) and we want keep the same data. Some of
+     * the data is not copyable too.
+     */
+    struct PrivateData;
+    boost::shared_ptr<PrivateData> data_;
 };
 }
 

+ 1 - 0
src/lib/asiolink/tests/Makefile.am

@@ -17,6 +17,7 @@ TESTS += run_unittests
 run_unittests_SOURCES = $(top_srcdir)/src/lib/dns/tests/unittest_util.h
 run_unittests_SOURCES += $(top_srcdir)/src/lib/dns/tests/unittest_util.cc
 run_unittests_SOURCES += asiolink_unittest.cc
+run_unittests_SOURCES += udpdns_unittest.cc
 run_unittests_SOURCES += run_unittests.cc
 run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES)
 run_unittests_LDFLAGS = $(AM_LDFLAGS) $(GTEST_LDFLAGS)

+ 70 - 2
src/lib/asiolink/tests/asiolink_unittest.cc

@@ -20,6 +20,7 @@
 #include <string.h>
 
 #include <boost/lexical_cast.hpp>
+#include <boost/bind.hpp>
 
 #include <gtest/gtest.h>
 
@@ -35,10 +36,14 @@
 #include <asiolink/internal/tcpdns.h>
 #include <asiolink/internal/udpdns.h>
 
+#include <asio.hpp>
+
 using isc::UnitTestUtil;
 using namespace std;
 using namespace asiolink;
 using namespace isc::dns;
+using namespace asio;
+using asio::ip::udp;
 
 namespace {
 const char* const TEST_SERVER_PORT = "53535";
@@ -354,12 +359,12 @@ protected:
                                       NULL, NULL);
     }
 
+    // Set up empty DNS Service
     // Set up an IO Service queue without any addresses
     void setDNSService() {
         delete dns_service_;
         dns_service_ = NULL;
         delete io_service_;
-        io_service_ = NULL;
         io_service_ = new IOService();
         callback_ = new ASIOCallBack(this);
         dns_service_ = new DNSService(*io_service_, callback_, NULL, NULL);
@@ -433,10 +438,11 @@ protected:
             }
         }
 
-    private:
+    protected:
         asio::io_service& io_;
         bool done_;
 
+    private:
         // Currently unused; these will be used for testing
         // asynchronous lookup calls via the asyncLookup() method
         boost::shared_ptr<asiolink::IOMessage> io_message_;
@@ -449,6 +455,26 @@ protected:
         const DNSAnswer* answer_;
     };
 
+    // This version of mock server just stops the io_service when it is resumed
+    class MockServerStop : public MockServer {
+        public:
+            explicit MockServerStop(asio::io_service& io_service, bool* done) :
+                MockServer(io_service, asio::ip::address(), 0),
+                done_(done)
+            {}
+
+            void resume(const bool done) {
+                *done_ = done;
+                io_.stop();
+            }
+
+            DNSServer* clone() {
+                return (new MockServerStop(*this));
+            }
+        private:
+            bool* done_;
+    };
+
 private:
     class ASIOCallBack : public SimpleCallback {
     public:
@@ -642,4 +668,46 @@ TEST_F(ASIOLinkTest, recursiveSend) {
     EXPECT_EQ(q.getClass(), q2->getClass());
 }
 
+void
+receive_and_inc(udp::socket* socket, int* num) {
+    (*num) ++;
+    static char inbuff[512];
+    socket->async_receive(asio::buffer(inbuff, 512),
+        boost::bind(receive_and_inc, socket, num));
+}
+
+// Test it tries the correct amount of times before giving up
+TEST_F(ASIOLinkTest, recursiveTimeout) {
+    // Prepare the service (we do not use the common setup, we do not answer
+    setDNSService();
+    asio::io_service& service = io_service_->get_io_service();
+
+    // Prepare the socket
+    uint16_t port = boost::lexical_cast<uint16_t>(TEST_CLIENT_PORT);
+    udp::socket socket(service, udp::v4());
+    socket.set_option(socket_base::reuse_address(true));
+    socket.bind(udp::endpoint(ip::address::from_string(TEST_IPV4_ADDR), port));
+    // And count the answers
+    int num = -1; // One is counted before the receipt of the first one
+    receive_and_inc(&socket, &num);
+
+    // Prepare the server
+    bool done(true);
+    MockServerStop server(service, &done);
+
+    // Do the answer
+    RecursiveQuery query(*dns_service_, singleAddress(TEST_IPV4_ADDR, port),
+        10, 2);
+    Question question(Name("example.net"), RRClass::IN(), RRType::A());
+    OutputBufferPtr buffer(new OutputBuffer(0));
+    query.sendQuery(question, buffer, &server);
+
+    // Run the test
+    service.run();
+
+    // The query should fail
+    EXPECT_FALSE(done);
+    EXPECT_EQ(3, num);
+}
+
 }

+ 145 - 0
src/lib/asiolink/tests/udpdns_unittest.cc

@@ -0,0 +1,145 @@
+// Copyright (C) 2010  CZ.NIC
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#include <gtest/gtest.h>
+#include <asio.hpp>
+#include <boost/bind.hpp>
+#include <cstdlib>
+
+#include <dns/question.h>
+
+#include <asiolink/internal/udpdns.h>
+
+using namespace asio;
+using namespace isc::dns;
+using asio::ip::udp;
+
+namespace {
+
+const asio::ip::address TEST_HOST(asio::ip::address::from_string("127.0.0.1"));
+const uint16_t TEST_PORT(5301);
+// FIXME Shouldn't we send something that is real message?
+const char TEST_DATA[] = "TEST DATA";
+
+// Test fixture for the asiolink::UDPQuery.
+class UDPQueryTest : public ::testing::Test,
+    public asiolink::UDPQuery::Callback
+{
+    public:
+        // Expected result of the callback
+        asiolink::UDPQuery::Result expected_;
+        // Did the callback run already?
+        bool run_;
+        // We use an io_service to run the query
+        io_service service_;
+        // Something to ask
+        Question question_;
+        // Buffer where the UDPQuery will store response
+        OutputBufferPtr buffer_;
+        // The query we are testing
+        asiolink::UDPQuery query_;
+
+        UDPQueryTest() :
+            run_(false),
+            question_(Name("example.net"), RRClass::IN(), RRType::A()),
+            buffer_(new OutputBuffer(512)),
+            query_(service_, question_, asiolink::IOAddress(TEST_HOST),
+                TEST_PORT, buffer_, this, 100)
+        { }
+
+        // This is the callback's (), so it can be called.
+        void operator()(asiolink::UDPQuery::Result result) {
+            // We check the query returns the correct result
+            EXPECT_EQ(expected_, result);
+            // Check it is called only once
+            EXPECT_FALSE(run_);
+            // And mark the callback was called
+            run_ = true;
+        }
+        // A response handler, pretending to be remote DNS server
+        void respond(udp::endpoint* remote, udp::socket* socket) {
+            // Some data came, just send something back.
+            socket->send_to(asio::buffer(TEST_DATA, sizeof TEST_DATA),
+                *remote);
+            socket->close();
+        }
+};
+
+/*
+ * Test that when we run the query and stop it after it was run,
+ * it returns "stopped" correctly.
+ *
+ * That is why stop() is posted to the service_ as well instead
+ * of calling it.
+ */
+TEST_F(UDPQueryTest, stop) {
+    expected_ = asiolink::UDPQuery::STOPPED;
+    // Post the query
+    service_.post(query_);
+    // Post query_.stop() (yes, the boost::bind thing is just
+    // query_.stop()).
+    service_.post(boost::bind(&asiolink::UDPQuery::stop, query_,
+        asiolink::UDPQuery::STOPPED));
+    // Run both of them
+    service_.run();
+    EXPECT_TRUE(run_);
+}
+
+/*
+ * Test that when we queue the query to service_ and call stop()
+ * before it gets executed, it acts sanely as well (eg. has the
+ * same result as running stop() after - calls the callback).
+ */
+TEST_F(UDPQueryTest, prematureStop) {
+    expected_ = asiolink::UDPQuery::STOPPED;
+    // Stop before it is started
+    query_.stop();
+    service_.post(query_);
+    service_.run();
+    EXPECT_TRUE(run_);
+}
+
+/*
+ * Test that it will timeout when no answer will arrive.
+ */
+TEST_F(UDPQueryTest, timeout) {
+    expected_ = asiolink::UDPQuery::TIME_OUT;
+    service_.post(query_);
+    service_.run();
+    EXPECT_TRUE(run_);
+}
+
+/*
+ * Test that it will succeed when we fake an answer and
+ * stores the same data we send.
+ *
+ * This is done through a real socket on loopback address.
+ */
+TEST_F(UDPQueryTest, receive) {
+    expected_ = asiolink::UDPQuery::SUCCESS;
+    udp::socket socket(service_, udp::v4());
+    socket.set_option(socket_base::reuse_address(true));
+    socket.bind(udp::endpoint(TEST_HOST, TEST_PORT));
+    char inbuff[512];
+    udp::endpoint remote;
+    socket.async_receive_from(asio::buffer(inbuff, 512), remote, boost::bind(
+        &UDPQueryTest::respond, this, &remote, &socket));
+    service_.post(query_);
+    service_.run();
+    EXPECT_TRUE(run_);
+    ASSERT_EQ(sizeof TEST_DATA, buffer_->getLength());
+    EXPECT_EQ(0, memcmp(TEST_DATA, buffer_->getData(), sizeof TEST_DATA));
+}
+
+}

+ 90 - 21
src/lib/asiolink/udpdns.cc

@@ -23,8 +23,10 @@
 #include <boost/bind.hpp>
 
 #include <asio.hpp>
+#include <asio/deadline_timer.hpp>
 
 #include <boost/shared_ptr.hpp>
+#include <boost/date_time/posix_time/posix_time_types.hpp>
 
 #include <dns/buffer.h>
 #include <dns/message.h>
@@ -173,27 +175,64 @@ UDPServer::resume(const bool done) {
     io_.post(*this);
 }
 
+// Private UDPQuery data (see internal/udpdns.h for reasons)
+struct UDPQuery::PrivateData {
+    // Socket we send query to and expect reply from there
+    udp::socket socket;
+    // Where was the query sent
+    udp::endpoint remote;
+    // What we ask the server
+    Question question;
+    // We will store the answer here
+    OutputBufferPtr buffer;
+    OutputBufferPtr msgbuf;
+    // Temporary buffer for answer
+    boost::shared_array<char> data;
+    // This will be called when the data arrive or timeouts
+    Callback* callback;
+    // Did we already stop operating (data arrived, we timed out, someone
+    // called stop). This can be so when we are cleaning up/there are
+    // still pointers to us.
+    bool stopped;
+    // Timer to measure timeouts.
+    deadline_timer timer;
+    // How many milliseconds are we willing to wait for answer?
+    int timeout;
+
+    PrivateData(io_service& service,
+        const udp::socket::protocol_type& protocol, const Question &q,
+        OutputBufferPtr b, Callback *c) :
+        socket(service, protocol),
+        question(q),
+        buffer(b),
+        msgbuf(new OutputBuffer(512)),
+        callback(c),
+        stopped(false),
+        timer(service)
+    { }
+};
+
 /// The following functions implement the \c UDPQuery class.
 ///
 /// The constructor
 UDPQuery::UDPQuery(io_service& io_service,
                    const Question& q, const IOAddress& addr, uint16_t port,
-                   OutputBufferPtr buffer, DNSServer* server) :
-    question_(q), buffer_(buffer), server_(server->clone())
+                   OutputBufferPtr buffer, Callback *callback, int timeout) :
+    data_(new PrivateData(io_service,
+        addr.getFamily() == AF_INET ? udp::v4() : udp::v6(), q, buffer,
+        callback))
 {
-    udp proto = (addr.getFamily() == AF_INET) ? udp::v4() : udp::v6();
-    socket_.reset(new udp::socket(io_service, proto));
-    msgbuf_.reset(new OutputBuffer(512));
-    remote_ = UDPEndpoint(addr, port).getASIOEndpoint();
+    data_->remote = UDPEndpoint(addr, port).getASIOEndpoint();
+    data_->timeout = timeout;
 }
 
 /// The function operator is implemented with the "stackless coroutine"
 /// pattern; see internal/coroutine.h for details.
 void
 UDPQuery::operator()(error_code ec, size_t length) {
-    if (ec) {
+    if (ec || data_->stopped) {
         return;
-    } 
+    }
 
     CORO_REENTER (this) {
         /// Generate the upstream query and render it to wire format
@@ -207,39 +246,69 @@ UDPQuery::operator()(error_code ec, size_t length) {
             msg.setOpcode(Opcode::QUERY());
             msg.setRcode(Rcode::NOERROR());
             msg.setHeaderFlag(Message::HEADERFLAG_RD);
-            msg.addQuestion(question_);
-            MessageRenderer renderer(*msgbuf_);
+            msg.addQuestion(data_->question);
+            MessageRenderer renderer(*data_->msgbuf);
             msg.toWire(renderer);
             dlog("Sending " + msg.toText() + " to " +
-                remote_.address().to_string());
+                data_->remote.address().to_string());
+        }
+
+        // If we timeout, we stop, which will shutdown everything and
+        // cancel all other attempts to run inside the coroutine
+        if (data_->timeout != -1) {
+            data_->timer.expires_from_now(boost::posix_time::milliseconds(
+                data_->timeout));
+            data_->timer.async_wait(boost::bind(&UDPQuery::stop, *this,
+                TIME_OUT));
         }
 
         // Begin an asynchronous send, and then yield.  When the
         // send completes, we will resume immediately after this point.
-        CORO_YIELD socket_->async_send_to(buffer(msgbuf_->getData(),
-                                                 msgbuf_->getLength()),
-                                           remote_, *this);
+        CORO_YIELD data_->socket.async_send_to(buffer(data_->msgbuf->getData(),
+            data_->msgbuf->getLength()), data_->remote, *this);
 
         /// Allocate space for the response.  (XXX: This should be
         /// optimized by maintaining a free list of pre-allocated blocks)
-        data_.reset(new char[MAX_LENGTH]);
+        data_->data.reset(new char[MAX_LENGTH]);
 
         /// Begin an asynchronous receive, and yield.  When the receive
         /// completes, we will resume immediately after this point.
-        CORO_YIELD socket_->async_receive_from(buffer(data_.get(), MAX_LENGTH),
-                                               remote_, *this);
+        CORO_YIELD data_->socket.async_receive_from(buffer(data_->data.get(),
+            MAX_LENGTH), data_->remote, *this);
         // The message is not rendered yet, so we can't print it easilly
-        dlog("Received response from " + remote_.address().to_string());
+        dlog("Received response from " + data_->remote.address().to_string());
 
         /// Copy the answer into the response buffer.  (XXX: If the
         /// OutputBuffer object were made to meet the requirements of
         /// a MutableBufferSequence, then it could be written to directly
         /// by async_recieve_from() and this additional copy step would
         /// be unnecessary.)
-        buffer_->writeData(data_.get(), length);
+        data_->buffer->writeData(data_->data.get(), length);
 
-        /// Signal the DNSServer object to resume processing.
-        server_->resume(true);
+        /// We are done
+        stop(SUCCESS);
+    }
+}
+
+void
+UDPQuery::stop(Result result) {
+    if (!data_->stopped) {
+        switch (result) {
+            case TIME_OUT:
+                dlog("Query timed out");
+                break;
+            case STOPPED:
+                dlog("Query stopped");
+                break;
+            default:;
+        }
+        data_->stopped = true;
+        data_->socket.cancel();
+        data_->socket.close();
+        data_->timer.cancel();
+        if (data_->callback) {
+            (*data_->callback)(result);
+        }
     }
 }