Browse Source

Merge branch 'trac497'

Jelte Jansen 14 years ago
parent
commit
af0e5cd93b

+ 0 - 1
src/bin/resolver/Makefile.am

@@ -37,7 +37,6 @@ spec_config.h: spec_config.h.pre
 BUILT_SOURCES = spec_config.h 
 pkglibexec_PROGRAMS = b10-resolver
 b10_resolver_SOURCES = resolver.cc resolver.h
-b10_resolver_SOURCES += response_classifier.cc response_classifier.h
 b10_resolver_SOURCES += response_scrubber.cc response_scrubber.h
 b10_resolver_SOURCES += $(top_builddir)/src/bin/auth/change_user.h
 b10_resolver_SOURCES += $(top_builddir)/src/bin/auth/common.h

+ 9 - 2
src/bin/resolver/resolver.h

@@ -65,7 +65,10 @@ public:
     /// send the reply.
     ///
     /// \param io_message The raw message received
-    /// \param message Pointer to the \c Message object
+    /// \param query_message Pointer to the query Message object we
+    /// received from the client
+    /// \param answer_message Pointer to the anwer Message object we
+    /// shall return to the client
     /// \param buffer Pointer to an \c OutputBuffer for the resposne
     /// \param server Pointer to the \c DNSServer
     void processMessage(const asiolink::IOMessage& io_message,
@@ -146,7 +149,11 @@ public:
      * \short Set options related to timeouts.
      *
      * This sets the time of timeout and number of retries.
-     * \param timeout The time in milliseconds. The value -1 disables timeouts.
+     * \param query_timeout The timeout we use for queries we send
+     * \param client_timeout The timeout at which point we send back a
+     * SERVFAIL (while continuing to resolve the query)
+     * \param lookup_timeout The timeout at which point we give up and
+     * stop.
      * \param retries The number of retries (0 means try the first time only,
      *     do not retry).
      */

+ 2 - 4
src/bin/resolver/tests/Makefile.am

@@ -19,11 +19,9 @@ TESTS += run_unittests
 run_unittests_SOURCES = $(top_srcdir)/src/lib/dns/tests/unittest_util.h
 run_unittests_SOURCES += $(top_srcdir)/src/lib/dns/tests/unittest_util.cc
 run_unittests_SOURCES += ../resolver.h ../resolver.cc
-run_unittests_SOURCES += ../response_classifier.h ../response_classifier.cc
 run_unittests_SOURCES += ../response_scrubber.h ../response_scrubber.cc
 run_unittests_SOURCES += resolver_unittest.cc
 run_unittests_SOURCES += resolver_config_unittest.cc
-run_unittests_SOURCES += response_classifier_unittest.cc
 run_unittests_SOURCES += response_scrubber_unittest.cc
 run_unittests_SOURCES += run_unittests.cc
 run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES)
@@ -31,8 +29,8 @@ run_unittests_LDFLAGS = $(AM_LDFLAGS) $(GTEST_LDFLAGS)
 run_unittests_LDADD = $(GTEST_LDADD)
 run_unittests_LDADD += $(SQLITE_LIBS)
 run_unittests_LDADD += $(top_builddir)/src/lib/testutils/libtestutils.la
-run_unittests_LDADD +=  $(top_builddir)/src/lib/datasrc/libdatasrc.la
-run_unittests_LDADD +=  $(top_builddir)/src/lib/dns/libdns++.la
+run_unittests_LDADD += $(top_builddir)/src/lib/datasrc/libdatasrc.la
+run_unittests_LDADD += $(top_builddir)/src/lib/dns/libdns++.la
 run_unittests_LDADD += $(top_builddir)/src/lib/asiolink/libasiolink.la
 run_unittests_LDADD += $(top_builddir)/src/lib/config/libcfgclient.la
 run_unittests_LDADD += $(top_builddir)/src/lib/cc/libcc.la

+ 115 - 78
src/lib/asiolink/asiolink.cc

@@ -50,42 +50,6 @@ using namespace isc::dns;
 using isc::log::dlog;
 using namespace boost;
 
