Browse Source

[2218] Update NSEC3Hash::create() to accept uint8_t buffer instead of vector

Mukund Sivaraman 12 years ago
parent
commit
09ccd7c1ec

+ 3 - 3
src/lib/datasrc/tests/faked_nsec3.cc

@@ -105,9 +105,9 @@ NSEC3Hash* TestNSEC3HashCreator::create(const rdata::generic::NSEC3&) const {
     return (new TestNSEC3Hash);
 }
 
-NSEC3Hash*
-TestNSEC3HashCreator::create(uint8_t, uint16_t,
-                             const vector<uint8_t>&) const {
+NSEC3Hash* TestNSEC3HashCreator::create(uint8_t, uint16_t,
+                                        const uint8_t*, size_t) const
+{
     return (new TestNSEC3Hash);
 }
 

+ 2 - 2
src/lib/datasrc/tests/faked_nsec3.h

@@ -63,8 +63,8 @@ public:
         const;
     virtual isc::dns::NSEC3Hash* create(const isc::dns::rdata::generic::NSEC3&)
         const;
-    virtual isc::dns::NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
-					const std::vector<uint8_t>& salt)
+    virtual isc::dns::NSEC3Hash* create(uint8_t, uint16_t,
+                                        const uint8_t*, size_t)
         const;
 };
 

+ 39 - 15
src/lib/dns/nsec3hash.cc

@@ -57,17 +57,34 @@ private:
 
 public:
     NSEC3HashRFC5155(uint8_t algorithm, uint16_t iterations,
-                     const vector<uint8_t>& salt) :
+                     const uint8_t* salt_data, size_t salt_length) :
         algorithm_(algorithm), iterations_(iterations),
-        salt_(salt), digest_(SHA1_HASHSIZE), obuf_(Name::MAX_WIRE)
+        salt_data_(NULL), salt_length_(salt_length),
+        digest_(SHA1_HASHSIZE), obuf_(Name::MAX_WIRE)
     {
         if (algorithm_ != NSEC3_HASH_SHA1) {
             isc_throw(UnknownNSEC3HashAlgorithm, "Unknown NSEC3 algorithm: " <<
                       static_cast<unsigned int>(algorithm_));
         }
+
+        if (salt_length > 0) {
+            salt_data_ = static_cast<uint8_t*>(std::malloc(salt_length));
+            if (salt_data_ == NULL) {
+                throw std::bad_alloc();
+            }
+            std::memcpy(salt_data_, salt_data, salt_length);
+        }
+
         SHA1Reset(&sha1_ctx_);
     }
 
+    ~NSEC3HashRFC5155()
+    {
+        if (salt_data_ != NULL) {
+            free(salt_data_);
+        }
+    }
+
     virtual std::string calculate(const Name& name) const;
 
     virtual bool match(const generic::NSEC3& nsec3) const;
@@ -78,7 +95,8 @@ public:
 private:
     const uint8_t algorithm_;
     const uint16_t iterations_;
-    const vector<uint8_t> salt_;
+    uint8_t* salt_data_;
+    const size_t salt_length_;
 
     // The following members are placeholder of work place and don't hold
     // any state over multiple calls so can be mutable without breaking
@@ -108,15 +126,15 @@ NSEC3HashRFC5155::calculate(const Name& name) const {
     name_copy.downcase();
     name_copy.toWire(obuf_);
 
-    const uint8_t saltlen = salt_.size();
-    const uint8_t* const salt = (saltlen > 0) ? &salt_[0] : NULL;
+    const uint8_t* const salt = (salt_length_ > 0) ? salt_data_ : NULL;
     uint8_t* const digest = &digest_[0];
     assert(digest_.size() == SHA1_HASHSIZE);
 
     iterateSHA1(&sha1_ctx_, static_cast<const uint8_t*>(obuf_.getData()),
-                obuf_.getLength(), salt, saltlen, digest);
+                obuf_.getLength(), salt, salt_length_, digest);
     for (unsigned int n = 0; n < iterations_; ++n) {
-        iterateSHA1(&sha1_ctx_, digest, SHA1_HASHSIZE, salt, saltlen, digest);
+        iterateSHA1(&sha1_ctx_, digest, SHA1_HASHSIZE,
+                    salt, salt_length_, digest);
     }
 
     return (encodeBase32Hex(digest_));
@@ -127,8 +145,9 @@ NSEC3HashRFC5155::match(uint8_t algorithm, uint16_t iterations,
                         const vector<uint8_t>& salt) const
 {
     return (algorithm_ == algorithm && iterations_ == iterations &&
-            salt_.size() == salt.size() &&
-            (salt_.empty() || memcmp(&salt_[0], &salt[0], salt_.size()) == 0));
+            salt_length_ == salt.size() &&
+            ((salt_length_ == 0) ||
+             memcmp(salt_data_, &salt[0], salt_length_) == 0));
 }
 
 bool
