Browse Source

Address entry can ask for its IP

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac408@3539 e5f2f494-b856-4b98-b285-d166d9295462
Michal Vaner 14 years ago
parent
commit
e6d8478ced

+ 116 - 0
src/lib/nsas/nameserver_entry.cc

@@ -28,14 +28,17 @@
 #include <dns/name.h>
 #include <dns/rrclass.h>
 #include <dns/rrttl.h>
+#include <dns/question.h>
 
 #include "address_entry.h"
 #include "nameserver_entry.h"
+#include "resolver_interface.h"
 
 using namespace asiolink;
 using namespace isc::nsas;
 using namespace isc::dns;
 using namespace std;
+using namespace boost;
 
 namespace isc {
 namespace nsas {
@@ -177,5 +180,118 @@ void NameserverEntry::setAddressUnreachable(const IOAddress& address) {
     setAddressRTT(address, AddressEntry::UNREACHABLE);
 }
 
+void NameserverEntry::ensureHasCallback(shared_ptr<ZoneEntry> zone,
+    NameserverEntry::Callback& callback)
+{
+    if (getState() != Fetchable::IN_PROGRESS) {
+        isc_throw(BadValue,
+            "Callbacks can be added only to IN_PROGRESS nameserver entries");
+    }
+    if (ipCallbacks_.find(zone) == ipCallbacks_.end()) {
+        ipCallbacks_[zone] = &callback;
+    }
+}
+
+class NameserverEntry::ResolverCallback : public ResolverInterface::Callback {
+    public:
+        ResolverCallback(shared_ptr<NameserverEntry> entry) :
+            entry_(entry),
+            rtt(0)
+        { }
+        virtual void success(const Message& response) {
+            map<shared_ptr<ZoneEntry>, NameserverEntry::Callback*> callbacks;
+            bool has_address(false);
+            {
+                mutex::scoped_lock lock(entry_->mutex_);
+                for (RRsetIterator set(
+                    // TODO Trunk does Section::ANSWER() by constant
+                    response.beginSection(Section::ANSWER()));
+                    set != response.endSection(Section::ANSWER()); ++ set)
+                {
+                    /**
+                     * TODO Move to common function, this is similar to
+                     * what is in constructor.
+                     */
+                    RdataIteratorPtr i((*set)->getRdataIterator());
+                    // TODO Remove at merge with #410
+                    i->first();
+                    while (! i->isLast()) {
+                        has_address = true;
+                        entry_->address_.push_back(AddressEntry(IOAddress(
+                        i->getCurrent().toText()), ++rtt));
+                        i->next();
+                    }
+                }
+                if (has_address) {
+                    callbacks.swap(entry_->ipCallbacks_);
+                    entry_->waiting_responses_ --;
+                    entry_->setState(Fetchable::READY);
+                }
+            } // Unlock
+            if (has_address) {
+                dispatchCallbacks(callbacks);
+            } else {
+                // No address there, so we take it as a failure
+                failure();
+            }
+        }
+        virtual void failure() {
+            map<shared_ptr<ZoneEntry>, NameserverEntry::Callback*> callbacks;
+            {
+                mutex::scoped_lock lock(entry_->mutex_);
+                entry_->waiting_responses_ --;
+                // Do we still have a chance to get the answer?
+                // Or did we already?
+                if (entry_->waiting_responses_ ||
+                    entry_->getState() != Fetchable::IN_PROGRESS)
+                {
+                    return;
+                }
+                // Remove the callbacks and call them now
+                callbacks.swap(entry_->ipCallbacks_);
+                entry_->setState(Fetchable::UNREACHABLE);
+            } // Unlock
+            dispatchCallbacks(callbacks);
+        }
+    private:
+        shared_ptr<NameserverEntry> entry_;
+        int rtt;
+        void dispatchCallbacks(map<shared_ptr<ZoneEntry>,
+            NameserverEntry::Callback*>& callbacks)
+        {
+            /*
+             * FIXME This approach is not completely exception safe.
+             * If we get an exception from callback, we lose the other
+             * callbacks.
+             */
+            for (map<shared_ptr<ZoneEntry>, NameserverEntry::Callback*>::
+                iterator i(callbacks.begin()); i != callbacks.end(); ++ i)
+            {
+                shared_ptr<ZoneEntry> zone(i->first);
+                (*i->second)(zone);
+            }
+        }
+};
+
+void NameserverEntry::askIP(ResolverInterface& resolver,
+    shared_ptr<ZoneEntry> zone, NameserverEntry::Callback& callback,
+    shared_ptr<NameserverEntry> self)
+{
+    if (getState() != Fetchable::NOT_ASKED) {
+        isc_throw(BadValue,
+            "Asking to resolve an IP address, but it was asked before");
+    }
+    setState(Fetchable::IN_PROGRESS);
+    ipCallbacks_[zone] = &callback;
+
+    shared_ptr<ResolverCallback> resolver_callback(new ResolverCallback(self));
+    waiting_responses_ = 2;
+    // TODO Should we ask for both A and AAAA in all occations?
+    resolver.resolve(QuestionPtr(new Question(Name(getName()),
+        RRClass(getClass()), RRType::A())), resolver_callback);
+    resolver.resolve(QuestionPtr(new Question(Name(getName()),
+        RRClass(getClass()), RRType::AAAA())), resolver_callback);
+}
+
 } // namespace dns
 } // namespace isc

