Browse Source

[2218] Add a NSEC3Hash::create() variant to take hash params as args

Mukund Sivaraman 12 years ago
parent
commit
56a18b6b8d
3 changed files with 96 additions and 52 deletions
  1. 51 42
      src/lib/dns/nsec3hash.cc
  2. 18 9
      src/lib/dns/nsec3hash.h
  3. 27 1
      src/lib/dns/tests/nsec3hash_unittest.cc

+ 51 - 42
src/lib/dns/nsec3hash.cc

@@ -40,17 +40,6 @@ using namespace isc::dns::rdata;
 
 namespace {
 
-inline void
-iterateSHA1(SHA1Context* ctx, const uint8_t* input, size_t inlength,
-            const uint8_t* salt, size_t saltlen,
-            uint8_t output[SHA1_HASHSIZE])
-{
-    SHA1Reset(ctx);
-    SHA1Input(ctx, input, inlength);
-    SHA1Input(ctx, salt, saltlen); // this works whether saltlen == or > 0
-    SHA1Result(ctx, output);
-}
-
 /// \brief A derived class of \c NSEC3Hash that implements the standard hash
 /// calculation specified in RFC5155.
 ///
@@ -70,12 +59,13 @@ public:
     NSEC3HashRFC5155(uint8_t algorithm, uint16_t iterations,
                      const vector<uint8_t>& salt) :
         algorithm_(algorithm), iterations_(iterations),
-        salt_(salt)
+        salt_(salt), digest_(SHA1_HASHSIZE), obuf_(Name::MAX_WIRE)
     {
         if (algorithm_ != NSEC3_HASH_SHA1) {
             isc_throw(UnknownNSEC3HashAlgorithm, "Unknown NSEC3 algorithm: " <<
                       static_cast<unsigned int>(algorithm_));
         }
+        SHA1Reset(&sha1_ctx_);
     }
 
     virtual std::string calculate(const Name& name) const;
@@ -89,11 +79,47 @@ private:
     const uint8_t algorithm_;
     const uint16_t iterations_;
     const vector<uint8_t> salt_;
+
+    // The following members are placeholder of work place and don't hold
+    // any state over multiple calls so can be mutable without breaking
+    // constness.
+    mutable SHA1Context sha1_ctx_;
+    mutable vector<uint8_t> digest_;
+    mutable OutputBuffer obuf_;
 };
 
+inline void
+iterateSHA1(SHA1Context* ctx, const uint8_t* input, size_t inlength,
+            const uint8_t* salt, size_t saltlen,
+            uint8_t output[SHA1_HASHSIZE])
+{
+    SHA1Reset(ctx);
+    SHA1Input(ctx, input, inlength);
+    SHA1Input(ctx, salt, saltlen); // this works whether saltlen == or > 0
+    SHA1Result(ctx, output);
+}
+
 string
 NSEC3HashRFC5155::calculate(const Name& name) const {
-    return (NSEC3Hash::calculate(name, iterations_, &salt_[0], salt_.size()));
+    // We first need to normalize the name by converting all upper case
+    // characters in the labels to lower ones.
+    obuf_.clear();
+    Name name_copy(name);
+    name_copy.downcase();
+    name_copy.toWire(obuf_);
+
+    const uint8_t saltlen = salt_.size();
+    const uint8_t* const salt = (saltlen > 0) ? &salt_[0] : NULL;
+    uint8_t* const digest = &digest_[0];
+    assert(digest_.size() == SHA1_HASHSIZE);
+
+    iterateSHA1(&sha1_ctx_, static_cast<const uint8_t*>(obuf_.getData()),
+                obuf_.getLength(), salt, saltlen, digest);
+    for (unsigned int n = 0; n < iterations_; ++n) {
+        iterateSHA1(&sha1_ctx_, digest, SHA1_HASHSIZE, salt, saltlen, digest);
+    }
+
+    return (encodeBase32Hex(digest_));
 }
 
 bool
