Parcourir la source

[2282merge] Merge branch 'trac2218' into trac2282merge

with fixing Conflicts:
	src/lib/datasrc/memory/zone_finder.cc
	src/lib/datasrc/memory/zone_finder.h
JINMEI Tatuya il y a 12 ans
Parent
commit
8660a77ac7

+ 63 - 37
src/lib/datasrc/memory/domaintree.h

@@ -373,7 +373,6 @@ private:
         }
     }
 
-public:
     /// \brief returns if the node is a subtree's root node
     ///
     /// This method takes a node and returns \c true if it is the root
@@ -391,6 +390,7 @@ public:
     /// This method never throws an exception.
     const DomainTreeNode<T>* getSubTreeRoot() const;
 
+public:
     /// \brief returns the parent of the root of its subtree
     ///
     /// This method takes a node and returns the parent of the root of
@@ -401,14 +401,6 @@ public:
     /// This method never throws an exception.
     const DomainTreeNode<T>* getUpperNode() const;
 
-    /// \brief returns the largest node of this node's subtree
-    ///
-    /// This method takes a node and returns the largest node in its
-    /// subtree.
-    ///
-    /// This method never throws an exception.
-    const DomainTreeNode<T>* getLargestInSubTree() const;
-
     /// \brief return the next node which is bigger than current node
     /// in the same subtree
     ///
@@ -578,17 +570,6 @@ DomainTreeNode<T>::getUpperNode() const {
 }
 
 template <typename T>
-const DomainTreeNode<T>*
-DomainTreeNode<T>::getLargestInSubTree() const {
-    const DomainTreeNode<T>* sroot = getSubTreeRoot();
-    while (sroot->getRight() != NULL) {
-        sroot = sroot->getRight();
-    }
-
-    return (sroot);
-}
-
-template <typename T>
 isc::dns::LabelSequence
 DomainTreeNode<T>::getAbsoluteLabels(
     uint8_t buf[isc::dns::LabelSequence::MAX_SERIALIZED_LENGTH]) const
@@ -720,14 +701,26 @@ public:
     /// The default constructor.
     ///
     /// \exception None
-    DomainTreeNodeChain() : node_count_(0), last_compared_(NULL),
+    DomainTreeNodeChain() : level_count_(0), last_compared_(NULL),
                         // XXX: meaningless initial values:
                         last_comparison_(0, 0,
                                          isc::dns::NameComparisonResult::EQUAL)
     {}
 
+    /// \brief Copy constructor.
+    ///
+    /// \exception None
+    DomainTreeNodeChain(const DomainTreeNodeChain<T>& other) :
+        level_count_(other.level_count_),
+        last_compared_(other.last_compared_),
+        last_comparison_(other.last_comparison_)
+    {
+        for (size_t i = 0; i < level_count_; i++) {
+	    nodes_[i] = other.nodes_[i];
+        }
+    }
+
 private:
-    DomainTreeNodeChain(const DomainTreeNodeChain<T>&);
     DomainTreeNodeChain<T>& operator=(const DomainTreeNodeChain<T>&);
     //@}
 
@@ -739,7 +732,7 @@ public:
     ///
     /// \exception None
     void clear() {
-        node_count_ = 0;
+        level_count_ = 0;
         last_compared_ = NULL;
     }
 
@@ -780,7 +773,7 @@ public:
     /// chain, 0 will be returned.
     ///
     /// \exception None
-    unsigned int getLevelCount() const { return (node_count_); }
+    size_t getLevelCount() const { return (level_count_); }
 
     /// \brief return the absolute name for the node which this
     /// \c DomainTreeNodeChain currently refers to.
@@ -798,11 +791,11 @@ public:
 
         const DomainTreeNode<T>* top_node = top();
         isc::dns::Name absolute_name = top_node->getName();
-        int node_count = node_count_ - 1;
-        while (node_count > 0) {
-            top_node = nodes_[node_count - 1];
+        size_t level = level_count_ - 1;
+        while (level > 0) {
+            top_node = nodes_[level - 1];
             absolute_name = absolute_name.concatenate(top_node->getName());
-            --node_count;
+            --level;
         }
         return (absolute_name);
     }
@@ -816,7 +809,7 @@ private:
     /// \brief return whether node chain has node in it.
     ///
     /// \exception None
-    bool isEmpty() const { return (node_count_ == 0); }
+    bool isEmpty() const { return (level_count_ == 0); }
 
     /// \brief return the top node for the node chain
     ///
@@ -826,7 +819,7 @@ private:
     /// \exception None
     const DomainTreeNode<T>* top() const {
         assert(!isEmpty());
-        return (nodes_[node_count_ - 1]);
+        return (nodes_[level_count_ - 1]);
     }
 
     /// \brief pop the top node from the node chain
@@ -837,7 +830,7 @@ private:
     /// \exception None
     void pop() {
         assert(!isEmpty());
-        --node_count_;
+        --level_count_;
     }
 
     /// \brief add the node into the node chain
@@ -848,8 +841,8 @@ private:
     ///
     /// \exception None
     void push(const DomainTreeNode<T>* node) {
-        assert(node_count_ < RBT_MAX_LEVEL);
-        nodes_[node_count_++] = node;
+        assert(level_count_ < RBT_MAX_LEVEL);
+        nodes_[level_count_++] = node;
     }
 
 private:
@@ -858,7 +851,7 @@ private:
     // it's also equal to the possible maximum level.
     const static int RBT_MAX_LEVEL = isc::dns::Name::MAX_LABELS;
 
-    int node_count_;
+    size_t level_count_;
     const DomainTreeNode<T>* nodes_[RBT_MAX_LEVEL];
     const DomainTreeNode<T>* last_compared_;
     isc::dns::NameComparisonResult last_comparison_;
@@ -1313,12 +1306,20 @@ public:
     const DomainTreeNode<T>*
     previousNode(DomainTreeNodeChain<T>& node_path) const;
 
+    /// \brief return the largest node in the tree of trees.
+    ///
+    /// \throw none
+    ///
+    /// \return A \c DomainTreeNode that is the largest node in the
+    /// tree. If there are no nodes, then \c NULL is returned.
+    const DomainTreeNode<T>* largestNode() const;
+
     /// \brief Get the total number of nodes in the tree
     ///
     /// It includes nodes internally created as a result of adding a domain
     /// name that is a subdomain of an existing node of the tree.
     /// This function is mainly intended to be used for debugging.
-    int getNodeCount() const { return (node_count_); }
+    uint32_t getNodeCount() const { return (node_count_); }
 
     /// \name Debug function
     //@{
@@ -1450,8 +1451,15 @@ private:
     //@}
 
     typename DomainTreeNode<T>::DomainTreeNodePtr root_;
-    /// the node count of current tree
-    unsigned int node_count_;
+
+    /// the node count of current tree.
+    ///
+    /// Note: uint32_t may look awkward, but we intentionally choose it so
+    /// that needsReturnEmptyNode_ below won't make cause extra padding
+    /// in 64-bit machines (and we can minimize the total size of this class).
+    /// 2^32 - 1 should be a reasonable max of possible number of nodes.
+    uint32_t node_count_;
+
     /// search policy for domaintree
     const bool needsReturnEmptyNode_;
 };
@@ -1758,6 +1766,24 @@ DomainTree<T>::previousNode(DomainTreeNodeChain<T>& node_path) const {
 }
 
 template <typename T>