+ 55 - 0
src/lib/nsas/nameserver_entry.h

@@ -57,6 +57,7 @@ public:
 };
 
 class ZoneEntry;
+class ResolverInterface;
 
 /// \brief Nameserver Entry
 ///
@@ -196,6 +197,53 @@ public:
         }
     };
 
+    /// \name Obtaining the IP addresses from resolver
+    //@{
+    /// \short A callback that some information here arrived (or are unavailable).
+    struct Callback {
+        virtual void operator()(boost::shared_ptr<ZoneEntry>) = 0;
+    };
+
+    /**
+     * \short Asks the resolver for IP address (or addresses).
+     *
+     * Adds a callback for given zone when they are ready or the information
+     * is found unreachable.
+     *
+     * This does not lock and expects that the entry is already locked.
+     *
+     * Expects that the nameserver entry is in NOT_ASKED state,
+     * throws BadValue otherwise.
+     *
+     * \param resolver Who to ask.
+     * \param zone The callbacks are named, so we can check if we already have
+     *     a callback for given zone. This is the name and the zone will be
+     *     passed to the callback when called.
+     * \param callback The callback.
+     * \param self Since we need to pass a shared pointer to the resolver, we
+     *     need to get one. However, we can not create one from this, because
+     *     it would have different reference count. So the caller must pass it.
+     */
+    void askIP(ResolverInterface& resolver, boost::shared_ptr<ZoneEntry> zone,
+        Callback& callback, boost::shared_ptr<NameserverEntry> self);
+    /**
+     * \short Ensures that zone has a callback registered.
+     *
+     * This adds a given callback to this nameserver entry, but only if
+     * the zone does not have one already.
+     *
+     * Does not lock and expects that the entry is already locked.
+     *
+     * Expects that the nameserver entri is in IN_PROGRESS state, throws
+     * BadValue otherwise.
+     *
+     * \param zone Whose callback we add.
+     * \param callback The callback.
+     */
+    void ensureHasCallback(boost::shared_ptr<ZoneEntry> zone,
+        Callback& callback);
+    //@}
+
 private:
     boost::mutex    mutex_;             ///< Mutex protecting this object
     std::string     name_;              ///< Canonical name of the nameserver
@@ -205,6 +253,13 @@ private:
     time_t          last_access_;       ///< Last access time to the structure
     // We allow ZoneEntry to lock us
     friend class ZoneEntry;
+    // We store the callbacks of zones asking for addresses here
+    std::map<boost::shared_ptr<ZoneEntry>, Callback*> ipCallbacks_;
+    // This is our callback class to resolver
+    class ResolverCallback;
+    friend class ResolverCallback;
+    // How many responses from resolver do we expect?
+    size_t waiting_responses_;
 };
 
 }   // namespace dns

+ 3 - 39
src/lib/nsas/tests/nameserver_address_store_unittest.cc

@@ -29,7 +29,6 @@
 #include <boost/foreach.hpp>
 
 #include <string.h>
-#include <vector>
 #include <cassert>
 
 #include "../nameserver_address_store.h"
@@ -117,42 +116,7 @@ protected:
 
     RRsetPtr authority_, empty_authority_;
 