@@ -149,6 +175,12 @@ NSEC3Hash::create(const generic::NSEC3& nsec3) {
 }
 
 NSEC3Hash*
+NSEC3Hash::create(uint8_t algorithm, uint16_t iterations,
+                  const vector<uint8_t>& salt) {
+    return (getNSEC3HashCreator()->create(algorithm, iterations, salt));
+}
+
+NSEC3Hash*
 DefaultNSEC3HashCreator::create(const generic::NSEC3PARAM& param) const {
     return (new NSEC3HashRFC5155(param.getHashalg(), param.getIterations(),
                                  param.getSalt()));
@@ -160,39 +192,16 @@ DefaultNSEC3HashCreator::create(const generic::NSEC3& nsec3) const {
                                  nsec3.getSalt()));
 }
 
+NSEC3Hash*
+DefaultNSEC3HashCreator::create(uint8_t algorithm, uint16_t iterations,
+                                const vector<uint8_t>& salt) const {
+    return (new NSEC3HashRFC5155(algorithm, iterations, salt));
+}
+
 void
 setNSEC3HashCreator(const NSEC3HashCreator* new_creator) {
     creator = new_creator;
 }
 
-std::string
-NSEC3Hash::calculate(const Name& name,
-                     const uint16_t iterations,
-                     const uint8_t* salt,
-                     size_t salt_len)
-{
-    // We first need to normalize the name by converting all upper case
-    // characters in the labels to lower ones.
-    OutputBuffer obuf(Name::MAX_WIRE);
-    Name name_copy(name);
-    name_copy.downcase();
-    name_copy.toWire(obuf);
-
-    const uint8_t* const salt_buf = (salt_len > 0) ? salt : NULL;
-    std::vector<uint8_t> digest(SHA1_HASHSIZE);
-    uint8_t* const digest_buf = &digest[0];
-
-    SHA1Context sha1_ctx;
-    iterateSHA1(&sha1_ctx, static_cast<const uint8_t*>(obuf.getData()),
-                obuf.getLength(), salt_buf, salt_len, digest_buf);
-    for (unsigned int n = 0; n < iterations; ++n) {
-        iterateSHA1(&sha1_ctx, digest_buf, SHA1_HASHSIZE,
-                    salt_buf, salt_len,
-                    digest_buf);
-    }
-
-    return (encodeBase32Hex(digest));
-}
-
 } // namespace dns
 } // namespace isc

+ 18 - 9
src/lib/dns/nsec3hash.h

@@ -16,6 +16,7 @@
 #define __NSEC3HASH_H 1
 
 #include <string>
+#include <vector>
 #include <stdint.h>
 #include <exceptions/exceptions.h>
 
@@ -115,6 +116,13 @@ public:
     /// for hash calculation from an NSEC3 RDATA object.
     static NSEC3Hash* create(const rdata::generic::NSEC3& nsec3);
 
+    /// \brief Factory method of NSECHash from args.
+    ///
+    /// This is similar to the other version, but uses the arguments
+    /// passed as the parameters for hash calculation.
+    static NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
+			     const std::vector<uint8_t>& salt);
+
     /// \brief The destructor.
     virtual ~NSEC3Hash() {}
 
@@ -131,15 +139,6 @@ public:
     /// \return Base32hex-encoded string of the hash value.
     virtual std::string calculate(const Name& name) const = 0;
 
-    /// \brief Calculate the NSEC3 SHA-1 hash.
-    ///
-    /// This method calculates the NSEC3 hash value for the given
-    /// \c name and hash parameters. It assumes the SHA-1 algorithm.
-    static std::string calculate(const Name& name,
-                                 const uint16_t iterations,
-                                 const uint8_t* salt,
-                                 size_t salt_len);
-
     /// \brief Match given NSEC3 parameters with that of the hash.
     ///
     /// This method compares NSEC3 parameters used for hash calculation