-// Is this something we can use in libdns++?
-namespace {
-    class SectionInserter {
-    public:
-        SectionInserter(MessagePtr message, const Message::Section sect) :
-            message_(message), section_(sect)
-        {}
-        void operator()(const RRsetPtr rrset) {
-            message_->addRRset(section_, rrset, true);
-        }
-        MessagePtr message_;
-        const Message::Section section_;
-    };
-
-
-    /// \brief Copies the parts relevant for a DNS answer to the
-    /// target message
-    ///
-    /// This adds all the RRsets in the answer, authority and
-    /// additional sections to the target, as well as the response
-    /// code
-    void copyAnswerMessage(const Message& source, MessagePtr target) {
-        target->setRcode(source.getRcode());
-
-        for_each(source.beginSection(Message::SECTION_ANSWER),
-                 source.endSection(Message::SECTION_ANSWER),
-                 SectionInserter(target, Message::SECTION_ANSWER));
-        for_each(source.beginSection(Message::SECTION_AUTHORITY),
-                 source.endSection(Message::SECTION_AUTHORITY),
-                 SectionInserter(target, Message::SECTION_AUTHORITY));
-        for_each(source.beginSection(Message::SECTION_ADDITIONAL),
-                 source.endSection(Message::SECTION_ADDITIONAL),
-                 SectionInserter(target, Message::SECTION_ADDITIONAL));
-    }
-}
-
 namespace asiolink {
 
 typedef pair<string, uint16_t> addr_t;
@@ -365,6 +329,12 @@ private:
     //shared_ptr<DNSServer> server_;
     isc::resolve::ResolverInterface::CallbackPtr resolvercallback_;
 
+    // To prevent both unreasonably long cname chains and cname loops,
+    // we simply keep a counter of the number of CNAMEs we have
+    // followed so far (and error if it exceeds RESOLVER_MAX_CNAME_CHAIN
+    // from lib/resolve/response_classifier.h)
+    unsigned cname_count_;
+
     /*
      * TODO Do something more clever with timeouts. In the long term, some
      *     computation of average RTT, increase with each retry, etc.
@@ -392,6 +362,11 @@ private:
     // If we timed out ourselves (lookup timeout), stop issuing queries
     bool done_;
 
+    // If we have a client timeout, we send back an answer, but don't
+    // stop. We use this variable to make sure we don't send another
+    // answer if we do find one later (or if we have a lookup_timeout)
+    bool answer_sent_;
+
     // (re)send the query to the server.
     void send() {
         const int uc = upstream_->size();
@@ -429,25 +404,61 @@ private:
     // Note that the footprint may change as this function may
     // need to append data to the answer we are building later.
     //
-    // returns true if we are done
+    // returns true if we are done (either we have an answer or an
+    //              error message)
     // returns false if we are not done
     bool handleRecursiveAnswer(const Message& incoming) {
-        if (incoming.getRRCount(Message::SECTION_ANSWER) > 0) {
-            dlog("Got final result, copying answer.");
-            copyAnswerMessage(incoming, answer_message_);
+        dlog("Handle response");
+        // In case we get a CNAME, we store the target
+        // here (classify() will set it when it walks through
+        // the cname chain to verify it).
+        Name cname_target(question_.getName());
+        
+        isc::resolve::ResponseClassifier::Category category =
+            isc::resolve::ResponseClassifier::classify(
+                question_, incoming, cname_target, cname_count_, true);
+
+        bool found_ns_address = false;
+
+        switch (category) {
+        case isc::resolve::ResponseClassifier::ANSWER:
+        case isc::resolve::ResponseClassifier::ANSWERCNAME:
+            // Done. copy and return.
+            isc::resolve::copyResponseMessage(incoming, answer_message_);
             return true;
-        } else {
-            dlog("Got delegation, continuing");
-            // ok we need to do some more processing.
-            // the ns list should contain all nameservers
-            // while the additional may contain addresses for
-            // them.
-            // this needs to tie into NSAS of course
-            // for this very first mockup, hope there is an
-            // address in additional and just use that
-
-            // send query to the addresses in the delegation
-            bool found_ns_address = false;
+            break;
+        case isc::resolve::ResponseClassifier::CNAME:
+            dlog("Response is CNAME!");
+            // (unfinished) CNAME. We set our question_ to the CNAME
+            // target, then start over at the beginning (for now, that
+            // is, we reset our 'current servers' to the root servers).
+            if (cname_count_ >= RESOLVER_MAX_CNAME_CHAIN) {
+                // just give up
+                dlog("CNAME chain too long");
+                isc::resolve::makeErrorMessage(answer_message_,
+                                               Rcode::SERVFAIL());
+                return true;
+            }
+
+            answer_message_->appendSection(Message::SECTION_ANSWER,
+                                           incoming);
+            setZoneServersToRoot();
+
+            question_ = Question(cname_target, question_.getClass(),
+                                 question_.getType());
+
+            dlog("Following CNAME chain to " + question_.toText());
+            send();
+            return false;
+            break;
+        case isc::resolve::ResponseClassifier::NXDOMAIN:
+            // NXDOMAIN, just copy and return.
+            isc::resolve::copyResponseMessage(incoming, answer_message_);
+            return true;
+            break;
+        case isc::resolve::ResponseClassifier::REFERRAL:
+            // Referral. For now we just take the first glue address
+            // we find and continue with that
             zone_servers_.clear();
 
             for (RRsetIterator rrsi = incoming.beginSection(Message::SECTION_ADDITIONAL);
@@ -466,7 +477,7 @@ private:
                         // to that address and yield, when it
                         // returns, loop again.
                         
-                        // should use NSAS
+                        // TODO should use NSAS
                         zone_servers_.push_back(addr_t(addr_str, 53));
                         found_ns_address = true;
                     }
@@ -478,14 +489,34 @@ private:
                 return false;
             } else {
                 dlog("[XX] no ready-made addresses in additional. need nsas.");
-                // this will result in answering with the delegation. oh well
-                copyAnswerMessage(incoming, answer_message_);
+                // TODO this will result in answering with the delegation. oh well
+                isc::resolve::copyResponseMessage(incoming, answer_message_);
                 return true;
             }
+            break;
+        case isc::resolve::ResponseClassifier::EMPTY:
+        case isc::resolve::ResponseClassifier::EXTRADATA:
+        case isc::resolve::ResponseClassifier::INVNAMCLASS:
+        case isc::resolve::ResponseClassifier::INVTYPE:
+        case isc::resolve::ResponseClassifier::MISMATQUEST:
+        case isc::resolve::ResponseClassifier::MULTICLASS:
+        case isc::resolve::ResponseClassifier::NOTONEQUEST:
+        case isc::resolve::ResponseClassifier::NOTRESPONSE:
+        case isc::resolve::ResponseClassifier::NOTSINGLE:
+        case isc::resolve::ResponseClassifier::OPCODE:
+        case isc::resolve::ResponseClassifier::RCODE:
+        case isc::resolve::ResponseClassifier::TRUNCATED:
+            // Should we try a different server rather than SERVFAIL?
+            isc::resolve::makeErrorMessage(answer_message_,
+                                           Rcode::SERVFAIL());
+            return true;
+            break;
         }
+        // should not be reached. assert here?
+        dlog("[FATAL] unreachable code");
+        return true;
     }
     
-
 public:
     RunningQuery(asio::io_service& io, const Question &question,
         MessagePtr answer_message, shared_ptr<AddressVector> upstream,
@@ -501,12 +532,14 @@ public:
         upstream_root_(upstream_root),
         buffer_(buffer),
         resolvercallback_(cb),
+        cname_count_(0),
         query_timeout_(query_timeout),
         retries_(retries),
         client_timer(io),
         lookup_timer(io),
         queries_out_(0),
-        done_(false)
+        done_(false),
+        answer_sent_(false)
     {
         // Setup the timer to stop trying (lookup_timeout)
         if (lookup_timeout >= 0) {
@@ -525,31 +558,35 @@ public:
         // should use NSAS for root servers
         // Adding root servers if not a forwarder
         if (upstream_->empty()) {
-            if (upstream_root_->empty()) { //if no root ips given, use this
-                zone_servers_.push_back(addr_t("192.5.5.241", 53));
-            }
-            else
-            {
-              //copy the list
-              dlog("Size is " + 
-                    boost::lexical_cast<string>(upstream_root_->size()) + 
-                    "\n");
-              //Use BOOST_FOREACH here? Is it faster?
-              for(AddressVector::iterator it = upstream_root_->begin();
-                   it < upstream_root_->end(); it++) {
-                zone_servers_.push_back(addr_t(it->first,it->second));
-                dlog("Put " + zone_servers_.back().first + "into root list\n");
-              }
-            }
+            setZoneServersToRoot();
         }
 
         send();
     }
 
+    void setZoneServersToRoot() {
+        zone_servers_.clear();
+        if (upstream_root_->empty()) { //if no root ips given, use this
+            zone_servers_.push_back(addr_t("192.5.5.241", 53));
+        } else {
+            // copy the list
+            dlog("Size is " + 
+                boost::lexical_cast<string>(upstream_root_->size()) + 
+                "\n");
+            for(AddressVector::iterator it = upstream_root_->begin();
+                it < upstream_root_->end(); ++it) {
+            zone_servers_.push_back(addr_t(it->first,it->second));
+            dlog("Put " + zone_servers_.back().first + "into root list\n");
+            }
+        }
+    }
     virtual void clientTimeout() {
-        // right now, just stop (should make SERVFAIL and send that
-        // back, but not stop)
-        stop(false);
+        // Return a SERVFAIL, but do not stop until
+        // we have an answer or timeout ourselves
+        isc::resolve::makeErrorMessage(answer_message_,
+                                       Rcode::SERVFAIL());
+        resolvercallback_->success(answer_message_);
+        answer_sent_ = true;
     }
 
     virtual void stop(bool resume) {
@@ -561,7 +598,7 @@ public:
         // same goes if we have an outstanding query (can't delete
         // until that one comes back to us)
         done_ = true;
-        if (resume) {
+        if (resume && !answer_sent_) {
             resolvercallback_->success(answer_message_);
         } else {
             resolvercallback_->failure();
@@ -592,7 +629,7 @@ public:
                 incoming.getRcode() == Rcode::NOERROR()) {
                 done_ = handleRecursiveAnswer(incoming);
             } else {
-                copyAnswerMessage(incoming, answer_message_);
+                isc::resolve::copyResponseMessage(incoming, answer_message_);
                 done_ = true;
             }
             

+ 9 - 6
src/lib/asiolink/asiolink.h

@@ -466,10 +466,12 @@ public:
     /// class.
     ///
     /// \param io_message The event message to handle
-    /// \param message The DNS MessagePtr that needs handling
-    /// \param buffer The result is put here
+    /// \param query_message The DNS MessagePtr of the original query
+    /// \param answer_message The DNS MessagePtr of the answer we are
+    /// building
+    /// \param buffer Intermediate data results are put here
     virtual void operator()(const IOMessage& io_message,
-                            isc::dns::MessagePtr message,
+                            isc::dns::MessagePtr query_message,
                             isc::dns::MessagePtr answer_message,
                             isc::dns::OutputBufferPtr buffer) const = 0;
 };
@@ -546,9 +548,10 @@ public:
     ///        to forward queries to.
     /// \param upstream_root Addresses and ports of the root servers
     ///        to use when resolving.
-    /// \param timeout How long to timeout the query, in ms
-    ///     -1 means never timeout (but do not use that).
-    ///     TODO: This should be computed somehow dynamically in future
+    /// \param query_timeout Timeout value for queries we sent, in ms
+    /// \param client_timeout Timeout value for when we send back an
+    ///        error, in ms
+    /// \param lookup_timeout Timeout value for when we give up, in ms
     /// \param retries how many times we try again (0 means just send and
     ///     and return if it returs).
     RecursiveQuery(DNSService& dns_service,

+ 44 - 12
src/lib/asiolink/tests/asiolink_unittest.cc

@@ -520,6 +520,39 @@ protected:
             bool* done_;
     };
 
+    // This version of mock server just stops the io_service when it is resumed
+    // the second time. (Used in the clientTimeout test, where resume
+    // is called initially with the error answer, and later when the
+    // lookup times out, it is called without an answer to send back)
+    class MockServerStop2 : public MockServer {
+        public:
+            explicit MockServerStop2(IOService& io_service,
+                                     bool* done1, bool* done2) :
+                MockServer(io_service),
+                done1_(done1),
+                done2_(done2),
+                stopped_once_(false)
+            {}
+
+            void resume(const bool done) {
+                if (stopped_once_) {
+                    *done2_ = done;
+                    io_.stop();
+                } else {
+                    *done1_ = done;
+                    stopped_once_ = true;
+                }
+            }
+
+            DNSServer* clone() {
+                return (new MockServerStop2(*this));
+            }
+        private:
+            bool* done1_;
+            bool* done2_;
+            bool stopped_once_;
+    };
+
 private:
     class ASIOCallBack : public SimpleCallback {
     public:
@@ -809,8 +842,9 @@ TEST_F(ASIOLinkTest, forwardClientTimeout) {
     sock_ = createTestSocket();
 
     // Prepare the server
-    bool done(true);
-    MockServerStop server(*io_service_, &done);
+    bool done1(true);
+    bool done2(true);
+    MockServerStop2 server(*io_service_, &done1, &done2);
 
     MessagePtr answer(new Message(Message::RENDER));
 
@@ -818,11 +852,11 @@ TEST_F(ASIOLinkTest, forwardClientTimeout) {
     const uint16_t port = boost::lexical_cast<uint16_t>(TEST_CLIENT_PORT);
     // Set it up to retry twice before client timeout fires
     // Since the lookup timer has not fired, it should retry
-    // a third time
+    // four times
     RecursiveQuery query(*dns_service_,
                          singleAddress(TEST_IPV4_ADDR, port),
                          singleAddress(TEST_IPV4_ADDR, port),
-                         50, 120, 1000, 3);
+                         50, 120, 1000, 4);
     Question question(Name("example.net"), RRClass::IN(), RRType::A());
     OutputBufferPtr buffer(new OutputBuffer(0));
     query.resolve(question, answer, buffer, &server);
@@ -833,17 +867,15 @@ TEST_F(ASIOLinkTest, forwardClientTimeout) {
     // we know it'll fail, so make it a shorter timeout
     int recv_options = setSocketTimeout(sock_, 1, 0);
 
-    // Try to read 5 times, should stop after 3 reads
+    // Try to read 5 times
     int num = 0;
     bool read_success = tryRead(sock_, recv_options, 5, &num);
 
-    // The query should fail (for resolver it should send back servfail,
-    // but currently, and perhaps for forwarder in general, the effect
-    // will be the same as on a lookup timeout, i.e. no answer is sent
-    // back)
-    EXPECT_FALSE(done);
-    EXPECT_EQ(3, num);
-    EXPECT_FALSE(read_success);
+    // The query should fail, but we should have kept on trying
+    EXPECT_TRUE(done1);
+    EXPECT_FALSE(done2);
+    EXPECT_EQ(5, num);
+    EXPECT_TRUE(read_success);
 }
 
 // If we set lookup timeout to lower than querytimeout*retries, we should

+ 29 - 0
src/lib/dns/message.cc

@@ -338,6 +338,14 @@ Message::removeRRset(const Section section, RRsetIterator& iterator) {
     return (removed);
 }
 
+void
+Message::clearSection(const Section section) {
+    if (section >= MessageImpl::NUM_SECTIONS) {
+        isc_throw(OutOfRange, "Invalid message section: " << section);
+    }
+    impl_->rrsets_[section].clear();
+    impl_->counts_[section] = 0;
+}
 
 void
 Message::addQuestion(const QuestionPtr question) {
@@ -769,6 +777,27 @@ Message::clear(Mode mode) {
 }
 
 void
+Message::appendSection(const Section section, const Message& source) {
+    if (section >= MessageImpl::NUM_SECTIONS) {
+        isc_throw(OutOfRange, "Invalid message section: " << section);
+    }
+
+    if (section == SECTION_QUESTION) {
+        for (QuestionIterator qi = source.beginQuestion();
+             qi != source.endQuestion();
+             ++qi) {
+            addQuestion(*qi);
+        }
+    } else {
+        for (RRsetIterator rrsi = source.beginSection(section);
+             rrsi != source.endSection(section);
+             ++rrsi) {
+            addRRset(section, *rrsi);
+        }
+    }
+}
+
+void
 Message::makeResponse() {
     if (impl_->mode_ != Message::PARSE) {
         isc_throw(InvalidMessageOperation,

+ 12 - 0
src/lib/dns/message.h

@@ -483,6 +483,11 @@ public:
     /// found in the specified section.
     bool removeRRset(const Section section, RRsetIterator& iterator);
 
+    /// \brief Remove all RRSets from the given Section
+    ///
+    /// \param section Section to remove all rrsets from
+    void clearSection(const Section section);
+
     // The following methods are not currently implemented.
     //void removeQuestion(QuestionPtr question);
     // notyet:
@@ -493,6 +498,13 @@ public:
     /// specified mode.
     void clear(Mode mode);
 
+    /// \brief Adds all rrsets from the source the given section in the
+    /// source message to the same section of this message
+    ///
+    /// \param section the section to append
+    /// \param target The source Message
+    void appendSection(const Section section, const Message& source);
+
     /// \brief Prepare for making a response from a request.
     ///
     /// This will clear the DNS header except those fields that should be kept

+ 126 - 0
src/lib/dns/tests/message_unittest.cc

@@ -297,6 +297,75 @@ TEST_F(MessageTest, removeRRset) {
     EXPECT_EQ(2, message_render.getRRCount(Message::SECTION_ANSWER));
 }
 
+TEST_F(MessageTest, clearQuestionSection) {
+    QuestionPtr q(new Question(Name("www.example.com"), RRClass::IN(),
+                               RRType::A()));
+    message_render.addQuestion(q);
+    ASSERT_EQ(1, message_render.getRRCount(Message::SECTION_QUESTION));
+
+    message_render.clearSection(Message::SECTION_QUESTION);
+    EXPECT_EQ(0, message_render.getRRCount(Message::SECTION_QUESTION));
+}
+
+
+TEST_F(MessageTest, clearAnswerSection) {
+    // Add two RRsets, check they are present, clear the section,
+    // check if they are gone.
+    message_render.addRRset(Message::SECTION_ANSWER, rrset_a);
+    message_render.addRRset(Message::SECTION_ANSWER, rrset_aaaa);
+    ASSERT_TRUE(message_render.hasRRset(Message::SECTION_ANSWER, test_name,
+        RRClass::IN(), RRType::A()));
+    ASSERT_TRUE(message_render.hasRRset(Message::SECTION_ANSWER, test_name,
+        RRClass::IN(), RRType::AAAA()));
+    ASSERT_EQ(3, message_render.getRRCount(Message::SECTION_ANSWER));
+
+    message_render.clearSection(Message::SECTION_ANSWER);
+    EXPECT_FALSE(message_render.hasRRset(Message::SECTION_ANSWER, test_name,
+        RRClass::IN(), RRType::A()));
+    EXPECT_FALSE(message_render.hasRRset(Message::SECTION_ANSWER, test_name,
+        RRClass::IN(), RRType::AAAA()));
+    EXPECT_EQ(0, message_render.getRRCount(Message::SECTION_ANSWER));
+}
+
+TEST_F(MessageTest, clearAuthoritySection) {
+    // Add two RRsets, check they are present, clear the section,
+    // check if they are gone.
+    message_render.addRRset(Message::SECTION_AUTHORITY, rrset_a);
+    message_render.addRRset(Message::SECTION_AUTHORITY, rrset_aaaa);
+    ASSERT_TRUE(message_render.hasRRset(Message::SECTION_AUTHORITY, test_name,
+        RRClass::IN(), RRType::A()));
+    ASSERT_TRUE(message_render.hasRRset(Message::SECTION_AUTHORITY, test_name,
+        RRClass::IN(), RRType::AAAA()));
+    ASSERT_EQ(3, message_render.getRRCount(Message::SECTION_AUTHORITY));
+
+    message_render.clearSection(Message::SECTION_AUTHORITY);
+    EXPECT_FALSE(message_render.hasRRset(Message::SECTION_AUTHORITY, test_name,
+        RRClass::IN(), RRType::A()));
+    EXPECT_FALSE(message_render.hasRRset(Message::SECTION_AUTHORITY, test_name,
+        RRClass::IN(), RRType::AAAA()));
+    EXPECT_EQ(0, message_render.getRRCount(Message::SECTION_AUTHORITY));
+}
+
+TEST_F(MessageTest, clearAdditionalSection) {
+    // Add two RRsets, check they are present, clear the section,
+    // check if they are gone.
+    message_render.addRRset(Message::SECTION_ADDITIONAL, rrset_a);
+    message_render.addRRset(Message::SECTION_ADDITIONAL, rrset_aaaa);
+    ASSERT_TRUE(message_render.hasRRset(Message::SECTION_ADDITIONAL, test_name,
+        RRClass::IN(), RRType::A()));
+    ASSERT_TRUE(message_render.hasRRset(Message::SECTION_ADDITIONAL, test_name,
+        RRClass::IN(), RRType::AAAA()));
+    ASSERT_EQ(3, message_render.getRRCount(Message::SECTION_ADDITIONAL));
+
+    message_render.clearSection(Message::SECTION_ADDITIONAL);
+    EXPECT_FALSE(message_render.hasRRset(Message::SECTION_ADDITIONAL, test_name,
+        RRClass::IN(), RRType::A()));
+    EXPECT_FALSE(message_render.hasRRset(Message::SECTION_ADDITIONAL, test_name,
+        RRClass::IN(), RRType::AAAA()));
+    EXPECT_EQ(0, message_render.getRRCount(Message::SECTION_ADDITIONAL));
+}
+
+
 TEST_F(MessageTest, badBeginSection) {
     // valid cases are tested via other tests
     EXPECT_THROW(message_render.beginSection(Message::SECTION_QUESTION),
@@ -311,6 +380,63 @@ TEST_F(MessageTest, badEndSection) {
     EXPECT_THROW(message_render.endSection(bogus_section), OutOfRange);
 }
 
+TEST_F(MessageTest, appendSection) {
+    Message target(Message::RENDER);
+
+    // Section check
+    EXPECT_THROW(target.appendSection(bogus_section, message_render),
+                 OutOfRange);
+
+    // Make sure nothing is copied if there is nothing to copy
+    target.appendSection(Message::SECTION_QUESTION, message_render);
+    EXPECT_EQ(0, target.getRRCount(Message::SECTION_QUESTION));
+    target.appendSection(Message::SECTION_ANSWER, message_render);
+    EXPECT_EQ(0, target.getRRCount(Message::SECTION_ANSWER));
+    target.appendSection(Message::SECTION_AUTHORITY, message_render);
+    EXPECT_EQ(0, target.getRRCount(Message::SECTION_AUTHORITY));
+    target.appendSection(Message::SECTION_ADDITIONAL, message_render);
+    EXPECT_EQ(0, target.getRRCount(Message::SECTION_ADDITIONAL));
+
+    // Now add some data, copy again, and see if it got added
+    message_render.addQuestion(Question(Name("test.example.com"),
+                                        RRClass::IN(), RRType::A()));
+    message_render.addRRset(Message::SECTION_ANSWER, rrset_a);
+    message_render.addRRset(Message::SECTION_AUTHORITY, rrset_a);
+    message_render.addRRset(Message::SECTION_ADDITIONAL, rrset_a);
+    message_render.addRRset(Message::SECTION_ADDITIONAL, rrset_aaaa);
+
+    target.appendSection(Message::SECTION_QUESTION, message_render);
+    EXPECT_EQ(1, target.getRRCount(Message::SECTION_QUESTION));
+
+    target.appendSection(Message::SECTION_ANSWER, message_render);
+    EXPECT_EQ(2, target.getRRCount(Message::SECTION_ANSWER));
+    EXPECT_TRUE(target.hasRRset(Message::SECTION_ANSWER, test_name,
+        RRClass::IN(), RRType::A()));
+
+    target.appendSection(Message::SECTION_AUTHORITY, message_render);
+    EXPECT_EQ(2, target.getRRCount(Message::SECTION_AUTHORITY));
+    EXPECT_TRUE(target.hasRRset(Message::SECTION_AUTHORITY, test_name,
+        RRClass::IN(), RRType::A()));
+
+    target.appendSection(Message::SECTION_ADDITIONAL, message_render);
+    EXPECT_EQ(3, target.getRRCount(Message::SECTION_ADDITIONAL));
+    EXPECT_TRUE(target.hasRRset(Message::SECTION_ADDITIONAL, test_name,
+        RRClass::IN(), RRType::A()));
+    EXPECT_TRUE(target.hasRRset(Message::SECTION_ADDITIONAL, test_name,
+        RRClass::IN(), RRType::AAAA()));
+
+    // One more test, test to see if the section gets added, not replaced
+    Message source2(Message::RENDER);
+    source2.addRRset(Message::SECTION_ANSWER, rrset_aaaa);
+    target.appendSection(Message::SECTION_ANSWER, source2);
+    EXPECT_EQ(3, target.getRRCount(Message::SECTION_ANSWER));
+    EXPECT_TRUE(target.hasRRset(Message::SECTION_ANSWER, test_name,
+        RRClass::IN(), RRType::A()));
+    EXPECT_TRUE(target.hasRRset(Message::SECTION_ANSWER, test_name,
+        RRClass::IN(), RRType::AAAA()));
+    
+}
+
 TEST_F(MessageTest, fromWire) {
     factoryFromFile(message_parse, "message_fromWire1");
     EXPECT_EQ(0x1035, message_parse.getQid());

+ 3 - 0
src/lib/nsas/zone_entry.h

@@ -74,6 +74,9 @@ public:
      * \param name Name of the zone
      * \param class_code Class of this zone (zones of different classes have
      *     different objects.
+     * \param nameserver_table Hashtable of NameServerEntry objects for
+     *     this zone
+     * \param namesever_lru LRU for the nameserver entries
      * \todo Move to cc file, include the lookup (if NSAS uses resolver for
      *     everything)
      */

+ 2 - 1
src/lib/resolve/Makefile.am

@@ -10,8 +10,9 @@ AM_CXXFLAGS = $(B10_CXXFLAGS)
 CLEANFILES = *.gcno *.gcda
 
 lib_LTLIBRARIES = libresolve.la
-libresolve_la_SOURCES = resolve.h
+libresolve_la_SOURCES = resolve.h resolve.cc
 libresolve_la_SOURCES += resolver_interface.h
 libresolve_la_SOURCES += resolver_callback.h resolver_callback.cc
+libresolve_la_SOURCES += response_classifier.cc response_classifier.h
 libresolve_la_LIBADD = $(top_builddir)/src/lib/dns/libdns++.la
 libresolve_la_LIBADD += $(top_builddir)/src/lib/exceptions/libexceptions.la

+ 58 - 0
src/lib/resolve/resolve.cc

@@ -0,0 +1,58 @@
+// Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#include <resolve/resolve.h>
+
+using namespace isc::dns;
+
+namespace {
+    class SectionInserter {
+    public:
+        SectionInserter(MessagePtr message, const Message::Section sect) :
+            message_(message), section_(sect)
+        {}
+        void operator()(const RRsetPtr rrset) {
+            message_->addRRset(section_, rrset, true);
+        }
+        MessagePtr message_;
+        const Message::Section section_;
+    };
+}
+
+namespace isc {
+namespace resolve {
+
+void
+makeErrorMessage(MessagePtr answer_message,
+                 const Rcode::Rcode& error_code)
+{
+    answer_message->clearSection(Message::SECTION_ANSWER);
+    answer_message->clearSection(Message::SECTION_AUTHORITY);
+    answer_message->clearSection(Message::SECTION_ADDITIONAL);
+
+    answer_message->setRcode(error_code);
+}
+
+void copyResponseMessage(const Message& source, MessagePtr target) {
+    target->setRcode(source.getRcode());
+
+    target->appendSection(Message::SECTION_ANSWER, source);
+    target->appendSection(Message::SECTION_AUTHORITY, source);
+    target->appendSection(Message::SECTION_ADDITIONAL, source);
+}
+
+
+} // namespace resolve
+} // namespace isc
+

+ 39 - 0
src/lib/resolve/resolve.h

@@ -15,6 +15,45 @@
 #ifndef _ISC_RESOLVE_H
 #define _ISC_RESOLVE_H 1
 
+/// This file includes all other libresolve headers, and provides
+/// several helper functions used in resolving.
+
 #include <resolve/resolver_interface.h>
 #include <resolve/resolver_callback.h>
+#include <resolve/response_classifier.h>
+
+namespace isc {
+namespace resolve {
+
+/// \brief Create an error response
+///
+/// Clears the answer, authority, and additional section of the
+/// given MessagePtr and sets the given error code
+///
+/// Notes: Assuming you have already done initial preparations
+/// on the given answer message (copy the opcode, qid and question
+/// section), you can simply use this to create an error response.
+///
+/// \param answer_message The message to clear and place the error in
+/// \param question The question to add to the
+/// \param error_code The error Rcode
+void makeErrorMessage(isc::dns::MessagePtr answer_message,
+                      const isc::dns::Rcode::Rcode& error_code);
+
+
+/// \brief Copies the parts relevant for a DNS response to the
+/// target message
+///
+/// This adds all the RRsets in the answer, authority and
+/// additional sections to the target, as well as the response
+/// code
+/// \param source The Message to copy the data from
+/// \param target The Message to copy the data to
+void copyResponseMessage(const isc::dns::Message& source,
+                         isc::dns::MessagePtr target);
+
+
+} // namespace resolve
+} // namespace isc
+
 #endif // ISC_RESOLVE_H_

+ 35 - 20
src/bin/resolver/response_classifier.cc

@@ -17,7 +17,7 @@
 #include <cstddef>
 #include <vector>
 
-#include <resolver/response_classifier.h>
+#include <resolve/response_classifier.h>
 #include <dns/name.h>
 #include <dns/opcode.h>
 #include <dns/rcode.h>
@@ -26,24 +26,29 @@
 using namespace isc::dns;
 using namespace std;
 
+namespace isc {
+namespace resolve {
+
 // Classify the response in the "message" object.
 
 ResponseClassifier::Category ResponseClassifier::classify(
-    const Question& question, const MessagePtr& message, bool tcignore)
+    const Question& question, const Message& message, 
+    Name& cname_target, unsigned int& cname_count, bool tcignore
+    )
 {
     // Check header bits
-    if (!message->getHeaderFlag(Message::HEADERFLAG_QR)) {
+    if (!message.getHeaderFlag(Message::HEADERFLAG_QR)) {
         return (NOTRESPONSE);   // Query-response bit not set in the response
     }
 
     // We only recognise responses to queries here
-    if (message->getOpcode() != Opcode::QUERY()) {
+    if (message.getOpcode() != Opcode::QUERY()) {
         return (OPCODE);
     }
 
     // Apparently have a response.  There must be a single question in it...
-    const vector<QuestionPtr> msgquestion(message->beginQuestion(),
-            message->endQuestion());
+    const vector<QuestionPtr> msgquestion(message.beginQuestion(),
+            message.endQuestion());
     if (msgquestion.size() != 1) {
         return (NOTONEQUEST); // Not one question in response question section
     }
@@ -57,7 +62,7 @@ ResponseClassifier::Category ResponseClassifier::classify(
     }
 
     // Check for Rcode-related errors.
-    const Rcode& rcode = message->getRcode();
+    const Rcode& rcode = message.getRcode();
     if (rcode != Rcode::NOERROR()) {
         if (rcode == Rcode::NXDOMAIN()) {
 
@@ -91,7 +96,7 @@ ResponseClassifier::Category ResponseClassifier::classify(
     // probably want to re-query over TCP.  However, in some circumstances we
     // might want to go with what we have.  So give the caller the option of
     // ignoring the TC bit.
-    if (message->getHeaderFlag(Message::HEADERFLAG_TC) && (!tcignore)) {
+    if (message.getHeaderFlag(Message::HEADERFLAG_TC) && (!tcignore)) {
         return (TRUNCATED);
     }
 
@@ -100,12 +105,12 @@ ResponseClassifier::Category ResponseClassifier::classify(
     // referral.  For this, we need to inspect the contents of the answer
     // and authority sections.
     const vector<RRsetPtr> answer(
-            message->beginSection(Message::SECTION_ANSWER),
-            message->endSection(Message::SECTION_ANSWER)
+            message.beginSection(Message::SECTION_ANSWER),
+            message.endSection(Message::SECTION_ANSWER)
             );
     const vector<RRsetPtr> authority(
-            message->beginSection(Message::SECTION_AUTHORITY),
-            message->endSection(Message::SECTION_AUTHORITY)
+            message.beginSection(Message::SECTION_AUTHORITY),
+            message.endSection(Message::SECTION_AUTHORITY)
             );
 
     // If there is nothing in the answer section, it is a referral - unless
@@ -134,6 +139,9 @@ ResponseClassifier::Category ResponseClassifier::classify(
                 (question.getType() == RRType::ANY())) {
                 return (ANSWER);
             } else if (answer[0]->getType() == RRType::CNAME()) {
+                RdataIteratorPtr it = answer[0]->getRdataIterator();
+                cname_target = Name(it->getCurrent().toText());
+                ++cname_count;
                 return (CNAME);
             } else {
                 return (INVTYPE);
@@ -186,14 +194,16 @@ ResponseClassifier::Category ResponseClassifier::classify(
 
     vector<RRsetPtr> ansrrset(answer);
     vector<int> present(ansrrset.size(), 1);
-    return cnameChase(question.getName(), question.getType(), ansrrset, present,
-        ansrrset.size());
+    return cnameChase(question.getName(), question.getType(),
+        cname_target, cname_count,
+        ansrrset, present, ansrrset.size());
 }
 
 // Search the CNAME chain.
 ResponseClassifier::Category ResponseClassifier::cnameChase(
-    const Name& qname, const RRType& qtype, vector<RRsetPtr>& ansrrset,
-    vector<int>& present, size_t size)
+    const Name& qname, const RRType& qtype,
+    Name& cname_target, unsigned int& cname_count,
+    vector<RRsetPtr>& ansrrset, vector<int>& present, size_t size)
 {
     // Search through the vector of RRset pointers until we find one with the
     // right QNAME.
@@ -215,9 +225,10 @@ ResponseClassifier::Category ResponseClassifier::cnameChase(
                     present[i] = 0;
                     --size;
                     if (size == 0) {
+                        RdataIteratorPtr it = ansrrset[i]->getRdataIterator();
+                        cname_target = Name(it->getCurrent().toText());
                         return (CNAME);
-                    }
-                    else {
+                    } else {
                         if (ansrrset[i]->getRdataCount() != 1) {
 
                             // Multiple RDATA for a CNAME?  This is invalid.
@@ -227,8 +238,9 @@ ResponseClassifier::Category ResponseClassifier::cnameChase(
                         RdataIteratorPtr it = ansrrset[i]->getRdataIterator();
                         Name newname(it->getCurrent().toText());
 
-                        return cnameChase(newname, qtype, ansrrset, present,
-                            size);
+                        // Increase CNAME count, and continue
+                        return cnameChase(newname, qtype, cname_target,
+                            ++cname_count, ansrrset, present, size);
                     }
 
                 } else {
@@ -257,3 +269,6 @@ ResponseClassifier::Category ResponseClassifier::cnameChase(
 
     return (EXTRADATA);
 }
+
+} // namespace resolve
+} // namespace isc

+ 22 - 4
src/bin/resolver/response_classifier.h

@@ -23,14 +23,16 @@
 #include <dns/message.h>
 #include <dns/question.h>
 
+#define RESOLVER_MAX_CNAME_CHAIN    16
+
+namespace isc {
+namespace resolve {
+
 /// \brief Classify Server Response
 ///
 /// This class is used in the recursive server.  It is passed an answer received
 /// from an upstream server and categorises it.
 ///
-/// TODO: It is unlikely that the code can be used in this form.  Some adaption
-/// of it will be required to put it in the server.
-///
 /// TODO: The code here does not take into account any EDNS0 fields.
 
 class ResponseClassifier {
@@ -89,13 +91,20 @@ public:
     ///
     /// \param question Question that was sent to the server
     /// \param message Pointer to the associated response from the server.
+    /// \param cname_target If the message contains an (unfinished) CNAME
+    /// chain, this Name will be replaced by the target of the last CNAME
+    /// in the chain
+    /// \param cname_count This unsigned int will be incremented with
+    /// the number of CNAMEs followed
     /// \param tcignore If set, the TC bit in a response packet is
     /// ignored.  Otherwise the error code TRUNCATED will be returned.  The
     /// only time this is likely to be used is in development where we are not
     /// going to fail over to TCP and will want to use what is returned, even
     /// if some of the response was lost.
     static Category classify(const isc::dns::Question& question,
-            const isc::dns::MessagePtr& message, bool tcignore = false);
+            const isc::dns::Message& message, 
+            isc::dns::Name& cname_target, unsigned int& cname_count,
+            bool tcignore = false);
 
 private:
     /// \brief Follow CNAMEs
@@ -127,12 +136,21 @@ private:
     /// management seemed high.  This solution imposes some additional loop
     /// cycles, but that should be minimal compared with the overhead of the
     /// memory management.
+    /// \param cname_target If the message contains an (unfinished) CNAME
+    /// chain, this Name will be replaced by the target of the last CNAME
+    /// in the chain
+    /// \param cname_count This unsigned int will be incremented with
+    /// the number of CNAMEs followed
     /// \param size Number of elements to check.  See description of \c present
     /// for details.
     static Category cnameChase(const isc::dns::Name& qname,
         const isc::dns::RRType& qtype,
+        isc::dns::Name& cname_target, unsigned int& cname_count,
         std::vector<isc::dns::RRsetPtr>& ansrrset, std::vector<int>& present,
         size_t size);
 };
 
 #endif // __RESPONSE_CLASSIFIER_H
+
+} // namespace resolve
+} // namespace isc

+ 3 - 0
src/lib/resolve/tests/Makefile.am

@@ -14,10 +14,13 @@ TESTS += run_unittests
 run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES)
 run_unittests_LDFLAGS = $(AM_LDFLAGS) $(GTEST_LDFLAGS)
 run_unittests_SOURCES = run_unittests.cc
+run_unittests_SOURCES += resolve_unittest.cc
 run_unittests_SOURCES += resolver_callback_unittest.cc
+run_unittests_SOURCES += response_classifier_unittest.cc
 
 run_unittests_LDADD = $(GTEST_LDADD)
 run_unittests_LDADD +=  $(top_builddir)/src/lib/resolve/libresolve.la
+run_unittests_LDADD +=  $(top_builddir)/src/lib/dns/libdns++.la
 
 endif
 

+ 172 - 0
src/lib/resolve/tests/resolve_unittest.cc

@@ -0,0 +1,172 @@
+// Copyright (C) 2010  Internet Systems Consortium, Inc. ("ISC")
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
+// REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+// AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
+// INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+// LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
+// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+// PERFORMANCE OF THIS SOFTWARE.
+
+#include <iostream>
+#include <gtest/gtest.h>
+
+#include <dns/message.h>
+#include <dns/question.h>
+#include <dns/opcode.h>
+#include <dns/rcode.h>
+#include <dns/rrttl.h>
+#include <dns/rdata.h>
+#include <resolve/resolve.h>
+
+using namespace isc::dns;
+
+namespace {
+
+class ResolveHelperFunctionsTest : public ::testing::Test {
+public:
+    ResolveHelperFunctionsTest() :
+        message_a_(new Message(Message::RENDER)),
+        message_b_(new Message(Message::RENDER)),
+        question_(new Question(Name("www.example.com"), RRClass::IN(), RRType::A()))
+    {
+        createMessageA();
+        createMessageB();
+    };
+
+    void createMessageA() {
+        message_a_->setOpcode(Opcode::QUERY());
+        message_a_->setRcode(Rcode::NOERROR());
+        message_a_->addQuestion(question_);
+    }
+
+    void createMessageB() {
+        message_b_->setOpcode(Opcode::QUERY());
+        message_b_->setRcode(Rcode::NOERROR());
+        message_b_->addQuestion(question_);
+
+        // We could reuse the same rrset in the different sections,
+        // but to be sure, we create separate ones
+        RRsetPtr answer_rrset(new RRset(Name("www.example.com"),
+                                        RRClass::IN(), RRType::TXT(),
+                                        RRTTL(3600)));
+        answer_rrset->addRdata(rdata::createRdata(RRType::TXT(),
+                                                  RRClass::IN(),
+                                                  "Answer"));
+        message_b_->addRRset(Message::SECTION_ANSWER, answer_rrset);
+    
+        RRsetPtr auth_rrset(new RRset(Name("www.example.com"),
+                                        RRClass::IN(), RRType::TXT(),
+                                        RRTTL(3600)));
+        auth_rrset->addRdata(rdata::createRdata(RRType::TXT(),
+                                                  RRClass::IN(),
+                                                  "Authority"));
+        auth_rrset->addRdata(rdata::createRdata(RRType::TXT(),
+                                                  RRClass::IN(),
+                                                  "Rdata"));
+        message_b_->addRRset(Message::SECTION_AUTHORITY, auth_rrset);
+    
+        RRsetPtr add_rrset(new RRset(Name("www.example.com"),
+                                     RRClass::IN(), RRType::TXT(),
+                                     RRTTL(3600)));
+        add_rrset->addRdata(rdata::createRdata(RRType::TXT(),
+                                               RRClass::IN(),
+                                               "Additional"));
+        add_rrset->addRdata(rdata::createRdata(RRType::TXT(),
+                                               RRClass::IN(),
+                                               "Rdata"));
+        add_rrset->addRdata(rdata::createRdata(RRType::TXT(),
+                                               RRClass::IN(),
+                                               "fields."));
+        message_b_->addRRset(Message::SECTION_ADDITIONAL, add_rrset);
+    };
+
+    MessagePtr message_a_;
+    MessagePtr message_b_;
+    QuestionPtr question_;
+};
+
+TEST_F(ResolveHelperFunctionsTest, makeErrorMessageEmptyMessage) {
+    ASSERT_EQ(Rcode::NOERROR(), message_a_->getRcode());
+    ASSERT_EQ(1, message_a_->getRRCount(Message::SECTION_QUESTION));
+    ASSERT_EQ(0, message_a_->getRRCount(Message::SECTION_ANSWER));
+    ASSERT_EQ(0, message_a_->getRRCount(Message::SECTION_AUTHORITY));
+    ASSERT_EQ(0, message_a_->getRRCount(Message::SECTION_ADDITIONAL));
+
+    isc::resolve::makeErrorMessage(message_a_, Rcode::SERVFAIL());
+    EXPECT_EQ(Rcode::SERVFAIL(), message_a_->getRcode());
+    EXPECT_EQ(1, message_a_->getRRCount(Message::SECTION_QUESTION));
+    EXPECT_EQ(0, message_a_->getRRCount(Message::SECTION_ANSWER));
+    EXPECT_EQ(0, message_a_->getRRCount(Message::SECTION_AUTHORITY));
+    EXPECT_EQ(0, message_a_->getRRCount(Message::SECTION_ADDITIONAL));
+}
+
+TEST_F(ResolveHelperFunctionsTest, makeErrorMessageNonEmptyMessage) {
+
+    ASSERT_EQ(Rcode::NOERROR(), message_b_->getRcode());
+    ASSERT_EQ(1, message_b_->getRRCount(Message::SECTION_QUESTION));
+    ASSERT_EQ(1, message_b_->getRRCount(Message::SECTION_ANSWER));
+    ASSERT_EQ(2, message_b_->getRRCount(Message::SECTION_AUTHORITY));
+    ASSERT_EQ(3, message_b_->getRRCount(Message::SECTION_ADDITIONAL));
+
+    isc::resolve::makeErrorMessage(message_b_, Rcode::FORMERR());
+    EXPECT_EQ(Rcode::FORMERR(), message_b_->getRcode());
+    EXPECT_EQ(1, message_b_->getRRCount(Message::SECTION_QUESTION));
+    EXPECT_EQ(0, message_b_->getRRCount(Message::SECTION_ANSWER));
+    EXPECT_EQ(0, message_b_->getRRCount(Message::SECTION_AUTHORITY));
+    EXPECT_EQ(0, message_b_->getRRCount(Message::SECTION_ADDITIONAL));
+}
+
+void
+compareSections(const MessagePtr message_a, const MessagePtr message_b,
+                Message::Section section)
+{
+    RRsetIterator rrs_a = message_a->beginSection(section);
+    RRsetIterator rrs_b = message_b->beginSection(section);
+    while (rrs_a != message_a->endSection(section) &&
+           rrs_b != message_b->endSection(section)
+          ) {
+        EXPECT_EQ(*rrs_a, *rrs_b);
+        ++rrs_a;
+        ++rrs_b;
+    }
+    // can't use EXPECT_EQ here, no eqHelper for endsection comparison
+    EXPECT_TRUE(rrs_a == message_a->endSection(section));
+    EXPECT_TRUE(rrs_b == message_b->endSection(section));
+}
+
+TEST_F(ResolveHelperFunctionsTest, copyAnswerMessage) {
+    message_b_->setRcode(Rcode::NXDOMAIN());
+    
+    ASSERT_NE(message_b_->getRcode(), message_a_->getRcode());
+    ASSERT_NE(message_b_->getRRCount(Message::SECTION_ANSWER),
+              message_a_->getRRCount(Message::SECTION_ANSWER));
+    ASSERT_NE(message_b_->getRRCount(Message::SECTION_AUTHORITY),
+              message_a_->getRRCount(Message::SECTION_AUTHORITY));
+    ASSERT_NE(message_b_->getRRCount(Message::SECTION_ADDITIONAL),
+              message_a_->getRRCount(Message::SECTION_ADDITIONAL));
+    ASSERT_EQ(0, message_a_->getRRCount(Message::SECTION_ANSWER));
+    ASSERT_EQ(0, message_a_->getRRCount(Message::SECTION_AUTHORITY));
+    ASSERT_EQ(0, message_a_->getRRCount(Message::SECTION_ADDITIONAL));
+
+    isc::resolve::copyResponseMessage(*message_b_, message_a_);
+
+    EXPECT_EQ(message_b_->getRcode(), message_a_->getRcode());
+    ASSERT_EQ(message_b_->getRRCount(Message::SECTION_ANSWER),
+              message_a_->getRRCount(Message::SECTION_ANSWER));
+    ASSERT_EQ(message_b_->getRRCount(Message::SECTION_AUTHORITY),
+              message_a_->getRRCount(Message::SECTION_AUTHORITY));
+    ASSERT_EQ(message_b_->getRRCount(Message::SECTION_ADDITIONAL),
+              message_a_->getRRCount(Message::SECTION_ADDITIONAL));
+
+    
+    compareSections(message_a_, message_b_, Message::SECTION_ANSWER);
+    compareSections(message_a_, message_b_, Message::SECTION_AUTHORITY);
+    compareSections(message_a_, message_b_, Message::SECTION_ADDITIONAL);
+}
+
+} // Anonymous namespace

+ 185 - 142
src/bin/resolver/tests/response_classifier_unittest.cc

@@ -17,7 +17,7 @@
 
 #include <dns/tests/unittest_util.h>
 
-#include <resolver/response_classifier.h>
+#include <resolve/response_classifier.h>
 
 #include <dns/name.h>
 #include <dns/opcode.h>
@@ -35,6 +35,8 @@ using namespace isc::dns;
 using namespace rdata;
 using namespace isc::dns::rdata::generic;
 using namespace isc::dns::rdata::in;
+using namespace isc::resolve;
+
 
 namespace {
 class ResponseClassifierTest : public ::testing::Test {
@@ -57,8 +59,8 @@ public:
     /// in the early tests where simple messages are required.
 
     ResponseClassifierTest() :
-        msg_a(new Message(Message::RENDER)),
-        msg_any(new Message(Message::RENDER)),
+        msg_a(Message::RENDER),
+        msg_any(Message::RENDER),
         qu_ch_a_www(Name("www.example.com"), RRClass::CH(), RRType::A()),
         qu_in_any_www(Name("www.example.com"), RRClass::IN(), RRType::ANY()),
         qu_in_a_www2(Name("www2.example.com"), RRClass::IN(), RRType::A()),
@@ -79,20 +81,22 @@ public:
         rrs_in_ns_(new RRset(Name("example.com"), RRClass::IN(),
             RRType::NS(), RRTTL(300))),
         rrs_in_txt_www(new RRset(Name("www.example.com"), RRClass::IN(),
-            RRType::TXT(), RRTTL(300)))
+            RRType::TXT(), RRTTL(300))),
+        cname_target("."),
+        cname_count(0)
     {
         // Set up the message to indicate a successful response to the question
         // "www.example.com A", but don't add in any response sections.
-        msg_a->setHeaderFlag(Message::HEADERFLAG_QR);
-        msg_a->setOpcode(Opcode::QUERY());
-        msg_a->setRcode(Rcode::NOERROR());
-        msg_a->addQuestion(qu_in_a_www);
+        msg_a.setHeaderFlag(Message::HEADERFLAG_QR);
+        msg_a.setOpcode(Opcode::QUERY());
+        msg_a.setRcode(Rcode::NOERROR());
+        msg_a.addQuestion(qu_in_a_www);
 
         // ditto for the query "www.example.com ANY"
-        msg_any->setHeaderFlag(Message::HEADERFLAG_QR);
-        msg_any->setOpcode(Opcode::QUERY());
-        msg_any->setRcode(Rcode::NOERROR());
-        msg_any->addQuestion(qu_in_any_www);
+        msg_any.setHeaderFlag(Message::HEADERFLAG_QR);
+        msg_any.setOpcode(Opcode::QUERY());
+        msg_any.setRcode(Rcode::NOERROR());
+        msg_any.addQuestion(qu_in_any_www);
 
         // The next set of assignments set up the following zone records
         //
@@ -127,8 +131,8 @@ public:
             new CNAME("www1.example.com")));
     }
 
-    MessagePtr  msg_a;              // Pointer to message in RENDER state
-    MessagePtr  msg_any;            // Pointer to message in RENDER state
+    Message     msg_a;              // Pointer to message in RENDER state
+    Message     msg_any;            // Pointer to message in RENDER state
     Question    qu_ch_a_www;        // www.example.com CH A
     Question    qu_in_any_www;      // www.example.com IN ANY
     Question    qu_in_a_www2;       // www.example.com IN ANY
@@ -143,6 +147,10 @@ public:
     RRsetPtr    rrs_in_cname_www2;  // www2.example.com IN CNAME
     RRsetPtr    rrs_in_ns_;         // example.com IN NS
     RRsetPtr    rrs_in_txt_www;     // www.example.com IN TXT
+    Name        cname_target;       // Used in response classifier to
+                                    // store the target of a possible
+                                    // CNAME chain
+    unsigned int cname_count;       // Used to count cnames in a chain
 };
 
 // Test that the error() function categorises the codes correctly.
@@ -174,11 +182,12 @@ TEST_F(ResponseClassifierTest, Query) {
 
     // Set up message to indicate a query (QR flag = 0, one question).  By
     // default the opcode will be 0 (query)
-    msg_a->setHeaderFlag(Message::HEADERFLAG_QR, false);
+    msg_a.setHeaderFlag(Message::HEADERFLAG_QR, false);
 
     // Should be rejected as it is a query, not a response
     EXPECT_EQ(ResponseClassifier::NOTRESPONSE,
-        ResponseClassifier::classify(qu_in_a_www, msg_a));
+        ResponseClassifier::classify(qu_in_a_www, msg_a,
+                                     cname_target, cname_count));
 }
 
 // Check that we get an OPCODE error on all but QUERY opcodes.
@@ -188,13 +197,15 @@ TEST_F(ResponseClassifierTest, Opcode) {
     uint8_t query = static_cast<uint8_t>(Opcode::QUERY().getCode());
 
     for (uint8_t i = 0; i < (1 << 4); ++i) {
-        msg_a->setOpcode(Opcode(i));
+        msg_a.setOpcode(Opcode(i));
         if (i == query) {
             EXPECT_NE(ResponseClassifier::OPCODE,
-                ResponseClassifier::classify(qu_in_a_www, msg_a));
+                ResponseClassifier::classify(qu_in_a_www, msg_a,
+                                             cname_target, cname_count));
         } else {
             EXPECT_EQ(ResponseClassifier::OPCODE,
-                ResponseClassifier::classify(qu_in_a_www, msg_a));
+                ResponseClassifier::classify(qu_in_a_www, msg_a,
+                                             cname_target, cname_count));
         }
     }
 }
@@ -205,29 +216,33 @@ TEST_F(ResponseClassifierTest, Opcode) {
 TEST_F(ResponseClassifierTest, MultipleQuestions) {
 
     // Create a message object for this test that has no question section.
-    MessagePtr message(new Message(Message::RENDER));
-    message->setHeaderFlag(Message::HEADERFLAG_QR);
-    message->setOpcode(Opcode::QUERY());
-    message->setRcode(Rcode::NOERROR());
+    Message message(Message::RENDER);
+    message.setHeaderFlag(Message::HEADERFLAG_QR);
+    message.setOpcode(Opcode::QUERY());
+    message.setRcode(Rcode::NOERROR());
 
     // Zero questions
     EXPECT_EQ(ResponseClassifier::NOTONEQUEST,
-        ResponseClassifier::classify(qu_in_a_www, message));
+        ResponseClassifier::classify(qu_in_a_www, message,
+                                     cname_target, cname_count));
 
     // One question
-    message->addQuestion(qu_in_a_www);
+    message.addQuestion(qu_in_a_www);
     EXPECT_NE(ResponseClassifier::NOTONEQUEST,
-        ResponseClassifier::classify(qu_in_a_www, message));
+        ResponseClassifier::classify(qu_in_a_www, message,
+                                     cname_target, cname_count));
 
     // Two questions
-    message->addQuestion(qu_in_ns_);
+    message.addQuestion(qu_in_ns_);
     EXPECT_EQ(ResponseClassifier::NOTONEQUEST,
-        ResponseClassifier::classify(qu_in_a_www, message));
+        ResponseClassifier::classify(qu_in_a_www, message,
+                                     cname_target, cname_count));
 
     // And finish the check with three questions
-    message->addQuestion(qu_in_txt_www);
+    message.addQuestion(qu_in_txt_www);
     EXPECT_EQ(ResponseClassifier::NOTONEQUEST,
-        ResponseClassifier::classify(qu_in_a_www, message));
+        ResponseClassifier::classify(qu_in_a_www, message,
+                                     cname_target, cname_count));
 }
 
 // Test that the question in the question section in the message response
@@ -236,9 +251,11 @@ TEST_F(ResponseClassifierTest, MultipleQuestions) {
 TEST_F(ResponseClassifierTest, SameQuestion) {
 
     EXPECT_EQ(ResponseClassifier::MISMATQUEST,
-        ResponseClassifier::classify(qu_in_ns_, msg_a));
+        ResponseClassifier::classify(qu_in_ns_, msg_a,
+                                     cname_target, cname_count));
     EXPECT_NE(ResponseClassifier::MISMATQUEST,
-        ResponseClassifier::classify(qu_in_a_www, msg_a));
+        ResponseClassifier::classify(qu_in_a_www, msg_a,
+                                     cname_target, cname_count));
 }
 
 // Should get an NXDOMAIN response only on an NXDOMAIN RCODE.