-    class TestResolver : public ResolverInterface {
-        public:
-            typedef pair<QuestionPtr, CallbackPtr> Request;
-            vector<Request> requests;
-            virtual void resolve(QuestionPtr q, CallbackPtr c) {
-                requests.push_back(Request(q, c));
-            }
-            QuestionPtr operator[](size_t index) {
-                return (requests[index].first);
-            }
-    } defaultTestResolver;
-
-    /**
-     * Looks if the two provided requests in resolver are A and AAAA.
-     * Sorts them so index1 is A.
-     */
-    void asksIPs(const Name& name, size_t index1, size_t index2) {
-        size_t max = (index1 < index2) ? index2 : index1;
-        ASSERT_GT(defaultTestResolver.requests.size(), max);
-        EXPECT_EQ(name, defaultTestResolver[index1]->getName());
-        EXPECT_EQ(name, defaultTestResolver[index2]->getName());
-        EXPECT_EQ(RRClass::IN(), defaultTestResolver[index1]->getClass());
-        EXPECT_EQ(RRClass::IN(), defaultTestResolver[index2]->getClass());
-        // If they are the other way around, swap
-        if (defaultTestResolver[index1]->getType() == RRType::AAAA() &&
-            defaultTestResolver[index2]->getType() == RRType::A())
-        {
-            TestResolver::Request tmp(defaultTestResolver.requests[index1]);
-            defaultTestResolver.requests[index1] =
-                defaultTestResolver.requests[index2];
-            defaultTestResolver.requests[index2] = tmp;
-        }
-        // Check the correct addresses
-        EXPECT_EQ(RRType::A(), defaultTestResolver[index1]->getType());
-        EXPECT_EQ(RRType::AAAA(), defaultTestResolver[index1]->getType());
-    }
+    TestResolver defaultTestResolver;
 
     class NSASCallback : public AddressRequestCallback {
         public:
@@ -247,7 +211,7 @@ TEST_F(NameserverAddressStoreTest, emptyLookup) {
         vector<AbstractRRset>(), getCallback());
     // It should ask for IP addresses for example.com.
     ASSERT_EQ(2, defaultTestResolver.requests.size());
-    asksIPs(Name("example.com."), 0, 1);
+    defaultTestResolver.asksIPs(Name("example.com."), 0, 1);
 
     // Ask another question for the same zone
     nsas.lookup("example.net.", RRClass::IN().getCode(), *authority_,
@@ -310,7 +274,7 @@ TEST_F(NameserverAddressStoreTest, unreachableNS) {
         vector<AbstractRRset>(), getCallback());
     // It should ask for IP addresses for example.com.
     ASSERT_EQ(2, defaultTestResolver.requests.size());
-    asksIPs(Name("example.com."), 0, 1);
+    defaultTestResolver.asksIPs(Name("example.com."), 0, 1);
 
     // Ask another question with different zone but the same nameserver
     authority_->setName(Name("example.com."));

+ 80 - 0
src/lib/nsas/tests/nameserver_entry_unittest.cc

@@ -19,17 +19,21 @@
 
 #include <limits.h>
 #include <boost/foreach.hpp>
+#include <boost/shared_ptr.hpp>
 #include <gtest/gtest.h>
 
 #include <dns/rdata.h>
 #include <dns/rrset.h>
 #include <dns/rrclass.h>
+#include <dns/rdataclass.h>
 #include <dns/rrttl.h>
 #include <dns/name.h>
+#include <exceptions/exceptions.h>
 
 #include "../asiolink.h"
 #include "../address_entry.h"
 #include "../nameserver_entry.h"
+#include "../zone_entry.h"
 
 #include "nsas_test.h"
 
@@ -37,6 +41,7 @@ using namespace asiolink;
 using namespace std;
 using namespace isc::dns;
 using namespace rdata;
+using namespace boost;
 
 namespace isc {
 namespace nsas {
@@ -102,6 +107,15 @@ protected:
     BasicRRset rrns_;           ///< NS RRset
     BasicRRset rrv6_;           ///< Standard RRset, IN, AAAA, lowercase name
     BasicRRset rrnet_;          ///< example.net A RRset
+
+    /// \short Just a really stupid callback counting times called
+    struct Callback : public NameserverEntry::Callback {
+        size_t count;
+        virtual void operator()(shared_ptr<ZoneEntry>) {
+            count ++;
+        }
+        Callback() : count(0) { }
+    };
 };
 
 /// \brief Compare Vectors of String
@@ -451,6 +465,72 @@ TEST_F(NameserverEntryTest, CheckClass) {
 
 }
 
+// Tests if it asks the IP addresses and calls callbacks when it comes
+TEST_F(NameserverEntryTest, IPCallbacks) {
+    shared_ptr<NameserverEntry> entry(new NameserverEntry(EXAMPLE_CO_UK,
+        RRClass::IN().getCode()));
+    Callback callback;
+    TestResolver resolver;
+    shared_ptr<ZoneEntry> no_zone, zone(new ZoneEntry(EXAMPLE_CO_UK,
+        RRClass::IN().getCode()));
+
+    // Ensure that we do not add callbacks now
+    EXPECT_THROW(entry->ensureHasCallback(no_zone, callback), isc::BadValue);
+
+    entry->askIP(resolver, no_zone, callback, entry);
+    // Ensure it becomes IN_PROGRESS
+    EXPECT_EQ(Fetchable::IN_PROGRESS, entry->getState());
+    // Ensure we can ask for IP address only once
+    EXPECT_THROW(entry->askIP(resolver, no_zone, callback, entry),
+        isc::BadValue);
+
+    // Add a callback for a different zone
+    entry->ensureHasCallback(zone, callback);
+
+    // Now, there should be two queries in the resolver
+    ASSERT_EQ(2, resolver.requests.size());
+    resolver.asksIPs(Name(EXAMPLE_CO_UK), 0, 1);
+
+    // Answer one and see that the callbacks are called
+    RRsetPtr answer(new RRset(Name(EXAMPLE_CO_UK), RRClass::IN(), RRType::A(),
+        RRTTL(100)));
+    answer->addRdata(rdata::in::A("192.0.2.1"));
+    Message address(Message::RENDER); // Not able to create different one
+    address.addRRset(Section::ANSWER(), answer);
+    address.addQuestion(resolver[0]);
+    resolver.requests[0].second->success(address);
+
+    // Both callbacks should be called by now
+    EXPECT_EQ(2, callback.count);
+    // It should contain one IP address
+    NameserverEntry::AddressVector addresses;
+    entry->getAddresses(addresses);
+    EXPECT_EQ(1, addresses.size());
+    EXPECT_EQ(Fetchable::READY, entry->getState());
+}
+
+// Test the callback is called even when the address is unreachable
+TEST_F(NameserverEntryTest, IPCallbacksUnreachable) {
+    shared_ptr<NameserverEntry> entry(new NameserverEntry(EXAMPLE_CO_UK,
+        RRClass::IN().getCode()));
+    Callback callback;
+    TestResolver resolver;
+    shared_ptr<ZoneEntry> no_zone;
+
+    // Ask for its IP
+    entry->askIP(resolver, no_zone, callback, entry);
+    // Check it asks the resolver
+    ASSERT_EQ(2, resolver.requests.size());
+    resolver.asksIPs(Name(EXAMPLE_CO_UK), 0, 1);
+    resolver.requests[0].second->failure();
+    // It should still wait for the second one
+    EXPECT_EQ(0, callback.count);
+    EXPECT_EQ(Fetchable::IN_PROGRESS, entry->getState());
+    // It should call the callback now and be unrechable
+    resolver.requests[1].second->failure();
+    EXPECT_EQ(1, callback.count);
+    EXPECT_EQ(Fetchable::UNREACHABLE, entry->getState());
+}
 
 }   // namespace nsas
 }   // namespace isc

+ 44 - 0
src/lib/nsas/tests/nsas_test.h

@@ -23,6 +23,7 @@
 /// address store tests.
 
 #include <string>
+#include <vector>
 
 #include <config.h>
 
@@ -31,6 +32,7 @@
 #include <dns/rrtype.h>
 #include <dns/messagerenderer.h>
 #include "../nsas_entry.h"
+#include "../resolver_interface.h"
 
 using namespace isc::dns::rdata;
 using namespace isc::dns;
@@ -213,4 +215,46 @@ static const uint32_t HASHTABLE_DEFAULT_SIZE = 1009; ///< First prime above 1000
 } // namespace nsas
 } // namespace isc
 