@@ -219,6 +218,14 @@ public:
     /// <code>NSEC3Hash::create(const rdata::generic::NSEC3& param)</code>
     virtual NSEC3Hash* create(const rdata::generic::NSEC3& nsec3)
         const = 0;
+
+    /// \brief Factory method of NSECHash from args.
+    ///
+    /// See
+    /// <code>NSEC3Hash::create(const rdata::generic::NSEC3& param)</code>
+    virtual NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
+			      const std::vector<uint8_t>& salt)
+        const = 0;
 };
 
 /// \brief The default NSEC3Hash creator.
@@ -234,6 +241,8 @@ class DefaultNSEC3HashCreator : public NSEC3HashCreator {
 public:
     virtual NSEC3Hash* create(const rdata::generic::NSEC3PARAM& param) const;
     virtual NSEC3Hash* create(const rdata::generic::NSEC3& nsec3) const;
+    virtual NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
+			      const std::vector<uint8_t>& salt) const;
 };
 
 /// \brief The registrar of \c NSEC3HashCreator.

+ 27 - 1
src/lib/dns/tests/nsec3hash_unittest.cc

@@ -20,11 +20,14 @@
 
 #include <dns/nsec3hash.h>
 #include <dns/rdataclass.h>
+#include <util/encode/hex.h>
 
 using boost::scoped_ptr;
 using namespace std;
 using namespace isc::dns;
 using namespace isc::dns::rdata;
+using namespace isc::util;
+using namespace isc::util::encode;
 
 namespace {
 typedef scoped_ptr<NSEC3Hash> NSEC3HashPtr;
@@ -39,7 +42,12 @@ protected:
         test_hash_nsec3(NSEC3Hash::create(generic::NSEC3
                                           ("1 0 12 aabbccdd " +
                                            string(nsec3_common))))
-    {}
+    {
+        std::string salt_hex("aabbccdd");
+        std::vector<uint8_t> salt;
+        decodeHex(salt_hex, salt);
+        test_hash_args.reset(NSEC3Hash::create(1, 12, salt));
+    }
 
     ~NSEC3HashTest() {
         // Make sure we reset the hash creator to the default
@@ -53,6 +61,9 @@ protected:
 
     // Similar to test_hash, but created from NSEC3 RR.
     NSEC3HashPtr test_hash_nsec3;
+
+    // Similar to test_hash, but created from passed args.
+    NSEC3HashPtr test_hash_args;
 };
 
 TEST_F(NSEC3HashTest, unknownAlgorithm) {
@@ -65,6 +76,12 @@ TEST_F(NSEC3HashTest, unknownAlgorithm) {
                          generic::NSEC3("2 0 12 aabbccdd " +
                                         string(nsec3_common)))),
                      UnknownNSEC3HashAlgorithm);
+
+    std::string salt_hex("aabbccdd");
+    std::vector<uint8_t> salt;
+    decodeHex(salt_hex, salt);
+    EXPECT_THROW(NSEC3HashPtr(NSEC3Hash::create(2, 12, salt)),
+                 UnknownNSEC3HashAlgorithm);
 }
 
 // Common checks for NSEC3 hash calculation
@@ -90,6 +107,10 @@ TEST_F(NSEC3HashTest, calculate) {
         SCOPED_TRACE("calculate check with NSEC3 based hash");
         calculateCheck(*test_hash_nsec3);
     }
+    {
+        SCOPED_TRACE("calculate check with args based hash");
+        calculateCheck(*test_hash_args);
+    }
 
     // Some boundary cases: 0-iteration and empty salt.  Borrowed from the
     // .com zone data.
@@ -177,6 +198,11 @@ public:
         }
         return (new TestNSEC3Hash);
     }
+    virtual NSEC3Hash* create(uint8_t, uint16_t,
+                              const std::vector<uint8_t>&) const {
+        isc_throw(isc::Unexpected,
+                  "This method is not implemented here.");
+    }
 private:
     DefaultNSEC3HashCreator default_creator_;
 };