Browse Source

[2218] Add facility to setup a fake NSEC3Calculate function for use in tests

Mukund Sivaraman 12 years ago
parent
commit
350fa8a599

+ 50 - 60
src/lib/datasrc/memory/tests/zone_finder_unittest.cc

@@ -61,59 +61,57 @@ const char* const xyw_hash = "2vptu5timamqttgl4luu9kg21e0aor3s";
 // For zzz.example.org.
 const char* const zzz_hash = "R53BQ7CC2UVMUBFU5OCMM6PERS9TK9EN";
 
-// A simple faked NSEC3 hash calculator with a dedicated creator for it.
-//
-// This is used in some NSEC3-related tests below.
-// Also see NOTE at inclusion of "../../tests/faked_nsec3.h"
-class TestNSEC3HashCreator : public NSEC3HashCreator {
-    class TestNSEC3Hash : public NSEC3Hash {
-    private:
-        typedef map<Name, string> NSEC3HashMap;
-        typedef NSEC3HashMap::value_type NSEC3HashPair;
-        NSEC3HashMap map_;
-    public:
-        TestNSEC3Hash() {
-            // Build pre-defined hash
-            map_[Name("example.org")] = apex_hash;
-            map_[Name("www.example.org")] = "2S9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-            map_[Name("xxx.example.org")] = "Q09MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-            map_[Name("yyy.example.org")] = "0A9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
-            map_[Name("x.y.w.example.org")] =
-                "2VPTU5TIMAMQTTGL4LUU9KG21E0AOR3S";
-            map_[Name("y.w.example.org")] = "K8UDEMVP1J2F7EG6JEBPS17VP3N8I58H";
-            map_[Name("w.example.org")] = w_hash;
-            map_[Name("zzz.example.org")] = zzz_hash;
-            map_[Name("smallest.example.org")] =
-                "00000000000000000000000000000000";
-            map_[Name("largest.example.org")] =
-                "UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUU";
-        }
-        virtual string calculate(const Name& name) const {
-            const NSEC3HashMap::const_iterator found = map_.find(name);
-            if (found != map_.end()) {
-                return (found->second);
-            }
-            isc_throw(isc::Unexpected, "unexpected name for NSEC3 test: "
-                      << name);
-        }
-        virtual bool match(const generic::NSEC3PARAM&) const {
-            return (true);
-        }
-        virtual bool match(const generic::NSEC3&) const {
-            return (true);
-        }
-    };
+typedef map<Name, string> NSEC3HashMap;
+typedef NSEC3HashMap::value_type NSEC3HashPair;
+NSEC3HashMap nsec3_hash_map;
+
+// A faked NSEC3 hash calculator for convenience. Tests that need to use
+// the faked hashed values should call setFakeNSEC3Calculate() on the
+// MyZoneFinder object at the beginning of the test (at least before
+// adding any NSEC3/NSEC3PARAM RR).
+std::string
+fakeNSEC3Calculate(const Name& name,
+                   const uint16_t,
+                   const uint8_t*,
+                   size_t) {
+    const NSEC3HashMap::const_iterator found = nsec3_hash_map.find(name);
+    if (found != nsec3_hash_map.end()) {
+        return (found->second);
+    }
+
+    isc_throw(isc::Unexpected,
+              "unexpected name for NSEC3 test: " << name);
+}
 
+class MyZoneFinder : public memory::InMemoryZoneFinder {
+private:
 public:
-    virtual NSEC3Hash* create(const generic::NSEC3PARAM&) const {
-        return (new TestNSEC3Hash);
+    MyZoneFinder(const ZoneData& zone_data,
+                 const isc::dns::RRClass& rrclass) :
+         memory::InMemoryZoneFinder(zone_data, rrclass)
+    {
+        // Build pre-defined hash
+        nsec3_hash_map.clear();
+        nsec3_hash_map[Name("example.org")] = apex_hash;
+        nsec3_hash_map[Name("www.example.org")] = "2S9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
+        nsec3_hash_map[Name("xxx.example.org")] = "Q09MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
+        nsec3_hash_map[Name("yyy.example.org")] = "0A9MHAVEQVM6T7VBL5LOP2U3T2RP3TOM";
+        nsec3_hash_map[Name("x.y.w.example.org")] =
+            "2VPTU5TIMAMQTTGL4LUU9KG21E0AOR3S";
+        nsec3_hash_map[Name("y.w.example.org")] = "K8UDEMVP1J2F7EG6JEBPS17VP3N8I58H";
+        nsec3_hash_map[Name("w.example.org")] = w_hash;
+        nsec3_hash_map[Name("zzz.example.org")] = zzz_hash;
+        nsec3_hash_map[Name("smallest.example.org")] =
+            "00000000000000000000000000000000";
+        nsec3_hash_map[Name("largest.example.org")] =
+            "UUUUUUUUUUUUUUUUUUUUUUUUUUUUUUUU";
     }
-    virtual NSEC3Hash* create(const generic::NSEC3&) const {
-        return (new TestNSEC3Hash);
+
+    void setFakeNSEC3Calculate() {
+        nsec3_calculate_ = fakeNSEC3Calculate;
     }
 };
 
-
 /// \brief expensive rrset converter
 ///
 /// converts any specialized rrset (which may not have implemented some
@@ -351,7 +349,7 @@ public:
     // The zone finder to torture by tests
     MemorySegmentTest mem_sgmt_;
     memory::ZoneData* zone_data_;
-    memory::InMemoryZoneFinder zone_finder_;
+    MyZoneFinder zone_finder_;
     isc::datasrc::memory::RdataEncoder encoder_;
 
     // Placeholder for storing RRsets to be checked with rrsetsCheck()
@@ -400,12 +398,6 @@ public:
     RRsetPtr rr_ns_nsec_;
     RRsetPtr rr_wild_nsec_;
 
-    // A faked NSEC3 hash calculator for convenience.
-    // Tests that need to use the faked hashed values should call
-    // setNSEC3HashCreator() with a pointer to this variable at the beginning
-    // of the test (at least before adding any NSEC3/NSEC3PARAM RR).
-    TestNSEC3HashCreator nsec3_hash_creator_;
-
     /**
      * \brief Test one find query to the zone finder.
      *
@@ -1458,10 +1450,9 @@ TEST_F(InMemoryZoneFinderTest, cancelWildcardNSEC) {
 }
 
 
-// DISABLED: nsec3 will be re-added in #2118
-TEST_F(InMemoryZoneFinderTest, DISABLED_findNSEC3) {
+TEST_F(InMemoryZoneFinderTest, findNSEC3) {
     // Set up the faked hash calculator.
-    setNSEC3HashCreator(&nsec3_hash_creator_);
+    zone_finder_.setFakeNSEC3Calculate();
 
     // Add a few NSEC3 records:
     // apex (example.org.): hash=0P..
@@ -1484,10 +1475,9 @@ TEST_F(InMemoryZoneFinderTest, DISABLED_findNSEC3) {
     performNSEC3Test(zone_finder_);
 }
 
-// DISABLED: NSEC3 will be re-added in #2218
-TEST_F(InMemoryZoneFinderTest, DISABLED_findNSEC3ForBadZone) {
+TEST_F(InMemoryZoneFinderTest, findNSEC3ForBadZone) {
     // Set up the faked hash calculator.
-    setNSEC3HashCreator(&nsec3_hash_creator_);
+    zone_finder_.setFakeNSEC3Calculate();
 
     // If the zone has nothing about NSEC3 (neither NSEC3 or NSEC3PARAM),
     // findNSEC3() should be rejected.

+ 44 - 44
src/lib/datasrc/memory/zone_finder.cc

@@ -266,45 +266,6 @@ getNSECForNXRRSET(const ZoneData& zone_data,
     return (NULL);
 }
 
-inline void
-iterateSHA1(SHA1Context* ctx, const uint8_t* input, size_t inlength,
-            const uint8_t* salt, size_t saltlen,
-            uint8_t output[SHA1_HASHSIZE])
-{
-    SHA1Reset(ctx);
-    SHA1Input(ctx, input, inlength);
-    SHA1Input(ctx, salt, saltlen); // this works whether saltlen == or > 0
-    SHA1Result(ctx, output);
-}
-
-std::string
-NSEC3Calculate(const Name& name,
-               const uint16_t iterations,
-               const uint8_t* salt,
-               size_t salt_len) {
-    // We first need to normalize the name by converting all upper case
-    // characters in the labels to lower ones.
-    OutputBuffer obuf(Name::MAX_WIRE);
-    Name name_copy(name);
-    name_copy.downcase();
-    name_copy.toWire(obuf);
-
-    const uint8_t* const salt_buf = (salt_len > 0) ? salt : NULL;
-    std::vector<uint8_t> digest(SHA1_HASHSIZE);
-    uint8_t* const digest_buf = &digest[0];
-
-    SHA1Context sha1_ctx;
-    iterateSHA1(&sha1_ctx, static_cast<const uint8_t*>(obuf.getData()),
-                obuf.getLength(), salt_buf, salt_len, digest_buf);
-    for (unsigned int n = 0; n < iterations; ++n) {
-        iterateSHA1(&sha1_ctx, digest_buf, SHA1_HASHSIZE,
-                    salt_buf, salt_len,
-                    digest_buf);
-    }
-
-    return (encodeBase32Hex(digest));
-}
-
 // Structure to hold result data of the findNode() call
 class FindNodeResult {
 public:
@@ -492,6 +453,45 @@ FindNodeResult findNode(const ZoneData& zone_data,
 
 } // end anonymous namespace
 
+inline void
+iterateSHA1(SHA1Context* ctx, const uint8_t* input, size_t inlength,
+            const uint8_t* salt, size_t saltlen,
+            uint8_t output[SHA1_HASHSIZE])
+{
+    SHA1Reset(ctx);
+    SHA1Input(ctx, input, inlength);
+    SHA1Input(ctx, salt, saltlen); // this works whether saltlen == or > 0
+    SHA1Result(ctx, output);
+}
+
+std::string
+InMemoryZoneFinderNSEC3Calculate(const Name& name,
+                                 const uint16_t iterations,
+                                 const uint8_t* salt,
+                                 size_t salt_len) {
+    // We first need to normalize the name by converting all upper case
+    // characters in the labels to lower ones.
+    OutputBuffer obuf(Name::MAX_WIRE);
+    Name name_copy(name);
+    name_copy.downcase();
+    name_copy.toWire(obuf);
+
+    const uint8_t* const salt_buf = (salt_len > 0) ? salt : NULL;
+    std::vector<uint8_t> digest(SHA1_HASHSIZE);
+    uint8_t* const digest_buf = &digest[0];
+
+    SHA1Context sha1_ctx;
+    iterateSHA1(&sha1_ctx, static_cast<const uint8_t*>(obuf.getData()),
+                obuf.getLength(), salt_buf, salt_len, digest_buf);
+    for (unsigned int n = 0; n < iterations; ++n) {
+        iterateSHA1(&sha1_ctx, digest_buf, SHA1_HASHSIZE,
+                    salt_buf, salt_len,
+                    digest_buf);
+    }
+
+    return (encodeBase32Hex(digest));
+}
+
 // Specialization of the ZoneFinder::Context for the in-memory finder.
 class InMemoryZoneFinder::Context : public ZoneFinder::Context {
 public:
@@ -664,11 +664,11 @@ InMemoryZoneFinder::findNSEC3(const isc::dns::Name& name, bool recursive) {
     // NSEC3 hash.
     for (unsigned int labels = qlabels; labels >= olabels; --labels) {
         const std::string hlabel =
-             NSEC3Calculate((labels == qlabels ?
-                             name : name.split(qlabels - labels, labels)),
-                            nsec3_data->iterations,
-                            nsec3_data->getSaltData(),
-                            nsec3_data->getSaltLen());
+             (nsec3_calculate_)((labels == qlabels ?
+                                 name : name.split(qlabels - labels, labels)),
+                                nsec3_data->iterations,
+                                nsec3_data->getSaltData(),
+                                nsec3_data->getSaltLen());
 
         LOG_DEBUG(logger, DBG_TRACE_BASIC, DATASRC_MEM_FINDNSEC3_TRYHASH).
             arg(name).arg(labels).arg(hlabel);

+ 15 - 2
src/lib/datasrc/memory/zone_finder.h

@@ -48,6 +48,12 @@ public:
     const ZoneNode* const found_node;
 };
 
+std::string
+InMemoryZoneFinderNSEC3Calculate(const isc::dns::Name& name,
+                                 const uint16_t iterations,
+                                 const uint8_t* salt,
+                                 size_t salt_len);
+
 /// A derived zone finder class intended to be used with the memory data
 /// source, using ZoneData for its contents.
 class InMemoryZoneFinder : boost::noncopyable, public ZoneFinder {
@@ -66,7 +72,8 @@ public:
     InMemoryZoneFinder(const ZoneData& zone_data,
                        const isc::dns::RRClass& rrclass) :
         zone_data_(zone_data),
-        rrclass_(rrclass)
+        rrclass_(rrclass),
+        nsec3_calculate_(InMemoryZoneFinderNSEC3Calculate)
     {}
 
     /// \brief Find an RRset in the datasource
@@ -101,7 +108,6 @@ public:
         return rrclass_;
     }
 
-
 private:
     /// \brief In-memory version of finder context.
     ///
@@ -119,6 +125,13 @@ private:
 
     const ZoneData& zone_data_;
     const isc::dns::RRClass& rrclass_;
+
+protected:
+    typedef std::string (NSEC3CalculateFn) (const isc::dns::Name& name,
+                                            const uint16_t iterations,
+                                            const uint8_t* salt,
+                                            size_t salt_len);
+    NSEC3CalculateFn* nsec3_calculate_;
 };
 
 } // namespace memory