+const DomainTreeNode<T>*
+DomainTree<T>::largestNode() const {
+    const DomainTreeNode<T>* node = root_.get();
+    while (node != NULL) {
+        // We go right first, then down.
+        if (node->getRight() != NULL) {
+            node = node->getRight();
+        } else if (node->getDown() != NULL) {
+            node = node->getDown();
+        } else {
+	    break;
+	}
+    }
+
+    return (node);
+}
+
+template <typename T>
 typename DomainTree<T>::Result
 DomainTree<T>::insert(util::MemorySegment& mem_sgmt,
                       const isc::dns::Name& target_name,

+ 19 - 16
src/lib/datasrc/memory/memory_client.cc

@@ -324,6 +324,7 @@ public:
         if (nsec3_data == NULL) {
             nsec3_data = NSEC3Data::create(mem_sgmt_, nsec3_rdata);
             zone_data.setNSEC3Data(nsec3_data);
+            zone_data.setSigned(true);
         } else {
             size_t salt_len = nsec3_data->getSaltLen();
             const uint8_t* salt_data = nsec3_data->getSaltData();
@@ -331,7 +332,12 @@ public:
 
             if ((nsec3_rdata.getHashalg() != nsec3_data->hashalg) ||
                 (nsec3_rdata.getIterations() != nsec3_data->iterations) ||
-                (salt_data_2.size() != salt_len) ||
+                (salt_data_2.size() != salt_len)) {
+                isc_throw(AddError,
+                          "NSEC3 with inconsistent parameters: " <<
+                          rrset->toText());
+            }
+            if ((salt_len > 0) &&
                 (std::memcmp(&salt_data_2[0], salt_data, salt_len) != 0)) {
                 isc_throw(AddError,
                           "NSEC3 with inconsistent parameters: " <<
@@ -339,20 +345,10 @@ public:
             }
         }
 
-        // Make just the NSEC3 hash label uppercase, and insert the
-        // entire name into the NSEC3Data ZoneTree.
-        string fst_label = rrset->getName().split(0, 1).toText(true);
-        transform(fst_label.begin(), fst_label.end(), fst_label.begin(),
-                  ::toupper);
-        const string rest = rrset->getName().split(1).toText(true);
-
         ZoneNode* node;
-        nsec3_data->insertName(mem_sgmt_, Name(fst_label + "." + rest), &node);
+        nsec3_data->insertName(mem_sgmt_, rrset->getName(), &node);
 
         RdataEncoder encoder;
-
-        // We assume that rrsig has already been checked to match rrset
-        // by the caller.
         RdataSet* set = RdataSet::create(mem_sgmt_, encoder, rrset, rrsig);
         RdataSet* old_set = node->setData(set);
         if (old_set != NULL) {
@@ -417,6 +413,7 @@ public:
                 if (nsec3_data == NULL) {
                     nsec3_data = NSEC3Data::create(mem_sgmt_, param);
                     zone_data.setNSEC3Data(nsec3_data);
+                    zone_data.setSigned(true);
                 } else {
                     size_t salt_len = nsec3_data->getSaltLen();
                     const uint8_t* salt_data = nsec3_data->getSaltData();
@@ -424,7 +421,13 @@ public:
 
                     if ((param.getHashalg() != nsec3_data->hashalg) ||
                         (param.getIterations() != nsec3_data->iterations) ||
-                        (salt_data_2.size() != salt_len) ||
+                        (salt_data_2.size() != salt_len)) {
+                        isc_throw(AddError,
+                                  "NSEC3PARAM with inconsistent parameters: "
+                                  << rrset->toText());
+                    }
+
+                    if ((salt_len > 0) &&
                         (std::memcmp(&salt_data_2[0],
                                      salt_data, salt_len) != 0)) {
                         isc_throw(AddError,
@@ -702,10 +705,10 @@ InMemoryClient::findZone(const isc::dns::Name& zone_name) const {
     return (DataSourceClient::FindResult(result.code, finder));
 }
 
-isc::datasrc::memory::ZoneTable::FindResult
-InMemoryClient::findZone2(const isc::dns::Name& zone_name) const {
+const ZoneData*
+InMemoryClient::findZoneData(const isc::dns::Name& zone_name) {
     ZoneTable::FindResult result(impl_->zone_table_->findZone(zone_name));
-    return (result);
+    return (result.zone_data);
 }
 
 result::Result

+ 8 - 6
src/lib/datasrc/memory/memory_client.h

@@ -20,7 +20,7 @@
 #include <datasrc/iterator.h>
 #include <datasrc/client.h>
 #include <datasrc/memory/zone_table.h>
-#include <datasrc/zonetable.h>
+#include <datasrc/memory/zone_data.h>
 
 #include <string>
 
@@ -209,12 +209,14 @@ public:
     virtual isc::datasrc::DataSourceClient::FindResult
     findZone(const isc::dns::Name& name) const;
 
-    /// Returns a \c ZoneTable result that best matches the given name.
+    /// Returns a \c ZoneData in the result that best matches the given
+    /// name.
     ///
-    /// This derived version of the method never throws an exception.
-    /// For other details see \c DataSourceClient::findZone().
-    virtual isc::datasrc::memory::ZoneTable::FindResult
-    findZone2(const isc::dns::Name& name) const;
+    /// This is mainly intended for use in unit tests and should not be
+    /// used in other code.
+    ///
+    /// \throws none
+    const ZoneData* findZoneData(const isc::dns::Name& name);
 
     /// \brief Implementation of the getIterator method
     virtual isc::datasrc::ZoneIteratorPtr

+ 14 - 0
src/lib/datasrc/memory/tests/testdata/example.org-nsec3-empty-salt.zone

@@ -0,0 +1,14 @@
+example.org.				      86400 IN SOA	ns.example.org. ns.example.org. 2012092602 7200 3600 2592000 1200
+example.org.				      86400 IN RRSIG	SOA 7 2 86400 20120301040838 20120131040838 19562 example.org. Jt9wCRLS5TQxZH0IBqrM9uMGD453rIoxYopfM9AjjRZfEx+HGlBpOZeR pGN7yLcN+URnicOD0ydLHiakaBODiZyNoYCKYG5d2ZOhL+026REnDKNM 0m5T3X3sczP+l55An/GITheTdrKt3Y1Ouc2yKI8ro8JjOxV/a4nGDWjK x9A=
+example.org.				      86400 IN NS	ns.example.org.
+example.org.				      86400 IN RRSIG	NS 7 2 86400 20120301040838 20120131040838 19562 example.org. gYXL3xK4IFdJU6TtiVuzqDBb2MeA8xB3AKtHlJGFTfTRNHyuej0ZGovx TeUYsLYmoiGYaJG66iD1tYYFq0qdj0xWq+LEa53ACtKvYf9IIwK4ijJs k0g6xCNavc6/qPppymDhN7MvoFVkW59uJa0HPWlsIIuRlEAr7xyt85vq yoA=
+example.org.				      86400 IN DNSKEY	256 3 7 AwEAAbrBkKf2gmGtG4aogZY4svIZCrOLLZlQzVHwz7WxJdTR8iEnvz/x Q/jReDroS5CBZWvzwLlhPIpsJAojx0oj0RvfJNsz3+6LN8q7x9u6+86B 85CYjTk3dcFOebgkF4fXr7/kkOX+ZY94Zk0Z1+pUC3eY4gkKcyME/Uxm O18PBTeB
+example.org.				      86400 IN RRSIG	DNSKEY 7 2 86400 20120301040838 20120131040838 19562 example.org. d0eLF8JqNHaGuBSX0ashU5c1O/wyWU43UUsKGrMQIoBDiJ588MWQOnas rwvW6vdkLNqRqCsP/B4epV/EtLL0tBsk5SHkTYbNo80gGrBufQ6YrWRr Ile8Z+h+MR4y9DybbjmuNKqaO4uQMg/X6+4HqRAKx1lmZMTcrcVeOwDM ZA4=
+example.org.				      0	IN NSEC3PARAM	1 0 10 -
+example.org.				      0	IN RRSIG	NSEC3PARAM 7 2 0 20120301040838 20120131040838 19562 example.org. Ggs5MiQDlXXt22Fz9DNg3Ujc0T6MBfumlRkd8/enBbJwLmqw2QXAzDEk pjUeGstCEHKzxJDJstavGoCpTDJgoV4Fd9szooMx69rzPrq9tdoM46jG xZHqw+Pv2fkRGC6aP7ZX1r3Qnpwpk47AQgATftbO4G6KcMcO8JoKE47x XLM=
+ns.example.org.				      86400 IN A	192.0.2.1
+ns.example.org.				      86400 IN RRSIG	A 7 3 86400 20120301040838 20120131040838 19562 example.org. dOH+Dxib8VcGnjLrKILsqDhS1wki6BWk1dZwpOGUGHyLWcLNW8ygWY2o r29jPhHtaFCNWpn46JebgnXDPRiQjaY3dQqL8rcf2QX1e3+Cicw1OSrs S0sUDE5DmMNEac+ZCUQ0quCStZTCldl05hlspV2RS92TpApnoOK0nXMp Uak=
+09GM5T42SMIMT7R8DF6RTG80SFMS1NLU.example.org. 1200 IN NSEC3	1 0 10 - RKOF8QMFRB5F2V9EJHFBVB2JPVSA0DJD A RRSIG
+09GM5T42SMIMT7R8DF6RTG80SFMS1NLU.example.org. 1200 IN RRSIG	NSEC3 7 3 1200 20120301040838 20120131040838 19562 example.org. EdwMeepLf//lV+KpCAN+213Scv1rrZyj4i2OwoCP4XxxS3CWGSuvYuKO yfZc8wKRcrD/4YG6nZVXE0s5O8NahjBJmDIyVt4WkfZ6QthxGg8ggLVv cD3dFksPyiKHf/WrTOZPSsxvN5m/i1Ey6+YWS01Gf3WDCMWDauC7Nmh3 CTM=
+RKOF8QMFRB5F2V9EJHFBVB2JPVSA0DJD.example.org. 1200 IN NSEC3	1 0 10 - 09GM5T42SMIMT7R8DF6RTG80SFMS1NLU NS SOA RRSIG DNSKEY NSEC3PARAM
+RKOF8QMFRB5F2V9EJHFBVB2JPVSA0DJD.example.org. 1200 IN RRSIG	NSEC3 7 3 1200 20120301040838 20120131040838 19562 example.org. j7d8GL4YqX035FBcPPsEcSWHjWcKdlQMHLL4TB67xVNFnl4SEFQCp4OO AtPap5tkKakwgWxoQVN9XjnqrBz+oQhofDkB3aTatAjIIkcwcnrm3AYQ rTI3E03ySiRwuCPKVmHOLUV2cG6O4xzcmP+MYZcvPTS8V3F5LlaU22i7 A3E=

+ 92 - 101
src/lib/datasrc/memory/zone_finder.cc

@@ -23,13 +23,11 @@
 #include <dns/name.h>
 #include <dns/rrset.h>
 #include <dns/rrtype.h>
-
-#include <util/buffer.h>
-#include <util/encode/base32hex.h>
-#include <util/hash/sha1.h>
+#include <dns/nsec3hash.h>
 
 #include <datasrc/logger.h>
 
+#include <boost/scoped_ptr.hpp>
 #include <boost/bind.hpp>
 
 #include <algorithm>
@@ -38,9 +36,6 @@
 using namespace isc::dns;
 using namespace isc::datasrc::memory;
 using namespace isc::datasrc;
-using namespace isc::util;
-using namespace isc::util::encode;
-using namespace isc::util::hash;
 
 namespace isc {
 namespace datasrc {
@@ -192,6 +187,27 @@ bool cutCallback(const ZoneNode& node, FindState* state) {
     return (false);
 }
 
+/// Creates a NSEC3 ConstRRsetPtr for the given ZoneNode inside the
+/// NSEC3 tree, for the given RRClass.
+///
+/// It asserts that the node contains data (RdataSet) and is of type
+/// NSEC3.
+///
+/// \param node The ZoneNode inside the NSEC3 tree
+/// \param rrclass The RRClass as passed by the client
+ConstRRsetPtr
+createNSEC3RRset(const ZoneNode* node, const RRClass& rrclass) {
+     const RdataSet* rdataset = node->getData();
+     // Only NSEC3 ZoneNodes are allowed to be passed to this method. We
+     // assert that these have data, and also are of type NSEC3.
+     assert(rdataset != NULL);
+     assert(rdataset->type == RRType::NSEC3());
+
+    // Create the RRset.  Note the DNSSEC flag: NSEC3 implies DNSSEC.
+    return (createTreeNodeRRset(node, rdataset, rrclass,
+                                ZoneFinder::FIND_DNSSEC));
+}
+
 // convenience function to fill in the final details
 //
 // Set up ZoneFinderResultContext object as a return value of find(),
@@ -504,44 +520,6 @@ FindNodeResult findNode(const ZoneData& zone_data,
 
 } // end anonymous namespace
 
-inline void
-iterateSHA1(SHA1Context* ctx, const uint8_t* input, size_t inlength,
-            const uint8_t* salt, size_t saltlen,
-            uint8_t output[SHA1_HASHSIZE])
-{
-    SHA1Reset(ctx);
-    SHA1Input(ctx, input, inlength);
-    SHA1Input(ctx, salt, saltlen); // this works whether saltlen == or > 0
-    SHA1Result(ctx, output);
-}
-
-std::string
-InMemoryZoneFinderNSEC3Calculate(const Name& name,
-                                 const uint16_t iterations,
-                                 const uint8_t* salt,
-                                 size_t salt_len) {
-    // We first need to normalize the name by converting all upper case
-    // characters in the labels to lower ones.
-    OutputBuffer obuf(Name::MAX_WIRE);
-    Name name_copy(name);
-    name_copy.downcase();
-    name_copy.toWire(obuf);
-
-    const uint8_t* const salt_buf = (salt_len > 0) ? salt : NULL;
-    std::vector<uint8_t> digest(SHA1_HASHSIZE);
-    uint8_t* const digest_buf = &digest[0];
-
-    SHA1Context sha1_ctx;
-    iterateSHA1(&sha1_ctx, static_cast<const uint8_t*>(obuf.getData()),
-                obuf.getLength(), salt_buf, salt_len, digest_buf);
-    for (unsigned int n = 0; n < iterations; ++n) {
-        iterateSHA1(&sha1_ctx, digest_buf, SHA1_HASHSIZE,
-                    salt_buf, salt_len,
-                    digest_buf);
-    }
-
-    return (encodeBase32Hex(digest));
-}
 
 /// \brief Specialization of the ZoneFinder::Context for the in-memory finder.
 ///
@@ -827,60 +805,91 @@ InMemoryZoneFinder::findNSEC3(const isc::dns::Name& name, bool recursive) {
     LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEM_FINDNSEC3).arg(name).
         arg(recursive ? "recursive" : "non-recursive");
 
+    uint8_t labels_buf[LabelSequence::MAX_SERIALIZED_LENGTH];
+    const LabelSequence origin_ls(zone_data_.getOriginNode()->
+                                  getAbsoluteLabels(labels_buf));
+    const LabelSequence name_ls(name);
+
     if (!zone_data_.isNSEC3Signed()) {
         isc_throw(DataSourceError,
                   "findNSEC3 attempt for non NSEC3 signed zone: " <<
-                  getOrigin() << "/" << getClass());
+                  origin_ls << "/" << getClass());
     }
 
-    const NameComparisonResult cmp_result = name.compare(getOrigin());
+    const NSEC3Data* nsec3_data = zone_data_.getNSEC3Data();
+    // This would be a programming mistake, as ZoneData::isNSEC3Signed()
+    // should check this.
+    assert(nsec3_data != NULL);
+
+    const ZoneTree& tree = nsec3_data->getNSEC3Tree();
+    if (tree.getNodeCount() == 0) {
+        isc_throw(DataSourceError,
+                  "findNSEC3 attempt but zone has no NSEC3 RRs: " <<
+                  origin_ls << "/" << getClass());
+    }
+
+    const NameComparisonResult cmp_result = name_ls.compare(origin_ls);
     if (cmp_result.getRelation() != NameComparisonResult::EQUAL &&
         cmp_result.getRelation() != NameComparisonResult::SUBDOMAIN) {
         isc_throw(OutOfZone, "findNSEC3 attempt for out-of-zone name: "
-                  << name << ", zone: " << getOrigin() << "/"
+                  << name_ls << ", zone: " << origin_ls << "/"
                   << getClass());
     }
 
     // Convenient shortcuts
-    const ZoneFinder::FindOptions options =
-        ZoneFinder::FIND_DNSSEC; // NSEC3 implies DNSSEC
-    const unsigned int olabels = getOrigin().getLabelCount();
+    const unsigned int olabels = origin_ls.getLabelCount();
     const unsigned int qlabels = name.getLabelCount();
-    const NSEC3Data* nsec3_data = zone_data_.getNSEC3Data();
+    // placeholder of the next closer proof
+    const ZoneNode* covering_node(NULL);
+
+    // Now we'll first look up the origin node and initialize orig_chain
+    // with it.
+    ZoneChain orig_chain;
+    const ZoneNode* node(NULL);
+    ZoneTree::Result result =
+         tree.find<void*>(origin_ls, &node, orig_chain, NULL, NULL);
+    if (result != ZoneTree::EXACTMATCH) {
+        // If the origin node doesn't exist, simply fail.
+        isc_throw(DataSourceError,
+                  "findNSEC3 attempt but zone has no NSEC3 RRs: " <<
+                  origin_ls << "/" << getClass());
+    }
+
+    const boost::scoped_ptr<NSEC3Hash> hash
+        (NSEC3Hash::create(nsec3_data->hashalg,
+                           nsec3_data->iterations,
+                           nsec3_data->getSaltData(),
+                           nsec3_data->getSaltLen()));
 
-    const ZoneNode* covering_node(NULL); // placeholder of the next closer proof
     // Examine all names from the query name to the origin name, stripping
     // the deepest label one by one, until we find a name that has a matching
     // NSEC3 hash.
     for (unsigned int labels = qlabels; labels >= olabels; --labels) {
-        const std::string hlabel = (nsec3_calculate_)
-            ((labels == qlabels ?
-              name : name.split(qlabels - labels, labels)),
-             nsec3_data->iterations,
-             nsec3_data->getSaltData(),
-             nsec3_data->getSaltLen());
+        const Name& hname = (labels == qlabels ?
+                             name : name.split(qlabels - labels, labels));
+        const std::string hlabel = hash->calculate(hname);
 
         LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEM_FINDNSEC3_TRYHASH).
             arg(name).arg(labels).arg(hlabel);
 
-        const ZoneTree& tree = nsec3_data->getNSEC3Tree();
-
-        ZoneNode* node(NULL);
-        ZoneChain chain;
+        node = NULL;
+        ZoneChain chain(orig_chain);
 
-        ZoneTree::Result result =
-            tree.find(Name(hlabel + "." + getOrigin().toText()), &node, chain);
+        // Now, make a label sequence relative to the origin.
+        const Name hlabel_name(hlabel);
+        LabelSequence hlabel_ls(hlabel_name);
+        // Remove trailing '.' making it relative
+        hlabel_ls.stripRight(1);
 
+        // Find hlabel relative to the orig_chain.
+        result = tree.find<void*>(hlabel_ls, &node, chain, NULL, NULL);
         if (result == ZoneTree::EXACTMATCH) {
             // We found an exact match.
-            RdataSet* set = node->getData();
-            ConstRRsetPtr closest = createTreeNodeRRset(node, set, getClass(),
-                                                        options);
-            ConstRRsetPtr next =
-                createTreeNodeRRset(covering_node,
-                                    (covering_node != NULL ?
-                                     covering_node->getData() : NULL),
-                                    getClass(), options);
+            ConstRRsetPtr closest = createNSEC3RRset(node, getClass());
+            ConstRRsetPtr next;
+            if (covering_node != NULL) {
+                next = createNSEC3RRset(covering_node, getClass());
+            }
 
             LOG_DEBUG(logger, DBG_TRACE_BASIC,
                       DATASRC_MEM_FINDNSEC3_MATCH).arg(name).arg(labels).
@@ -888,37 +897,19 @@ InMemoryZoneFinder::findNSEC3(const isc::dns::Name& name, bool recursive) {
 
             return (FindNSEC3Result(true, labels, closest, next));
         } else {
-            const NameComparisonResult& last_cmp =
-                chain.getLastComparisonResult();
-            const ZoneNode* last_node = chain.getLastComparedNode();
-            assert(last_cmp.getOrder() != 0);
-
-            // find() finished in between one of these and last_node:
-            const ZoneNode* previous_node = last_node->predecessor();
-            const ZoneNode* next_node = last_node->successor();
-
-            // If the given hash is larger than the largest stored hash or
-            // the first label doesn't match the target, identify the "previous"
-            // hash value and remember it as the candidate next closer proof.
-            if (((last_cmp.getOrder() < 0) && (previous_node == NULL)) ||
-                ((last_cmp.getOrder() > 0) && (next_node == NULL))) {
-                covering_node = last_node->getLargestInSubTree();
-            } else {
-                // Otherwise, H(found_entry-1) < given_hash < H(found_entry).
-                // The covering proof is the first one (and it's valid
-                // because found is neither begin nor end)
-                covering_node = previous_node;
+            while ((covering_node = tree.previousNode(chain)) != NULL &&
+                   covering_node->isEmpty()) {
+                ;
+            }
+            if (covering_node == NULL) {
+                covering_node = tree.largestNode();
             }
 
             if (!recursive) {   // in non recursive mode, we are done.
-                ConstRRsetPtr closest =
-                    createTreeNodeRRset(covering_node,
-                                        (covering_node != NULL ?
-                                         covering_node->getData() :
-                                         NULL),
-                                        getClass(), options);
-
-                if (closest) {
+                ConstRRsetPtr closest;
+                if (covering_node != NULL) {
+                    closest = createNSEC3RRset(covering_node, getClass());
+
                     LOG_DEBUG(logger, DBG_TRACE_BASIC,
                               DATASRC_MEM_FINDNSEC3_COVER).
                         arg(name).arg(*closest);

+ 3 - 15
src/lib/datasrc/memory/zone_finder.h

@@ -23,6 +23,8 @@
 #include <dns/rrset.h>
 #include <dns/rrtype.h>
 
+#include <string>
+
 namespace isc {
 namespace datasrc {
 namespace memory {
@@ -31,12 +33,6 @@ namespace internal {
 class ZoneFinderResultContext;
 }
 
-std::string
-InMemoryZoneFinderNSEC3Calculate(const isc::dns::Name& name,
-                                 const uint16_t iterations,
-                                 const uint8_t* salt,
-                                 size_t salt_len);
-
 /// A derived zone finder class intended to be used with the memory data
 /// source, using ZoneData for its contents.
 class InMemoryZoneFinder : boost::noncopyable, public ZoneFinder {
@@ -55,8 +51,7 @@ public:
     InMemoryZoneFinder(const ZoneData& zone_data,
                        const isc::dns::RRClass& rrclass) :
         zone_data_(zone_data),
-        rrclass_(rrclass),
-        nsec3_calculate_(InMemoryZoneFinderNSEC3Calculate)
+        rrclass_(rrclass)
     {}
 
     /// \brief Find an RRset in the datasource
@@ -106,13 +101,6 @@ private:
 
     const ZoneData& zone_data_;
     const isc::dns::RRClass rrclass_;
-
-protected:
-    typedef std::string (NSEC3CalculateFn) (const isc::dns::Name& name,
-                                            const uint16_t iterations,
-                                            const uint8_t* salt,
-                                            size_t salt_len);
-    NSEC3CalculateFn* nsec3_calculate_;
 };
 
 } // namespace memory

+ 17 - 0
src/lib/datasrc/tests/faked_nsec3.cc

@@ -67,6 +67,17 @@ public:
             "00000000000000000000000000000000";
         map_[Name("largest.example.org")] =
             "UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUU";
+
+        // These are used by the findNSEC3Walk test.
+        map_[Name("n0.example.org")] = "00000000000000000000000000000000";
+        map_[Name("n1.example.org")] = "01UDEMVP1J2F7EG6JEBPS17VP3N8I58H";
+        map_[Name("n2.example.org")] = "02UDEMVP1J2F7EG6JEBPS17VP3N8I58H";
+        map_[Name("n3.example.org")] = "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
+        map_[Name("n4.example.org")] = "11111111111111111111111111111111";
+        map_[Name("n5.example.org")] = "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR";
+        map_[Name("n6.example.org")] = "44444444444444444444444444444444";
+        map_[Name("n7.example.org")] = "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN";
+        map_[Name("n8.example.org")] = "ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ";
     }
     virtual string calculate(const Name& name) const {
         const NSEC3HashMap::const_iterator found = map_.find(name);
@@ -94,6 +105,12 @@ NSEC3Hash* TestNSEC3HashCreator::create(const rdata::generic::NSEC3&) const {
     return (new TestNSEC3Hash);
 }
 
+NSEC3Hash* TestNSEC3HashCreator::create(uint8_t, uint16_t,
+                                        const uint8_t*, size_t) const
+{
+    return (new TestNSEC3Hash);
+}
+
 void
 findNSEC3Check(bool expected_matched, uint8_t expected_labels,
                const string& expected_closest,

+ 4 - 0
src/lib/datasrc/tests/faked_nsec3.h

@@ -57,11 +57,15 @@ class TestNSEC3HashCreator : public isc::dns::NSEC3HashCreator {
 private:
     class TestNSEC3Hash;
 public:
+    TestNSEC3HashCreator() {}
     virtual isc::dns::NSEC3Hash* create(const
                                         isc::dns::rdata::generic::NSEC3PARAM&)
         const;
     virtual isc::dns::NSEC3Hash* create(const isc::dns::rdata::generic::NSEC3&)
         const;
+    virtual isc::dns::NSEC3Hash* create(uint8_t, uint16_t,
+                                        const uint8_t*, size_t)
+        const;
 };
 
 // Check the result against expected values. It directly calls EXPECT_ macros

+ 37 - 43
src/lib/datasrc/tests/memory/domaintree_unittest.cc

@@ -257,7 +257,7 @@ TEST_F(DomainTreeTest, subTreeRoot) {
     // "g.h" is not a subtree root
     EXPECT_EQ(TestDomainTree::EXACTMATCH,
               dtree_expose_empty_node.find(Name("g.h"), &dtnode));
-    EXPECT_FALSE(dtnode->isSubTreeRoot());
+    EXPECT_FALSE(dtnode->getFlag(TestDomainTreeNode::FLAG_SUBTREE_ROOT));
 
     // fission the node "g.h"
     EXPECT_EQ(TestDomainTree::ALREADYEXISTS,
@@ -266,12 +266,12 @@ TEST_F(DomainTreeTest, subTreeRoot) {
 
     // the node "h" (h.down_ -> "g") should not be a subtree root. "g"
     // should be a subtree root.
-    EXPECT_FALSE(dtnode->isSubTreeRoot());
+    EXPECT_FALSE(dtnode->getFlag(TestDomainTreeNode::FLAG_SUBTREE_ROOT));
 
     // "g.h" should be a subtree root now.
     EXPECT_EQ(TestDomainTree::EXACTMATCH,
               dtree_expose_empty_node.find(Name("g.h"), &dtnode));
-    EXPECT_TRUE(dtnode->isSubTreeRoot());
+    EXPECT_TRUE(dtnode->getFlag(TestDomainTreeNode::FLAG_SUBTREE_ROOT));
 }
 
 TEST_F(DomainTreeTest, additionalNodeFission) {
@@ -287,7 +287,7 @@ TEST_F(DomainTreeTest, additionalNodeFission) {
     // "t.0" is not a subtree root
     EXPECT_EQ(TestDomainTree::EXACTMATCH,
               dtree_expose_empty_node.find(Name("t.0"), &dtnode));
-    EXPECT_FALSE(dtnode->isSubTreeRoot());
+    EXPECT_FALSE(dtnode->getFlag(TestDomainTreeNode::FLAG_SUBTREE_ROOT));
 
     // fission the node "t.0"
     EXPECT_EQ(TestDomainTree::ALREADYEXISTS,
@@ -296,12 +296,12 @@ TEST_F(DomainTreeTest, additionalNodeFission) {
 
     // the node "0" ("0".down_ -> "t") should not be a subtree root. "t"
     // should be a subtree root.
-    EXPECT_FALSE(dtnode->isSubTreeRoot());
+    EXPECT_FALSE(dtnode->getFlag(TestDomainTreeNode::FLAG_SUBTREE_ROOT));
 
     // "t.0" should be a subtree root now.
     EXPECT_EQ(TestDomainTree::EXACTMATCH,
               dtree_expose_empty_node.find(Name("t.0"), &dtnode));
-    EXPECT_TRUE(dtnode->isSubTreeRoot());
+    EXPECT_TRUE(dtnode->getFlag(TestDomainTreeNode::FLAG_SUBTREE_ROOT));
 }
 
 TEST_F(DomainTreeTest, findName) {
@@ -596,6 +596,10 @@ TEST_F(DomainTreeTest, chainLevel) {
     // by default there should be no level in the chain.
     EXPECT_EQ(0, chain.getLevelCount());
 
+    // Copy should be consistent
+    TestDomainTreeNodeChain chain2(chain);
+    EXPECT_EQ(chain.getLevelCount(), chain2.getLevelCount());
+
     // insert one node to the tree and find it.  there should be exactly
     // one level in the chain.
     TreeHolder tree_holder(mem_sgmt_, TestDomainTree::create(mem_sgmt_, true));
@@ -607,6 +611,11 @@ TEST_F(DomainTreeTest, chainLevel) {
               tree.find(node_name, &cdtnode, chain));
     EXPECT_EQ(1, chain.getLevelCount());
 
+    // Copy should be consistent
+    TestDomainTreeNodeChain chain3(chain);
+    EXPECT_EQ(chain.getLevelCount(), chain3.getLevelCount());
+    EXPECT_EQ(chain.getAbsoluteName(), chain3.getAbsoluteName());
+
     // Check the name of the found node (should have '.' as both non-absolute
     // and absolute name
     EXPECT_EQ(".", cdtnode->getLabels().toText());
@@ -687,16 +696,6 @@ const char* const upper_node_names[] = {
     "w.y.d.e.f", "w.y.d.e.f", "d.e.f", "z.d.e.f",
     ".", "g.h", "g.h"};
 
-const char* const subtree_root_node_names[] = {
-    "b", "b", "b", "b", "w.y.d.e.f", "w.y.d.e.f", "p.w.y.d.e.f",
-    "p.w.y.d.e.f", "p.w.y.d.e.f", "w.y.d.e.f", "j.z.d.e.f",
-    "b", "i.g.h", "i.g.h"};
-
-const char* const largest_node_names[] = {
-    "g.h", "g.h", "g.h", "g.h", "z.d.e.f", "z.d.e.f", "q.w.y.d.e.f",
-    "q.w.y.d.e.f", "q.w.y.d.e.f", "z.d.e.f", "j.z.d.e.f",
-    "g.h", "k.g.h", "k.g.h"};
-
 TEST_F(DomainTreeTest, getUpperNode) {
     TestDomainTreeNodeChain node_path;
     const TestDomainTreeNode* node = NULL;
@@ -726,6 +725,16 @@ TEST_F(DomainTreeTest, getUpperNode) {
     EXPECT_EQ(static_cast<void*>(NULL), node);
 }
 
+
+#if 0
+// Disabled and kept still, for use in case we make getSubTreeRoot() a
+// public function again.
+
+const char* const subtree_root_node_names[] = {
+    "b", "b", "b", "b", "w.y.d.e.f", "w.y.d.e.f", "p.w.y.d.e.f",
+    "p.w.y.d.e.f", "p.w.y.d.e.f", "w.y.d.e.f", "j.z.d.e.f",
+    "b", "i.g.h", "i.g.h"};
+
 TEST_F(DomainTreeTest, getSubTreeRoot) {
     TestDomainTreeNodeChain node_path;
     const TestDomainTreeNode* node = NULL;
@@ -755,34 +764,8 @@ TEST_F(DomainTreeTest, getSubTreeRoot) {
     EXPECT_EQ(static_cast<void*>(NULL), node);
 }
 
-TEST_F(DomainTreeTest, getLargestInSubTree) {
-    TestDomainTreeNodeChain node_path;
-    const TestDomainTreeNode* node = NULL;
-    EXPECT_EQ(TestDomainTree::EXACTMATCH,
-              dtree_expose_empty_node.find(Name(names[0]),
-                                            &node,
-                                            node_path));
-    for (int i = 0; i < name_count; ++i) {
-        EXPECT_NE(static_cast<void*>(NULL), node);
-
-        const TestDomainTreeNode* largest_node = node->getLargestInSubTree();
-        if (largest_node_names[i] != NULL) {
-            const TestDomainTreeNode* largest_node2 = NULL;
-            EXPECT_EQ(TestDomainTree::EXACTMATCH,
-                dtree_expose_empty_node.find(Name(largest_node_names[i]),
-                                             &largest_node2));
-            EXPECT_NE(static_cast<void*>(NULL), largest_node2);
-            EXPECT_EQ(largest_node, largest_node2);
-        } else {
-            EXPECT_EQ(static_cast<void*>(NULL), largest_node);
-        }
-
-        node = dtree_expose_empty_node.nextNode(node_path);
-    }
+#endif // disabled getSubTreeRoot()
 
-    // We should have reached the end of the tree.
-    EXPECT_EQ(static_cast<void*>(NULL), node);
-}
 
 TEST_F(DomainTreeTest, nextNode) {
     TestDomainTreeNodeChain node_path;
@@ -1005,6 +988,17 @@ TEST_F(DomainTreeTest, previousNode) {
     }
 }
 
+TEST_F(DomainTreeTest, largestNode) {
+    cdtnode = dtree.largestNode();
+    EXPECT_EQ(Name("k"), cdtnode->getName());
+
+    // Check for largest node in an empty tree.
+    TreeHolder empty_tree_holder
+        (mem_sgmt_, TestDomainTree::create(mem_sgmt_));
+    TestDomainTree& empty_tree(*empty_tree_holder.get());
+    EXPECT_EQ(static_cast<void*>(NULL), empty_tree.largestNode());
+}
+
 TEST_F(DomainTreeTest, nextNodeError) {
     // Empty chain for nextNode() is invalid.
     TestDomainTreeNodeChain chain;

+ 43 - 27
src/lib/datasrc/tests/memory/memory_client_unittest.cc

@@ -200,6 +200,11 @@ TEST_F(MemoryClientTest, load) {
     // should not result in any exceptions.
     client_->load(Name("example.org"),
                   TEST_DATA_DIR "/example.org.zone");
+    const ZoneData* zone_data =
+        client_->findZoneData(Name("example.org"));
+    ASSERT_NE(static_cast<const ZoneData*>(NULL), zone_data);
+    EXPECT_FALSE(zone_data->isSigned());
+    EXPECT_FALSE(zone_data->isNSEC3Signed());
 }
 
 TEST_F(MemoryClientTest, loadFromIterator) {
@@ -266,11 +271,33 @@ TEST_F(MemoryClientTest, loadMemoryAllocationFailures) {
 TEST_F(MemoryClientTest, loadNSEC3Signed) {
     client_->load(Name("example.org"),
                   TEST_DATA_DIR "/example.org-nsec3-signed.zone");
+    const ZoneData* zone_data =
+        client_->findZoneData(Name("example.org"));
+    ASSERT_NE(static_cast<const ZoneData*>(NULL), zone_data);
+    EXPECT_TRUE(zone_data->isSigned());
+    EXPECT_TRUE(zone_data->isNSEC3Signed());
+}
+
+TEST_F(MemoryClientTest, loadNSEC3EmptySalt) {
+    // Load NSEC3 with empty ("-") salt. This should not throw or crash
+    // or anything.
+    client_->load(Name("example.org"),
+                  TEST_DATA_DIR "/example.org-nsec3-empty-salt.zone");
+    const ZoneData* zone_data =
+        client_->findZoneData(Name("example.org"));
+    ASSERT_NE(static_cast<const ZoneData*>(NULL), zone_data);
+    EXPECT_TRUE(zone_data->isSigned());
+    EXPECT_TRUE(zone_data->isNSEC3Signed());
 }
 
 TEST_F(MemoryClientTest, loadNSEC3SignedNoParam) {
     client_->load(Name("example.org"),
                   TEST_DATA_DIR "/example.org-nsec3-signed-no-param.zone");
+    const ZoneData* zone_data =
+        client_->findZoneData(Name("example.org"));
+    ASSERT_NE(static_cast<const ZoneData*>(NULL), zone_data);
+    EXPECT_TRUE(zone_data->isSigned());
+    EXPECT_TRUE(zone_data->isNSEC3Signed());
 }
 
 TEST_F(MemoryClientTest, loadReloadZone) {
@@ -288,14 +315,12 @@ TEST_F(MemoryClientTest, loadReloadZone) {
                   client_->getFileName(Name("example.org")));
     EXPECT_EQ(1, client_->getZoneCount());
 
-    isc::datasrc::memory::ZoneTable::FindResult
-        result(client_->findZone2(Name("example.org")));
-    EXPECT_EQ(result::SUCCESS, result.code);
-    EXPECT_NE(static_cast<ZoneData*>(NULL),
-              result.zone_data);
+    const ZoneData* zone_data =
+        client_->findZoneData(Name("example.org"));
+    EXPECT_NE(static_cast<const ZoneData*>(NULL), zone_data);
 
     /* Check SOA */
-    const ZoneNode* node = result.zone_data->getOriginNode();
+    const ZoneNode* node = zone_data->getOriginNode();
     EXPECT_NE(static_cast<const ZoneNode*>(NULL), node);
 
     const RdataSet* set = node->getData();
@@ -306,7 +331,7 @@ TEST_F(MemoryClientTest, loadReloadZone) {
     EXPECT_EQ(static_cast<const RdataSet*>(NULL), set);
 
     /* Check ns1.example.org */
-    const ZoneTree& tree = result.zone_data->getZoneTree();
+    const ZoneTree& tree = zone_data->getZoneTree();
     ZoneTree::Result zresult(tree.find(Name("ns1.example.org"), &node));
     EXPECT_NE(ZoneTree::EXACTMATCH, zresult);
 
@@ -316,14 +341,11 @@ TEST_F(MemoryClientTest, loadReloadZone) {
                   TEST_DATA_DIR "/example.org-rrsigs.zone");
     EXPECT_EQ(1, client_->getZoneCount());
 
-    isc::datasrc::memory::ZoneTable::FindResult
-        result2(client_->findZone2(Name("example.org")));
-    EXPECT_EQ(result::SUCCESS, result2.code);
-    EXPECT_NE(static_cast<ZoneData*>(NULL),
-              result2.zone_data);
+    zone_data = client_->findZoneData(Name("example.org"));
+    EXPECT_NE(static_cast<const ZoneData*>(NULL), zone_data);
 
     /* Check SOA */
-    node = result2.zone_data->getOriginNode();
+    node = zone_data->getOriginNode();
     EXPECT_NE(static_cast<const ZoneNode*>(NULL), node);
 
     set = node->getData();
@@ -334,7 +356,7 @@ TEST_F(MemoryClientTest, loadReloadZone) {
     EXPECT_EQ(static_cast<const RdataSet*>(NULL), set);
 
     /* Check ns1.example.org */
-    const ZoneTree& tree2 = result2.zone_data->getZoneTree();
+    const ZoneTree& tree2 = zone_data->getZoneTree();
     ZoneTree::Result zresult2(tree2.find(Name("ns1.example.org"), &node));
     EXPECT_EQ(ZoneTree::EXACTMATCH, zresult2);
     EXPECT_NE(static_cast<const ZoneNode*>(NULL), node);
@@ -702,24 +724,18 @@ TEST_F(MemoryClientTest, add) {
     EXPECT_EQ(ConstRRsetPtr(), iterator->getNextRRset());
 }
 
-TEST_F(MemoryClientTest, findZone2) {
+TEST_F(MemoryClientTest, findZoneData) {
     client_->load(Name("example.org"),
                   TEST_DATA_DIR "/example.org-rrsigs.zone");
 
-    isc::datasrc::memory::ZoneTable::FindResult
-        result(client_->findZone2(Name("example.com")));
-    EXPECT_EQ(result::NOTFOUND, result.code);
-    EXPECT_EQ(static_cast<ZoneData*>(NULL),
-              result.zone_data);
+    const ZoneData* zone_data = client_->findZoneData(Name("example.com"));
+    EXPECT_EQ(static_cast<const ZoneData*>(NULL), zone_data);
 
-    isc::datasrc::memory::ZoneTable::FindResult
-        result2(client_->findZone2(Name("example.org")));
-    EXPECT_EQ(result::SUCCESS, result2.code);
-    EXPECT_NE(static_cast<ZoneData*>(NULL),
-              result2.zone_data);
+    zone_data = client_->findZoneData(Name("example.org"));
+    EXPECT_NE(static_cast<const ZoneData*>(NULL), zone_data);
 
     /* Check SOA */
-    const ZoneNode* node = result2.zone_data->getOriginNode();
+    const ZoneNode* node = zone_data->getOriginNode();
     EXPECT_NE(static_cast<const ZoneNode*>(NULL), node);
 
     const RdataSet* set = node->getData();
@@ -730,7 +746,7 @@ TEST_F(MemoryClientTest, findZone2) {
     EXPECT_EQ(static_cast<const RdataSet*>(NULL), set);
 
     /* Check ns1.example.org */
-    const ZoneTree& tree = result2.zone_data->getZoneTree();
+    const ZoneTree& tree = zone_data->getZoneTree();
     ZoneTree::Result result3(tree.find(Name("ns1.example.org"), &node));
     EXPECT_EQ(ZoneTree::EXACTMATCH, result3);
     EXPECT_NE(static_cast<const ZoneNode*>(NULL), node);

+ 1 - 0
src/lib/datasrc/tests/memory/testdata/Makefile.am

@@ -19,6 +19,7 @@ EXTRA_DIST += example.org-multiple-dname.zone
 EXTRA_DIST += example.org-multiple-nsec3.zone
 EXTRA_DIST += example.org-multiple-nsec3param.zone
 EXTRA_DIST += example.org-multiple.zone
+EXTRA_DIST += example.org-nsec3-empty-salt.zone
 EXTRA_DIST += example.org-nsec3-fewer-labels.zone example.org-nsec3-more-labels.zone
 EXTRA_DIST += example.org-nsec3-signed-no-param.zone
 EXTRA_DIST += example.org-nsec3-signed.zone

+ 182 - 113
src/lib/datasrc/tests/memory/zone_finder_unittest.cc

@@ -31,6 +31,8 @@
 
 #include <gtest/gtest.h>
 
+#include <string>
+
 using namespace std;
 using namespace isc::dns;
 using namespace isc::dns::rdata;
@@ -46,72 +48,6 @@ namespace {
 using result::SUCCESS;
 using result::EXIST;
 
-// Some faked NSEC3 hash values commonly used in tests and the faked NSEC3Hash
-// object.
-//
-// For apex (example.org)
-const char* const apex_hash = "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-const char* const apex_hash_lower = "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom";
-// For ns1.example.org
-const char* const ns1_hash = "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR";
-// For w.example.org
-const char* const w_hash = "01UDEMVP1J2F7EG6JEBPS17VP3N8I58H";
-// For x.y.w.example.org (lower-cased)
-const char* const xyw_hash = "2vptu5timamqttgl4luu9kg21e0aor3s";
-// For zzz.example.org.
-const char* const zzz_hash = "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN";
-
-typedef map<Name, string> NSEC3HashMap;
-typedef NSEC3HashMap::value_type NSEC3HashPair;
-NSEC3HashMap nsec3_hash_map;
-
-// A faked NSEC3 hash calculator for convenience. Tests that need to use
-// the faked hashed values should call setFakeNSEC3Calculate() on the
-// MyZoneFinder object at the beginning of the test (at least before
-// adding any NSEC3/NSEC3PARAM RR).
-std::string
-fakeNSEC3Calculate(const Name& name,
-                   const uint16_t,
-                   const uint8_t*,
-                   size_t) {
-    const NSEC3HashMap::const_iterator found = nsec3_hash_map.find(name);
-    if (found != nsec3_hash_map.end()) {
-        return (found->second);
-    }
-
-    isc_throw(isc::Unexpected,
-              "unexpected name for NSEC3 test: " << name);
-}
-
-class MyZoneFinder : public memory::InMemoryZoneFinder {
-private:
-public:
-    MyZoneFinder(const ZoneData& zone_data,
-                 const isc::dns::RRClass& rrclass) :
-         memory::InMemoryZoneFinder(zone_data, rrclass)
-    {
-        // Build pre-defined hash
-        nsec3_hash_map.clear();
-        nsec3_hash_map[Name("example.org")] = apex_hash;
-        nsec3_hash_map[Name("www.example.org")] = "2S9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-        nsec3_hash_map[Name("xxx.example.org")] = "Q09MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-        nsec3_hash_map[Name("yyy.example.org")] = "0A9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-        nsec3_hash_map[Name("x.y.w.example.org")] =
-            "2VPTU5TIMAMQTTGL4LUU9KG21E0AOR3S";
-        nsec3_hash_map[Name("y.w.example.org")] = "K8UDEMVP1J2F7EG6JEBPS17VP3N8I58H";
-        nsec3_hash_map[Name("w.example.org")] = w_hash;
-        nsec3_hash_map[Name("zzz.example.org")] = zzz_hash;
-        nsec3_hash_map[Name("smallest.example.org")] =
-            "00000000000000000000000000000000";
-        nsec3_hash_map[Name("largest.example.org")] =
-            "UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUU";
-    }
-
-    void setFakeNSEC3Calculate() {
-        nsec3_calculate_ = fakeNSEC3Calculate;
-    }
-};
-
 /// \brief expensive rrset converter
 ///
 /// converts any specialized rrset (which may not have implemented some
@@ -255,7 +191,6 @@ public:
     }
 
     // NSEC3-specific call for 'loading' data
-    // This needs to be updated and checked when implementing #2118
     void addZoneDataNSEC3(const ConstRRsetPtr rrset) {
         assert(rrset->getType() == RRType::NSEC3());
 
@@ -268,13 +203,19 @@ public:
              nsec3_data = NSEC3Data::create(mem_sgmt_, nsec3_rdata);
              zone_data_->setNSEC3Data(nsec3_data);
         } else {
-             size_t salt_len = nsec3_data->getSaltLen();
+             const size_t salt_len = nsec3_data->getSaltLen();
              const uint8_t* salt_data = nsec3_data->getSaltData();
              const vector<uint8_t>& salt_data_2 = nsec3_rdata.getSalt();
 
              if ((nsec3_rdata.getHashalg() != nsec3_data->hashalg) ||
                  (nsec3_rdata.getIterations() != nsec3_data->iterations) ||
-                 (salt_data_2.size() != salt_len) ||
+                 (salt_data_2.size() != salt_len)) {
+                  isc_throw(isc::Unexpected,
+                            "NSEC3 with inconsistent parameters: " <<
+                            rrset->toText());
+             }
+
+             if ((salt_len > 0) &&
                  (std::memcmp(&salt_data_2[0], salt_data, salt_len) != 0)) {
                   isc_throw(isc::Unexpected,
                             "NSEC3 with inconsistent parameters: " <<
@@ -282,23 +223,14 @@ public:
              }
         }
 
-        // Make just the NSEC3 hash label uppercase, and insert the
-        // entire name into the NSEC3Data ZoneTree.
-        string fst_label = rrset->getName().split(0, 1).toText(true);
-        transform(fst_label.begin(), fst_label.end(), fst_label.begin(),
-                  ::toupper);
-        const string rest = rrset->getName().split(1).toText(true);
-
-        ZoneNode *node;
-        nsec3_data->insertName(mem_sgmt_, Name(fst_label + "." + rest), &node);
-
-        // We assume that rrsig has already been checked to match rrset
-        // by the caller.
-        RdataSet *set = RdataSet::create(mem_sgmt_, encoder_,
-                                         rrset, ConstRRsetPtr());
-        RdataSet *old_set = node->setData(set);
-        if (old_set != NULL) {
-             RdataSet::destroy(mem_sgmt_, class_, old_set);
+        ZoneNode* node;
+        nsec3_data->insertName(mem_sgmt_, rrset->getName(), &node);
+
+        RdataSet* rdset = RdataSet::create(mem_sgmt_, encoder_,
+                                           rrset, ConstRRsetPtr());
+        RdataSet* old_rdset = node->setData(rdset);
+        if (old_rdset != NULL) {
+             RdataSet::destroy(mem_sgmt_, class_, old_rdset);
         }
         zone_data_->setSigned(true);
     }
@@ -344,6 +276,44 @@ public:
             }
             name = name.split(1);
         }
+
+        // If we've added NSEC3PARAM at zone origin, set up NSEC3
+        // specific data or check consistency with already set up
+        // parameters.
+        if (rrset->getType() == RRType::NSEC3PARAM() &&
+            rrset->getName() == origin_) {
+            // We know rrset has exactly one RDATA
+            const generic::NSEC3PARAM& param =
+                dynamic_cast<const generic::NSEC3PARAM&>
+                 (rrset->getRdataIterator()->getCurrent());
+
+            NSEC3Data* nsec3_data = zone_data_->getNSEC3Data();
+            if (nsec3_data == NULL) {
+                nsec3_data = NSEC3Data::create(mem_sgmt_, param);
+                zone_data_->setNSEC3Data(nsec3_data);
+                zone_data_->setSigned(true);
+            } else {
+                size_t salt_len = nsec3_data->getSaltLen();
+                const uint8_t* salt_data = nsec3_data->getSaltData();
+                const vector<uint8_t>& salt_data_2 = param.getSalt();
+
+                if ((param.getHashalg() != nsec3_data->hashalg) ||
+                    (param.getIterations() != nsec3_data->iterations) ||
+                    (salt_data_2.size() != salt_len)) {
+                     isc_throw(isc::Unexpected,
+                               "NSEC3PARAM with inconsistent parameters: "
+                               << rrset->toText());
+                }
+
+                if ((salt_len > 0) &&
+                    (std::memcmp(&salt_data_2[0],
+                                 salt_data, salt_len) != 0)) {
+                     isc_throw(isc::Unexpected,
+                               "NSEC3PARAM with inconsistent parameters: "
+                               << rrset->toText());
+                }
+            }
+        }
     }
 
     // Some data to test with
@@ -352,7 +322,7 @@ public:
     // The zone finder to torture by tests
     MemorySegmentTest mem_sgmt_;
     memory::ZoneData* zone_data_;
-    MyZoneFinder zone_finder_;
+    memory::InMemoryZoneFinder zone_finder_;
     isc::datasrc::memory::RdataEncoder encoder_;
 
     // Placeholder for storing RRsets to be checked with rrsetsCheck()
@@ -1483,34 +1453,10 @@ TEST_F(InMemoryZoneFinderTest, cancelWildcardNSEC) {
 }
 
 
-TEST_F(InMemoryZoneFinderTest, findNSEC3) {
-    // Set up the faked hash calculator.
-    zone_finder_.setFakeNSEC3Calculate();
-
-    // Add a few NSEC3 records:
-    // apex (example.org.): hash=0P..
-    // ns1.example.org:     hash=2T..
-    // w.example.org:       hash=01..
-    // zzz.example.org:     hash=R5..
-    const string apex_nsec3_text = string(apex_hash) + ".example.org." +
-        string(nsec3_common);
-    addZoneData(textToRRset(apex_nsec3_text));
-    const string ns1_nsec3_text = string(ns1_hash) + ".example.org." +
-        string(nsec3_common);
-    addZoneData(textToRRset(ns1_nsec3_text));
-    const string w_nsec3_text = string(w_hash) + ".example.org." +
-        string(nsec3_common);
-    addZoneData(textToRRset(w_nsec3_text));
-    const string zzz_nsec3_text = string(zzz_hash) + ".example.org." +
-        string(nsec3_common);
-    addZoneData(textToRRset(zzz_nsec3_text));
-
-    performNSEC3Test(zone_finder_);
-}
-
 TEST_F(InMemoryZoneFinderTest, findNSEC3ForBadZone) {
     // Set up the faked hash calculator.
-    zone_finder_.setFakeNSEC3Calculate();
+    const TestNSEC3HashCreator creator;
+    setNSEC3HashCreator(&creator);
 
     // If the zone has nothing about NSEC3 (neither NSEC3 or NSEC3PARAM),
     // findNSEC3() should be rejected.
@@ -1532,4 +1478,127 @@ TEST_F(InMemoryZoneFinderTest, findNSEC3ForBadZone) {
                  DataSourceError);
 }
 
+/// \brief NSEC3 specific tests fixture for the InMemoryZoneFinder class
+class InMemoryZoneFinderNSEC3Test : public InMemoryZoneFinderTest {
+public:
+    InMemoryZoneFinderNSEC3Test() {
+        // Set up the faked hash calculator.
+        setNSEC3HashCreator(&creator_);
+
+        // Add a few NSEC3 records:
+        // apex (example.org.): hash=0P..
+        // ns1.example.org:     hash=2T..
+        // w.example.org:       hash=01..
+        // zzz.example.org:     hash=R5..
+        const string apex_nsec3_text = string(apex_hash) + ".example.org." +
+            string(nsec3_common);
+        addZoneData(textToRRset(apex_nsec3_text));
+        const string ns1_nsec3_text = string(ns1_hash) + ".example.org." +
+            string(nsec3_common);
+        addZoneData(textToRRset(ns1_nsec3_text));
+        const string w_nsec3_text = string(w_hash) + ".example.org." +
+            string(nsec3_common);
+        addZoneData(textToRRset(w_nsec3_text));
+        const string zzz_nsec3_text = string(zzz_hash) + ".example.org." +
+            string(nsec3_common);
+        addZoneData(textToRRset(zzz_nsec3_text));
+    }
+
+private:
+    const TestNSEC3HashCreator creator_;
+};
+
+TEST_F(InMemoryZoneFinderNSEC3Test, findNSEC3) {
+    performNSEC3Test(zone_finder_);
+}
+
+struct TestData {
+     // String for the name passed to findNSEC3() (concatenated with
+     // "example.org.")
+     const char* const name;
+     // Should recursive findNSEC3() be performed?
+     const bool recursive;
+     // The following are members of the FindNSEC3Result returned by
+     // findNSEC3(). The proofs are given as char*, which are converted
+     // to Name objects and checked against getName() on the returned
+     // ConstRRsetPtr. If any of these is NULL, then it's expected that
+     // ConstRRsetPtr() will be returned.
+     const bool matched;
+     const uint8_t closest_labels;
+     const char* const closest_proof;
+     const char* const next_proof;
+};
+
+const TestData nsec3_data[] = {
+     // ==== These are non-recursive tests.
+     {"n0", false, false, 4, "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN", NULL},
+     {"n1", false,  true, 4, "01UDEMVP1J2F7EG6JEBPS17VP3N8I58H", NULL},
+     {"n2", false, false, 4, "01UDEMVP1J2F7EG6JEBPS17VP3N8I58H", NULL},
+     {"n3", false,  true, 4, "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM", NULL},
+     {"n4", false, false, 4, "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM", NULL},
+     {"n5", false,  true, 4, "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR", NULL},
+     {"n6", false, false, 4, "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR", NULL},
+     {"n7", false,  true, 4, "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN", NULL},
+     {"n8", false, false, 4, "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN", NULL},
+
+     // ==== These are recursive tests.
+     {"n0",  true,  true, 3, "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+         "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN"},
+     {"n1",  true,  true, 4, "01UDEMVP1J2F7EG6JEBPS17VP3N8I58H", NULL},
+     {"n2",  true,  true, 3, "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+         "01UDEMVP1J2F7EG6JEBPS17VP3N8I58H"},
+     {"n3",  true,  true, 4, "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM", NULL},
+     {"n4",  true,  true, 3, "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+         "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM"},
+     {"n5",  true,  true, 4, "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR", NULL},
+     {"n6",  true,  true, 3, "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+         "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR"},
+     {"n7",  true,  true, 4, "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN", NULL},
+     {"n8",  true,  true, 3, "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM",
+         "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN"}
+};
+
+const size_t data_count(sizeof(nsec3_data) / sizeof(*nsec3_data));
+
+TEST_F(InMemoryZoneFinderNSEC3Test, findNSEC3Walk) {
+    // This test basically uses nsec3_data[] declared above along with
+    // the fake hash setup to walk the NSEC3 tree. The names and fake
+    // hash calculation is specially setup so that the tree search
+    // terminates at specific locations in the tree. We findNSEC3() on
+    // each of the nsec3_data[], which is setup such that the hash
+    // results in the search terminating on either side of each node of
+    // the NSEC3 tree. This way, we check what result is returned in
+    // every search termination case in the NSEC3 tree.
+
+    const Name origin("example.org");
+    for (size_t i = 0; i < data_count; ++i) {
+        const Name name = Name(nsec3_data[i].name).concatenate(origin);
+
+        SCOPED_TRACE(name.toText() + (nsec3_data[i].recursive ?
+                                      ", recursive" :
+                                      ", non-recursive"));
+
+        const ZoneFinder::FindNSEC3Result result =
+            zone_finder_.findNSEC3(name, nsec3_data[i].recursive);
+
+        EXPECT_EQ(nsec3_data[i].matched, result.matched);
+        EXPECT_EQ(nsec3_data[i].closest_labels, result.closest_labels);
+
+        if (nsec3_data[i].closest_proof != NULL) {
+            ASSERT_TRUE(result.closest_proof);
+            EXPECT_EQ(Name(nsec3_data[i].closest_proof).concatenate(origin),
+                      result.closest_proof->getName());
+        } else {
+            EXPECT_FALSE(result.closest_proof);
+        }
+
+        if (nsec3_data[i].next_proof != NULL) {
+            ASSERT_TRUE(result.next_proof);
+            EXPECT_EQ(Name(nsec3_data[i].next_proof).concatenate(origin),
+                      result.next_proof->getName());
+        } else {
+            EXPECT_FALSE(result.next_proof);
+        }
+    }
+}
 }

+ 0 - 66
src/lib/datasrc/tests/memory_datasrc_unittest.cc

@@ -293,72 +293,6 @@ setRRset(RRsetPtr rrset, vector<RRsetPtr*>::iterator& it) {
     ++it;
 }
 
-// Some faked NSEC3 hash values commonly used in tests and the faked NSEC3Hash
-// object.
-//
-// For apex (example.org)
-const char* const apex_hash = "0P9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-const char* const apex_hash_lower = "0p9mhaveqvm6t7vbl5lop2u3t2rp3tom";
-// For ns1.example.org
-const char* const ns1_hash = "2T7B4G4VSA5SMI47K61MV5BV1A22BOJR";
-// For w.example.org
-const char* const w_hash = "01UDEMVP1J2F7EG6JEBPS17VP3N8I58H";
-// For x.y.w.example.org (lower-cased)
-const char* const xyw_hash = "2vptu5timamqttgl4luu9kg21e0aor3s";
-// For zzz.example.org.
-const char* const zzz_hash = "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN";
-
-// A simple faked NSEC3 hash calculator with a dedicated creator for it.
-//
-// This is used in some NSEC3-related tests below.
-class TestNSEC3HashCreator : public NSEC3HashCreator {
-    class TestNSEC3Hash : public NSEC3Hash {
-    private:
-        typedef map<Name, string> NSEC3HashMap;
-        typedef NSEC3HashMap::value_type NSEC3HashPair;
-        NSEC3HashMap map_;
-    public:
-        TestNSEC3Hash() {
-            // Build pre-defined hash
-            map_[Name("example.org")] = apex_hash;
-            map_[Name("www.example.org")] = "2S9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-            map_[Name("xxx.example.org")] = "Q09MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-            map_[Name("yyy.example.org")] = "0A9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-            map_[Name("x.y.w.example.org")] =
-                "2VPTU5TIMAMQTTGL4LUU9KG21E0AOR3S";
-            map_[Name("y.w.example.org")] = "K8UDEMVP1J2F7EG6JEBPS17VP3N8I58H";
-            map_[Name("w.example.org")] = w_hash;
-            map_[Name("zzz.example.org")] = zzz_hash;
-            map_[Name("smallest.example.org")] =
-                "00000000000000000000000000000000";
-            map_[Name("largest.example.org")] =
-                "UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUU";
-        }
-        virtual string calculate(const Name& name) const {
-            const NSEC3HashMap::const_iterator found = map_.find(name);
-            if (found != map_.end()) {
-                return (found->second);
-            }
-            isc_throw(isc::Unexpected, "unexpected name for NSEC3 test: "
-                      << name);
-        }
-        virtual bool match(const generic::NSEC3PARAM&) const {
-            return (true);
-        }
-        virtual bool match(const generic::NSEC3&) const {
-            return (true);
-        }
-    };
-
-public:
-    virtual NSEC3Hash* create(const generic::NSEC3PARAM&) const {
-        return (new TestNSEC3Hash);
-    }
-    virtual NSEC3Hash* create(const generic::NSEC3&) const {
-        return (new TestNSEC3Hash);
-    }
-};
-
 /// \brief Test fixture for the InMemoryZoneFinder class
 class InMemoryZoneFinderTest : public ::testing::Test {
     // A straightforward pair of textual RR(set) and a RRsetPtr variable

+ 47 - 11
src/lib/dns/nsec3hash.cc

@@ -16,6 +16,7 @@
 
 #include <cassert>
 #include <cstring>
+#include <cstdlib>
 #include <string>
 #include <vector>
 
@@ -57,17 +58,31 @@ private:
 
 public:
     NSEC3HashRFC5155(uint8_t algorithm, uint16_t iterations,
-                     const vector<uint8_t>& salt) :
+                     const uint8_t* salt_data, size_t salt_length) :
         algorithm_(algorithm), iterations_(iterations),
-        salt_(salt), digest_(SHA1_HASHSIZE), obuf_(Name::MAX_WIRE)
+        salt_data_(NULL), salt_length_(salt_length),
+        digest_(SHA1_HASHSIZE), obuf_(Name::MAX_WIRE)
     {
         if (algorithm_ != NSEC3_HASH_SHA1) {
             isc_throw(UnknownNSEC3HashAlgorithm, "Unknown NSEC3 algorithm: " <<
                       static_cast<unsigned int>(algorithm_));
         }
+
+        if (salt_length > 0) {
+            salt_data_ = static_cast<uint8_t*>(std::malloc(salt_length));
+            if (salt_data_ == NULL) {
+                throw std::bad_alloc();
+            }
+            std::memcpy(salt_data_, salt_data, salt_length);
+        }
+
         SHA1Reset(&sha1_ctx_);
     }
 
+    ~NSEC3HashRFC5155() {
+        std::free(salt_data_);
+    }
+
     virtual std::string calculate(const Name& name) const;
 
     virtual bool match(const generic::NSEC3& nsec3) const;
@@ -78,7 +93,8 @@ public:
 private:
     const uint8_t algorithm_;
     const uint16_t iterations_;
-    const vector<uint8_t> salt_;
+    uint8_t* salt_data_;
+    const size_t salt_length_;
 
     // The following members are placeholder of work place and don't hold
     // any state over multiple calls so can be mutable without breaking
@@ -108,15 +124,14 @@ NSEC3HashRFC5155::calculate(const Name& name) const {
     name_copy.downcase();
     name_copy.toWire(obuf_);
 
-    const uint8_t saltlen = salt_.size();
-    const uint8_t* const salt = (saltlen > 0) ? &salt_[0] : NULL;
     uint8_t* const digest = &digest_[0];
     assert(digest_.size() == SHA1_HASHSIZE);
 
     iterateSHA1(&sha1_ctx_, static_cast<const uint8_t*>(obuf_.getData()),
-                obuf_.getLength(), salt, saltlen, digest);
+                obuf_.getLength(), salt_data_, salt_length_, digest);
     for (unsigned int n = 0; n < iterations_; ++n) {
-        iterateSHA1(&sha1_ctx_, digest, SHA1_HASHSIZE, salt, saltlen, digest);
+        iterateSHA1(&sha1_ctx_, digest, SHA1_HASHSIZE,
+                    salt_data_, salt_length_, digest);
     }
 
     return (encodeBase32Hex(digest_));
@@ -127,8 +142,9 @@ NSEC3HashRFC5155::match(uint8_t algorithm, uint16_t iterations,
                         const vector<uint8_t>& salt) const
 {
     return (algorithm_ == algorithm && iterations_ == iterations &&
-            salt_.size() == salt.size() &&
-            (salt_.empty() || memcmp(&salt_[0], &salt[0], salt_.size()) == 0));
+            salt_length_ == salt.size() &&
+            ((salt_length_ == 0) ||
+             memcmp(salt_data_, &salt[0], salt_length_) == 0));
 }
 
 bool
@@ -175,15 +191,35 @@ NSEC3Hash::create(const generic::NSEC3& nsec3) {
 }
 
 NSEC3Hash*
+NSEC3Hash::create(uint8_t algorithm, uint16_t iterations,
+                  const uint8_t* salt_data, size_t salt_length) {
+    return (getNSEC3HashCreator()->create(algorithm, iterations,
+                                          salt_data, salt_length));
+}
+
+NSEC3Hash*
 DefaultNSEC3HashCreator::create(const generic::NSEC3PARAM& param) const {
+    const vector<uint8_t>& salt = param.getSalt();
     return (new NSEC3HashRFC5155(param.getHashalg(), param.getIterations(),
-                                 param.getSalt()));
+                                 salt.empty() ? NULL : &salt[0],
+                                 salt.size()));
 }
 
 NSEC3Hash*
 DefaultNSEC3HashCreator::create(const generic::NSEC3& nsec3) const {
+    const vector<uint8_t>& salt = nsec3.getSalt();
     return (new NSEC3HashRFC5155(nsec3.getHashalg(), nsec3.getIterations(),
-                                 nsec3.getSalt()));
+                                 salt.empty() ? NULL : &salt[0],
+                                 salt.size()));
+}
+
+NSEC3Hash*
+DefaultNSEC3HashCreator::create(uint8_t algorithm, uint16_t iterations,
+                                const uint8_t* salt_data,
+                                size_t salt_length) const
+{
+    return (new NSEC3HashRFC5155(algorithm, iterations,
+                                 salt_data, salt_length));
 }
 
 void

+ 32 - 2
src/lib/dns/nsec3hash.h

@@ -16,7 +16,8 @@
 #define __NSEC3HASH_H 1
 
 #include <string>
-
+#include <vector>
+#include <stdint.h>
 #include <exceptions/exceptions.h>
 
 namespace isc {
@@ -115,6 +116,16 @@ public:
     /// for hash calculation from an NSEC3 RDATA object.
     static NSEC3Hash* create(const rdata::generic::NSEC3& nsec3);
 
+    /// \brief Factory method of NSECHash from args.
+    ///
+    /// \param algorithm the NSEC3 algorithm to use; currently only 1
+    ///                  (SHA-1) is supported
+    /// \param iterations the number of iterations
+    /// \param salt_data the salt data as a byte array
+    /// \param salt_data_length the length of the salt data
+    static NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
+                             const uint8_t* salt_data, size_t salt_length);
+
     /// \brief The destructor.
     virtual ~NSEC3Hash() {}
 
@@ -167,7 +178,7 @@ public:
 /// would be an experimental extension for a newer hash algorithm or
 /// implementation.
 ///
-/// The two main methods named \c create() correspond to the static factory
+/// The three main methods named \c create() correspond to the static factory
 /// methods of \c NSEC3Hash of the same name.
 ///
 /// By default, the library uses the \c DefaultNSEC3HashCreator creator.
@@ -210,6 +221,22 @@ public:
     /// <code>NSEC3Hash::create(const rdata::generic::NSEC3& param)</code>
     virtual NSEC3Hash* create(const rdata::generic::NSEC3& nsec3)
         const = 0;
+
+    /// \brief Factory method of NSECHash from args.
+    ///
+    /// See
+    /// <code>NSEC3Hash::create(uint8_t algorithm, uint16_t iterations,
+    ///                         const uint8_t* salt_data,
+    ///                         size_t salt_length)</code>
+    ///
+    /// \param algorithm the NSEC3 algorithm to use; currently only 1
+    ///                  (SHA-1) is supported
+    /// \param iterations the number of iterations
+    /// \param salt_data the salt data as a byte array
+    /// \param salt_data_length the length of the salt data
+    virtual NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
+                              const uint8_t* salt_data, size_t salt_length)
+        const = 0;
 };
 
 /// \brief The default NSEC3Hash creator.
@@ -225,6 +252,9 @@ class DefaultNSEC3HashCreator : public NSEC3HashCreator {
 public:
     virtual NSEC3Hash* create(const rdata::generic::NSEC3PARAM& param) const;
     virtual NSEC3Hash* create(const rdata::generic::NSEC3& nsec3) const;
+    virtual NSEC3Hash* create(uint8_t algorithm, uint16_t iterations,
+                              const uint8_t* salt_data,
+                              size_t salt_length) const;
 };
 
 /// \brief The registrar of \c NSEC3HashCreator.

+ 23 - 1
src/lib/dns/tests/nsec3hash_unittest.cc

@@ -20,11 +20,14 @@
 
 #include <dns/nsec3hash.h>
 #include <dns/rdataclass.h>
+#include <util/encode/hex.h>
 
 using boost::scoped_ptr;
 using namespace std;
 using namespace isc::dns;
 using namespace isc::dns::rdata;
+using namespace isc::util;
+using namespace isc::util::encode;
 
 namespace {
 typedef scoped_ptr<NSEC3Hash> NSEC3HashPtr;
@@ -39,7 +42,10 @@ protected:
         test_hash_nsec3(NSEC3Hash::create(generic::NSEC3
                                           ("1 0 12 aabbccdd " +
                                            string(nsec3_common))))
-    {}
+    {
+        const uint8_t salt[] = {0xaa, 0xbb, 0xcc, 0xdd};
+        test_hash_args.reset(NSEC3Hash::create(1, 12, salt, sizeof(salt)));
+    }
 
     ~NSEC3HashTest() {
         // Make sure we reset the hash creator to the default
@@ -53,6 +59,9 @@ protected:
 
     // Similar to test_hash, but created from NSEC3 RR.
     NSEC3HashPtr test_hash_nsec3;
+
+    // Similar to test_hash, but created from passed args.
+    NSEC3HashPtr test_hash_args;
 };
 
 TEST_F(NSEC3HashTest, unknownAlgorithm) {
@@ -65,6 +74,10 @@ TEST_F(NSEC3HashTest, unknownAlgorithm) {
                          generic::NSEC3("2 0 12 aabbccdd " +
                                         string(nsec3_common)))),
                      UnknownNSEC3HashAlgorithm);
+
+    const uint8_t salt[] = {0xaa, 0xbb, 0xcc, 0xdd};
+    EXPECT_THROW(NSEC3HashPtr(NSEC3Hash::create(2, 12, salt, sizeof(salt))),
+                 UnknownNSEC3HashAlgorithm);
 }
 
 // Common checks for NSEC3 hash calculation
@@ -90,6 +103,10 @@ TEST_F(NSEC3HashTest, calculate) {
         SCOPED_TRACE("calculate check with NSEC3 based hash");
         calculateCheck(*test_hash_nsec3);
     }
+    {
+        SCOPED_TRACE("calculate check with args based hash");
+        calculateCheck(*test_hash_args);
+    }
 
     // Some boundary cases: 0-iteration and empty salt.  Borrowed from the
     // .com zone data.
@@ -177,6 +194,11 @@ public:
         }
         return (new TestNSEC3Hash);
     }
+    virtual NSEC3Hash* create(uint8_t, uint16_t,
+                              const uint8_t*, size_t) const {
+        isc_throw(isc::Unexpected,
+                  "This method is not implemented here.");
+    }
 private:
     DefaultNSEC3HashCreator default_creator_;
 };