@@ -248,13 +265,15 @@ TEST_F(ResponseClassifierTest, NXDOMAIN) {
     uint16_t nxdomain = static_cast<uint16_t>(Rcode::NXDOMAIN().getCode());
 
     for (uint8_t i = 0; i < (1 << 4); ++i) {
-        msg_a->setRcode(Rcode(i));
+        msg_a.setRcode(Rcode(i));
         if (i == nxdomain) {
             EXPECT_EQ(ResponseClassifier::NXDOMAIN,
-                ResponseClassifier::classify(qu_in_a_www, msg_a));
+                ResponseClassifier::classify(qu_in_a_www, msg_a,
+                                             cname_target, cname_count));
         } else {
             EXPECT_NE(ResponseClassifier::NXDOMAIN,
-                ResponseClassifier::classify(qu_in_a_www, msg_a));
+                ResponseClassifier::classify(qu_in_a_www, msg_a,
+                                             cname_target, cname_count));
         }
     }
 }
@@ -267,13 +286,15 @@ TEST_F(ResponseClassifierTest, RCODE) {
     uint16_t noerror = static_cast<uint16_t>(Rcode::NOERROR().getCode());
 
     for (uint8_t i = 0; i < (1 << 4); ++i) {
-        msg_a->setRcode(Rcode(i));
+        msg_a.setRcode(Rcode(i));
         if ((i == nxdomain) || (i == noerror)) {
             EXPECT_NE(ResponseClassifier::RCODE,
-                ResponseClassifier::classify(qu_in_a_www, msg_a));
+                ResponseClassifier::classify(qu_in_a_www, msg_a,
+                                             cname_target, cname_count));
         } else {
             EXPECT_EQ(ResponseClassifier::RCODE,
-                ResponseClassifier::classify(qu_in_a_www, msg_a));
+                ResponseClassifier::classify(qu_in_a_www, msg_a,
+                                             cname_target, cname_count));
         }
     }
 }
