Browse Source

commit the lastest change for message entry

zhanglikun 14 years ago
parent
commit
42beef26ba

+ 1 - 0
src/lib/cache/TODO

@@ -3,4 +3,5 @@
 * Once LRU hash table is implemented, it should be used by message/rrset cache.  
 * Once the hash/lrulist related files in /lib/nsas is moved to seperated
   folder, the code of recursor cache has to be updated.
+* Set proper AD flags once DNSSEC is supported by the cache.
 

+ 2 - 1
src/lib/cache/message_cache.h

@@ -20,6 +20,7 @@
 #include <map>
 #include <string>
 #include <boost/shared_ptr.hpp>
+#include <boost/noncopyable.hpp>
 #include <dns/message.h>
 #include "message_entry.h"
 #include <nsas/hash_table.h>
@@ -34,7 +35,7 @@ class RRsetCache;
 /// The object of MessageCache represents the cache for class-specific 
 /// messages.
 ///
-class MessageCache {
+class MessageCache: public boost::noncopyable {
 public:
     /// \param cache_size The size of message cache.
     MessageCache(boost::shared_ptr<RRsetCache> rrset_cache_,

+ 63 - 38
src/lib/cache/message_entry.cc

@@ -17,7 +17,6 @@
 #include <limits>
 #include <dns/message.h>
 #include <nsas/nsas_entry.h>
-#include <nsas/fetchable.h>
 #include "message_entry.h"
 #include "rrset_cache.h"
 
@@ -33,14 +32,53 @@ MessageEntry::MessageEntry(const isc::dns::Message& msg,
                            boost::shared_ptr<RRsetCache> rrset_cache):
     rrset_cache_(rrset_cache),
     headerflag_aa_(false),
-    headerflag_tc_(false),
-    headerflag_ad_(false)
+    headerflag_tc_(false)
 {
     initMessageEntry(msg);
     entry_name_ = genCacheEntryName(query_name_, query_type_);
+    hash_key_ptr_ = new HashKey(entry_name_, RRClass(query_class_));
 }
     
 bool
+MessageEntry::getRRsetEntries(vector<RRsetEntryPtr>& rrset_entry_vec, 
+                              const time_t time_now)
+{
+    uint16_t entry_count = answer_count_ + authority_count_ + additional_count_;
+    for (int index = 0; index < entry_count; ++index) {
+        RRsetEntryPtr rrset_entry = rrset_cache_->lookup(rrsets_[index].name_, 
+                                                        rrsets_[index].type_);
+        if (time_now < rrset_entry->getExpireTime()) {
+            rrset_entry_vec.push_back(rrset_entry);
+        } else {
+            return false;
+        }
+    }
+
+    return true;
+}
+
+void
+MessageEntry::addRRset(isc::dns::Message& message,
+                       const vector<RRsetEntryPtr> rrset_entry_vec,
+                       isc::dns::Message::Section section,
+                       bool dnssec_need) {
+    uint16_t start_index = 0;
+    uint16_t end_index = answer_count_;
+
+    if (section == Message::SECTION_AUTHORITY) {
+        start_index = answer_count_;
+        end_index = answer_count_ + authority_count_;
+    } else if (section == Message::SECTION_ADDITIONAL) {
+        start_index = answer_count_ + authority_count_;
+        end_index = start_index + additional_count_;
+    }
+
+    for(uint16_t index = start_index; index < end_index; ++index) {
+        message.addRRset(section, rrset_entry_vec[index]->getRRset(), dnssec_need);
+    }
+}
+
+bool
 MessageEntry::genMessage(const time_t& time_now,
                          isc::dns::Message& msg)
 {
@@ -48,31 +86,24 @@ MessageEntry::genMessage(const time_t& time_now,
         // The message entry has expired.
         return false;
     } else {
-        // We don't need to add question section, since it has 
-        // been included in the message.
-        ConstEDNSPtr edns(msg.getEDNS());
-        bool dnssec_need = edns;
-        uint16_t index = 0;
-        // Add answer section's rrsets.
-        for(index = 0; index < answer_count_; index++) {
-            msg.addRRset(Message::SECTION_ANSWER, 
-                         rrsets_[index]->getRRset(), dnssec_need);
-        }
-        
-        // Add authority section's rrsets.
-        uint16_t end = answer_count_ + authority_count_;
-        for(index = answer_count_; index < end; index++) {
-            msg.addRRset(Message::SECTION_AUTHORITY, 
-                         rrsets_[index]->getRRset(), dnssec_need);
+        // Before do any generation, we should check if some rrset
+        // has expired, if it is, return false.
+        vector<RRsetEntryPtr> rrset_entry_vec;
+        if (false == getRRsetEntries(rrset_entry_vec, time_now)) {
+            return false;
         }
 
-        // Add additional section's rrsets.
-        index = end;
-        end = end + additional_count_;
-        for(; index < end; index++) {
-            msg.addRRset(Message::SECTION_ADDITIONAL, 
-                         rrsets_[index]->getRRset(), dnssec_need);
-        }
+        // Begin message generation. We don't need to add question 
+        // section, since it has been included in the message.
+        // Set cached header flags.
+        msg.setHeaderFlag(Message::HEADERFLAG_AA, headerflag_aa_);
+        msg.setHeaderFlag(Message::HEADERFLAG_TC, headerflag_tc_);
+
+        bool dnssec_need = msg.getEDNS().get();
+        addRRset(msg, rrset_entry_vec, Message::SECTION_ANSWER, dnssec_need);
+        addRRset(msg, rrset_entry_vec, Message::SECTION_AUTHORITY, dnssec_need);
+        addRRset(msg, rrset_entry_vec, Message::SECTION_ADDITIONAL, dnssec_need);
+
         return true;
     }
 }
@@ -148,7 +179,7 @@ MessageEntry::getRRsetTrustLevel(const Message& message,
 }
 
 void
-MessageEntry::parseRRset(const isc::dns::Message& msg,
+MessageEntry::parseSection(const isc::dns::Message& msg,
                          const Message::Section& section,
                          uint32_t& smaller_ttl, 
                          uint16_t& rrset_count)
@@ -164,8 +195,8 @@ MessageEntry::parseRRset(const isc::dns::Message& msg,
         RRsetPtr rrset_ptr = *iter;
         RRsetTrustLevel level = getRRsetTrustLevel(msg, rrset_ptr, section);
         RRsetEntryPtr rrset_entry = rrset_cache_->update(*rrset_ptr, level);
-        rrsets_.push_back(rrset_entry);
-        
+        rrsets_.push_back(RRsetRef(rrset_ptr->getName(), rrset_ptr->getType()));
+
         uint32_t rrset_ttl = rrset_entry->getTTL();
         if (smaller_ttl > rrset_ttl) {
             smaller_ttl = rrset_ttl;
@@ -182,7 +213,6 @@ MessageEntry::initMessageEntry(const isc::dns::Message& msg) {
     //TODO better way to cache the header flags?
     headerflag_aa_ = msg.getHeaderFlag(Message::HEADERFLAG_AA);
     headerflag_tc_ = msg.getHeaderFlag(Message::HEADERFLAG_TC);
-    headerflag_ad_ = msg.getHeaderFlag(Message::HEADERFLAG_AD);
 
     // We only cache the first question in question section.
     // TODO, do we need to support muptiple questions?
@@ -193,18 +223,13 @@ MessageEntry::initMessageEntry(const isc::dns::Message& msg) {
     query_class_ = (*iter)->getClass().getCode();
     
     uint32_t min_ttl = MAX_UINT32;
-    parseRRset(msg, Message::SECTION_ANSWER, min_ttl, answer_count_);
-    parseRRset(msg, Message::SECTION_AUTHORITY, min_ttl, authority_count_);
-    parseRRset(msg, Message::SECTION_ADDITIONAL, min_ttl, additional_count_);
+    parseSection(msg, Message::SECTION_ANSWER, min_ttl, answer_count_);
+    parseSection(msg, Message::SECTION_AUTHORITY, min_ttl, authority_count_);
+    parseSection(msg, Message::SECTION_ADDITIONAL, min_ttl, additional_count_);
 
     expire_time_ = time(NULL) + min_ttl;
 }
 
-HashKey
-MessageEntry::hashKey() const {
-    return HashKey(entry_name_, RRClass(query_class_));
-}
-
 } // namespace cache
 } // namespace isc
 

+ 46 - 9
src/lib/cache/message_entry.h

@@ -20,8 +20,8 @@
 #include <vector>
 #include <dns/message.h>
 #include <dns/rrset.h>
+#include <boost/noncopyable.hpp>
 #include <nsas/nsas_entry.h>
-#include <nsas/fetchable.h>
 #include "rrset_entry.h"
 
 
@@ -33,10 +33,23 @@ namespace cache {
 class RRsetEntry;
 class RRsetCache;
 
+/// \brief Information to refer an RRset.
+/// There is no class information here, since the rrsets
+/// are cached in the class-specific rrset cache.
+struct RRsetRef{
+    RRsetRef(const isc::dns::Name& name, const isc::dns::RRType& type):
+            name_(name), type_(type)
+    {}
+
+    isc::dns::Name name_; // Name of rrset.
+    isc::dns::RRType type_; // Type of rrset. 
+};
+
 /// \brief Message Entry
 /// The object of MessageEntry represents one response message
 /// answered to the recursor client. 
-class MessageEntry : public NsasEntry<MessageEntry>
+class MessageEntry : public NsasEntry<MessageEntry>,
+                     public boost::noncopyable
 {
 public:
 
@@ -63,7 +76,9 @@ public:
     
     /// \brief Get the hash key of the message entry.
     /// \return return hash key
-    virtual HashKey hashKey() const;
+    virtual HashKey hashKey() const {
+        return *hash_key_ptr_;
+    }
 
 protected:
     /// \brief Initialize the message entry with dns message.
@@ -79,7 +94,7 @@ protected:
     /// \param rrset_count set the rrset count of the section.
     /// (TODO for Message, getRRsetCount() should be one interface provided 
     //  by Message.)
-    void parseRRset(const isc::dns::Message& msg,
+    void parseSection(const isc::dns::Message& msg,
                     const isc::dns::Message::Section& section,
                     uint32_t& smaller_ttl,
                     uint16_t& rrset_count);
@@ -98,10 +113,36 @@ protected:
     RRsetTrustLevel getRRsetTrustLevel(const isc::dns::Message& message,
                                const isc::dns::RRsetPtr rrset,
                                const isc::dns::Message::Section& section);
+
+    /// \brief Add rrset to one section of message.
+    /// \param dnssec_need need dnssec records or not.
+    /// \param message The message to add rrsets.
+    /// \param rrset_entry_vec vector for rrset entries in
+    ///        different sections.
+    void addRRset(isc::dns::Message& message,
+                  const std::vector<RRsetEntryPtr> rrset_entry_vec,
+                  isc::dns::Message::Section section,
+                  bool dnssec_need);
+
+    /// \brief Get the all the rrset entries for the message entry.
+    /// \param rrset_entry_vec vector of rrset entries
+    /// \param time_now the time of now. Used to compare with rrset
+    ///        entry's expire time.
+    /// \return return false if any rrset entry has expired, or else,
+    ///         return false.
+    bool getRRsetEntries(std::vector<RRsetEntryPtr>& rrset_entry_vec, 
+                         const time_t time_now); 
     //@}
-private:
+protected:
+    /// \note Make the variable be protected for easy test.
     time_t expire_time_;  // Expiration time of the message.
+
+private:
     std::string entry_name_; // The name for this entry(name + type)
+    HashKey* hash_key_ptr_;  // the key for messag entry in hash table.
+
+    std::vector<RRsetRef> rrsets_;
+    boost::shared_ptr<RRsetCache> rrset_cache_;
 
     std::string query_name_; // query name of the message.
     uint16_t query_class_; // query class of the message.
@@ -112,13 +153,9 @@ private:
     uint16_t authority_count_; // rrset count in authority section.
     uint16_t additional_count_; // rrset count in addition section.
 
-    std::vector<boost::shared_ptr<RRsetEntry> > rrsets_;
-    boost::shared_ptr<RRsetCache> rrset_cache_;
-
     //TODO, there should be a better way to cache these header flags
     bool headerflag_aa_; // Whether AA bit is set.
     bool headerflag_tc_; // Whether TC bit is set.
-    bool headerflag_ad_; // Whether AD bit is set.
 };
     
 typedef boost::shared_ptr<MessageEntry> MessageEntryPtr;

+ 63 - 5
src/lib/cache/tests/message_entry_unittest.cc

@@ -43,12 +43,12 @@ public:
     {}
 
     /// \brief Wrap the protected function so that it can be tested.   
-    void parseRRsetForTest(const Message& msg,
+    void parseSectionForTest(const Message& msg,
                            const Message::Section& section,
                            uint32_t& smaller_ttl, 
                            uint16_t& rrset_count)
     {
-        parseRRset(msg, section, smaller_ttl, rrset_count);
+        parseSection(msg, section, smaller_ttl, rrset_count);
     }
 
     RRsetTrustLevel getRRsetTrustLevelForTest(const Message& message,
@@ -58,6 +58,14 @@ public:
         return getRRsetTrustLevel(message, rrset, section);
     }
 
+    bool getRRsetEntriesForTest(vector<RRsetEntryPtr> vec, time_t now) {
+        return getRRsetEntries(vec, now);
+    }
+
+    time_t getExpireTime() {
+        return expire_time_;
+    }
+
 };
 
 class MessageEntryTest: public testing::Test {
@@ -82,17 +90,17 @@ TEST_F(MessageEntryTest, testParseRRset) {
     DerivedMessageEntry message_entry(message_parse, rrset_cache_);
     uint32_t ttl = MAX_UINT32;
     uint16_t rrset_count = 0;
-    message_entry.parseRRsetForTest(message_parse, Message::SECTION_ANSWER, ttl, rrset_count);
+    message_entry.parseSectionForTest(message_parse, Message::SECTION_ANSWER, ttl, rrset_count);
     EXPECT_EQ(ttl, 21600);
     EXPECT_EQ(rrset_count, 1);
 
     ttl = MAX_UINT32;
-    message_entry.parseRRsetForTest(message_parse, Message::SECTION_AUTHORITY, ttl, rrset_count);
+    message_entry.parseSectionForTest(message_parse, Message::SECTION_AUTHORITY, ttl, rrset_count);
     EXPECT_EQ(ttl, 21600);
     EXPECT_EQ(rrset_count, 1);
 
     ttl = MAX_UINT32;
-    message_entry.parseRRsetForTest(message_parse, Message::SECTION_ADDITIONAL, ttl, rrset_count);
+    message_entry.parseSectionForTest(message_parse, Message::SECTION_ADDITIONAL, ttl, rrset_count);
     EXPECT_EQ(ttl, 10800);
     EXPECT_EQ(rrset_count, 5);
 }
@@ -181,10 +189,60 @@ TEST_F(MessageEntryTest, testGetRRsetTrustLevel_DNAME) {
     EXPECT_EQ(level, RRSET_TRUST_ANSWER_NONAA);
 }
 
+// We only test the expire_time of the message entry.
+// The test for genMessage() will make sure whether InitMessageEntry()
+// is right
 TEST_F(MessageEntryTest, testInitMessageEntry) {
     messageFromFile(message_parse, "message_fromWire3");
     DerivedMessageEntry message_entry(message_parse, rrset_cache_);
+    time_t expire_time = message_entry.getExpireTime();
+    // 1 second should be enough to do the compare
+    EXPECT_TRUE((time(NULL) + 10801) > expire_time);
+}
+
+TEST_F(MessageEntryTest, testGetRRsetEntries) {
+    messageFromFile(message_parse, "message_fromWire3");
+    DerivedMessageEntry message_entry(message_parse, rrset_cache_);
+    vector<RRsetEntryPtr> vec;
+    
+    // the time is bigger than the smallest expire time of 
+    // the rrset in message.
+    time_t expire_time = time(NULL) + 10802;
+    EXPECT_FALSE(message_entry.getRRsetEntriesForTest(vec, expire_time));
 }
 
+static int
+section_rrset_count(Message& msg, Message::Section section) {
+    int count = 0;
+    for (RRsetIterator rrset_iter = msg.beginSection(section);
+         rrset_iter != msg.endSection(section); 
+         ++rrset_iter) {
+        ++count;
+    }
+
+    return count;
+}
+
+TEST_F(MessageEntryTest, testGenMessage) {
+    messageFromFile(message_parse, "message_fromWire3");
+    DerivedMessageEntry message_entry(message_parse, rrset_cache_);
+    time_t expire_time = message_entry.getExpireTime();
+    
+    Message msg(Message::RENDER);
+    EXPECT_FALSE(message_entry.genMessage(expire_time + 2, msg));
+    message_entry.genMessage(time(NULL), msg);
+    // Check whether the generated message is same with cached one.
+    
+    EXPECT_TRUE(msg.getHeaderFlag(Message::HEADERFLAG_AA));
+    EXPECT_FALSE(msg.getHeaderFlag(Message::HEADERFLAG_TC));
+    EXPECT_EQ(1, section_rrset_count(msg, Message::SECTION_ANSWER)); 
+    EXPECT_EQ(1, section_rrset_count(msg, Message::SECTION_AUTHORITY)); 
+    EXPECT_EQ(5, section_rrset_count(msg, Message::SECTION_ADDITIONAL)); 
+
+    // Check the rrset in answer section.
+    EXPECT_EQ(1, msg.getRRCount(Message::SECTION_ANSWER));
+    EXPECT_EQ(5, msg.getRRCount(Message::SECTION_AUTHORITY));
+    EXPECT_EQ(7, msg.getRRCount(Message::SECTION_ADDITIONAL));
+}
 
 }   // namespace

BIN
src/lib/xfr/.libs/libxfr_python.a