@@ -176,27 +195,32 @@ NSEC3Hash::create(const generic::NSEC3& nsec3) {
 
 NSEC3Hash*
 NSEC3Hash::create(uint8_t algorithm, uint16_t iterations,
-                  const vector<uint8_t>& salt) {
-    return (getNSEC3HashCreator()->create(algorithm, iterations, salt));
+                  const uint8_t* salt_data, size_t salt_length) {
+    return (getNSEC3HashCreator()->create(algorithm, iterations,
+                                          salt_data, salt_length));
 }
 
 NSEC3Hash*
 DefaultNSEC3HashCreator::create(const generic::NSEC3PARAM& param) const {
+    const vector<uint8_t>& salt = param.getSalt();
     return (new NSEC3HashRFC5155(param.getHashalg(), param.getIterations(),
-                                 param.getSalt()));
+                                 &salt[0], salt.size()));
 }
 
 NSEC3Hash*
 DefaultNSEC3HashCreator::create(const generic::NSEC3& nsec3) const {
+    const vector<uint8_t>& salt = nsec3.getSalt();
     return (new NSEC3HashRFC5155(nsec3.getHashalg(), nsec3.getIterations(),
-                                 nsec3.getSalt()));
+                                 &salt[0], salt.size()));
 }
 
 NSEC3Hash*
 DefaultNSEC3HashCreator::create(uint8_t algorithm, uint16_t iterations,
-                                const vector<uint8_t>& salt) const
+                                const uint8_t* salt_data,
+                                size_t salt_length) const
 {
-    return (new NSEC3HashRFC5155(algorithm, iterations, salt));
+    return (new NSEC3HashRFC5155(algorithm, iterations,
+                                 salt_data, salt_length));
 }
 
 void

+ 19 - 7
src/lib/dns/nsec3hash.h

@@ -118,10 +118,13 @@ public:
 
     /// \brief Factory method of NSECHash from args.
     ///
-    /// This is similar to the other version, but uses the arguments
-    /// passed as the parameters for hash calculation.
+    /// \param algorithm the NSEC3 algorithm to use; currently only 1
+    ///                  (SHA-1) is supported
+    /// \param iterations the number of iterations
+    /// \param salt_data the salt data as a byte array
+    /// \param salt_data_length the length of the salt data
     static NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
-			     const std::vector<uint8_t>& salt);
+                             const uint8_t* salt_data, size_t salt_length);
 
     /// \brief The destructor.
     virtual ~NSEC3Hash() {}
@@ -175,7 +178,7 @@ public:
 /// would be an experimental extension for a newer hash algorithm or
 /// implementation.
 ///
-/// The two main methods named \c create() correspond to the static factory
+/// The three main methods named \c create() correspond to the static factory
 /// methods of \c NSEC3Hash of the same name.
 ///
 /// By default, the library uses the \c DefaultNSEC3HashCreator creator.
@@ -222,9 +225,17 @@ public:
     /// \brief Factory method of NSECHash from args.
     ///
     /// See
-    /// <code>NSEC3Hash::create(const rdata::generic::NSEC3& param)</code>
+    /// <code>NSEC3Hash::create(uint8_t algorithm, uint16_t iterations,
+    ///                         const uint8_t* salt_data,
+    ///                         size_t salt_length)</code>
+    ///
+    /// \param algorithm the NSEC3 algorithm to use; currently only 1
+    ///                  (SHA-1) is supported
+    /// \param iterations the number of iterations
+    /// \param salt_data the salt data as a byte array
+    /// \param salt_data_length the length of the salt data
     virtual NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
-			      const std::vector<uint8_t>& salt)
+                              const uint8_t* salt_data, size_t salt_length)
         const = 0;
 };
 
@@ -242,7 +253,8 @@ public:
     virtual NSEC3Hash* create(const rdata::generic::NSEC3PARAM& param) const;
     virtual NSEC3Hash* create(const rdata::generic::NSEC3& nsec3) const;
     virtual NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
-			      const std::vector<uint8_t>& salt) const;
+                              const uint8_t* salt_data,
+                              size_t salt_length) const;
 };
 
 /// \brief The registrar of \c NSEC3HashCreator.

+ 5 - 9
src/lib/dns/tests/nsec3hash_unittest.cc

@@ -43,10 +43,8 @@ protected:
                                           ("1 0 12 aabbccdd " +
                                            string(nsec3_common))))
     {
-        std::string salt_hex("aabbccdd");
-        std::vector<uint8_t> salt;
-        decodeHex(salt_hex, salt);
-        test_hash_args.reset(NSEC3Hash::create(1, 12, salt));
+        const uint8_t salt[] = {0xaa, 0xbb, 0xcc, 0xdd};
+        test_hash_args.reset(NSEC3Hash::create(1, 12, salt, sizeof(salt)));
     }
 
     ~NSEC3HashTest() {
@@ -77,10 +75,8 @@ TEST_F(NSEC3HashTest, unknownAlgorithm) {
                                         string(nsec3_common)))),
                      UnknownNSEC3HashAlgorithm);
 
-    std::string salt_hex("aabbccdd");
-    std::vector<uint8_t> salt;
-    decodeHex(salt_hex, salt);
-    EXPECT_THROW(NSEC3HashPtr(NSEC3Hash::create(2, 12, salt)),
+    const uint8_t salt[] = {0xaa, 0xbb, 0xcc, 0xdd};
+    EXPECT_THROW(NSEC3HashPtr(NSEC3Hash::create(2, 12, salt, sizeof(salt))),
                  UnknownNSEC3HashAlgorithm);
 }
 
@@ -199,7 +195,7 @@ public:
         return (new TestNSEC3Hash);
     }
     virtual NSEC3Hash* create(uint8_t, uint16_t,
-                              const std::vector<uint8_t>&) const {
+                              const uint8_t*, size_t) const {
         isc_throw(isc::Unexpected,
                   "This method is not implemented here.");
     }