Browse Source

Merge branch 'trac1894'

Conflicts:
	src/lib/datasrc/database.cc
Mukund Sivaraman 11 years ago
parent
commit
fe7ae0a7b3

+ 8 - 0
src/bin/auth/tests/query_unittest.cc

@@ -22,6 +22,7 @@
 #include <dns/message.h>
 #include <dns/master_loader.h>
 #include <dns/name.h>
+#include <dns/labelsequence.h>
 #include <dns/nsec3hash.h>
 #include <dns/opcode.h>
 #include <dns/rcode.h>
@@ -245,6 +246,13 @@ public:
         isc_throw(isc::Unexpected, "unexpected name for NSEC3 test: "
                   << name);
     }
+    virtual string calculate(const LabelSequence& ls) const {
+        assert(ls.isAbsolute());
+        // This is not very optimal, but it's only going to be used in
+        // tests.
+        const Name name(ls.toText());
+        return (calculate(name));
+    }
     virtual bool match(const rdata::generic::NSEC3PARAM&) const {
         return (true);
     }

+ 7 - 4
src/lib/datasrc/database.cc

@@ -23,6 +23,7 @@
 
 #include <exceptions/exceptions.h>
 #include <dns/name.h>
+#include <dns/labelsequence.h>
 #include <dns/rrclass.h>
 #include <dns/rrttl.h>
 #include <dns/rrset.h>