+namespace {
+
+using namespace std;
+
+class TestResolver : public isc::nsas::ResolverInterface {
+    public:
+        typedef pair<QuestionPtr, CallbackPtr> Request;
+        vector<Request> requests;
+        virtual void resolve(QuestionPtr q, CallbackPtr c) {
+            requests.push_back(Request(q, c));
+        }
+        QuestionPtr operator[](size_t index) {
+            return (requests[index].first);
+        }
+        /**
+         * Looks if the two provided requests in resolver are A and AAAA.
+         * Sorts them so index1 is A.
+         */
+        void asksIPs(const Name& name, size_t index1, size_t index2) {
+            size_t max = (index1 < index2) ? index2 : index1;
+            ASSERT_GT(requests.size(), max);
+            EXPECT_EQ(name, (*this)[index1]->getName());
+            EXPECT_EQ(name, (*this)[index2]->getName());
+            EXPECT_EQ(RRClass::IN(), (*this)[index1]->getClass());
+            EXPECT_EQ(RRClass::IN(), (*this)[index2]->getClass());
+            // If they are the other way around, swap
+            if ((*this)[index1]->getType() == RRType::AAAA() &&
+                (*this)[index2]->getType() == RRType::A())
+            {
+                TestResolver::Request tmp((*this).requests[index1]);
+                (*this).requests[index1] =
+                    (*this).requests[index2];
+                (*this).requests[index2] = tmp;
+            }
+            // Check the correct addresses
+            EXPECT_EQ(RRType::A(), (*this)[index1]->getType());
+            EXPECT_EQ(RRType::AAAA(), (*this)[index2]->getType());
+        }
+};
+
+} // Empty namespace
+
 #endif // __NSAS_TEST_H