@@ -287,19 +308,23 @@ TEST_F(ResponseClassifierTest, Truncated) {
 
     // Don't expect the truncated code whatever option we ask for if the TC
     // bit is not set.
-    msg_a->setHeaderFlag(Message::HEADERFLAG_TC, false);
+    msg_a.setHeaderFlag(Message::HEADERFLAG_TC, false);
     EXPECT_NE(ResponseClassifier::TRUNCATED,
-        ResponseClassifier::classify(qu_in_a_www, msg_a, true));
+        ResponseClassifier::classify(qu_in_a_www, msg_a, cname_target,
+                                     cname_count, true));
     EXPECT_NE(ResponseClassifier::TRUNCATED,
-        ResponseClassifier::classify(qu_in_a_www, msg_a, false));
+        ResponseClassifier::classify(qu_in_a_www, msg_a, cname_target,
+                                     cname_count, false));
 
     // Expect the truncated code if the TC bit is set, only if we don't ignore
     // it.
-    msg_a->setHeaderFlag(Message::HEADERFLAG_TC, true);
+    msg_a.setHeaderFlag(Message::HEADERFLAG_TC, true);
     EXPECT_NE(ResponseClassifier::TRUNCATED,
-        ResponseClassifier::classify(qu_in_a_www, msg_a, true));
+        ResponseClassifier::classify(qu_in_a_www, msg_a, cname_target,
+                                     cname_count, true));
     EXPECT_EQ(ResponseClassifier::TRUNCATED,
-        ResponseClassifier::classify(qu_in_a_www, msg_a, false));
+        ResponseClassifier::classify(qu_in_a_www, msg_a, cname_target,
+                                     cname_count, false));
 }
 
 // Check for an empty packet (i.e. no error, but with the answer and additional