@@ -1108,12 +1109,14 @@ DatabaseClient::Finder::findNSEC3(const Name& name, bool recursive) {
     // This will be set to the one covering the query name
     ConstRRsetPtr covering_proof;
 
+    LabelSequence name_ls(name);
     // We keep stripping the leftmost label until we find something.
     // In case it is recursive, we'll exit the loop at the first iteration.
-    for (unsigned labels = qlabels; labels >= olabels; --labels) {
-        const string hash(calculator->calculate(labels == qlabels ? name :
-                                                name.split(qlabels - labels,
-                                                           labels)));
+    for (unsigned int labels = qlabels; labels >= olabels;
+         --labels, name_ls.stripLeft(1))
+    {
+        const std::string hash = calculator->calculate(name_ls);
+
         // Get the exact match for the name.
         LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_DATABASE_FINDNSEC3_TRYHASH).
             arg(name).arg(labels).arg(hash);

+ 5 - 5
src/lib/datasrc/memory/zone_finder.cc

@@ -903,7 +903,7 @@ InMemoryZoneFinder::findNSEC3(const isc::dns::Name& name, bool recursive) {
     uint8_t labels_buf[LabelSequence::MAX_SERIALIZED_LENGTH];
     const LabelSequence origin_ls(zone_data_.getOriginNode()->
                                   getAbsoluteLabels(labels_buf));
-    const LabelSequence name_ls(name);
+    LabelSequence name_ls(name);
 
     if (!zone_data_.isNSEC3Signed()) {
         isc_throw(DataSourceError,
@@ -959,10 +959,10 @@ InMemoryZoneFinder::findNSEC3(const isc::dns::Name& name, bool recursive) {
     // Examine all names from the query name to the origin name, stripping
     // the deepest label one by one, until we find a name that has a matching
     // NSEC3 hash.
-    for (unsigned int labels = qlabels; labels >= olabels; --labels) {
-        const Name& hname = (labels == qlabels ?
-                             name : name.split(qlabels - labels, labels));
-        const std::string hlabel = hash->calculate(hname);
+    for (unsigned int labels = qlabels; labels >= olabels;
+         --labels, name_ls.stripLeft(1))
+    {
+        const std::string hlabel = hash->calculate(name_ls);
 
         LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEMORY_FINDNSEC3_TRYHASH).
             arg(name).arg(labels).arg(hlabel);

+ 8 - 0
src/lib/datasrc/tests/faked_nsec3.cc

@@ -15,6 +15,7 @@
 #include "faked_nsec3.h"
 
 #include <dns/name.h>
+#include <dns/labelsequence.h>
 #include <testutils/dnsmessage_test.h>
 
 #include <map>
@@ -87,6 +88,13 @@ public:
         isc_throw(isc::Unexpected, "unexpected name for NSEC3 test: "
                   << name);
     }
+    virtual string calculate(const LabelSequence& ls) const {
+        assert(ls.isAbsolute());
+        // This is not very optimal, but it's only going to be used in
+        // tests.
+        const Name name(ls.toText());
+        return (calculate(name));
+    }
     virtual bool match(const rdata::generic::NSEC3PARAM&) const {
         return (true);
     }

+ 45 - 7
src/lib/dns/nsec3hash.cc

@@ -29,8 +29,10 @@
 #include <util/hash/sha1.h>
 
 #include <dns/name.h>
+#include <dns/labelsequence.h>
 #include <dns/nsec3hash.h>
 #include <dns/rdataclass.h>
+#include <dns/name_internal.h>
 
 using namespace std;
 using namespace isc::util;
@@ -84,6 +86,7 @@ public:
     }
 
     virtual std::string calculate(const Name& name) const;
+    virtual std::string calculate(const LabelSequence& ls) const;
 
     virtual bool match(const generic::NSEC3& nsec3) const;
     virtual bool match(const generic::NSEC3PARAM& nsec3param) const;
@@ -91,6 +94,8 @@ public:
                const vector<uint8_t>& salt) const;
 
 private:
+    std::string calculateForWiredata(const uint8_t* data, size_t length) const;
+
     const uint8_t algorithm_;
     const uint16_t iterations_;
     uint8_t* salt_data_;
@@ -116,19 +121,33 @@ iterateSHA1(SHA1Context* ctx, const uint8_t* input, size_t inlength,
 }
 
 string
-NSEC3HashRFC5155::calculate(const Name& name) const {
+NSEC3HashRFC5155::calculateForWiredata(const uint8_t* data,
+                                       size_t length) const
+{
     // We first need to normalize the name by converting all upper case
     // characters in the labels to lower ones.
-    obuf_.clear();
-    Name name_copy(name);
-    name_copy.downcase();
-    name_copy.toWire(obuf_);
+
+    uint8_t name_buf[256];
+    assert(length < sizeof (name_buf));
+
+    const uint8_t *p1 = data;
+    uint8_t *p2 = name_buf;
+    while (*p1 != 0) {
+        char len = *p1;
+
+        *p2++ = *p1++;
+        while (len--) {
+            *p2++ = isc::dns::name::internal::maptolower[*p1++];
+        }
+    }
+
+    *p2 = *p1;
 
     uint8_t* const digest = &digest_[0];
     assert(digest_.size() == SHA1_HASHSIZE);
 
-    iterateSHA1(&sha1_ctx_, static_cast<const uint8_t*>(obuf_.getData()),
-                obuf_.getLength(), salt_data_, salt_length_, digest);
+    iterateSHA1(&sha1_ctx_, name_buf, length,
+                salt_data_, salt_length_, digest);
     for (unsigned int n = 0; n < iterations_; ++n) {
         iterateSHA1(&sha1_ctx_, digest, SHA1_HASHSIZE,
                     salt_data_, salt_length_, digest);
@@ -137,6 +156,25 @@ NSEC3HashRFC5155::calculate(const Name& name) const {
     return (encodeBase32Hex(digest_));
 }
 
+string
+NSEC3HashRFC5155::calculate(const Name& name) const {
+    obuf_.clear();
+    name.toWire(obuf_);
+
+    return (calculateForWiredata(static_cast<const uint8_t*>(obuf_.getData()),
+                                 obuf_.getLength()));
+}
+
+string
+NSEC3HashRFC5155::calculate(const LabelSequence& ls) const {
+    assert(ls.isAbsolute());
+
+    size_t length;
+    const uint8_t* data = ls.getData(&length);
+
+    return (calculateForWiredata(data, length));
+}
+
 bool
 NSEC3HashRFC5155::match(uint8_t algorithm, uint16_t iterations,
                         const vector<uint8_t>& salt) const

+ 17 - 2
src/lib/dns/nsec3hash.h

@@ -23,6 +23,7 @@
 namespace isc {
 namespace dns {
 class Name;
+class LabelSequence;
 
 namespace rdata {
 namespace generic {
@@ -129,19 +130,33 @@ public:
     /// \brief The destructor.
     virtual ~NSEC3Hash() {}
 
-    /// \brief Calculate the NSEC3 hash.
+    /// \brief Calculate the NSEC3 hash (Name variant).
     ///
     /// This method calculates the NSEC3 hash value for the given \c name
     /// with the hash parameters (algorithm, iterations and salt) given at
     /// construction, and returns the value as a base32hex-encoded string
     /// (without containing any white spaces).  All US-ASCII letters in the
-    /// string will be upper cased.
+    /// string will be lower cased.
     ///
     /// \param name The domain name for which the hash value is to be
     /// calculated.
     /// \return Base32hex-encoded string of the hash value.
     virtual std::string calculate(const Name& name) const = 0;
 
+    /// \brief Calculate the NSEC3 hash (LabelSequence variant).
+    ///
+    /// This method calculates the NSEC3 hash value for the given
+    /// absolute LabelSequence \c ls with the hash parameters
+    /// (algorithm, iterations and salt) given at construction, and
+    /// returns the value as a base32hex-encoded string (without
+    /// containing any white spaces).  All US-ASCII letters in the
+    /// string will be lower cased.
+    ///
+    /// \param ls The absolute label sequence for which the hash value
+    /// is to be calculated.
+    /// \return Base32hex-encoded string of the hash value.
+    virtual std::string calculate(const LabelSequence& ls) const = 0;
+
     /// \brief Match given NSEC3 parameters with that of the hash.
     ///
     /// This method compares NSEC3 parameters used for hash calculation

+ 32 - 1
src/lib/dns/tests/nsec3hash_unittest.cc

@@ -19,6 +19,7 @@
 #include <boost/scoped_ptr.hpp>
 
 #include <dns/nsec3hash.h>
+#include <dns/labelsequence.h>
 #include <dns/rdataclass.h>
 #include <util/encode/hex.h>
 
@@ -92,6 +93,18 @@ calculateCheck(NSEC3Hash& hash) {
     // Check case-insensitiveness
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
               hash.calculate(Name("EXAMPLE")));
+
+    // Repeat for the LabelSequence variant.
+
+    // A couple of normal cases from the RFC5155 example.
+    EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+              hash.calculate(LabelSequence(Name("example"))));
+    EXPECT_EQ("35MTHGPGCU1QG68FAB165KLNSNK3DPVL",
+              hash.calculate(LabelSequence(Name("a.example"))));
+
+    // Check case-insensitiveness
+    EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+              hash.calculate(LabelSequence(Name("EXAMPLE"))));
 }
 
 TEST_F(NSEC3HashTest, calculate) {
@@ -113,13 +126,16 @@ TEST_F(NSEC3HashTest, calculate) {
     EXPECT_EQ("CK0POJMG874LJREF7EFN8430QVIT8BSM",
               NSEC3HashPtr(NSEC3Hash::create(generic::NSEC3PARAM("1 0 0 -")))
               ->calculate(Name("com")));
+    EXPECT_EQ("CK0POJMG874LJREF7EFN8430QVIT8BSM",
+              NSEC3HashPtr(NSEC3Hash::create(generic::NSEC3PARAM("1 0 0 -")))
+              ->calculate(LabelSequence(Name("com"))));
 
     // Using unusually large iterations, something larger than the 8-bit range.
     // (expected hash value generated by BIND 9's dnssec-signzone)
     EXPECT_EQ("COG6A52MJ96MNMV3QUCAGGCO0RHCC2Q3",
               NSEC3HashPtr(NSEC3Hash::create(
                                generic::NSEC3PARAM("1 0 256 AABBCCDD")))
-              ->calculate(Name("example.org")));
+              ->calculate(LabelSequence(Name("example.org"))));
 }
 
 // Common checks for match cases
@@ -169,6 +185,9 @@ class TestNSEC3Hash : public NSEC3Hash {
     virtual string calculate(const Name&) const {
         return ("00000000000000000000000000000000");
     }
+    virtual string calculate(const LabelSequence&) const {
+        return ("00000000000000000000000000000000");
+    }
     virtual bool match(const generic::NSEC3PARAM&) const {
         return (true);
     }
@@ -207,6 +226,8 @@ TEST_F(NSEC3HashTest, setCreator) {
     // Re-check an existing case using the default creator/hash implementation
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
               test_hash->calculate(Name("example")));
+    EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+              test_hash->calculate(LabelSequence(Name("example"))));
 
     // Replace the creator, and confirm the hash values are faked
     TestNSEC3HashCreator test_creator;
@@ -215,12 +236,16 @@ TEST_F(NSEC3HashTest, setCreator) {
     test_hash.reset(NSEC3Hash::create(generic::NSEC3PARAM("1 0 12 aabbccdd")));
     EXPECT_EQ("00000000000000000000000000000000",
               test_hash->calculate(Name("example")));
+    EXPECT_EQ("00000000000000000000000000000000",
+              test_hash->calculate(LabelSequence(Name("example"))));
     // Same for hash from NSEC3 RDATA
     test_hash.reset(NSEC3Hash::create(generic::NSEC3
                                       ("1 0 12 aabbccdd " +
                                        string(nsec3_common))));
     EXPECT_EQ("00000000000000000000000000000000",
               test_hash->calculate(Name("example")));
+    EXPECT_EQ("00000000000000000000000000000000",
+              test_hash->calculate(LabelSequence(Name("example"))));
 
     // If we set a special flag big (0x80) on creation, it will act like the
     // default creator.
@@ -228,17 +253,23 @@ TEST_F(NSEC3HashTest, setCreator) {
                                           "1 128 12 aabbccdd")));
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
               test_hash->calculate(Name("example")));
+    EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+              test_hash->calculate(LabelSequence(Name("example"))));
     test_hash.reset(NSEC3Hash::create(generic::NSEC3
                                       ("1 128 12 aabbccdd " +
                                        string(nsec3_common))));
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
               test_hash->calculate(Name("example")));
+    EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+              test_hash->calculate(LabelSequence(Name("example"))));
 
     // Reset the creator to default, and confirm that
     setNSEC3HashCreator(NULL);
     test_hash.reset(NSEC3Hash::create(generic::NSEC3PARAM("1 0 12 aabbccdd")));
     EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
               test_hash->calculate(Name("example")));
+    EXPECT_EQ("0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+              test_hash->calculate(LabelSequence(Name("example"))));
 }
 
 } // end namespace