@@ -308,7 +333,8 @@ TEST_F(ResponseClassifierTest, Truncated) {
 TEST_F(ResponseClassifierTest, Empty) {
 
     EXPECT_EQ(ResponseClassifier::EMPTY,
-        ResponseClassifier::classify(qu_in_a_www, msg_a));
+        ResponseClassifier::classify(qu_in_a_www, msg_a, cname_target,
+                                     cname_count));
 }
 
 // Anything where we have an empty answer section but something in the
@@ -316,9 +342,10 @@ TEST_F(ResponseClassifierTest, Empty) {
 
 TEST_F(ResponseClassifierTest, EmptyAnswerReferral) {
 
-    msg_a->addRRset(Message::SECTION_AUTHORITY, rrs_in_ns_);
+    msg_a.addRRset(Message::SECTION_AUTHORITY, rrs_in_ns_);
     EXPECT_EQ(ResponseClassifier::REFERRAL,
-        ResponseClassifier::classify(qu_in_a_www, msg_a));
+        ResponseClassifier::classify(qu_in_a_www, msg_a, cname_target,
+                                     cname_count));
 
 }
 
@@ -330,60 +357,66 @@ TEST_F(ResponseClassifierTest, EmptyAnswerReferral) {
 TEST_F(ResponseClassifierTest, SingleAnswer) {
 
     // Check a question that matches the answer
-    msg_a->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    msg_a.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
     EXPECT_EQ(ResponseClassifier::ANSWER,
-        ResponseClassifier::classify(qu_in_a_www, msg_a));
+        ResponseClassifier::classify(qu_in_a_www, msg_a, cname_target,
+                                     cname_count));
 
     // Check an ANY question that matches the answer
-    msg_any->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    msg_any.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
     EXPECT_EQ(ResponseClassifier::ANSWER,
-        ResponseClassifier::classify(qu_in_any_www, msg_any));
+        ResponseClassifier::classify(qu_in_any_www, msg_any, cname_target,
+                                     cname_count));
 
     // Check a CNAME response that matches the QNAME.
-    MessagePtr message_a(new Message(Message::RENDER));
-    message_a->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_a->setOpcode(Opcode::QUERY());
-    message_a->setRcode(Rcode::NOERROR());
-    message_a->addQuestion(qu_in_cname_www1);
-    message_a->addRRset(Message::SECTION_ANSWER, rrs_in_cname_www1);
+    Message message_a(Message::RENDER);
+    message_a.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_a.setOpcode(Opcode::QUERY());
+    message_a.setRcode(Rcode::NOERROR());
+    message_a.addQuestion(qu_in_cname_www1);
+    message_a.addRRset(Message::SECTION_ANSWER, rrs_in_cname_www1);
     EXPECT_EQ(ResponseClassifier::CNAME,
-        ResponseClassifier::classify(qu_in_cname_www1, message_a));
+        ResponseClassifier::classify(qu_in_cname_www1, message_a,
+                                     cname_target, cname_count));
 
     // Check if the answer QNAME does not match the question
     // Q: www.example.com  IN A
     // A: mail.example.com IN A
-    MessagePtr message_b(new Message(Message::RENDER));
-    message_b->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_b->setOpcode(Opcode::QUERY());
-    message_b->setRcode(Rcode::NOERROR());
-    message_b->addQuestion(qu_in_a_www);
-    message_b->addRRset(Message::SECTION_ANSWER, rrs_in_a_mail);
+    Message message_b(Message::RENDER);
+    message_b.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_b.setOpcode(Opcode::QUERY());
+    message_b.setRcode(Rcode::NOERROR());
+    message_b.addQuestion(qu_in_a_www);
+    message_b.addRRset(Message::SECTION_ANSWER, rrs_in_a_mail);
     EXPECT_EQ(ResponseClassifier::INVNAMCLASS,
-        ResponseClassifier::classify(qu_in_a_www, message_b));
+        ResponseClassifier::classify(qu_in_a_www, message_b,
+                                     cname_target, cname_count));
 
     // Check if the answer class does not match the question
     // Q: www.example.com CH A
     // A: www.example.com IN A
-    MessagePtr message_c(new Message(Message::RENDER));
-    message_c->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_c->setOpcode(Opcode::QUERY());
-    message_c->setRcode(Rcode::NOERROR());
-    message_c->addQuestion(qu_ch_a_www);
-    message_c->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    Message message_c(Message::RENDER);
+    message_c.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_c.setOpcode(Opcode::QUERY());
+    message_c.setRcode(Rcode::NOERROR());
+    message_c.addQuestion(qu_ch_a_www);
+    message_c.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
     EXPECT_EQ(ResponseClassifier::INVNAMCLASS,
-        ResponseClassifier::classify(qu_ch_a_www, message_c));
+        ResponseClassifier::classify(qu_ch_a_www, message_c,
+                                     cname_target, cname_count));
 
     // Check if the answer type does not match the question
     // Q: www.example.com IN A
     // A: www.example.com IN TXT
-    MessagePtr message_d(new Message(Message::RENDER));
-    message_d->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_d->setOpcode(Opcode::QUERY());
-    message_d->setRcode(Rcode::NOERROR());
-    message_d->addQuestion(qu_in_a_www);
-    message_d->addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
+    Message message_d(Message::RENDER);
+    message_d.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_d.setOpcode(Opcode::QUERY());
+    message_d.setRcode(Rcode::NOERROR());
+    message_d.addQuestion(qu_in_a_www);
+    message_d.addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
     EXPECT_EQ(ResponseClassifier::INVTYPE,
-        ResponseClassifier::classify(qu_in_a_www, message_d));
+        ResponseClassifier::classify(qu_in_a_www, message_d,
+                                     cname_target, cname_count));
 }
 
 // Check what happens if we have multiple RRsets in the answer.
@@ -391,60 +424,65 @@ TEST_F(ResponseClassifierTest, SingleAnswer) {
 TEST_F(ResponseClassifierTest, MultipleAnswerRRsets) {
 
     // All the same QNAME but different types is only valid on an ANY query.
-    MessagePtr message_a(new Message(Message::RENDER));
-    message_a->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_a->setOpcode(Opcode::QUERY());
-    message_a->setRcode(Rcode::NOERROR());
-    message_a->addQuestion(qu_in_any_www);
-    message_a->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
-    message_a->addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
+    Message message_a(Message::RENDER);
+    message_a.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_a.setOpcode(Opcode::QUERY());
+    message_a.setRcode(Rcode::NOERROR());
+    message_a.addQuestion(qu_in_any_www);
+    message_a.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    message_a.addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
     EXPECT_EQ(ResponseClassifier::ANSWER,
-        ResponseClassifier::classify(qu_in_any_www, message_a));
+        ResponseClassifier::classify(qu_in_any_www, message_a,
+                                     cname_target, cname_count));
 
     // On another type of query, it results in an EXTRADATA error
-    MessagePtr message_b(new Message(Message::RENDER));
-    message_b->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_b->setOpcode(Opcode::QUERY());
-    message_b->setRcode(Rcode::NOERROR());
-    message_b->addQuestion(qu_in_a_www);
-    message_b->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
-    message_b->addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
+    Message message_b(Message::RENDER);
+    message_b.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_b.setOpcode(Opcode::QUERY());
+    message_b.setRcode(Rcode::NOERROR());
+    message_b.addQuestion(qu_in_a_www);
+    message_b.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    message_b.addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
     EXPECT_EQ(ResponseClassifier::EXTRADATA,
-        ResponseClassifier::classify(qu_in_a_www, message_b));
+        ResponseClassifier::classify(qu_in_a_www, message_b,
+                                     cname_target, cname_count));
 
     // Same QNAME on an ANY query is not valid with mixed classes
-    MessagePtr message_c(new Message(Message::RENDER));
-    message_c->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_c->setOpcode(Opcode::QUERY());
-    message_c->setRcode(Rcode::NOERROR());
-    message_c->addQuestion(qu_in_any_www);
-    message_c->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
-    message_c->addRRset(Message::SECTION_ANSWER, rrs_hs_txt_www);
+    Message message_c(Message::RENDER);
+    message_c.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_c.setOpcode(Opcode::QUERY());
+    message_c.setRcode(Rcode::NOERROR());
+    message_c.addQuestion(qu_in_any_www);
+    message_c.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    message_c.addRRset(Message::SECTION_ANSWER, rrs_hs_txt_www);
     EXPECT_EQ(ResponseClassifier::MULTICLASS,
-        ResponseClassifier::classify(qu_in_any_www, message_c));
+        ResponseClassifier::classify(qu_in_any_www, message_c,
+                                     cname_target, cname_count));
 
     // Mixed QNAME is not valid unless QNAME requested is a CNAME.
-    MessagePtr message_d(new Message(Message::RENDER));
-    message_d->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_d->setOpcode(Opcode::QUERY());
-    message_d->setRcode(Rcode::NOERROR());
-    message_d->addQuestion(qu_in_a_www);
-    message_d->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
-    message_d->addRRset(Message::SECTION_ANSWER, rrs_in_a_mail);
+    Message message_d(Message::RENDER);
+    message_d.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_d.setOpcode(Opcode::QUERY());
+    message_d.setRcode(Rcode::NOERROR());
+    message_d.addQuestion(qu_in_a_www);
+    message_d.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    message_d.addRRset(Message::SECTION_ANSWER, rrs_in_a_mail);
     EXPECT_EQ(ResponseClassifier::EXTRADATA,
-        ResponseClassifier::classify(qu_in_a_www, message_d));
+        ResponseClassifier::classify(qu_in_a_www, message_d,
+                                     cname_target, cname_count));
 
     // Mixed QNAME is not valid when the query is an ANY.
-    MessagePtr message_e(new Message(Message::RENDER));
-    message_e->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_e->setOpcode(Opcode::QUERY());
-    message_e->setRcode(Rcode::NOERROR());
-    message_e->addQuestion(qu_in_any_www);
-    message_e->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
-    message_e->addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
-    message_e->addRRset(Message::SECTION_ANSWER, rrs_in_a_mail);
+    Message message_e(Message::RENDER);
+    message_e.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_e.setOpcode(Opcode::QUERY());
+    message_e.setRcode(Rcode::NOERROR());
+    message_e.addQuestion(qu_in_any_www);
+    message_e.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    message_e.addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
+    message_e.addRRset(Message::SECTION_ANSWER, rrs_in_a_mail);
     EXPECT_EQ(ResponseClassifier::EXTRADATA,
-        ResponseClassifier::classify(qu_in_any_www, message_e));
+        ResponseClassifier::classify(qu_in_any_www, message_e,
+                                     cname_target, cname_count));
 }
 
 // CNAME chain is CNAME if it terminates in a CNAME, answer if it
@@ -452,43 +490,48 @@ TEST_F(ResponseClassifierTest, MultipleAnswerRRsets) {
 TEST_F(ResponseClassifierTest, CNAMEChain) {
 
     // Answer contains a single CNAME
-    MessagePtr message_a(new Message(Message::RENDER));
-    message_a->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_a->setOpcode(Opcode::QUERY());
-    message_a->setRcode(Rcode::NOERROR());
-    message_a->addQuestion(qu_in_a_www2);
-    message_a->addRRset(Message::SECTION_ANSWER, rrs_in_cname_www2);
+    Message message_a(Message::RENDER);
+    message_a.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_a.setOpcode(Opcode::QUERY());
+    message_a.setRcode(Rcode::NOERROR());
+    message_a.addQuestion(qu_in_a_www2);
+    message_a.addRRset(Message::SECTION_ANSWER, rrs_in_cname_www2);
     EXPECT_EQ(ResponseClassifier::CNAME,
-        ResponseClassifier::classify(qu_in_a_www2, message_a));
+        ResponseClassifier::classify(qu_in_a_www2, message_a,
+                                     cname_target, cname_count));
 
     // Add a CNAME for www1, and it should still return a CNAME
-    message_a->addRRset(Message::SECTION_ANSWER, rrs_in_cname_www1);
+    message_a.addRRset(Message::SECTION_ANSWER, rrs_in_cname_www1);
     EXPECT_EQ(ResponseClassifier::CNAME,
-        ResponseClassifier::classify(qu_in_a_www2, message_a));
+        ResponseClassifier::classify(qu_in_a_www2, message_a,
+                                     cname_target, cname_count));
 
     // Add the A record for www and it should be an answer
-    message_a->addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
+    message_a.addRRset(Message::SECTION_ANSWER, rrs_in_a_www);
     EXPECT_EQ(ResponseClassifier::ANSWERCNAME,
-        ResponseClassifier::classify(qu_in_a_www2, message_a));
+        ResponseClassifier::classify(qu_in_a_www2, message_a,
+                                     cname_target, cname_count));
 
     // Adding an unrelated TXT record should result in EXTRADATA
-    message_a->addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
+    message_a.addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
     EXPECT_EQ(ResponseClassifier::EXTRADATA,
-        ResponseClassifier::classify(qu_in_a_www2, message_a));
+        ResponseClassifier::classify(qu_in_a_www2, message_a,
+                                     cname_target, cname_count));
 
     // Recreate the chain, but this time end with a TXT RR and not the A
     // record.  This should return INVTYPE.
-    MessagePtr message_b(new Message(Message::RENDER));
-    message_b->setHeaderFlag(Message::HEADERFLAG_QR);
-    message_b->setOpcode(Opcode::QUERY());
-    message_b->setRcode(Rcode::NOERROR());
-    message_b->addQuestion(qu_in_a_www2);
-    message_b->addRRset(Message::SECTION_ANSWER, rrs_in_cname_www2);
-    message_b->addRRset(Message::SECTION_ANSWER, rrs_in_cname_www1);
-    message_b->addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
+    Message message_b(Message::RENDER);
+    message_b.setHeaderFlag(Message::HEADERFLAG_QR);
+    message_b.setOpcode(Opcode::QUERY());
+    message_b.setRcode(Rcode::NOERROR());
+    message_b.addQuestion(qu_in_a_www2);
+    message_b.addRRset(Message::SECTION_ANSWER, rrs_in_cname_www2);
+    message_b.addRRset(Message::SECTION_ANSWER, rrs_in_cname_www1);
+    message_b.addRRset(Message::SECTION_ANSWER, rrs_in_txt_www);
 
     EXPECT_EQ(ResponseClassifier::INVTYPE,
-        ResponseClassifier::classify(qu_in_a_www2, message_b));
+        ResponseClassifier::classify(qu_in_a_www2, message_b,
+                                     cname_target, cname_count));
 }
 
 } // Anonymous namespace